Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -14,7 +14,8 @@ logger = logging.getLogger(__name__)
|
|
14 |
device = torch.device("cpu")
|
15 |
logger.info(f"Using device: {device}")
|
16 |
|
17 |
-
#
|
|
|
18 |
response_cache = {
|
19 |
"hi": "Hello! I'm your financial advisor. How can I help with investing?",
|
20 |
"hello": "Hello! I'm your financial advisor. How can I help with investing?",
|
@@ -95,9 +96,27 @@ response_cache = {
|
|
95 |
"4. Use dollar-cost averaging for regular investments.\n"
|
96 |
"5. Monitor and diversify your portfolio.\n"
|
97 |
"Consult a financial planner."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
)
|
99 |
}
|
100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
# Load model and tokenizer
|
102 |
model_name = "distilgpt2"
|
103 |
try:
|
@@ -124,7 +143,7 @@ prompt_prefix = (
|
|
124 |
prefix_tokens = tokenizer(prompt_prefix, return_tensors="pt", truncation=True, max_length=512).to(device)
|
125 |
|
126 |
# Fuzzy matching for cache
|
127 |
-
def get_closest_cache_key(message, cache_keys, threshold=0.
|
128 |
matches = difflib.get_close_matches(message, cache_keys, n=1, cutoff=threshold)
|
129 |
return matches[0] if matches else None
|
130 |
|
@@ -138,12 +157,22 @@ def chat_with_model(user_input, history=None):
|
|
138 |
closest_key = cache_key if cache_key in response_cache else get_closest_cache_key(cache_key, cache_keys)
|
139 |
if closest_key:
|
140 |
logger.info(f"Cache hit for: {closest_key}")
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
# Skip model for short prompts
|
144 |
if len(user_input.strip()) <= 5:
|
145 |
logger.info("Short prompt, returning default response")
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
# Construct prompt
|
149 |
full_prompt = prompt_prefix + user_input + "\nA:"
|
@@ -153,7 +182,7 @@ def chat_with_model(user_input, history=None):
|
|
153 |
with torch.cpu.amp.autocast(), torch.inference_mode():
|
154 |
outputs = model.generate(
|
155 |
**inputs,
|
156 |
-
max_new_tokens=
|
157 |
min_length=15,
|
158 |
do_sample=True,
|
159 |
temperature=0.7,
|
@@ -162,37 +191,55 @@ def chat_with_model(user_input, history=None):
|
|
162 |
pad_token_id=tokenizer.eos_token_id
|
163 |
)
|
164 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
165 |
-
logger.info("Generated response")
|
166 |
-
torch.cuda.empty_cache() # Clear memory
|
167 |
response = response[len(full_prompt):].strip() if response.startswith(full_prompt) else response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
# Update history
|
170 |
history = history or []
|
171 |
history.append({"role": "user", "content": user_input})
|
172 |
history.append({"role": "assistant", "content": response})
|
|
|
173 |
return response, history
|
174 |
except Exception as e:
|
175 |
logger.error(f"Error generating response: {e}")
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
177 |
|
178 |
# Create Gradio interface
|
179 |
logger.info("Initializing Gradio interface")
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
185 |
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
|
|
|
|
|
|
196 |
|
197 |
# Launch interface (conditional for Spaces)
|
198 |
if __name__ == "__main__" and not os.getenv("HF_SPACE"):
|
|
|
14 |
device = torch.device("cpu")
|
15 |
logger.info(f"Using device: {device}")
|
16 |
|
17 |
+
# Load or initialize response cache
|
18 |
+
cache_file = "cache.json"
|
19 |
response_cache = {
|
20 |
"hi": "Hello! I'm your financial advisor. How can I help with investing?",
|
21 |
"hello": "Hello! I'm your financial advisor. How can I help with investing?",
|
|
|
96 |
"4. Use dollar-cost averaging for regular investments.\n"
|
97 |
"5. Monitor and diversify your portfolio.\n"
|
98 |
"Consult a financial planner."
|
99 |
+
),
|
100 |
+
"steps to invest": (
|
101 |
+
"Here are steps to invest:\n"
|
102 |
+
"1. Educate yourself using Investopedia.\n"
|
103 |
+
"2. Open a brokerage account (e.g., Fidelity).\n"
|
104 |
+
"3. Deposit an initial $100 after savings.\n"
|
105 |
+
"4. Buy an ETF like VOO after research.\n"
|
106 |
+
"5. Use dollar-cost averaging monthly.\n"
|
107 |
+
"Consult a financial planner."
|
108 |
)
|
109 |
}
|
110 |
|
111 |
+
# Load persistent cache
|
112 |
+
try:
|
113 |
+
if os.path.exists(cache_file):
|
114 |
+
with open(cache_file, 'r') as f:
|
115 |
+
response_cache.update(json.load(f))
|
116 |
+
logger.info("Loaded persistent cache from cache.json")
|
117 |
+
except Exception as e:
|
118 |
+
logger.warning(f"Failed to load cache.json: {e}")
|
119 |
+
|
120 |
# Load model and tokenizer
|
121 |
model_name = "distilgpt2"
|
122 |
try:
|
|
|
143 |
prefix_tokens = tokenizer(prompt_prefix, return_tensors="pt", truncation=True, max_length=512).to(device)
|
144 |
|
145 |
# Fuzzy matching for cache
|
146 |
+
def get_closest_cache_key(message, cache_keys, threshold=0.75):
|
147 |
matches = difflib.get_close_matches(message, cache_keys, n=1, cutoff=threshold)
|
148 |
return matches[0] if matches else None
|
149 |
|
|
|
157 |
closest_key = cache_key if cache_key in response_cache else get_closest_cache_key(cache_key, cache_keys)
|
158 |
if closest_key:
|
159 |
logger.info(f"Cache hit for: {closest_key}")
|
160 |
+
response = response_cache[closest_key]
|
161 |
+
logger.info(f"Chatbot response: {response}")
|
162 |
+
history = history or []
|
163 |
+
history.append({"role": "user", "content": user_input})
|
164 |
+
history.append({"role": "assistant", "content": response})
|
165 |
+
return response, history
|
166 |
|
167 |
# Skip model for short prompts
|
168 |
if len(user_input.strip()) <= 5:
|
169 |
logger.info("Short prompt, returning default response")
|
170 |
+
response = "Hello! I'm your financial advisor. Ask about investing!"
|
171 |
+
logger.info(f"Chatbot response: {response}")
|
172 |
+
history = history or []
|
173 |
+
history.append({"role": "user", "content": user_input})
|
174 |
+
history.append({"role": "assistant", "content": response})
|
175 |
+
return response, history
|
176 |
|
177 |
# Construct prompt
|
178 |
full_prompt = prompt_prefix + user_input + "\nA:"
|
|
|
182 |
with torch.cpu.amp.autocast(), torch.inference_mode():
|
183 |
outputs = model.generate(
|
184 |
**inputs,
|
185 |
+
max_new_tokens=40,
|
186 |
min_length=15,
|
187 |
do_sample=True,
|
188 |
temperature=0.7,
|
|
|
191 |
pad_token_id=tokenizer.eos_token_id
|
192 |
)
|
193 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
194 |
response = response[len(full_prompt):].strip() if response.startswith(full_prompt) else response
|
195 |
+
logger.info(f"Chatbot response: {response}")
|
196 |
+
|
197 |
+
# Update cache and save to file
|
198 |
+
response_cache[cache_key] = response
|
199 |
+
try:
|
200 |
+
with open(cache_file, 'w') as f:
|
201 |
+
json.dump(response_cache, f)
|
202 |
+
logger.info("Updated cache.json")
|
203 |
+
except Exception as e:
|
204 |
+
logger.warning(f"Failed to update cache.json: {e}")
|
205 |
|
206 |
# Update history
|
207 |
history = history or []
|
208 |
history.append({"role": "user", "content": user_input})
|
209 |
history.append({"role": "assistant", "content": response})
|
210 |
+
torch.cuda.empty_cache() # Clear memory
|
211 |
return response, history
|
212 |
except Exception as e:
|
213 |
logger.error(f"Error generating response: {e}")
|
214 |
+
response = f"Error: {str(e)}"
|
215 |
+
logger.info(f"Chatbot response: {response}")
|
216 |
+
history = history or []
|
217 |
+
history.append({"role": "user", "content": user_input})
|
218 |
+
history.append({"role": "assistant", "content": response})
|
219 |
+
return response, history
|
220 |
|
221 |
# Create Gradio interface
|
222 |
logger.info("Initializing Gradio interface")
|
223 |
+
try:
|
224 |
+
with gr.Blocks() as interface:
|
225 |
+
chatbot = gr.Chatbot(type="messages")
|
226 |
+
msg = gr.Textbox(label="Your message")
|
227 |
+
submit = gr.Button("Send")
|
228 |
+
clear = gr.Button("Clear")
|
229 |
|
230 |
+
def submit_message(user_input, history):
|
231 |
+
response, updated_history = chat_with_model(user_input, history)
|
232 |
+
return response, updated_history
|
233 |
|
234 |
+
submit.click(
|
235 |
+
fn=submit_message,
|
236 |
+
inputs=[msg, chatbot],
|
237 |
+
outputs=[msg, chatbot]
|
238 |
+
)
|
239 |
+
clear.click(lambda: None, None, chatbot)
|
240 |
+
except Exception as e:
|
241 |
+
logger.error(f"Error initializing Gradio interface: {e}")
|
242 |
+
raise
|
243 |
|
244 |
# Launch interface (conditional for Spaces)
|
245 |
if __name__ == "__main__" and not os.getenv("HF_SPACE"):
|