import gradio as gr
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from datasets import load_dataset
from evaluate.utils import parse_readme
from scipy.stats import gaussian_kde, spearmanr

import generate_annotated_diffs
from api_wrappers import hf_data_loader
from generation_steps.metrics_analysis import AGGR_METRICS, edit_distance_fn

colors = {
    "Expert-labeled": "#C19C0B",
    "Synthetic Backward": "#913632",
    "Synthetic Forward": "#58136a",
    "Full": "#000000",
}

METRICS = {
    "Edit Distance": "editdist",
    "Edit Similarity": "editsim",
    "BLEU": "bleu",
    "METEOR": "meteor",
    "ROUGE-1": "rouge1",
    "ROUGE-2": "rouge2",
    "ROUGE-L": "rougeL",
    "BERTScore": "bertscore",
    "ChrF": "chrF",
}


df_related = generate_annotated_diffs.data_with_annotated_diffs()


def golden():
    return df_related.loc[(df_related["G_type"] == "initial") & (df_related["E_type"] == "expert_labeled")].reset_index(
        drop=True
    )


def backward():
    return df_related.loc[
        (df_related["G_type"] == "synthetic_backward") & (df_related["E_type"] == "expert_labeled")
    ].reset_index(drop=True)


def forward():
    return df_related.loc[
        (df_related["G_type"] == "initial") & (df_related["E_type"] == "synthetic_forward")
    ].reset_index(drop=True)


def forward_from_backward():
    return df_related.loc[
        (df_related.G_type == "synthetic_backward")
        & (df_related.E_type.isin(["synthetic_forward", "synthetic_forward_from_backward"]))
    ].reset_index(drop=True)


n_diffs_manual = len(golden())
n_diffs_synthetic_backward = len(backward())
n_diffs_synthetic_forward = len(forward())
n_diffs_synthetic_forward_backward = len(forward_from_backward())


def update_dataset_view(diff_idx, df):
    diff_idx -= 1
    return (
        df.iloc[diff_idx]["annotated_diff"],
        df.iloc[diff_idx]["commit_msg_start"] if "commit_msg_start" in df.columns else df.iloc[diff_idx]["G_text"],
        df.iloc[diff_idx]["commit_msg_end"] if "commit_msg_end" in df.columns else df.iloc[diff_idx]["E_text"],
        f"https://github.com/{df.iloc[diff_idx]['repo']}/commit/{df.iloc[diff_idx]['hash']}",
    )


def update_dataset_view_manual(diff_idx):
    return update_dataset_view(diff_idx, golden())


def update_dataset_view_synthetic_backward(diff_idx):
    return update_dataset_view(diff_idx, backward())


def update_dataset_view_synthetic_forward(diff_idx):
    return update_dataset_view(diff_idx, forward())


def update_dataset_view_synthetic_forward_backward(diff_idx):
    return update_dataset_view(diff_idx, forward_from_backward())


