"""
AugmenterArgs Class
===================
"""


from dataclasses import dataclass

AUGMENTATION_RECIPE_NAMES = {
    "wordnet": "textattack.augmentation.WordNetAugmenter",
    "embedding": "textattack.augmentation.EmbeddingAugmenter",
    "charswap": "textattack.augmentation.CharSwapAugmenter",
    "eda": "textattack.augmentation.EasyDataAugmenter",
    "checklist": "textattack.augmentation.CheckListAugmenter",
    "clare": "textattack.augmentation.CLAREAugmenter",
    "back_trans": "textattack.augmentation.BackTranslationAugmenter",
}


@dataclass
class AugmenterArgs:
    """Arguments for performing data augmentation.

    Args:
        input_csv (str): Path of input CSV file to augment.
        output_csv (str): Path of CSV file to output augmented data.
    """

    input_csv: str
    output_csv: str
    input_column: str
    recipe: str = "embedding"
    pct_words_to_swap: float = 0.1
    transformations_per_example: int = 2
    random_seed: int = 42
    exclude_original: bool = False
    overwrite: bool = False
    interactive: bool = False
    fast_augment: bool = False
    high_yield: bool = False
    enable_advanced_metrics: bool = False

    @classmethod
    def _add_parser_args(cls, parser):
        parser.add_argument(
            "--input-csv",
            type=str,
            help="Path of input CSV file to augment.",
        )
        parser.add_argument(
            "--output-csv",
            type=str,
            help="Path of CSV file to output augmented data.",
        )
        parser.add_argument(
            "--input-column",
            "--i",
            type=str,
            help="CSV input column to be augmented",
        )
        parser.add_argument(
            "--recipe",
            "-r",
            help="Name of augmentation recipe",
            type=str,
            default="embedding",
            choices=AUGMENTATION_RECIPE_NAMES.keys(),
        )
        parser.add_argument(
            "--pct-words-to-swap",
            "--p",
            help="Percentage of words to modify when generating each augmented example.",
            type=float,
            default=0.1,
        )
        parser.add_argument(
            "--transformations-per-example",
            "--t",
            help="number of augmentations to return for each input",
            type=int,
            default=2,
        )
        parser.add_argument(
            "--random-seed", default=42, type=int, help="random seed to set"
        )
        parser.add_argument(
            "--exclude-original",
            default=False,
            action="store_true",
            help="exclude original example from augmented CSV",
        )
        parser.add_argument(
            "--overwrite",
            default=False,
            action="store_true",
            help="overwrite output file, if it exists",
        )
        parser.add_argument(
            "--interactive",
            default=False,
            action="store_true",
            help="Whether to run attacks interactively.",
        )
        parser.add_argument(
            "--high_yield",
            default=False,
            action="store_true",
            help="run attacks with high yield.",
        )
        parser.add_argument(
            "--fast_augment",
            default=False,
            action="store_true",
            help="faster augmentation but may use only a few transformations.",
        )
        parser.add_argument(
            "--enable_advanced_metrics",
            default=False,
            action="store_true",
            help="return perplexity and USE score",
        )

        return parser