"""
DatasetArgs Class
=================
"""

from dataclasses import dataclass

import textattack
from textattack.shared.utils import ARGS_SPLIT_TOKEN, load_module_from_file

HUGGINGFACE_DATASET_BY_MODEL = {
    #
    # bert-base-uncased
    #
    "bert-base-uncased-ag-news": ("ag_news", None, "test"),
    "bert-base-uncased-cola": ("glue", "cola", "validation"),
    "bert-base-uncased-imdb": ("imdb", None, "test"),
    "bert-base-uncased-mnli": (
        "glue",
        "mnli",
        "validation_matched",
        None,
        {0: 1, 1: 2, 2: 0},
    ),
    "bert-base-uncased-mrpc": ("glue", "mrpc", "validation"),
    "bert-base-uncased-qnli": ("glue", "qnli", "validation"),
    "bert-base-uncased-qqp": ("glue", "qqp", "validation"),
    "bert-base-uncased-rte": ("glue", "rte", "validation"),
    "bert-base-uncased-sst2": ("glue", "sst2", "validation"),
    "bert-base-uncased-stsb": (
        "glue",
        "stsb",
        "validation",
        None,
        None,
        None,
        5.0,
    ),
    "bert-base-uncased-wnli": ("glue", "wnli", "validation"),
    "bert-base-uncased-mr": ("rotten_tomatoes", None, "test"),
    "bert-base-uncased-snli": ("snli", None, "test", None, {0: 1, 1: 2, 2: 0}),
    "bert-base-uncased-yelp": ("yelp_polarity", None, "test"),
    #
    # distilbert-base-cased
    #
    "distilbert-base-cased-cola": ("glue", "cola", "validation"),
    "distilbert-base-cased-mrpc": ("glue", "mrpc", "validation"),
    "distilbert-base-cased-qqp": ("glue", "qqp", "validation"),
    "distilbert-base-cased-snli": ("snli", None, "test"),
    "distilbert-base-cased-sst2": ("glue", "sst2", "validation"),
    "distilbert-base-cased-stsb": (
        "glue",
        "stsb",
        "validation",
        None,
        None,
        None,
        5.0,
    ),
    "distilbert-base-uncased-ag-news": ("ag_news", None, "test"),
    "distilbert-base-uncased-cola": ("glue", "cola", "validation"),
    "distilbert-base-uncased-imdb": ("imdb", None, "test"),
    "distilbert-base-uncased-mnli": (
        "glue",
        "mnli",
        "validation_matched",
        None,
        {0: 1, 1: 2, 2: 0},
    ),
    "distilbert-base-uncased-mr": ("rotten_tomatoes", None, "test"),
    "distilbert-base-uncased-mrpc": ("glue", "mrpc", "validation"),
    "distilbert-base-uncased-qnli": ("glue", "qnli", "validation"),
    "distilbert-base-uncased-rte": ("glue", "rte", "validation"),
    "distilbert-base-uncased-wnli": ("glue", "wnli", "validation"),
    #
    # roberta-base (RoBERTa is cased by default)
    #
    "roberta-base-ag-news": ("ag_news", None, "test"),
    "roberta-base-cola": ("glue", "cola", "validation"),
    "roberta-base-imdb": ("imdb", None, "test"),
    "roberta-base-mr": ("rotten_tomatoes", None, "test"),
    "roberta-base-mrpc": ("glue", "mrpc", "validation"),
    "roberta-base-qnli": ("glue", "qnli", "validation"),
    "roberta-base-rte": ("glue", "rte", "validation"),
    "roberta-base-sst2": ("glue", "sst2", "validation"),
    "roberta-base-stsb": ("glue", "stsb", "validation", None, None, None, 5.0),
    "roberta-base-wnli": ("glue", "wnli", "validation"),
    #
    # albert-base-v2 (ALBERT is cased by default)
    #
    "albert-base-v2-ag-news": ("ag_news", None, "test"),
    "albert-base-v2-cola": ("glue", "cola", "validation"),
    "albert-base-v2-imdb": ("imdb", None, "test"),
    "albert-base-v2-mr": ("rotten_tomatoes", None, "test"),
    "albert-base-v2-rte": ("glue", "rte", "validation"),
    "albert-base-v2-qqp": ("glue", "qqp", "validation"),
    "albert-base-v2-snli": ("snli", None, "test"),
    "albert-base-v2-sst2": ("glue", "sst2", "validation"),
    "albert-base-v2-stsb": ("glue", "stsb", "validation", None, None, None, 5.0),
    "albert-base-v2-wnli": ("glue", "wnli", "validation"),
    "albert-base-v2-yelp": ("yelp_polarity", None, "test"),
    #
    # xlnet-base-cased
    #
    "xlnet-base-cased-cola": ("glue", "cola", "validation"),
    "xlnet-base-cased-imdb": ("imdb", None, "test"),
    "xlnet-base-cased-mr": ("rotten_tomatoes", None, "test"),
    "xlnet-base-cased-mrpc": ("glue", "mrpc", "validation"),
    "xlnet-base-cased-rte": ("glue", "rte", "validation"),
    "xlnet-base-cased-stsb": (
        "glue",
        "stsb",
        "validation",
        None,
        None,
        None,
        5.0,
    ),
    "xlnet-base-cased-wnli": ("glue", "wnli", "validation"),
}


