Spaces:
Sleeping
Sleeping
import gradio as gr | |
import yfinance as yf | |
import pandas as pd | |
import numpy as np | |
import plotly.graph_objects as go | |
from plotly.subplots import make_subplots | |
import torch | |
import torch.nn as nn | |
from sklearn.preprocessing import StandardScaler | |
from typing import Dict, List, Optional, Tuple, Union | |
from datetime import datetime, timedelta | |
import warnings | |
warnings.filterwarnings('ignore') | |
# Constants | |
COMPANIES = { | |
'Apple (AAPL)': 'AAPL', | |
'Microsoft (MSFT)': 'MSFT', | |
'Amazon (AMZN)': 'AMZN', | |
'Google (GOOGL)': 'GOOGL', | |
'Meta (META)': 'META', | |
'Tesla (TSLA)': 'TSLA', | |
'NVIDIA (NVDA)': 'NVDA', | |
'JPMorgan Chase (JPM)': 'JPM', | |
'Johnson & Johnson (JNJ)': 'JNJ', | |
'Walmart (WMT)': 'WMT', | |
'Visa (V)': 'V', | |
'Mastercard (MA)': 'MA', | |
'Procter & Gamble (PG)': 'PG', | |
'UnitedHealth (UNH)': 'UNH', | |
'Home Depot (HD)': 'HD', | |
'Bank of America (BAC)': 'BAC', | |
'Coca-Cola (KO)': 'KO', | |
'Pfizer (PFE)': 'PFE', | |
'Disney (DIS)': 'DIS', | |
'Netflix (NFLX)': 'NFLX' | |
} | |
class TimeSeriesPreprocessor: | |
def __init__(self): | |
self.scaler = StandardScaler() | |
def process(self, data: pd.DataFrame) -> Tuple[pd.DataFrame, StandardScaler]: | |
processed = data.copy() | |
# Calculate returns and volatility | |
processed['Returns'] = processed['Close'].pct_change() | |
processed['Volatility'] = processed['Returns'].rolling(window=20).std() | |
# Technical indicators | |
processed['SMA_20'] = processed['Close'].rolling(window=20).mean() | |
processed['SMA_50'] = processed['Close'].rolling(window=50).mean() | |
processed['RSI'] = self.calculate_rsi(processed['Close']) | |
# MACD | |
exp1 = processed['Close'].ewm(span=12, adjust=False).mean() | |
exp2 = processed['Close'].ewm(span=26, adjust=False).mean() | |
processed['MACD'] = exp1 - exp2 | |
processed['Signal_Line'] = processed['MACD'].ewm(span=9, adjust=False).mean() | |
# Bollinger Bands | |
processed['BB_middle'] = processed['Close'].rolling(window=20).mean() | |
processed['BB_upper'] = processed['BB_middle'] + 2 * processed['Close'].rolling(window=20).std() | |
processed['BB_lower'] = processed['BB_middle'] - 2 * processed['Close'].rolling(window=20).std() | |
# Handle missing values | |
processed = processed.fillna(method='ffill').fillna(method='bfill') | |
# Scale numerical features | |
numerical_cols = ['Close', 'Volume', 'Returns', 'Volatility'] | |
processed[numerical_cols] = self.scaler.fit_transform(processed[numerical_cols]) | |
return processed, self.scaler | |
def calculate_rsi(prices: pd.Series, period: int = 14) -> pd.Series: | |
delta = prices.diff() | |
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean() | |
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean() | |
rs = gain / loss | |
return 100 - (100 / (1 + rs)) | |
class AgenticRAGFramework: | |
def __init__(self): | |
self.preprocessor = TimeSeriesPreprocessor() | |
def analyze(self, data: pd.DataFrame) -> Dict: | |
processed_data, scaler = self.preprocessor.process(data) | |
analysis = { | |
'processed_data': processed_data, | |
'trend': self.analyze_trend(processed_data), | |
'technical': self.analyze_technical(processed_data), | |
'volatility': self.analyze_volatility(processed_data), | |
'summary': self.generate_summary(processed_data) | |
} | |
return analysis | |
def analyze_trend(self, data: pd.DataFrame) -> Dict: | |
sma_20 = data['SMA_20'].iloc[-1] | |
sma_50 = data['SMA_50'].iloc[-1] | |
trend = { | |
'direction': 'Bullish' if sma_20 > sma_50 else 'Bearish', | |
'strength': abs(sma_20 - sma_50) / sma_50, | |
'sma_20': sma_20, | |
'sma_50': sma_50 | |
} | |
return trend | |
def analyze_technical(self, data: pd.DataFrame) -> Dict: | |
technical = { | |
'rsi': data['RSI'].iloc[-1], | |
'macd': data['MACD'].iloc[-1], | |
'signal_line': data['Signal_Line'].iloc[-1], | |
'bb_position': (data['Close'].iloc[-1] - data['BB_lower'].iloc[-1]) / | |
(data['BB_upper'].iloc[-1] - data['BB_lower'].iloc[-1]) | |
} | |
return technical | |
def analyze_volatility(self, data: pd.DataFrame) -> Dict: | |
volatility = { | |
'current': data['Volatility'].iloc[-1], | |
'avg_20d': data['Volatility'].rolling(20).mean().iloc[-1], | |
'trend': 'Increasing' if data['Volatility'].iloc[-1] > data['Volatility'].iloc[-2] else 'Decreasing' | |
} | |
return volatility | |
def generate_summary(self, data: pd.DataFrame) -> str: | |
latest_close = data['Close'].iloc[-1] | |
prev_close = data['Close'].iloc[-2] | |
daily_return = (latest_close - prev_close) / prev_close * 100 | |
rsi = data['RSI'].iloc[-1] | |
volatility = data['Volatility'].iloc[-1] | |
summary = f"""Market Analysis Summary: | |
• Price Action: The stock {'increased' if daily_return > 0 else 'decreased'} by {abs(daily_return):.2f}% in the last session. | |
• Technical Indicators: | |
- RSI is at {rsi:.2f} indicating {'overbought' if rsi > 70 else 'oversold' if rsi < 30 else 'neutral'} conditions | |
- Current volatility is {volatility:.2f} which is {'high' if volatility > 0.5 else 'moderate' if volatility > 0.2 else 'low'} | |
• Market Signals: | |
- MACD: {'Bullish' if data['MACD'].iloc[-1] > data['Signal_Line'].iloc[-1] else 'Bearish'} crossover | |
- Bollinger Bands: Price is { | |
'near upper band (potential resistance)' if data['BB_position'].iloc[-1] > 0.8 | |
else 'near lower band (potential support)' if data['BB_position'].iloc[-1] < 0.2 | |
else 'in middle range'} | |
""" | |
return summary | |
def create_analysis_plots(data: pd.DataFrame, analysis: Dict) -> List[go.Figure]: | |
# Price and Technical Indicators Plot | |
fig1 = make_subplots(rows=2, cols=1, shared_xaxes=True, | |
subplot_titles=('Price and Technical Indicators', 'Volume'), | |
row_heights=[0.7, 0.3]) | |
# Price and SMAs | |
fig1.add_trace(go.Scatter(x=data.index, y=data['Close'], | |
name='Close Price', line=dict(color='blue')), row=1, col=1) | |
fig1.add_trace(go.Scatter(x=data.index, y=data['SMA_20'], | |
name='SMA 20', line=dict(color='orange', dash='dash')), row=1, col=1) | |
fig1.add_trace(go.Scatter(x=data.index, y=data['SMA_50'], | |
name='SMA 50', line=dict(color='green', dash='dash')), row=1, col=1) | |
# Volume | |
fig1.add_trace(go.Bar(x=data.index, y=data['Volume'], | |
name='Volume', marker_color='lightblue'), row=2, col=1) | |
fig1.update_layout(height=600, title_text="Price Analysis") | |
# Technical Analysis Plot | |
fig2 = make_subplots(rows=3, cols=1, shared_xaxes=True, | |
subplot_titles=('RSI', 'MACD', 'Bollinger Bands'), | |
row_heights=[0.33, 0.33, 0.33]) | |
# RSI | |
fig2.add_trace(go.Scatter(x=data.index, y=data['RSI'], | |
name='RSI', line=dict(color='purple')), row=1, col=1) | |
fig2.add_hline(y=70, line_dash="dash", line_color="red", row=1, col=1) | |
fig2.add_hline(y=30, line_dash="dash", line_color="green", row=1, col=1) | |
# MACD | |
fig2.add_trace(go.Scatter(x=data.index, y=data['MACD'], | |
name='MACD', line=dict(color='blue')), row=2, col=1) | |
fig2.add_trace(go.Scatter(x=data.index, y=data['Signal_Line'], | |
name='Signal Line', line=dict(color='red')), row=2, col=1) | |
# Bollinger Bands | |
fig2.add_trace(go.Scatter(x=data.index, y=data['BB_upper'], | |
name='Upper BB', line=dict(color='gray', dash='dash')), row=3, col=1) | |
fig2.add_trace(go.Scatter(x=data.index, y=data['BB_middle'], | |
name='Middle BB', line=dict(color='blue', dash='dash')), row=3, col=1) | |
fig2.add_trace(go.Scatter(x=data.index, y=data['BB_lower'], | |
name='Lower BB', line=dict(color='gray', dash='dash')), row=3, col=1) | |
fig2.update_layout(height=800, title_text="Technical Analysis") | |
return [fig1, fig2] | |
def analyze_stock(company: str, lookback_days: int) -> Tuple[str, List[go.Figure]]: | |
symbol = COMPANIES[company] | |
end_date = datetime.now() | |
start_date = end_date - timedelta(days=lookback_days) | |
# Download data | |
data = yf.download(symbol, start=start_date, end=end_date) | |
if len(data) == 0: | |
return "No data available for the selected period.", [] | |
# Analyze data | |
framework = AgenticRAGFramework() | |
analysis = framework.analyze(data) | |
# Create plots | |
plots = create_analysis_plots(data, analysis) | |
return analysis['summary'], plots | |
def create_gradio_interface(): | |
with gr.Blocks() as interface: | |
gr.Markdown("# Stock Market Analysis with Agentic RAG") | |
with gr.Row(): | |
company = gr.Dropdown(choices=list(COMPANIES.keys()), label="Select Company") | |
lookback = gr.Slider(minimum=30, maximum=365, value=180, step=1, label="Lookback Period (days)") | |
analyze_btn = gr.Button("Analyze") | |
with gr.Row(): | |
summary = gr.Textbox(label="Analysis Summary", lines=10) | |
with gr.Row(): | |
plot1 = gr.Plot(label="Price Analysis") | |
plot2 = gr.Plot(label="Technical Analysis") | |
analyze_btn.click( | |
fn=analyze_stock, | |
inputs=[company, lookback], | |
outputs=[summary, plot1, plot2] | |
) | |
return interface | |
if __name__ == "__main__": | |
interface = create_gradio_interface() | |
interface.launch(share=True) |