import os

import gradio as gr
import pandas as pd

from apscheduler.schedulers.background import BackgroundScheduler
from collections import Counter, defaultdict
from datasets import load_dataset
import datasets
from huggingface_hub import HfApi, list_datasets

api = HfApi(token=os.environ.get("HF_TOKEN", None))
def restart_space():
  api.restart_space(repo_id="OpenGenAI/parti-prompts-leaderboard")

parti_prompt_results = []
ORG = "diffusers-parti-prompts"
SUBMISSIONS = {
    "kand2": None,
    "sdxl": None,
    "wuerst": None,
    "karlo": None,
}
LINKS = {
    "kand2": "https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder",
    "sdxl": "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0",
    "wuerst": "https://huggingface.co/warp-ai/wuerstchen",
    "karlo": "https://huggingface.co/kakaobrain/karlo-v1-alpha",
}
MODEL_KEYS = "-".join(SUBMISSIONS.keys())
SUBMISSION_ORG = f"result-{MODEL_KEYS}"

submission_names = list(SUBMISSIONS.keys())
ds = load_dataset("nateraw/parti-prompts")["train"]

parti_prompt_categories = ds["Category"]
parti_prompt_challenge = ds["Challenge"]


def load_submissions():
    all_datasets = list_datasets(author=SUBMISSION_ORG)
    relevant_ids = [d.id for d in all_datasets]
    
    ids = defaultdict(list)
    challenges = defaultdict(list)
    categories = defaultdict(list)

    total_submissions = 0

    for _id in relevant_ids:
        try:
            ds = load_dataset(_id)["train"]
        except:
            # skip dataset
            continue

        all_results = []
        all_ids = []
        for result, image_id in zip(ds["result"], ds["id"]):
            all_result = result.split(",")

            all_results += all_result
            all_ids += (len(all_result) * [image_id])
            
        for result, image_id in zip(all_results, all_ids):
            if result == "":
                print(f"{result} was not solved by any model.")

            elif result not in submission_names:
                import ipdb; ipdb.set_trace()
                # Make sure that incorrect model names are not added
                continue

            ids[result].append(image_id)
            challenges[parti_prompt_challenge[image_id]].append(result)
            categories[parti_prompt_categories[image_id]].append(result)
            total_submissions += 1
    
    all_values = sum(len(v) for v in ids.values())
    main_dict = {k: float('{:.2}'.format(len(v)/all_values)) for k, v in ids.items()}
    challenges = {k: Counter(v) for k, v in challenges.items()}
    categories = {k: Counter(v) for k, v in categories.items()}

    return total_submissions, main_dict, challenges, categories

def sort_by_highest_percentage(df):
    # Convert percentage values to numeric format
    df = df[df.loc[0].sort_values(ascending=False).index]

    return df

def get_dataframe_all():
    total_submissions, main, challenges, categories = load_submissions()
    main_frame = pd.DataFrame([main])

    challenges_frame = pd.DataFrame.from_dict(challenges).fillna(0).T
    challenges_frame = challenges_frame.div(challenges_frame.sum(axis=1), axis=0)

    categories_frame = pd.DataFrame.from_dict(categories).fillna(0).T
    categories_frame = categories_frame.div(categories_frame.sum(axis=1), axis=0)

    main_frame = main_frame.rename(columns={"": "NOT SOLVED"})
    categories_frame = categories_frame.rename(columns={"": "NOT SOLVED"})
    challenges_frame = challenges_frame.rename(columns={"": "NOT SOLVED"})

    main_frame = sort_by_highest_percentage(main_frame)

    main_frame = main_frame.applymap(lambda x: '{:.2%}'.format(x))
    challenges_frame = challenges_frame.applymap(lambda x: '{:.2%}'.format(x))
    categories_frame = categories_frame.applymap(lambda x: '{:.2%}'.format(x))

    categories_frame = categories_frame.reindex(columns=main_frame.columns.to_list())
    challenges_frame = challenges_frame.reindex(columns=main_frame.columns.to_list())

    categories_frame = categories_frame.reset_index().rename(columns={'index': 'Category'})
    challenges_frame = challenges_frame.reset_index().rename(columns={'index': 'Challenge'})

    return total_submissions, main_frame, challenges_frame, categories_frame

