__all__ = ['block', 'make_clickable_model', 'make_clickable_user', 'get_submissions']

import gradio as gr
import pandas as pd
import re
import os
import json
import yaml
import matplotlib.pyplot as plt
import seaborn as sns
import plotnine as p9
import sys
import zipfile
import tempfile
sys.path.append('./src')
sys.path.append('.')

from src.about import *
from src.saving_utils import *
from src.vis_utils import *
from src.bin.PROBE import run_probe


def add_new_eval(
    human_file,
    skempi_file,
    model_name_textbox: str,
    revision_name_textbox: str,
    benchmark_types,
    similarity_tasks,
    function_prediction_aspect,
    function_prediction_dataset,
    family_prediction_dataset,
    save,
):
    # Validate required files based on selected benchmarks
    if any(task in benchmark_types for task in ['similarity', 'family', 'function']) and human_file is None:
        gr.Warning("Human representations are required for similarity, family, or function benchmarks!")
        return -1
    if 'affinity' in benchmark_types and skempi_file is None:
        gr.Warning("SKEMPI representations are required for affinity benchmark!")
        return -1

    processing_info = gr.Info("Your submission is being processed...")

    representation_name = model_name_textbox if revision_name_textbox == '' else revision_name_textbox

    try:
        results = run_probe(
            benchmark_types,
            representation_name,
            human_file,
            skempi_file,
            similarity_tasks,
            function_prediction_aspect,
            function_prediction_dataset,
            family_prediction_dataset,
        )
    except Exception as e:
        gr.Warning("Your submission has not been processed. Please check your representation files!")
        return -1

    # Even if save is False, store the submission (e.g. temporarily) so that the leaderboard includes it.
    if save:
        save_results(representation_name, benchmark_types, results)
    else:
        save_results(representation_name, benchmark_types, results, temporary=True)

    return 0


def refresh_data():
    benchmark_types = ["similarity", "function", "family", "affinity", "leaderboard"]
    for benchmark_type in benchmark_types:
        path = f"/tmp/{benchmark_type}_results.csv"
        if os.path.exists(path):
            os.remove(path)
    benchmark_types.remove("leaderboard")
    download_from_hub(benchmark_types)


def download_leaderboard_csv():
    """Generates a CSV file for the updated leaderboard."""
    df = get_baseline_df(None, None)
    tmp_csv = os.path.join(tempfile.gettempdir(), "leaderboard_download.csv")
    df.to_csv(tmp_csv, index=False)
    return tmp_csv


def generate_plots_based_on_submission(benchmark_types, similarity_tasks, function_prediction_aspect, function_prediction_dataset, family_prediction_dataset):
    """
    For each benchmark type selected during submission, generate a plot based on the corresponding extra parameters.
    """
    tmp_dir = tempfile.mkdtemp()
    plot_files = []
    # Get the current leaderboard to retrieve available method names.
    leaderboard = get_baseline_df(None, None)
    method_names = leaderboard['Method'].unique().tolist()

    for btype in benchmark_types:
        # For each benchmark type, choose plotting parameters based on additional selections.
        if btype == "similarity":
            x_metric = similarity_tasks[0] if similarity_tasks and len(similarity_tasks) > 0 else None
            y_metric = similarity_tasks[1] if similarity_tasks and len(similarity_tasks) > 1 else None
        elif btype == "function":
            x_metric = function_prediction_aspect if function_prediction_aspect else None
            y_metric = function_prediction_dataset if function_prediction_dataset else None
        elif btype == "family":
            x_metric = family_prediction_dataset[0] if family_prediction_dataset and len(family_prediction_dataset) > 0 else None
            y_metric = family_prediction_dataset[1] if family_prediction_dataset and len(family_prediction_dataset) > 1 else None
        elif btype == "affinity":
            x_metric, y_metric = None, None  # Use default plotting for affinity
        else:
            x_metric, y_metric = None, None

        # Generate the plot using your benchmark_plot function.
        plot_img = benchmark_plot(btype, method_names, x_metric, y_metric, None, None, None)
        plot_file = os.path.join(tmp_dir, f"{btype}.png")
        if isinstance(plot_img, plt.Figure):
            plot_img.savefig(plot_file)
            plt.close(plot_img)
        else:
            # Assume plot_img is a file path already.
            plot_file = plot_img
        plot_files.append(plot_file)

    # Zip all plot images
    zip_path = os.path.join(tmp_dir, "submission_plots.zip")
    with zipfile.ZipFile(zip_path, "w") as zipf:
        for file in plot_files:
            zipf.write(file, arcname=os.path.basename(file))
    return zip_path


