import os 
os.environ["TF_ENABLE_ONEDNN_OPTS"] = '0'

from huggingface_hub import login


from typing import Union, Any, Dict
# from datasets.arrow_dataset import Batch

import argparse
import datasets
from transformers.utils import logging, check_min_version
from transformers.utils.versions import require_version

from retro_reader import RetroReader
from retro_reader.constants import EXAMPLE_FEATURES
import torch

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.13.0.dev0")

require_version("datasets>=1.8.0")

logger = logging.get_logger(__name__)


def schema_integrate(example) -> Union[Dict, Any]:
    title = example["title"]
    question = example["question"]
    context = example["context"]
    guid = example["id"]
    classtype = [""] * len(title)
    dataset_name = source = ["squad_v2"] * len(title)
    answers, is_impossible = [], []
    for answer_examples in example["answers"]:
        if answer_examples["text"]:
            answers.append(answer_examples)
            is_impossible.append(False)
        else:
            answers.append({"text": [""], "answer_start": [-1]})
            is_impossible.append(True)
    # The feature names must be sorted.
    return {
        "guid": guid,
        "question": question,
        "context": context,
        "answers": answers,
        "title": title,
        "classtype": classtype,
        "source": source,
        "is_impossible": is_impossible,
        "dataset": dataset_name,
    }


# data augmentation for multiple answers
def data_aug_for_multiple_answers(examples) -> Union[Dict, Any]:
    result = {key: [] for key in examples.keys()}
    
    def update(i, answers=None):
        for key in result.keys():
            if key == "answers" and answers is not None:
                result[key].append(answers)
            else:
                result[key].append(examples[key][i])
                
    for i, (answers, unanswerable) in enumerate(
        zip(examples["answers"], examples["is_impossible"])
    ):
        answerable = not unanswerable
        assert (
            len(answers["text"]) == len(answers["answer_start"]) or
            answers["answer_start"][0] == -1
        )
        if answerable and len(answers["text"]) > 1:
            for n_ans in range(len(answers["text"])):
                ans = {
                    "text": [answers["text"][n_ans]],
                    "answer_start": [answers["answer_start"][n_ans]],
                }
                update(i, ans)
        elif not answerable:
            update(i, {"text": [], "answer_start": []})
        else:
            update(i)
            
    return result


def main(args):
    # Load SQuAD V2.0 dataset
    print("Loading SQuAD v2.0 dataset ...")
    squad_v2 = datasets.load_dataset("squad_v2")
    # Integrate into the schema used in this library
    # Note: The columns used for preprocessing are `question`, `context`, `answers`
    #       and `is_impossible`. The remaining columns are columns that exist to 
    #       process other types of data.
    
    # Minize the dataset for debugging
    if args.debug:
        squad_v2["train"] = squad_v2["train"].select(range(5))
        squad_v2["validation"] = squad_v2["validation"].select(range(5))
    
    print("Integrating into the schema used in this library ...")
    squad_v2 = squad_v2.map(
        schema_integrate, 
        batched=True,
        remove_columns=squad_v2.column_names["train"],
        features=EXAMPLE_FEATURES,
    )
    # num_rows in train: 130,319, num_unanswerable in train: 43,498
    # num_rows in valid:  11,873, num_unanswerable in valid:  5,945
    num_unanswerable_train = sum(squad_v2["train"]["is_impossible"])
    num_unanswerable_valid = sum(squad_v2["validation"]["is_impossible"])
    logger.warning(f"Number of unanswerable sample for SQuAD v2.0 train dataset: {num_unanswerable_train}")
    logger.warning(f"Number of unanswerable sample for SQuAD v2.0 validation dataset: {num_unanswerable_valid}")
    # Train data augmentation for multiple answers
    # no answer {"text": [], "answer_start": [-1]} -> {"text": [], "answer_start": []}
    
    print("Data augmentation for multiple answers ...")
    squad_v2_train = squad_v2["train"].map(
        data_aug_for_multiple_answers,
        batched=True,
        batch_size=args.batch_size,
        num_proc=5,
    )
    squad_v2 = datasets.DatasetDict({
        "train": squad_v2_train,              # num_rows: 130,319
        "validation": squad_v2["validation"]  # num_rows:  11,873
    })
    # Load Retro Reader
    # features: parse arguments
    #           make train/eval dataset from examples
    #           load model from 🤗 hub
    #           set sketch/intensive reader and rear verifier
    print("Loading Retro Reader ...")
    retro_reader = RetroReader.load(
        train_examples=squad_v2["train"],
        eval_examples=squad_v2["validation"],
        config_file=args.configs,
        device="cuda" if torch.cuda.is_available() else "cpu",
    )
    if args.resume_checkpoint:
        retro_reader = retro_reader.load_checkpoint(args.resume_checkpoint)
    
    # Train
    print("Training ...")
    retro_reader.train(module=args.module)
    logger.warning("Train retrospective reader Done.")
    
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--configs", "-c", type=str, default="configs/train_distilbert.yaml", help="config file path")
    parser.add_argument("--batch_size", "-b", type=int, default=1024, help="batch size")
    parser.add_argument("--resume_checkpoint", "-r", type=str, default=None, help="resume checkpoint path")
    parser.add_argument("--module", "-m", type=str, default="all", choices=["all", "sketch", "intensive"], help="module to train")
    parser.add_argument("--debug", "-d", action="store_true", help="debug mode")
    args = parser.parse_args()
    main(args)