import pandas as pd
import json
from typing import Dict, Any, Tuple
import os
from constants import MODEL_NAME_MAP, DIMENSION_NAME_MAP, KEYWORD_NAME_MAP, MODEL_URLS, BASE_MODEL_GROUPS


class MEGABenchEvalDataLoader:
    def __init__(self, base_path):
        self.base_path = base_path
        # Load both model and summary data at once
        self.KEYWORD_DATA, self.SUMMARY_DATA = self._load_data()
        # Add loading of self-reported results
        self.SELF_REPORTED = self._load_self_reported()
        self.SUPER_GROUPS = self._initialize_super_groups()
        self.MODEL_GROUPS = self._initialize_model_groups()

    def _get_base_path(self) -> str:
        raise NotImplementedError("Subclasses must implement _get_base_path")

    def _load_data(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        summary_data = {}
        keyword_data = {}
        model_folders = [f for f in os.listdir(self.base_path) if os.path.isdir(os.path.join(self.base_path, f))]
        for model_name in model_folders:
            model_path = f"{self.base_path}/{model_name}/summary_and_keyword_stats.json"
            with open(model_path, "r") as f:
                data = json.load(f)
                if "keyword_stats" in data:
                    keyword_data[model_name] = data["keyword_stats"]
                if "model_summary" in data:
                    summary_data[model_name] = data["model_summary"]

        return keyword_data, summary_data

    def _load_self_reported(self) -> Dict[str, float]:
        try:
            with open(os.path.join(self.base_path, "self_reported.json"), "r") as f:
                return json.load(f)
        except FileNotFoundError:
            print(
                "Warning: No self-reported file found at",
                os.path.join(os.path.dirname(self.base_path), "self_reported.json"),
            )
            return {}

    def _initialize_super_groups(self):
        # Get a sample model to access the structure
        sample_model = next(iter(self.KEYWORD_DATA))

        # Create groups with task counts
        groups = {}
        self.keyword_display_map = {}  # Add this map to store display-to-original mapping

        for dim in self.KEYWORD_DATA[sample_model]:
            dim_name = DIMENSION_NAME_MAP[dim]
            # Create a list of tuples (display_name, count, keyword) for sorting
            keyword_info = []

            for keyword in self.KEYWORD_DATA[sample_model][dim]:
                # Get the task count for this keyword
                task_count = self.KEYWORD_DATA[sample_model][dim][keyword]["count"]
                original_name = KEYWORD_NAME_MAP.get(keyword, keyword)
                display_name = f"{original_name}({task_count})"
                keyword_info.append((display_name, task_count, keyword))

            # Sort by count (descending) and then by display name (for ties)
            keyword_info.sort(key=lambda x: (-x[1], x[0]))

            # Store sorted display names and update mapping
            groups[dim_name] = [info[0] for info in keyword_info]
            for display_name, _, keyword in keyword_info:
                self.keyword_display_map[display_name] = keyword

        # Sort based on predefined order
        order = ["Application", "Skills", "Output Format", "Input Format", "Visual Input Number"]
        return {k: groups[k] for k in order if k in groups}

    def _initialize_model_groups(self) -> Dict[str, list]:
        # Include both evaluated and self-reported models
        available_models = set(self.KEYWORD_DATA.keys()) | set(self.SELF_REPORTED.keys())

        filtered_groups = {}
        for group_name, models in BASE_MODEL_GROUPS.items():
            if group_name == "All":
                filtered_groups[group_name] = sorted(list(available_models))
            else:
                filtered_models = [model for model in models if model in available_models]
                if filtered_models:
                    filtered_groups[group_name] = filtered_models

        return filtered_groups

    def get_df(self, selected_super_group: str, selected_model_group: str) -> pd.DataFrame:
        original_dimension = get_original_dimension(selected_super_group)
        data = []

        for model in self.MODEL_GROUPS[selected_model_group]:
            if (model not in self.KEYWORD_DATA or model not in self.SUMMARY_DATA) and model not in self.SELF_REPORTED:
                continue
            
            # Basic model information
            row = {
                "Models": get_display_model_name(model, as_link=True),
            }
            
            # Add asterisk for self-reported results
            if model in self.SELF_REPORTED:
                # Store numeric value for sorting but display with asterisk
                row["Overall"] = self.SELF_REPORTED[model]
                row["Overall_display"] = f"{self.SELF_REPORTED[model]:.2f}*"
                row["Core"] = None
                row["Open-ended"] = None
                for display_name in self.SUPER_GROUPS[selected_super_group]:
                    row[display_name] = None
            else:
                model_data = self.KEYWORD_DATA[model]
                summary = self.SUMMARY_DATA[model]
                
                # Store numeric values
                overall_score = round(summary["overall_score"] * 100, 2)
                row["Overall"] = overall_score
                row["Overall_display"] = f"{overall_score:.2f}"
                row["Core"] = round(summary["core"]["macro_mean_score"] * 100, 2)
                row["Open-ended"] = round(summary["open"]["macro_mean_score"] * 100, 2)
                
                # Add dimension-specific scores
                if original_dimension in model_data:
                    for display_name in self.SUPER_GROUPS[selected_super_group]:
                        original_keyword = self.keyword_display_map[display_name]
                        if original_keyword in model_data[original_dimension]:
                            row[display_name] = round(
                                model_data[original_dimension][original_keyword]["average_score"] * 100, 2
                            )
                        else:
                            row[display_name] = None
                else:
                    for display_name in self.SUPER_GROUPS[selected_super_group]:
                        row[display_name] = None

            data.append(row)

        df = pd.DataFrame(data)
        # Sort by numeric Overall column
        df = df.sort_values(by="Overall", ascending=False)
        
        # Replace None with "-" for display
        display_cols = ["Core", "Open-ended"] + self.SUPER_GROUPS[selected_super_group]
        df[display_cols] = df[display_cols].fillna("-")
        
        # Replace Overall with Overall_display
        df["Overall"] = df["Overall_display"]
        df = df.drop("Overall_display", axis=1)
        
        return df

    def get_leaderboard_data(self, selected_super_group: str, selected_model_group: str) -> Tuple[list, list]:
        df = self.get_df(selected_super_group, selected_model_group)

        # Get total task counts from the first model's data
        sample_model = "GPT_4o"
        total_core_tasks = self.SUMMARY_DATA[sample_model]["core"]["num_eval_tasks"]
        total_open_tasks = self.SUMMARY_DATA[sample_model]["open"]["num_eval_tasks"]
        total_tasks = total_core_tasks + total_open_tasks

        # Define headers with task counts on new line using Unicode line break
        column_headers = {
            "Rank": "Rank",
            "Models": "Models",
            "Overall": f"Overall\n({total_tasks})",
            "Core": f"Core\n({total_core_tasks})",
            "Open-ended": f"Open-ended\n({total_open_tasks})",
        }

        # Add rank column to DataFrame
        df = df.reset_index(drop=True)
        df.insert(0, "Rank", range(1, len(df) + 1))

        # Rename the columns in DataFrame to match headers
        df = df.rename(columns=column_headers)

        # For dimension columns, add task counts on new line
        dimension_headers = []
        for display_name in self.SUPER_GROUPS[selected_super_group]:
            task_count = display_name.split("(")[1].rstrip(")")
            base_name = display_name.split("(")[0]
            dimension_headers.append(f"{base_name}\n({task_count})")

        headers = [
            column_headers["Rank"],
            column_headers["Models"],
            column_headers["Overall"],
            column_headers["Core"],
            column_headers["Open-ended"],
        ] + dimension_headers

        data = df[
            [
                column_headers["Rank"],
                column_headers["Models"],
                column_headers["Overall"],
                column_headers["Core"],
                column_headers["Open-ended"],
            ]
            + self.SUPER_GROUPS[selected_super_group]
        ].values.tolist()

        return headers, data


# Keep your helper functions
def get_original_dimension(mapped_dimension):
    return next(k for k, v in DIMENSION_NAME_MAP.items() if v == mapped_dimension)


def get_original_keyword(mapped_keyword):
    return next((k for k, v in KEYWORD_NAME_MAP.items() if v == mapped_keyword), mapped_keyword)


def get_display_model_name(model_name: str, as_link: bool = True) -> str:
    display_name = MODEL_NAME_MAP.get(model_name, model_name)
    if as_link and model_name in MODEL_URLS:
        return f'<a href="{MODEL_URLS[model_name]}" target="_blank" style="text-decoration: none; color: #2196F3;">{display_name}</a>'
    return display_name