AnilNiraula commited on
Commit
08d63f2
Β·
verified Β·
1 Parent(s): 93da63e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -37
app.py CHANGED
@@ -6,6 +6,7 @@ import gradio as gr
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  import pandas as pd
8
  import re
 
9
  import json
10
  import difflib
11
  from functools import lru_cache
@@ -48,54 +49,75 @@ else:
48
  # Hardcoded fallback for recent periods if dataset is incomplete
49
  fallback_returns = {
50
  (2020, 2022): 8.3, # Average annual return based on external data
51
- (2015, 2024): 12.2
 
52
  }
53
 
54
  # Load model and tokenizer at startup
55
  model_name = "./finetuned_model" if os.path.exists("./finetuned_model") else "distilgpt2"
56
- tokenizer = AutoTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=False)
57
- tokenizer.pad_token = tokenizer.eos_token
58
- model = AutoModelForCausalLM.from_pretrained(model_name).eval()
59
- model = model.to_bettertransformer() # Enable BetterTransformer
60
- sample_input = tokenizer("Average return of S&P 500", return_tensors="pt")["input_ids"]
61
- traced_model = torch.jit.trace(model, sample_input)
62
- traced_model.save("distilgpt2_traced.pt")
63
- model = torch.jit.load("distilgpt2_traced.pt")
 
 
 
 
 
 
 
 
 
64
 
65
- # Response cache with financial data entries
66
- response_cache = {
67
- "hi": "Hello! I'm FinChat, your financial advisor. How can I help with investing?",
68
- "what is the average return rate of the s&p 500 in the past 10 years?": (
69
- "The S&P 500’s average annual return rate from 2015 to 2024 was approximately 12.2%, including dividends, based on historical data."
70
- ),
71
- "what was the average annual return of the s&p 500 between 2020 and 2022?": (
72
- "The S&P 500’s average annual return from 2020 to 2022 was approximately 8.3%, including dividends, with significant volatility due to the COVID-19 recovery and 2022 bear market."
73
- )
74
- }
 
 
 
 
 
 
 
75
 
76
  # Substring matching for cache with exact year matching
77
- def get_closest_cache_key(message, cache_keys):
 
78
  message = message.lower().strip()
79
  year_match = re.search(r'(\d{4})\s*(?:and|to|-|–)\s*(\d{4})', message)
80
  if year_match:
81
  start_year, end_year = year_match.groups()
82
- for key in cache_keys:
83
  if f"{start_year} and {end_year}" in key or f"{start_year} to {end_year}" in key or f"{start_year}–{end_year}" in key:
84
  return key
85
- matches = difflib.get_close_matches(message, cache_keys, n=1, cutoff=0.7)
86
  return matches[0] if matches else None
87
 
88
  # Parse period from user input
89
  def parse_period(query):
 
90
  match = re.search(r'(?:between|from)\s*(\d{4})\s*(?:and|to|-|–)\s*(\d{4})', query, re.IGNORECASE)
91
  if match:
92
  start_year, end_year = map(int, match.groups())
93
  return start_year, end_year, None
 
94
  match = re.search(r'(\d+)-year.*from\s*(\d{4})', query, re.IGNORECASE)
95
  if match:
96
  duration, start_year = map(int, match.groups())
97
  end_year = start_year + duration - 1
98
  return start_year, end_year, duration
 
99
  match = re.search(r'past\s*(\d+)-year|\b(\d+)-year.*(?:return|growth\s*rate)', query, re.IGNORECASE)
100
  if match:
101
  duration = int(match.group(1) or match.group(2))
@@ -109,7 +131,10 @@ def parse_period(query):
109
  def calculate_growth_rate(start_year, end_year, duration=None):
110
  if (start_year, end_year) in fallback_returns:
111
  avg_return = fallback_returns[(start_year, end_year)]
112
- response = f"The S&P 500’s average annual return from {start_year} to {end_year} was approximately {avg_return:.1f}%, including dividends."
 
 
 
113
  return avg_return, response
114
  if df_yearly is None or start_year is None or end_year is None:
115
  return None, "Data not available or invalid period."
@@ -117,9 +142,33 @@ def calculate_growth_rate(start_year, end_year, duration=None):
117
  if df_period.empty:
118
  return None, f"No data available for {start_year} to {end_year}."
119
  avg_return = df_period['Return'].mean()
120
- response = f"The S&P 500’s average annual return from {start_year} to {end_year} was approximately {avg_return:.1f}%, including dividends."
 
 
 
121
  return avg_return, response
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  # Define chat function
124
  def chat_with_model(user_input, history=None, is_processing=False):
125
  try:
