import torch
import spacy
import en_core_web_sm
from torch import nn
import math


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

from transformers import AutoModel, TrainingArguments, Trainer, RobertaTokenizer, RobertaModel
from transformers import AutoTokenizer

model_checkpoint = "ehsanaghaei/SecureBERT"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
roberta_model = RobertaModel.from_pretrained(model_checkpoint).to(device)

nlp = en_core_web_sm.load()
pos_spacy_tag_list = ["ADJ","ADP","ADV","AUX","CCONJ","DET","INTJ","NOUN","NUM","PART","PRON","PROPN","PUNCT","SCONJ","SYM","VERB","SPACE","X"]
ner_spacy_tag_list = [bio + entity for entity in list(nlp.get_pipe('ner').labels) for bio in ["B-", "I-"]] + ["O"]


class CustomRobertaWithPOS(nn.Module):
    def __init__(self, num_classes):
        super(CustomRobertaWithPOS, self).__init__()
        self.num_classes = num_classes
        self.pos_embed = nn.Embedding(len(pos_spacy_tag_list), 16)
        self.ner_embed = nn.Embedding(len(ner_spacy_tag_list), 16)
        self.roberta = roberta_model
        self.dropout1 = nn.Dropout(0.2)
        self.fc1 = nn.Linear(self.roberta.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask, pos_spacy, ner_spacy, dep_spacy, depth_spacy):
        outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_output = outputs.last_hidden_state

        pos_mask = pos_spacy != -100

        pos_one_hot = torch.zeros((pos_spacy.shape[0], pos_spacy.shape[1], len(pos_spacy_tag_list)), dtype=torch.long)
        pos_one_hot[pos_mask, pos_spacy[pos_mask]] = 1
        pos_one_hot = pos_one_hot.to(device)

        ner_mask = ner_spacy != -100

        ner_one_hot = torch.zeros((ner_spacy.shape[0], ner_spacy.shape[1], len(ner_spacy_tag_list)), dtype=torch.long)
        ner_one_hot[ner_mask, ner_spacy[ner_mask]] = 1
        ner_one_hot = ner_one_hot.to(device)

        features_concat = last_hidden_output
        features_concat = self.dropout1(features_concat)

        logits = self.fc1(features_concat)

        return logits


def tokenize_and_align_labels_with_pos_ner_dep(examples, tokenizer, label_all_tokens = True):
    tokenized_inputs = tokenizer(examples["tokens"], padding='max_length', truncation=True, is_split_into_words=True)
    #tokenized_inputs.pop('input_ids')
    ner_spacy = []
    pos_spacy = []
    dep_spacy = []
    depth_spacy = []

    for i, (pos, ner, dep, depth) in enumerate(zip(examples["pos_spacy"], 
                                                   examples["ner_spacy"], 
                                                   examples["dep_spacy"], 
                                                   examples["depth_spacy"])):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        ner_spacy_ids = []
        pos_spacy_ids = []
        dep_spacy_ids = []
        depth_spacy_ids = []

        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                ner_spacy_ids.append(-100)
                pos_spacy_ids.append(-100)
                dep_spacy_ids.append(-100)
                depth_spacy_ids.append(-100)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                ner_spacy_ids.append(ner[word_idx])
                pos_spacy_ids.append(pos[word_idx])
                dep_spacy_ids.append(dep[word_idx])
                depth_spacy_ids.append(depth[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                ner_spacy_ids.append(ner[word_idx] if label_all_tokens else -100)
                pos_spacy_ids.append(pos[word_idx] if label_all_tokens else -100)
                dep_spacy_ids.append(dep[word_idx] if label_all_tokens else -100)
                depth_spacy_ids.append(depth[word_idx] if label_all_tokens else -100)
            previous_word_idx = word_idx

        ner_spacy.append(ner_spacy_ids)
        pos_spacy.append(pos_spacy_ids)
        dep_spacy.append(dep_spacy_ids)
        depth_spacy.append(depth_spacy_ids)

    tokenized_inputs["pos_spacy"] = pos_spacy
    tokenized_inputs["ner_spacy"] = ner_spacy
    tokenized_inputs["dep_spacy"] = dep_spacy
    tokenized_inputs["depth_spacy"] = depth_spacy

    return tokenized_inputs


def find_nearest_nugget_features(doc, start_idx, end_idx, event_nuggets):
            nearest_subtype = None
            nearest_dist = math.inf
            relative_pos = None

            mid_idx = (end_idx + start_idx) / 2
            for nugget in event_nuggets:
                mid_nugget_idx = (nugget["nugget"]["startOffset"] + nugget["nugget"]["endOffset"]) / 2
                dist = abs(mid_nugget_idx - mid_idx)

                if dist < nearest_dist:
                    nearest_dist = dist
                    nearest_subtype = nugget["subtype"]
                    for sent in doc.sents:
                        if between_idxs(mid_idx, sent.start_char, sent.end_char) and between_idxs(mid_nugget_idx, sent.start_char, sent.end_char):
                            if mid_idx < mid_nugget_idx:
                                relative_pos = "before-same-sentence"
                            else:
                                relative_pos = "after-same-sentence"
                            break
                        elif between_idxs(mid_nugget_idx, sent.start_char, sent.end_char) and mid_idx > mid_nugget_idx:
                            relative_pos = "after-differ-sentence"
                            break
                        elif between_idxs(mid_idx, sent.start_char, sent.end_char) and mid_idx < mid_nugget_idx:
                            relative_pos = "before-differ-sentence"
                            break
            
            nearest_dist = int(min(10, nearest_dist // 20))
            return nearest_subtype, nearest_dist, relative_pos

def find_dep_depth(token):
            depth = 0
            current_token = token
            while current_token.head != current_token:
                depth += 1
                current_token = current_token.head
            return min(depth, 16)
        
def between_idxs(idx, start_idx, end_idx):
    return idx >= start_idx and idx <= end_idx