from typing import List
import torch.nn as nn
import os

import torch
import numpy as np
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from transformers import logging
from torch.nn.functional import normalize


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)

        self.register_buffer('pe', pe, persistent=False)

    def forward(self, x):
        return x + self.pe[:x.shape[0], :]


class TMR_textencoder(nn.Module):
    def __init__(self, modelpath: str, latent_dim: int, ff_size: int,
                 num_layers: int, num_heads: int, activation: str, **kwargs) -> None:
        super().__init__()

        logging.set_verbosity_error()

        # Tokenizer
        os.environ["TOKENIZERS_PARALLELISM"] = "false"
        self.tokenizer = AutoTokenizer.from_pretrained(modelpath)

        # Text model
        self.text_model = AutoModel.from_pretrained(modelpath)
        # Then configure the model
        self.text_encoded_dim = self.text_model.config.hidden_size

        # Projection of the text-outputs into the latent space
        self.projection = nn.Sequential(
            nn.ReLU(),
            nn.Linear(self.text_encoded_dim, latent_dim)
        )

        self.mu_token = nn.Parameter(torch.randn(latent_dim))
        self.logvar_token = nn.Parameter(torch.randn(latent_dim))
        self.sequence_pos_encoding = PositionalEncoding(latent_dim)

        seq_trans_encoder_layer = nn.TransformerEncoderLayer(d_model=latent_dim,
                                                             nhead=num_heads,
                                                             dim_feedforward=ff_size,
                                                             dropout=0.0,
                                                             activation=activation)
        self.seqTransEncoder = nn.TransformerEncoder(
            seq_trans_encoder_layer,
            num_layers=num_layers
        )

    def get_last_hidden_state(self, texts: List[str],
                              return_mask: bool = False):
        encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
        output = self.text_model(**encoded_inputs.to(self.text_model.device))
        if not return_mask:
            return output.last_hidden_state
        return output.last_hidden_state, encoded_inputs.attention_mask.to(dtype=bool)

    def forward(self, texts: List[str]) -> Tensor:
        text_encoded, mask = self.get_last_hidden_state(texts, return_mask=True)

        x = self.projection(text_encoded)
        bs, nframes, _ = x.shape
        # bs, nframes, totjoints, nfeats = x.shape
        # Switch sequence and batch_size because the input of
        # Pytorch Transformer is [Sequence, Batch size, ...]
        x = x.permute(1, 0, 2)  # now it is [nframes, bs, latent_dim]

        mu_token = torch.tile(self.mu_token, (bs,)).reshape(bs, -1)
        logvar_token = torch.tile(self.logvar_token, (bs,)).reshape(bs, -1)

        # adding the distribution tokens for all sequences
        xseq = torch.cat((mu_token[None], logvar_token[None], x), 0)

        # create a bigger mask, to allow attend to mu and logvar
        token_mask = torch.ones((bs, 2), dtype=bool, device=x.device)
        aug_mask = torch.cat((token_mask, mask), 1)

        # add positional encoding
        xseq = self.sequence_pos_encoding(xseq)
        final = self.seqTransEncoder(xseq, src_key_padding_mask=~aug_mask)

        # only mu for inference
        mu = final[0]
        return mu

    # compute score for retrieval
    def compute_scores(self, texts, unit_embs=None, embs=None):
        # not both empty
        assert not (unit_embs is None and embs is None)
        # not both filled
        assert not (unit_embs is not None and embs is not None)

        output_str = False
        # if one input, squeeze the output
        if isinstance(texts, str):
            texts = [texts]
            output_str = True

        # compute unit_embs from embs if not given
        if embs is not None:
            unit_embs = normalize(embs)

        with torch.no_grad():
            latent_unit_texts = normalize(self(texts))
            # compute cosine similarity between 0 and 1
            scores = (unit_embs @ latent_unit_texts.T).T/2 + 0.5
            scores = scores.cpu().numpy()

        if output_str:
            scores = scores[0]

        return scores