@@ -130,8 +179,7 @@ def chat_with_model(user_input, history=None, is_processing=False):
130
 
131
  # Normalize and check cache
132
  cache_key = user_input.lower().strip()
133
- cache_keys = list(response_cache.keys())
134
- closest_key = cache_key if cache_key in response_cache else get_closest_cache_key(cache_key, cache_keys)
135
  if closest_key:
136
  logger.info(f"Cache hit for: {closest_key}")
137
  response = response_cache[closest_key]
@@ -143,10 +191,25 @@ def chat_with_model(user_input, history=None, is_processing=False):
143
  logger.info(f"Response time: {end_time - start_time:.2f} seconds")
144
  return response, history, False, ""
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  # Check for period-specific query
147
- start_year, end_year, carica_duration = parse_period(user_input)
148
  if start_year and end_year:
149
- avg_return, response = calculate_growth_rate(start_year, end_year, carica_duration)
150
  if avg_return is not None:
151
  response_cache[cache_key] = response
152
  logger.info(f"Dynamic period query: {start_year}–{end_year}, added to cache")
@@ -158,15 +221,50 @@ def chat_with_model(user_input, history=None, is_processing=False):
158
  logger.info(f"Response time: {end_time - start_time:.2f} seconds")
159
  return response, history, False, ""
160
 
161
- # Model inference
162
- inputs = tokenizer(user_input, return_tensors="pt")
163
- outputs = model.generate(
164
- **inputs,
165
- max_new_tokens=30,
166
- repetition_penalty=2.5,
167
- no_repeat_ngram_size=2
168
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
170
  logger.info(f"Chatbot response: {response}")
171
 
172
  # Update cache
@@ -177,6 +275,7 @@ def chat_with_model(user_input, history=None, is_processing=False):
177
  history = history or []
178
  history.append({"role": "user", "content": user_input})
179
  history.append({"role": "assistant", "content": response})
 
180
  end_time = time.time()
181
  logger.info(f"Response time: {end_time - start_time:.2f} seconds")
182
  return response, history, False, ""
@@ -195,7 +294,7 @@ def chat_with_model(user_input, history=None, is_processing=False):
195
  # Save cache on exit
196
  def save_cache():
197
  try:
198
- with open("cache.json", 'w') as f:
199
  json.dump(response_cache, f, indent=2)
200
  logger.info("Saved cache to cache.json")
201
  except Exception as e:
 
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  import pandas as pd
8
  import re
9
+ import numpy as np
10
  import json
11
  import difflib
12
  from functools import lru_cache
 
49
  # Hardcoded fallback for recent periods if dataset is incomplete
50
  fallback_returns = {
51
  (2020, 2022): 8.3, # Average annual return based on external data
52
+ (2015, 2024): 12.2,
53
+ (2020, 2024): 10.5
54
  }
55
 
56
  # Load model and tokenizer at startup
57
  model_name = "./finetuned_model" if os.path.exists("./finetuned_model") else "distilgpt2"
58
+ try:
59
+ logger.info(f"Loading tokenizer for {model_name}")
60
+ tokenizer = AutoTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=False)
61
+ tokenizer.pad_token = tokenizer.eos_token
62
+ logger.info(f"Loading model {model_name}")
63
+ with torch.inference_mode():
64
+ if os.path.exists("./finetuned_model/distilgpt2_traced.pt"):
65
+ model = torch.jit.load("./finetuned_model/distilgpt2_traced.pt")
66
+ else:
67
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(device)
68
+ sample_input = tokenizer("What was the average annual return of the S&P 500 between 2020 and 2022?", return_tensors="pt")["input_ids"].to(device)
69
+ model = torch.jit.trace(model, sample_input)
70
+ model.save("./finetuned_model/distilgpt2_traced.pt")
71
+ logger.info(f"Successfully loaded model: {model_name}")
72
+ except Exception as e:
73
+ logger.error(f"Error loading model/tokenizer: {e}")
74
+ raise RuntimeError(f"Failed to load model: {str(e)}")
75
 