def number_of_pairs_plot():
    related_plot_dict = {
        "Full": df_related,
        "Synthetic Backward": backward(),
        "Synthetic Forward": pd.concat([forward(), forward_from_backward()], axis=0, ignore_index=True),
        "Expert-labeled": golden(),
    }

    df_unrelated = hf_data_loader.load_synthetic_as_pandas()
    df_unrelated = df_unrelated.loc[~df_unrelated.is_related].copy()
    unrelated_plot_dict = {
        "Full": df_unrelated,
        "Synthetic Backward": df_unrelated.loc[
            (df_unrelated["G_type"] == "synthetic_backward")
            & (~df_unrelated.E_type.isin(["synthetic_forward", "synthetic_forward_from_backward"]))
        ],
        "Synthetic Forward": df_unrelated.loc[
            ((df_unrelated["G_type"] == "initial") & (df_unrelated["E_type"] == "synthetic_forward"))
            | (
                (df_unrelated["G_type"] == "synthetic_backward")
                & (df_unrelated["E_type"].isin(["synthetic_forward", "synthetic_forward_from_backward"]))
            )
        ],
        "Expert-labeled": df_unrelated.loc[
            (df_unrelated.G_type == "initial") & (df_unrelated.E_type == "expert_labeled")
        ],
    }

    traces = []

    for split in related_plot_dict.keys():
        related_count = len(related_plot_dict[split])
        unrelated_count = len(unrelated_plot_dict[split])

        traces.append(
            go.Bar(
                name=f"{split} - Related pairs",
                x=[split],
                y=[related_count],
                marker=dict(
                    color=colors[split],
                ),
            )
        )

        traces.append(
            go.Bar(
                name=f"{split} - Conditionally independent pairs",
                x=[split],
                y=[unrelated_count],
                marker=dict(
                    color=colors[split],
                    pattern=dict(
                        shape="/",  # Crosses
                        fillmode="overlay",
                        solidity=0.5,
                    ),
                ),
            )
        )

    fig = go.Figure(data=traces)

    fig.update_layout(
        barmode="stack",
        bargap=0.2,
        xaxis=dict(title="Split", showgrid=True, gridcolor="lightgrey"),
        yaxis=dict(title="Number of Examples", showgrid=True, gridcolor="lightgrey"),
        legend=dict(title="Pair Type", orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
        plot_bgcolor="rgba(0,0,0,0)",
        paper_bgcolor="rgba(0,0,0,0)",
        width=1100,
    )
    return fig


def edit_distance_plot():
    df_edit_distance = {
        "Full": [edit_distance_fn(pred=row["G_text"], ref=row["E_text"]) for _, row in df_related.iterrows()],
        "Synthetic Backward": [
            edit_distance_fn(pred=row["G_text"], ref=row["E_text"]) for _, row in backward().iterrows()
        ],
        "Synthetic Forward": [
            edit_distance_fn(pred=row["G_text"], ref=row["E_text"])
            for _, row in pd.concat([forward(), forward_from_backward()], axis=0, ignore_index=True).iterrows()
        ],
        "Expert-labeled": [edit_distance_fn(pred=row["G_text"], ref=row["E_text"]) for _, row in golden().iterrows()],
    }
    traces = []

    for key in df_edit_distance:
        kde_x = np.linspace(0, 1200, 1000)
        kde = gaussian_kde(df_edit_distance[key])
        kde_line = go.Scatter(x=kde_x, y=kde(kde_x), mode="lines", name=key, line=dict(color=colors[key], width=5))
        traces.append(kde_line)

    fig = go.Figure(data=traces)

    fig.update_layout(
        bargap=0.1,
        xaxis=dict(title=dict(text="Edit Distance"), range=[0, 1200], showgrid=True, gridcolor="lightgrey"),
        yaxis=dict(
            title=dict(text="Probability Density"),
            range=[0, 0.004],
            showgrid=True,
            gridcolor="lightgrey",
            tickvals=[0.0005, 0.001, 0.0015, 0.002, 0.0025, 0.003, 0.0035, 0.004],
            tickformat=".4f",
        ),
        plot_bgcolor="rgba(0,0,0,0)",
        paper_bgcolor="rgba(0,0,0,0)",
        width=1100,
    )
    return fig


def get_correlations_table(online_metric_name: str) -> pd.DataFrame:
    df = load_dataset(
        "JetBrains-Research/synthetic-commit-msg-edits", "all_pairs_with_metrics_other_online_metrics", split="train"
    ).to_pandas()
    corr_df = (
        df.loc[~df.is_related]
        .groupby(["G_text", "G_type", "hash", "repo"] + [f"online_{online_metric_name}"])
        .apply(lambda g: g.to_dict(orient="records"), include_groups=False)
        .reset_index(name="unrelated_pairs")
        .copy()
    )
    _ = corr_df.copy()
    for metric in AGGR_METRICS:
        if metric in ["editdist"]:
            _[metric] = _.unrelated_pairs.apply(lambda pairs: min(pair[metric] for pair in pairs))
        else:
            _[metric] = _.unrelated_pairs.apply(lambda pairs: max(pair[metric] for pair in pairs))

    results = []

    for metric in AGGR_METRICS:
        x = _[metric].to_numpy()
        y = _[f"online_{online_metric_name}"].to_numpy()
        corr, p_value = spearmanr(x, y)
        results.append({"metric": metric, "corr": corr, "p_value": p_value})

    __ = pd.DataFrame(results)
    __["p_value"] = ["< 0.05" if p < 0.05 else p for p in __.p_value]
    __["corr_abs"] = abs(__["corr"])
    __["corr"] = __["corr"].round(2)
    __["metric"] = __["metric"].map({v: k for k, v in METRICS.items()})
    return (
        __.sort_values(by=["corr_abs"], ascending=False)
        .drop(columns=["corr_abs"])
        .rename(columns={"metric": "Metric m", "corr": "Correlation Q(m, m*)", "p_value": "p-value"})
    )


force_light_theme_js_func = """
function refresh() {
    const url = new URL(window.location);

    if (url.searchParams.get('__theme') !== 'light') {
        url.searchParams.set('__theme', 'light');
        window.location.href = url.href;
    }
}
"""

if __name__ == "__main__":
    with gr.Blocks(theme=gr.themes.Soft(), js=force_light_theme_js_func) as application:
        gr.Markdown(parse_readme("README.md"))

        def dataset_view_tab(n_items):
            slider = gr.Slider(minimum=1, maximum=n_items, step=1, value=1, label=f"Sample number (total: {n_items})")

            diff_view = gr.Highlightedtext(combine_adjacent=True, color_map={"+": "green", "-": "red"})
            start_view = gr.Textbox(interactive=False, label="Initial message G", container=True)
            end_view = gr.Textbox(interactive=False, label="Edited message E", container=True)
            link_view = gr.Markdown()

            view = [diff_view, start_view, end_view, link_view]

            return slider, view

        with gr.Tab("Examples Exploration"):
            with gr.Tab("Manual"):
                slider_manual, view_manual = dataset_view_tab(n_diffs_manual)

                slider_manual.change(update_dataset_view_manual, inputs=slider_manual, outputs=view_manual)

            with gr.Tab("Synthetic Backward"):
                slider_synthetic_backward, view_synthetic_backward = dataset_view_tab(n_diffs_synthetic_backward)

                slider_synthetic_backward.change(
                    update_dataset_view_synthetic_backward,
                    inputs=slider_synthetic_backward,
                    outputs=view_synthetic_backward,
                )

            with gr.Tab("Synthetic Forward (from initial)"):
                slider_synthetic_forward, view_synthetic_forward = dataset_view_tab(n_diffs_synthetic_forward)

                slider_synthetic_forward.change(
                    update_dataset_view_synthetic_forward,
                    inputs=slider_synthetic_forward,
                    outputs=view_synthetic_forward,
                )

            with gr.Tab("Synthetic Forward (from backward)"):
                slider_synthetic_forward_backward, view_synthetic_forward_backward = dataset_view_tab(
                    n_diffs_synthetic_forward_backward
                )

                slider_synthetic_forward_backward.change(
                    update_dataset_view_synthetic_forward_backward,
                    inputs=slider_synthetic_forward_backward,
                    outputs=view_synthetic_forward_backward,
                )

        with gr.Tab("Dataset Statistics"):
            gr.Markdown("## Number of examples per split")

            number_of_pairs_gr_plot = gr.Plot(number_of_pairs_plot, label=None)

            gr.Markdown("## Edit Distance Distribution (w/o PyCharm Logs)")

            edit_distance_gr_plot = gr.Plot(edit_distance_plot(), label=None)

        with gr.Tab("Experimental Results"):
            gr.Markdown(
                "Here, we provide the additional experimental results with different text similarity metrics used as the target online metric, "
                "in addition to edit distance between generated messages G and their edited counterparts E."
            )

            gr.Markdown(
                "Please, select one of the available metrics **m*** below to see the correlations **Q(m, m\*)** of offline text similarity metrics with **m*** as an online metric."
            )

            for metric in METRICS:
                with gr.Tab(metric):
                    gr.Markdown(
                        f"The table below presents the correlation coefficients **Q(m, m\*)** where {metric} is used as an online metric **m***."
                    )

                    result_df = get_correlations_table(METRICS[metric])
                    gr.DataFrame(result_df)

        application.load(update_dataset_view_manual, inputs=slider_manual, outputs=view_manual)

        application.load(
            update_dataset_view_synthetic_backward, inputs=slider_synthetic_backward, outputs=view_synthetic_backward
        )

        application.load(
            update_dataset_view_synthetic_forward, inputs=slider_synthetic_forward, outputs=view_synthetic_forward
        )

        application.load(
            update_dataset_view_synthetic_forward_backward,
            inputs=slider_synthetic_forward_backward,
            outputs=view_synthetic_forward_backward,
        )

    application.launch()