import os

import gradio as gr
import pandas as pd
from apscheduler.schedulers.background import BackgroundScheduler
from huggingface_hub import HfApi

from src.about import CITATION_TEXT, INTRODUCTION_TEXT, LLM_BENCHMARKS_TEXT, TITLE
from src.populate import (
    MU_MATH_COLUMNS_DICT,
    U_MATH_COLUMNS_DICT,
    U_MATH_AND_MU_MATH_COLUMNS_DICT,
    Field,
    get_mu_math_leaderboard_df,
    get_u_math_leaderboard_df,
    get_joined_leaderboard_df,
)


def restart_space():
    TOKEN = os.environ.get("HF_TOKEN")  # A read/write token for your org
    API = HfApi(token=TOKEN)
    REPO_ID = "toloka/u-math-leaderboard"
    API.restart_space(repo_id=REPO_ID)


LEADERBOARD_U_MATH_DF = get_u_math_leaderboard_df()
LEADERBOARD_MU_MATH_DF = get_mu_math_leaderboard_df()
LEADERBOARD_U_MATH_MU_MATH_JOINED_DF = get_joined_leaderboard_df()


def init_leaderboard(dataframe: pd.DataFrame, columns_dict: dict[str, Field]) -> gr.components.Component:
    if dataframe is None or dataframe.empty:
        raise ValueError("Leaderboard DataFrame is empty or None.")

    def filter_dataframe_by_selected_columns(full_df: pd.DataFrame, columns: list[str]) -> pd.DataFrame:
        always_here_cols = [c.pretty_name for c in columns_dict.values() if c.never_hidden]
        selected_columns = [c for c in columns if c in full_df.columns and c not in always_here_cols]
        # keep the order of the columns
        filtered_df = full_df[[c for c in full_df.columns if c in (always_here_cols + selected_columns)]]
        return filtered_df

    def filter_dataframe_by_selected_tag_columns(
        full_df: pd.DataFrame, current_tag: str
    ) -> tuple[pd.DataFrame, list[str], str, str]:
        always_here_cols = [c.pretty_name for c in columns_dict.values() if c.never_hidden]
        selected_columns = [
            c.pretty_name for c in columns_dict.values() if current_tag in c.tags and c.pretty_name not in always_here_cols
        ]
        # keep the order of the columns
        filtered_df = full_df[[c for c in full_df.columns if c in (always_here_cols + selected_columns)]]
        _columns_to_select_visibility = [
            c.pretty_name for c in columns_dict.values() if not c.fully_hidden and not c.never_hidden
        ]
        return filtered_df, [c for c in _columns_to_select_visibility if c in filtered_df.columns], "All", "All"

    def filter_dataframe_by_search(full_df: pd.DataFrame, current_df: pd.DataFrame, search: str) -> pd.DataFrame:
        filtered_df = full_df[
            (full_df[columns_dict["model_name"].pretty_name].str.contains(search, case=False, na=False))
        ]
        return filtered_df[current_df.columns]

    def filter_dataframe_by_model_type(
        full_df: pd.DataFrame, current_df: pd.DataFrame, filter_name: str
    ) -> pd.DataFrame:
        if filter_name == "All":
            return full_df[current_df.columns]
        else:
            query_symbol = filter_name[0]
            filtered_df = full_df[full_df[columns_dict["model_type_symbol"].pretty_name] == query_symbol]
            return filtered_df[current_df.columns]

    def filter_dataframe_by_model_size(
        full_df: pd.DataFrame, current_df: pd.DataFrame, filter_name: str
    ) -> pd.DataFrame:
        if filter_name == "All":
            return full_df[current_df.columns]
        else:
            query_symbol = filter_name[0]
            filtered_df = full_df[full_df[columns_dict["model_size_symbol"].pretty_name] == query_symbol]
            return filtered_df[current_df.columns]

    def filter_dataframe_by_model_family(
        full_df: pd.DataFrame, current_df: pd.DataFrame, filter_name: str,
    ) -> pd.DataFrame:
        if filter_name == "All":
            return full_df[current_df.columns]
        else:
            filtered_df = full_df[full_df[columns_dict["model_family"].pretty_name] == filter_name]
            return filtered_df[current_df.columns]

    with gr.Column(scale=1) as col:
        with gr.Row():
            with gr.Column(scale=8):
                with gr.Accordion("➡️ See All Columns", open=False):
                    columns_to_select_visibility = [
                        c.pretty_name for c in columns_dict.values() if not c.fully_hidden and not c.never_hidden
                    ]
                    all_columns_selector = gr.CheckboxGroup(
                        choices=columns_to_select_visibility,
                        value=[
                            c.pretty_name
                            for c in columns_dict.values()
                            if c.pretty_name in columns_to_select_visibility and c.displayed_by_default
                        ],
                        label="Select Columns to Display:",
                        interactive=True,
                        container=False,
                    )
                with gr.Column(variant='panel'):
                    gr.Markdown("Visible Columns:", elem_id="visible-columns-label")
                    all_tags = {}
                    with gr.Row():
                        for c in columns_dict.values():
                            for tag in c.tags:
                                if tag not in all_tags:
                                    all_tags[tag] = gr.Button(tag, interactive=True, size="sm", variant="secondary", min_width=50)
            with gr.Column(scale=8):
                with gr.Row():
                    search_bar = gr.Textbox(
                        placeholder="🔍 Search for your model and press ENTER...",
                        show_label=False,
                        elem_id="search-bar",
                    )

                with gr.Row():
                    model_type_filter_selector = gr.Dropdown(
                        label="Filter model types:",
                        choices=["All", "💙 Open-Weights", "🟥 Proprietary"],
                        value="All",
                        elem_id="model-type-filter",
                        interactive=True,
                        multiselect=False,
                        min_width=120,
                    )
                    model_size_filter_selector = gr.Dropdown(
                        label="Filter model sizes:",
                        choices=["All", "🛴 Tiny (<5B)", "🚗 Small (5-50B)", "🚚 Medium (50-100B)", "🚀 Large (>100B)"],
                        value="All",
                        elem_id="model-size-filter",
                        interactive=True,
                        multiselect=False,
                        min_width=120,
                    )
                    model_family_filter_selector = gr.Dropdown(
                        label="Filter model families:",
                        choices=["All"] + list(dataframe[columns_dict["model_family"].pretty_name].unique()),
                        value="All",
                        elem_id="model-family-filter",
                        interactive=True,
                        multiselect=False,
                        min_width=120,
                    )

        # create the hidden and visible dataframes to display
        hidden_leaderboard_df = gr.components.Dataframe(
            value=dataframe,
            datatype=[c.gradio_column_type for c in columns_dict.values()],
            visible=False,
            interactive=False,
        )
        leaderboard_df = gr.components.Dataframe(
            value=dataframe[[c.pretty_name for c in columns_dict.values() if c.displayed_by_default]],
            datatype=[c.gradio_column_type for c in columns_dict.values()],
            elem_id="leaderboard-df",
            interactive=False,
        )

        # Add the callbacks
        all_columns_selector.change(
            fn=filter_dataframe_by_selected_columns,
            inputs=[hidden_leaderboard_df, all_columns_selector],
            outputs=[leaderboard_df],
        )
        search_bar.submit(
            fn=filter_dataframe_by_search,
            inputs=[hidden_leaderboard_df, leaderboard_df, search_bar],
            outputs=[leaderboard_df],
        )
        model_type_filter_selector.change(
            fn=filter_dataframe_by_model_type,
            inputs=[hidden_leaderboard_df, leaderboard_df, model_type_filter_selector],
            outputs=[leaderboard_df],
        )
        model_size_filter_selector.change(
            fn=filter_dataframe_by_model_size,
            inputs=[hidden_leaderboard_df, leaderboard_df, model_size_filter_selector],
            outputs=[leaderboard_df],
        )
        model_family_filter_selector.change(
            fn=filter_dataframe_by_model_family,
            inputs=[hidden_leaderboard_df, leaderboard_df, model_family_filter_selector],
            outputs=[leaderboard_df],
        )
        # Wire up each visible-column button to filter by tag
        for tag, button in all_tags.items():
            button.click(
                fn=filter_dataframe_by_selected_tag_columns,
                inputs=[hidden_leaderboard_df, button],
                outputs=[leaderboard_df, all_columns_selector, model_type_filter_selector, model_size_filter_selector],
            )

        # On first load, show the default columns
        filter_dataframe_by_selected_columns(dataframe, all_columns_selector.value)
        return col


