import os
import json
import datetime
from email.utils import parseaddr
import numpy as np
import gradio as gr
import pandas as pd
from datasets import load_dataset
from evaluation.evaluator import question_scorer as eval_scorer
from apscheduler.schedulers.background import BackgroundScheduler
from huggingface_hub import HfApi
from content import format_error, format_warning, format_log, TITLE

# Placeholder for the question_scorer function
def question_scorer(prediction, gold_answer):
    acc, has_ans = eval_scorer(prediction, gold_answer)
    return acc, has_ans


# Constants and Configuration
TOKEN = os.environ.get("TOKEN", None)
OWNER = "Ori"
DATA_DATASET = f"Ori/AssistantBench_V1.0"
RESULTS_DATASET = f"Ori/results"
SUBMISSION_DATASET = f"AssistantBench/submissions"
LEADERBOARD_PATH = f"{OWNER}/leaderboard"
api = HfApi()

YEAR_VERSION = "default"

os.makedirs("scored", exist_ok=True)

# Load datasets
eval_results = load_dataset(RESULTS_DATASET, token=TOKEN, download_mode="force_redownload", trust_remote_code=True)
gold_results = load_dataset(DATA_DATASET, token=TOKEN, trust_remote_code=True)

gold_answers = {split: {row["id"]: row["answer"] for row in gold_results[split]} for split in ["test"]}
gold_difficulties = {split: {row["id"]: row["difficulty"] for row in gold_results[split]} for split in ["test"]}


# Function to get dataframe from results
def get_dataframe_from_results(eval_results, split):
    local_df = eval_results[split]
    df = pd.DataFrame(local_df)
    df = df.sort_values(by=["Accuracy"], ascending=False)
    numeric_cols = [c for c in local_df.column_names if "score" in c]
    df[numeric_cols] = df[numeric_cols].multiply(100).round(decimals=2)
    return df

# Update function to format dataframe
def format_dataframe(df):
    df["Accuracy"] = df["Accuracy"].apply(lambda x: f"**{x:.2f}**")
    if "URL" in df.columns:
        df["Model Name"] = df.apply(lambda row: f"[{row['Model Name']}]({row['URL']})", axis=1)
        df = df.drop(columns=["URL"])
    #df = df.rename(columns={"Model Family": "Base Model"})
    df = df[["Model Name", "Accuracy",  "Answer rate", "Precision", "EM", "Accuracy (easy)", "Accuracy (medium)", "Accuracy (hard)", "Base Model", "Organization"]]
    return df


eval_results = load_dataset(RESULTS_DATASET, YEAR_VERSION, token=TOKEN, download_mode="force_redownload", trust_remote_code=True)
eval_dataframe_test = get_dataframe_from_results(eval_results=eval_results, split="test")
eval_dataframe_test = format_dataframe(eval_dataframe_test)

# Function to restart the space
def restart_space():
    api.restart_space(repo_id=LEADERBOARD_PATH, token=TOKEN)


TYPES = ["markdown", "markdown", "number", "number", "number", "number", "number", "number", "str", "str"]


# Function to add a new evaluation
def add_new_eval(
        model_name: str,
        model_family: str,
        url: str,
        path_to_file: str,
        organization: str,
        mail: str,
):
    _, parsed_mail = parseaddr(mail)
    if "@" not in parsed_mail:
        return format_warning("Please provide a valid email address.")

    print("Adding new eval")

    if model_name.lower() in set(
            [m.lower() for m in eval_results["test"]["Model Name"]]) and organization.lower() in set(
            [o.lower() for o in eval_results["test"]["Organization"]]):
        return format_warning("This model has already been submitted.")

    if path_to_file is None:
        return format_warning("Please attach a file.")

    api.upload_file(
        repo_id=SUBMISSION_DATASET,
        path_or_fileobj=path_to_file.name,
        path_in_repo=f"{organization}/{model_name}/{YEAR_VERSION}_test_raw_{datetime.datetime.today()}.jsonl",
        repo_type="dataset",
        token=TOKEN
    )

    file_path = path_to_file.name
    scores = 0
    num_questions = 0

    difficulty_scores = {"Easy": 0, "Medium": 0, "Hard": 0}
    difficulty_counts = {"Easy": 0, "Medium": 0, "Hard": 0}

    all_scores = list()

    with open(f"scored/{organization}_{model_name}.jsonl", "w") as scored_file:
        with open(file_path, 'r') as f:
            submitted_ids = set()
            for ix, line in enumerate(f):
                try:
                    task = json.loads(line)
                except Exception:
                    return format_error(f"Line {ix} is incorrectly formatted. Please fix it and resubmit your file.")

                if "answer" not in task:
                    return format_error(
                        f"Line {ix} contains no answer key. Please fix it and resubmit your file.")

                answer = task["answer"]
                task_id = task["id"]
                if task_id not in gold_answers["test"]:
                    return format_error(
                        f"{task_id} not found in test set. Are you sure you submitted the correct file?")

                score, has_ans = question_scorer(task['answer'], gold_answers["test"][task_id])
                difficulty = gold_difficulties["test"][task_id]

                scored_file.write(
                    json.dumps({
                        "id": task_id,
                        "model_answer": answer,
                        "score": score,
                        "has_ans": has_ans
                    }) + "\n"
                )

                all_scores.append({"score": score, "has_ans": has_ans, "model_answer": answer, 'id': task_id})
                submitted_ids.add(task["id"])
                scores += score
                num_questions += 1
                difficulty_scores[difficulty] += score
                difficulty_counts[difficulty] += 1

    # Check if all gold answer IDs are present in the submission
    missing_ids = set(gold_answers["test"].keys()) - submitted_ids
    if missing_ids:
        return format_error(f"Submission is missing the following IDs: {', '.join(missing_ids)}")

    accuracy_easy = difficulty_scores["Easy"] / difficulty_counts["Easy"] if difficulty_counts["Easy"] > 0 else 0
    accuracy_medium = difficulty_scores["Medium"] / difficulty_counts["Medium"] if difficulty_counts["Medium"] > 0 else 0
    accuracy_hard = difficulty_scores["Hard"] / difficulty_counts["Hard"] if difficulty_counts["Hard"] > 0 else 0

    api.upload_file(
        repo_id=SUBMISSION_DATASET,
        path_or_fileobj=f"scored/{organization}_{model_name}.jsonl",
        path_in_repo=f"{organization}/{model_name}/{YEAR_VERSION}_test_scored_{datetime.datetime.today()}.jsonl",
        repo_type="dataset",
        token=TOKEN
    )

    accuracy = float("{:.1f}".format(np.average([x["score"] for x in all_scores]) * 100))
    coverage = float("{:.1f}".format(np.average([x["has_ans"] for x in all_scores]) * 100))
    em = float("{:.1f}".format(np.average([1 if x["score"] == 1 else 0 for x in all_scores]) * 100))
    precision = float("{:.1f}".format(np.average([x["score"] for x in all_scores if x["has_ans"] == 1]) * 100))
    accuracy_easy = float("{:.1f}".format(accuracy_easy * 100))
    accuracy_medium = float("{:.1f}".format(accuracy_medium * 100))
    accuracy_hard = float("{:.1f}".format(accuracy_hard * 100))

    eval_entry = {
        "Model Name": model_name,
        "Base Model": model_family,
        "URL": url,
        "Organization": organization,
        "Accuracy": accuracy,
        "Accuracy (easy)": accuracy_easy,
        "Accuracy (medium)": accuracy_medium,
        "Accuracy (hard)": accuracy_hard,
        "Answer rate": coverage,
        "Precision": precision,
        "EM": em
    }
    eval_results["test"] = eval_results["test"].add_item(eval_entry)

    eval_results.push_to_hub(RESULTS_DATASET, config_name=YEAR_VERSION, token=TOKEN)

    return format_log(
        f"Model {model_name} submitted by {organization} successfully.\nPlease wait a few hours and refresh the leaderboard to see your score displayed.")