TITLE = "# Open Parti Prompts Leaderboard"
DESCRIPTION = """
The *Open Parti Prompts Leaderboard* compares state-of-the-art, open-source text-to-image models to each other according to **human preferences**. \n\n
Text-to-image models are notoriously difficult to evaluate. [FID](https://en.wikipedia.org/wiki/Fr%C3%A9chet_inception_distance) and 
[CLIP Score](https://en.wikipedia.org/wiki/Fr%C3%A9chet_inception_distance) are not enough to accurately state whether a text-to-image model can 
**generate "good" images**. "Good" is extremely difficult to put into numbers. \n\n
Instead, the **Open Parti Prompts Leaderboard** uses human feedback from the community to compare images from different text-to-image models to each other.

\n\n

❤️ ***Please take 3 minutes to contribute to the benchmark.*** \n
👉 ***Play one round of [Open Parti Prompts Game](https://huggingface.co/spaces/OpenGenAI/open-parti-prompts) to contribute 10 answers.*** 🤗
"""

EXPLANATION = """\n\n
## How the is data collected 📊 \n\n

In more detail, the [Open Parti Prompts Game](https://huggingface.co/spaces/OpenGenAI/open-parti-prompts) collects human preferences that state which generated image 
best fits a given prompt from the [Parti Prompts](https://huggingface.co/datasets/nateraw/parti-prompts) dataset. Parti Prompts has been designed to challenge
text-to-image models on prompts of varying categories and difficulty. The images have been pre-generated from the models that are compared in this space.
For more information of how the images were created, please refer to [Open Parti Prompts](https://huggingface.co/spaces/OpenGenAI/open-parti-prompts).
The community's answers are then stored and used in this space to give a human-preference-based comparison of the different models. \n\n

Currently the leaderboard includes the following models:
- [Kandinsky 2.2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder)
- [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
- [Wuerstchen](https://huggingface.co/warp-ai/wuerstchen)
- [Karlo](https://huggingface.co/kakaobrain/karlo-v1-alpha)

In the following you can see three result tables. The first shows the overall comparison of the 4 models. The score states, 
**the percentage at which images generated from the corresponding model are preferred over the image from all other models**. The second and third tables
show you a breakdown analysis per category and per type of challenge as defined by [Parti Prompts](https://huggingface.co/datasets/nateraw/parti-prompts).
"""

GALLERY_COLUMN_NUM = len(SUBMISSIONS)

def refresh():
    return get_dataframe_all()

with gr.Blocks() as demo:
    with gr.Column(visible=True) as intro_view:
        gr.Markdown(TITLE)
        gr.Markdown(DESCRIPTION)
        gr.Markdown(EXPLANATION)

    headers = list(SUBMISSIONS.keys())
    datatype = "str"

    total_submissions, main_df, challenge_df, category_df = get_dataframe_all()

    with gr.Column():
        gr.Markdown("# Open Parti Prompts")
        main_dataframe = gr.Dataframe(
            value=main_df,
            headers=main_df.columns.to_list(),
            datatype="str",
            row_count=main_df.shape[0],
            col_count=main_df.shape[1],
            interactive=False,
        )

    with gr.Column():
        gr.Markdown("## per category")
        cat_dataframe = gr.Dataframe(
            value=category_df,
            headers=category_df.columns.to_list(),
            datatype="str",
            row_count=category_df.shape[0],
            col_count=category_df.shape[1],
            interactive=False,
        )

    with gr.Column():
        gr.Markdown("## per challenge")
        chal_dataframe = gr.Dataframe(
            value=challenge_df,
            headers=challenge_df.columns.to_list(),
            datatype="str",
            row_count=challenge_df.shape[0],
            col_count=challenge_df.shape[1],
            interactive=False,
        )

    with gr.Column():
        gr.Markdown("## # Submissions")
        num_submissions = gr.Number(value=total_submissions, interactive=False)

    with gr.Row():
        refresh_button = gr.Button("Refresh")
        refresh_button.click(refresh, inputs=[], outputs=[num_submissions, main_dataframe, cat_dataframe, chal_dataframe])   

# Restart space every 20 minutes
scheduler = BackgroundScheduler()
scheduler.add_job(restart_space, 'interval', seconds=3600)
scheduler.start()

demo.launch()