# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
from mmengine.model import BaseModule, ModuleList
from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict
from torch.nn import functional as F

from mmseg.registry import MODELS
from mmseg.utils import get_classes, get_predefined_templates, tokenizer


@MODELS.register_module()
class CLIPTextEncoder(BaseModule):
    """A text encoder with transformer architecture to encode the label text.

    Modified from https://github.com/MendelXu/SAN/blob/main/san/model/clip_utils/classifier.py # noqa:E501
    Copyright (c) 2023 MendelXu.
    Licensed under the MIT License

    Args:
        dataset_name: (str|None): The name of the dataset to which
            the data belongs.
        vocabulary: (List[str]|None): The list of class names. Default: None.
        templates: (List[str]|None): The prompt template used for labels.
            Default: None.
        total_vocab_size: (int): Number of all words used by the pre-trained
            model. Default: 49408 (CLIP).
        context_length: (int): The max length of prompt text.
            Default: 77 (CLIP).
        embed_dims: (int): Width of transformer model. Default: 512.
        num_layers: (int): Depth of transformer. Default: 12,
        num_heads: (int): Number of attention heads in transformer.
            Default: 8,
        mlp_ratio: (int) Ratio of mlp hidden dim to embedding dim in
            transformer. Default: 4,
        output_dims: (int) Dim of output text embeddings. Default: 512,
        cache_feature: (bool) Whether to save class embeddings in cache.
            Default: True,
        cat_bg: (bool) Whether to add background embedding. Default: True.
        norm_cfg (dict|None): Config for norm layer. Default: dict(type='LN')
        init_cfg (dict or list[dict], optional): Initialization config dict.
            Default: None.
    """

    def __init__(self,
                 dataset_name: str = None,
                 vocabulary: List[str] = None,
                 templates: str = 'vild',
                 total_vocab_size: int = 49408,
                 context_length: int = 77,
                 embed_dims: int = 512,
                 num_layers: int = 12,
                 num_heads: int = 8,
                 mlp_ratio: int = 4,
                 output_dims: int = 512,
                 cache_feature: bool = True,
                 cat_bg: bool = True,
                 norm_cfg: dict = dict(type='LN'),
                 init_cfg: dict = None):
        super().__init__(init_cfg)
        if isinstance(templates, List):
            self.templates = templates
        else:
            self.templates = get_predefined_templates(templates)

        assert dataset_name is not None or vocabulary is not None, \
            "text_encoder required either 'dataset_name' or 'vocabulary'"
        assert dataset_name is None or vocabulary is None, \
            "there is conflict between 'dataset_name' and 'vocabulary'"
        self.dataset_name = dataset_name
        self.vocabulary = vocabulary
        self.num_pos = context_length
        self.token_embedding = nn.Embedding(total_vocab_size, embed_dims)
        self.positional_embedding = nn.Parameter(
            torch.empty(context_length, embed_dims))
        self.text_projection = nn.Parameter(
            torch.empty(embed_dims, output_dims))
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.transformer = ModuleList()
        self.register_buffer(
            'attn_mask', self.build_attention_mask(), persistent=False)
        for i in range(num_layers):
            self.transformer.append(
                BaseTransformerLayer(
                    attn_cfgs=dict(
                        type='MultiheadAttention',
                        embed_dims=embed_dims,
                        num_heads=num_heads,
                        batch_first=False,
                        bias=True),
                    ffn_cfgs=dict(
                        type='FFN',
                        embed_dims=embed_dims,
                        feedforward_channels=mlp_ratio * embed_dims,
                        act_cfg=dict(type='QuickGELU')),
                    operation_order=('norm', 'self_attn', 'norm', 'ffn')))
        self.ln_final = build_norm_layer(
            norm_cfg, embed_dims, postfix='_final')[1]

        self.cache_feature = cache_feature
        if self.cache_feature:
            self.cache = {}

        self._freeze()

        self.cat_bg = cat_bg
        if self.cat_bg:
            self.bg_embed = nn.Parameter(
                torch.randn(1, self.text_projection.shape[1]))

    @property
    def ln_final(self):
        return getattr(self, self.final_name)

    def build_attention_mask(self):
        """lazily create causal attention mask, with full attention between the
        tokens.

        pytorch uses additive attention mask; fill with -inf
        """
        mask = torch.empty(self.num_pos, self.num_pos)
        mask.fill_(float('-inf'))
        mask.triu_(1)  # zero out the lower diagonal
        return mask

    def _freeze(self):
        for param in self.parameters():
            param.requires_grad = False

    def init_weights(self):
        if self.cat_bg:
            nn.init.normal_(
                self.bg_embed,
                std=self.bg_embed.shape[1]**-0.5,
            )
        if isinstance(self.init_cfg, dict) and \
                self.init_cfg.get('type') == 'Pretrained_Part':
            checkpoint = CheckpointLoader.load_checkpoint(
                self.init_cfg['checkpoint'], logger=None, map_location='cpu')

            state_dict = checkpoint.copy()
            para_prefix = 'text_encoder'
            prefix_len = len(para_prefix) + 1
            for k, v in checkpoint.items():
                state_dict.pop(k)
                if para_prefix in k:
                    state_dict[k[prefix_len:]] = v

            load_state_dict(self, state_dict, strict=False, logger=None)

        else:
            super().init_weights()

    @torch.no_grad()
    def encode_text(self, text, normalize=False):
        """encode class token."""

        embed_device = self.token_embedding.weight.device
        x = self.token_embedding(
            text.to(embed_device))  # [batch_size, n_ctx, d_model]
        x = x + self.positional_embedding
        x = x.permute(1, 0, 2)  # NLD -> LND
        for block in self.transformer:
            x = block(query=x, attn_masks=self.attn_mask)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x)  # [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding
        # (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]),
              text.argmax(dim=-1)] @ self.text_projection
        return F.normalize(x, dim=-1) if normalize else x

    def template_encode(self, vocabulary):
        """Prompt engineering."""
        text_embed_bucket = []
        for template in self.templates:
            text_inputs = tokenizer.tokenize(
                [template.format(noun) for noun in vocabulary])
            text_embed = self.encode_text(text_inputs, normalize=True)
            text_embed_bucket.append(text_embed)
        text_embed = torch.stack(text_embed_bucket).mean(dim=0)
        text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
        return text_embed

    def forward(self):
        """Forward function."""
        if self.dataset_name is None:  # encoding vocabulary directly
            class_names = self.vocabulary
            if self.cache_feature:
                new_classes = [
                    word for word in class_names if word not in self.cache
                ]
                if len(new_classes) > 0:
                    class_embeds = self.template_encode(new_classes)
                    self.cache.update(dict(zip(new_classes, class_embeds)))
                class_embeds = torch.stack(
                    [self.cache[word] for word in class_names])
            else:
                class_embeds = self.template_encode(class_names)

        else:  # encoding the classes of the dataset
            class_names = get_classes(self.dataset_name)
            if class_names[0] == 'background':
                class_names = class_names[1:]
            if self.cache_feature:
                if self.dataset_name not in self.cache:
                    class_embeds = self.template_encode(class_names)
                    self.cache[self.dataset_name] = class_embeds
                else:
                    class_embeds = self.cache[self.dataset_name]
            else:
                class_embeds = self.template_encode(class_names)

        if self.cat_bg:
            class_embeds = torch.cat([class_embeds, self.bg_embed])
            class_embeds = F.normalize(class_embeds, p=2, dim=-1)
        return self.logit_scale.exp() * class_embeds


@MODELS.register_module()
class QuickGELU(nn.Module):
    # From https://github.com/openai/CLIP/blob/main/clip/model.py
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)