demo = gr.Blocks(css=".scatter-plot {height: 500px;}")
with demo:
    gr.HTML(TITLE)
    gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")

    print(LEADERBOARD_U_MATH_DF)
    print(LEADERBOARD_MU_MATH_DF)

    with gr.Tabs(elem_classes="tab-buttons") as tabs:
        with gr.TabItem("🏆 U-MATH", elem_id="u-math-benchmark-tab-table", id=0):
            leaderboard_umath = init_leaderboard(LEADERBOARD_U_MATH_DF, U_MATH_COLUMNS_DICT)
            # gr.ScatterPlot(
            #     value=LEADERBOARD_U_MATH_DF,
            #     title="U-MATH: Text vs Visual Accuracy",
            #     x=U_MATH_COLUMNS_DICT["u_math_text_acc"].pretty_name,
            #     y=U_MATH_COLUMNS_DICT["u_math_visual_acc"].pretty_name,
            #     color=U_MATH_COLUMNS_DICT["model_family"].pretty_name,
            #     tooltip=[U_MATH_COLUMNS_DICT["full_model_name"].pretty_name, U_MATH_COLUMNS_DICT["u_math_acc"].pretty_name],
            #     elem_classes="scatter-plot",
            #     height=500,
            # )

        with gr.TabItem("🏅 μ-MATH (Meta-Benchmark)", elem_id="mu-math-benchmark-tab-table", id=1):
            leaderboard_mumath = init_leaderboard(LEADERBOARD_MU_MATH_DF, MU_MATH_COLUMNS_DICT)
            # gr.ScatterPlot(
            #     value=LEADERBOARD_MU_MATH_DF,
            #     title="μ-MATH: True Positive Rate (Recall) vs True Negative Rate (Specificity)",
            #     x=MU_MATH_COLUMNS_DICT["mu_math_tpr"].pretty_name,
            #     y=MU_MATH_COLUMNS_DICT["mu_math_tnr"].pretty_name,
            #     color=MU_MATH_COLUMNS_DICT["model_family"].pretty_name,
            #     tooltip=[MU_MATH_COLUMNS_DICT["full_model_name"].pretty_name, MU_MATH_COLUMNS_DICT["mu_math_f1"].pretty_name],
            #     elem_classes="scatter-plot",
            #     height=500,
            # )

        with gr.TabItem("📊 U-MATH vs μ-MATH", elem_id="u-math-vs-mu-math-tab-table", id=2):
            leaderboard_aggregated = init_leaderboard(LEADERBOARD_U_MATH_MU_MATH_JOINED_DF, U_MATH_AND_MU_MATH_COLUMNS_DICT)
            # gr.ScatterPlot(
            #     value=LEADERBOARD_U_MATH_MU_MATH_JOINED_DF,
            #     title="U-MATH Accuracy (Solving) vs μ-MATH F1 Score (Judging)",
            #     x=U_MATH_AND_MU_MATH_COLUMNS_DICT["u_math_acc"].pretty_name,
            #     y=U_MATH_AND_MU_MATH_COLUMNS_DICT["mu_math_f1"].pretty_name,
            #     color=U_MATH_AND_MU_MATH_COLUMNS_DICT["model_family"].pretty_name,
            #     tooltip=[
            #         U_MATH_AND_MU_MATH_COLUMNS_DICT["full_model_name"].pretty_name,
            #         U_MATH_AND_MU_MATH_COLUMNS_DICT["u_math_text_acc"].pretty_name,
            #         U_MATH_AND_MU_MATH_COLUMNS_DICT["u_math_visual_acc"].pretty_name,
            #     ],
            #     elem_classes="scatter-plot",
            #     height=500,
            # )

        with gr.TabItem("📝 About", elem_id="about-tab-table", id=3):
            gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="markdown-text")

        citation_button = gr.Textbox(
            value=CITATION_TEXT,
            label="📙 Citation",
            lines=9,
            elem_id="citation-button",
            show_copy_button=True,
            container=True,
        )

scheduler = BackgroundScheduler()
scheduler.add_job(restart_space, "interval", seconds=60 * 60)
scheduler.start()
# demo.queue(default_concurrency_limit=40).launch(ssr_mode=False)
demo.queue(default_concurrency_limit=40).launch(allowed_paths=[".cache"])