AnilNiraula commited on
Commit
f498762
·
verified ·
1 Parent(s): 29e2fcf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -22
app.py CHANGED
@@ -14,7 +14,8 @@ logger = logging.getLogger(__name__)
14
  device = torch.device("cpu")
15
  logger.info(f"Using device: {device}")
16
 
17
- # Response cache
 
18
  response_cache = {
19
  "hi": "Hello! I'm your financial advisor. How can I help with investing?",
20
  "hello": "Hello! I'm your financial advisor. How can I help with investing?",
@@ -95,9 +96,27 @@ response_cache = {
95
  "4. Use dollar-cost averaging for regular investments.\n"
96
  "5. Monitor and diversify your portfolio.\n"
97
  "Consult a financial planner."
 
 
 
 
 
 
 
 
 
98
  )
99
  }
100
 
 
 
 
 
 
 
 
 
 
101
  # Load model and tokenizer
102
  model_name = "distilgpt2"
103
  try:
@@ -124,7 +143,7 @@ prompt_prefix = (
124
  prefix_tokens = tokenizer(prompt_prefix, return_tensors="pt", truncation=True, max_length=512).to(device)
125
 
126
  # Fuzzy matching for cache
127
- def get_closest_cache_key(message, cache_keys, threshold=0.8):
128
  matches = difflib.get_close_matches(message, cache_keys, n=1, cutoff=threshold)
129
  return matches[0] if matches else None
130
 
@@ -138,12 +157,22 @@ def chat_with_model(user_input, history=None):
138
  closest_key = cache_key if cache_key in response_cache else get_closest_cache_key(cache_key, cache_keys)
139
  if closest_key:
140
  logger.info(f"Cache hit for: {closest_key}")
141
- return response_cache[closest_key], history
 
 
 
 
 
142
 
143
  # Skip model for short prompts
144
  if len(user_input.strip()) <= 5:
145
  logger.info("Short prompt, returning default response")
146
- return "Hello! I'm your financial advisor. Ask about investing!", history
 
 
 
 
 
147
 
148
  # Construct prompt
149
  full_prompt = prompt_prefix + user_input + "\nA:"
@@ -153,7 +182,7 @@ def chat_with_model(user_input, history=None):
153
  with torch.cpu.amp.autocast(), torch.inference_mode():
154
  outputs = model.generate(
155
  **inputs,
156
- max_new_tokens=50,
157
  min_length=15,
158
  do_sample=True,
159
  temperature=0.7,
@@ -162,37 +191,55 @@ def chat_with_model(user_input, history=None):
162
  pad_token_id=tokenizer.eos_token_id
163
  )
164
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
165
- logger.info("Generated response")
166
- torch.cuda.empty_cache() # Clear memory
167
  response = response[len(full_prompt):].strip() if response.startswith(full_prompt) else response
 
 
 
 
 
 
 
 
 
 
168
 
169
  # Update history
170
  history = history or []
171
  history.append({"role": "user", "content": user_input})
172
  history.append({"role": "assistant", "content": response})
 
173
  return response, history
174
  except Exception as e:
175
  logger.error(f"Error generating response: {e}")
176
- return f"Error: {str(e)}", history
 
 
 
 
 
177
 
178
  # Create Gradio interface
179
  logger.info("Initializing Gradio interface")
180
- with gr.Blocks() as interface:
181
- chatbot = gr.Chatbot(type="messages")
182
- msg = gr.Textbox(label="Your message")
183
- submit = gr.Button("Send")
184
- clear = gr.Button("Clear")
 
185
 
186
- def submit_message(user_input, history):
187
- response, updated_history = chat_with_model(user_input, history)
188
- return response, updated_history
189
 
190
- submit.click(
191
- fn=submit_message,
192
- inputs=[msg, chatbot],
193
- outputs=[msg, chatbot]
194
- )
195
- clear.click(lambda: None, None, chatbot, queue=False)
 
 
 
196
 
197
  # Launch interface (conditional for Spaces)
198
  if __name__ == "__main__" and not os.getenv("HF_SPACE"):
 
14
  device = torch.device("cpu")
15
  logger.info(f"Using device: {device}")
16
 
17
+ # Load or initialize response cache
18
+ cache_file = "cache.json"
19
  response_cache = {
20
  "hi": "Hello! I'm your financial advisor. How can I help with investing?",
21
  "hello": "Hello! I'm your financial advisor. How can I help with investing?",
 
96
  "4. Use dollar-cost averaging for regular investments.\n"
97
  "5. Monitor and diversify your portfolio.\n"
98
  "Consult a financial planner."
99
+ ),
100
+ "steps to invest": (
101
+ "Here are steps to invest:\n"
102
+ "1. Educate yourself using Investopedia.\n"
103
+ "2. Open a brokerage account (e.g., Fidelity).\n"
104
+ "3. Deposit an initial $100 after savings.\n"
105
+ "4. Buy an ETF like VOO after research.\n"
106
+ "5. Use dollar-cost averaging monthly.\n"
107
+ "Consult a financial planner."
108
  )
109
  }
