import os
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
import streamlit as st
from datasets import get_dataset_config_names
from dotenv import load_dotenv
from huggingface_hub import DatasetFilter, list_datasets

if Path(".env").is_file():
    load_dotenv(".env")

auth_token = os.getenv("HF_HUB_TOKEN")

TASKS = sorted(get_dataset_config_names("ought/raft"))
# Split and capitalize the task names, e.g. banking_77 => Banking 77
FORMATTED_TASK_NAMES = sorted([" ".join(t.capitalize() for t in task.split("_")) for task in TASKS])


def download_submissions():
    filt = DatasetFilter(benchmark="raft")
    all_submissions = list_datasets(filter=filt, full=True, use_auth_token=auth_token)
    submissions = []

    for dataset in all_submissions:
        tags = dataset.cardData
        if tags.get("type") == "evaluation":
            submissions.append(dataset)
    return submissions


def format_submissions(submissions):
    submission_data = {
        **{"Submitter": []},
        **{"Submission Name": []},
        **{"Submission Date": []},
        **{t: [] for t in TASKS},
    }

    # The following picks the latest submissions which adhere to the model card schema
    for submission in submissions:
        submission_id = submission.id
        card_data = submission.cardData
        username = card_data["submission_dataset"].split("/")[0]
        submission_data["Submitter"].append(username)
        submission_id = card_data["submission_id"]
        submission_name, sha, timestamp = submission_id.split("__")
        # Format submission names with new backend constraints
        # TODO(lewtun): make this less hacky!
        if "_XXX_" in submission_name:
            submission_name = submission_name.replace("_XXX_", " ")
        if "_DDD_" in submission_name:
            submission_name = submission_name.replace("_DDD_", "--")
        submission_data["Submission Name"].append(submission_name)
        # Handle mismatch in epoch microseconds vs epoch seconds in new AutoTrain API
        if len(timestamp) > 10:
            timestamp = pd.to_datetime(int(timestamp))
        else:
            timestamp = pd.to_datetime(int(timestamp), unit="s")
        submission_data["Submission Date"].append(datetime.date(timestamp).strftime("%b %d, %Y"))

        for task in card_data["results"]:
            task_data = task["task"]
            task_name = task_data["name"]
            score = task_data["metrics"][0]["value"]
            submission_data[task_name].append(score)

    df = pd.DataFrame(submission_data)
    df.insert(3, "Overall", df[TASKS].mean(axis=1))
    df = df.copy().sort_values("Overall", ascending=False)
    df.rename(columns={k: v for k, v in zip(TASKS, FORMATTED_TASK_NAMES)}, inplace=True)
    # Start ranking from 1
    df.insert(0, "Rank", np.arange(1, len(df) + 1))
    return df


###########
### APP ###
###########
st.set_page_config(layout="wide")
st.title("RAFT: Real-world Annotated Few-shot Tasks")
st.markdown(
    """
⚠️ **The RAFT benchmark is currently undergoing maintenance and is not accepting submissions at the moment. We apologise for the inconvenience.**

Large pre-trained language models have shown promise for few-shot learning, completing text-based tasks given only a few task-specific examples. Will models soon solve classification tasks that have so far been reserved for human research assistants? 

[RAFT](https://raft.elicit.org) is a few-shot classification benchmark that tests language models:

- across multiple domains (lit review, tweets, customer interaction, etc.)
- on economically valuable classification tasks (someone inherently cares about the task)
- in a setting that mirrors deployment (50 examples per task, info retrieval allowed, hidden test set)

To submit to RAFT, follow the instruction posted on [this page](https://huggingface.co/datasets/ought/raft-submission).
"""
)
submissions = download_submissions()
print(f"INFO - downloaded {len(submissions)} submissions")
df = format_submissions(submissions)
styler = pd.io.formats.style.Styler(df, precision=3).set_properties(
    **{"white-space": "pre-wrap", "text-align": "center"}
)
# hack to remove index column: https://discuss.streamlit.io/t/questions-on-st-table/6878/3
st.markdown(
    """
<style>
table td:nth-child(1) {
    display: none
}
table th:nth-child(1) {
    display: none
}
</style>
""",
    unsafe_allow_html=True,
)
st.table(styler)