Spaces:
Runtime error
Runtime error
File size: 7,393 Bytes
1e0965c e4c8e31 1e0965c e4c8e31 1e0965c e4c8e31 5deb033 e4c8e31 1e0965c e4c8e31 5deb033 e4c8e31 fe2a089 e4c8e31 6f04329 1e0965c e4c8e31 1e0965c c15e868 1e0965c |
|
import gradio as gr
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import mplfinance as mpf
import pandas as pd
import yfinance as yf
from datetime import datetime, timedelta
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
import os
class CandlestickApp:
def __init__(self):
self.current_symbol = None
self.data = None
self.prediction_data = None
self.model = None
self.model_path = "models/" # Specify the directory to save models
def get_stock_data(self, symbol, timeframe, start_date, end_date):
try:
ticker = yf.Ticker(symbol)
data = ticker.history(start=start_date, end=end_date, interval=timeframe) # Now with start/end dates
return data
except Exception as e:
return None
def calculate_indicators(self, data):
# Calculate RSI, SMA20, SMA50 (as before)
delta = data['Close'].diff()
gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
rs = gain / loss
data['RSI'] = 100 - (100 / (1 + rs))
data['SMA20'] = data['Close'].rolling(window=20).mean()
data['SMA50'] = data['Close'].rolling(window=50).mean()
# Add more indicators as needed
# ...
return data
def plot_candlestick_chart(self, data, symbol, timeframe):
# Chart plotting logic (remains the same)
fig = plt.figure(figsize=(12, 6))
ax1 = fig.add_subplot(211)
ax2 = fig.add_subplot(212, sharex=ax1)
mpf.plot(data,
type='candle',
style='charles',
ax=ax1,
volume=ax2,
show_nontrading=True)
# Add moving averages
if len(data) >= 20:
ax1.plot(data.index, data['SMA20'], label='SMA20', color='blue', alpha=0.7)
if len(data) >= 50:
ax1.plot(data.index, data['SMA50'], label='SMA50', color='red', alpha=0.7)
# Add Prediction (if available)
if self.prediction_data is not None:
ax1.scatter(self.prediction_data['timestamp'],
self.prediction_data['price'],
color='purple',
marker='*',
s=100,
label='Prediction')
ax1.legend()
ax1.set_title(f"{symbol} - {timeframe}")
ax1.tick_params(axis='x', rotation=45)
fig.tight_layout()
canvas = FigureCanvas(fig)
image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return image
def predict_next_movement(self, data):
# Improved prediction using LSTM and potentially larger model
if self.model is None:
self.create_lstm_model(data)
# Prepare data for prediction
dataset = data['Close'].values.reshape(-1, 1)
scaler = MinMaxScaler(feature_range=(0, 1))
dataset = scaler.fit_transform(dataset)
# Prepare the input sequence
look_back = 20
X_test = []
X_test.append(dataset[-look_back:, 0])
X_test = np.array(X_test)
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))
# Make prediction
predicted_price = self.model.predict(X_test)
predicted_price = scaler.inverse_transform(predicted_price)[0][0]
# Store prediction data
self.prediction_data = {
'timestamp': data.index[-1] + pd.Timedelta(self.timeframe_var.get()),
'price': predicted_price
}
return predicted_price
def create_lstm_model(self, data):
# Enhanced Model Training (larger model, more features, callbacks)
# Standardize the data for the model
dataset = data['Close'].values.reshape(-1, 1)
scaler = MinMaxScaler(feature_range=(0, 1))
dataset = scaler.fit_transform(dataset)
train_size = int(len(dataset) * 0.8)
test_size = len(dataset) - train_size
train_data, test_data = dataset[0:train_size,:], dataset[train_size:len(dataset),:]
# Create dataset for LSTM with possible modifications
def create_dataset(dataset, look_back=1):
X, Y = [], []
for i in range(len(dataset)-look_back-1):
a = dataset[i:(i+look_back), 0]
X.append(a)
Y.append(dataset[i + look_back, 0])
return np.array(X), np.array(Y)
look_back = 100
X_train, Y_train = create_dataset(train_data, look_back)
X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1))
# Create and fit the LSTM network; potentially with more layers
model = Sequential()
model.add(LSTM(units=256, return_sequences=True, input_shape=(X_train.shape[1], 1)))
model.add(LSTM(units=128, return_sequences=True)) # Add more LSTM layers if needed
model.add(LSTM(units=64))
model.add(Dense(1))
model.compile(loss='mean_squared_error', optimizer='adam')
os.makedirs(self.model_path, exist_ok=True)
filepath = os.path.join(self.model_path,"stock_prediction_model.h5")
checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True, mode='min') # Save best model
early_stop = EarlyStopping(monitor='loss', patience=10, restore_best_weights=True) # Prevent overfitting
# Train the model
model.fit(X_train, Y_train, epochs=500, batch_size=64, callbacks=[checkpoint, early_stop]) # Increase epochs potentially
self.model = model
def inference(self, symbol, timeframe, start_date, end_date):
data = self.get_stock_data(symbol, timeframe, start_date, end_date)
if data is None:
return "Error fetching data", None
data = self.calculate_indicators(data)
if len(data) < 20:
return "Insufficient data for prediction & chart", None
predicted_price = self.predict_next_movement(data)
chart = self.plot_candlestick_chart(data, symbol, timeframe)
return f"Predicted price: ${predicted_price:.2f}", chart
def main():
app = CandlestickApp()
iface = gr.Interface(
fn=app.inference,
inputs=[
gr.inputs.Textbox(lines=1, placeholder="Enter Stock Symbol (e.g., AAPL)", label="Stock Symbol"),
gr.inputs.Dropdown(["1m", "5m", "15m", "30m", "1h", "1d"], label="Timeframe"),
gr.inputs.DatePicker(label="Start Date"), # New, for start date
gr.inputs.DatePicker(label="End Date"), # New, for end date
],
outputs=[
gr.outputs.Textbox(label="Prediction"),
gr.outputs.Image(label="Candlestick Chart"),
],
title="Stock Market Prediction & Analysis (Enhanced)",
description="Enter a stock symbol, timeframe, and date range to get a prediction and candlestick chart analysis.",
)
iface.launch()
if __name__ == "__main__":
main() |