import os
from glob import glob

import gradio as gr
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from matplotlib.colors import BoundaryNorm, ListedColormap

all_results = pd.read_pickle("final_df.pkl")


def get_accuracy_dataframe(df_mother, category):
    # Calculate overall model accuracy
    # filter for category only
    df = df_mother[df_mother["category"] == category].copy()
    df["is_answer_correct"] = df["is_answer_correct"].astype(float)
    model_accuracy = df.groupby("model")["is_answer_correct"].mean().reset_index()

    # Calculate model accuracy per difficulty level
    df["difficulty_level"] = df["difficulty_level"].astype(int)
    model_accuracy_per_level = (
        df.groupby(["model", "difficulty_level"])["is_answer_correct"]
        .mean()
        .reset_index()
    )
    model_accuracy_per_level_df = model_accuracy_per_level.pivot(
        index="model", columns="difficulty_level", values="is_answer_correct"
    )

    # Merge overall accuracy and level-based accuracy into a single DataFrame
    model_accuracy_df = model_accuracy.merge(model_accuracy_per_level_df, on="model")
    model_accuracy_df.rename(
        columns={"is_answer_correct": "Overall Accuracy"}, inplace=True
    )

    model_accuracy_df['model'] = model_accuracy_df['model'].apply(lambda x: x.split('/')[-1])

    # Ensure all expected difficulty levels are present
    expected_levels = [1, 2, 3, 4]  # Adjust based on your data
    for level in expected_levels:
        if level not in model_accuracy_df.columns:
            model_accuracy_df[
                level
            ] = None  # Fill missing levels with None or an appropriate value

    # Rename columns to include levels
    level_columns = {level: f"Level {level} Accuracy" for level in expected_levels}
    model_accuracy_df.rename(columns=level_columns, inplace=True)

    # Multiply by 100 and format to one decimal point
    model_accuracy_df = model_accuracy_df.applymap(
        lambda x: round(x * 100, 1) if isinstance(x, float) else x
    )

    # Add headers with icons
    model_accuracy_df.columns = [
        "🤖 Model Name",
        "⭐ Overall",
        "📈 Level 1",
        "🔍 Level 2",
        "📘 Level 3",
        "🔬 Level 4",
    ]

    model_accuracy_df.sort_values(by="⭐ Overall", ascending=False, inplace=True)
    
    return model_accuracy_df


# categories = array(['1shot', 'CoT', 'Textonly', 'vision', 'vision-CoT'], dtype=object)
accuracy_df_textonly = get_accuracy_dataframe(all_results, "Textonly")
accuracy_df_cot = get_accuracy_dataframe(all_results, "CoT")
accuracy_df_vision = get_accuracy_dataframe(all_results, "vision")
accuracy_df_vision_cot = get_accuracy_dataframe(all_results, "vision-CoT")
accuracy_df_1shot = get_accuracy_dataframe(all_results, "1shot")


# Define the column names with icons
headers_with_icons = [
    "🤖 Model Name",
    "⭐ Overall",
    "📈 Level 1",
    "🔍 Level 2",
    "📘 Level 3",
    "🔬 Level 4",
]

column_names = [
    "Model Name",
    "Overall Accuracy",
    "Level 1 Accuracy",
    "Level 2 Accuracy",
    "Level 3 Accuracy",
    "Level 4 Accuracy",
]


def load_heatmap_textonly(evt: gr.SelectData):
    print(f"./heatmaps/{evt.value}_Textonly.jpg")
    heatmap_image = gr.Image(f"./heatmaps/{evt.value}_Textonly.jpg")
    return heatmap_image


def load_heatmap_cot(evt: gr.SelectData):
    heatmap_image = gr.Image(f"./heatmaps/{evt.value}_CoT.jpg")
    return heatmap_image


def load_heatmap_vision(evt: gr.SelectData):
    heatmap_image = gr.Image(f"./heatmaps/{evt.value}_vision.jpg")
    return heatmap_image


def load_heatmap_vision_cot(evt: gr.SelectData):
    heatmap_image = gr.Image(f"./heatmaps/{evt.value}_vision-CoT.jpg")
    return heatmap_image


def load_heatmap_1shot(evt: gr.SelectData):
    heatmap_image = gr.Image(f"./heatmaps/{evt.value}_1shot.jpg")
    return heatmap_image


# Then, use these functions in the corresponding select method calls:

with gr.Blocks() as demo:
    gr.Markdown("# FSM Benchmark Leaderboard")

    # Text-only Benchmark
    with gr.Tab("Text-only Benchmark"):
        leader_board_textonly = gr.Dataframe(
            accuracy_df_textonly, headers=headers_with_icons
        )
        gr.Markdown("## Heatmap")
        heatmap_image_textonly = gr.Image(label="", show_label=False)
        leader_board_textonly.select(
            fn=load_heatmap_textonly, outputs=[heatmap_image_textonly]
        )

    # CoT Benchmark
    with gr.Tab("CoT Benchmark"):
        leader_board_cot = gr.Dataframe(accuracy_df_cot, headers=headers_with_icons)
        gr.Markdown("## Heatmap")
        heatmap_image_cot = gr.Image(label="", show_label=False)
        leader_board_cot.select(fn=load_heatmap_cot, outputs=[heatmap_image_cot])

    # Vision Benchmark
    with gr.Tab("Vision Benchmark"):
        leader_board_vision = gr.Dataframe(
            accuracy_df_vision, headers=headers_with_icons
        )
        gr.Markdown("## Heatmap")
        heatmap_image_vision = gr.Image(label="", show_label=False)
        leader_board_vision.select(
            fn=load_heatmap_vision, outputs=[heatmap_image_vision]
        )

    # Vision-CoT Benchmark
    with gr.Tab("Vision-CoT Benchmark"):
        leader_board_vision_cot = gr.Dataframe(
            accuracy_df_vision_cot, headers=headers_with_icons
        )
        gr.Markdown("## Heatmap")
        heatmap_image_vision_cot = gr.Image(label="", show_label=False)
        leader_board_vision_cot.select(
            fn=load_heatmap_vision_cot, outputs=[heatmap_image_vision_cot]
        )

    # 1shot Benchmark
    with gr.Tab("1shot Benchmark"):
        leader_board_1shot = gr.Dataframe(accuracy_df_1shot, headers=headers_with_icons)
        gr.Markdown("## Heatmap")
        heatmap_image_1shot = gr.Image(label="", show_label=False)
        leader_board_1shot.select(fn=load_heatmap_1shot, outputs=[heatmap_image_1shot])

    demo.launch()