import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB


def choose_model(model):
    if model == "Logistic Regression":
        return LogisticRegression(max_iter=1000, random_state=123)
    elif model == "Random Forest":
        return RandomForestClassifier(n_estimators=100, random_state=123)
    elif model == "Gaussian Naive Bayes":
        return GaussianNB()
    else:
        raise ValueError("Model is not supported.")


def get_proba_plots(
    model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight
):
    clf1 = choose_model(model_1)
    clf2 = choose_model(model_2)
    clf3 = choose_model(model_3)
    X = np.array([[-1.0, -1.0], [-1.2, -1.4], [-3.4, -2.2], [1.1, 1.2]])
    y = np.array([1, 1, 2, 2])

    eclf = VotingClassifier(
        estimators=[("clf1", clf1), ("clf2", clf2), ("clf3", clf3)],
        voting="soft",
        weights=[model_1_weight, model_2_weight, model_3_weight],
    )

    # predict class probabilities for all classifiers
    probas = [c.fit(X, y).predict_proba(X) for c in (clf1, clf2, clf3, eclf)]

    # get class probabilities for the first sample in the dataset
    class1_1 = [pr[0, 0] for pr in probas]
    class2_1 = [pr[0, 1] for pr in probas]

    # plotting

    N = 4  # number of groups
    ind = np.arange(N)  # group positions
    width = 0.35  # bar width

    fig, ax = plt.subplots()

    # bars for classifier 1-3
    p1 = ax.bar(
        ind, np.hstack(([class1_1[:-1], [0]])), width, color="green", edgecolor="k"
    )
    p2 = ax.bar(
        ind + width,
        np.hstack(([class2_1[:-1], [0]])),
        width,
        color="lightgreen",
        edgecolor="k",
    )

    # bars for VotingClassifier
    ax.bar(ind, [0, 0, 0, class1_1[-1]], width, color="blue", edgecolor="k")
    ax.bar(
        ind + width, [0, 0, 0, class2_1[-1]], width, color="steelblue", edgecolor="k"
    )

    # plot annotations
    plt.axvline(2.8, color="k", linestyle="dashed")
    ax.set_xticks(ind + width)
    ax.set_xticklabels(
        [
            f"{model_1}\nweight {model_1_weight}",
            f"{model_2}\nweight {model_2_weight}",
            f"{model_3}\nweight {model_3_weight}",
            "VotingClassifier\n(average probabilities)",
        ],
        rotation=40,
        ha="right",
    )
    plt.ylim([0, 1])
    plt.title("Class probabilities for sample 1 by different classifiers")
    plt.legend([p1[0], p2[0]], ["class 1", "class 2"], loc="upper left")
    plt.tight_layout()
    plt.show()
    return fig


with gr.Blocks() as demo:
    gr.Markdown(
        """
        # Class probabilities by the `VotingClassifier`

        This space shows the effect of the weight of different classifiers when using sklearn's [VotingClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.VotingClassifier.html#sklearn.ensemble.VotingClassifier).

        For example, suppose you set the weights as in the table below, and the models have the following predicted probabilities:
        
        |         | Weights | Predicted Probabilities |
        |---------|:-------:|:----------------:|
        | Model 1 |    1    |        0.5       |
        | Model 2 |    2    |        0.8       |
        | Model 3 |    5    |        0.9       |

        The predicted probability by the `VotingClassifier` will be $(1*0.5 + 2*0.8 + 5*0.9) / (1 + 2 + 5)$

        You can experiment with different model types and weights and see their effect on the VotingClassifier's prediction.

        This space is based on [sklearn’s original demo](https://scikit-learn.org/stable/auto_examples/ensemble/plot_voting_probas.html#sphx-glr-auto-examples-ensemble-plot-voting-probas-py).
        """
    )
    with gr.Row():
        with gr.Column(scale=3):
            with gr.Row():
                model_1 = gr.Dropdown(
                    [
                        "Logistic Regression",
                        "Random Forest",
                        "Gaussian Naive Bayes",
                    ],
                    label="Model 1",
                    value="Logistic Regression",
                )
                model_1_weight = gr.Slider(
                    value=1, label="Model 1 Weight", minimum=0, maximum=10, step=1
                )
            with gr.Row():
                model_2 = gr.Dropdown(
                    [
                        "Logistic Regression",
                        "Random Forest",
                        "Gaussian Naive Bayes",
                    ],
                    label="Model 2",
                    value="Random Forest",
                )
                model_2_weight = gr.Slider(
                    value=1, label="Model 2 Weight", minimum=0, maximum=10, step=1
                )
            with gr.Row():
                model_3 = gr.Dropdown(
                    [
                        "Logistic Regression",
                        "Random Forest",
                        "Gaussian Naive Bayes",
                    ],
                    label="Model 3",
                    value="Gaussian Naive Bayes",
                )

                model_3_weight = gr.Slider(
                    value=5, label="Model 3 Weight", minimum=0, maximum=10, step=1
                )
        with gr.Column(scale=4):
            proba_plots = gr.Plot()

    model_1.change(
        get_proba_plots,
        [model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight],
        proba_plots,
        queue=False,
    )
    model_2.change(
        get_proba_plots,
        [model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight],
        proba_plots,
        queue=False,
    )
    model_3.change(
        get_proba_plots,
        [model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight],
        proba_plots,
        queue=False,
    )
    model_1_weight.change(
        get_proba_plots,
        [model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight],
        proba_plots,
        queue=False,
    )
    model_2_weight.change(
        get_proba_plots,
        [model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight],
        proba_plots,
        queue=False,
    )
    model_3_weight.change(
        get_proba_plots,
        [model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight],
        proba_plots,
        queue=False,
    )

    demo.load(
        get_proba_plots,
        [model_1, model_2, model_3, model_1_weight, model_2_weight, model_3_weight],
        proba_plots,
        queue=False,
    )

if __name__ == "__main__":
    demo.launch()