"""Gradio demo for different clustering techiniques

Derived from https://scikit-learn.org/stable/auto_examples/cluster/plot_cluster_comparison.html

"""

import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import (
    AgglomerativeClustering, Birch, DBSCAN, KMeans, MeanShift, OPTICS, SpectralClustering, estimate_bandwidth
)
from sklearn.datasets import make_blobs, make_circles, make_moons
from sklearn.mixture import GaussianMixture
from sklearn.neighbors import kneighbors_graph
from sklearn.preprocessing import StandardScaler


plt.style.use('seaborn')


SEED = 0
N_CLUSTERS = 4
N_SAMPLES = 1000
np.random.seed(SEED)


def normalize(X):
    return StandardScaler().fit_transform(X)


def get_regular():
    centers = [[1, 1], [1, -1], [-1, 1], [-1, -1]]
    assert len(centers) == N_CLUSTERS
    X, labels = make_blobs(n_samples=N_SAMPLES, centers=centers, cluster_std=0.7, random_state=SEED)
    return normalize(X), labels


def get_circles():
    X, labels = make_circles(n_samples=N_SAMPLES, factor=0.5, noise=0.05, random_state=SEED)
    return normalize(X), labels


def get_moons():
    X, labels = make_moons(n_samples=N_SAMPLES, noise=0.05, random_state=SEED)
    return normalize(X), labels


def get_noise():
    X, labels = np.random.rand(N_SAMPLES, 2), np.zeros(N_SAMPLES)
    return normalize(X), labels


def get_anisotropic():
    X, labels = make_blobs(n_samples=N_SAMPLES, centers=N_CLUSTERS, random_state=170)
    transformation = [[0.6, -0.6], [-0.4, 0.8]]
    X = np.dot(X, transformation)
    return X, labels


def get_varied():
    X, labels = make_blobs(
        n_samples=N_SAMPLES, cluster_std=[1.0, 2.5, 0.5], random_state=SEED
    )
    return normalize(X), labels


DATA_MAPPING = {
    'regular': get_regular,
    'circles': get_circles,
    'moons': get_moons,
    'noise': get_noise,
    'anisotropic': get_anisotropic,
    'varied': get_varied,
}

def get_kmeans(X, **kwargs):
    model = KMeans(init="k-means++", n_clusters=N_CLUSTERS, n_init=10, random_state=SEED)
    model.set_params(**kwargs)
    return model.fit(X)


def get_dbscan(X, **kwargs):
    model = DBSCAN(eps=0.3)
    model.set_params(**kwargs)
    return model.fit(X)


def get_agglomerative(X, **kwargs):
    connectivity = kneighbors_graph(
        X, n_neighbors=N_CLUSTERS, include_self=False
    )
    # make connectivity symmetric
    connectivity = 0.5 * (connectivity + connectivity.T)
    model = AgglomerativeClustering(
        n_clusters=N_CLUSTERS, linkage="ward", connectivity=connectivity
    )
    model.set_params(**kwargs)
    return model.fit(X)


def get_meanshift(X, **kwargs):
    bandwidth = estimate_bandwidth(X, quantile=0.3)
    model = MeanShift(bandwidth=bandwidth, bin_seeding=True)
    model.set_params(**kwargs)
    return model.fit(X)


def get_spectral(X, **kwargs):
    model = SpectralClustering(
        n_clusters=N_CLUSTERS,
        eigen_solver="arpack",
        affinity="nearest_neighbors",
    )
    model.set_params(**kwargs)
    return model.fit(X)


def get_optics(X, **kwargs):
    model = OPTICS(
        min_samples=7,
        xi=0.05,
        min_cluster_size=0.1,
    )
    model.set_params(**kwargs)
    return model.fit(X)


def get_birch(X, **kwargs):
    model = Birch(n_clusters=3)
    model.set_params(**kwargs)
    return model.fit(X)


def get_gaussianmixture(X, **kwargs):
    model = GaussianMixture(
        n_components=N_CLUSTERS, covariance_type="full", random_state=SEED,
    )
    model.set_params(**kwargs)
    return model.fit(X)


MODEL_MAPPING = {
    'KMeans': get_kmeans,
    'DBSCAN': get_dbscan,
    'AgglomerativeClustering': get_agglomerative,
    'MeanShift': get_meanshift,
    'SpectralClustering': get_spectral,
    'OPTICS': get_optics,
    'Birch': get_birch,
    'GaussianMixture': get_gaussianmixture,
}


def plot_clusters(ax, X, labels):
    for label in range(N_CLUSTERS):
        idx = labels == label
        if not sum(idx):
            continue
        ax.scatter(X[idx, 0], X[idx, 1])

    ax.grid(None)
    ax.set_xticks([])
    ax.set_yticks([])
    return ax


def cluster(clustering_algorithm: str, dataset: str):
    X, labels = DATA_MAPPING[dataset]()
    model = MODEL_MAPPING[clustering_algorithm](X)
    if hasattr(model, "labels_"):
        y_pred = model.labels_.astype(int)
    else:
        y_pred = model.predict(X)

    fig, axes = plt.subplots(1, 2, figsize=(16, 8))

    ax = axes[0]
    plot_clusters(ax, X, labels)
    ax.set_title("True clusters")

    ax = axes[1]
    plot_clusters(ax, X, y_pred)
    ax.set_title(clustering_algorithm)

    return fig


demo = gr.Interface(
    fn=cluster,
    inputs=[
        gr.Radio(
            list(MODEL_MAPPING),
            value="KMeans",
            label="clustering algorithm"
        ),
        gr.Radio(
            list(DATA_MAPPING),
            value="regular",
            label="dataset"
        ),
    ],
    outputs=gr.Plot(),
)

demo.launch()