import csv
import functools
from typing import Dict, List, Optional, Tuple

import datasets
import pkg_resources
import seqio
import t5
import tensorflow as tf
from t5.data.glue_utils import get_glue_metric, get_super_glue_metric
from t5.evaluation import metrics as mt

import promptsource.templates
from promptsource.seqio_tasks import utils


GET_METRICS = {
    "BLEU": mt.bleu,
    "ROUGE": mt.rouge,
    "Span Squad": mt.span_squad,
    "Squad": mt.squad,
    "Trivia QA": mt.trivia_qa,
    "Accuracy": mt.accuracy,
    "Sequence Accuracy": mt.sequence_accuracy,
    "Pearson Correlation": mt.pearson_corrcoef,
    "Spearman Correlation": mt.spearman_corrcoef,
    "MultiRC": mt.multirc_f1_over_all_answers,
    "AUC": mt.auc,
    "COQA F1": mt.coqa_f1,
    "Edit Distance": mt.edit_distance,
    # "Mean Reciprocal Rank": mt.accuracy,  # NOTE not in T5?
    "Other": mt.accuracy,
    # Missing support for mean_multiclass_f1 etc. which need a num_classes parameter
}

MAX_EXAMPLES_PER_DATASET = 500_000


def strip_whitespace(output_or_target, example=None, is_target=False):
    """Cached tasks from promptsource all have a leading space on the ground-truth targets."""
    return output_or_target.strip()


def maybe_get_class_id_postprocessor(template):
    if template.get_fixed_answer_choices_list():

        def postprocess_fn(output_or_target, example=None, is_target=False):
            output_or_target = strip_whitespace(output_or_target)
            return t5.data.postprocessors.string_label_to_class_id(
                output_or_target, label_classes=template.get_fixed_answer_choices_list()
            )

        return postprocess_fn

    else:
        return strip_whitespace


def get_tf_dataset(split, shuffle_files, seed, dataset_name, subset_name, template, split_mapping):
    # HF datasets does not support file-level shuffling
    del shuffle_files, seed
    dataset = datasets.load_dataset(dataset_name, subset_name)
    dataset = dataset[split_mapping[split]]
    dataset = utils.apply_template(dataset, template)
    return utils.hf_dataset_to_tf_dataset(dataset)


def add_task(dataset_name, subset_name, template_name, task_name=None, split_mapping=None):
    template = all_templates.get_dataset(dataset_name, subset_name)[template_name]
    task_name = task_name or utils.get_task_name(dataset_name, subset_name, template_name)

    if dataset_name == "glue":
        metrics = get_glue_metric(subset_name)
    elif dataset_name == "super_glue":
        if subset_name in ("wsc.fixed", "multirc"):
            # TODO: WSC and MultiRC need special pre/postprocesing
            metrics = [mt.accuracy]
        else:
            metrics = get_super_glue_metric(subset_name)
    else:
        # TODO what if metric is null?
        metrics = [GET_METRICS[m] for m in template.metadata.metrics]

    dataset_splits = utils.get_dataset_splits(dataset_name, subset_name)
    split_mapping = split_mapping or {k: k for k in dataset_splits.keys()}

    dataset_fn = functools.partial(
        get_tf_dataset,
        seed=None,
        dataset_name=dataset_name,
        subset_name=subset_name,
        template=template,
        split_mapping=split_mapping,
    )
    data_source = seqio.FunctionDataSource(
        dataset_fn,
        splits=list(split_mapping.keys()),
        num_input_examples={s: dataset_splits[split_mapping[s]].num_examples for s in split_mapping.keys()},
    )
    output_features = {
        "inputs": seqio.Feature(t5.data.get_default_vocabulary(), add_eos=False, dtype=tf.int32),
        "targets": seqio.Feature(t5.data.get_default_vocabulary(), add_eos=True, dtype=tf.int32),
    }
    preprocessors = [
        seqio.preprocessors.tokenize,
        seqio.preprocessors.append_eos,
        seqio.CacheDatasetPlaceholder(required=False),
    ]

    # Add train and normal eval tasks
    seqio.TaskRegistry.add(
        task_name,
        data_source,
        preprocessors=preprocessors,
        output_features=output_features,
        metric_fns=metrics,
        postprocess_fn=maybe_get_class_id_postprocessor(template),
    )

    # Add rank classification eval task
    if template.answer_choices:
        rank_classification_preprocessor = functools.partial(
            t5.data.preprocessors.rank_classification,
            inputs_fn=lambda ex: tf.fill((len(ex["answer_choices"]),), ex["inputs"]),
            targets_fn=lambda ex: ex["answer_choices"],
            is_correct_fn=lambda ex: tf.equal(ex["answer_choices"], tf.strings.strip(ex["targets"])),
            weight_fn=lambda ex: 1.0,
        )

        fixed_choices = template.get_fixed_answer_choices_list()
        num_classes = len(fixed_choices) if fixed_choices else None
        seqio.TaskRegistry.add(
            task_name + "_score_eval",
            data_source,
            preprocessors=[rank_classification_preprocessor] + preprocessors,
            output_features=output_features,
            metric_fns=[functools.partial(t5.evaluation.metrics.rank_classification, num_classes=num_classes)],
            postprocess_fn=t5.data.postprocessors.rank_classification,
        )


