'''
Copyright (c) Alibaba, Inc. and its affiliates.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from ldm.modules.diffusionmodules.util import conv_nd, linear


def get_clip_token_for_string(tokenizer, string):
    batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,
                               return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
    tokens = batch_encoding["input_ids"]
    assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string"
    return tokens[0, 1]


def get_bert_token_for_string(tokenizer, string):
    token = tokenizer(string)
    assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
    token = token[0, 1]
    return token


def get_clip_vision_emb(encoder, processor, img):
    _img = img.repeat(1, 3, 1, 1)*255
    inputs = processor(images=_img, return_tensors="pt")
    inputs['pixel_values'] = inputs['pixel_values'].to(img.device)
    outputs = encoder(**inputs)
    emb = outputs.image_embeds
    return emb


def get_recog_emb(encoder, img_list):
    _img_list = [(img.repeat(1, 3, 1, 1)*255)[0] for img in img_list]
    encoder.predictor.eval()
    _, preds_neck = encoder.pred_imglist(_img_list, show_debug=False)
    return preds_neck


def pad_H(x):
    _, _, H, W = x.shape
    p_top = (W - H) // 2
    p_bot = W - H - p_top
    return F.pad(x, (0, 0, p_top, p_bot))


class EncodeNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EncodeNet, self).__init__()
        chan = 16
        n_layer = 4  # downsample

        self.conv1 = conv_nd(2, in_channels, chan, 3, padding=1)
        self.conv_list = nn.ModuleList([])
        _c = chan
        for i in range(n_layer):
            self.conv_list.append(conv_nd(2, _c, _c*2, 3, padding=1, stride=2))
            _c *= 2
        self.conv2 = conv_nd(2, _c, out_channels, 3, padding=1)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.act = nn.SiLU()

    def forward(self, x):
        x = self.act(self.conv1(x))
        for layer in self.conv_list:
            x = self.act(layer(x))
        x = self.act(self.conv2(x))
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        return x


class EmbeddingManager(nn.Module):
    def __init__(
            self,
            embedder,
            valid=True,
            glyph_channels=20,
            position_channels=1,
            placeholder_string='*',
            add_pos=False,
            emb_type='ocr',
            **kwargs
    ):
        super().__init__()
        if hasattr(embedder, 'tokenizer'):  # using Stable Diffusion's CLIP encoder
            get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
            token_dim = 768
            if hasattr(embedder, 'vit'):
                assert emb_type == 'vit'
                self.get_vision_emb = partial(get_clip_vision_emb, embedder.vit, embedder.processor)
            self.get_recog_emb = None
        else:  # using LDM's BERT encoder
            get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
            token_dim = 1280
        self.token_dim = token_dim
        self.emb_type = emb_type

        self.add_pos = add_pos
        if add_pos:
            self.position_encoder = EncodeNet(position_channels, token_dim)
        if emb_type == 'ocr':
            self.proj = linear(40*64, token_dim)
        if emb_type == 'conv':
            self.glyph_encoder = EncodeNet(glyph_channels, token_dim)

        self.placeholder_token = get_token_for_string(placeholder_string)

    def encode_text(self, text_info):
        if self.get_recog_emb is None and self.emb_type == 'ocr':
            self.get_recog_emb = partial(get_recog_emb, self.recog)

        gline_list = []
        pos_list = []
        for i in range(len(text_info['n_lines'])):  # sample index in a batch
            n_lines = text_info['n_lines'][i]
            for j in range(n_lines):  # line
                gline_list += [text_info['gly_line'][j][i:i+1]]
                if self.add_pos:
                    pos_list += [text_info['positions'][j][i:i+1]]

        if len(gline_list) > 0:
            if self.emb_type == 'ocr':
                recog_emb = self.get_recog_emb(gline_list)
                enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1))
            elif self.emb_type == 'vit':
                enc_glyph = self.get_vision_emb(pad_H(torch.cat(gline_list, dim=0)))
            elif self.emb_type == 'conv':
                enc_glyph = self.glyph_encoder(pad_H(torch.cat(gline_list, dim=0)))
            if self.add_pos:
                enc_pos = self.position_encoder(torch.cat(gline_list, dim=0))
                enc_glyph = enc_glyph+enc_pos

        self.text_embs_all = []
        n_idx = 0
        for i in range(len(text_info['n_lines'])):  # sample index in a batch
            n_lines = text_info['n_lines'][i]
            text_embs = []
            for j in range(n_lines):  # line
                text_embs += [enc_glyph[n_idx:n_idx+1]]
                n_idx += 1
            self.text_embs_all += [text_embs]

    def forward(
            self,
            tokenized_text,
            embedded_text,
    ):
        b, device = tokenized_text.shape[0], tokenized_text.device
        for i in range(b):
            idx = tokenized_text[i] == self.placeholder_token.to(device)
            if sum(idx) > 0:
                if i >= len(self.text_embs_all):
                    print('truncation for log images...')
                    break
                text_emb = torch.cat(self.text_embs_all[i], dim=0)
                if sum(idx) != len(text_emb):
                    print('truncation for long caption...')
                embedded_text[i][idx] = text_emb[:sum(idx)]
        return embedded_text

    def embedding_parameters(self):
        return self.parameters()