from vncorenlp import VnCoreNLP

from typing import Union
from transformers import AutoConfig, AutoTokenizer
from Model.NER.VLSP2021.Ner_CRF import PhoBertCrf,PhoBertSoftmax,PhoBertLstmCrf
import re
import os
import torch
import itertools
import numpy as np

MODEL_MAPPING = {
    'vinai/phobert-base': {
        'softmax': PhoBertSoftmax,
        'crf': PhoBertCrf,
        'lstm_crf': PhoBertLstmCrf
    },
}


def normalize_text(txt: str) -> str:
    # Remove special character
    txt = re.sub("\xad|\u200b|\ufeff", "", txt)
    # Normalize vietnamese accents
    txt = re.sub(r"òa", "oà", txt)
    txt = re.sub(r"óa", "oá", txt)
    txt = re.sub(r"ỏa", "oả", txt)
    txt = re.sub(r"õa", "oã", txt)
    txt = re.sub(r"ọa", "oạ", txt)
    txt = re.sub(r"òe", "oè", txt)
    txt = re.sub(r"óe", "oé", txt)
    txt = re.sub(r"ỏe", "oẻ", txt)
    txt = re.sub(r"õe", "oẽ", txt)
    txt = re.sub(r"ọe", "oẹ", txt)
    txt = re.sub(r"ùy", "uỳ", txt)
    txt = re.sub(r"úy", "uý", txt)
    txt = re.sub(r"ủy", "uỷ", txt)
    txt = re.sub(r"ũy", "uỹ", txt)
    txt = re.sub(r"ụy", "uỵ", txt)
    txt = re.sub(r"Ủy", "Uỷ", txt)

    txt = re.sub(r'"', '”', txt)

    # Remove multi-space
    txt = re.sub(" +", " ", txt)
    return txt.strip()
