import json
from datetime import datetime

from typing import Literal, List

import pandas as pd
import plotly.express as px
from huggingface_hub import HfFileSystem, hf_hub_download

# from: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/monitor/monitor.py#L389
KEY_TO_CATEGORY_NAME = {
    "full": "Overall",
    "dedup": "De-duplicate Top Redundant Queries (soon to be default)",
    "math": "Math",
    "if": "Instruction Following",
    "multiturn": "Multi-Turn",
    "coding": "Coding",
    "hard_6": "Hard Prompts (Overall)",
    "hard_english_6": "Hard Prompts (English)",
    "long_user": "Longer Query",
    "english": "English",
    "chinese": "Chinese",
    "french": "French",
    "german": "German",
    "spanish": "Spanish",
    "russian": "Russian",
    "japanese": "Japanese",
    "korean": "Korean",
    "no_tie": "Exclude Ties",
    "no_short": "Exclude Short Query (< 5 tokens)",
    "no_refusal": "Exclude Refusal",
    "overall_limit_5_user_vote": "overall_limit_5_user_vote",
    "full_old": "Overall (Deprecated)",
}

CAT_NAME_TO_EXPLANATION = {
    "Overall": "Overall Questions",
    "De-duplicate Top Redundant Queries (soon to be default)": "De-duplicate top redundant queries (top 0.1%). See details in [blog post](https://lmsys.org/blog/2024-05-17-category-hard/#note-enhancing-quality-through-de-duplication).",
    "Math": "Math",
    "Instruction Following": "Instruction Following",
    "Multi-Turn": "Multi-Turn Conversation (>= 2 turns)",
    "Coding": "Coding: whether conversation contains code snippets",
    "Hard Prompts (Overall)": "Hard Prompts (Overall): details in [blog post](https://lmsys.org/blog/2024-05-17-category-hard/)",
    "Hard Prompts (English)": "Hard Prompts (English), note: the delta is to English Category. details in [blog post](https://lmsys.org/blog/2024-05-17-category-hard/)",
    "Longer Query": "Longer Query (>= 500 tokens)",
    "English": "English Prompts",
    "Chinese": "Chinese Prompts",
    "French": "French Prompts",
    "German": "German Prompts",
    "Spanish": "Spanish Prompts",
    "Russian": "Russian Prompts",
    "Japanese": "Japanese Prompts",
    "Korean": "Korean Prompts",
    "Exclude Ties": "Exclude Ties and Bothbad",
    "Exclude Short Query (< 5 tokens)": "Exclude Short User Query (< 5 tokens)",
    "Exclude Refusal": 'Exclude model responses with refusal (e.g., "I cannot answer")',
    "overall_limit_5_user_vote": "overall_limit_5_user_vote",
    "Overall (Deprecated)": "Overall without De-duplicating Top Redundant Queries (top 0.1%). See details in [blog post](https://lmsys.org/blog/2024-05-17-category-hard/#note-enhancing-quality-through-de-duplication).",
}

PROPRIETARY_LICENSES = ["Proprietary", "Proprietory"]


def download_latest_data_from_space(
    repo_id: str, file_type: Literal["pkl", "csv"]
) -> str:
    """
    Downloads the latest data file of the specified file type from the given repository space.

    Args:
        repo_id (str): The ID of the repository space.
        file_type (Literal["pkl", "csv"]): The type of the data file to download. Must be either "pkl" or "csv".

    Returns:
        str: The local file path of the downloaded data file.
    """

    def extract_date(filename):
        return filename.split("/")[-1].split(".")[0].split("_")[-1]

    fs = HfFileSystem()
    data_file_path = f"spaces/{repo_id}/*.{file_type}"
    files = fs.glob(data_file_path)
    files = [
        file for file in files if "leaderboard_table" in file or "elo_results" in file
    ]
    latest_file = sorted(files, key=extract_date, reverse=True)[0]

    latest_filepath_local = hf_hub_download(
        repo_id=repo_id,
        filename=latest_file.split("/")[-1],
        repo_type="space",
    )
    print(latest_file.split("/")[-1])
    return latest_filepath_local