datatset_subset_tuple = Tuple[str, Optional[str]]
d4_train: List[datatset_subset_tuple] = []
d4_eval: List[datatset_subset_tuple] = []
d3_train_gpt: List[datatset_subset_tuple] = []
d3_train_sglue: List[datatset_subset_tuple] = []
bias_fairness_eval: List[datatset_subset_tuple] = []
gsheet: Dict[datatset_subset_tuple, Dict] = {}
experiment_path = pkg_resources.resource_filename(__name__, "experiment_D4.csv")
with open(experiment_path) as exp_file:
    reader = csv.DictReader(exp_file)
    for row in reader:
        if row["skip"]:
            continue
        if row["subset"] == "":
            row["subset"] = None  # to match promptsource.Template object
        dataset_subset = (row["HF_name"], row["subset"])
        if row["do_train"] == "TRUE":
            d4_train.append(dataset_subset)
        if row["do_eval"] == "TRUE":
            d4_eval.append(dataset_subset)
        if row["D3_do_train"] == "TRUE" and "GPT" in row["seed_paper"]:
            d3_train_gpt.append(dataset_subset)
        if row["D3_do_train"] == "TRUE" and row["HF_name"] == "super_glue":
            d3_train_sglue.append(dataset_subset)
        if (
            row["do_eval"] == "TRUE"
            and row["task_by_convention"] == "bias_and_fairness"
            and row["HF_name"] != "winogender"
        ):
            bias_fairness_eval.append(dataset_subset)
        gsheet[dataset_subset] = row
all_datasets = d4_train + d4_eval + d3_train_gpt + d3_train_sglue + bias_fairness_eval

all_templates = promptsource.templates.TemplateCollection()
all_templates.remove("anli")  # Need to special-case ANLI due to weird split conventions

# 3 stages of training/ablation: D4 -> GPT -> SuperGLUE
d4_train_mixture: List[str] = []  # strings are dataset_subset_template
gpt_train_mixture: List[str] = []
sglue_train_mixture: List[str] = []
d4_eval_mixture: List[str] = []
bias_fairness_eval_mixture: List[str] = []
mixture_cap: Dict[str, int] = {}
single_original_task: Dict[Tuple[str, str], str] = {}
all_original_tasks: List[str] = []
for dataset_name, subset_name in all_templates.keys:
    if (dataset_name, subset_name) not in all_datasets:
        all_templates.remove(dataset_name, subset_name)
        continue

    dataset = all_templates.get_dataset(dataset_name, subset_name)
    num_templates = len(dataset.all_template_names)
    train_size = gsheet[(dataset_name, subset_name)]["train_size"]
    if train_size == "":
        train_size = 0
    else:
        train_size = int(train_size)
    if train_size > MAX_EXAMPLES_PER_DATASET:
        cap = MAX_EXAMPLES_PER_DATASET // num_templates
    else:
        cap = train_size
    for template_name in dataset.all_template_names:
        add_task(dataset_name, subset_name, template_name)

        template = dataset[template_name]

        task_name = utils.get_task_name(dataset_name, subset_name, template_name)

        if (dataset_name, subset_name) not in single_original_task and template.metadata.original_task:
            single_original_task[(dataset_name, subset_name)] = task_name

        if template.metadata.original_task:
            all_original_tasks.append(task_name)

        if (dataset_name, subset_name) in d4_train:
            d4_train_mixture.append(task_name)
            mixture_cap[task_name] = cap
        if (dataset_name, subset_name) in d3_train_gpt:
            gpt_train_mixture.append(task_name)
            mixture_cap[task_name] = cap
        if (dataset_name, subset_name) in d3_train_sglue:
            sglue_train_mixture.append(task_name)
            mixture_cap[task_name] = cap
        if (dataset_name, subset_name) in d4_eval:
            if template.metadata.original_task:
                d4_eval_mixture.append(task_name)
            # TODO use template.metadata.answer_choices here for rank eval
        if (dataset_name, subset_name) in bias_fairness_eval:
            bias_fairness_eval_mixture.append(task_name)