class ViTagger(object):
    def __init__(self, model_path: Union[str or os.PathLike],  no_cuda=False):
        self.device = 'cuda' if not no_cuda and torch.cuda.is_available() else 'cpu'
        print("[ViTagger] VnCoreNLP loading ...")
        self.rdrsegmenter = VnCoreNLP("VnCoreNLP/VnCoreNLP-1.1.1.jar", annotators="wseg", max_heap_size='-Xmx500m')
        print("[ViTagger] Model loading ...")
        self.model, self.tokenizer,  self.max_seq_len, self.label2id, self.use_crf = self.load_model(model_path, device=self.device)
        self.id2label = {idx: label for idx, label in enumerate(self.label2id)}
        print("[ViTagger] All ready!")

    @staticmethod
    def load_model(model_path: Union[str or os.PathLike],  device='cpu'):
        if device == 'cpu':
            checkpoint_data = torch.load(model_path, map_location='cpu')
        else:
            checkpoint_data = torch.load(model_path)
        args = checkpoint_data["args"]
        max_seq_len = args.max_seq_length
        use_crf = True if 'crf' in args.model_arch else False
        tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False)
        config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=len(args.label2id))
        model_clss = MODEL_MAPPING[args.model_name_or_path][args.model_arch]
        model = model_clss(config=config)
        model.load_state_dict(checkpoint_data['model'],strict=False)
        model.to(device)
        model.eval()

        return model, tokenizer, max_seq_len, args.label2id, use_crf

    def preprocess(self, in_raw: str):
        norm_text = normalize_text(in_raw)
        sents = []
        sentences = self.rdrsegmenter.tokenize(norm_text)
        for sentence in sentences:
            sents.append(sentence)
        return sents

    def convert_tensor(self, tokens):
        seq_len = len(tokens)
        encoding = self.tokenizer(tokens,
                                  padding='max_length',
                                  truncation=True,
                                  is_split_into_words=True,
                                  max_length=self.max_seq_len)
        if 'vinai/phobert' in self.tokenizer.name_or_path:
            print(' '.join(tokens))
            subwords = self.tokenizer.tokenize(' '.join(tokens))
            valid_ids = np.zeros(len(encoding.input_ids), dtype=int)
            label_marks = np.zeros(len(encoding.input_ids), dtype=int)
            i = 1
            for idx, subword in enumerate(subwords[:self.max_seq_len - 2]):
                if idx != 0 and subwords[idx - 1].endswith("@@"):
                    continue
                if self.use_crf:
                    valid_ids[i - 1] = idx + 1
                else:
                    valid_ids[idx + 1] = 1
                i += 1
        else:
            valid_ids = np.zeros(len(encoding.input_ids), dtype=int)
            label_marks = np.zeros(len(encoding.input_ids), dtype=int)
            i = 1
            word_ids = encoding.word_ids()
            for idx in range(1, len(word_ids)):
                if word_ids[idx] is not None and word_ids[idx] != word_ids[idx - 1]:
                    if self.use_crf:
                        valid_ids[i - 1] = idx
                    else:
                        valid_ids[idx] = 1
                    i += 1
        if self.max_seq_len >= seq_len + 2:
            label_marks[:seq_len] = [1] * seq_len
        else:
            label_marks[:-2] = [1] * (self.max_seq_len - 2)
        if self.use_crf and label_marks[0] == 0:
            raise f"{tokens} have mark == 0 at index 0!"
        item = {key: torch.as_tensor([val]).to(self.device, dtype=torch.long) for key, val in encoding.items()}
        item['valid_ids'] = torch.as_tensor([valid_ids]).to(self.device, dtype=torch.long)
        item['label_masks'] = torch.as_tensor([valid_ids]).to(self.device, dtype=torch.long)
        return item

    def extract_entity_doc(self, in_raw: str):
        sents = self.preprocess(in_raw)
        print(sents)
        entities_doc = []
        for sent in sents:
            item = self.convert_tensor(sent)
            with torch.no_grad():
                outputs = self.model(**item)
            entity = None
            if isinstance(outputs.tags[0], list):
                tags = list(itertools.chain(*outputs.tags))
            else:
                tags = outputs.tags
            for w, l in list(zip(sent, tags)):
                w = w.replace("_", " ")
                tag = self.id2label[l]
                if not tag == 'O':
                    parts = tag.split('-', 1)
                    prefix = parts[0]
                    tag = parts[1] if len(parts) > 1 else ""
                    if entity is None:
                        entity = (w, tag)
                    else:
                        if entity[-1] == tag:
                            if prefix == 'I':
                                entity = (entity[0] + f' {w}', tag)
                            else:
                                entities_doc.append(entity)
                                entity = (w, tag)
                        else:
                            entities_doc.append(entity)
                            entity = (w, tag)
                elif entity is not None:
                    entities_doc.append(entity)
                    if w != ' ':
                        entities_doc.append((w, 'O'))
                    entity = None
                elif w != ' ':
                    entities_doc.append((w, 'O'))
                    entity = None
        return entities_doc


    def __call__(self, in_raw: str):
        sents = self.preprocess(in_raw)
        entites = []
        for sent in sents:
            item = self.convert_tensor(sent)
            with torch.no_grad():
                outputs = self.model(**item)
            entity = None
            if isinstance(outputs.tags[0], list):
                tags = list(itertools.chain(*outputs.tags))
            else:
                tags = outputs.tags
            for w, l in list(zip(sent, tags)):
                w = w.replace("_", " ")
                tag = self.id2label[l]
                if not tag == 'O':
                    prefix, tag = tag.split('-')
                    if entity is None:
                        entity = (w, tag)
                    else:
                        if entity[-1] == tag:
                            if prefix == 'I':
                                entity = (entity[0] + f' {w}', tag)
                            else:
                                entites.append(entity)
                                entity = (w, tag)
                        else:
                            entites.append(entity)
                            entity = (w, tag)
                elif entity is not None:
                    entites.append(entity)
                    entity = None
                else:
                    entity = None
        return entites