import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn import metrics
from datasets import load_dataset

import histos

dataset = load_dataset("cmpatino/optimal_observables", "train")
dataset_df = dataset["train"].to_pandas()
dataset_df["target"] = dataset_df["target"].map({0: "spin-OFF", 1: "spin-ON"})


def get_roc_auc_scores(pos_samples, neg_samples):
    y_score = np.concatenate([pos_samples, neg_samples], axis=0)
    if pos_samples.mean() >= neg_samples.mean():
        y_true = np.concatenate(
            [np.ones_like(pos_samples), np.zeros_like(neg_samples)], axis=0
        )
        roc_auc_score = metrics.roc_auc_score(y_true, y_score)
    else:
        y_true = np.concatenate(
            [np.zeros_like(pos_samples), np.ones_like(neg_samples)], axis=0
        )
        roc_auc_score = metrics.roc_auc_score(y_true, y_score)
    return roc_auc_score


def get_plot(features, n_bins):
    plotting_df = dataset_df.copy()
    if len(features) == 1:
        fig, ax = plt.subplots()
        pos_samples = plotting_df[plotting_df["target"] == "spin-ON"][features[0]]
        neg_samples = plotting_df[plotting_df["target"] == "spin-OFF"][features[0]]
        roc_auc_score = get_roc_auc_scores(pos_samples, neg_samples)
        values = [
            pos_samples,
            neg_samples,
        ]
        labels = ["spin-ON", "spin-OFF"]
        fig = histos.ratio_hist(
            processes_q=values,
            hist_labels=labels,
            reference_label=labels[1],
            n_bins=n_bins,
            hist_range=None,
            title=f"{features[0]} (ROC AUC: {roc_auc_score:.3f})",
        )
        return fig
    if len(features) == 2:
        fig, ax = plt.subplots(ncols=2, figsize=(12, 6))
        pos_samples = plotting_df[plotting_df["target"] == "spin-ON"][features]
        neg_samples = plotting_df[plotting_df["target"] == "spin-OFF"][features]
        x_lims = (
            min(pos_samples[features[0]].min(), neg_samples[features[0]].min()),
            max(pos_samples[features[0]].max(), neg_samples[features[0]].max()),
        )
        y_lims = (
            min(pos_samples[features[1]].min(), neg_samples[features[1]].min()),
            max(pos_samples[features[1]].max(), neg_samples[features[1]].max()),
        )
        ranges = (x_lims, y_lims)

        sns.histplot(
            pos_samples,
            x=features[0],
            y=features[1],
            bins=n_bins,
            ax=ax[0],
            color="C0",
            binrange=ranges,
        )
        sns.histplot(
            neg_samples,
            x=features[0],
            y=features[1],
            bins=n_bins,
            ax=ax[1],
            color="C1",
            binrange=ranges,
        )
        ax[0].set_title("spin-ON")
        ax[1].set_title("spin-OFF")
        return fig


with gr.Blocks() as demo:
    with gr.Tab("Plots"):
        with gr.Column():
            with gr.Row():
                features = gr.Dropdown(
                    choices=dataset_df.columns.to_list(),
                    label="Feature",
                    value="m_tt",
                    multiselect=True,
                )
                n_bins = gr.Slider(
                    label="Number of Bins for Histogram",
                    value=10,
                    minimum=10,
                    maximum=100,
                    step=10,
                )

            feature_plot = gr.Plot(label="Feature's Plot")
    with gr.Tab("ROC-AUC Table"):
        roc_auc_values = []
        for feature in dataset_df.columns.to_list():
            if feature in ["target", "reco_weight"]:
                continue
            pos_samples = dataset_df[dataset_df["target"] == "spin-ON"][feature]
            neg_samples = dataset_df[dataset_df["target"] == "spin-OFF"][feature]
            roc_auc_score = get_roc_auc_scores(pos_samples, neg_samples)
            roc_auc_values.append([feature, roc_auc_score])
        roc_auc_table = gr.Dataframe(
            label="ROC-AUC Table", headers=["Feature", "ROC-AUC"], value=roc_auc_values
        )

    features.change(
        get_plot,
        [features, n_bins],
        feature_plot,
        queue=False,
    )
    n_bins.change(
        get_plot,
        [features, n_bins],
        feature_plot,
        queue=False,
    )
    demo.load(
        get_plot,
        [features, n_bins],
        feature_plot,
        queue=False,
    )

if __name__ == "__main__":
    demo.launch()