import gradio as gr
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.utils import shuffle
from sklearn.ensemble import StackingRegressor
from sklearn.linear_model import RidgeCV
from skops.hub_utils import download
import joblib
import shutil

# load dataset
def load_ames_housing():
    df = fetch_openml(name="house_prices", as_frame=True, parser="pandas")
    X = df.data
    y = df.target

    features = [
        "YrSold",
        "HeatingQC",
        "Street",
        "YearRemodAdd",
        "Heating",
        "MasVnrType",
        "BsmtUnfSF",
        "Foundation",
        "MasVnrArea",
        "MSSubClass",
        "ExterQual",
        "Condition2",
        "GarageCars",
        "GarageType",
        "OverallQual",
        "TotalBsmtSF",
        "BsmtFinSF1",
        "HouseStyle",
        "MiscFeature",
        "MoSold",
    ]

    X = X.loc[:, features]
    X, y = shuffle(X, y, random_state=0)

    X = X.iloc[:600]
    y = y.iloc[:600]
    return X, np.log(y)

def stacked_model(model1,model2,model3):
    X, y = load_ames_housing()
    estimators = []
    for model in [model1,model2,model3]:
        download(repo_id=model, dst='temp_dir')
        pipeline = joblib.load( "temp_dir/model.pkl")
        estimators.append((model.split('/')[-1], pipeline))
        shutil.rmtree("temp_dir")

    stacking_regressor = StackingRegressor(estimators=estimators, final_estimator=RidgeCV())

    # plot and compare the performance of the single models and the stacked model
    import time
    import matplotlib.pyplot as plt
    from sklearn.metrics import PredictionErrorDisplay
    from sklearn.model_selection import cross_validate, cross_val_predict

    fig, axs = plt.subplots(2, 2, figsize=(9, 7))
    axs = np.ravel(axs)

    for ax, (name, est) in zip(
        axs, estimators + [("Stacking Regressor", stacking_regressor)]
    ):
        scorers = {"R2": "r2", "MAE": "neg_mean_absolute_error"}

        start_time = time.time()
        scores = cross_validate(
            est, X, y, scoring=list(scorers.values()), n_jobs=-1, verbose=0
        )

        elapsed_time = time.time() - start_time

        y_pred = cross_val_predict(est, X, y, n_jobs=-1, verbose=0)
        scores = {
            key: (
                f"{np.abs(np.mean(scores[f'test_{value}'])):.2f} +- "
                f"{np.std(scores[f'test_{value}']):.2f}"
            )
            for key, value in scorers.items()
        }

        display = PredictionErrorDisplay.from_predictions(
            y_true=y,
            y_pred=y_pred,
            kind="actual_vs_predicted",
            ax=ax,
            scatter_kwargs={"alpha": 0.2, "color": "tab:blue"},
            line_kwargs={"color": "tab:red"},
        )
        ax.set_title(f"{name}\nEvaluation in {elapsed_time:.2f} seconds")

        for name, score in scores.items():
            ax.plot([], [], " ", label=f"{name}: {score}")
        ax.legend(loc="upper left")

    fig.suptitle("Single predictor versus stacked predictors")
    fig.tight_layout()
    fig.subplots_adjust(top=0.9)
    return fig

title = "Combine predictors using stacking"
with gr.Blocks(title=title) as demo:
    gr.Markdown(f"## {title}")
    gr.Markdown("""
    This app demonstrates combining 3 predictors trained on Ames housing dataset from OpenML using stacking and Ridge estimator as final estimator.  
    Stacking uses a meta-learning algorithm to learn how to combine the predictions from trained models. 
    The OpenML Ames housing dataset is a processed version of the 'Ames Iowa Housing' with 81 features.
    This app is developed based on [scikit-learn example](https://scikit-learn.org/stable/auto_examples/ensemble/plot_stack_predictors.html#sphx-glr-auto-examples-ensemble-plot-stack-predictors-py)
    """)

    model1 = gr.Textbox(label="Repo id of first model", value="haizad/ames-housing-random-forest-predictor")
    model2 = gr.Textbox(label="Repo id of second model", value="haizad/ames-housing-gbdt-predictor")
    model3 = gr.Textbox(label="Repo id of third model", value="haizad/ames-housing-lasso-predictor")
    plot = gr.Plot(label="Comparison of single predictor against stacked predictor")
    stack_btn = gr.Button("Stack")
    stack_btn.click(fn=stacked_model, inputs=[model1,model2,model3], outputs=[plot])

demo.launch()