def get_constants(dfs):
    """
    Calculate and return the minimum and maximum Elo scores, as well as the maximum number of models per month.

    Parameters:
    - dfs (dict): A dictionary containing DataFrames for different categories.

    Returns:
    - min_elo_score (float): The minimum Elo score across all DataFrames.
    - max_elo_score (float): The maximum Elo score across all DataFrames.
    - upper_models_per_month (int): The maximum number of models per month per license across all DataFrames.
    """
    filter_ranges = {}
    for k, df in dfs.items():
        filter_ranges[k] = {
            "min_elo_score": df["rating"].min().round(),
            "max_elo_score": df["rating"].max().round(),
            "upper_models_per_month": int(
                df.groupby(["Month-Year", "License"])["rating"]
                .apply(lambda x: x.count())
                .max()
            ),
        }

    min_elo_score = float("inf")
    max_elo_score = float("-inf")
    upper_models_per_month = 0

    for _, value in filter_ranges.items():
        min_elo_score = min(min_elo_score, value["min_elo_score"])
        max_elo_score = max(max_elo_score, value["max_elo_score"])
        upper_models_per_month = max(
            upper_models_per_month, value["upper_models_per_month"]
        )
    return min_elo_score, max_elo_score, upper_models_per_month


def update_release_date_mapping(
    new_model_keys_to_add: List[str],
    leaderboard_df: pd.DataFrame,
    release_date_mapping: pd.DataFrame,
) -> pd.DataFrame:
    """
    Update the release date mapping with new model keys.

    Args:
        new_model_keys_to_add (List[str]): A list of new model keys to add to the release date mapping.
        leaderboard_df (pd.DataFrame): The leaderboard DataFrame containing the model information.
        release_date_mapping (pd.DataFrame): The current release date mapping DataFrame.

    Returns:
        pd.DataFrame: The updated release date mapping DataFrame.
    """
    # if any, add those to the release date mapping
    if new_model_keys_to_add:
        for key in new_model_keys_to_add:
            new_entry = {
                "key": key,
                "Model": leaderboard_df[leaderboard_df["key"] == key]["Model"].values[
                    0
                ],
                "Release Date": datetime.today().strftime("%Y-%m-%d"),
            }

            with open("release_date_mapping.json", "r") as file:
                data = json.load(file)

            data.append(new_entry)

            with open("release_date_mapping.json", "w") as file:
                json.dump(data, file, indent=4)

            print(f"Added {key} to release_date_mapping.json")

        # reload the release date mapping
        release_date_mapping = pd.read_json(
            "release_date_mapping.json", orient="records"
        )
    return release_date_mapping


def format_data(df):
    """
    Formats the given DataFrame by performing the following operations:
    - Converts the 'License' column values to 'Proprietary LLM' if they are in PROPRIETARY_LICENSES, otherwise 'Open LLM'.
    - Converts the 'Release Date' column to datetime format.
    - Adds a new 'Month-Year' column by extracting the month and year from the 'Release Date' column.
    - Rounds the 'rating' column to the nearest integer.
    - Resets the index of the DataFrame.

    Args:
        df (pandas.DataFrame): The DataFrame to be formatted.

    Returns:
        pandas.DataFrame: The formatted DataFrame.
    """
    df["License"] = df["License"].apply(
        lambda x: "Proprietary LLM" if x in PROPRIETARY_LICENSES else "Open LLM"
    )
    df["Release Date"] = pd.to_datetime(df["Release Date"])
    df["Month-Year"] = df["Release Date"].dt.to_period("M")
    df["rating"] = df["rating"].round()
    return df.reset_index(drop=True)


def get_trendlines(fig):

    trend_lines = px.get_trendline_results(fig)

    return [
        trend_lines.iloc[i]["px_fit_results"].params.tolist()
        for i in range(len(trend_lines))
    ]


def find_crossover_point(b1, m1, b2, m2):
    """
    Determine the X value at which two trendlines will cross over.

    Parameters:
    m1 (float): Slope of the first trendline.
    b1 (float): Intercept of the first trendline.
    m2 (float): Slope of the second trendline.
    b2 (float): Intercept of the second trendline.

    Returns:
    float: The X value where the two trendlines cross.
    """
    if m1 == m2:
        raise ValueError("The trendlines are parallel and do not cross.")

    x_crossover = (b2 - b1) / (m1 - m2)
    return x_crossover