Reality123b commited on
Commit
e4c8e31
·
verified ·
1 Parent(s): 9b26658

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -146
app.py CHANGED
@@ -1,156 +1,193 @@
1
  import gradio as gr
2
- import yfinance as yf
 
 
3
  import pandas as pd
 
 
4
  import numpy as np
5
- from sklearn.model_selection import train_test_split
6
  from sklearn.preprocessing import MinMaxScaler
7
- from sklearn.linear_model import LinearRegression
8
- from sklearn.metrics import mean_squared_error, r2_score
9
- import mplfinance as mpf
10
- import matplotlib.pyplot as plt
11
-
12
-
13
- def get_stock_data(symbol, timeframe):
14
- """Fetches stock data from Yahoo Finance."""
15
- ticker = yf.Ticker(symbol)
16
-
17
- # Calculate period based on timeframe
18
- if timeframe in ['1m', '5m', '15m', '30m']:
19
- period = "1d"
20
- elif timeframe in ['1h']:
21
- period = "5d"
22
- else:
23
- period = "60d"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- data = ticker.history(period=period, interval=timeframe)
26
-
27
- if data.empty:
28
- raise ValueError(f"No data found for symbol '{symbol}' with timeframe '{timeframe}'.")
29
-
30
- return data
31
-
32
-
33
-
34
- def calculate_indicators(data):
35
- """Calculates technical indicators."""
36
- data['SMA20'] = data['Close'].rolling(window=20).mean() # Simple Moving Average (20 days)
37
- data['EMA20'] = data['Close'].ewm(span=20, adjust=False).mean() # Exponential Moving Average (20 days)
38
- data['RSI'] = calculate_rsi(data['Close']) # Relative Strength Index
39
- data['MACD'], data['MACD_Signal'], _ = calculate_macd(data['Close']) # Moving Average Convergence Divergence
40
- data['Stochastic_K'], data['Stochastic_D'] = calculate_stochastic(data['High'], data['Low'], data['Close']) # Stochastic Oscillator
41
- return data
42
-
43
-
44
-
45
- def calculate_rsi(close_prices, period=14):
46
- """Calculates the Relative Strength Index (RSI)."""
47
- delta = close_prices.diff()
48
- gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
49
- loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
50
- rs = gain / loss
51
- rsi = 100 - (100 / (1 + rs))
52
- return rsi
53
-
54
- def calculate_macd(close_prices, fast_period=12, slow_period=26, signal_period=9):
55
- """Calculates the Moving Average Convergence Divergence (MACD)."""
56
- fast_ema = close_prices.ewm(span=fast_period, adjust=False).mean()
57
- slow_ema = close_prices.ewm(span=slow_period, adjust=False).mean()
58
- macd = fast_ema - slow_ema
59
- macd_signal = macd.ewm(span=signal_period, adjust=False).mean()
60
- macd_histogram = macd - macd_signal
61
- return macd, macd_signal, macd_histogram
62
-
63
- def calculate_stochastic(high_prices, low_prices, close_prices, period=14):
64
- """Calculates the Stochastic Oscillator."""
65
- lowest_low = low_prices.rolling(window=period).min()
66
- highest_high = high_prices.rolling(window=period).max()
67
- k = ((close_prices - lowest_low) / (highest_high - lowest_low)) * 100
68
- d = k.rolling(window=3).mean()
69
- return k, d
70
-
71
- def predict_next_day(symbol, timeframe):
72
- """Predicts the next day's closing price."""
73
- data = get_stock_data(symbol, timeframe)
74
- data = calculate_indicators(data)
75
-
76
- # Prepare data for training
77
- data = data.dropna()
78
- X = data[['SMA20', 'EMA20', 'RSI', 'MACD', 'MACD_Signal', 'Stochastic_K', 'Stochastic_D']]
79
- y = data['Close']
80
-
81
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
82
-
83
- # Scale the data
84
- scaler = MinMaxScaler()
85
- # Create a linear regression model
86
- model = LinearRegression()
87
- model.fit(X_train, y_train)
88
-
89
- # Predict next day's closing price
90
- last_data_point = data.iloc[-1]
91
- last_data_point = last_data_point[['SMA20', 'EMA20', 'RSI', 'MACD', 'MACD_Signal', 'Stochastic_K', 'Stochastic_D']]
92
- predicted_price = model.predict([last_data_point.values])[0]
93
-
94
- # Calculate model evaluation metrics
95
- y_pred = model.predict(X_test)
96
- mse = mean_squared_error(y_test, y_pred)
97
- rmse = np.sqrt(mse)
98
- r2 = r2_score(y_test, y_pred)
99
-
100
- print(f"Mean Squared Error: {mse:.2f}")
101
- print(f"Root Mean Squared Error: {rmse:.2f}")
102
- print(f"R-squared: {r2:.2f}")
103
-
104
- return predicted_price
105
-
106
- def plot_candlestick(data, symbol, timeframe, predicted_price=None):
107
- """Plots the candlestick chart with technical indicators."""
108
- if data.empty:
109
- raise ValueError("No valid data to plot. Please check your inputs.")
110
-
111
- fig, ax = plt.subplots(figsize=(12, 6))
112
- mpf.plot(data, type='candle', style='charles', ax=ax, volume=True, show_nontrading=True)
113
-
114
- # Add moving averages
115
- ax.plot(data.index, data['SMA20'], label='SMA20', color='blue', alpha=0.7)
116
- ax.plot(data.index, data['EMA20'], label='EMA20', color='red', alpha=0.7)
117
-
118
- # Add prediction
119
- if predicted_price is not None:
120
- last_timestamp = data.index[-1] + pd.Timedelta(timeframe)
121
- ax.scatter(last_timestamp, predicted_price, color='green', marker='*', s=100, label='Prediction')
122
-
123
- ax.legend()
124
- ax.set_title(f"{symbol} - {timeframe}")
125
- ax.tick_params(axis='x', rotation=45)
126
- fig.tight_layout()
127
- return fig
128
-
129
 
