File size: 5,722 Bytes
79e02f2
71ef153
79e02f2
 
 
 
4e7895f
14c8701
2a8517e
71ef153
 
c80d447
cae7f88
 
71ef153
 
 
c764d12
799c24a
76a5ec0
71ef153
 
0e9b0c6
4e7895f
 
 
 
 
 
71ef153
 
 
 
4e7895f
71ef153
4e7895f
71ef153
 
 
 
 
4e7895f
 
 
 
 
 
 
71ef153
 
79e02f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e9b0c6
 
 
 
 
 
79e02f2
 
0e9b0c6
c764d12
79e02f2
 
 
 
2a8517e
79e02f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c764d12
 
 
 
79e02f2
 
 
 
 
c764d12
79e02f2
c764d12
 
 
 
 
 
 
 
 
 
 
 
 
 
79e02f2
 
 
 
 
 
 
 
 
 
 
 
 
 
c764d12
79e02f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b18a919
79e02f2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import streamlit as st
from langchain.llms import CTransformers
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import ChatMessageHistory
from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnableSequence
from huggingface_hub import hf_hub_download
import os

# --- Model Definitions ---
MODEL_MAP = {
    "TinyLlama (1.1B)": {
        "repo_id": "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
        "filename": "tinyllama-1.1b-chat-v1.0.Q4_0.gguf",
        "type": "llama"
    },
    "Deepseek-Coder (1.3B)": {
        "repo_id": "TheBloke/deepseek-coder-1.3b-base-GGUF",
        "filename": "deepseek-coder-1.3b-base.Q4_K_M.gguf",
        "type": "llama"
    }
}

# --- Model Loading ---
@st.cache_resource
def download_model_from_hub(repo_id, filename):
    st.write(f"Downloading model '{filename}' from Hugging Face Hub...")
    model_path = hf_hub_download(repo_id=repo_id, filename=filename)
    return model_path

@st.cache_resource
def load_llm(model_name):
    model_info = MODEL_MAP[model_name]
    model_path = download_model_from_hub(model_info["repo_id"], model_info["filename"])
    llm = CTransformers(
        model=model_path,
        model_type=model_info["type"],
        config={'max_new_tokens': 2048, 'temperature': 0.7}
    )
    return llm

# ----------------- Streamlit UI and Logic -----------------
st.set_page_config(layout="wide")
st.title("My Local Chatbot")

st.sidebar.header("Settings")
selected_model_name = st.sidebar.selectbox("Choose a Model", list(MODEL_MAP.keys()))

llm = load_llm(selected_model_name)
st.success(f"Model '{selected_model_name}' loaded successfully!")

MAX_HISTORY = st.sidebar.number_input("Max History", min_value=1, max_value=10, value=2, step=1)
CONTEXT_SIZE = st.sidebar.number_input("Context Size", min_value=1024, max_value=16384, value=8192, step=1024)
TEMPERATURE = st.sidebar.slider("Temperature", 0.0, 1.5, 0.7, 0.1)
TOP_P = st.sidebar.slider("Top-p (nucleus sampling)", 0.0, 1.0, 0.9, 0.05)
TOP_K = st.sidebar.slider("Top-k", 0, 100, 40, 5)
MAX_TOKENS = st.sidebar.number_input("Max Tokens", min_value=256, max_value=16384, value=2048, step=256)

if "chat_history" not in st.session_state:
    st.session_state.chat_history = []
if "memory" not in st.session_state:
    st.session_state.memory = ConversationBufferMemory(return_messages=True)
if "summary" not in st.session_state:
    st.session_state.summary = ""

def clear_memory():
    chat_history = ChatMessageHistory()
    st.session_state.memory = ConversationBufferMemory(chat_memory=chat_history)
    st.session_state.chat_history = []
    st.session_state.summary = ""

if st.sidebar.button("Clear Conversation History"):
    clear_memory()

# --- CORRECTED: Summary chain is now defined globally ---
summary_prompt_template = PromptTemplate(
    input_variables=["chat_history"],
    template="You are a summarizer. Summarize the following conversation to preserve key information and context. \n\n{chat_history}"
)

summary_chain = summary_prompt_template | llm

def get_summary(chat_history_str):
    return summary_chain.invoke({"chat_history": chat_history_str})

def summarize_chat():
    if not st.session_state.chat_history:
        return "No chat history to summarize."
    return get_summary(st.session_state.chat_history)

if st.sidebar.button("Summarize Chat"):
    with st.sidebar:
        st.markdown("**Chat Summary:**")
        summary = summarize_chat()
        st.success(summary)

# --- Model-specific prompt templates ---
if "Llama" in selected_model_name:
    template = """[INST]
    You are a helpful assistant.
    Current conversation summary:
    {summary}
    Conversation history:
    {history}
    User: {human_input}
    [/INST]
    Assistant:"""
elif "Deepseek" in selected_model_name:
    template = """<|im_start|>system
    You are a helpful assistant.
    Current conversation summary:
    {summary}<|im_end|>
    <|im_start|>user
    {history}
    {human_input}<|im_end|>
    <|im_start|>assistant
    """

prompt_template = PromptTemplate(
    input_variables=["summary", "history", "human_input"],
    template=template
)

chain = prompt_template | llm

for msg in st.session_state.chat_history:
    with st.chat_message(msg["role"]):
        st.markdown(msg["content"])

def trim_memory():
    if len(st.session_state.chat_history) > MAX_HISTORY * 2:
        history_to_summarize = st.session_state.chat_history[:(len(st.session_state.chat_history) - MAX_HISTORY * 2)]
        history_str = ""
        for msg in history_to_summarize:
            history_str += f"{msg['role']}: {msg['content']}\n"
        new_summary = summary_chain.invoke({"chat_history": history_str})
        st.session_state.summary += "\n" + new_summary
        st.session_state.chat_history = st.session_state.chat_history[(len(st.session_state.chat_history) - MAX_HISTORY * 2):]

if prompt := st.chat_input("Say something"):
    with st.chat_message("user"):
        st.markdown(prompt)
    st.session_state.chat_history.append({"role": "user", "content": prompt})
    trim_memory()
    formatted_history = ""
    for msg in st.session_state.chat_history:
        formatted_history += f"{msg['role']}: {msg['content']}\n"
    with st.chat_message("assistant"):
        response_container = st.empty()
        full_response = ""
        for chunk in chain.stream({
                "human_input": prompt,
                "history": formatted_history,
                "summary": st.session_state.summary
            }):
            full_response += chunk
            response_container.markdown(full_response)
    st.session_state.chat_history.append({"role": "assistant", "content": full_response})