arjunanand13's picture
Create app.py
5be7da8 verified
raw
history blame
10 kB
import gradio as gr
import yfinance as yf
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import torch
import torch.nn as nn
from sklearn.preprocessing import StandardScaler
from typing import Dict, List, Optional, Tuple, Union
from datetime import datetime, timedelta
import warnings
warnings.filterwarnings('ignore')
# Constants
COMPANIES = {
'Apple (AAPL)': 'AAPL',
'Microsoft (MSFT)': 'MSFT',
'Amazon (AMZN)': 'AMZN',
'Google (GOOGL)': 'GOOGL',
'Meta (META)': 'META',
'Tesla (TSLA)': 'TSLA',
'NVIDIA (NVDA)': 'NVDA',
'JPMorgan Chase (JPM)': 'JPM',
'Johnson & Johnson (JNJ)': 'JNJ',
'Walmart (WMT)': 'WMT',
'Visa (V)': 'V',
'Mastercard (MA)': 'MA',
'Procter & Gamble (PG)': 'PG',
'UnitedHealth (UNH)': 'UNH',
'Home Depot (HD)': 'HD',
'Bank of America (BAC)': 'BAC',
'Coca-Cola (KO)': 'KO',
'Pfizer (PFE)': 'PFE',
'Disney (DIS)': 'DIS',
'Netflix (NFLX)': 'NFLX'
}
class TimeSeriesPreprocessor:
def __init__(self):
self.scaler = StandardScaler()
def process(self, data: pd.DataFrame) -> Tuple[pd.DataFrame, StandardScaler]:
processed = data.copy()
# Calculate returns and volatility
processed['Returns'] = processed['Close'].pct_change()
processed['Volatility'] = processed['Returns'].rolling(window=20).std()
# Technical indicators
processed['SMA_20'] = processed['Close'].rolling(window=20).mean()
processed['SMA_50'] = processed['Close'].rolling(window=50).mean()
processed['RSI'] = self.calculate_rsi(processed['Close'])
# MACD
exp1 = processed['Close'].ewm(span=12, adjust=False).mean()
exp2 = processed['Close'].ewm(span=26, adjust=False).mean()
processed['MACD'] = exp1 - exp2
processed['Signal_Line'] = processed['MACD'].ewm(span=9, adjust=False).mean()
# Bollinger Bands
processed['BB_middle'] = processed['Close'].rolling(window=20).mean()
processed['BB_upper'] = processed['BB_middle'] + 2 * processed['Close'].rolling(window=20).std()
processed['BB_lower'] = processed['BB_middle'] - 2 * processed['Close'].rolling(window=20).std()
# Handle missing values
processed = processed.fillna(method='ffill').fillna(method='bfill')
# Scale numerical features
numerical_cols = ['Close', 'Volume', 'Returns', 'Volatility']
processed[numerical_cols] = self.scaler.fit_transform(processed[numerical_cols])
return processed, self.scaler
@staticmethod
def calculate_rsi(prices: pd.Series, period: int = 14) -> pd.Series:
delta = prices.diff()
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
rs = gain / loss
return 100 - (100 / (1 + rs))
class AgenticRAGFramework:
def __init__(self):
self.preprocessor = TimeSeriesPreprocessor()
def analyze(self, data: pd.DataFrame) -> Dict:
processed_data, scaler = self.preprocessor.process(data)
analysis = {
'processed_data': processed_data,
'trend': self.analyze_trend(processed_data),
'technical': self.analyze_technical(processed_data),
'volatility': self.analyze_volatility(processed_data),
'summary': self.generate_summary(processed_data)
}
return analysis
def analyze_trend(self, data: pd.DataFrame) -> Dict:
sma_20 = data['SMA_20'].iloc[-1]
sma_50 = data['SMA_50'].iloc[-1]
trend = {
'direction': 'Bullish' if sma_20 > sma_50 else 'Bearish',
'strength': abs(sma_20 - sma_50) / sma_50,
'sma_20': sma_20,
'sma_50': sma_50
}
return trend
def analyze_technical(self, data: pd.DataFrame) -> Dict:
technical = {
'rsi': data['RSI'].iloc[-1],
'macd': data['MACD'].iloc[-1],
'signal_line': data['Signal_Line'].iloc[-1],
'bb_position': (data['Close'].iloc[-1] - data['BB_lower'].iloc[-1]) /
(data['BB_upper'].iloc[-1] - data['BB_lower'].iloc[-1])
}
return technical
def analyze_volatility(self, data: pd.DataFrame) -> Dict:
volatility = {
'current': data['Volatility'].iloc[-1],
'avg_20d': data['Volatility'].rolling(20).mean().iloc[-1],
'trend': 'Increasing' if data['Volatility'].iloc[-1] > data['Volatility'].iloc[-2] else 'Decreasing'
}
return volatility
def generate_summary(self, data: pd.DataFrame) -> str:
latest_close = data['Close'].iloc[-1]
prev_close = data['Close'].iloc[-2]
daily_return = (latest_close - prev_close) / prev_close * 100
rsi = data['RSI'].iloc[-1]
volatility = data['Volatility'].iloc[-1]
summary = f"""Market Analysis Summary:
• Price Action: The stock {'increased' if daily_return > 0 else 'decreased'} by {abs(daily_return):.2f}% in the last session.
• Technical Indicators:
- RSI is at {rsi:.2f} indicating {'overbought' if rsi > 70 else 'oversold' if rsi < 30 else 'neutral'} conditions
- Current volatility is {volatility:.2f} which is {'high' if volatility > 0.5 else 'moderate' if volatility > 0.2 else 'low'}
• Market Signals:
- MACD: {'Bullish' if data['MACD'].iloc[-1] > data['Signal_Line'].iloc[-1] else 'Bearish'} crossover
- Bollinger Bands: Price is {
'near upper band (potential resistance)' if data['BB_position'].iloc[-1] > 0.8
else 'near lower band (potential support)' if data['BB_position'].iloc[-1] < 0.2
else 'in middle range'}
"""
return summary
def create_analysis_plots(data: pd.DataFrame, analysis: Dict) -> List[go.Figure]:
# Price and Technical Indicators Plot
fig1 = make_subplots(rows=2, cols=1, shared_xaxes=True,
subplot_titles=('Price and Technical Indicators', 'Volume'),
row_heights=[0.7, 0.3])
# Price and SMAs
fig1.add_trace(go.Scatter(x=data.index, y=data['Close'],
name='Close Price', line=dict(color='blue')), row=1, col=1)
fig1.add_trace(go.Scatter(x=data.index, y=data['SMA_20'],
name='SMA 20', line=dict(color='orange', dash='dash')), row=1, col=1)
fig1.add_trace(go.Scatter(x=data.index, y=data['SMA_50'],
name='SMA 50', line=dict(color='green', dash='dash')), row=1, col=1)
# Volume
fig1.add_trace(go.Bar(x=data.index, y=data['Volume'],
name='Volume', marker_color='lightblue'), row=2, col=1)
fig1.update_layout(height=600, title_text="Price Analysis")
# Technical Analysis Plot
fig2 = make_subplots(rows=3, cols=1, shared_xaxes=True,
subplot_titles=('RSI', 'MACD', 'Bollinger Bands'),
row_heights=[0.33, 0.33, 0.33])
# RSI
fig2.add_trace(go.Scatter(x=data.index, y=data['RSI'],
name='RSI', line=dict(color='purple')), row=1, col=1)
fig2.add_hline(y=70, line_dash="dash", line_color="red", row=1, col=1)
fig2.add_hline(y=30, line_dash="dash", line_color="green", row=1, col=1)
# MACD
fig2.add_trace(go.Scatter(x=data.index, y=data['MACD'],
name='MACD', line=dict(color='blue')), row=2, col=1)
fig2.add_trace(go.Scatter(x=data.index, y=data['Signal_Line'],
name='Signal Line', line=dict(color='red')), row=2, col=1)
# Bollinger Bands
fig2.add_trace(go.Scatter(x=data.index, y=data['BB_upper'],
name='Upper BB', line=dict(color='gray', dash='dash')), row=3, col=1)
fig2.add_trace(go.Scatter(x=data.index, y=data['BB_middle'],
name='Middle BB', line=dict(color='blue', dash='dash')), row=3, col=1)
fig2.add_trace(go.Scatter(x=data.index, y=data['BB_lower'],
name='Lower BB', line=dict(color='gray', dash='dash')), row=3, col=1)
fig2.update_layout(height=800, title_text="Technical Analysis")
return [fig1, fig2]
def analyze_stock(company: str, lookback_days: int) -> Tuple[str, List[go.Figure]]:
symbol = COMPANIES[company]
end_date = datetime.now()
start_date = end_date - timedelta(days=lookback_days)
# Download data
data = yf.download(symbol, start=start_date, end=end_date)
if len(data) == 0:
return "No data available for the selected period.", []
# Analyze data
framework = AgenticRAGFramework()
analysis = framework.analyze(data)
# Create plots
plots = create_analysis_plots(data, analysis)
return analysis['summary'], plots
def create_gradio_interface():
with gr.Blocks() as interface:
gr.Markdown("# Stock Market Analysis with Agentic RAG")
with gr.Row():
company = gr.Dropdown(choices=list(COMPANIES.keys()), label="Select Company")
lookback = gr.Slider(minimum=30, maximum=365, value=180, step=1, label="Lookback Period (days)")
analyze_btn = gr.Button("Analyze")
with gr.Row():
summary = gr.Textbox(label="Analysis Summary", lines=10)
with gr.Row():
plot1 = gr.Plot(label="Price Analysis")
plot2 = gr.Plot(label="Technical Analysis")
analyze_btn.click(
fn=analyze_stock,
inputs=[company, lookback],
outputs=[summary, plot1, plot2]
)
return interface
if __name__ == "__main__":
interface = create_gradio_interface()
interface.launch(share=True)