arjunanand13 commited on
Commit
5be7da8
·
verified ·
1 Parent(s): bee42fa

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +255 -0
app.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import yfinance as yf
3
+ import pandas as pd
4
+ import numpy as np
5
+ import plotly.graph_objects as go
6
+ from plotly.subplots import make_subplots
7
+ import torch
8
+ import torch.nn as nn
9
+ from sklearn.preprocessing import StandardScaler
10
+ from typing import Dict, List, Optional, Tuple, Union
11
+ from datetime import datetime, timedelta
12
+ import warnings
13
+ warnings.filterwarnings('ignore')
14
+
15
+ # Constants
16
+ COMPANIES = {
17
+ 'Apple (AAPL)': 'AAPL',
18
+ 'Microsoft (MSFT)': 'MSFT',
19
+ 'Amazon (AMZN)': 'AMZN',
20
+ 'Google (GOOGL)': 'GOOGL',
21
+ 'Meta (META)': 'META',
22
+ 'Tesla (TSLA)': 'TSLA',
23
+ 'NVIDIA (NVDA)': 'NVDA',
24
+ 'JPMorgan Chase (JPM)': 'JPM',
25
+ 'Johnson & Johnson (JNJ)': 'JNJ',
26
+ 'Walmart (WMT)': 'WMT',
27
+ 'Visa (V)': 'V',
28
+ 'Mastercard (MA)': 'MA',
29
+ 'Procter & Gamble (PG)': 'PG',
30
+ 'UnitedHealth (UNH)': 'UNH',
31
+ 'Home Depot (HD)': 'HD',
32
+ 'Bank of America (BAC)': 'BAC',
33
+ 'Coca-Cola (KO)': 'KO',
34
+ 'Pfizer (PFE)': 'PFE',
35
+ 'Disney (DIS)': 'DIS',
36
+ 'Netflix (NFLX)': 'NFLX'
37
+ }
38
+
39
+ class TimeSeriesPreprocessor:
40
+ def __init__(self):
41
+ self.scaler = StandardScaler()
42
+
43
+ def process(self, data: pd.DataFrame) -> Tuple[pd.DataFrame, StandardScaler]:
44
+ processed = data.copy()
45
+
46
+ # Calculate returns and volatility
47
+ processed['Returns'] = processed['Close'].pct_change()
48
+ processed['Volatility'] = processed['Returns'].rolling(window=20).std()
49
+
50
+ # Technical indicators
51
+ processed['SMA_20'] = processed['Close'].rolling(window=20).mean()
52
+ processed['SMA_50'] = processed['Close'].rolling(window=50).mean()
53
+ processed['RSI'] = self.calculate_rsi(processed['Close'])
54
+
55
+ # MACD
56
+ exp1 = processed['Close'].ewm(span=12, adjust=False).mean()
57
+ exp2 = processed['Close'].ewm(span=26, adjust=False).mean()
58
+ processed['MACD'] = exp1 - exp2
59
+ processed['Signal_Line'] = processed['MACD'].ewm(span=9, adjust=False).mean()
60
+
61
+ # Bollinger Bands
62
+ processed['BB_middle'] = processed['Close'].rolling(window=20).mean()
63
+ processed['BB_upper'] = processed['BB_middle'] + 2 * processed['Close'].rolling(window=20).std()
64
+ processed['BB_lower'] = processed['BB_middle'] - 2 * processed['Close'].rolling(window=20).std()
65
+
66
+ # Handle missing values
67
+ processed = processed.fillna(method='ffill').fillna(method='bfill')
68
+
69
+ # Scale numerical features
70
+ numerical_cols = ['Close', 'Volume', 'Returns', 'Volatility']
71
+ processed[numerical_cols] = self.scaler.fit_transform(processed[numerical_cols])
72
+
73
+ return processed, self.scaler
74
+
75
+ @staticmethod
76
+ def calculate_rsi(prices: pd.Series, period: int = 14) -> pd.Series:
77
+ delta = prices.diff()
78
+ gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
79
+ loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
80
+ rs = gain / loss
81
+ return 100 - (100 / (1 + rs))
82
+
83
+ class AgenticRAGFramework:
84
+ def __init__(self):
85
+ self.preprocessor = TimeSeriesPreprocessor()
86
+
87
+ def analyze(self, data: pd.DataFrame) -> Dict:
88
+ processed_data, scaler = self.preprocessor.process(data)
89
+
90
+ analysis = {
91
+ 'processed_data': processed_data,
92
+ 'trend': self.analyze_trend(processed_data),
93
+ 'technical': self.analyze_technical(processed_data),
94
+ 'volatility': self.analyze_volatility(processed_data),
95
+ 'summary': self.generate_summary(processed_data)
96
+ }
97
+
98
+ return analysis
99
+
100
+ def analyze_trend(self, data: pd.DataFrame) -> Dict:
101
+ sma_20 = data['SMA_20'].iloc[-1]
102
+ sma_50 = data['SMA_50'].iloc[-1]
103
+
104
+ trend = {
105
+ 'direction': 'Bullish' if sma_20 > sma_50 else 'Bearish',
106
+ 'strength': abs(sma_20 - sma_50) / sma_50,
107
+ 'sma_20': sma_20,
108
+ 'sma_50': sma_50
109
+ }
110
+
111
+ return trend
112
+
113
+ def analyze_technical(self, data: pd.DataFrame) -> Dict:
114
+ technical = {
115
+ 'rsi': data['RSI'].iloc[-1],
116
+ 'macd': data['MACD'].iloc[-1],
117
+ 'signal_line': data['Signal_Line'].iloc[-1],
118
+ 'bb_position': (data['Close'].iloc[-1] - data['BB_lower'].iloc[-1]) /
119
+ (data['BB_upper'].iloc[-1] - data['BB_lower'].iloc[-1])
120
+ }
121
+
122
+ return technical
123
+
124
+ def analyze_volatility(self, data: pd.DataFrame) -> Dict:
125
+ volatility = {
126
+ 'current': data['Volatility'].iloc[-1],
127
+ 'avg_20d': data['Volatility'].rolling(20).mean().iloc[-1],
128
+ 'trend': 'Increasing' if data['Volatility'].iloc[-1] > data['Volatility'].iloc[-2] else 'Decreasing'
129
+ }
130
+
131
+ return volatility
132
+
133
+ def generate_summary(self, data: pd.DataFrame) -> str:
134
+ latest_close = data['Close'].iloc[-1]
135
+ prev_close = data['Close'].iloc[-2]
136
+ daily_return = (latest_close - prev_close) / prev_close * 100
137
+
138
+ rsi = data['RSI'].iloc[-1]
139
+ volatility = data['Volatility'].iloc[-1]
140
+
141
+ summary = f"""Market Analysis Summary:
142
+
143
+ • Price Action: The stock {'increased' if daily_return > 0 else 'decreased'} by {abs(daily_return):.2f}% in the last session.
144
+
145
+ • Technical Indicators:
146
+ - RSI is at {rsi:.2f} indicating {'overbought' if rsi > 70 else 'oversold' if rsi < 30 else 'neutral'} conditions
147
+ - Current volatility is {volatility:.2f} which is {'high' if volatility > 0.5 else 'moderate' if volatility > 0.2 else 'low'}
148
+
149
+ • Market Signals:
150
+ - MACD: {'Bullish' if data['MACD'].iloc[-1] > data['Signal_Line'].iloc[-1] else 'Bearish'} crossover
151
+ - Bollinger Bands: Price is {
152
+ 'near upper band (potential resistance)' if data['BB_position'].iloc[-1] > 0.8
153
+ else 'near lower band (potential support)' if data['BB_position'].iloc[-1] < 0.2
154
+ else 'in middle range'}
155
+ """
156
+
157
+ return summary
158
+
159
+ def create_analysis_plots(data: pd.DataFrame, analysis: Dict) -> List[go.Figure]:
160
+ # Price and Technical Indicators Plot
161
+ fig1 = make_subplots(rows=2, cols=1, shared_xaxes=True,
162
+ subplot_titles=('Price and Technical Indicators', 'Volume'),
163
+ row_heights=[0.7, 0.3])
164
+
165
+ # Price and SMAs
166
+ fig1.add_trace(go.Scatter(x=data.index, y=data['Close'],
167
+ name='Close Price', line=dict(color='blue')), row=1, col=1)
168
+ fig1.add_trace(go.Scatter(x=data.index, y=data['SMA_20'],
169
+ name='SMA 20', line=dict(color='orange', dash='dash')), row=1, col=1)
170
+ fig1.add_trace(go.Scatter(x=data.index, y=data['SMA_50'],
171
+ name='SMA 50', line=dict(color='green', dash='dash')), row=1, col=1)
172
+
173
+ # Volume
174
+ fig1.add_trace(go.Bar(x=data.index, y=data['Volume'],
175
+ name='Volume', marker_color='lightblue'), row=2, col=1)
176
+
177
+ fig1.update_layout(height=600, title_text="Price Analysis")
178
+
179
+ # Technical Analysis Plot
180
+ fig2 = make_subplots(rows=3, cols=1, shared_xaxes=True,
181
+ subplot_titles=('RSI', 'MACD', 'Bollinger Bands'),
182
+ row_heights=[0.33, 0.33, 0.33])
183
+
184
+ # RSI
185
+ fig2.add_trace(go.Scatter(x=data.index, y=data['RSI'],
186
+ name='RSI', line=dict(color='purple')), row=1, col=1)
187
+ fig2.add_hline(y=70, line_dash="dash", line_color="red", row=1, col=1)
188
+ fig2.add_hline(y=30, line_dash="dash", line_color="green", row=1, col=1)
189
+
190
+ # MACD
191
+ fig2.add_trace(go.Scatter(x=data.index, y=data['MACD'],
192
+ name='MACD', line=dict(color='blue')), row=2, col=1)
193
+ fig2.add_trace(go.Scatter(x=data.index, y=data['Signal_Line'],
194
+ name='Signal Line', line=dict(color='red')), row=2, col=1)
195
+
196
+ # Bollinger Bands
197
+ fig2.add_trace(go.Scatter(x=data.index, y=data['BB_upper'],
198
+ name='Upper BB', line=dict(color='gray', dash='dash')), row=3, col=1)
199
+ fig2.add_trace(go.Scatter(x=data.index, y=data['BB_middle'],
200
+ name='Middle BB', line=dict(color='blue', dash='dash')), row=3, col=1)
201
+ fig2.add_trace(go.Scatter(x=data.index, y=data['BB_lower'],
202
+ name='Lower BB', line=dict(color='gray', dash='dash')), row=3, col=1)
203
+
204
+ fig2.update_layout(height=800, title_text="Technical Analysis")
205
+
206
+ return [fig1, fig2]
207
+
208
+ def analyze_stock(company: str, lookback_days: int) -> Tuple[str, List[go.Figure]]:
209
+ symbol = COMPANIES[company]
210
+ end_date = datetime.now()
211
+ start_date = end_date - timedelta(days=lookback_days)
212
+
213
+ # Download data
214
+ data = yf.download(symbol, start=start_date, end=end_date)
215
+
216
+ if len(data) == 0:
217
+ return "No data available for the selected period.", []
218
+
219
+ # Analyze data
220
+ framework = AgenticRAGFramework()
221
+ analysis = framework.analyze(data)
222
+
223
+ # Create plots
224
+ plots = create_analysis_plots(data, analysis)
225
+
226
+ return analysis['summary'], plots
227
+
228
+ def create_gradio_interface():
229
+ with gr.Blocks() as interface:
230
+ gr.Markdown("# Stock Market Analysis with Agentic RAG")
231
+
232
+ with gr.Row():
233
+ company = gr.Dropdown(choices=list(COMPANIES.keys()), label="Select Company")
234
+ lookback = gr.Slider(minimum=30, maximum=365, value=180, step=1, label="Lookback Period (days)")
235
+
236
+ analyze_btn = gr.Button("Analyze")
237
+
238
+ with gr.Row():
239
+ summary = gr.Textbox(label="Analysis Summary", lines=10)
240
+
241
+ with gr.Row():
242
+ plot1 = gr.Plot(label="Price Analysis")
243
+ plot2 = gr.Plot(label="Technical Analysis")
244
+
245
+ analyze_btn.click(
246
+ fn=analyze_stock,
247
+ inputs=[company, lookback],
248
+ outputs=[summary, plot1, plot2]
249
+ )
250
+
251
+ return interface
252
+
253
+ if __name__ == "__main__":
254
+ interface = create_gradio_interface()
255
+ interface.launch(share=True)