import gradio as gr
import pandas as pd 
import numpy as np
from time import time
from sklearn import metrics
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from huggingface_hub import login
from datasets import load_dataset
import matplotlib.pyplot as plt


# https://scikit-learn.org/stable/auto_examples/cluster/plot_kmeans_digits.html#sphx-glr-auto-examples-cluster-plot-kmeans-digits-py

def display_plot(data, n_digits):
    reduced_data = PCA(n_components=2).fit_transform(data)
    kmeans = KMeans(init="k-means++", n_clusters=n_digits, n_init=4)
    kmeans.fit(reduced_data)

    # Step size of the mesh. Decrease to increase the quality of the VQ.
    h = 0.02  # point in the mesh [x_min, x_max]x[y_min, y_max].

    # Plot the decision boundary. For that, we will assign a color to each
    x_min, x_max = reduced_data[:, 0].min() - 1, reduced_data[:, 0].max() + 1
    y_min, y_max = reduced_data[:, 1].min() - 1, reduced_data[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))

    # Obtain labels for each point in mesh. Use last trained model.
    Z = kmeans.predict(np.c_[xx.ravel(), yy.ravel()])

    # Put the result into a color plot
    Z = Z.reshape(xx.shape)
    
    fig = plt.figure()

    plt.clf()
    plt.imshow(
        Z,
        interpolation="nearest",
        extent=(xx.min(), xx.max(), yy.min(), yy.max()),
        cmap=plt.cm.Paired,
        aspect="auto",
        origin="lower",
    )

    plt.plot(reduced_data[:, 0], reduced_data[:, 1], "k.", markersize=2)
    # Plot the centroids as a white X
    centroids = kmeans.cluster_centers_
    plt.scatter(
        centroids[:, 0],
        centroids[:, 1],
        marker="x",
        s=169,
        linewidths=3,
        color="w",
        zorder=10,
    )
    plt.title(
        "K-means clustering on the digits dataset (PCA-reduced data)\n"
        "Centroids are marked with white cross"
    )
    plt.xlim(x_min, x_max)
    plt.ylim(y_min, y_max)
    plt.xticks(())
    plt.yticks(())
    return fig

def bench_k_means(kmeans, name, data, labels):
    """Benchmark to evaluate the KMeans initialization methods.

    Parameters
    ----------
    kmeans : KMeans instance
        A :class:`~sklearn.cluster.KMeans` instance with the initialization
        already set.
    name : str
        Name given to the strategy. It will be used to show the results in a
        table.
    data : ndarray of shape (n_samples, n_features)
        The data to cluster.
    labels : ndarray of shape (n_samples,)
        The labels used to compute the clustering metrics which requires some
        supervision.
    """
    t0 = time()
    estimator = make_pipeline(StandardScaler(), kmeans).fit(data)
    fit_time = time() - t0
    results = [name, fit_time, estimator[-1].inertia_]

    # Define the metrics which require only the true labels and estimator
    # labels
    clustering_metrics = [
        metrics.homogeneity_score,
        metrics.completeness_score,
        metrics.v_measure_score,
        metrics.adjusted_rand_score,
        metrics.adjusted_mutual_info_score,
    ]
    results += [m(labels, estimator[-1].labels_) for m in clustering_metrics]

    # The silhouette score requires the full dataset
    results += [
        metrics.silhouette_score(
            data,
            estimator[-1].labels_,
            metric="euclidean",
            sample_size=300,
        )
    ]

    return results

title = "A demo of K-Means clustering on the handwritten digits data"
def do_submit(kmeans_n_digit,random_n_digit, pca_n_digit):
    # Load the dataset
    dataset = load_dataset("sklearn-docs/digits", header=None)
    # convert dataset to pandas
    df = dataset['train'].to_pandas()
    data = df.iloc[:, :64]
    labels = df.iloc[:, 64]

    kmeans = KMeans(init="k-means++", n_clusters=int(kmeans_n_digit), n_init=4, random_state=0)
    results = bench_k_means(kmeans=kmeans, name="k-means++", data=data, labels=labels)
    
    df = pd.DataFrame(results).T
    numeric_cols = ['time','inertia','homo','compl','v-meas','ARI','AMI','silhouette']
    df.columns = ['init'] + numeric_cols

    kmeans = KMeans(init="random", n_clusters=int(random_n_digit), n_init=4, random_state=0)
    results = bench_k_means(kmeans=kmeans, name="random", data=data, labels=labels)
    df.loc[len(df.index)] = results
    
    pca = PCA(n_components=int(pca_n_digit)).fit(data)
    kmeans = KMeans(init=pca.components_, n_clusters=int(pca_n_digit), n_init=1)
    results = bench_k_means(kmeans=kmeans, name="PCA-based", data=data, labels=labels)
    df.loc[len(df.index)] = results
    df[df.columns[1:]] = df.iloc[:,1:].astype(float).round(3)
    
    df = df.T #Transpose for display
    df.columns = df.iloc[0,:].tolist()
    df = df.iloc[1:,:].reset_index()
    df.columns = ['metrics', 'k-means++', 'random', 'PCA-based']
    return display_plot(data, kmeans_n_digit), df

#Theme from - https://huggingface.co/spaces/trl-lib/stack-llama/blob/main/app.py
theme = gr.themes.Monochrome(
    primary_hue="indigo",
    secondary_hue="blue",
    neutral_hue="slate",
    radius_size=gr.themes.sizes.radius_sm,
    font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
)

with gr.Blocks(title=title, theme=theme) as demo:
    gr.Markdown(f"## {title}")
    gr.Markdown("This demo is based on this [scikit-learn example](https://scikit-learn.org/stable/auto_examples/cluster/plot_kmeans_digits.html#sphx-glr-auto-examples-cluster-plot-kmeans-digits-py)")
    gr.Markdown("In this example we compare the various initialization strategies for K-means in terms of runtime and quality of the results.")
    gr.Markdown("As the ground truth is known here, we also apply different cluster quality metrics to judge the goodness of fit of the cluster labels to the ground truth.")
    gr.Markdown("Cluster quality metrics evaluated (see [Clustering performance evaluation](https://scikit-learn.org/stable/modules/clustering.html#clustering-evaluation) \
                for definitions and discussions of the metrics):")
    gr.Markdown("---")
    gr.Markdown(" We will be utilizing [digits](https://huggingface.co/datasets/sklearn-docs/digits) dataset. This dataset contains handwritten digits from 0 to 9. \
        In the context of clustering, one would like to group images such that the handwritten digits on the image are the same.")


    with gr.Row():
        with gr.Column(scale=0.5):
            kmeans_n_digit = gr.Slider(minimum=2, maximum=10, label="KMeans n_digits", info="n_digits is number of handwritten digits" , step=1, value=10)
            random_n_digit = gr.Slider(minimum=2, maximum=10, label="Random n_digits",  step=1, value=10)
            pca_n_digit = gr.Slider(minimum=2, maximum=10, label="PCA n_digits",step=1, value=10)
            
            plt_out = gr.Plot()
            
        with gr.Column(scale=0.5):
            sample_df = pd.DataFrame(np.zeros((9,4)),columns=['metrics', 'k-means++', 'random', 'PCA-based'])
        
            output = gr.Dataframe(sample_df, label="Clustering Metrics")
            
    with gr.Row():
            sub_btn = gr.Button("Submit")
            sub_btn.click(fn=do_submit, inputs=[kmeans_n_digit,random_n_digit, pca_n_digit], outputs=[plt_out, output])

demo.launch()