Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -119,10 +119,11 @@ def parse_dates(query):
|
|
| 119 |
end_date = datetime.now()
|
| 120 |
start_date = end_date - period
|
| 121 |
return start_date, end_date
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
|
|
|
| 126 |
|
| 127 |
def find_closest_symbol(input_symbol):
|
| 128 |
input_symbol = input_symbol.upper()
|
|
@@ -159,11 +160,8 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
| 159 |
)
|
| 160 |
|
| 161 |
def generate_response(user_query, enable_thinking=False):
|
| 162 |
-
print(f"Processing query: {user_query}") # Debugging
|
| 163 |
stock_keywords = ['stock', 'growth', 'investment', 'price', 'return', 'cagr']
|
| 164 |
is_stock_query = any(keyword in user_query.lower() for keyword in stock_keywords)
|
| 165 |
-
print(f"Is stock query: {is_stock_query}") # Debugging
|
| 166 |
-
summary = ""
|
| 167 |
if is_stock_query:
|
| 168 |
# Try to find symbol from company name or ticker
|
| 169 |
symbol = None
|
|
@@ -174,59 +172,34 @@ def generate_response(user_query, enable_thinking=False):
|
|
| 174 |
if not symbol:
|
| 175 |
symbol_match = re.search(r'\b([A-Z]{1,5})\b', user_query.upper())
|
| 176 |
symbol = find_closest_symbol(symbol_match.group(1)) if symbol_match else None
|
| 177 |
-
print(f"Detected symbol: {symbol}") # Debugging
|
| 178 |
if symbol:
|
| 179 |
start_date, end_date = parse_dates(user_query)
|
| 180 |
-
print(f"Parsed dates: {start_date} to {end_date}") # Debugging
|
| 181 |
if start_date is None or end_date is None:
|
| 182 |
-
|
| 183 |
else:
|
| 184 |
hist = fetch_stock_data(symbol, start_date.strftime('%Y-%m-%d'), end_date.strftime('%Y-%m-%d'))
|
| 185 |
-
print(f"Data fetched for {symbol}: { 'Valid' if hist is not None and not hist.empty else 'Empty/None'}") # Debugging
|
| 186 |
if "average price" in user_query.lower():
|
| 187 |
if hist is not None and not hist.empty and 'Close' in hist.columns:
|
| 188 |
avg_price = hist['Close'].mean()
|
| 189 |
-
|
| 190 |
else:
|
| 191 |
-
|
| 192 |
elif "cagr" in user_query.lower() or "return" in user_query.lower():
|
| 193 |
growth_rate = calculate_growth_rate(start_date, end_date, symbol)
|
| 194 |
if growth_rate is not None:
|
| 195 |
-
|
| 196 |
else:
|
| 197 |
-
|
| 198 |
investment_match = re.search(r'\$(\d+)', user_query)
|
| 199 |
if investment_match:
|
| 200 |
principal = float(investment_match.group(1))
|
| 201 |
years = (end_date - start_date).days / 365.25
|
| 202 |
projected = calculate_investment(principal, years)
|
| 203 |
-
|
| 204 |
else:
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
system_prompt = (
|
| 209 |
-
"You are FinChat, a knowledgeable financial advisor. Always respond in a friendly, professional manner. "
|
| 210 |
-
"For greetings like 'Hi' or 'Hello', reply warmly, e.g., 'Hi! I'm FinChat, your financial advisor. What can I help you with today regarding stocks, investments, or advice?' "
|
| 211 |
-
"Provide accurate, concise advice based on the provided data summary. If no data is available, suggest alternatives politely. "
|
| 212 |
-
"Use the data summary to inform your response where relevant."
|
| 213 |
-
)
|
| 214 |
-
|
| 215 |
-
# Integrate thinking if enabled (simple chain-of-thought prompt)
|
| 216 |
-
thinking_prompt = "Think step by step before responding. " if enable_thinking else ""
|
| 217 |
-
|
| 218 |
-
# Prepare messages for the model
|
| 219 |
-
messages = [
|
| 220 |
-
{"role": "system", "content": system_prompt},
|
| 221 |
-
{"role": "user", "content": f"{thinking_prompt}Query: {user_query}\nData summary: {summary if summary else 'No specific data computed; respond generally based on knowledge.'}"}
|
| 222 |
-
]
|
| 223 |
-
|
| 224 |
-
# Generate response using the model
|
| 225 |
-
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device)
|
| 226 |
-
outputs = model.generate(input_ids, max_new_tokens=200, do_sample=True, temperature=0.7)
|
| 227 |
-
response = tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
|
| 228 |
-
|
| 229 |
-
return response
|
| 230 |
|
| 231 |
# Gradio interface setup
|
| 232 |
demo = gr.Interface(
|
|
|
|
| 119 |
end_date = datetime.now()
|
| 120 |
start_date = end_date - period
|
| 121 |
return start_date, end_date
|
| 122 |
+
else:
|
| 123 |
+
# Default to 1 year
|
| 124 |
+
end_date = datetime.now()
|
| 125 |
+
start_date = end_date - timedelta(days=365)
|
| 126 |
+
return start_date, end_date
|
| 127 |
|
| 128 |
def find_closest_symbol(input_symbol):
|
| 129 |
input_symbol = input_symbol.upper()
|
|
|
|
| 160 |
)
|
| 161 |
|
| 162 |
def generate_response(user_query, enable_thinking=False):
|
|
|
|
| 163 |
stock_keywords = ['stock', 'growth', 'investment', 'price', 'return', 'cagr']
|
| 164 |
is_stock_query = any(keyword in user_query.lower() for keyword in stock_keywords)
|
|
|
|
|
|
|
| 165 |
if is_stock_query:
|
| 166 |
# Try to find symbol from company name or ticker
|
| 167 |
symbol = None
|
|
|
|
| 172 |
if not symbol:
|
| 173 |
symbol_match = re.search(r'\b([A-Z]{1,5})\b', user_query.upper())
|
| 174 |
symbol = find_closest_symbol(symbol_match.group(1)) if symbol_match else None
|
|
|
|
| 175 |
if symbol:
|
| 176 |
start_date, end_date = parse_dates(user_query)
|
|
|
|
| 177 |
if start_date is None or end_date is None:
|
| 178 |
+
return "Invalid date range."
|
| 179 |
else:
|
| 180 |
hist = fetch_stock_data(symbol, start_date.strftime('%Y-%m-%d'), end_date.strftime('%Y-%m-%d'))
|
|
|
|
| 181 |
if "average price" in user_query.lower():
|
| 182 |
if hist is not None and not hist.empty and 'Close' in hist.columns:
|
| 183 |
avg_price = hist['Close'].mean()
|
| 184 |
+
return f"${avg_price:.2f}"
|
| 185 |
else:
|
| 186 |
+
return "No data."
|
| 187 |
elif "cagr" in user_query.lower() or "return" in user_query.lower():
|
| 188 |
growth_rate = calculate_growth_rate(start_date, end_date, symbol)
|
| 189 |
if growth_rate is not None:
|
| 190 |
+
return f"{growth_rate:.2f}%"
|
| 191 |
else:
|
| 192 |
+
return "No data."
|
| 193 |
investment_match = re.search(r'\$(\d+)', user_query)
|
| 194 |
if investment_match:
|
| 195 |
principal = float(investment_match.group(1))
|
| 196 |
years = (end_date - start_date).days / 365.25
|
| 197 |
projected = calculate_investment(principal, years)
|
| 198 |
+
return f"${projected:.2f}"
|
| 199 |
else:
|
| 200 |
+
return "Invalid symbol."
|
| 201 |
+
else:
|
| 202 |
+
return "Hello! Ask about stocks or investments."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
# Gradio interface setup
|
| 205 |
demo = gr.Interface(
|