from utils.finetune import Graph2TextModule
from typing import Dict, List, Tuple, Union, Optional
import torch
import re

if torch.cuda.is_available():
    DEVICE = 'cuda'
else:
    DEVICE = 'cpu'
    print('CUDA NOT AVAILABLE')

CHECKPOINT = 'base/t5-base_13881_val_avg_bleu=68.1000-step_count=5.ckpt'
MAX_LENGTH = 384
SEED = 42


class VerbModule():
    
    def __init__(self, override_args: Dict[str, str] = None): 
        # Model
        if not override_args:
            override_args = {}
        self.g2t_module = Graph2TextModule.load_from_checkpoint(CHECKPOINT, strict=False, **override_args)
        self.tokenizer = self.g2t_module.tokenizer
        # Unk replacer
        self.vocab = self.tokenizer.get_vocab()
        self.convert_some_japanese_characters = True
        self.unk_char_replace_sliding_window_size = 2
        self.unknowns = []

    def __generate_verbalisations_from_inputs(self, inputs: Union[str, List[str]]):
        try:
            inputs_encoding = self.tokenizer.prepare_seq2seq_batch(
                inputs, truncation=True, max_length=MAX_LENGTH, return_tensors='pt'
            )
            inputs_encoding = {k: v.to(DEVICE) for k, v in inputs_encoding.items()}
            
            self.g2t_module.model.eval()
            with torch.no_grad():
                gen_output = self.g2t_module.model.generate(
                    inputs_encoding['input_ids'],
                    attention_mask=inputs_encoding['attention_mask'],
                    use_cache=True,
                    decoder_start_token_id = self.g2t_module.decoder_start_token_id,
                    num_beams= self.g2t_module.eval_beams,
                    max_length= self.g2t_module.eval_max_length,
                    length_penalty=1.0    
                )
        except Exception:
            print(inputs)
            raise

        return gen_output
    
    '''
    We create this function as an alteration from [this one](https://github.com/huggingface/transformers/blob/198c335d219a5eb4d3f124fdd1ce1a9cd9f78a9b/src/transformers/tokenization_utils_fast.py#L537), mainly because the official 'tokenizer.decode' treats all special tokens the same, while we want to drop all special tokens from the decoded sentence EXCEPT for the <unk> token, which we will replace later on.
    '''
    def __decode_ids_to_string_custom(
        self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
    ) -> str:
        filtered_tokens = self.tokenizer.convert_ids_to_tokens(token_ids, skip_special_tokens=False)
        # Do not remove special tokens yet

        # To avoid mixing byte-level and unicode for byte-level BPT
        # we need to build string separatly for added tokens and byte-level tokens
        # cf. https://github.com/huggingface/transformers/issues/1133
        sub_texts = []
        current_sub_text = []
        for token in filtered_tokens:
            if skip_special_tokens and\
                token != self.tokenizer.unk_token and\
                token in self.tokenizer.all_special_tokens:

                continue
            else:
                current_sub_text.append(token)
        if current_sub_text:
            sub_texts.append(self.tokenizer.convert_tokens_to_string(current_sub_text))
        text = " ".join(sub_texts)

        if clean_up_tokenization_spaces:
            clean_text = self.tokenizer.clean_up_tokenization(text)
            return clean_text
        else:
            return text

    def __decode_sentences(self, encoded_sentences: Union[str, List[str]]):
        if type(encoded_sentences) == str:
            encoded_sentences = [encoded_sentences]
        decoded_sentences = [self.__decode_ids_to_string_custom(i, skip_special_tokens=True) for i in encoded_sentences]
        return decoded_sentences
        
    def verbalise_sentence(self, inputs: Union[str, List[str]]):
        if type(inputs) == str:
            inputs = [inputs]
        
        gen_output = self.__generate_verbalisations_from_inputs(inputs)
        
        decoded_sentences = self.__decode_sentences(gen_output)

        if len(decoded_sentences) == 1:
            return decoded_sentences[0]
        else:
            return decoded_sentences

    def verbalise_triples(self, input_triples: Union[Dict[str, str], List[Dict[str, str]], List[List[Dict[str, str]]]]):
        if type(input_triples) == dict:
            input_triples = [input_triples]

        verbalisation_inputs = []
        for triple in input_triples:
            if type(triple) == dict:
                assert 'subject' in triple
                assert 'predicate' in triple
                assert 'object' in triple
                verbalisation_inputs.append(
                    f'translate Graph to English: <H> {triple["subject"]} <R> {triple["predicate"]} <T> {triple["object"]}'
                )
            elif type(triple) == list:
                input_sentence = ['translate Graph to English:']
                for subtriple in triple:
                    assert 'subject' in subtriple
                    assert 'predicate' in subtriple
                    assert 'object' in subtriple
                    input_sentence.append(f'<H> {subtriple["subject"]}')
                    input_sentence.append(f'<R> {subtriple["predicate"]}')
                    input_sentence.append(f'<T> {subtriple["object"]}')
                verbalisation_inputs.append(
                    ' '.join(input_sentence)
                )

        return self.verbalise_sentence(verbalisation_inputs)
        
    def verbalise(self, input: Union[str, List, Dict]):
        try:
            if (type(input) == str) or (type(input) == list and type(input[0]) == str):
                return self.verbalise_sentence(input)
            elif (type(input) == dict) or (type(input) == list and type(input[0]) == dict):
                return self.verbalise_triples(input)
            else:
                return self.verbalise_triples(input)
        except Exception:
            print(f'ERROR VERBALISING {input}')
            raise
                
    def add_label_to_unk_replacer(self, label: str):
        N = self.unk_char_replace_sliding_window_size
        self.unknowns.append({})
        
        # Some pre-processing of labels to normalise some characters
        if self.convert_some_japanese_characters:
            label = label.replace('(','(')
            label = label.replace(')',')')
            label = label.replace('〈','<')
            label = label.replace('/','/')
            label = label.replace('〉','>')        
        
        label_encoded = self.tokenizer.encode(label)
        label_tokens = self.tokenizer.convert_ids_to_tokens(label_encoded)
        
        # Here, we also remove </s> (eos) and <pad> tokens in the replacing key, because:
        # 1) When the whole label is all unk:
        #   label_token_to_string would be '<unk></s>', meaning the replacing key (which is the same) only replaces
        #   the <unk> if it appears at the end of the sentence, which is not the desired effect.
        #   But since this means ANY <unk> will be replaced by this, it would be good to only replace keys that are <unk>
        #   on the last replacing pass.
        # 2) On other cases, then the unk is in the label but not in its entirety, like in the start/end, it might
        #   involve the starting <pad> token or the ending <eos> token on the replacing key, again forcing the replacement
        #   to only happen if the label appears in the end of the sentence.
        label_tokens = [t for t in label_tokens if t not in [
            self.tokenizer.eos_token, self.tokenizer.pad_token
        ]]

        label_token_to_string = self.tokenizer.convert_tokens_to_string(label_tokens)
        unk_token_to_string = self.tokenizer.convert_tokens_to_string([self.tokenizer.unk_token])
                
        #print(label_encoded,label_tokens,label_token_to_string)
        
        match_unks_in_label = re.findall('(?:(?: )*<unk>(?: )*)+', label_token_to_string)
        if len(match_unks_in_label) > 0:
            # If the whole label is made of UNK
            if (match_unks_in_label[0]) == label_token_to_string:
                #print('Label is all unks')                    
                self.unknowns[-1][label_token_to_string.strip()] = label
            # Else, there should be non-UNK characters in the label
            else:
                #print('Label is NOT all unks')
                # Analyse the label with a sliding window of size N (N before, N ahead)
                for idx, token in enumerate(label_tokens):
                    idx_before = max(0,idx-N)
                    idx_ahead = min(len(label_tokens), idx+N+1)
                    
                                       
                    # Found a UNK
                    if token == self.tokenizer.unk_token:
                        
                        # In case multiple UNK, exclude UNKs seen after this one, expand window to other side if possible
                        if len(match_unks_in_label) > 1:
                            #print(idx)
                            #print(label_tokens)
                            #print(label_tokens[idx_before:idx_ahead])
                            #print('HERE!')
                            # Reduce on the right, expanding on the left
                            while self.tokenizer.unk_token in label_tokens[idx+1:idx_ahead]:
                                idx_before = max(0,idx_before-1)
                                idx_ahead = min(idx+2, idx_ahead-1)
                                #print(label_tokens[idx_before:idx_ahead])
                            # Now just reduce on the left
                            while self.tokenizer.unk_token in label_tokens[idx_before:idx]:
                                idx_before = min(idx-1,idx_before+2)
                                #print(label_tokens[idx_before:idx_ahead])

                        span = self.tokenizer.convert_tokens_to_string(label_tokens[idx_before:idx_ahead])        
                        # First token of the label is UNK                        
                        if idx == 1 and label_tokens[0] == '▁':
                            #print('Label begins with unks')
                            to_replace = '^' + re.escape(span).replace(
                                    re.escape(unk_token_to_string),
                                    '.+?'
                                )
                            
                            replaced_span = re.search(
                                to_replace,
                                label
                            )[0]
                            self.unknowns[-1][span.strip()] = replaced_span
                        # Last token of the label is UNK
                        elif idx == len(label_tokens)-2 and label_tokens[-1] == self.tokenizer.eos_token:
                            #print('Label ends with unks')
                            pre_idx = self.tokenizer.convert_tokens_to_string(label_tokens[idx_before:idx])
                            pre_idx_unk_counts = pre_idx.count(unk_token_to_string)
                            to_replace = re.escape(span).replace(
                                    re.escape(unk_token_to_string),
                                    f'[^{re.escape(pre_idx)}]+?'
                                ) + '$'
                            
                            if pre_idx.strip() == '':
                                to_replace = to_replace.replace('[^]', '(?<=\s)[^a-zA-Z0-9]')
                            
                            replaced_span = re.search(
                                to_replace,
                                label
                            )[0]
                            self.unknowns[-1][span.strip()] = replaced_span
                            
                        # A token in-between the label is UNK                            
                        else:
                            #print('Label has unks in the middle')
                            pre_idx = self.tokenizer.convert_tokens_to_string(label_tokens[idx_before:idx])

                            to_replace = re.escape(span).replace(
                                re.escape(unk_token_to_string),
                                f'[^{re.escape(pre_idx)}]+?'
                            )
                            #If there is nothing behind the ??, because it is in the middle but the previous token is also
                            #a ??, then we would end up with to_replace beginning with [^], which we can't have
                            if pre_idx.strip() == '':
                                to_replace = to_replace.replace('[^]', '(?<=\s)[^a-zA-Z0-9]')
        
                            replaced_span = re.search(
                                to_replace,
                                label
                            )
                            
                            if replaced_span:
                                span = re.sub(r'\s([?.!",](?:\s|$))', r'\1', span.strip())
                                self.unknowns[-1][span] = replaced_span[0]  

    def replace_unks_on_sentence(self, sentence: str, loop_n : int = 3, empty_after : bool = False):
        # Loop through in case the labels are repeated, maximum of three times
        while '<unk>' in sentence and loop_n > 0:
            loop_n -= 1
            for unknowns in self.unknowns:
                for k,v in unknowns.items():
                    # Leave to replace all-unk labels at the last pass
                    if k == '<unk>' and loop_n > 0:
                        continue
                    # In case it is because the first letter of the sentence has been uppercased
                    if not k in sentence and k[0] == k[0].lower() and k[0].upper() == sentence[0]:
                        k = k[0].upper() + k[1:]
                        v = v[0].upper() + v[1:]
                    # In case it is because a double space is found where it should not be
                    elif not k in sentence and len(re.findall(r'\s{2,}',k))>0:
                        k = re.sub(r'\s+', ' ', k)
                    #print(k,'/',v,'/',sentence)
                    sentence = sentence.replace(k.strip(),v.strip(),1)
                    #sentence = re.sub(k, v, sentence)
            # Removing final doublespaces
            sentence = re.sub(r'\s+', ' ', sentence).strip()
            # Removing spaces before punctuation
            sentence = re.sub(r'\s([?.!",](?:\s|$))', r'\1', sentence)
        if empty_after:
            self.unknowns = []
        return sentence

if __name__ == '__main__':

    verb_module = VerbModule()
    verbs = verb_module.verbalise('translate Graph to English: <H> World Trade Center <R> height <T> 200 meter <H> World Trade Center <R> is a <T> tower')
    print(verbs)