76
+ # Pre-tokenize prompt prefix
77
+ prompt_prefix = (
78
+ "You are FinChat, a financial advisor with expertise in stock market performance. Provide concise, accurate answers with historical data for S&P 500 queries. "
79
+ "For period-specific queries, use precise year ranges and calculate average annual returns. For investment return queries, use compound interest calculations "
80
+ "based on historical averages. Avoid repetition and ensure answers are relevant.\n\n"
81
+ "Example 1:\n"
82
+ "Q: What is the S&P 500’s average annual return?\n"
83
+ "A: The S&P 500’s average annual return is ~10–12% over the long term (1927–2025), including dividends.\n\n"
84
+ "Example 2:\n"
85
+ "Q: What will $5,000 be worth in 10 years if invested in the S&P 500?\n"
86
+ "A: Assuming a 10% average annual return, a $5,000 investment in the S&P 500 would grow to approximately $12,974 in 10 years with annual compounding.\n\n"
87
+ "Example 3:\n"
88
+ "Q: What was the average annual return of the S&P 500 between 2020 and 2022?\n"
89
+ "A: The S&P 500’s average annual return from 2020 to 2022 was approximately 8.3%, including dividends, with significant volatility due to the COVID-19 recovery and 2022 bear market.\n\n"
90
+ "Q: "
91
+ )
92
+ prefix_tokens = tokenizer(prompt_prefix, return_tensors="pt", truncation=True, max_length=512)["input_ids"].to(device)
93
 
94
  # Substring matching for cache with exact year matching
95
+ @lru_cache(maxsize=100)
96
+ def get_closest_cache_key(message):
97
  message = message.lower().strip()
98
  year_match = re.search(r'(\d{4})\s*(?:and|to|-|–)\s*(\d{4})', message)
99
  if year_match:
100
  start_year, end_year = year_match.groups()
101
+ for key in response_cache.keys():
102
  if f"{start_year} and {end_year}" in key or f"{start_year} to {end_year}" in key or f"{start_year}–{end_year}" in key:
103
  return key
104
+ matches = difflib.get_close_matches(message, response_cache.keys(), n=1, cutoff=0.7)
105
  return matches[0] if matches else None
106
 
107
  # Parse period from user input
108
  def parse_period(query):
109
+ # Match specific year ranges (e.g., "between 2020 and 2022", "2020–2022")
110
  match = re.search(r'(?:between|from)\s*(\d{4})\s*(?:and|to|-|–)\s*(\d{4})', query, re.IGNORECASE)
111
  if match:
112
  start_year, end_year = map(int, match.groups())
113
  return start_year, end_year, None
114
+ # Match duration-based queries (e.g., "1-year from 2020", "3-year growth rate")
115
  match = re.search(r'(\d+)-year.*from\s*(\d{4})', query, re.IGNORECASE)
116
  if match:
117
  duration, start_year = map(int, match.groups())
118
  end_year = start_year + duration - 1
119
  return start_year, end_year, duration
120
+ # Match general duration queries (e.g., "past 5 years", "3-year growth rate")
121
  match = re.search(r'past\s*(\d+)-year|\b(\d+)-year.*(?:return|growth\s*rate)', query, re.IGNORECASE)
122
  if match:
123
  duration = int(match.group(1) or match.group(2))
 
131
  def calculate_growth_rate(start_year, end_year, duration=None):
132
  if (start_year, end_year) in fallback_returns:
133
  avg_return = fallback_returns[(start_year, end_year)]
134
+ if duration:
135
+ response = f"The S&P 500’s {duration}-year average annual return from {start_year} to {end_year} was approximately {avg_return:.1f}%, including dividends."
136
+ else:
137
+ response = f"The S&P 500’s average annual return from {start_year} to {end_year} was approximately {avg_return:.1f}%, including dividends."
138
  return avg_return, response
139
  if df_yearly is None or start_year is None or end_year is None:
140
  return None, "Data not available or invalid period."
 
142
  if df_period.empty:
143
  return None, f"No data available for {start_year} to {end_year}."
144
  avg_return = df_period['Return'].mean()
145
+ if duration:
146
+ response = f"The S&P 500’s {duration}-year average annual return from {start_year} to {end_year} was approximately {avg_return:.1f}%, including dividends."
147
+ else:
148
+ response = f"The S&P 500’s average annual return from {start_year} to {end_year} was approximately {avg_return:.1f}%, including dividends."
149
  return avg_return, response
150
 
151
+ # Parse investment return query
152
+ def parse_investment_query(query):
153
+ match = re.search(r'\$(\d+).*\s(\d+)\s*years?.*\bs&p\s*500', query, re.IGNORECASE)
154
+ if match:
155
+ amount = float(match.group(1))
156
+ years = int(match.group(2))
157
+ return amount, years
158
+ return None, None
159
+
160
+ # Calculate future value
161
+ def calculate_future_value(amount, years):
162
+ if df_yearly is None or amount is None or years is None:
163
+ return None, "Data not available or invalid input."
164
+ avg_annual_return = 10.0 # Historical S&P 500 average (1927–2025)
165
+ future_value = amount * (1 + avg_annual_return / 100) ** years
166
+ return future_value, (
167
+ f"Assuming a 10% average annual return, a ${amount:,.0f} investment in the S&P 500 would grow to approximately ${future_value:,.0f} "
168
+ f"in {years} years with annual compounding. This is based on the historical average return of 10–12% (1927–2025). "
169
+ "Future returns vary and are not guaranteed. Consult a financial planner."
170
+ )
171
+
172
  # Define chat function