130
  def main():
131
- """Gradio Interface."""
132
-
133
- symbol_input = gr.Textbox("AAPL", label="Symbol", interactive=True) # Moved "AAPL" to the correct position
134
- timeframe_input = gr.Dropdown(label="Timeframe", choices=["1m", "5m", "15m", "30m", "1h", "1d"], value="1d")
135
-
136
- with gr.Blocks() as interface:
137
- gr.Markdown("## Real-time Stock Market Analysis")
138
- with gr.Row():
139
- symbol_input = gr.Textbox("AAPL", label="Symbol", interactive=True)
140
- timeframe_input = gr.Dropdown(label="Timeframe", choices=["1m", "5m", "15m", "30m", "1h", "1d"], value="1d", interactive=True)
141
-
142
- with gr.Row():
143
- predict_button = gr.Button(value="Predict")
144
- predicted_price = gr.Textbox(label="Predicted Price")
145
-
146
- with gr.Row():
147
- output_plot = gr.Plot(label="Candlestick Chart")
148
-
149
- predict_button.click(fn=predict_next_day, inputs=[symbol_input, timeframe_input], outputs=predicted_price)
150
-
151
- predicted_price.change(fn=plot_candlestick, inputs=[symbol_input, timeframe_input, predicted_price], outputs=output_plot)
152
-
153
- interface.launch()
154
 
155
  if __name__ == "__main__":
156
  main()
 
1
  import gradio as gr
2
+ import matplotlib.pyplot as plt
3
+ from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
4
+ import mplfinance as mpf
5
  import pandas as pd
6
+ import yfinance as yf
7
+ from datetime import datetime, timedelta
8
  import numpy as np
 
9
  from sklearn.preprocessing import MinMaxScaler
