Spaces:
Sleeping
Sleeping
import gradio as gr | |
import yfinance as yf | |
import pandas as pd | |
import numpy as np | |
import plotly.graph_objects as go | |
from plotly.subplots import make_subplots | |
from datetime import datetime, timedelta | |
from langchain_huggingface import HuggingFaceEndpoint | |
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
import chromadb | |
import requests | |
from bs4 import BeautifulSoup | |
import warnings | |
from typing import Dict, List, Tuple | |
import feedparser | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import json | |
import os | |
warnings.filterwarnings('ignore') | |
COMPANIES = { | |
'Apple (AAPL)': 'AAPL', | |
'Microsoft (MSFT)': 'MSFT', | |
'Amazon (AMZN)': 'AMZN', | |
'Google (GOOGL)': 'GOOGL', | |
'Meta (META)': 'META', | |
'Tesla (TSLA)': 'TSLA', | |
'NVIDIA (NVDA)': 'NVDA', | |
'JPMorgan Chase (JPM)': 'JPM', | |
'Johnson & Johnson (JNJ)': 'JNJ', | |
'Walmart (WMT)': 'WMT', | |
'Visa (V)': 'V', | |
'Mastercard (MA)': 'MA', | |
'Procter & Gamble (PG)': 'PG', | |
'UnitedHealth (UNH)': 'UNH', | |
'Home Depot (HD)': 'HD', | |
'Bank of America (BAC)': 'BAC', | |
'Coca-Cola (KO)': 'KO', | |
'Pfizer (PFE)': 'PFE', | |
'Disney (DIS)': 'DIS', | |
'Netflix (NFLX)': 'NFLX' | |
} | |
# Initialize models | |
print("Initializing models...") | |
api_token = os.getenv("TOKEN") | |
llm = HuggingFaceEndpoint( | |
repo_id="mistralai/Mistral-7B-Instruct-v0.2", | |
huggingfacehub_api_token=api_token, | |
temperature=0.7, | |
max_new_tokens=1000 | |
) | |
vader = SentimentIntensityAnalyzer() | |
finbert = pipeline("sentiment-analysis", | |
model="ProsusAI/finbert") | |
print("Models initialized successfully!") | |
class AgenticRAGFramework: | |
"""Main framework coordinating all agents""" | |
def __init__(self): | |
self.technical_agent = TechnicalAnalysisAgent() | |
self.sentiment_agent = SentimentAnalysisAgent() | |
self.llama_agent = LLMAgent() | |
self.knowledge_base = chromadb.Client() | |
def analyze(self, symbol: str, data: pd.DataFrame) -> Dict: | |
"""Perform comprehensive analysis""" | |
technical_analysis = self.technical_agent.analyze(data) | |
sentiment_analysis = self.sentiment_agent.analyze(symbol) | |
llm_analysis = self.llama_agent.generate_analysis( | |
technical_analysis, | |
sentiment_analysis | |
) | |
return { | |
'technical_analysis': technical_analysis, | |
'sentiment_analysis': sentiment_analysis, | |
'llm_analysis': llm_analysis | |
} | |
class NewsSource: | |
"""Base class for news sources""" | |
def get_news(self, company: str) -> List[Dict]: | |
raise NotImplementedError | |
class FinvizNews(NewsSource): | |
"""Fetch news from FinViz""" | |
def get_news(self, company: str) -> List[Dict]: | |
try: | |
ticker = company.split('(')[-1].replace(')', '') | |
url = f"https://finviz.com/quote.ashx?t={ticker}" | |
headers = { | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' | |
} | |
response = requests.get(url, headers=headers) | |
soup = BeautifulSoup(response.text, 'html.parser') | |
news_table = soup.find('table', {'class': 'news-table'}) | |
if not news_table: | |
return [] | |
news_list = [] | |
for row in news_table.find_all('tr')[:5]: | |
cols = row.find_all('td') | |
if len(cols) >= 2: | |
date = cols[0].text.strip() | |
title = cols[1].a.text.strip() | |
link = cols[1].a['href'] | |
news_list.append({ | |
'title': title, | |
'description': title, | |
'date': date, | |
'source': 'FinViz', | |
'url': link | |
}) | |
return news_list | |
except Exception as e: | |
print(f"FinViz Error: {str(e)}") | |
return [] | |
class MarketWatchNews(NewsSource): | |
"""Fetch news from MarketWatch""" | |
def get_news(self, company: str) -> List[Dict]: | |
try: | |
ticker = company.split('(')[-1].replace(')', '') | |
url = f"https://www.marketwatch.com/investing/stock/{ticker}" | |
headers = { | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' | |
} | |
response = requests.get(url, headers=headers) | |
soup = BeautifulSoup(response.text, 'html.parser') | |
news_elements = soup.find_all('div', {'class': 'article__content'}) | |
news_list = [] | |
for element in news_elements[:5]: | |
title_elem = element.find('a', {'class': 'link'}) | |
if title_elem: | |
title = title_elem.text.strip() | |
link = title_elem['href'] | |
date_elem = element.find('span', {'class': 'article__timestamp'}) | |
date = date_elem.text if date_elem else 'Recent' | |
news_list.append({ | |
'title': title, | |
'description': title, | |
'date': date, | |
'source': 'MarketWatch', | |
'url': link | |
}) | |
return news_list | |
except Exception as e: | |
print(f"MarketWatch Error: {str(e)}") | |
return [] | |
class YahooRSSNews(NewsSource): | |
"""Fetch news from Yahoo Finance RSS feed""" | |
def get_news(self, company: str) -> List[Dict]: | |
try: | |
ticker = company.split('(')[-1].replace(')', '') | |
url = f"https://feeds.finance.yahoo.com/rss/2.0/headline?s={ticker}®ion=US&lang=en-US" | |
feed = feedparser.parse(url) | |
news_list = [] | |
for entry in feed.entries[:5]: | |
news_list.append({ | |
'title': entry.title, | |
'description': entry.description, | |
'date': entry.published, | |
'source': 'Yahoo Finance', | |
'url': entry.link | |
}) | |
return news_list | |
except Exception as e: | |
print(f"Yahoo RSS Error: {str(e)}") | |
return [] | |
class TechnicalAnalysisAgent: | |
"""Agent for technical analysis""" | |
def __init__(self): | |
self.required_periods = { | |
'sma': [20, 50, 200], | |
'rsi': 14, | |
'volatility': 20, | |
'macd': [12, 26, 9] | |
} | |
def analyze(self, data: pd.DataFrame) -> Dict: | |
df = data.copy() | |
close_col = ('Close', df.columns.get_level_values(1)[0]) | |
# Calculate metrics | |
df['Returns'] = df[close_col].pct_change() | |
# SMAs | |
for period in self.required_periods['sma']: | |
df[f'SMA_{period}'] = df[close_col].rolling(window=period).mean() | |
# RSI | |
delta = df[close_col].diff() | |
gain = delta.where(delta > 0, 0).rolling(window=14).mean() | |
loss = -delta.where(delta < 0, 0).rolling(window=14).mean() | |
rs = gain / loss | |
df['RSI'] = 100 - (100 / (1 + rs)) | |
# MACD | |
exp1 = df[close_col].ewm(span=12, adjust=False).mean() | |
exp2 = df[close_col].ewm(span=26, adjust=False).mean() | |
df['MACD'] = exp1 - exp2 | |
df['Signal_Line'] = df['MACD'].ewm(span=9, adjust=False).mean() | |
# Bollinger Bands | |
df['BB_middle'] = df[close_col].rolling(window=20).mean() | |
rolling_std = df[close_col].rolling(window=20).std() | |
df['BB_upper'] = df['BB_middle'] + (2 * rolling_std) | |
df['BB_lower'] = df['BB_middle'] - (2 * rolling_std) | |
return { | |
'processed_data': df, | |
'current_signals': self._generate_signals(df, close_col) | |
} | |
def _generate_signals(self, df: pd.DataFrame, close_col) -> Dict: | |
if df.empty: | |
return { | |
'trend': 'Unknown', | |
'rsi_signal': 'Unknown', | |
'macd_signal': 'Unknown', | |
'bb_position': 'Unknown' | |
} | |
current = df.iloc[-1] | |
trend = 'Bullish' if float(current['SMA_20']) > float(current['SMA_50']) else 'Bearish' | |
rsi_value = float(current['RSI']) | |
if rsi_value > 70: | |
rsi_signal = 'Overbought' | |
elif rsi_value < 30: | |
rsi_signal = 'Oversold' | |
else: | |
rsi_signal = 'Neutral' | |
macd_signal = 'Buy' if float(current['MACD']) > float(current['Signal_Line']) else 'Sell' | |
close_value = float(current[close_col]) | |
bb_upper = float(current['BB_upper']) | |
bb_lower = float(current['BB_lower']) | |
if close_value > bb_upper: | |
bb_position = 'Above Upper Band' | |
elif close_value < bb_lower: | |
bb_position = 'Below Lower Band' | |
else: | |
bb_position = 'Within Bands' | |
return { | |
'trend': trend, | |
'rsi_signal': rsi_signal, | |
'macd_signal': macd_signal, | |
'bb_position': bb_position | |
} | |
class SentimentAnalysisAgent: | |
"""Agent for sentiment analysis""" | |
def __init__(self): | |
self.news_sources = [ | |
FinvizNews(), | |
MarketWatchNews(), | |
YahooRSSNews() | |
] | |
def analyze(self, symbol: str) -> Dict: | |
all_news = [] | |
for source in self.news_sources: | |
news_items = source.get_news(symbol) | |
all_news.extend(news_items) | |
vader_scores = [] | |
finbert_scores = [] | |
for article in all_news: | |
vader_scores.append(vader.polarity_scores(article['title'])) | |
finbert_scores.append( | |
finbert(article['title'][:512])[0] | |
) | |
return { | |
'articles': all_news, | |
'vader_scores': vader_scores, | |
'finbert_scores': finbert_scores, | |
'aggregated': self._aggregate_sentiment(vader_scores, finbert_scores) | |
} | |
def _aggregate_sentiment(self, vader_scores: List[Dict], | |
finbert_scores: List[Dict]) -> Dict: | |
if not vader_scores or not finbert_scores: | |
return { | |
'sentiment': 'Neutral', | |
'confidence': 0, | |
'vader_sentiment': 0, | |
'finbert_sentiment': 0 | |
} | |
avg_vader = np.mean([score['compound'] for score in vader_scores]) | |
avg_finbert = np.mean([ | |
1 if score['label'] == 'positive' else -1 | |
for score in finbert_scores | |
]) | |
combined_score = (avg_vader + avg_finbert) / 2 | |
return { | |
'sentiment': 'Bullish' if combined_score > 0.1 else 'Bearish' if combined_score < -0.1 else 'Neutral', | |
'confidence': abs(combined_score), | |
'vader_sentiment': avg_vader, | |
'finbert_sentiment': avg_finbert | |
} | |
class LLMAgent: | |
"""Agent for LLM-based analysis using HuggingFace API""" | |
def __init__(self): | |
self.llm = llm | |
def generate_analysis(self, technical_data: Dict, sentiment_data: Dict) -> str: | |
prompt = self._create_prompt(technical_data, sentiment_data) | |
response = self.llm.invoke(prompt) | |
return response | |
def _create_prompt(self, technical_data: Dict, sentiment_data: Dict) -> str: | |
return f"""Based on technical and sentiment indicators: | |
Technical Signals: | |
- Trend: {technical_data['current_signals']['trend']} | |
- RSI: {technical_data['current_signals']['rsi_signal']} | |
- MACD: {technical_data['current_signals']['macd_signal']} | |
- BB Position: {technical_data['current_signals']['bb_position']} | |
- Sentiment: {sentiment_data['aggregated']['sentiment']} (Confidence: {sentiment_data['aggregated']['confidence']:.2f}) | |
Provide: | |
1. Current Trend Analysis | |
2. Key Risk Factors | |
3. Trading Recommendations | |
4. Price Targets | |
5. Near-term Outlook (1-2 weeks) | |
Note: return only required information and nothing unnecessary""" | |
# class ChatbotRouter: | |
# """Routes chatbot queries to appropriate data sources and generates responses""" | |
# def __init__(self): | |
# self.llm = llm | |
# self.encoder = SentenceTransformer('all-MiniLM-L6-v2') | |
# self.faiss_index = None | |
# self.company_data = {} | |
# self.news_sources = [ | |
# FinvizNews(), | |
# MarketWatchNews(), | |
# YahooRSSNews() | |
# ] | |
# self.load_faiss_index() | |
# def route_and_respond(self, query: str, company: str) -> str: | |
# query_type = self._classify_query(query.lower()) | |
# route_message = f"\n[Taking {query_type.upper()} route]\n\n" | |
# if query_type == "company_info": | |
# context = self._get_company_context(query, company) | |
# elif query_type == "news": | |
# context = self._get_news_context(company) | |
# elif query_type == "price": | |
# context = self._get_price_context(company) | |
# else: | |
# return route_message + "I'm not sure how to handle this query. Please ask about company information, news, or price data." | |
# prompt = self._create_prompt(query, context, query_type) | |
# response = self.llm.invoke(prompt) | |
# return route_message + response | |
class ChatbotRouter: | |
"""Routes chatbot queries to appropriate data sources and generates responses""" | |
def __init__(self): | |
self.llm = llm | |
self.encoder = SentenceTransformer('all-MiniLM-L6-v2') | |
self.faiss_index = None | |
self.company_data = {} | |
self.news_sources = [ | |
FinvizNews(), | |
MarketWatchNews(), | |
YahooRSSNews() | |
] | |
self.load_faiss_index() | |
def load_faiss_index(self): | |
try: | |
self.faiss_index = faiss.read_index("company_profiles.index") | |
for file in os.listdir('company_data'): | |
with open(f'company_data/{file}', 'r') as f: | |
company_name = file.replace('.txt', '') | |
self.company_data[company_name] = json.load(f) | |
except Exception as e: | |
print(f"Error loading FAISS index: {e}") | |
def route_and_respond(self, query: str, company: str) -> str: | |
query_type = self._classify_query(query.lower()) | |
route_message = f"\n[Taking {query_type.upper()} route]\n\n" | |
if query_type == "company_info": | |
context = self._get_company_context(query, company) | |
elif query_type == "news": | |
context = self._get_news_context(company) | |
elif query_type == "price": | |
context = self._get_price_context(company) | |
else: | |
return route_message + "I'm not sure how to handle this query. Please ask about company information, news, or price data." | |
prompt = self._create_prompt(query, context, query_type) | |
response = self.llm.invoke(prompt) | |
return route_message + response | |
def _classify_query(self, query: str) -> str: | |
"""Classify query type""" | |
if any(word in query for word in ["profile", "about", "information", "details", "what", "who", "describe"]): | |
return "company_info" | |
elif any(word in query for word in ["news", "latest", "recent", "announcement", "update"]): | |
return "news" | |
elif any(word in query for word in ["price", "stock", "value", "market", "trading", "cost"]): | |
return "price" | |
return "unknown" | |
def _get_company_context(self, query: str, company: str) -> str: | |
"""Get relevant company information using FAISS""" | |
try: | |
query_vector = self.encoder.encode([query]) | |
D, I = self.faiss_index.search(query_vector, 1) | |
company_name = company.split(" (")[0] | |
company_info = self.company_data.get(company_name, {}) | |
print(company_info) | |
return company_info | |
except Exception as e: | |
return f"Error retrieving company information: {str(e)}" | |
def _get_news_context(self, company: str) -> str: | |
"""Get news from multiple sources""" | |
all_news = [] | |
for source in self.news_sources: | |
news_items = source.get_news(company) | |
all_news.extend(news_items) | |
seen_titles = set() | |
unique_news = [] | |
for news in all_news: | |
if news['title'] not in seen_titles: | |
seen_titles.add(news['title']) | |
unique_news.append(news) | |
if not unique_news: | |
return "No recent news found." | |
news_context = "Recent news articles:\n\n" | |
for news in unique_news[:5]: | |
news_context += f"Source: {news['source']}\n" | |
news_context += f"Title: {news['title']}\n" | |
if news['description']: | |
news_context += f"Description: {news['description']}\n" | |
news_context += f"Date: {news['date']}\n\n" | |
return news_context | |
def _get_price_context(self, company: str) -> str: | |
"""Get current price information""" | |
try: | |
ticker = company.split('(')[-1].replace(')', '') | |
stock = yf.Ticker(ticker) | |
info = stock.info | |
return f"""Current Stock Information: | |
Price: ${info.get('currentPrice', 'N/A')} | |
Day Range: ${info.get('dayLow', 'N/A')} - ${info.get('dayHigh', 'N/A')} | |
52 Week Range: ${info.get('fiftyTwoWeekLow', 'N/A')} - ${info.get('fiftyTwoWeekHigh', 'N/A')} | |
Market Cap: ${info.get('marketCap', 'N/A'):,} | |
Volume: {info.get('volume', 'N/A'):,} | |
P/E Ratio: {info.get('trailingPE', 'N/A')} | |
Dividend Yield: {info.get('dividendYield', 'N/A')}%""" | |
except Exception as e: | |
return f"Error fetching price data: {str(e)}" | |
def _create_prompt(self, query: str, context: str, query_type: str) -> str: | |
"""Create prompt for LLM""" | |
if query_type == "news": | |
return f"""Based on the following news articles, please provide a summary addressing the query. | |
Context: | |
{context} | |
Query: {query} | |
Please analyze the news and provide: | |
1. Key points from the recent articles | |
2. Any significant developments or trends | |
3. Potential impact on the company | |
4. Overall sentiment (positive/negative/neutral) | |
Response should be clear, concise, and focused on the most relevant information.""" | |
else: | |
return f"""Based on the following {query_type} context, please answer the question. | |
Context: | |
{context} | |
Question: {query} | |
Please provide a clear and concise answer based on the given context.""" | |
def _generate_response(self, prompt: str) -> str: | |
"""Generate response using LLM""" | |
inputs = self.llm_agent.tokenizer(prompt, return_tensors="pt").to(self.llm_agent.model.device) | |
outputs = self.llm_agent.model.generate( | |
inputs["input_ids"], | |
max_new_tokens=200, | |
temperature=0.7, | |
num_return_sequences=1 | |
) | |
# Decode and remove the prompt part from the output | |
response = self.llm_agent.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response_only = response.replace(prompt, "").strip() | |
print(response) | |
return response_only | |
def analyze_stock(company: str, lookback_days: int = 180) -> Tuple[str, go.Figure, go.Figure]: | |
"""Main analysis function""" | |
try: | |
symbol = COMPANIES[company] | |
end_date = datetime.now() | |
start_date = end_date - timedelta(days=lookback_days) | |
data = yf.download(symbol, start=start_date, end=end_date) | |
if len(data) == 0: | |
return "No data available.", None, None | |
framework = AgenticRAGFramework() | |
analysis = framework.analyze(symbol, data) | |
plots = create_plots(analysis) | |
return analysis['llm_analysis'], plots[0], plots[1] | |
except Exception as e: | |
return f"Error analyzing stock: {str(e)}", None, None | |
def create_plots(analysis: Dict) -> List[go.Figure]: | |
"""Create analysis plots""" | |
data = analysis['technical_analysis']['processed_data'] | |
# Price and Volume Plot | |
fig1 = make_subplots( | |
rows=2, cols=1, | |
shared_xaxes=True, | |
vertical_spacing=0.03, | |
subplot_titles=('Price Analysis', 'Volume'), | |
row_heights=[0.7, 0.3] | |
) | |
close_col = ('Close', data.columns.get_level_values(1)[0]) | |
open_col = ('Open', data.columns.get_level_values(1)[0]) | |
volume_col = ('Volume', data.columns.get_level_values(1)[0]) | |
fig1.add_trace( | |
go.Scatter(x=data.index, y=data[close_col], name='Price', | |
line=dict(color='blue', width=2)), | |
row=1, col=1 | |
) | |
fig1.add_trace( | |
go.Scatter(x=data.index, y=data['SMA_20'], name='SMA20', | |
line=dict(color='orange', width=1.5)), | |
row=1, col=1 | |
) | |
fig1.add_trace( | |
go.Scatter(x=data.index, y=data['SMA_50'], name='SMA50', | |
line=dict(color='red', width=1.5)), | |
row=1, col=1 | |
) | |
colors = ['red' if float(row[close_col]) < float(row[open_col]) else 'green' | |
for idx, row in data.iterrows()] | |
fig1.add_trace( | |
go.Bar(x=data.index, y=data[volume_col], marker_color=colors, name='Volume'), | |
row=2, col=1 | |
) | |
fig1.update_layout( | |
height=400, | |
showlegend=True, | |
xaxis_rangeslider_visible=False, | |
plot_bgcolor='white', | |
paper_bgcolor='white' | |
) | |
# Technical Indicators Plot | |
fig2 = make_subplots( | |
rows=3, cols=1, | |
shared_xaxes=True, | |
subplot_titles=('RSI', 'MACD', 'Bollinger Bands'), | |
row_heights=[0.33, 0.33, 0.34], | |
vertical_spacing=0.03 | |
) | |
# RSI | |
fig2.add_trace( | |
go.Scatter(x=data.index, y=data['RSI'], name='RSI', | |
line=dict(color='purple', width=1.5)), | |
row=1, col=1 | |
) | |
fig2.add_hline(y=70, line_dash="dash", line_color="red", row=1, col=1) | |
fig2.add_hline(y=30, line_dash="dash", line_color="green", row=1, col=1) | |
# MACD | |
fig2.add_trace( | |
go.Scatter(x=data.index, y=data['MACD'], name='MACD', | |
line=dict(color='blue', width=1.5)), | |
row=2, col=1 | |
) | |
fig2.add_trace( | |
go.Scatter(x=data.index, y=data['Signal_Line'], name='Signal', | |
line=dict(color='orange', width=1.5)), | |
row=2, col=1 | |
) | |
# Bollinger Bands | |
fig2.add_trace( | |
go.Scatter(x=data.index, y=data[close_col], name='Price', | |
line=dict(color='blue', width=2)), | |
row=3, col=1 | |
) | |
fig2.add_trace( | |
go.Scatter(x=data.index, y=data['BB_upper'], name='Upper BB', | |
line=dict(color='gray', dash='dash')), | |
row=3, col=1 | |
) | |
fig2.add_trace( | |
go.Scatter(x=data.index, y=data['BB_lower'], name='Lower BB', | |
line=dict(color='gray', dash='dash')), | |
row=3, col=1 | |
) | |
fig2.update_layout( | |
height=400, | |
showlegend=True, | |
plot_bgcolor='white', | |
paper_bgcolor='white' | |
) | |
return [fig1, fig2] | |
def chatbot_response(message: str, company: str, history: List[Tuple[str, str]]) -> List[Tuple[str, str]]: | |
"""Handle chatbot interactions""" | |
router = ChatbotRouter(LlamaAgent()) | |
response = router.route_and_respond(message, company) | |
history = history + [(message, response)] | |
return history | |
# def create_interface(): | |
# """Create Gradio interface""" | |
# with gr.Blocks() as interface: | |
# gr.Markdown("# Stock Analysis with Multi-Source News") | |
# with gr.Row(): | |
# with gr.Column(scale=2): | |
# company = gr.Dropdown( | |
# choices=list(COMPANIES.keys()), | |
# value=list(COMPANIES.keys())[0], | |
# label="Company" | |
# ) | |
# lookback = gr.Slider( | |
# minimum=30, | |
# maximum=365, | |
# value=180, | |
# step=1, | |
# label="Analysis Period (days)" | |
# ) | |
# analyze_btn = gr.Button("Analyze", variant="primary") | |
# with gr.Row(): | |
# with gr.Column(scale=1): | |
# chatbot = gr.Chatbot(label="Stock Assistant", height=400) | |
# with gr.Row(): | |
# msg = gr.Textbox( | |
# label="Ask about company info, news, or prices", | |
# scale=4 | |
# ) | |
# submit = gr.Button("Submit", scale=1) | |
# clear = gr.Button("Clear", scale=1) | |
# with gr.Column(scale=2): | |
# analysis = gr.Textbox( | |
# label="Technical Analysis Summary", | |
# lines=10 | |
# ) | |
# chart1 = gr.Plot(label="Price and Volume Analysis") | |
# chart2 = gr.Plot(label="Technical Indicators") | |
# # Event handlers | |
# analyze_btn.click( | |
# fn=analyze_stock, | |
# inputs=[company, lookback], | |
# outputs=[analysis, chart1, chart2] | |
# ) | |
# submit.click( | |
# fn=chatbot_response, | |
# inputs=[msg, company, chatbot], | |
# outputs=chatbot | |
# ) | |
# msg.submit( | |
# fn=chatbot_response, | |
# inputs=[msg, company, chatbot], | |
# outputs=chatbot | |
# ) | |
# clear.click(lambda: None, None, chatbot, queue=False) | |
# return interface | |
def create_interface(): | |
"""Create Gradio interface""" | |
with gr.Blocks() as interface: | |
gr.Markdown("# Stock Analysis with Multi-Source News") | |
# Top section with analysis components | |
with gr.Row(): | |
# Left column - Controls and Summary | |
with gr.Column(scale=1): | |
company = gr.Dropdown( | |
choices=list(COMPANIES.keys()), | |
value=list(COMPANIES.keys())[0], | |
label="Company" | |
) | |
lookback = gr.Slider( | |
minimum=30, | |
maximum=365, | |
value=180, | |
step=1, | |
label="Analysis Period (days)" | |
) | |
analyze_btn = gr.Button("Analyze", variant="primary") | |
analysis = gr.Textbox( | |
label="Technical Analysis Summary", | |
lines=30 | |
) | |
# Right column - Charts | |
with gr.Column(scale=2): | |
chart1 = gr.Plot(label="Price and Volume Analysis") | |
chart2 = gr.Plot(label="Technical Indicators") | |
gr.Markdown("---") # Separator | |
# Bottom section - Chatbot | |
with gr.Row(): | |
chatbot = gr.Chatbot(label="Stock Assistant", height=400) | |
with gr.Row(): | |
msg = gr.Textbox( | |
label="Ask about company info, news, or prices", | |
scale=4 | |
) | |
submit = gr.Button("Submit", scale=1) | |
clear = gr.Button("Clear", scale=1) | |
# Event handlers | |
analyze_btn.click( | |
fn=analyze_stock, | |
inputs=[company, lookback], | |
outputs=[analysis, chart1, chart2] | |
) | |
submit.click( | |
fn=chatbot_response, | |
inputs=[msg, company, chatbot], | |
outputs=chatbot | |
) | |
msg.submit( | |
fn=chatbot_response, | |
inputs=[msg, company, chatbot], | |
outputs=chatbot | |
) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
return interface | |
if __name__ == "__main__": | |
interface = create_interface() | |
interface.launch(debug=True) |