# Original Author: Gael Varoquaux
# Gradio Implementation: Lenix Carter
# License: BSD 3-Clause or CC-0

import gradio as gr
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patheffects as PathEffects

from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import pairwise_distances

np.random.seed(0)
matplotlib.use('agg')
labels = ("Waveform 1", "Waveform 2", "Waveform 3")
colors = ["#f7bd01", "#377eb8", "#f781bf"]
n_clusters = 3

def sqr(x):
    return np.sign(np.cos(x))

def ground_truth_plot(n_features):
    t = np.pi * np.linspace(0, 1, n_features)

    X = list()
    y = list()
    for i, (phi, a) in enumerate([(0.5, 0.15), (0.5, 0.6), (0.3, 0.2)]):
        for _ in range(30):
            phase_noise = 0.01 * np.random.normal()
            amplitude_noise = 0.04 * np.random.normal()
            additional_noise = 1 - 2 * np.random.rand(n_features)
            # Make the noise sparse
            additional_noise[np.abs(additional_noise) < 0.997] = 0

            X.append(
                12
                * (
                    (a + amplitude_noise) * (sqr(6 * (t + phi + phase_noise)))
                    + additional_noise
                )
            )
            y.append(i)

    X = np.array(X)
    y = np.array(y)
    
    gt_plot, ax = plt.subplots()

    for l, color, n in zip(range(n_clusters), colors, labels):
        lines = plt.plot(X[y == l].T, c=color, alpha=0.5)
        lines[0].set_label(n)

    plt.subplots_adjust(top=0.8, bottom=0, left=0, right=1.0)
    ax.set_title("Ground Truth", size=20, pad=1)
    plt.legend(loc="best")
    plt.axis("off")

    return gt_plot, X, y

def plot_cluster_waves(metric, X, y):
    model = AgglomerativeClustering(
        n_clusters=n_clusters, linkage="average", metric=metric
    )
    model.fit(X)

    clust_plot, ax = plt.subplots()
    for l, color in zip(np.arange(model.n_clusters), colors):
        plt.plot(X[model.labels_ == l].T, c=color, alpha=0.5)

    plt.subplots_adjust(top=0.75, bottom=0, left=0, right=1.0)
    ax.set_title("Agglomerative Clustering\n(metric=%s)" % metric, size=20, pad=1.0)
    plt.axis("tight")
    plt.axis("off")
    return clust_plot

def plot_distances(metric, X, y):
    avg_dist = np.zeros((n_clusters, n_clusters))
    dist_plot, ax = plt.subplots()
   
    for i in range(n_clusters):
        for j in range(n_clusters):
            avg_dist[i, j] = pairwise_distances(
                X[y == i], X[y == j], metric=metric
            ).mean()
    avg_dist /= avg_dist.max()
    for i in range(n_clusters):
        for j in range(n_clusters):
            t = plt.text(
                i,
                j,
                "%5.3f" % avg_dist[i, j],
                verticalalignment="center",
                horizontalalignment="center",
            )
            t.set_path_effects(
                [PathEffects.withStroke(linewidth=5, foreground="w", alpha=0.5)]
            )

    plt.imshow(avg_dist, interpolation="nearest", cmap="cividis", vmin=0)
    plt.xticks(range(n_clusters), labels, rotation=45)
    plt.yticks(range(n_clusters), labels)
    plt.colorbar()
    plt.subplots_adjust(top=0.8)
    ax.set_title("Interclass %s distances" % metric, size=20, pad=1.0)
    plt.axis("off")
    return dist_plot

def agg_cluster(n_feats, measure):
    plt.clf()
    gt_plt, X, y = ground_truth_plot(n_feats)
    cluster_waves_plot = plot_cluster_waves(measure, X, y)
    dist_plot = plot_distances(measure, X, y)
    return gt_plt, cluster_waves_plot, dist_plot

title = "Agglomerative clustering with different metrics"
with gr.Blocks() as demo:
    gr.Markdown(f" # {title}")
    gr.Markdown(
            """
            This example demonstrates the effect of different metrics on hierarchical clustering.

            This is based on the example [here](https://scikit-learn.org/stable/auto_examples/cluster/plot_agglomerative_clustering_metrics.html#sphx-glr-auto-examples-cluster-plot-agglomerative-clustering-metrics-py)
            """
            )
    with gr.Row():
        with gr.Column():
            n_feats = gr.Slider(10, 4000, 2000, label="Number of Features")
            measure = gr.Radio(["cosine", "euclidean", "cityblock"], label="Metric", value="cosine")
        gt_graph = gr.Plot(label="Ground Truth Graph")
        gt_graph.style()
    with gr.Row():
        dist_plot = gr.Plot(label="Interclass Distances")
        clust_waves = gr.Plot(label="Agglomerative Clustering")

    n_feats.change(
                   fn=agg_cluster,
                   inputs=[n_feats, measure],
                   outputs=[gt_graph, clust_waves, dist_plot]
                  )
    measure.change(
                   fn=agg_cluster,
                   inputs=[n_feats, measure],
                   outputs=[gt_graph, clust_waves, dist_plot]
                  )

if __name__ == '__main__':
    demo.launch()