110
 
111
+ # Load persistent cache
112
+ try:
113
+ if os.path.exists(cache_file):
114
+ with open(cache_file, 'r') as f:
115
+ response_cache.update(json.load(f))
116
+ logger.info("Loaded persistent cache from cache.json")
117
+ except Exception as e:
118
+ logger.warning(f"Failed to load cache.json: {e}")
119
+
120
  # Load model and tokenizer
121
  model_name = "distilgpt2"
122
  try:
 
143
  prefix_tokens = tokenizer(prompt_prefix, return_tensors="pt", truncation=True, max_length=512).to(device)
144
 
145
  # Fuzzy matching for cache
146
+ def get_closest_cache_key(message, cache_keys, threshold=0.75):
147
  matches = difflib.get_close_matches(message, cache_keys, n=1, cutoff=threshold)
148
  return matches[0] if matches else None
149
 
 
157
  closest_key = cache_key if cache_key in response_cache else get_closest_cache_key(cache_key, cache_keys)
158
  if closest_key:
159
  logger.info(f"Cache hit for: {closest_key}")
160
+ response = response_cache[closest_key]
161
+ logger.info(f"Chatbot response: {response}")
162
+ history = history or []
163
+ history.append({"role": "user", "content": user_input})
164
+ history.append({"role": "assistant", "content": response})
165
+ return response, history
166
 
167
  # Skip model for short prompts
168
  if len(user_input.strip()) <= 5:
169
  logger.info("Short prompt, returning default response")
170
+ response = "Hello! I'm your financial advisor. Ask about investing!"
171
+ logger.info(f"Chatbot response: {response}")
172
+ history = history or []
173
+ history.append({"role": "user", "content": user_input})
174
+ history.append({"role": "assistant", "content": response})
175
+ return response, history
176
 
177
  # Construct prompt
178
  full_prompt = prompt_prefix + user_input + "\nA:"
 
182
  with torch.cpu.amp.autocast(), torch.inference_mode():
183
  outputs = model.generate(
184
  **inputs,
185
+ max_new_tokens=40,
186
  min_length=15,
187
  do_sample=True,
188
  temperature=0.7,
 
191
  pad_token_id=tokenizer.eos_token_id
192
  )
193
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
194
  response = response[len(full_prompt):].strip() if response.startswith(full_prompt) else response
195
+ logger.info(f"Chatbot response: {response}")
196
+
197
+ # Update cache and save to file
198
+ response_cache[cache_key] = response
199
+ try:
200
+ with open(cache_file, 'w') as f:
201
+ json.dump(response_cache, f)
202
+ logger.info("Updated cache.json")
203
+ except Exception as e:
204
+ logger.warning(f"Failed to update cache.json: {e}")
205
 
206
  # Update history
207
  history = history or []
208
  history.append({"role": "user", "content": user_input})
209
  history.append({"role": "assistant", "content": response})
210
+ torch.cuda.empty_cache() # Clear memory
211
  return response, history
212
  except Exception as e:
213
  logger.error(f"Error generating response: {e}")
214
+ response = f"Error: {str(e)}"
215
+ logger.info(f"Chatbot response: {response}")
216
+ history = history or []
217
+ history.append({"role": "user", "content": user_input})
218
+ history.append({"role": "assistant", "content": response})
219
+ return response, history
220
 
221
  # Create Gradio interface
222
  logger.info("Initializing Gradio interface")
223
+ try:
224
+ with gr.Blocks() as interface:
225
+ chatbot = gr.Chatbot(type="messages")
226
+ msg = gr.Textbox(label="Your message")
227
+ submit = gr.Button("Send")
228
+ clear = gr.Button("Clear")
229
 
230
+ def submit_message(user_input, history):
231
+ response, updated_history = chat_with_model(user_input, history)
232
+ return response, updated_history
233
 
234
+ submit.click(
235
+ fn=submit_message,
236
+ inputs=[msg, chatbot],
237
+ outputs=[msg, chatbot]
238
+ )
239
+ clear.click(lambda: None, None, chatbot)
240
+ except Exception as e:
241
+ logger.error(f"Error initializing Gradio interface: {e}")
242
+ raise
243
 
244
  # Launch interface (conditional for Spaces)
245
  if __name__ == "__main__" and not os.getenv("HF_SPACE"):