import textattack
import transformers
from FlowCorrector import Flow_Corrector
import torch
import torch.nn.functional as F

def count_matching_classes(original, corrected):
    if len(original) != len(corrected):
        raise ValueError("Arrays must have the same length")

    matching_count = 0

    for i in range(len(corrected)):
        if original[i] == corrected[i]:
            matching_count += 1

    return matching_count

if __name__ == "main" :

    # Load model, tokenizer, and model_wrapper
    model = transformers.AutoModelForSequenceClassification.from_pretrained(
        "textattack/bert-base-uncased-ag-news"
    )
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        "textattack/bert-base-uncased-ag-news"
    )
    model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)

    # Construct our four components for `Attack`
    from textattack.constraints.pre_transformation import (
        RepeatModification,
        StopwordModification,
    )
    from textattack.constraints.semantics import WordEmbeddingDistance
    from textattack.transformations import WordSwapEmbedding
    from textattack.search_methods import GreedyWordSwapWIR

    goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper)
    constraints = [
        RepeatModification(),
        StopwordModification(),
        WordEmbeddingDistance(min_cos_sim=0.9),
    ]
    transformation = WordSwapEmbedding(max_candidates=50)
    search_method = GreedyWordSwapWIR(wir_method="weighted-saliency")

    # Construct the actual attack
    attack = textattack.Attack(goal_function, constraints, transformation, search_method)
    attack.cuda_()

    # intialisation de coreecteur
    corrector = Flow_Corrector(
        attack,
        word_rank_file="en_full_ranked.json",
        word_freq_file="en_full_freq.json",
    )

    # All these texts are adverserial ones

    with open('perturbed_texts_ag_news.txt', 'r') as f:
        detected_texts = [line.strip() for line in f]


    #These are orginal texts in same order of adverserial ones 

    with open("original_texts_ag_news.txt", "r") as f:
        original_texts = [line.strip() for line in f]

    victim_model = attack.goal_function.model

    # getting original labels for benchmarking later
    original_classes = [
        torch.argmax(F.softmax(victim_model(original_text), dim=1)).item()
        for original_text in original_texts
    ]

    """ 0 :World
        1 : Sports
        2 : Business
        3 : Sci/Tech"""

    corrected_classes = corrector.correct(original_texts)
    print(f"match {count_matching_classes()}")