# Special case for ANLI, which has weirdly-named splits and rounds that should be subsets
dataset_name, subset_name = ("anli", None)
dataset = all_templates.get_dataset(dataset_name, subset_name)
for anli_round in ("r1", "r2", "r3"):
    for template_name in all_templates.get_dataset(dataset_name, subset_name).all_template_names:
        task_name = utils.get_task_name(dataset_name, subset_name, template_name) + f"_{anli_round}"
        split_mapping = {
            "train": f"train_{anli_round}",
            "validation": f"dev_{anli_round}",
            "test": f"test_{anli_round}",
        }
        add_task(dataset_name, subset_name, template_name, task_name, split_mapping)

        template = dataset[template_name]
        if template.metadata.original_task:
            d4_eval_mixture.append(task_name)  # TODO or add to ANLI special mixture
        # TODO use template.metadata.answer_choices here for rank eval


TASK_BLACKLIST = [
    # Tasks which often tokenize to > 1024 tokens currently
    "hotpot_qa_distractor_Generate_Explanations",
    "hotpot_qa_fullwiki_Generate_Explanations",
    "hotpot_qa_distractor_Generate_Answer_and_Explanations",
    "hotpot_qa_fullwiki_Generate_Answer_and_Explanations",
    "hotpot_qa_fullwiki_Generate_Answer",
    "hotpot_qa_distractor_Generate_Answer",
    "hotpot_qa_distractor_Generate_Title_2",
    "hotpot_qa_fullwiki_Generate_Title_2",
    "hotpot_qa_fullwiki_Generate_Title_1",
    "hotpot_qa_distractor_Generate_Title_1",
    "hotpot_qa_distractor_Generate_Question",
    "hotpot_qa_fullwiki_Generate_Question",
    "tab_fact_tab_fact_tab_fact_3",
    "tab_fact_tab_fact_tab_fact_2",
    "tab_fact_tab_fact_tab_fact_1",
    "tab_fact_tab_fact_tab_fact_7",
    "tab_fact_tab_fact_tab_fact_4",
    "tab_fact_tab_fact_tab_fact_5",
    "tab_fact_tab_fact_tab_fact_6",
    "wiki_hop_masked_Choose_Best_Object_Candidate",
    "wiki_hop_masked_Indirect_Question_about_Birthplace_Citizenship_Place_of_Death",
    "narrativeqa_Template_05",
    "ecthr_cases_alleged_violation_prediction_silver_rationales",
    # Tasks with broken cached files
    "gigaword_summarize_",
]

# Tasks that failed caching (won't try to fix them for now) - remove when we are done
D4_TRAIN_SCORE_EVAL_TASK_BLACKLIST = [
    "amazon_polarity_Is_this_product_review_positive_score_eval",
    "amazon_polarity_Is_this_review_negative_score_eval",
    "amazon_polarity_Is_this_review_score_eval",
    "amazon_polarity_User_recommend_this_product_score_eval",
    "amazon_polarity_convey_negative_or_positive_sentiment_score_eval",
    "amazon_polarity_flattering_or_not_score_eval",
    "amazon_polarity_negative_or_positive_tone_score_eval",
    "amazon_polarity_user_satisfied_score_eval",
    "amazon_polarity_would_you_buy_score_eval",
    "dbpedia_14_given_a_choice_of_categories__score_eval",
    "dbpedia_14_given_list_what_category_does_the_paragraph_belong_to_score_eval",
    "dbpedia_14_pick_one_category_for_the_following_text_score_eval",
    "wiki_hop_original_choose_best_object_affirmative_1_score_eval",
    "wiki_hop_original_choose_best_object_affirmative_2_score_eval",
    "wiki_hop_original_choose_best_object_affirmative_3_score_eval",
    "wiki_hop_original_choose_best_object_interrogative_1_score_eval",
    "wiki_hop_original_choose_best_object_interrogative_2_score_eval",
]

