import gradio as gr
import pandas as pd
import plotly.express as px


BETTERTRANSFORMER_DATA = [
    # open llm
    "Model 🤗",
    "Arch 🏛️",
    "DType 📥",
    "Backend 🏭",
    "Params (B)",
    "Open LLM Score (%)",
    # deployment settings
    "DType 📥",
    "Backend 🏭",
    "Optimization 🛠️",
    "Quantization 🗜️",
    "Optimization 🛠️ BetterTransformer",
    # primary measurements
    "Prefill Latency (s)",
    "Prefill Latency (s) BetterTransformer",
    "Decode Throughput (tokens/s)",
    "Decode Throughput (tokens/s) BetterTransformer",
    "E2E Throughput (tokens/s)",
    "E2E Throughput (tokens/s) BetterTransformer",
    # speedups
    "Prefill Latency Speedup (%)",
    "Decode Throughput Speedup (%)",
]


def get_bt_df(llm_perf_df):
    copy_df = llm_perf_df.copy()
    # seperate original model experiments from BetterTransformer experiments
    original_df = copy_df[(copy_df["Optimization 🛠️"] == "None") & (copy_df["DType 📥"] == "float16")]
    bt_df = copy_df[(copy_df["Optimization 🛠️"] == "BetterTransformer") & (copy_df["DType 📥"] == "float16")]
    # merge the two dataframes
    bt_df = pd.merge(
        original_df,
        bt_df,
        on=["Model 🤗", "Quantization 🗜️"],
        suffixes=["", " BetterTransformer"],
    )
    # compute speedups
    bt_df["Prefill Latency Speedup (%)"] = (
        (bt_df["Prefill Latency (s)"] / bt_df["Prefill Latency (s) BetterTransformer"]) * 100
    ).round(2) - 100
    bt_df["Decode Throughput Speedup (%)"] = (
        (bt_df["Decode Throughput (tokens/s) BetterTransformer"] / bt_df["Decode Throughput (tokens/s)"]) * 100
    ).round(2) - 100
    # filter speedups > 1000%
    bt_df = bt_df[bt_df["Prefill Latency Speedup (%)"] < 1000]
    bt_df = bt_df[bt_df["Decode Throughput Speedup (%)"] < 1000]

    return bt_df


def get_bt_prefill_fig(llm_perf_df):
    bt_df = get_bt_df(llm_perf_df)
    # plot
    prefill_fig = px.box(
        bt_df,
        x="Arch 🏛️",
        y="Prefill Latency Speedup (%)",
        color_discrete_sequence=px.colors.qualitative.Light24,
        custom_data=BETTERTRANSFORMER_DATA,
        color="Quantization 🗜️",
        points="all",
    )
    # add hover data
    prefill_fig.update_traces(
        hovertemplate="<br>".join(
            [f"<b>{column}:</b> %{{customdata[{i}]}}" for i, column in enumerate(BETTERTRANSFORMER_DATA)]
        )
    )
    # add layout
    prefill_fig.update_layout(
        title={
            "text": "Prefill Latency Speedup per Architecture, Compared To Non-Optimized Model",
            "y": 0.95,
            "x": 0.5,
            "xanchor": "center",
            "yanchor": "top",
        },
        xaxis_title="LLM Architecture",
        yaxis_title="Prefill Speedup (%)",
        legend_title="Quantization Scheme",
        width=1200,
        height=600,
    )

    return prefill_fig


def get_bt_decode_fig(llm_perf_df):
    bt_df = get_bt_df(llm_perf_df)
    # plot
    decode_fig = px.box(
        bt_df,
        x="Arch 🏛️",
        y="Decode Throughput Speedup (%)",
        color_discrete_sequence=px.colors.qualitative.Light24,
        custom_data=BETTERTRANSFORMER_DATA,
        color="Quantization 🗜️",
        points="all",
    )
    # add hover data
    decode_fig.update_traces(
        hovertemplate="<br>".join(
            [f"<b>{column}:</b> %{{customdata[{i}]}}" for i, column in enumerate(BETTERTRANSFORMER_DATA)]
        )
    )
    # add layout
    decode_fig.update_layout(
        title={
            "text": "Decode Throughput Speedup per Architecture, Compared To Non-Optimized Model",
            "y": 0.95,
            "x": 0.5,
            "xanchor": "center",
            "yanchor": "top",
        },
        xaxis_title="LLM Architecture",
        yaxis_title="Decode Speedup (%)",
        legend_title="Quantization Scheme",
        width=1200,
        height=600,
    )

    return decode_fig


def create_bt_plots(llm_perf_df):
    # descriptive text
    gr.HTML("👆 Hover over the points 👆 for additional information.", elem_id="text")
    # get figures
    prefill_fig = get_bt_prefill_fig(llm_perf_df)
    decode_fig = get_bt_decode_fig(llm_perf_df)

    # create plots
    prefill_plot = gr.components.Plot(value=prefill_fig, elem_id="plot", show_label=False)
    decode_plot = gr.components.Plot(value=decode_fig, elem_id="plot", show_label=False)

    return prefill_plot, decode_plot