173
  def chat_with_model(user_input, history=None, is_processing=False):
174
  try:
 
179
 
180
  # Normalize and check cache
181
  cache_key = user_input.lower().strip()
182
+ closest_key = get_closest_cache_key(cache_key)
 
183
  if closest_key:
184
  logger.info(f"Cache hit for: {closest_key}")
185
  response = response_cache[closest_key]
 
191
  logger.info(f"Response time: {end_time - start_time:.2f} seconds")
192
  return response, history, False, ""
193
 
194
+ # Check for investment return query
195
+ amount, years = parse_investment_query(user_input)
196
+ if amount and years:
197
+ future_value, response = calculate_future_value(amount, years)
198
+ if future_value is not None:
199
+ response_cache[cache_key] = response
200
+ logger.info(f"Investment query: ${amount} for {years} years, added to cache")
201
+ logger.info(f"Chatbot response: {response}")
202
+ history = history or []
203
+ history.append({"role": "user", "content": user_input})
204
+ history.append({"role": "assistant", "content": response})
205
+ end_time = time.time()
206
+ logger.info(f"Response time: {end_time - start_time:.2f} seconds")
207
+ return response, history, False, ""
208
+
209
  # Check for period-specific query
210
+ start_year, end_year, duration = parse_period(user_input)
211
  if start_year and end_year:
212
+ avg_return, response = calculate_growth_rate(start_year, end_year, duration)
213
  if avg_return is not None:
214
  response_cache[cache_key] = response
215
  logger.info(f"Dynamic period query: {start_year}–{end_year}, added to cache")
 
221
  logger.info(f"Response time: {end_time - start_time:.2f} seconds")
222
  return response, history, False, ""
223
 
224
+ # Skip model for short prompts
225
+ if len(user_input.strip()) <= 5:
226
+ logger.info("Short prompt, returning default response")
227
+ response = "Hello! I'm FinChat, your financial advisor. Ask about investing!"
228
+ logger.info(f"Chatbot response: {response}")
229
+ history = history or []
230
+ history.append({"role": "user", "content": user_input})
231
+ history.append({"role": "assistant", "content": response})
232
+ end_time = time.time()
233
+ logger.info(f"Response time: {end_time - start_time:.2f} seconds")
234
+ return response, history, False, ""
235
+
236
+ # Construct prompt
237
+ full_prompt = prompt_prefix + user_input + "\nA:"
238
+ try:
239
+ inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512)["input_ids"].to(device)
240
+ except Exception as e:
241
+ logger.error(f"Error tokenizing input: {e}")
242
+ response = f"Error: Failed to process input: {str(e)}"
243
+ logger.info(f"Chatbot response: {response}")
244
+ history = history or []
245
+ history.append({"role": "user", "content": user_input})
246
+ history.append({"role": "assistant", "content": response})
247
+ end_time = time.time()
248
+ logger.info(f"Response time: {end_time - start_time:.2f} seconds")
249
+ return response, history, False, ""
250
+
251
+ # Generate response
252
+ with torch.inference_mode():
253
+ logger.info("Generating response with model")
254
+ gen_start_time = time.time()
255
+ outputs = model.generate(
256
+ inputs,
257
+ max_new_tokens=20,
258
+ min_length=10,
259
+ do_sample=False,
260
+ repetition_penalty=3.0,
261
+ no_repeat_ngram_size=2,
262
+ pad_token_id=tokenizer.eos_token_id
263
+ )
264
+ gen_end_time = time.time()
265
+ logger.info(f"Generation time: {gen_end_time - gen_start_time:.2f} seconds")
266
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
267
+ response = response[len(full_prompt):].strip() if response.startswith(full_prompt) else response
268
  logger.info(f"Chatbot response: {response}")
269
 
270
  # Update cache
 
275
  history = history or []
276
  history.append({"role": "user", "content": user_input})
277
  history.append({"role": "assistant", "content": response})
278
+ torch.cuda.empty_cache()
279
  end_time = time.time()
280
  logger.info(f"Response time: {end_time - start_time:.2f} seconds")
281
  return response, history, False, ""
 
294
  # Save cache on exit
295
  def save_cache():
296
  try:
297
+ with open("cache.json", "w") as f:
298
  json.dump(response_cache, f, indent=2)
299
  logger.info("Saved cache to cache.json")
300
  except Exception as e: