from typing import List, Union, Tuple

import torch
from transformers import AutoConfig, AutoTokenizer, AutoModelWithLMHead, AutoModel

from transformer_formatter import TransformerOutputFormatter
from utils.f import delegates, pick, memoize

@memoize
def get_details(mname):
    return ModelDetails(mname)

def get_model_tok(mname):
    conf = AutoConfig.from_pretrained(mname, output_attentions=True, output_past=False)
    tok = AutoTokenizer.from_pretrained(mname, config=conf)
    model = AutoModelWithLMHead.from_pretrained(mname, config=conf)
    return model, tok

class ModelDetails:
    """Wraps a transformer model and tokenizer to prepare inputs to the frontend visualization"""
    def __init__(self, mname):
        self.mname = mname
        self.model, self.tok = get_model_tok(self.mname)
        self.model.eval()
        self.config = self.model.config
    
    def from_sentence(self, sentence: str) -> TransformerOutputFormatter:
        """Get attentions and word probabilities from a sentence. Special tokens are automatically added if a sentence is passed.
        
        Args:
            sentence: The input sentence to tokenize and analyze. 
        """
        tokens = self.tok.tokenize(sentence)

        return self.from_tokens(tokens, sentence, add_special_tokens=True)

    def from_tokens(
        self, tokens: List[str], orig_sentence:str, add_special_tokens:bool=False, mask_attentions:bool=False, topk:int=5
    ) -> TransformerOutputFormatter:
        """Get formatted attention and predictions from a list of tokens.
        Args:
            tokens: Tokens to analyze
            orig_sentence: The sentence the tokens came from (needed to help organize the output)
            add_special_tokens: Whether to add special tokens like CLS / <|endoftext|> to the tokens.
                If False, assume the tokens already have the special tokens
            mask_attentions: If True, do not pay attention to attention patterns to special tokens through the model.
            topk: How many top predictions to report
        """
        ids = self.tok.convert_tokens_to_ids(tokens)

        # For GPT2, add the beginning of sentence token to the input. Note that this will work on all models but XLM
        bost = self.tok.bos_token_id
        clst = self.tok.cls_token_id
        sept = self.tok.sep_token_id
        if (bost is not None) and (bost != clst)and add_special_tokens:
            ids.insert(0, bost)

        inputs = self.tok.prepare_for_model(ids, add_special_tokens=add_special_tokens, return_tensors="pt")
        parsed_input = self.parse_inputs(inputs, mask_attentions=mask_attentions)
        output = self.model(parsed_input['input_ids'], attention_mask=parsed_input['attention_mask'])

        logits, atts = self.choose_logits_att(output)
        words, probs = self.logits2words(logits, topk)
        tokens = self.view_ids(inputs["input_ids"])

        formatted_output = TransformerOutputFormatter(
            orig_sentence,
            tokens,
            inputs["special_tokens_mask"],
            atts,
            words,
            probs.tolist(),
            self.config
        )

        return formatted_output

    def choose_logits_att(self, out:Tuple) -> Tuple:
        """Select from the model's output the logits and the attentions, switching on model name
        
        Args:
            out: Output from the model's forward pass
        Returns:
            (logits: tensor((bs, N)), attentions: Tuple[tensor(())])
        """
        if 't5' in self.mname:
            logits, _, atts = out
        else:
            logits, atts = out

        return logits, atts

    def logits2words(self, logits, topk):
        """Convert logit probabilities into words from the tokenizer's vocabulary.
        
        """
        probs, idxs = torch.topk(torch.softmax(logits.squeeze(0), 1), topk)
        words = [self.tok.convert_ids_to_tokens(i) for i in idxs]
        return words, probs

    def view_ids(self, ids: Union[List[int], torch.Tensor]) -> List[str]:
        """View what the tokenizer thinks certain ids are for a single input"""
        if type(ids) == torch.Tensor:
            # Remove batch dimension
            ids = ids.squeeze(0).tolist()

        out = self.tok.convert_ids_to_tokens(ids)
        return out

    def parse_inputs(self, inputs, mask_attentions=False):
        """Parse the output from `tokenizer.prepare_for_model` to the desired attention mask from special tokens
        Args:
            - inputs: The output of `tokenizer.prepare_for_model`.
                A dict with keys: {'special_token_mask', 'token_type_ids', 'input_ids'}
            - mask_attentions: Flag indicating whether to mask the attentions or not
        Returns:
            Dict with keys: {'input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask'}
        Usage:
            ```
            s = "test sentence"
            # from raw sentence to tokens
            tokens = tokenizer.tokenize(s)
            # From tokens to ids
            ids = tokenizer.convert_tokens_to_ids(tokens)
            # From ids to input
            inputs = tokenizer.prepare_for_model(ids, return_tensors='pt')
            # Parse the input. Optionally mask the special tokens from the analysis.
            parsed_input = parse_inputs(inputs)
            # Run the model, pick from this output whatever inputs you want
            from utils.f import pick
            out = model(**pick(['input_ids'], parse_inputs(inputs)))
            ```
        """

        out = inputs.copy()

        # DEFINE SPECIAL TOKENS MASK
        if "special_tokens_mask" not in inputs.keys():
            special_tokens = set([self.tok.unk_token_id, self.tok.cls_token_id, self.tok.sep_token_id, self.tok.bos_token_id, self.tok.eos_token_id, self.tok.pad_token_id])
            in_ids = inputs['input_ids'][0]
            special_tok_mask = [1 if int(i) in special_tokens else 0 for i in in_ids]
            inputs['special_tokens_mask'] = special_tok_mask

        if mask_attentions:
            out["attention_mask"] = torch.tensor(
                [int(not i) for i in inputs.get("special_tokens_mask")]
            ).unsqueeze(0)
        else:
            out["attention_mask"] = torch.tensor(
                [1 for i in inputs.get("special_tokens_mask")]
            ).unsqueeze(0)

        return out