arjunanand13 commited on
Commit
8dbdb70
·
verified ·
1 Parent(s): 723772b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -178
app.py CHANGED
@@ -4,15 +4,10 @@ import pandas as pd
4
  import numpy as np
5
  import plotly.graph_objects as go
6
  from plotly.subplots import make_subplots
7
- import torch
8
- import torch.nn as nn
9
- from sklearn.preprocessing import StandardScaler
10
- from typing import Dict, List, Optional, Tuple, Union
11
  from datetime import datetime, timedelta
12
  import warnings
13
  warnings.filterwarnings('ignore')
14
 
15
- # Constants
16
  COMPANIES = {
17
  'Apple (AAPL)': 'AAPL',
18
  'Microsoft (MSFT)': 'MSFT',
@@ -36,204 +31,168 @@ COMPANIES = {
36
  'Netflix (NFLX)': 'NFLX'
37
  }
38
 
39
- class TimeSeriesPreprocessor:
40
- def __init__(self):
41
- self.scaler = StandardScaler()
42
-
43
- def process(self, data: pd.DataFrame) -> Tuple[pd.DataFrame, StandardScaler]:
44
- processed = data.copy()
45
-
46
- # Calculate returns and volatility
47
- processed['Returns'] = processed['Close'].pct_change()
48
- processed['Volatility'] = processed['Returns'].rolling(window=20).std()
49
-
50
- # Technical indicators
51
- processed['SMA_20'] = processed['Close'].rolling(window=20).mean()
52
- processed['SMA_50'] = processed['Close'].rolling(window=50).mean()
53
- processed['RSI'] = self.calculate_rsi(processed['Close'])
54
-
55
- # MACD
56
- exp1 = processed['Close'].ewm(span=12, adjust=False).mean()
57
- exp2 = processed['Close'].ewm(span=26, adjust=False).mean()
58
- processed['MACD'] = exp1 - exp2
59
- processed['Signal_Line'] = processed['MACD'].ewm(span=9, adjust=False).mean()
60
-
61
- # Bollinger Bands
62
- processed['BB_middle'] = processed['Close'].rolling(window=20).mean()
63
- processed['BB_upper'] = processed['BB_middle'] + 2 * processed['Close'].rolling(window=20).std()
64
- processed['BB_lower'] = processed['BB_middle'] - 2 * processed['Close'].rolling(window=20).std()
65
-
66
- # Handle missing values
67
- processed = processed.fillna(method='ffill').fillna(method='bfill')
68
-
69
- # Scale numerical features
70
- numerical_cols = ['Close', 'Volume', 'Returns', 'Volatility']
71
- processed[numerical_cols] = self.scaler.fit_transform(processed[numerical_cols])
72
-
73
- return processed, self.scaler
74
-
75
- @staticmethod
76
- def calculate_rsi(prices: pd.Series, period: int = 14) -> pd.Series:
77
- delta = prices.diff()
78
- gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
79
- loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
80
- rs = gain / loss
81
- return 100 - (100 / (1 + rs))
82
-
83
- class AgenticRAGFramework:
84
- def __init__(self):
85
- self.preprocessor = TimeSeriesPreprocessor()
86
 
87
- def analyze(self, data: pd.DataFrame) -> Dict:
88
- processed_data, scaler = self.preprocessor.process(data)
89
-
90
- analysis = {
91
- 'processed_data': processed_data,
92
- 'trend': self.analyze_trend(processed_data),
93
- 'technical': self.analyze_technical(processed_data),
94
- 'volatility': self.analyze_volatility(processed_data),
95
- 'summary': self.generate_summary(processed_data)
96
- }
97
-
98
- return analysis
99
 
100
- def analyze_trend(self, data: pd.DataFrame) -> Dict:
101
- sma_20 = data['SMA_20'].iloc[-1]
102
- sma_50 = data['SMA_50'].iloc[-1]
103
-
104
- trend = {
105
- 'direction': 'Bullish' if sma_20 > sma_50 else 'Bearish',
106
- 'strength': abs(sma_20 - sma_50) / sma_50,
107
- 'sma_20': sma_20,
108
- 'sma_50': sma_50
109
- }
110
-
111
- return trend
112
-
113
- def analyze_technical(self, data: pd.DataFrame) -> Dict:
114
- technical = {
115
- 'rsi': data['RSI'].iloc[-1],
116
- 'macd': data['MACD'].iloc[-1],
117
- 'signal_line': data['Signal_Line'].iloc[-1],
118
- 'bb_position': (data['Close'].iloc[-1] - data['BB_lower'].iloc[-1]) /
119
- (data['BB_upper'].iloc[-1] - data['BB_lower'].iloc[-1])
120
- }
121
-
122
- return technical
123
-
124
- def analyze_volatility(self, data: pd.DataFrame) -> Dict:
125
- volatility = {
126
- 'current': data['Volatility'].iloc[-1],
127
- 'avg_20d': data['Volatility'].rolling(20).mean().iloc[-1],
128
- 'trend': 'Increasing' if data['Volatility'].iloc[-1] > data['Volatility'].iloc[-2] else 'Decreasing'
129
- }
130
-
131
- return volatility
132
 
133
- def generate_summary(self, data: pd.DataFrame) -> str:
134
- latest_close = data['Close'].iloc[-1]
135
- prev_close = data['Close'].iloc[-2]
136
- daily_return = (latest_close - prev_close) / prev_close * 100
137
-
138
- rsi = data['RSI'].iloc[-1]
139
- volatility = data['Volatility'].iloc[-1]
140
-
141
- summary = f"""Market Analysis Summary:
142
-
143
- • Price Action: The stock {'increased' if daily_return > 0 else 'decreased'} by {abs(daily_return):.2f}% in the last session.
144
-
145
- • Technical Indicators:
146
- - RSI is at {rsi:.2f} indicating {'overbought' if rsi > 70 else 'oversold' if rsi < 30 else 'neutral'} conditions
147
- - Current volatility is {volatility:.2f} which is {'high' if volatility > 0.5 else 'moderate' if volatility > 0.2 else 'low'}
148
-
149
- • Market Signals:
150
- - MACD: {'Bullish' if data['MACD'].iloc[-1] > data['Signal_Line'].iloc[-1] else 'Bearish'} crossover
151
- - Bollinger Bands: Price is {
152
- 'near upper band (potential resistance)' if data['BB_position'].iloc[-1] > 0.8
153
- else 'near lower band (potential support)' if data['BB_position'].iloc[-1] < 0.2
154
- else 'in middle range'}
155
- """
156
-
157
- return summary
158
 
159
- def create_analysis_plots(data: pd.DataFrame, analysis: Dict) -> List[go.Figure]:
160
- # Price and Technical Indicators Plot
161
  fig1 = make_subplots(rows=2, cols=1, shared_xaxes=True,
162
- subplot_titles=('Price and Technical Indicators', 'Volume'),
163
- row_heights=[0.7, 0.3])
 
164
 
165
  # Price and SMAs
166
- fig1.add_trace(go.Scatter(x=data.index, y=data['Close'],
167
- name='Close Price', line=dict(color='blue')), row=1, col=1)
168
- fig1.add_trace(go.Scatter(x=data.index, y=data['SMA_20'],
169
- name='SMA 20', line=dict(color='orange', dash='dash')), row=1, col=1)
170
- fig1.add_trace(go.Scatter(x=data.index, y=data['SMA_50'],
171
- name='SMA 50', line=dict(color='green', dash='dash')), row=1, col=1)
 
 
 
 
 
 
172
 
173
  # Volume
174
- fig1.add_trace(go.Bar(x=data.index, y=data['Volume'],
175
- name='Volume', marker_color='lightblue'), row=2, col=1)
 
 
176
 
177
  fig1.update_layout(height=600, title_text="Price Analysis")
178
 
179
- # Technical Analysis Plot
180
- fig2 = make_subplots(rows=3, cols=1, shared_xaxes=True,
181
- subplot_titles=('RSI', 'MACD', 'Bollinger Bands'),
182
- row_heights=[0.33, 0.33, 0.33])
 
183
 
184
  # RSI
185
- fig2.add_trace(go.Scatter(x=data.index, y=data['RSI'],
186
- name='RSI', line=dict(color='purple')), row=1, col=1)
 
 
187
  fig2.add_hline(y=70, line_dash="dash", line_color="red", row=1, col=1)
188
  fig2.add_hline(y=30, line_dash="dash", line_color="green", row=1, col=1)
189
 
190
- # MACD
191
- fig2.add_trace(go.Scatter(x=data.index, y=data['MACD'],
192
- name='MACD', line=dict(color='blue')), row=2, col=1)
193
- fig2.add_trace(go.Scatter(x=data.index, y=data['Signal_Line'],
194
- name='Signal Line', line=dict(color='red')), row=2, col=1)
195
-
196
  # Bollinger Bands
197
- fig2.add_trace(go.Scatter(x=data.index, y=data['BB_upper'],
198
- name='Upper BB', line=dict(color='gray', dash='dash')), row=3, col=1)
199
- fig2.add_trace(go.Scatter(x=data.index, y=data['BB_middle'],
200
- name='Middle BB', line=dict(color='blue', dash='dash')), row=3, col=1)
201
- fig2.add_trace(go.Scatter(x=data.index, y=data['BB_lower'],
202
- name='Lower BB', line=dict(color='gray', dash='dash')), row=3, col=1)
203
-
204
- fig2.update_layout(height=800, title_text="Technical Analysis")
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
  return [fig1, fig2]
207
 
208
- def analyze_stock(company: str, lookback_days: int) -> Tuple[str, List[go.Figure]]:
209
- symbol = COMPANIES[company]
210
- end_date = datetime.now()
211
- start_date = end_date - timedelta(days=lookback_days)
212
-
213
- # Download data
214
- data = yf.download(symbol, start=start_date, end=end_date)
215
-
216
- if len(data) == 0:
217
- return "No data available for the selected period.", []
218
-
219
- # Analyze data
220
- framework = AgenticRAGFramework()
221
- analysis = framework.analyze(data)
222
-
223
- # Create plots
224
- plots = create_analysis_plots(data, analysis)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
- return analysis['summary'], plots
 
 
 
 
227
 
228
  def create_gradio_interface():
229
  with gr.Blocks() as interface:
230
- gr.Markdown("# Stock Market Analysis with Agentic RAG")
231
 
232
  with gr.Row():
233
- company = gr.Dropdown(choices=list(COMPANIES.keys()), label="Select Company")
234
- lookback = gr.Slider(minimum=30, maximum=365, value=180, step=1, label="Lookback Period (days)")
235
-
236
- analyze_btn = gr.Button("Analyze")
 
 
 
 
 
 
 
 
 
237
 
238
  with gr.Row():
239
  summary = gr.Textbox(label="Analysis Summary", lines=10)
@@ -242,12 +201,24 @@ def create_gradio_interface():
242
  plot1 = gr.Plot(label="Price Analysis")
243
  plot2 = gr.Plot(label="Technical Analysis")
244
 
245
- analyze_btn.click(
 
 
 
 
 
 
 
246
  fn=analyze_stock,
247
  inputs=[company, lookback],
248
  outputs=[summary, plot1, plot2]
249
  )
250
-
 
 
 
 
 
251
  return interface
252
 
253
  if __name__ == "__main__":
 
4
  import numpy as np
5
  import plotly.graph_objects as go
6
  from plotly.subplots import make_subplots
 
 
 
 
7
  from datetime import datetime, timedelta
8
  import warnings
9
  warnings.filterwarnings('ignore')
10
 
 
11
  COMPANIES = {
12
  'Apple (AAPL)': 'AAPL',
13
  'Microsoft (MSFT)': 'MSFT',
 
31
  'Netflix (NFLX)': 'NFLX'
32
  }
33
 
34
+ def calculate_metrics(data: pd.DataFrame) -> pd.DataFrame:
35
+ df = data.copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ # Basic metrics
38
+ df['Returns'] = df['Close'].pct_change()
39
+ df['SMA_20'] = df['Close'].rolling(window=20).mean()
40
+ df['SMA_50'] = df['Close'].rolling(window=50).mean()
 
 
 
 
 
 
 
 
41
 
42
+ # RSI
43
+ delta = df['Close'].diff()
44
+ gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
45
+ loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
46
+ rs = gain / loss
47
+ df['RSI'] = 100 - (100 / (1 + rs))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ # Bollinger Bands
50
+ df['BB_middle'] = df['Close'].rolling(window=20).mean()
51
+ bb_std = df['Close'].rolling(window=20).std()
52
+ df['BB_upper'] = df['BB_middle'] + (2 * bb_std)
53
+ df['BB_lower'] = df['BB_middle'] - (2 * bb_std)
54
+
55
+ return df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ def create_analysis_plots(data: pd.DataFrame) -> list:
58
+ # Price and Volume Plot
59
  fig1 = make_subplots(rows=2, cols=1, shared_xaxes=True,
60
+ subplot_titles=('Price and Moving Averages', 'Volume'),
61
+ row_heights=[0.7, 0.3],
62
+ vertical_spacing=0.1)
63
 
64
  # Price and SMAs
65
+ fig1.add_trace(
66
+ go.Scatter(x=data.index, y=data['Close'], name='Close', line=dict(color='blue')),
67
+ row=1, col=1
68
+ )
69
+ fig1.add_trace(
70
+ go.Scatter(x=data.index, y=data['SMA_20'], name='SMA 20', line=dict(color='orange', dash='dash')),
71
+ row=1, col=1
72
+ )
73
+ fig1.add_trace(
74
+ go.Scatter(x=data.index, y=data['SMA_50'], name='SMA 50', line=dict(color='green', dash='dash')),
75
+ row=1, col=1
76
+ )
77
 
78
  # Volume
79
+ fig1.add_trace(
80
+ go.Bar(x=data.index, y=data['Volume'], name='Volume', marker_color='lightblue'),
81
+ row=2, col=1
82
+ )
83
 
84
  fig1.update_layout(height=600, title_text="Price Analysis")
85
 
86
+ # Technical Indicators Plot
87
+ fig2 = make_subplots(rows=2, cols=1, shared_xaxes=True,
88
+ subplot_titles=('RSI', 'Bollinger Bands'),
89
+ row_heights=[0.5, 0.5],
90
+ vertical_spacing=0.1)
91
 
92
  # RSI
93
+ fig2.add_trace(
94
+ go.Scatter(x=data.index, y=data['RSI'], name='RSI', line=dict(color='purple')),
95
+ row=1, col=1
96
+ )
97
  fig2.add_hline(y=70, line_dash="dash", line_color="red", row=1, col=1)
98
  fig2.add_hline(y=30, line_dash="dash", line_color="green", row=1, col=1)
99
 
 
 
 
 
 
 
100
  # Bollinger Bands
101
+ fig2.add_trace(
102
+ go.Scatter(x=data.index, y=data['Close'], name='Close', line=dict(color='blue')),
103
+ row=2, col=1
104
+ )
105
+ fig2.add_trace(
106
+ go.Scatter(x=data.index, y=data['BB_upper'], name='Upper BB',
107
+ line=dict(color='gray', dash='dash')),
108
+ row=2, col=1
109
+ )
110
+ fig2.add_trace(
111
+ go.Scatter(x=data.index, y=data['BB_middle'], name='Middle BB',
112
+ line=dict(color='red', dash='dash')),
113
+ row=2, col=1
114
+ )
115
+ fig2.add_trace(
116
+ go.Scatter(x=data.index, y=data['BB_lower'], name='Lower BB',
117
+ line=dict(color='gray', dash='dash')),
118
+ row=2, col=1
119
+ )
120
+
121
+ fig2.update_layout(height=600, title_text="Technical Analysis")
122
 
123
  return [fig1, fig2]
124
 
125
+ def generate_summary(data: pd.DataFrame) -> str:
126
+ current_price = data['Close'].iloc[-1]
127
+ prev_price = data['Close'].iloc[-2]
128
+ daily_return = ((current_price - prev_price) / prev_price) * 100
129
+
130
+ rsi = data['RSI'].iloc[-1]
131
+ sma_20 = data['SMA_20'].iloc[-1]
132
+ sma_50 = data['SMA_50'].iloc[-1]
133
+
134
+ summary = f"""Market Analysis Summary:
135
+
136
+ Current Price: ${current_price:.2f}
137
+ Daily Change: {daily_return:+.2f}%
138
+ Trend: {'Bullish' if sma_20 > sma_50 else 'Bearish'} (20-day MA vs 50-day MA)
139
+ • RSI: {rsi:.2f} ({'Overbought' if rsi > 70 else 'Oversold' if rsi < 30 else 'Neutral'})
140
+ Volume: {data['Volume'].iloc[-1]:,.0f}
141
+
142
+ Technical Signals:
143
+ • Moving Averages: Price is {'above' if current_price > sma_20 else 'below'} 20-day MA
144
+ • Bollinger Bands: Price is {
145
+ 'near upper band (potential resistance)' if current_price > data['BB_upper'].iloc[-1] * 0.95
146
+ else 'near lower band (potential support)' if current_price < data['BB_lower'].iloc[-1] * 1.05
147
+ else 'in middle range'}
148
+ """
149
+ return summary
150
+
151
+ def analyze_stock(company: str, lookback_days: int = 180) -> tuple:
152
+ try:
153
+ symbol = COMPANIES[company]
154
+ end_date = datetime.now()
155
+ start_date = end_date - timedelta(days=lookback_days)
156
+
157
+ # Download data
158
+ data = yf.download(symbol, start=start_date, end=end_date)
159
+
160
+ if len(data) == 0:
161
+ return "No data available for the selected period.", None, None
162
+
163
+ # Calculate metrics
164
+ data = calculate_metrics(data)
165
+
166
+ # Generate analysis
167
+ summary = generate_summary(data)
168
+ plots = create_analysis_plots(data)
169
+
170
+ return summary, plots[0], plots[1]
171
 
172
+ except Exception as e:
173
+ return f"Error analyzing stock: {str(e)}", None, None
174
+
175
+ def refresh_analysis(company, lookback_days):
176
+ return analyze_stock(company, lookback_days)
177
 
178
  def create_gradio_interface():
179
  with gr.Blocks() as interface:
180
+ gr.Markdown("# Stock Market Analysis Dashboard")
181
 
182
  with gr.Row():
183
+ company = gr.Dropdown(
184
+ choices=list(COMPANIES.keys()),
185
+ label="Select Company",
186
+ value="Apple (AAPL)"
187
+ )
188
+ lookback = gr.Slider(
189
+ minimum=30,
190
+ maximum=365,
191
+ value=180,
192
+ step=1,
193
+ label="Lookback Period (days)"
194
+ )
195
+ refresh_btn = gr.Button("Refresh Analysis")
196
 
197
  with gr.Row():
198
  summary = gr.Textbox(label="Analysis Summary", lines=10)
 
201
  plot1 = gr.Plot(label="Price Analysis")
202
  plot2 = gr.Plot(label="Technical Analysis")
203
 
204
+ refresh_btn.click(
205
+ fn=refresh_analysis,
206
+ inputs=[company, lookback],
207
+ outputs=[summary, plot1, plot2]
208
+ )
209
+
210
+ # Also trigger analysis when company or lookback period changes
211
+ company.change(
212
  fn=analyze_stock,
213
  inputs=[company, lookback],
214
  outputs=[summary, plot1, plot2]
215
  )
216
+ lookback.release(
217
+ fn=analyze_stock,
218
+ inputs=[company, lookback],
219
+ outputs=[summary, plot1, plot2]
220
+ )
221
+
222
  return interface
223
 
224
  if __name__ == "__main__":