import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
from transformers import pipeline
import random
from nltk.corpus import stopwords
import math
from vocabulary_split import split_vocabulary, filter_logits

# Load tokenizer and model for masked language model
tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
model = AutoModelForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer)

# Get permissible vocabulary
permissible, _ = split_vocabulary(seed=42)
permissible_indices = torch.tensor([i in permissible.values() for i in range(len(tokenizer))])

def get_logits_for_mask(model, tokenizer, sentence):
    inputs = tokenizer(sentence, return_tensors="pt")
    mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]

    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    mask_token_logits = logits[0, mask_token_index, :]
    return mask_token_logits.squeeze()

def mask_non_stopword(sentence):
    stop_words = set(stopwords.words('english'))
    words = sentence.split()
    non_stop_words = [word for word in words if word.lower() not in stop_words]
    if not non_stop_words:
        return sentence, None, None
    word_to_mask = random.choice(non_stop_words)
    masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
    logits = get_logits_for_mask(model, tokenizer, masked_sentence)
    filtered_logits = filter_logits(logits, permissible_indices)
    words = [tokenizer.decode([i]) for i in filtered_logits.argsort()[-5:]]
    return masked_sentence, filtered_logits.tolist(), words

def mask_non_stopword_pseudorandom(sentence):
    stop_words = set(stopwords.words('english'))
    words = sentence.split()
    non_stop_words = [word for word in words if word.lower() not in stop_words]
    if not non_stop_words:
        return sentence, None, None
    random.seed(10)  # Fixed seed for pseudo-randomness
    word_to_mask = random.choice(non_stop_words)
    masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
    logits = get_logits_for_mask(model, tokenizer, masked_sentence)
    filtered_logits = filter_logits(logits, permissible_indices)
    words = [tokenizer.decode([i]) for i in filtered_logits.argsort()[-5:]]
    return masked_sentence, filtered_logits.tolist(), words

# New function: mask words between LCS points
def mask_between_lcs(sentence, lcs_points):
    words = sentence.split()
    masked_indices = []

    # Mask between first word and first LCS point
    if lcs_points and lcs_points[0] > 0:
        idx = random.randint(0, lcs_points[0]-1)
        words[idx] = '[MASK]'
        masked_indices.append(idx)
    
    # Mask between LCS points
    for i in range(len(lcs_points) - 1):
        start, end = lcs_points[i], lcs_points[i+1]
        if end - start > 1:
            mask_index = random.randint(start + 1, end - 1)
            words[mask_index] = '[MASK]'
            masked_indices.append(mask_index)
    
    # Mask between last LCS point and last word
    if lcs_points and lcs_points[-1] < len(words) - 1:
        idx = random.randint(lcs_points[-1]+1, len(words)-1)
        words[idx] = '[MASK]'
        masked_indices.append(idx)
    
    masked_sentence = ' '.join(words)
    logits = get_logits_for_mask(model, tokenizer, masked_sentence)
    
    # Now process each masked token separately
    top_words_list = []
    logits_list = []
    for i in range(len(masked_indices)):
        logits_i = logits[i]
        if logits_i.dim() > 1:
            logits_i = logits_i.squeeze()
        filtered_logits_i = filter_logits(logits_i, permissible_indices)
        logits_list.append(filtered_logits_i.tolist())
        top_5_indices = filtered_logits_i.topk(5).indices.tolist()
        top_words = [tokenizer.decode([i]) for i in top_5_indices]
        top_words_list.append(top_words)
    
    return masked_sentence, logits_list, top_words_list


def high_entropy_words(sentence, non_melting_points):
    stop_words = set(stopwords.words('english'))
    words = sentence.split()

    non_melting_words = set()
    for _, point in non_melting_points:
        non_melting_words.update(point.lower().split())

    candidate_words = [word for word in words if word.lower() not in stop_words and word.lower() not in non_melting_words]

    if not candidate_words:
        return sentence, None, None

    max_entropy = -float('inf')
    max_entropy_word = None
    max_logits = None

    for word in candidate_words:
        masked_sentence = sentence.replace(word, '[MASK]', 1)
        logits = get_logits_for_mask(model, tokenizer, masked_sentence)
        filtered_logits = filter_logits(logits, permissible_indices)
        
        # Calculate entropy based on top 5 predictions
        probs = torch.softmax(filtered_logits, dim=-1)
        top_5_probs = probs.topk(5).values
        entropy = -torch.sum(top_5_probs * torch.log(top_5_probs))

        if entropy > max_entropy:
            max_entropy = entropy
            max_entropy_word = word
            max_logits = filtered_logits

    if max_entropy_word is None:
        return sentence, None, None

    masked_sentence = sentence.replace(max_entropy_word, '[MASK]', 1)
    words = [tokenizer.decode([i]) for i in max_logits.argsort()[-5:]]
    return masked_sentence, max_logits.tolist(), words

# New function: mask based on part of speech
def mask_by_pos(sentence, pos_to_mask=['NOUN', 'VERB', 'ADJ']):
    import nltk
    nltk.download('averaged_perceptron_tagger', quiet=True)
    
    words = nltk.word_tokenize(sentence)
    pos_tags = nltk.pos_tag(words)
    
    maskable_words = [word for word, pos in pos_tags if pos[:2] in pos_to_mask]
    
    if not maskable_words:
        return sentence, None, None
    
    word_to_mask = random.choice(maskable_words)
    masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
    
    logits = get_logits_for_mask(model, tokenizer, masked_sentence)
    filtered_logits = filter_logits(logits, permissible_indices)
    words = [tokenizer.decode([i]) for i in filtered_logits.argsort()[-5:]]
    
    return masked_sentence, filtered_logits.tolist(), words

# New function: mask named entities
def mask_named_entity(sentence):
    import nltk
    nltk.download('maxent_ne_chunker', quiet=True)
    nltk.download('words', quiet=True)
    
    words = nltk.word_tokenize(sentence)
    pos_tags = nltk.pos_tag(words)
    named_entities = nltk.ne_chunk(pos_tags)
    
    maskable_words = [word for word, tag in named_entities.leaves() if isinstance(tag, nltk.Tree)]
    
    if not maskable_words:
        return sentence, None, None
    
    word_to_mask = random.choice(maskable_words)
    masked_sentence = sentence.replace(word_to_mask, '[MASK]', 1)
    
    logits = get_logits_for_mask(model, tokenizer, masked_sentence)
    filtered_logits = filter_logits(logits, permissible_indices)
    words = [tokenizer.decode([i]) for i in filtered_logits.argsort()[-5:]]
    
    return masked_sentence, filtered_logits.tolist(), words


# sentence = "This is a sample sentence with some LCS points"
# lcs_points = [2, 5, 8]  # Indices of LCS points
# masked_sentence, logits_list, top_words_list = mask_between_lcs(sentence, lcs_points)

# print("Masked Sentence:", masked_sentence)
# for idx, top_words in enumerate(top_words_list):
#     print(f"Top words for mask {idx+1}:", top_words)