Spaces:
Runtime error
Runtime error
import gradio as gr | |
import matplotlib.pyplot as plt | |
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas | |
import mplfinance as mpf | |
import pandas as pd | |
import yfinance as yf | |
from datetime import datetime, timedelta | |
import numpy as np | |
from sklearn.preprocessing import MinMaxScaler | |
from tensorflow.keras.models import Sequential | |
from tensorflow.keras.layers import LSTM, Dense | |
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping | |
import os | |
class CandlestickApp: | |
def __init__(self): | |
self.current_symbol = None | |
self.data = None | |
self.prediction_data = None | |
self.model = None | |
self.model_path = "models/" # Specify the directory to save models | |
def get_stock_data(self, symbol, timeframe, start_date, end_date): | |
try: | |
ticker = yf.Ticker(symbol) | |
data = ticker.history(start=start_date, end=end_date, interval=timeframe) # Now with start/end dates | |
return data | |
except Exception as e: | |
return None | |
def calculate_indicators(self, data): | |
# Calculate RSI, SMA20, SMA50 (as before) | |
delta = data['Close'].diff() | |
gain = (delta.where(delta > 0, 0)).rolling(window=14).mean() | |
loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean() | |
rs = gain / loss | |
data['RSI'] = 100 - (100 / (1 + rs)) | |
data['SMA20'] = data['Close'].rolling(window=20).mean() | |
data['SMA50'] = data['Close'].rolling(window=50).mean() | |
# Add more indicators as needed | |
# ... | |
return data | |
def plot_candlestick_chart(self, data, symbol, timeframe): | |
# Chart plotting logic (remains the same) | |
fig = plt.figure(figsize=(12, 6)) | |
ax1 = fig.add_subplot(211) | |
ax2 = fig.add_subplot(212, sharex=ax1) | |
mpf.plot(data, | |
type='candle', | |
style='charles', | |
ax=ax1, | |
volume=ax2, | |
show_nontrading=True) | |
# Add moving averages | |
if len(data) >= 20: | |
ax1.plot(data.index, data['SMA20'], label='SMA20', color='blue', alpha=0.7) | |
if len(data) >= 50: | |
ax1.plot(data.index, data['SMA50'], label='SMA50', color='red', alpha=0.7) | |
# Add Prediction (if available) | |
if self.prediction_data is not None: | |
ax1.scatter(self.prediction_data['timestamp'], | |
self.prediction_data['price'], | |
color='purple', | |
marker='*', | |
s=100, | |
label='Prediction') | |
ax1.legend() | |
ax1.set_title(f"{symbol} - {timeframe}") | |
ax1.tick_params(axis='x', rotation=45) | |
fig.tight_layout() | |
canvas = FigureCanvas(fig) | |
image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8') | |
image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
return image | |
def predict_next_movement(self, data): | |
# Improved prediction using LSTM and potentially larger model | |
if self.model is None: | |
self.create_lstm_model(data) | |
# Prepare data for prediction | |
dataset = data['Close'].values.reshape(-1, 1) | |
scaler = MinMaxScaler(feature_range=(0, 1)) | |
dataset = scaler.fit_transform(dataset) | |
# Prepare the input sequence | |
look_back = 20 | |
X_test = [] | |
X_test.append(dataset[-look_back:, 0]) | |
X_test = np.array(X_test) | |
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1)) | |
# Make prediction | |
predicted_price = self.model.predict(X_test) | |
predicted_price = scaler.inverse_transform(predicted_price)[0][0] | |
# Store prediction data | |
self.prediction_data = { | |
'timestamp': data.index[-1] + pd.Timedelta(self.timeframe_var.get()), | |
'price': predicted_price | |
} | |
return predicted_price | |
def create_lstm_model(self, data): | |
# Enhanced Model Training (larger model, more features, callbacks) | |
# Standardize the data for the model | |
dataset = data['Close'].values.reshape(-1, 1) | |
scaler = MinMaxScaler(feature_range=(0, 1)) | |
dataset = scaler.fit_transform(dataset) | |
train_size = int(len(dataset) * 0.8) | |
test_size = len(dataset) - train_size | |
train_data, test_data = dataset[0:train_size,:], dataset[train_size:len(dataset),:] | |
# Create dataset for LSTM with possible modifications | |
def create_dataset(dataset, look_back=1): | |
X, Y = [], [] | |
for i in range(len(dataset)-look_back-1): | |
a = dataset[i:(i+look_back), 0] | |
X.append(a) | |
Y.append(dataset[i + look_back, 0]) | |
return np.array(X), np.array(Y) | |
look_back = 100 | |
X_train, Y_train = create_dataset(train_data, look_back) | |
X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1)) | |
# Create and fit the LSTM network; potentially with more layers | |
model = Sequential() | |
model.add(LSTM(units=256, return_sequences=True, input_shape=(X_train.shape[1], 1))) | |
model.add(LSTM(units=128, return_sequences=True)) # Add more LSTM layers if needed | |
model.add(LSTM(units=64)) | |
model.add(Dense(1)) | |
model.compile(loss='mean_squared_error', optimizer='adam') | |
os.makedirs(self.model_path, exist_ok=True) | |
filepath = os.path.join(self.model_path,"stock_prediction_model.h5") | |
checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True, mode='min') # Save best model | |
early_stop = EarlyStopping(monitor='loss', patience=10, restore_best_weights=True) # Prevent overfitting | |
# Train the model | |
model.fit(X_train, Y_train, epochs=500, batch_size=64, callbacks=[checkpoint, early_stop]) # Increase epochs potentially | |
self.model = model | |
def inference(self, symbol, timeframe, start_date, end_date): | |
data = self.get_stock_data(symbol, timeframe, start_date, end_date) | |
if data is None: | |
return "Error fetching data", None | |
data = self.calculate_indicators(data) | |
if len(data) < 20: | |
return "Insufficient data for prediction & chart", None | |
predicted_price = self.predict_next_movement(data) | |
chart = self.plot_candlestick_chart(data, symbol, timeframe) | |
return f"Predicted price: ${predicted_price:.2f}", chart | |
def main(): | |
app = CandlestickApp() | |
iface = gr.Interface( | |
fn=app.inference, | |
inputs=[ | |
gr.inputs.Textbox(lines=1, placeholder="Enter Stock Symbol (e.g., AAPL)", label="Stock Symbol"), | |
gr.inputs.Dropdown(["1m", "5m", "15m", "30m", "1h", "1d"], label="Timeframe"), | |
gr.inputs.DatePicker(label="Start Date"), # New, for start date | |
gr.inputs.DatePicker(label="End Date"), # New, for end date | |
], | |
outputs=[ | |
gr.outputs.Textbox(label="Prediction"), | |
gr.outputs.Image(label="Candlestick Chart"), | |
], | |
title="Stock Market Prediction & Analysis (Enhanced)", | |
description="Enter a stock symbol, timeframe, and date range to get a prediction and candlestick chart analysis.", | |
) | |
iface.launch() | |
if __name__ == "__main__": | |
main() |