#
# Models hosted by textattack.
#
TEXTATTACK_DATASET_BY_MODEL = {
    #
    # LSTMs
    #
    "lstm-ag-news": ("ag_news", None, "test"),
    "lstm-imdb": ("imdb", None, "test"),
    "lstm-mr": ("rotten_tomatoes", None, "test"),
    "lstm-sst2": ("glue", "sst2", "validation"),
    "lstm-yelp": ("yelp_polarity", None, "test"),
    #
    # CNNs
    #
    "cnn-ag-news": ("ag_news", None, "test"),
    "cnn-imdb": ("imdb", None, "test"),
    "cnn-mr": ("rotten_tomatoes", None, "test"),
    "cnn-sst2": ("glue", "sst2", "validation"),
    "cnn-yelp": ("yelp_polarity", None, "test"),
    #
    # T5 for translation
    #
    "t5-en-de": (
        "textattack.datasets.helpers.TedMultiTranslationDataset",
        "en",
        "de",
    ),
    "t5-en-fr": (
        "textattack.datasets.helpers.TedMultiTranslationDataset",
        "en",
        "fr",
    ),
    "t5-en-ro": (
        "textattack.datasets.helpers.TedMultiTranslationDataset",
        "en",
        "de",
    ),
    #
    # T5 for summarization
    #
    "t5-summarization": ("gigaword", None, "test"),
}


@dataclass
class DatasetArgs:
    """Arguments for loading dataset from command line input."""

    dataset_by_model: str = None
    dataset_from_huggingface: str = None
    dataset_from_file: str = None
    dataset_split: str = None
    filter_by_labels: list = None

    @classmethod
    def _add_parser_args(cls, parser):
        """Adds dataset-related arguments to an argparser."""

        dataset_group = parser.add_mutually_exclusive_group()
        dataset_group.add_argument(
            "--dataset-by-model",
            type=str,
            required=False,
            default=None,
            help="Dataset to load depending on the name of the model",
        )
        dataset_group.add_argument(
            "--dataset-from-huggingface",
            type=str,
            required=False,
            default=None,
            help="Dataset to load from `datasets` repository.",
        )
        dataset_group.add_argument(
            "--dataset-from-file",
            type=str,
            required=False,
            default=None,
            help="Dataset to load from a file.",
        )
        parser.add_argument(
            "--dataset-split",
            type=str,
            required=False,
            default=None,
            help="Split of dataset to use when specifying --dataset-by-model or --dataset-from-huggingface.",
        )
        parser.add_argument(
            "--filter-by-labels",
            nargs="+",
            type=int,
            required=False,
            default=None,
            help="List of labels to keep in the dataset and discard all others.",
        )
        return parser

    @classmethod
    def _create_dataset_from_args(cls, args):
        """Given ``DatasetArgs``, return specified
        ``textattack.dataset.Dataset`` object."""

        assert isinstance(
            args, cls
        ), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`."

        # Automatically detect dataset for huggingface & textattack models.
        # This allows us to use the --model shortcut without specifying a dataset.
        if hasattr(args, "model"):
            args.dataset_by_model = args.model
        if args.dataset_by_model in HUGGINGFACE_DATASET_BY_MODEL:
            args.dataset_from_huggingface = HUGGINGFACE_DATASET_BY_MODEL[
                args.dataset_by_model
            ]
        elif args.dataset_by_model in TEXTATTACK_DATASET_BY_MODEL:
            dataset = TEXTATTACK_DATASET_BY_MODEL[args.dataset_by_model]
            if dataset[0].startswith("textattack"):
                # unsavory way to pass custom dataset classes
                # ex: dataset = ('textattack.datasets.helpers.TedMultiTranslationDataset', 'en', 'de')
                dataset = eval(f"{dataset[0]}")(*dataset[1:])
                return dataset
            else:
                args.dataset_from_huggingface = dataset

        # Get dataset from args.
        if args.dataset_from_file:
            textattack.shared.logger.info(
                f"Loading model and tokenizer from file: {args.model_from_file}"
            )
            if ARGS_SPLIT_TOKEN in args.dataset_from_file:
                dataset_file, dataset_name = args.dataset_from_file.split(
                    ARGS_SPLIT_TOKEN
                )
            else:
                dataset_file, dataset_name = args.dataset_from_file, "dataset"
            try:
                dataset_module = load_module_from_file(dataset_file)
            except Exception:
                raise ValueError(f"Failed to import file {args.dataset_from_file}")
            try:
                dataset = getattr(dataset_module, dataset_name)
            except AttributeError:
                raise AttributeError(
                    f"Variable ``dataset`` not found in module {args.dataset_from_file}"
                )
        elif args.dataset_from_huggingface:
            dataset_args = args.dataset_from_huggingface
            if isinstance(dataset_args, str):
                if ARGS_SPLIT_TOKEN in dataset_args:
                    dataset_args = dataset_args.split(ARGS_SPLIT_TOKEN)
                else:
                    dataset_args = (dataset_args,)
            if args.dataset_split:
                if len(dataset_args) > 1:
                    dataset_args = (
                        dataset_args[:2] + (args.dataset_split,) + dataset_args[3:]
                    )
                    dataset = textattack.datasets.HuggingFaceDataset(
                        *dataset_args, shuffle=False
                    )
                else:
                    dataset = textattack.datasets.HuggingFaceDataset(
                        *dataset_args, split=args.dataset_split, shuffle=False
                    )
            else:
                dataset = textattack.datasets.HuggingFaceDataset(
                    *dataset_args, shuffle=False
                )
        else:
            raise ValueError("Must supply pretrained model or dataset")

        assert isinstance(
            dataset, textattack.datasets.Dataset
        ), "Loaded `dataset` must be of type `textattack.datasets.Dataset`."

        if args.filter_by_labels:
            dataset.filter_by_labels_(args.filter_by_labels)

        return dataset