import gradio as gr
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import mplfinance as mpf
import pandas as pd
import yfinance as yf
from datetime import datetime, timedelta
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
import os

class CandlestickApp:
    def __init__(self):
        self.current_symbol = None
        self.data = None
        self.prediction_data = None
        self.model = None
        self.model_path = "models/"  # Specify the directory to save models

    def get_stock_data(self, symbol, timeframe, start_date, end_date):
        try:
            ticker = yf.Ticker(symbol)
            data = ticker.history(start=start_date, end=end_date, interval=timeframe)  # Now with start/end dates
            return data
        except Exception as e:
            return None

    def calculate_indicators(self, data):
        # Calculate RSI, SMA20, SMA50 (as before)
        delta = data['Close'].diff()
        gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
        loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
        rs = gain / loss
        data['RSI'] = 100 - (100 / (1 + rs))

        data['SMA20'] = data['Close'].rolling(window=20).mean()
        data['SMA50'] = data['Close'].rolling(window=50).mean()

        # Add more indicators as needed
        # ...
        return data

    def plot_candlestick_chart(self, data, symbol, timeframe):
        # Chart plotting logic (remains the same)
        fig = plt.figure(figsize=(12, 6))
        ax1 = fig.add_subplot(211)
        ax2 = fig.add_subplot(212, sharex=ax1)

        mpf.plot(data,
                type='candle',
                style='charles',
                ax=ax1,
                volume=ax2,
                show_nontrading=True)

        # Add moving averages
        if len(data) >= 20:
            ax1.plot(data.index, data['SMA20'], label='SMA20', color='blue', alpha=0.7)
        if len(data) >= 50:
            ax1.plot(data.index, data['SMA50'], label='SMA50', color='red', alpha=0.7)

        # Add Prediction (if available)
        if self.prediction_data is not None:
            ax1.scatter(self.prediction_data['timestamp'],
                           self.prediction_data['price'],
                           color='purple',
                           marker='*',
                           s=100,
                           label='Prediction')
        ax1.legend()
        ax1.set_title(f"{symbol} - {timeframe}")
        ax1.tick_params(axis='x', rotation=45)
        fig.tight_layout()

        canvas = FigureCanvas(fig)
        image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
        image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        return image

    def predict_next_movement(self, data):
        # Improved prediction using LSTM and potentially larger model

        if self.model is None:
            self.create_lstm_model(data)

        # Prepare data for prediction
        dataset = data['Close'].values.reshape(-1, 1)
        scaler = MinMaxScaler(feature_range=(0, 1))
        dataset = scaler.fit_transform(dataset)

        # Prepare the input sequence 
        look_back = 20 
        X_test = []
        X_test.append(dataset[-look_back:, 0])
        X_test = np.array(X_test)
        X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], 1))

        # Make prediction
        predicted_price = self.model.predict(X_test)
        predicted_price = scaler.inverse_transform(predicted_price)[0][0]

        # Store prediction data
        self.prediction_data = {
            'timestamp': data.index[-1] + pd.Timedelta(self.timeframe_var.get()),
            'price': predicted_price
        }

        return predicted_price
    
    def create_lstm_model(self, data):
        # Enhanced Model Training (larger model, more features, callbacks)

        # Standardize the data for the model
        dataset = data['Close'].values.reshape(-1, 1)
        scaler = MinMaxScaler(feature_range=(0, 1))
        dataset = scaler.fit_transform(dataset)
        train_size = int(len(dataset) * 0.8)
        test_size = len(dataset) - train_size
        train_data, test_data = dataset[0:train_size,:], dataset[train_size:len(dataset),:]

        # Create dataset for LSTM with possible modifications
        def create_dataset(dataset, look_back=1):
            X, Y = [], []
            for i in range(len(dataset)-look_back-1):
                a = dataset[i:(i+look_back), 0]
                X.append(a)
                Y.append(dataset[i + look_back, 0])
            return np.array(X), np.array(Y)

        look_back = 100
        X_train, Y_train = create_dataset(train_data, look_back)
        X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 1))

        # Create and fit the LSTM network; potentially with more layers
        model = Sequential()
        model.add(LSTM(units=256, return_sequences=True, input_shape=(X_train.shape[1], 1)))
        model.add(LSTM(units=128, return_sequences=True)) # Add more LSTM layers if needed
        model.add(LSTM(units=64))
        model.add(Dense(1))
        model.compile(loss='mean_squared_error', optimizer='adam')

        os.makedirs(self.model_path, exist_ok=True)
        filepath = os.path.join(self.model_path,"stock_prediction_model.h5") 
        checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True, mode='min') # Save best model
        early_stop = EarlyStopping(monitor='loss', patience=10, restore_best_weights=True) # Prevent overfitting

        # Train the model
        model.fit(X_train, Y_train, epochs=500, batch_size=64, callbacks=[checkpoint, early_stop])  # Increase epochs potentially

        self.model = model   

    def inference(self, symbol, timeframe, start_date, end_date):
        data = self.get_stock_data(symbol, timeframe, start_date, end_date)

        if data is None:
            return "Error fetching data", None

        data = self.calculate_indicators(data)

        if len(data) < 20: 
            return "Insufficient data for prediction & chart", None

        predicted_price = self.predict_next_movement(data)
        chart = self.plot_candlestick_chart(data, symbol, timeframe)
        return f"Predicted price: ${predicted_price:.2f}", chart

def main():
    app = CandlestickApp()
    iface = gr.Interface(
        fn=app.inference,
        inputs=[
            gr.inputs.Textbox(lines=1, placeholder="Enter Stock Symbol (e.g., AAPL)", label="Stock Symbol"),
            gr.inputs.Dropdown(["1m", "5m", "15m", "30m", "1h", "1d"], label="Timeframe"),
            gr.inputs.DatePicker(label="Start Date"),  # New, for start date
            gr.inputs.DatePicker(label="End Date"),    # New, for end date
        ],
        outputs=[
            gr.outputs.Textbox(label="Prediction"),
            gr.outputs.Image(label="Candlestick Chart"), 
        ],
        title="Stock Market Prediction & Analysis (Enhanced)",
        description="Enter a stock symbol, timeframe, and date range to get a prediction and candlestick chart analysis.",
    )

    iface.launch()

if __name__ == "__main__":
    main()