def submission_callback(
    human_file,
    skempi_file,
    model_name_textbox,
    revision_name_textbox,
    benchmark_types,
    similarity_tasks,
    function_prediction_aspect,
    function_prediction_dataset,
    family_prediction_dataset,
    save_checkbox,
    return_leaderboard,  # Checkbox: if checked, return leaderboard CSV
    return_plots       # Checkbox: if checked, return plot results ZIP
):
    """
    Runs the evaluation and returns files based on selected output options.
    """
    eval_status = add_new_eval(
        human_file,
        skempi_file,
        model_name_textbox,
        revision_name_textbox,
        benchmark_types,
        similarity_tasks,
        function_prediction_aspect,
        function_prediction_dataset,
        family_prediction_dataset,
        save_checkbox,
    )

    if eval_status == -1:
        return "Submission failed. Please check your files and selections.", None, None

    csv_file = None
    plots_file = None
    msg = "Submission processed. "

    if return_leaderboard:
        csv_file = download_leaderboard_csv()
        msg += "Leaderboard CSV is ready. "
    if return_plots:
        plots_file = generate_plots_based_on_submission(
            benchmark_types,
            similarity_tasks,
            function_prediction_aspect,
            function_prediction_dataset,
            family_prediction_dataset,
        )
        msg += "Plot results ZIP is ready."

    return msg, csv_file, plots_file


# --------------------------
# Build the Gradio interface
# --------------------------
block = gr.Blocks()