seqio.MixtureRegistry.add(
    "d4_train",
    [task for task in d4_train_mixture if task not in TASK_BLACKLIST],
    default_rate=lambda t: mixture_cap[t.name],
)

seqio.MixtureRegistry.add(
    "gpt_train",
    [task for task in gpt_train_mixture if task not in TASK_BLACKLIST],
    default_rate=lambda t: mixture_cap[t.name],
)

seqio.MixtureRegistry.add(
    "sglue_train",
    [task for task in sglue_train_mixture if task not in TASK_BLACKLIST],
    default_rate=lambda t: mixture_cap[t.name],
)

seqio.MixtureRegistry.add(
    "d4_gpt_train",
    [task for task in d4_train_mixture + gpt_train_mixture if task not in TASK_BLACKLIST],
    default_rate=lambda t: mixture_cap[t.name],
)

seqio.MixtureRegistry.add(
    "d4_gpt_sglue_train",
    [task for task in d4_train_mixture + gpt_train_mixture + sglue_train_mixture if task not in TASK_BLACKLIST],
    default_rate=lambda t: mixture_cap[t.name],
)

seqio.MixtureRegistry.add(
    "d4_eval",
    [task for task in d4_eval_mixture if task not in TASK_BLACKLIST],
    default_rate=functools.partial(seqio.mixing_rate_num_examples, maximum=500_000),
)  # eval mixture does not need to be capped


seqio.MixtureRegistry.add(
    "d4_score_eval",
    [
        task
        for task in seqio.TaskRegistry.names()
        if task.endswith("_score_eval")
        and task.split("_score_eval")[0] in d4_eval_mixture
        and task.split("_score_eval")[0] not in TASK_BLACKLIST
    ],
    default_rate=functools.partial(seqio.mixing_rate_num_examples, maximum=500_000),
)

# Train tasks we don't care about evaluating on
D4_TRAIN_SKIP_EVAL = [
    "paws_labeled_final",
    "adversarial_qa_dbidaf",
    "adversarial_qa_dbert",
    "duorc_ParaphraseRC",
    "dream",
    "amazon_polarity",
    "app_reviews",
    "imdb",
    "wiki_bio",
    "gigaword",
    "multi_news",
    "samsum",
    "dbpedia_14",
    "trec",
]

seqio.MixtureRegistry.add(
    "d4_train_eval",
    [
        task
        for task in d4_train_mixture
        if task not in TASK_BLACKLIST
        and not any([skip in task for skip in D4_TRAIN_SKIP_EVAL])
        and task in all_original_tasks
    ],
    default_rate=lambda t: mixture_cap[t.name],
)

seqio.MixtureRegistry.add(
    "d4_train_score_eval",
    [
        task
        for task in seqio.TaskRegistry.names()
        if task.endswith("_score_eval")
        and task.split("_score_eval")[0] in d4_train_mixture
        and task.split("_score_eval")[0] not in TASK_BLACKLIST
        and task not in D4_TRAIN_SCORE_EVAL_TASK_BLACKLIST
        and not any([skip in task for skip in D4_TRAIN_SKIP_EVAL])
        and task.split("_score_eval")[0] in all_original_tasks
    ],
    default_rate=functools.partial(seqio.mixing_rate_num_examples, maximum=500_000),
)

seqio.MixtureRegistry.add(
    "d4_train_one_og_prompt",
    [task for task in single_original_task.values() if task in d4_train_mixture and task not in TASK_BLACKLIST],
    default_rate=lambda t: mixture_cap[t.name],
)

seqio.MixtureRegistry.add(
    "d4_train_all_og_prompts",
    [task for task in all_original_tasks if task in d4_train_mixture and task not in TASK_BLACKLIST],
    default_rate=lambda t: mixture_cap[t.name],
)

seqio.MixtureRegistry.add(
    "bias_fairness_eval",
    bias_fairness_eval_mixture,
    default_rate=functools.partial(seqio.mixing_rate_num_examples, maximum=500_000),
)

seqio.MixtureRegistry.add(
    "bias_fairness_eval_score_eval",
    [
        task
        for task in seqio.TaskRegistry.names()
        if task.endswith("_score_eval") and task.split("_score_eval")[0] in bias_fairness_eval_mixture
    ],
    default_rate=functools.partial(seqio.mixing_rate_num_examples, maximum=500_000),
)