10
+ from tensorflow.keras.models import Sequential
11
+ from tensorflow.keras.layers import LSTM, Dense
12
+ from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
13
+ import os
14
+
15
+
16
+ class CandlestickApp:
17
+ def __init__(self):
18
+ self.current_symbol = None
19
+ self.data = None
20
+ self.prediction_data = None
21
+ self.model = None
22
+ self.model_path = "models/" # Specify the directory to save models
23
+
24
+ def get_stock_data(self, symbol, timeframe, start_date, end_date):
25
+ try:
26
+ ticker = yf.Ticker(symbol)
27
+ data = ticker.history(start=start_date, end=end_date, interval=timeframe) # Now with start/end dates
28
+ return data
29
+ except Exception as e:
30
+ return None
31
+
32
+ def calculate_indicators(self, data):
33
+ # Calculate RSI, SMA20, SMA50 (as before)
34
+ delta = data['Close'].diff()
35
+ gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
36
+ loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
37
+ rs = gain / loss
38
+ data['RSI'] = 100 - (100 / (1 + rs))
39
+
40
+ data['SMA20'] = data['Close'].rolling(window=20).mean()
41
+ data['SMA50'] = data['Close'].rolling(window=50).mean()
42
+
43
+ # Add more indicators as needed
44
+ # ...
45
+ return data
46
+
47
+ def plot_candlestick_chart(self, data, symbol, timeframe):
48
+ # Chart plotting logic (remains the same)
49
+ fig = plt.figure(figsize=(12, 6))
50
+ ax1 = fig.add_subplot(211)
51
+ ax2 = fig.add_subplot(212, sharex=ax1)
52
+
53
+ mpf.plot(data,
54
+ type='candle',
55
+ style='charles',
56
+ ax=ax1,
57
+ volume=ax2,
58
+ show_nontrading=True)
59
+
60
+ # Add moving averages
61
+ if len(data) >= 20:
62
+ ax1.plot(data.index, data['SMA20'], label='SMA20', color='blue', alpha=0.7)
63
+ if len(data) >= 50:
64
+ ax1.plot(data.index, data['SMA50'], label='SMA50', color='red', alpha=0.7)
65
+
66
+ # Add Prediction (if available)
67
+ if self.prediction_data is not None:
68
+ ax1.scatter(self.prediction_data['timestamp'],
69
+ self.prediction_data['price'],
70
+ color='purple',
71
+ marker='*',
72
+ s=100,
73
+ label='Prediction')
74
+ ax1.legend()
75
+ ax1.set_title(f"{symbol} - {timeframe}")
76
+ ax1.tick_params(axis='x', rotation=45)
77
+ fig.tight_layout()
78
+
79
+ canvas = FigureCanvas(fig)
80
+ image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
81
+ image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
82
+ return image
83
+
84
+
85
+ def predict_next_movement(self, data):
86
+ # Improved prediction using LSTM and potentially larger model
87
+
88
+ if self.model is None:
89
+ self.create_lstm_model(data)
90
+
91
+ # Prepare data for prediction
92
+ dataset = data['Close'].values.reshape(-1, 1)
93
+ scaler = MinMaxScaler(feature_range=(0, 1))
94
+ dataset = scaler.fit_transform(dataset)
95
+
96
+ # Prepare the input sequence
97
+ look_back = 20
98
+ X_test = []
99
+ X_test.append(dataset[-look_back:, 0])
100
+ X_test = np.array(X_test)
101
+ X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))
102
+
103
+ # Make prediction
104
+ predicted_price = self.model.predict(X_test)
105
+ predicted_price = scaler.inverse_transform(predicted_price)[0][0]
106
+
107
+ # Store prediction data
108
+ self.prediction_data = {
109
+ 'timestamp': data.index[-1] + pd.Timedelta(self.timeframe_var.get()),
110
+ 'price': predicted_price
111
+ }
112
+
113
+ return predicted_price
114
 
