import random
from difflib import Differ

from textattack.attack_recipes import BAEGarg2019
from textattack.datasets import Dataset
from textattack.models.wrappers import HuggingFaceModelWrapper
from findfile import find_files
from flask import Flask
from textattack import Attacker


class ModelWrapper(HuggingFaceModelWrapper):
    def __init__(self, model):
        self.model = model  # pipeline = pipeline

    def __call__(self, text_inputs, **kwargs):
        outputs = []
        for text_input in text_inputs:
            raw_outputs = self.model.infer(text_input, print_result=False, **kwargs)
            outputs.append(raw_outputs["probs"])
        return outputs


class SentAttacker:
    def __init__(self, model, recipe_class=BAEGarg2019):
        model = model
        model_wrapper = ModelWrapper(model)

        recipe = recipe_class.build(model_wrapper)
        # WordNet defaults to english. Set the default language to French ('fra')

        # recipe.transformation.language = "en"

        _dataset = [("", 0)]
        _dataset = Dataset(_dataset)

        self.attacker = Attacker(recipe, _dataset)


def diff_texts(text1, text2):
    d = Differ()
    text1_words = text1.split()
    text2_words = text2.split()
    return [
        (token[2:], token[0] if token[0] != " " else None)
        for token in d.compare(text1_words, text2_words)
    ]


def get_ensembled_tad_results(results):
    target_dict = {}
    for r in results:
        target_dict[r["label"]] = (
            target_dict.get(r["label"]) + 1 if r["label"] in target_dict else 1
        )

    return dict(zip(target_dict.values(), target_dict.keys()))[
        max(target_dict.values())
    ]



def get_sst2_example():
    filter_key_words = [
        ".py",
        ".md",
        "readme",
        "log",
        "result",
        "zip",
        ".state_dict",
        ".model",
        ".png",
        "acc_",
        "f1_",
        ".origin",
        ".adv",
        ".csv",
    ]

    dataset_file = {"train": [], "test": [], "valid": []}
    dataset = "sst2"
    search_path = "./"
    task = "text_defense"
    dataset_file["test"] += find_files(
        search_path,
        [dataset, "test", task],
        exclude_key=[".adv", ".org", ".defense", ".inference", "train."]
                    + filter_key_words,
    )

    for dat_type in ["test"]:
        data = []
        label_set = set()
        for data_file in dataset_file[dat_type]:
            with open(data_file, mode="r", encoding="utf8") as fin:
                lines = fin.readlines()
                for line in lines:
                    text, label = line.split("$LABEL$")
                    text = text.strip()
                    label = int(label.strip())
                    data.append((text, label))
                    label_set.add(label)
        return random.choice(data)


def get_agnews_example():
    filter_key_words = [
        ".py",
        ".md",
        "readme",
        "log",
        "result",
        "zip",
        ".state_dict",
        ".model",
        ".png",
        "acc_",
        "f1_",
        ".origin",
        ".adv",
        ".csv",
    ]

    dataset_file = {"train": [], "test": [], "valid": []}
    dataset = "agnews"
    search_path = "./"
    task = "text_defense"
    dataset_file["test"] += find_files(
        search_path,
        [dataset, "test", task],
        exclude_key=[".adv", ".org", ".defense", ".inference", "train."]
                    + filter_key_words,
    )
    for dat_type in ["test"]:
        data = []
        label_set = set()
        for data_file in dataset_file[dat_type]:
            with open(data_file, mode="r", encoding="utf8") as fin:
                lines = fin.readlines()
                for line in lines:
                    text, label = line.split("$LABEL$")
                    text = text.strip()
                    label = int(label.strip())
                    data.append((text, label))
                    label_set.add(label)
        return random.choice(data)


def get_amazon_example():
    filter_key_words = [
        ".py",
        ".md",
        "readme",
        "log",
        "result",
        "zip",
        ".state_dict",
        ".model",
        ".png",
        "acc_",
        "f1_",
        ".origin",
        ".adv",
        ".csv",
    ]

    dataset_file = {"train": [], "test": [], "valid": []}
    dataset = "amazon"
    search_path = "./"
    task = "text_defense"
    dataset_file["test"] += find_files(
        search_path,
        [dataset, "test", task],
        exclude_key=[".adv", ".org", ".defense", ".inference", "train."]
                    + filter_key_words,
    )

    for dat_type in ["test"]:
        data = []
        label_set = set()
        for data_file in dataset_file[dat_type]:
            with open(data_file, mode="r", encoding="utf8") as fin:
                lines = fin.readlines()
                for line in lines:
                    text, label = line.split("$LABEL$")
                    text = text.strip()
                    label = int(label.strip())
                    data.append((text, label))
                    label_set.add(label)
        return random.choice(data)


def get_imdb_example():
    filter_key_words = [
        ".py",
        ".md",
        "readme",
        "log",
        "result",
        "zip",
        ".state_dict",
        ".model",
        ".png",
        "acc_",
        "f1_",
        ".origin",
        ".adv",
        ".csv",
    ]

    dataset_file = {"train": [], "test": [], "valid": []}
    dataset = "imdb"
    search_path = "./"
    task = "text_defense"
    dataset_file["test"] += find_files(
        search_path,
        [dataset, "test", task],
        exclude_key=[".adv", ".org", ".defense", ".inference", "train."]
                    + filter_key_words,
    )

    for dat_type in ["test"]:
        data = []
        label_set = set()
        for data_file in dataset_file[dat_type]:
            with open(data_file, mode="r", encoding="utf8") as fin:
                lines = fin.readlines()
                for line in lines:
                    text, label = line.split("$LABEL$")
                    text = text.strip()
                    label = int(label.strip())
                    data.append((text, label))
                    label_set.add(label)
        return random.choice(data)