finance_rag / app.py
arjunanand13's picture
Update app.py
e5e1abd verified
import gradio as gr
import yfinance as yf
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from datetime import datetime, timedelta
from langchain_huggingface import HuggingFaceEndpoint
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import chromadb
import requests
from bs4 import BeautifulSoup
import warnings
from typing import Dict, List, Tuple
import feedparser
from sentence_transformers import SentenceTransformer
import faiss
import json
import os
warnings.filterwarnings('ignore')
COMPANIES = {
'Apple (AAPL)': 'AAPL',
'Microsoft (MSFT)': 'MSFT',
'Amazon (AMZN)': 'AMZN',
'Google (GOOGL)': 'GOOGL',
'Meta (META)': 'META',
'Tesla (TSLA)': 'TSLA',
'NVIDIA (NVDA)': 'NVDA',
'JPMorgan Chase (JPM)': 'JPM',
'Johnson & Johnson (JNJ)': 'JNJ',
'Walmart (WMT)': 'WMT',
'Visa (V)': 'V',
'Mastercard (MA)': 'MA',
'Procter & Gamble (PG)': 'PG',
'UnitedHealth (UNH)': 'UNH',
'Home Depot (HD)': 'HD',
'Bank of America (BAC)': 'BAC',
'Coca-Cola (KO)': 'KO',
'Pfizer (PFE)': 'PFE',
'Disney (DIS)': 'DIS',
'Netflix (NFLX)': 'NFLX'
}
# Initialize models
print("Initializing models...")
api_token = os.getenv("TOKEN")
llm = HuggingFaceEndpoint(
repo_id="mistralai/Mistral-7B-Instruct-v0.2",
huggingfacehub_api_token=api_token,
temperature=0.7,
max_new_tokens=1000
)
vader = SentimentIntensityAnalyzer()
finbert = pipeline("sentiment-analysis",
model="ProsusAI/finbert")
print("Models initialized successfully!")
class AgenticRAGFramework:
"""Main framework coordinating all agents"""
def __init__(self):
self.technical_agent = TechnicalAnalysisAgent()
self.sentiment_agent = SentimentAnalysisAgent()
self.llama_agent = LLMAgent()
self.knowledge_base = chromadb.Client()
def analyze(self, symbol: str, data: pd.DataFrame) -> Dict:
"""Perform comprehensive analysis"""
technical_analysis = self.technical_agent.analyze(data)
sentiment_analysis = self.sentiment_agent.analyze(symbol)
llm_analysis = self.llama_agent.generate_analysis(
technical_analysis,
sentiment_analysis
)
return {
'technical_analysis': technical_analysis,
'sentiment_analysis': sentiment_analysis,
'llm_analysis': llm_analysis
}
class NewsSource:
"""Base class for news sources"""
def get_news(self, company: str) -> List[Dict]:
raise NotImplementedError
class FinvizNews(NewsSource):
"""Fetch news from FinViz"""
def get_news(self, company: str) -> List[Dict]:
try:
ticker = company.split('(')[-1].replace(')', '')
url = f"https://finviz.com/quote.ashx?t={ticker}"
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
}
response = requests.get(url, headers=headers)
soup = BeautifulSoup(response.text, 'html.parser')
news_table = soup.find('table', {'class': 'news-table'})
if not news_table:
return []
news_list = []
for row in news_table.find_all('tr')[:5]:
cols = row.find_all('td')
if len(cols) >= 2:
date = cols[0].text.strip()
title = cols[1].a.text.strip()
link = cols[1].a['href']
news_list.append({
'title': title,
'description': title,
'date': date,
'source': 'FinViz',
'url': link
})
return news_list
except Exception as e:
print(f"FinViz Error: {str(e)}")
return []
class MarketWatchNews(NewsSource):
"""Fetch news from MarketWatch"""
def get_news(self, company: str) -> List[Dict]:
try:
ticker = company.split('(')[-1].replace(')', '')
url = f"https://www.marketwatch.com/investing/stock/{ticker}"
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
}
response = requests.get(url, headers=headers)
soup = BeautifulSoup(response.text, 'html.parser')
news_elements = soup.find_all('div', {'class': 'article__content'})
news_list = []
for element in news_elements[:5]:
title_elem = element.find('a', {'class': 'link'})
if title_elem:
title = title_elem.text.strip()
link = title_elem['href']
date_elem = element.find('span', {'class': 'article__timestamp'})
date = date_elem.text if date_elem else 'Recent'
news_list.append({
'title': title,
'description': title,
'date': date,
'source': 'MarketWatch',
'url': link
})
return news_list
except Exception as e:
print(f"MarketWatch Error: {str(e)}")
return []
class YahooRSSNews(NewsSource):
"""Fetch news from Yahoo Finance RSS feed"""
def get_news(self, company: str) -> List[Dict]:
try:
ticker = company.split('(')[-1].replace(')', '')
url = f"https://feeds.finance.yahoo.com/rss/2.0/headline?s={ticker}&region=US&lang=en-US"
feed = feedparser.parse(url)
news_list = []
for entry in feed.entries[:5]:
news_list.append({
'title': entry.title,
'description': entry.description,
'date': entry.published,
'source': 'Yahoo Finance',
'url': entry.link
})
return news_list
except Exception as e:
print(f"Yahoo RSS Error: {str(e)}")
return []
class TechnicalAnalysisAgent:
"""Agent for technical analysis"""
def __init__(self):
self.required_periods = {
'sma': [20, 50, 200],
'rsi': 14,
'volatility': 20,
'macd': [12, 26, 9]
}
def analyze(self, data: pd.DataFrame) -> Dict:
df = data.copy()
close_col = ('Close', df.columns.get_level_values(1)[0])
# Calculate metrics
df['Returns'] = df[close_col].pct_change()
# SMAs
for period in self.required_periods['sma']:
df[f'SMA_{period}'] = df[close_col].rolling(window=period).mean()
# RSI
delta = df[close_col].diff()
gain = delta.where(delta > 0, 0).rolling(window=14).mean()
loss = -delta.where(delta < 0, 0).rolling(window=14).mean()
rs = gain / loss
df['RSI'] = 100 - (100 / (1 + rs))
# MACD
exp1 = df[close_col].ewm(span=12, adjust=False).mean()
exp2 = df[close_col].ewm(span=26, adjust=False).mean()
df['MACD'] = exp1 - exp2
df['Signal_Line'] = df['MACD'].ewm(span=9, adjust=False).mean()
# Bollinger Bands
df['BB_middle'] = df[close_col].rolling(window=20).mean()
rolling_std = df[close_col].rolling(window=20).std()
df['BB_upper'] = df['BB_middle'] + (2 * rolling_std)
df['BB_lower'] = df['BB_middle'] - (2 * rolling_std)
return {
'processed_data': df,
'current_signals': self._generate_signals(df, close_col)
}
def _generate_signals(self, df: pd.DataFrame, close_col) -> Dict:
if df.empty:
return {
'trend': 'Unknown',
'rsi_signal': 'Unknown',
'macd_signal': 'Unknown',
'bb_position': 'Unknown'
}
current = df.iloc[-1]
trend = 'Bullish' if float(current['SMA_20']) > float(current['SMA_50']) else 'Bearish'
rsi_value = float(current['RSI'])
if rsi_value > 70:
rsi_signal = 'Overbought'
elif rsi_value < 30:
rsi_signal = 'Oversold'
else:
rsi_signal = 'Neutral'
macd_signal = 'Buy' if float(current['MACD']) > float(current['Signal_Line']) else 'Sell'
close_value = float(current[close_col])
bb_upper = float(current['BB_upper'])
bb_lower = float(current['BB_lower'])
if close_value > bb_upper:
bb_position = 'Above Upper Band'
elif close_value < bb_lower:
bb_position = 'Below Lower Band'
else:
bb_position = 'Within Bands'
return {
'trend': trend,
'rsi_signal': rsi_signal,
'macd_signal': macd_signal,
'bb_position': bb_position
}
class SentimentAnalysisAgent:
"""Agent for sentiment analysis"""
def __init__(self):
self.news_sources = [
FinvizNews(),
MarketWatchNews(),
YahooRSSNews()
]
def analyze(self, symbol: str) -> Dict:
all_news = []
for source in self.news_sources:
news_items = source.get_news(symbol)
all_news.extend(news_items)
vader_scores = []
finbert_scores = []
for article in all_news:
vader_scores.append(vader.polarity_scores(article['title']))
finbert_scores.append(
finbert(article['title'][:512])[0]
)
return {
'articles': all_news,
'vader_scores': vader_scores,
'finbert_scores': finbert_scores,
'aggregated': self._aggregate_sentiment(vader_scores, finbert_scores)
}
def _aggregate_sentiment(self, vader_scores: List[Dict],
finbert_scores: List[Dict]) -> Dict:
if not vader_scores or not finbert_scores:
return {
'sentiment': 'Neutral',
'confidence': 0,
'vader_sentiment': 0,
'finbert_sentiment': 0
}
avg_vader = np.mean([score['compound'] for score in vader_scores])
avg_finbert = np.mean([
1 if score['label'] == 'positive' else -1
for score in finbert_scores
])
combined_score = (avg_vader + avg_finbert) / 2
return {
'sentiment': 'Bullish' if combined_score > 0.1 else 'Bearish' if combined_score < -0.1 else 'Neutral',
'confidence': abs(combined_score),
'vader_sentiment': avg_vader,
'finbert_sentiment': avg_finbert
}
class LLMAgent:
"""Agent for LLM-based analysis using HuggingFace API"""
def __init__(self):
self.llm = llm
def generate_analysis(self, technical_data: Dict, sentiment_data: Dict) -> str:
prompt = self._create_prompt(technical_data, sentiment_data)
response = self.llm.invoke(prompt)
return response
def _create_prompt(self, technical_data: Dict, sentiment_data: Dict) -> str:
return f"""Based on technical and sentiment indicators:
Technical Signals:
- Trend: {technical_data['current_signals']['trend']}
- RSI: {technical_data['current_signals']['rsi_signal']}
- MACD: {technical_data['current_signals']['macd_signal']}
- BB Position: {technical_data['current_signals']['bb_position']}
- Sentiment: {sentiment_data['aggregated']['sentiment']} (Confidence: {sentiment_data['aggregated']['confidence']:.2f})
Provide:
1. Current Trend Analysis
2. Key Risk Factors
3. Trading Recommendations
4. Price Targets
5. Near-term Outlook (1-2 weeks)
Note: return only required information and nothing unnecessary"""
# class ChatbotRouter:
# """Routes chatbot queries to appropriate data sources and generates responses"""
# def __init__(self):
# self.llm = llm
# self.encoder = SentenceTransformer('all-MiniLM-L6-v2')
# self.faiss_index = None
# self.company_data = {}
# self.news_sources = [
# FinvizNews(),
# MarketWatchNews(),
# YahooRSSNews()
# ]
# self.load_faiss_index()
# def route_and_respond(self, query: str, company: str) -> str:
# query_type = self._classify_query(query.lower())
# route_message = f"\n[Taking {query_type.upper()} route]\n\n"
# if query_type == "company_info":
# context = self._get_company_context(query, company)
# elif query_type == "news":
# context = self._get_news_context(company)
# elif query_type == "price":
# context = self._get_price_context(company)
# else:
# return route_message + "I'm not sure how to handle this query. Please ask about company information, news, or price data."
# prompt = self._create_prompt(query, context, query_type)
# response = self.llm.invoke(prompt)
# return route_message + response
class ChatbotRouter:
"""Routes chatbot queries to appropriate data sources and generates responses"""
def __init__(self):
self.llm = llm
self.encoder = SentenceTransformer('all-MiniLM-L6-v2')
self.faiss_index = None
self.company_data = {}
self.news_sources = [
FinvizNews(),
MarketWatchNews(),
YahooRSSNews()
]
self.load_faiss_index()
def load_faiss_index(self):
try:
self.faiss_index = faiss.read_index("company_profiles.index")
for file in os.listdir('company_data'):
with open(f'company_data/{file}', 'r') as f:
company_name = file.replace('.txt', '')
self.company_data[company_name] = json.load(f)
except Exception as e:
print(f"Error loading FAISS index: {e}")
def route_and_respond(self, query: str, company: str) -> str:
query_type = self._classify_query(query.lower())
route_message = f"\n[Taking {query_type.upper()} route]\n\n"
if query_type == "company_info":
context = self._get_company_context(query, company)
elif query_type == "news":
context = self._get_news_context(company)
elif query_type == "price":
context = self._get_price_context(company)
else:
return route_message + "I'm not sure how to handle this query. Please ask about company information, news, or price data."
prompt = self._create_prompt(query, context, query_type)
response = self.llm.invoke(prompt)
return route_message + response
def _classify_query(self, query: str) -> str:
"""Classify query type"""
if any(word in query for word in ["profile", "about", "information", "details", "what", "who", "describe"]):
return "company_info"
elif any(word in query for word in ["news", "latest", "recent", "announcement", "update"]):
return "news"
elif any(word in query for word in ["price", "stock", "value", "market", "trading", "cost"]):
return "price"
return "unknown"
def _get_company_context(self, query: str, company: str) -> str:
"""Get relevant company information using FAISS"""
try:
query_vector = self.encoder.encode([query])
D, I = self.faiss_index.search(query_vector, 1)
company_name = company.split(" (")[0]
company_info = self.company_data.get(company_name, {})
print(company_info)
return company_info
except Exception as e:
return f"Error retrieving company information: {str(e)}"
def _get_news_context(self, company: str) -> str:
"""Get news from multiple sources"""
all_news = []
for source in self.news_sources:
news_items = source.get_news(company)
all_news.extend(news_items)
seen_titles = set()
unique_news = []
for news in all_news:
if news['title'] not in seen_titles:
seen_titles.add(news['title'])
unique_news.append(news)
if not unique_news:
return "No recent news found."
news_context = "Recent news articles:\n\n"
for news in unique_news[:5]:
news_context += f"Source: {news['source']}\n"
news_context += f"Title: {news['title']}\n"
if news['description']:
news_context += f"Description: {news['description']}\n"
news_context += f"Date: {news['date']}\n\n"
return news_context
def _get_price_context(self, company: str) -> str:
"""Get current price information"""
try:
ticker = company.split('(')[-1].replace(')', '')
stock = yf.Ticker(ticker)
info = stock.info
return f"""Current Stock Information:
Price: ${info.get('currentPrice', 'N/A')}
Day Range: ${info.get('dayLow', 'N/A')} - ${info.get('dayHigh', 'N/A')}
52 Week Range: ${info.get('fiftyTwoWeekLow', 'N/A')} - ${info.get('fiftyTwoWeekHigh', 'N/A')}
Market Cap: ${info.get('marketCap', 'N/A'):,}
Volume: {info.get('volume', 'N/A'):,}
P/E Ratio: {info.get('trailingPE', 'N/A')}
Dividend Yield: {info.get('dividendYield', 'N/A')}%"""
except Exception as e:
return f"Error fetching price data: {str(e)}"
def _create_prompt(self, query: str, context: str, query_type: str) -> str:
"""Create prompt for LLM"""
if query_type == "news":
return f"""Based on the following news articles, please provide a summary addressing the query.
Context:
{context}
Query: {query}
Please analyze the news and provide:
1. Key points from the recent articles
2. Any significant developments or trends
3. Potential impact on the company
4. Overall sentiment (positive/negative/neutral)
Response should be clear, concise, and focused on the most relevant information."""
else:
return f"""Based on the following {query_type} context, please answer the question.
Context:
{context}
Question: {query}
Please provide a clear and concise answer based on the given context."""
def _generate_response(self, prompt: str) -> str:
"""Generate response using LLM"""
inputs = self.llm_agent.tokenizer(prompt, return_tensors="pt").to(self.llm_agent.model.device)
outputs = self.llm_agent.model.generate(
inputs["input_ids"],
max_new_tokens=200,
temperature=0.7,
num_return_sequences=1
)
# Decode and remove the prompt part from the output
response = self.llm_agent.tokenizer.decode(outputs[0], skip_special_tokens=True)
response_only = response.replace(prompt, "").strip()
print(response)
return response_only
def analyze_stock(company: str, lookback_days: int = 180) -> Tuple[str, go.Figure, go.Figure]:
"""Main analysis function"""
try:
symbol = COMPANIES[company]
end_date = datetime.now()
start_date = end_date - timedelta(days=lookback_days)
data = yf.download(symbol, start=start_date, end=end_date)
if len(data) == 0:
return "No data available.", None, None
framework = AgenticRAGFramework()
analysis = framework.analyze(symbol, data)
plots = create_plots(analysis)
return analysis['llm_analysis'], plots[0], plots[1]
except Exception as e:
return f"Error analyzing stock: {str(e)}", None, None
def create_plots(analysis: Dict) -> List[go.Figure]:
"""Create analysis plots"""
data = analysis['technical_analysis']['processed_data']
# Price and Volume Plot
fig1 = make_subplots(
rows=2, cols=1,
shared_xaxes=True,
vertical_spacing=0.03,
subplot_titles=('Price Analysis', 'Volume'),
row_heights=[0.7, 0.3]
)
close_col = ('Close', data.columns.get_level_values(1)[0])
open_col = ('Open', data.columns.get_level_values(1)[0])
volume_col = ('Volume', data.columns.get_level_values(1)[0])
fig1.add_trace(
go.Scatter(x=data.index, y=data[close_col], name='Price',
line=dict(color='blue', width=2)),
row=1, col=1
)
fig1.add_trace(
go.Scatter(x=data.index, y=data['SMA_20'], name='SMA20',
line=dict(color='orange', width=1.5)),
row=1, col=1
)
fig1.add_trace(
go.Scatter(x=data.index, y=data['SMA_50'], name='SMA50',
line=dict(color='red', width=1.5)),
row=1, col=1
)
colors = ['red' if float(row[close_col]) < float(row[open_col]) else 'green'
for idx, row in data.iterrows()]
fig1.add_trace(
go.Bar(x=data.index, y=data[volume_col], marker_color=colors, name='Volume'),
row=2, col=1
)
fig1.update_layout(
height=400,
showlegend=True,
xaxis_rangeslider_visible=False,
plot_bgcolor='white',
paper_bgcolor='white'
)
# Technical Indicators Plot
fig2 = make_subplots(
rows=3, cols=1,
shared_xaxes=True,
subplot_titles=('RSI', 'MACD', 'Bollinger Bands'),
row_heights=[0.33, 0.33, 0.34],
vertical_spacing=0.03
)
# RSI
fig2.add_trace(
go.Scatter(x=data.index, y=data['RSI'], name='RSI',
line=dict(color='purple', width=1.5)),
row=1, col=1
)
fig2.add_hline(y=70, line_dash="dash", line_color="red", row=1, col=1)
fig2.add_hline(y=30, line_dash="dash", line_color="green", row=1, col=1)
# MACD
fig2.add_trace(
go.Scatter(x=data.index, y=data['MACD'], name='MACD',
line=dict(color='blue', width=1.5)),
row=2, col=1
)
fig2.add_trace(
go.Scatter(x=data.index, y=data['Signal_Line'], name='Signal',
line=dict(color='orange', width=1.5)),
row=2, col=1
)
# Bollinger Bands
fig2.add_trace(
go.Scatter(x=data.index, y=data[close_col], name='Price',
line=dict(color='blue', width=2)),
row=3, col=1
)
fig2.add_trace(
go.Scatter(x=data.index, y=data['BB_upper'], name='Upper BB',
line=dict(color='gray', dash='dash')),
row=3, col=1
)
fig2.add_trace(
go.Scatter(x=data.index, y=data['BB_lower'], name='Lower BB',
line=dict(color='gray', dash='dash')),
row=3, col=1
)
fig2.update_layout(
height=400,
showlegend=True,
plot_bgcolor='white',
paper_bgcolor='white'
)
return [fig1, fig2]
def chatbot_response(message: str, company: str, history: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
"""Handle chatbot interactions"""
router = ChatbotRouter(LlamaAgent())
response = router.route_and_respond(message, company)
history = history + [(message, response)]
return history
# def create_interface():
# """Create Gradio interface"""
# with gr.Blocks() as interface:
# gr.Markdown("# Stock Analysis with Multi-Source News")
# with gr.Row():
# with gr.Column(scale=2):
# company = gr.Dropdown(
# choices=list(COMPANIES.keys()),
# value=list(COMPANIES.keys())[0],
# label="Company"
# )
# lookback = gr.Slider(
# minimum=30,
# maximum=365,
# value=180,
# step=1,
# label="Analysis Period (days)"
# )
# analyze_btn = gr.Button("Analyze", variant="primary")
# with gr.Row():
# with gr.Column(scale=1):
# chatbot = gr.Chatbot(label="Stock Assistant", height=400)
# with gr.Row():
# msg = gr.Textbox(
# label="Ask about company info, news, or prices",
# scale=4
# )
# submit = gr.Button("Submit", scale=1)
# clear = gr.Button("Clear", scale=1)
# with gr.Column(scale=2):
# analysis = gr.Textbox(
# label="Technical Analysis Summary",
# lines=10
# )
# chart1 = gr.Plot(label="Price and Volume Analysis")
# chart2 = gr.Plot(label="Technical Indicators")
# # Event handlers
# analyze_btn.click(
# fn=analyze_stock,
# inputs=[company, lookback],
# outputs=[analysis, chart1, chart2]
# )
# submit.click(
# fn=chatbot_response,
# inputs=[msg, company, chatbot],
# outputs=chatbot
# )
# msg.submit(
# fn=chatbot_response,
# inputs=[msg, company, chatbot],
# outputs=chatbot
# )
# clear.click(lambda: None, None, chatbot, queue=False)
# return interface
def create_interface():
"""Create Gradio interface"""
with gr.Blocks() as interface:
gr.Markdown("# Stock Analysis with Multi-Source News")
# Top section with analysis components
with gr.Row():
# Left column - Controls and Summary
with gr.Column(scale=1):
company = gr.Dropdown(
choices=list(COMPANIES.keys()),
value=list(COMPANIES.keys())[0],
label="Company"
)
lookback = gr.Slider(
minimum=30,
maximum=365,
value=180,
step=1,
label="Analysis Period (days)"
)
analyze_btn = gr.Button("Analyze", variant="primary")
analysis = gr.Textbox(
label="Technical Analysis Summary",
lines=30
)
# Right column - Charts
with gr.Column(scale=2):
chart1 = gr.Plot(label="Price and Volume Analysis")
chart2 = gr.Plot(label="Technical Indicators")
gr.Markdown("---") # Separator
# Bottom section - Chatbot
with gr.Row():
chatbot = gr.Chatbot(label="Stock Assistant", height=400)
with gr.Row():
msg = gr.Textbox(
label="Ask about company info, news, or prices",
scale=4
)
submit = gr.Button("Submit", scale=1)
clear = gr.Button("Clear", scale=1)
# Event handlers
analyze_btn.click(
fn=analyze_stock,
inputs=[company, lookback],
outputs=[analysis, chart1, chart2]
)
submit.click(
fn=chatbot_response,
inputs=[msg, company, chatbot],
outputs=chatbot
)
msg.submit(
fn=chatbot_response,
inputs=[msg, company, chatbot],
outputs=chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
return interface
if __name__ == "__main__":
interface = create_interface()
interface.launch(debug=True)