115
+ def create_lstm_model(self, data):
116
+ # Enhanced Model Training (larger model, more features, callbacks)
117
+
118
+ # Standardize the data for the model
119
+ dataset = data['Close'].values.reshape(-1, 1)
120
+ scaler = MinMaxScaler(feature_range=(0, 1))
121
+ dataset = scaler.fit_transform(dataset)
122
+ train_size = int(len(dataset) * 0.8)
123
+ test_size = len(dataset) - train_size
124
+ train_data, test_data = dataset[0:train_size,:], dataset[train_size:len(dataset),:]
125
+
126
+ # Create dataset for LSTM with possible modifications
127
+ def create_dataset(dataset, look_back=1):
128
+ X, Y = [], []
129
+ for i in range(len(dataset)-look_back-1):
130
+ a = dataset[i:(i+look_back), 0]
131
+ X.append(a)
132
+ Y.append(dataset[i + look_back, 0])
133
+ return np.array(X), np.array(Y)
134
+
135
+ look_back = 100
136
+ X_train, Y_train = create_dataset(train_data, look_back)
137
+ X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1))
138
+
139
+ # Create and fit the LSTM network; potentially with more layers
140
+ model = Sequential()
141
+ model.add(LSTM(units=256, return_sequences=True, input_shape=(X_train.shape[1], 1)))
142
+ model.add(LSTM(units=128, return_sequences=True)) # Add more LSTM layers if needed
143
+ model.add(LSTM(units=64))
144
+ model.add(Dense(1))
145
+ model.compile(loss='mean_squared_error', optimizer='adam')
146
+
147
+ os.makedirs(self.model_path, exist_ok=True)
148
+ filepath = os.path.join(self.model_path,"stock_prediction_model.h5")
149
+ checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True, mode='min') # Save best model
150
+ early_stop = EarlyStopping(monitor='loss', patience=10, restore_best_weights=True) # Prevent overfitting
151
+
152
+ # Train the model
153
+ model.fit(X_train, Y_train, epochs=500, batch_size=64, callbacks=[checkpoint, early_stop]) # Increase epochs potentially
154
+
155
+ self.model = model
156
+
157
+ def inference(self, symbol, timeframe, start_date, end_date):
158
+ data = self.get_stock_data(symbol, timeframe, start_date, end_date)
159
+
160
+ if data is None:
161
+ return "Error fetching data", None
162
+
163
+ data = self.calculate_indicators(data)
164
+
165
+ if len(data) < 20:
166
+ return "Insufficient data for prediction & chart", None
167
+
168
+ predicted_price = self.predict_next_movement(data)
169
+ chart = self.plot_candlestick_chart(data, symbol, timeframe)
170
+ return f"Predicted price: ${predicted_price:.2f}", chart
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  def main():
173
+ app = CandlestickApp()
174
+ iface = gr.Interface(
175
+ fn=app.inference,
176
+ inputs=[
177
+ gr.inputs.Textbox(lines=1, placeholder="Enter Stock Symbol (e.g., AAPL)", label="Stock Symbol"),
178
+ gr.inputs.Dropdown(["1m", "5m", "15m", "30m", "1h", "1d"], label="Timeframe"),
179
+ gr.inputs.DatePicker(label="Start Date"), # New, for start date
180
+ gr.inputs.DatePicker(label="End Date"), # New, for end date
181
+ ],
182
+ outputs=[
183
+ gr.outputs.Textbox(label="Prediction"),
184
+ gr.outputs.Image(label="Candlestick Chart"),
185
+ ],
186
+ title="Stock Market Prediction & Analysis (Enhanced)",
187
+ description="Enter a stock symbol, timeframe, and date range to get a prediction and candlestick chart analysis.",
188
+ )
189
+
190
+ iface.launch()
 
 
 
 
 
191
 
192
  if __name__ == "__main__":
193
  main()