import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import make_moons, make_circles, make_classification
from sklearn.neural_network import MLPClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.gaussian_process.kernels import RBF
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.datasets import make_blobs, make_circles, make_moons
import gradio as gr
import math
from functools import partial



### DATASETS

def normalize(X):
    return StandardScaler().fit_transform(X)


def linearly_separable():
    X, y = make_classification(
        n_features=2, n_redundant=0, n_informative=2, random_state=1, n_clusters_per_class=1
    )
    rng = np.random.RandomState(2)
    X += 2 * rng.uniform(size=X.shape)
    linearly_separable = (X, y)
    return linearly_separable

DATA_MAPPING = {
    "Moons": make_moons(noise=0.3, random_state=0),
    "Circles":make_circles(noise=0.2, factor=0.5, random_state=1),
    "Linearly Separable Random Dataset": linearly_separable(),
}


#### MODELS

def get_groundtruth_model(X, labels):
    # dummy model to show true label distribution
    class Dummy:
        def __init__(self, y):
            self.labels_ = labels

    return Dummy(labels)
    
DATASETS = [
    make_moons(noise=0.3, random_state=0),
    make_circles(noise=0.2, factor=0.5, random_state=1),
    linearly_separable()
]
NAME_CLF_MAPPING = {
    "Ground Truth":get_groundtruth_model,
    "Nearest Neighbors":KNeighborsClassifier(3),
    "Linear SVM":SVC(kernel="linear", C=0.025),
    "RBF SVM":SVC(gamma=2, C=1),
    "Gaussian Process":GaussianProcessClassifier(1.0 * RBF(1.0)),
    "Decision Tree":DecisionTreeClassifier(max_depth=5),
    "Random Forest":RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1),
    "Neural Net":MLPClassifier(alpha=1, max_iter=1000),
    "AdaBoost":AdaBoostClassifier(),
    "Naive Bayes":GaussianNB(),
}



#### PLOT
FIGSIZE = 7,7
figure = plt.figure(figsize=(25, 10))
i = 1




def train_models(selected_data, clf_name):
    cm = plt.cm.RdBu
    cm_bright = ListedColormap(["#FF0000", "#0000FF"])
    clf = NAME_CLF_MAPPING[clf_name]
    
    X, y = DATA_MAPPING[selected_data]
    X = StandardScaler().fit_transform(X)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.4, random_state=42
    )
    
    x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
    y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
    if clf_name != "Ground Truth":
        clf.fit(X_train, y_train)
        score = clf.score(X_test, y_test)
        fig, ax = plt.subplots(figsize=FIGSIZE)
        ax.set_title(clf_name, fontsize = 10)
        
        DecisionBoundaryDisplay.from_estimator(
                clf, X, cmap=cm, alpha=0.8, ax=ax, eps=0.5
            ).plot()
        return fig
    else:
        #########
        
        for ds_cnt, ds in enumerate(DATASETS):
            X, y = ds

            x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
            y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5

            # just plot the dataset first
            cm = plt.cm.RdBu
            cm_bright = ListedColormap(["#FF0000", "#0000FF"])
            fig, ax = plt.subplots(figsize=FIGSIZE)
            ax.set_title("Input data")
            # Plot the training points

            ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=cm_bright, edgecolors="k")
            # Plot the testing points
            ax.scatter(
                X_test[:, 0], X_test[:, 1], c=y_test, cmap=cm_bright, alpha=0.6, edgecolors="k"
            )
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_xticks(())
            ax.set_yticks(())

            return fig



        ###########
description = "Learn how different statistical classifiers perform in different datasets."

def iter_grid(n_rows, n_cols):
    # create a grid using gradio Block
    for _ in range(n_rows):
        with gr.Row():
            for _ in range(n_cols):
                with gr.Column():
                    yield

title = "Compare Classifiers!"
with gr.Blocks(title=title) as demo:
    gr.Markdown(f"## {title}")
    gr.Markdown(description)

    input_models = list(NAME_CLF_MAPPING)
    input_data = gr.Radio(
        choices=["Moons", "Circles", "Linearly Separable Random Dataset"],
        value="Moons"
    )
    counter = 0


    for _ in iter_grid(2, 5):
        if counter >= len(input_models):
            break

        input_model = input_models[counter]
        plot = gr.Plot(label=input_model)
        fn = partial(train_models, clf_name=input_model)
        input_data.change(fn=fn, inputs=[input_data], outputs=plot)
        counter += 1

demo.launch(debug=True)