# Function to refresh the results
def refresh():
    eval_results = load_dataset(RESULTS_DATASET, YEAR_VERSION, token=TOKEN, download_mode="force_redownload", trust_remote_code=True)
    eval_dataframe_test = get_dataframe_from_results(eval_results=eval_results, split="test")
    eval_dataframe_test = format_dataframe(eval_dataframe_test)
    return eval_dataframe_test


# Gradio interface
demo = gr.Blocks()
with demo:
    gr.HTML("<h1>AssistantBench</h1>")
    gr.Markdown("""
        AssistantBench aims to evaluate the ability of web agents to assist with real and time-consuming tasks.
        For more information, please check out our paper or the official website.
        To download AssistantBench, press [here](https://huggingface.co/datasets/AssistantBench/AssistantBench).
    """)

    gr.HTML("<h2>AssistantBench Leaderboard</h2>")
    with gr.Tab("Results: Test"):
        leaderboard_table_test = gr.Dataframe(
            value=eval_dataframe_test, datatype=TYPES, interactive=False,
            column_widths=["20%"]
        )

    refresh_button = gr.Button("Refresh")
    refresh_button.click(
        refresh,
        inputs=[],
        outputs=[
            leaderboard_table_test,
        ],
    )

    gr.HTML("<h2>Making a New Submission</h2>")
    with gr.Accordion("Submit a new model for evaluation"):
        with gr.Row():
            gr.Markdown("""
                To make a new submission, upload a predictions file. Our scoring function can be found [here](https://huggingface.co/spaces/AssistantBench/leaderboard/blob/main/scorer.py). We support JSONL files with the following format:
                ```
                {"id": "task_id_1", "answer": "Answer 1 from your model"}
                {"id": "task_id_2", "answer": "Answer 2 from your model"}
                ```
            """)
        with gr.Row():
            with gr.Column():
                model_name_textbox = gr.Textbox(label="Model Name")
                model_family_textbox = gr.Textbox(label="Base Model")
                url_textbox = gr.Textbox(label="URL to Model Information")
            with gr.Column():
                organization = gr.Textbox(label="Organization")
                mail = gr.Textbox(
                    label="Contact Email (will be stored privately & used if there is an issue with your submission)")
                file_output = gr.File()

        submit_button = gr.Button("Submit Eval")
        submission_result = gr.Markdown()
        submit_button.click(
            add_new_eval,
            [
                model_name_textbox,
                model_family_textbox,
                url_textbox,
                file_output,
                organization,
                mail
            ],
            submission_result,
        )

    with gr.Row():
        with gr.Accordion("📙 Citation", open=False):
            citation_text = """@article{yoran-etal-2024-assistantbench,
    title={AssistantBench: Can Web Agents Solve Realistic and Time-Consuming Tasks?}, 
    author={Ori Yoran and Samuel Amouyal and Chaitanya Malaviya and Ben Bogin and Ofir Press and Jonathan Berant},
    year={2024},
    eprint={?},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}"""
            citation_button = gr.Textbox(
                value=citation_text,
                label="Citation",
                lines=20,
                elem_id="citation-button",
                show_copy_button=True
            )

    gr.HTML(
        "<p>We would like to thank the GAIA team for sharing the source code for their leaderboard which we used as a template and HuggingFace for hosting the leaderboard.</p>")

scheduler = BackgroundScheduler()
scheduler.add_job(restart_space, "interval", seconds=3600)
scheduler.start()
demo.launch(debug=True)