import gradio as gr

from src.llm_perf import get_llm_perf_df
from src.leaderboard import get_leaderboard_df
from src.latency_score_memory import get_lat_score_mem_fig
from src.bettertransformer import get_bt_prefill_fig, get_bt_decode_fig
from src.flashattentionv2 import get_fa2_prefill_fig, get_fa2_decode_fig
from src.quantization_kernels import get_quant_prefill_fig, get_quant_decode_fig


def create_control_panel(machine: str = "hf-dgx-01"):
    # descriptive text
    gr.HTML("Use this control panel to filter the leaderboard.", elem_id="text")
    # controls
    machine_textbox = gr.Textbox(value=machine, visible=False)
    with gr.Row():
        with gr.Column():
            search_bar = gr.Textbox(
                label="Model 🤗",
                info="🔍 Search for a model name",
                elem_id="search-bar",
            )
    with gr.Row():
        with gr.Column(scale=1, variant="panel"):
            score_slider = gr.Slider(
                label="Open LLM Score (%) 📈",
                info="🎚️ Slide to minimum Open LLM score",
                value=0,
                elem_id="threshold-slider",
            )
        with gr.Column(scale=1, variant="panel"):
            memory_slider = gr.Slider(
                label="Peak Memory (MB) 📈",
                info="🎚️ Slide to maximum Peak Memory",
                minimum=0,
                maximum=80 * 1024,
                value=80 * 1024,
                elem_id="memory-slider",
            )
        with gr.Column(scale=1):
            backend_checkboxes = gr.CheckboxGroup(
                label="Backends 🏭",
                choices=["pytorch"],
                value=["pytorch"],
                info="☑️ Select the backends",
                elem_id="backend-checkboxes",
            )
    with gr.Row():
        with gr.Column(scale=1, variant="panel"):
            datatype_checkboxes = gr.CheckboxGroup(
                label="Load DTypes 📥",
                choices=["float32", "float16", "bfloat16"],
                value=["float32", "float16", "bfloat16"],
                info="☑️ Select the load data types",
                elem_id="dtype-checkboxes",
            )
        with gr.Column(scale=1, variant="panel"):
            optimization_checkboxes = gr.CheckboxGroup(
                label="Optimizations 🛠️",
                choices=["None", "BetterTransformer", "FlashAttentionV2"],
                value=["None", "BetterTransformer", "FlashAttentionV2"],
                info="☑️ Select the optimization",
                elem_id="optimization-checkboxes",
            )
        with gr.Column(scale=2):
            quantization_checkboxes = gr.CheckboxGroup(
                label="Quantizations 🗜️",
                choices=[
                    "None",
                    "BnB.4bit",
                    "BnB.8bit",
                    "GPTQ.4bit",
                    "GPTQ.4bit+ExllamaV1",
                    "GPTQ.4bit+ExllamaV2",
                    "AWQ.4bit+GEMM",
                    "AWQ.4bit+GEMV",
                ],
                value=[
                    "None",
                    "BnB.4bit",
                    "BnB.8bit",
                    "GPTQ.4bit",
                    "GPTQ.4bit+ExllamaV1",
                    "GPTQ.4bit+ExllamaV2",
                    "AWQ.4bit+GEMM",
                    "AWQ.4bit+GEMV",
                ],
                info="☑️ Select the quantization schemes",
                elem_id="quantization-checkboxes",
            )
    with gr.Row():
        filter_button = gr.Button(
            value="Filter 🚀",
            elem_id="filter-button",
        )

    return (
        filter_button,
        machine_textbox,
        search_bar,
        score_slider,
        memory_slider,
        backend_checkboxes,
        datatype_checkboxes,
        optimization_checkboxes,
        quantization_checkboxes,
    )


def filter_fn(
    machine,
    model,
    backends,
    datatypes,
    optimizations,
    quantizations,
    score,
    memory,
):
    raw_df = get_llm_perf_df(machine=machine)
    filtered_df = raw_df[
        raw_df["Model 🤗"].str.contains(model, case=False)
        & raw_df["Backend 🏭"].isin(backends)
        & raw_df["DType 📥"].isin(datatypes)
        & raw_df["Optimization 🛠️"].isin(optimizations)
        & raw_df["Quantization 🗜️"].isin(quantizations)
        & (raw_df["Open LLM Score (%)"] >= score)
        & (raw_df["Allocated Memory (MB)"] <= memory)
    ]
    filtered_leaderboard_df = get_leaderboard_df(filtered_df)
    filtered_lat_score_mem_fig = get_lat_score_mem_fig(filtered_df)
    filtered_bt_prefill_fig = get_bt_prefill_fig(filtered_df)
    filtered_bt_decode_fig = get_bt_decode_fig(filtered_df)
    filtered_fa2_prefill_fig = get_fa2_prefill_fig(filtered_df)
    filtered_fa2_decode_fig = get_fa2_decode_fig(filtered_df)
    filtered_quant_prefill_fig = get_quant_prefill_fig(filtered_df)
    filtered_quant_decode_fig = get_quant_decode_fig(filtered_df)

    return [
        filtered_leaderboard_df,
        filtered_lat_score_mem_fig,
        filtered_bt_prefill_fig,
        filtered_bt_decode_fig,
        filtered_fa2_prefill_fig,
        filtered_fa2_decode_fig,
        filtered_quant_prefill_fig,
        filtered_quant_decode_fig,
    ]


def create_control_callback(
    # button
    filter_button,
    # inputs
    machine_textbox,
    search_bar,
    score_slider,
    memory_slider,
    backend_checkboxes,
    datatype_checkboxes,
    optimization_checkboxes,
    quantization_checkboxes,
    # outputs
    leaderboard_table,
    lat_score_mem_plot,
    bt_prefill_plot,
    bt_decode_plot,
    fa2_prefill_plot,
    fa2_decode_plot,
    quant_prefill_plot,
    quant_decode_plot,
):
    filter_button.click(
        fn=filter_fn,
        inputs=[
            machine_textbox,
            search_bar,
            backend_checkboxes,
            datatype_checkboxes,
            optimization_checkboxes,
            quantization_checkboxes,
            score_slider,
            memory_slider,
        ],
        outputs=[
            leaderboard_table,
            lat_score_mem_plot,
            bt_prefill_plot,
            bt_decode_plot,
            fa2_prefill_plot,
            fa2_decode_plot,
            quant_prefill_plot,
            quant_decode_plot,
        ],
    )