arjunanand13 commited on
Commit
c23f76f
·
verified ·
1 Parent(s): e11e1da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +815 -0
app.py CHANGED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import yfinance as yf
3
+ import pandas as pd
4
+ import numpy as np
5
+ import plotly.graph_objects as go
6
+ from plotly.subplots import make_subplots
7
+ from datetime import datetime, timedelta
8
+ from langchain_huggingface import HuggingFaceEndpoint
9
+ from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
11
+ import chromadb
12
+ import requests
13
+ from bs4 import BeautifulSoup
14
+ import warnings
15
+ from typing import Dict, List, Tuple
16
+ import feedparser
17
+ from sentence_transformers import SentenceTransformer
18
+ import faiss
19
+ import json
20
+ import os
21
+
22
+ warnings.filterwarnings('ignore')
23
+
24
+
25
+
26
+ COMPANIES = {
27
+ 'Apple (AAPL)': 'AAPL',
28
+ 'Microsoft (MSFT)': 'MSFT',
29
+ 'Amazon (AMZN)': 'AMZN',
30
+ 'Google (GOOGL)': 'GOOGL',
31
+ 'Meta (META)': 'META',
32
+ 'Tesla (TSLA)': 'TSLA',
33
+ 'NVIDIA (NVDA)': 'NVDA',
34
+ 'JPMorgan Chase (JPM)': 'JPM',
35
+ 'Johnson & Johnson (JNJ)': 'JNJ',
36
+ 'Walmart (WMT)': 'WMT',
37
+ 'Visa (V)': 'V',
38
+ 'Mastercard (MA)': 'MA',
39
+ 'Procter & Gamble (PG)': 'PG',
40
+ 'UnitedHealth (UNH)': 'UNH',
41
+ 'Home Depot (HD)': 'HD',
42
+ 'Bank of America (BAC)': 'BAC',
43
+ 'Coca-Cola (KO)': 'KO',
44
+ 'Pfizer (PFE)': 'PFE',
45
+ 'Disney (DIS)': 'DIS',
46
+ 'Netflix (NFLX)': 'NFLX'
47
+ }
48
+
49
+ # Initialize models
50
+ print("Initializing models...")
51
+ api_token = os.getenv(TOKEN)
52
+ llm = HuggingFaceEndpoint(
53
+ repo_id="mistralai/Mistral-7B-Instruct-v0.2",
54
+ huggingfacehub_api_token=api_token,
55
+ temperature=0.7,
56
+ max_new_tokens=1000
57
+ )
58
+ vader = SentimentIntensityAnalyzer()
59
+ finbert = pipeline("sentiment-analysis",
60
+ model="ProsusAI/finbert")
61
+ print("Models initialized successfully!")
62
+ class AgenticRAGFramework:
63
+ """Main framework coordinating all agents"""
64
+ def __init__(self):
65
+ self.technical_agent = TechnicalAnalysisAgent()
66
+ self.sentiment_agent = SentimentAnalysisAgent()
67
+ self.llama_agent = LLMAgent()
68
+ self.knowledge_base = chromadb.Client()
69
+
70
+ def analyze(self, symbol: str, data: pd.DataFrame) -> Dict:
71
+ """Perform comprehensive analysis"""
72
+ technical_analysis = self.technical_agent.analyze(data)
73
+ sentiment_analysis = self.sentiment_agent.analyze(symbol)
74
+ llm_analysis = self.llama_agent.generate_analysis(
75
+ technical_analysis,
76
+ sentiment_analysis
77
+ )
78
+
79
+ return {
80
+ 'technical_analysis': technical_analysis,
81
+ 'sentiment_analysis': sentiment_analysis,
82
+ 'llm_analysis': llm_analysis
83
+ }
84
+
85
+
86
+ class NewsSource:
87
+ """Base class for news sources"""
88
+ def get_news(self, company: str) -> List[Dict]:
89
+ raise NotImplementedError
90
+
91
+ class FinvizNews(NewsSource):
92
+ """Fetch news from FinViz"""
93
+ def get_news(self, company: str) -> List[Dict]:
94
+ try:
95
+ ticker = company.split('(')[-1].replace(')', '')
96
+ url = f"https://finviz.com/quote.ashx?t={ticker}"
97
+ headers = {
98
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
99
+ }
100
+
101
+ response = requests.get(url, headers=headers)
102
+ soup = BeautifulSoup(response.text, 'html.parser')
103
+ news_table = soup.find('table', {'class': 'news-table'})
104
+
105
+ if not news_table:
106
+ return []
107
+
108
+ news_list = []
109
+ for row in news_table.find_all('tr')[:5]:
110
+ cols = row.find_all('td')
111
+ if len(cols) >= 2:
112
+ date = cols[0].text.strip()
113
+ title = cols[1].a.text.strip()
114
+ link = cols[1].a['href']
115
+
116
+ news_list.append({
117
+ 'title': title,
118
+ 'description': title,
119
+ 'date': date,
120
+ 'source': 'FinViz',
121
+ 'url': link
122
+ })
123
+
124
+ return news_list
125
+ except Exception as e:
126
+ print(f"FinViz Error: {str(e)}")
127
+ return []
128
+
129
+ class MarketWatchNews(NewsSource):
130
+ """Fetch news from MarketWatch"""
131
+ def get_news(self, company: str) -> List[Dict]:
132
+ try:
133
+ ticker = company.split('(')[-1].replace(')', '')
134
+ url = f"https://www.marketwatch.com/investing/stock/{ticker}"
135
+ headers = {
136
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
137
+ }
138
+
139
+ response = requests.get(url, headers=headers)
140
+ soup = BeautifulSoup(response.text, 'html.parser')
141
+ news_elements = soup.find_all('div', {'class': 'article__content'})
142
+
143
+ news_list = []
144
+ for element in news_elements[:5]:
145
+ title_elem = element.find('a', {'class': 'link'})
146
+ if title_elem:
147
+ title = title_elem.text.strip()
148
+ link = title_elem['href']
149
+ date_elem = element.find('span', {'class': 'article__timestamp'})
150
+ date = date_elem.text if date_elem else 'Recent'
151
+
152
+ news_list.append({
153
+ 'title': title,
154
+ 'description': title,
155
+ 'date': date,
156
+ 'source': 'MarketWatch',
157
+ 'url': link
158
+ })
159
+
160
+ return news_list
161
+ except Exception as e:
162
+ print(f"MarketWatch Error: {str(e)}")
163
+ return []
164
+
165
+ class YahooRSSNews(NewsSource):
166
+ """Fetch news from Yahoo Finance RSS feed"""
167
+ def get_news(self, company: str) -> List[Dict]:
168
+ try:
169
+ ticker = company.split('(')[-1].replace(')', '')
170
+ url = f"https://feeds.finance.yahoo.com/rss/2.0/headline?s={ticker}&region=US&lang=en-US"
171
+
172
+ feed = feedparser.parse(url)
173
+ news_list = []
174
+
175
+ for entry in feed.entries[:5]:
176
+ news_list.append({
177
+ 'title': entry.title,
178
+ 'description': entry.description,
179
+ 'date': entry.published,
180
+ 'source': 'Yahoo Finance',
181
+ 'url': entry.link
182
+ })
183
+
184
+ return news_list
185
+ except Exception as e:
186
+ print(f"Yahoo RSS Error: {str(e)}")
187
+ return []
188
+
189
+ class TechnicalAnalysisAgent:
190
+ """Agent for technical analysis"""
191
+ def __init__(self):
192
+ self.required_periods = {
193
+ 'sma': [20, 50, 200],
194
+ 'rsi': 14,
195
+ 'volatility': 20,
196
+ 'macd': [12, 26, 9]
197
+ }
198
+
199
+ def analyze(self, data: pd.DataFrame) -> Dict:
200
+ df = data.copy()
201
+ close_col = ('Close', df.columns.get_level_values(1)[0])
202
+
203
+ # Calculate metrics
204
+ df['Returns'] = df[close_col].pct_change()
205
+
206
+ # SMAs
207
+ for period in self.required_periods['sma']:
208
+ df[f'SMA_{period}'] = df[close_col].rolling(window=period).mean()
209
+
210
+ # RSI
211
+ delta = df[close_col].diff()
212
+ gain = delta.where(delta > 0, 0).rolling(window=14).mean()
213
+ loss = -delta.where(delta < 0, 0).rolling(window=14).mean()
214
+ rs = gain / loss
215
+ df['RSI'] = 100 - (100 / (1 + rs))
216
+
217
+ # MACD
218
+ exp1 = df[close_col].ewm(span=12, adjust=False).mean()
219
+ exp2 = df[close_col].ewm(span=26, adjust=False).mean()
220
+ df['MACD'] = exp1 - exp2
221
+ df['Signal_Line'] = df['MACD'].ewm(span=9, adjust=False).mean()
222
+
223
+ # Bollinger Bands
224
+ df['BB_middle'] = df[close_col].rolling(window=20).mean()
225
+ rolling_std = df[close_col].rolling(window=20).std()
226
+ df['BB_upper'] = df['BB_middle'] + (2 * rolling_std)
227
+ df['BB_lower'] = df['BB_middle'] - (2 * rolling_std)
228
+
229
+ return {
230
+ 'processed_data': df,
231
+ 'current_signals': self._generate_signals(df, close_col)
232
+ }
233
+
234
+ def _generate_signals(self, df: pd.DataFrame, close_col) -> Dict:
235
+ if df.empty:
236
+ return {
237
+ 'trend': 'Unknown',
238
+ 'rsi_signal': 'Unknown',
239
+ 'macd_signal': 'Unknown',
240
+ 'bb_position': 'Unknown'
241
+ }
242
+
243
+ current = df.iloc[-1]
244
+
245
+ trend = 'Bullish' if float(current['SMA_20']) > float(current['SMA_50']) else 'Bearish'
246
+
247
+ rsi_value = float(current['RSI'])
248
+ if rsi_value > 70:
249
+ rsi_signal = 'Overbought'
250
+ elif rsi_value < 30:
251
+ rsi_signal = 'Oversold'
252
+ else:
253
+ rsi_signal = 'Neutral'
254
+
255
+ macd_signal = 'Buy' if float(current['MACD']) > float(current['Signal_Line']) else 'Sell'
256
+
257
+ close_value = float(current[close_col])
258
+ bb_upper = float(current['BB_upper'])
259
+ bb_lower = float(current['BB_lower'])
260
+
261
+ if close_value > bb_upper:
262
+ bb_position = 'Above Upper Band'
263
+ elif close_value < bb_lower:
264
+ bb_position = 'Below Lower Band'
265
+ else:
266
+ bb_position = 'Within Bands'
267
+
268
+ return {
269
+ 'trend': trend,
270
+ 'rsi_signal': rsi_signal,
271
+ 'macd_signal': macd_signal,
272
+ 'bb_position': bb_position
273
+ }
274
+
275
+ class SentimentAnalysisAgent:
276
+ """Agent for sentiment analysis"""
277
+ def __init__(self):
278
+ self.news_sources = [
279
+ FinvizNews(),
280
+ MarketWatchNews(),
281
+ YahooRSSNews()
282
+ ]
283
+
284
+ def analyze(self, symbol: str) -> Dict:
285
+ all_news = []
286
+ for source in self.news_sources:
287
+ news_items = source.get_news(symbol)
288
+ all_news.extend(news_items)
289
+
290
+ vader_scores = []
291
+ finbert_scores = []
292
+
293
+ for article in all_news:
294
+ vader_scores.append(vader.polarity_scores(article['title']))
295
+ finbert_scores.append(
296
+ finbert(article['title'][:512])[0]
297
+ )
298
+
299
+ return {
300
+ 'articles': all_news,
301
+ 'vader_scores': vader_scores,
302
+ 'finbert_scores': finbert_scores,
303
+ 'aggregated': self._aggregate_sentiment(vader_scores, finbert_scores)
304
+ }
305
+
306
+ def _aggregate_sentiment(self, vader_scores: List[Dict],
307
+ finbert_scores: List[Dict]) -> Dict:
308
+ if not vader_scores or not finbert_scores:
309
+ return {
310
+ 'sentiment': 'Neutral',
311
+ 'confidence': 0,
312
+ 'vader_sentiment': 0,
313
+ 'finbert_sentiment': 0
314
+ }
315
+
316
+ avg_vader = np.mean([score['compound'] for score in vader_scores])
317
+ avg_finbert = np.mean([
318
+ 1 if score['label'] == 'positive' else -1
319
+ for score in finbert_scores
320
+ ])
321
+
322
+ combined_score = (avg_vader + avg_finbert) / 2
323
+
324
+ return {
325
+ 'sentiment': 'Bullish' if combined_score > 0.1 else 'Bearish' if combined_score < -0.1 else 'Neutral',
326
+ 'confidence': abs(combined_score),
327
+ 'vader_sentiment': avg_vader,
328
+ 'finbert_sentiment': avg_finbert
329
+ }
330
+
331
+ class LLMAgent:
332
+ """Agent for LLM-based analysis using HuggingFace API"""
333
+ def __init__(self):
334
+ self.llm = llm
335
+
336
+ def generate_analysis(self, technical_data: Dict, sentiment_data: Dict) -> str:
337
+ prompt = self._create_prompt(technical_data, sentiment_data)
338
+
339
+ response = self.llm.invoke(prompt)
340
+ return response
341
+
342
+ def _create_prompt(self, technical_data: Dict, sentiment_data: Dict) -> str:
343
+ return f"""Based on technical and sentiment indicators:
344
+
345
+ Technical Signals:
346
+ - Trend: {technical_data['current_signals']['trend']}
347
+ - RSI: {technical_data['current_signals']['rsi_signal']}
348
+ - MACD: {technical_data['current_signals']['macd_signal']}
349
+ - BB Position: {technical_data['current_signals']['bb_position']}
350
+ - Sentiment: {sentiment_data['aggregated']['sentiment']} (Confidence: {sentiment_data['aggregated']['confidence']:.2f})
351
+
352
+ Provide:
353
+ 1. Current Trend Analysis
354
+ 2. Key Risk Factors
355
+ 3. Trading Recommendations
356
+ 4. Price Targets
357
+ 5. Near-term Outlook (1-2 weeks)
358
+
359
+ Note: return only required information and nothing unnecessary"""
360
+
361
+ # class ChatbotRouter:
362
+ # """Routes chatbot queries to appropriate data sources and generates responses"""
363
+ # def __init__(self):
364
+ # self.llm = llm
365
+ # self.encoder = SentenceTransformer('all-MiniLM-L6-v2')
366
+ # self.faiss_index = None
367
+ # self.company_data = {}
368
+ # self.news_sources = [
369
+ # FinvizNews(),
370
+ # MarketWatchNews(),
371
+ # YahooRSSNews()
372
+ # ]
373
+ # self.load_faiss_index()
374
+
375
+ # def route_and_respond(self, query: str, company: str) -> str:
376
+ # query_type = self._classify_query(query.lower())
377
+ # route_message = f"\n[Taking {query_type.upper()} route]\n\n"
378
+
379
+ # if query_type == "company_info":
380
+ # context = self._get_company_context(query, company)
381
+ # elif query_type == "news":
382
+ # context = self._get_news_context(company)
383
+ # elif query_type == "price":
384
+ # context = self._get_price_context(company)
385
+ # else:
386
+ # return route_message + "I'm not sure how to handle this query. Please ask about company information, news, or price data."
387
+
388
+ # prompt = self._create_prompt(query, context, query_type)
389
+ # response = self.llm.invoke(prompt)
390
+
391
+ # return route_message + response
392
+
393
+ class ChatbotRouter:
394
+ """Routes chatbot queries to appropriate data sources and generates responses"""
395
+ def __init__(self):
396
+ self.llm = llm
397
+ self.encoder = SentenceTransformer('all-MiniLM-L6-v2')
398
+ self.faiss_index = None
399
+ self.company_data = {}
400
+ self.news_sources = [
401
+ FinvizNews(),
402
+ MarketWatchNews(),
403
+ YahooRSSNews()
404
+ ]
405
+ self.load_faiss_index()
406
+
407
+ def load_faiss_index(self):
408
+ try:
409
+ self.faiss_index = faiss.read_index("company_profiles.index")
410
+ for file in os.listdir('company_data'):
411
+ with open(f'company_data/{file}', 'r') as f:
412
+ company_name = file.replace('.txt', '')
413
+ self.company_data[company_name] = json.load(f)
414
+ except Exception as e:
415
+ print(f"Error loading FAISS index: {e}")
416
+
417
+ def route_and_respond(self, query: str, company: str) -> str:
418
+ query_type = self._classify_query(query.lower())
419
+ route_message = f"\n[Taking {query_type.upper()} route]\n\n"
420
+
421
+ if query_type == "company_info":
422
+ context = self._get_company_context(query, company)
423
+ elif query_type == "news":
424
+ context = self._get_news_context(company)
425
+ elif query_type == "price":
426
+ context = self._get_price_context(company)
427
+ else:
428
+ return route_message + "I'm not sure how to handle this query. Please ask about company information, news, or price data."
429
+
430
+ prompt = self._create_prompt(query, context, query_type)
431
+ response = self.llm.invoke(prompt)
432
+
433
+ return route_message + response
434
+
435
+ def _classify_query(self, query: str) -> str:
436
+ """Classify query type"""
437
+ if any(word in query for word in ["profile", "about", "information", "details", "what", "who", "describe"]):
438
+ return "company_info"
439
+ elif any(word in query for word in ["news", "latest", "recent", "announcement", "update"]):
440
+ return "news"
441
+ elif any(word in query for word in ["price", "stock", "value", "market", "trading", "cost"]):
442
+ return "price"
443
+ return "unknown"
444
+
445
+ def _get_company_context(self, query: str, company: str) -> str:
446
+ """Get relevant company information using FAISS"""
447
+ try:
448
+ query_vector = self.encoder.encode([query])
449
+ D, I = self.faiss_index.search(query_vector, 1)
450
+
451
+ company_name = company.split(" (")[0]
452
+ company_info = self.company_data.get(company_name, {})
453
+ print(company_info)
454
+ return company_info
455
+
456
+ except Exception as e:
457
+ return f"Error retrieving company information: {str(e)}"
458
+
459
+ def _get_news_context(self, company: str) -> str:
460
+ """Get news from multiple sources"""
461
+ all_news = []
462
+
463
+ for source in self.news_sources:
464
+ news_items = source.get_news(company)
465
+ all_news.extend(news_items)
466
+
467
+ seen_titles = set()
468
+ unique_news = []
469
+ for news in all_news:
470
+ if news['title'] not in seen_titles:
471
+ seen_titles.add(news['title'])
472
+ unique_news.append(news)
473
+
474
+ if not unique_news:
475
+ return "No recent news found."
476
+
477
+ news_context = "Recent news articles:\n\n"
478
+ for news in unique_news[:5]:
479
+ news_context += f"Source: {news['source']}\n"
480
+ news_context += f"Title: {news['title']}\n"
481
+ if news['description']:
482
+ news_context += f"Description: {news['description']}\n"
483
+ news_context += f"Date: {news['date']}\n\n"
484
+
485
+ return news_context
486
+
487
+ def _get_price_context(self, company: str) -> str:
488
+ """Get current price information"""
489
+ try:
490
+ ticker = company.split('(')[-1].replace(')', '')
491
+ stock = yf.Ticker(ticker)
492
+ info = stock.info
493
+
494
+ return f"""Current Stock Information:
495
+ Price: ${info.get('currentPrice', 'N/A')}
496
+ Day Range: ${info.get('dayLow', 'N/A')} - ${info.get('dayHigh', 'N/A')}
497
+ 52 Week Range: ${info.get('fiftyTwoWeekLow', 'N/A')} - ${info.get('fiftyTwoWeekHigh', 'N/A')}
498
+ Market Cap: ${info.get('marketCap', 'N/A'):,}
499
+ Volume: {info.get('volume', 'N/A'):,}
500
+ P/E Ratio: {info.get('trailingPE', 'N/A')}
501
+ Dividend Yield: {info.get('dividendYield', 'N/A')}%"""
502
+
503
+ except Exception as e:
504
+ return f"Error fetching price data: {str(e)}"
505
+
506
+ def _create_prompt(self, query: str, context: str, query_type: str) -> str:
507
+ """Create prompt for LLM"""
508
+ if query_type == "news":
509
+ return f"""Based on the following news articles, please provide a summary addressing the query.
510
+
511
+ Context:
512
+ {context}
513
+
514
+ Query: {query}
515
+
516
+ Please analyze the news and provide:
517
+ 1. Key points from the recent articles
518
+ 2. Any significant developments or trends
519
+ 3. Potential impact on the company
520
+ 4. Overall sentiment (positive/negative/neutral)
521
+
522
+ Response should be clear, concise, and focused on the most relevant information."""
523
+ else:
524
+ return f"""Based on the following {query_type} context, please answer the question.
525
+
526
+ Context:
527
+ {context}
528
+
529
+ Question: {query}
530
+
531
+ Please provide a clear and concise answer based on the given context."""
532
+
533
+ def _generate_response(self, prompt: str) -> str:
534
+ """Generate response using LLM"""
535
+ inputs = self.llm_agent.tokenizer(prompt, return_tensors="pt").to(self.llm_agent.model.device)
536
+ outputs = self.llm_agent.model.generate(
537
+ inputs["input_ids"],
538
+ max_new_tokens=200,
539
+ temperature=0.7,
540
+ num_return_sequences=1
541
+ )
542
+ # Decode and remove the prompt part from the output
543
+ response = self.llm_agent.tokenizer.decode(outputs[0], skip_special_tokens=True)
544
+ response_only = response.replace(prompt, "").strip()
545
+ print(response)
546
+ return response_only
547
+
548
+ def analyze_stock(company: str, lookback_days: int = 180) -> Tuple[str, go.Figure, go.Figure]:
549
+ """Main analysis function"""
550
+ try:
551
+ symbol = COMPANIES[company]
552
+ end_date = datetime.now()
553
+ start_date = end_date - timedelta(days=lookback_days)
554
+
555
+ data = yf.download(symbol, start=start_date, end=end_date)
556
+ if len(data) == 0:
557
+ return "No data available.", None, None
558
+
559
+ framework = AgenticRAGFramework()
560
+ analysis = framework.analyze(symbol, data)
561
+
562
+ plots = create_plots(analysis)
563
+
564
+ return analysis['llm_analysis'], plots[0], plots[1]
565
+
566
+ except Exception as e:
567
+ return f"Error analyzing stock: {str(e)}", None, None
568
+
569
+ def create_plots(analysis: Dict) -> List[go.Figure]:
570
+ """Create analysis plots"""
571
+ data = analysis['technical_analysis']['processed_data']
572
+
573
+ # Price and Volume Plot
574
+ fig1 = make_subplots(
575
+ rows=2, cols=1,
576
+ shared_xaxes=True,
577
+ vertical_spacing=0.03,
578
+ subplot_titles=('Price Analysis', 'Volume'),
579
+ row_heights=[0.7, 0.3]
580
+ )
581
+
582
+ close_col = ('Close', data.columns.get_level_values(1)[0])
583
+ open_col = ('Open', data.columns.get_level_values(1)[0])
584
+ volume_col = ('Volume', data.columns.get_level_values(1)[0])
585
+
586
+ fig1.add_trace(
587
+ go.Scatter(x=data.index, y=data[close_col], name='Price',
588
+ line=dict(color='blue', width=2)),
589
+ row=1, col=1
590
+ )
591
+ fig1.add_trace(
592
+ go.Scatter(x=data.index, y=data['SMA_20'], name='SMA20',
593
+ line=dict(color='orange', width=1.5)),
594
+ row=1, col=1
595
+ )
596
+ fig1.add_trace(
597
+ go.Scatter(x=data.index, y=data['SMA_50'], name='SMA50',
598
+ line=dict(color='red', width=1.5)),
599
+ row=1, col=1
600
+ )
601
+
602
+ colors = ['red' if float(row[close_col]) < float(row[open_col]) else 'green'
603
+ for idx, row in data.iterrows()]
604
+
605
+ fig1.add_trace(
606
+ go.Bar(x=data.index, y=data[volume_col], marker_color=colors, name='Volume'),
607
+ row=2, col=1
608
+ )
609
+
610
+ fig1.update_layout(
611
+ height=400,
612
+ showlegend=True,
613
+ xaxis_rangeslider_visible=False,
614
+ plot_bgcolor='white',
615
+ paper_bgcolor='white'
616
+ )
617
+
618
+ # Technical Indicators Plot
619
+ fig2 = make_subplots(
620
+ rows=3, cols=1,
621
+ shared_xaxes=True,
622
+ subplot_titles=('RSI', 'MACD', 'Bollinger Bands'),
623
+ row_heights=[0.33, 0.33, 0.34],
624
+ vertical_spacing=0.03
625
+ )
626
+
627
+ # RSI
628
+ fig2.add_trace(
629
+ go.Scatter(x=data.index, y=data['RSI'], name='RSI',
630
+ line=dict(color='purple', width=1.5)),
631
+ row=1, col=1
632
+ )
633
+ fig2.add_hline(y=70, line_dash="dash", line_color="red", row=1, col=1)
634
+ fig2.add_hline(y=30, line_dash="dash", line_color="green", row=1, col=1)
635
+
636
+ # MACD
637
+ fig2.add_trace(
638
+ go.Scatter(x=data.index, y=data['MACD'], name='MACD',
639
+ line=dict(color='blue', width=1.5)),
640
+ row=2, col=1
641
+ )
642
+ fig2.add_trace(
643
+ go.Scatter(x=data.index, y=data['Signal_Line'], name='Signal',
644
+ line=dict(color='orange', width=1.5)),
645
+ row=2, col=1
646
+ )
647
+
648
+ # Bollinger Bands
649
+ fig2.add_trace(
650
+ go.Scatter(x=data.index, y=data[close_col], name='Price',
651
+ line=dict(color='blue', width=2)),
652
+ row=3, col=1
653
+ )
654
+ fig2.add_trace(
655
+ go.Scatter(x=data.index, y=data['BB_upper'], name='Upper BB',
656
+ line=dict(color='gray', dash='dash')),
657
+ row=3, col=1
658
+ )
659
+ fig2.add_trace(
660
+ go.Scatter(x=data.index, y=data['BB_lower'], name='Lower BB',
661
+ line=dict(color='gray', dash='dash')),
662
+ row=3, col=1
663
+ )
664
+
665
+ fig2.update_layout(
666
+ height=400,
667
+ showlegend=True,
668
+ plot_bgcolor='white',
669
+ paper_bgcolor='white'
670
+ )
671
+
672
+ return [fig1, fig2]
673
+
674
+ def chatbot_response(message: str, company: str, history: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
675
+ """Handle chatbot interactions"""
676
+ router = ChatbotRouter(LlamaAgent())
677
+ response = router.route_and_respond(message, company)
678
+ history = history + [(message, response)]
679
+ return history
680
+
681
+ # def create_interface():
682
+ # """Create Gradio interface"""
683
+ # with gr.Blocks() as interface:
684
+ # gr.Markdown("# Stock Analysis with Multi-Source News")
685
+
686
+ # with gr.Row():
687
+ # with gr.Column(scale=2):
688
+ # company = gr.Dropdown(
689
+ # choices=list(COMPANIES.keys()),
690
+ # value=list(COMPANIES.keys())[0],
691
+ # label="Company"
692
+ # )
693
+ # lookback = gr.Slider(
694
+ # minimum=30,
695
+ # maximum=365,
696
+ # value=180,
697
+ # step=1,
698
+ # label="Analysis Period (days)"
699
+ # )
700
+ # analyze_btn = gr.Button("Analyze", variant="primary")
701
+
702
+ # with gr.Row():
703
+ # with gr.Column(scale=1):
704
+ # chatbot = gr.Chatbot(label="Stock Assistant", height=400)
705
+ # with gr.Row():
706
+ # msg = gr.Textbox(
707
+ # label="Ask about company info, news, or prices",
708
+ # scale=4
709
+ # )
710
+ # submit = gr.Button("Submit", scale=1)
711
+ # clear = gr.Button("Clear", scale=1)
712
+
713
+ # with gr.Column(scale=2):
714
+ # analysis = gr.Textbox(
715
+ # label="Technical Analysis Summary",
716
+ # lines=10
717
+ # )
718
+ # chart1 = gr.Plot(label="Price and Volume Analysis")
719
+ # chart2 = gr.Plot(label="Technical Indicators")
720
+
721
+ # # Event handlers
722
+ # analyze_btn.click(
723
+ # fn=analyze_stock,
724
+ # inputs=[company, lookback],
725
+ # outputs=[analysis, chart1, chart2]
726
+ # )
727
+
728
+ # submit.click(
729
+ # fn=chatbot_response,
730
+ # inputs=[msg, company, chatbot],
731
+ # outputs=chatbot
732
+ # )
733
+
734
+ # msg.submit(
735
+ # fn=chatbot_response,
736
+ # inputs=[msg, company, chatbot],
737
+ # outputs=chatbot
738
+ # )
739
+
740
+ # clear.click(lambda: None, None, chatbot, queue=False)
741
+
742
+ # return interface
743
+
744
+ def create_interface():
745
+ """Create Gradio interface"""
746
+ with gr.Blocks() as interface:
747
+ gr.Markdown("# Stock Analysis with Multi-Source News")
748
+
749
+ # Top section with analysis components
750
+ with gr.Row():
751
+ # Left column - Controls and Summary
752
+ with gr.Column(scale=1):
753
+ company = gr.Dropdown(
754
+ choices=list(COMPANIES.keys()),
755
+ value=list(COMPANIES.keys())[0],
756
+ label="Company"
757
+ )
758
+ lookback = gr.Slider(
759
+ minimum=30,
760
+ maximum=365,
761
+ value=180,
762
+ step=1,
763
+ label="Analysis Period (days)"
764
+ )
765
+ analyze_btn = gr.Button("Analyze", variant="primary")
766
+ analysis = gr.Textbox(
767
+ label="Technical Analysis Summary",
768
+ lines=30
769
+ )
770
+
771
+ # Right column - Charts
772
+ with gr.Column(scale=2):
773
+ chart1 = gr.Plot(label="Price and Volume Analysis")
774
+ chart2 = gr.Plot(label="Technical Indicators")
775
+
776
+ gr.Markdown("---") # Separator
777
+
778
+ # Bottom section - Chatbot
779
+ with gr.Row():
780
+ chatbot = gr.Chatbot(label="Stock Assistant", height=400)
781
+
782
+ with gr.Row():
783
+ msg = gr.Textbox(
784
+ label="Ask about company info, news, or prices",
785
+ scale=4
786
+ )
787
+ submit = gr.Button("Submit", scale=1)
788
+ clear = gr.Button("Clear", scale=1)
789
+
790
+ # Event handlers
791
+ analyze_btn.click(
792
+ fn=analyze_stock,
793
+ inputs=[company, lookback],
794
+ outputs=[analysis, chart1, chart2]
795
+ )
796
+
797
+ submit.click(
798
+ fn=chatbot_response,
799
+ inputs=[msg, company, chatbot],
800
+ outputs=chatbot
801
+ )
802
+
803
+ msg.submit(
804
+ fn=chatbot_response,
805
+ inputs=[msg, company, chatbot],
806
+ outputs=chatbot
807
+ )
808
+
809
+ clear.click(lambda: None, None, chatbot, queue=False)
810
+
811
+ return interface
812
+
813
+ if __name__ == "__main__":
814
+ interface = create_interface()
815
+ interface.launch(debug=True)