with block:
    gr.Markdown(LEADERBOARD_INTRODUCTION)

    with gr.Tabs(elem_classes="tab-buttons") as tabs:
        with gr.TabItem("🏅 PROBE Leaderboard", elem_id="probe-benchmark-tab-table", id=1):
            # Leaderboard Tab (unchanged)
            leaderboard = get_baseline_df(None, None)
            method_names = leaderboard['Method'].unique().tolist()
            metric_names = leaderboard.columns.tolist()
            metrics_with_method = metric_names.copy()
            metric_names.remove('Method')

            benchmark_metric_mapping = {
                "similarity": [metric for metric in metric_names if metric.startswith('sim_')],
                "function": [metric for metric in metric_names if metric.startswith('func')],
                "family": [metric for metric in metric_names if metric.startswith('fam_')],
                "affinity": [metric for metric in metric_names if metric.startswith('aff_')],
            }

            leaderboard_method_selector = gr.CheckboxGroup(
                choices=method_names,
                label="Select Methods for the Leaderboard",
                value=method_names,
                interactive=True
            )
            benchmark_type_selector = gr.CheckboxGroup(
                choices=list(benchmark_metric_mapping.keys()),
                label="Select Benchmark Types",
                value=None,
                interactive=True
            )
            leaderboard_metric_selector = gr.CheckboxGroup(
                choices=metric_names,
                label="Select Metrics for the Leaderboard",
                value=None,
                interactive=True
            )

            baseline_value = get_baseline_df(method_names, metric_names)
            baseline_value = baseline_value.applymap(lambda x: round(x, 4) if isinstance(x, (int, float)) else x)
            baseline_header = ["Method"] + metric_names
            baseline_datatype = ['markdown'] + ['number'] * len(metric_names)

            with gr.Row(show_progress=True, variant='panel'):
                data_component = gr.components.Dataframe(
                    value=baseline_value,
                    headers=baseline_header,
                    type="pandas",
                    datatype=baseline_datatype,
                    interactive=False,
                    visible=True,
                )

            leaderboard_method_selector.change(
                get_baseline_df,
                inputs=[leaderboard_method_selector, leaderboard_metric_selector],
                outputs=data_component
            )
            benchmark_type_selector.change(
                lambda selected_benchmarks: update_metrics(selected_benchmarks),
                inputs=[benchmark_type_selector],
                outputs=leaderboard_metric_selector
            )
            leaderboard_metric_selector.change(
                get_baseline_df,
                inputs=[leaderboard_method_selector, leaderboard_metric_selector],
                outputs=data_component
            )

            with gr.Row():
                gr.Markdown(
                    """
                    ## **Visualize the Leaderboard Results**
                    Select options to update the visualization.
                    """
                )
            # Plotting section remains available as before.
            benchmark_type_selector_plot = gr.Dropdown(
                choices=list(benchmark_specific_metrics.keys()),
                label="Select Benchmark Type for Plotting",
                value=None
            )
            with gr.Row():
                x_metric_selector = gr.Dropdown(choices=[], label="Select X-axis Metric", visible=False)
                y_metric_selector = gr.Dropdown(choices=[], label="Select Y-axis Metric", visible=False)
                aspect_type_selector = gr.Dropdown(choices=[], label="Select Aspect Type", visible=False)
                dataset_selector = gr.Dropdown(choices=[], label="Select Dataset", visible=False)
                single_metric_selector = gr.Dropdown(choices=[], label="Select Metric", visible=False)
            method_selector = gr.CheckboxGroup(
                choices=method_names,
                label="Select Methods to Visualize",
                interactive=True,
                value=method_names
            )
            plot_button = gr.Button("Plot")
            with gr.Row(show_progress=True, variant='panel'):
                plot_output = gr.Image(label="Plot")
            benchmark_type_selector_plot.change(
                update_metric_choices,
                inputs=[benchmark_type_selector_plot],
                outputs=[x_metric_selector, y_metric_selector, aspect_type_selector, dataset_selector, single_metric_selector]
            )
            plot_button.click(
                benchmark_plot,
                inputs=[benchmark_type_selector_plot, method_selector, x_metric_selector, y_metric_selector, aspect_type_selector, dataset_selector, single_metric_selector],
                outputs=plot_output
            )

        with gr.TabItem("📝 About", elem_id="probe-benchmark-tab-table", id=2):
            with gr.Row():
                gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="markdown-text")
            with gr.Row():
                gr.Image(
                    value="./src/data/PROBE_workflow_figure.jpg",
                    label="PROBE Workflow Figure",
                    elem_classes="about-image",
                )

        with gr.TabItem("🚀 Submit here! ", elem_id="probe-benchmark-tab-table", id=3):
            with gr.Row():
                gr.Markdown(EVALUATION_QUEUE_TEXT, elem_classes="markdown-text")
            with gr.Row():
                gr.Markdown("# ✉️✨ Submit your model's representation files here!", elem_classes="markdown-text")
            with gr.Row():
                with gr.Column():
                    model_name_textbox = gr.Textbox(label="Method name")
                    revision_name_textbox = gr.Textbox(label="Revision Method Name")
                    benchmark_types = gr.CheckboxGroup(
                        choices=TASK_INFO,
                        label="Benchmark Types",
                        interactive=True,
                    )
                    similarity_tasks = gr.CheckboxGroup(
                        choices=similarity_tasks_options,
                        label="Similarity Tasks (if selected)",
                        interactive=True,
                    )
                    function_prediction_aspect = gr.Radio(
                        choices=function_prediction_aspect_options,
                        label="Function Prediction Aspects (if selected)",
                        interactive=True,
                    )
                    family_prediction_dataset = gr.CheckboxGroup(
                        choices=family_prediction_dataset_options,
                        label="Family Prediction Datasets (if selected)",
                        interactive=True,
                    )
                    function_dataset = gr.Textbox(
                        label="Function Prediction Datasets",
                        visible=False,
                        value="All_Data_Sets"
                    )
                    save_checkbox = gr.Checkbox(
                        label="Save results for leaderboard and visualization",
                        value=True
                    )
                    # New independent checkboxes for output return options:
                    return_leaderboard = gr.Checkbox(
                        label="Return Leaderboard CSV",
                        value=False
                    )
                    return_plots = gr.Checkbox(
                        label="Return Plot Results",
                        value=False
                    )
            with gr.Row():
                human_file = gr.components.File(
                    label="The representation file (csv) for Human dataset",
                    file_count="single",
                    type='filepath'
                )
                skempi_file = gr.components.File(
                    label="The representation file (csv) for SKEMPI dataset",
                    file_count="single",
                    type='filepath'
                )
            submit_button = gr.Button("Submit Eval")
            submission_result_msg = gr.Markdown()
            # Two file outputs: one for CSV, one for Plot ZIP.
            submission_csv_file = gr.File(label="Leaderboard CSV", visible=True)
            submission_plots_file = gr.File(label="Plot Results ZIP", visible=True)
            submit_button.click(
                submission_callback,
                inputs=[
                    human_file,
                    skempi_file,
                    model_name_textbox,
                    revision_name_textbox,
                    benchmark_types,
                    similarity_tasks,
                    function_prediction_aspect,
                    function_dataset,
                    family_prediction_dataset,
                    save_checkbox,
                    return_leaderboard,
                    return_plots,
                ],
                outputs=[submission_result_msg, submission_csv_file, submission_plots_file]
            )

    with gr.Row():
        data_run = gr.Button("Refresh")
        data_run.click(refresh_data, outputs=[data_component])

    with gr.Accordion("Citation", open=False):
        citation_button = gr.Textbox(
            value=CITATION_BUTTON_TEXT,
            label=CITATION_BUTTON_LABEL,
            elem_id="citation-button",
            show_copy_button=True,
        )

block.launch()