# Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------------
# Adapted from https://github.com/wl-zhao/VPD/blob/main/vpd/models.py
# Original licence: MIT License
# ------------------------------------------------------------------------------

import math
from typing import List, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
from mmengine.runner import CheckpointLoader, load_checkpoint

from mmseg.registry import MODELS
from mmseg.utils import ConfigType, OptConfigType

try:
    from ldm.modules.diffusionmodules.util import timestep_embedding
    from ldm.util import instantiate_from_config
    has_ldm = True
except ImportError:
    has_ldm = False


def register_attention_control(model, controller):
    """Registers a control function to manage attention within a model.

    Args:
        model: The model to which attention is to be registered.
        controller: The control function responsible for managing attention.
    """

    def ca_forward(self, place_in_unet):
        """Custom forward method for attention.

        Args:
            self: Reference to the current object.
            place_in_unet: The location in UNet (down/mid/up).

        Returns:
            The modified forward method.
        """

        def forward(x, context=None, mask=None):
            h = self.heads
            is_cross = context is not None
            context = context or x  # if context is None, use x

            q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
            q, k, v = (
                tensor.view(tensor.shape[0] * h, tensor.shape[1],
                            tensor.shape[2] // h) for tensor in [q, k, v])

            sim = torch.matmul(q, k.transpose(-2, -1)) * self.scale

            if mask is not None:
                mask = mask.flatten(1).unsqueeze(1).repeat(h, 1, 1)
                max_neg_value = -torch.finfo(sim.dtype).max
                sim.masked_fill_(~mask, max_neg_value)

            attn = sim.softmax(dim=-1)
            attn_mean = attn.view(h, attn.shape[0] // h,
                                  *attn.shape[1:]).mean(0)
            controller(attn_mean, is_cross, place_in_unet)

            out = torch.matmul(attn, v)
            out = out.view(out.shape[0] // h, out.shape[1], out.shape[2] * h)
            return self.to_out(out)

        return forward

    def register_recr(net_, count, place_in_unet):
        """Recursive function to register the custom forward method to all
        CrossAttention layers.

        Args:
            net_: The network layer currently being processed.
            count: The current count of layers processed.
            place_in_unet: The location in UNet (down/mid/up).

        Returns:
            The updated count of layers processed.
        """
        if net_.__class__.__name__ == 'CrossAttention':
            net_.forward = ca_forward(net_, place_in_unet)
            return count + 1
        if hasattr(net_, 'children'):
            return sum(
                register_recr(child, 0, place_in_unet)
                for child in net_.children())
        return count

    cross_att_count = sum(
        register_recr(net[1], 0, place) for net, place in [
            (child, 'down') if 'input_blocks' in name else (
                child, 'up') if 'output_blocks' in name else
            (child,
             'mid') if 'middle_block' in name else (None, None)  # Default case
            for name, child in model.diffusion_model.named_children()
        ] if net is not None)

    controller.num_att_layers = cross_att_count


class AttentionStore:
    """A class for storing attention information in the UNet model.

    Attributes:
        base_size (int): Base size for storing attention information.
        max_size (int): Maximum size for storing attention information.
    """

    def __init__(self, base_size=64, max_size=None):
        """Initialize AttentionStore with default or custom sizes."""
        self.reset()
        self.base_size = base_size
        self.max_size = max_size or (base_size // 2)
        self.num_att_layers = -1

    @staticmethod
    def get_empty_store():
        """Returns an empty store for holding attention values."""
        return {
            key: []
            for key in [
                'down_cross', 'mid_cross', 'up_cross', 'down_self', 'mid_self',
                'up_self'
            ]
        }

    def reset(self):
        """Resets the step and attention stores to their initial states."""
        self.cur_step = 0
        self.cur_att_layer = 0
        self.step_store = self.get_empty_store()
        self.attention_store = {}

    def forward(self, attn, is_cross: bool, place_in_unet: str):
        """Processes a single forward step, storing the attention.

        Args:
            attn: The attention tensor.
            is_cross (bool): Whether it's cross attention.
            place_in_unet (str): The location in UNet (down/mid/up).

        Returns:
            The unmodified attention tensor.
        """
        key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
        if attn.shape[1] <= (self.max_size)**2:
            self.step_store[key].append(attn)
        return attn

    def between_steps(self):
        """Processes and stores attention information between steps."""
        if not self.attention_store:
            self.attention_store = self.step_store
        else:
            for key in self.attention_store:
                self.attention_store[key] = [
                    stored + step for stored, step in zip(
                        self.attention_store[key], self.step_store[key])
                ]
        self.step_store = self.get_empty_store()

    def get_average_attention(self):
        """Calculates and returns the average attention across all steps."""
        return {
            key: [item for item in self.step_store[key]]
            for key in self.step_store
        }

    def __call__(self, attn, is_cross: bool, place_in_unet: str):
        """Allows the class instance to be callable."""
        return self.forward(attn, is_cross, place_in_unet)

    @property
    def num_uncond_att_layers(self):
        """Returns the number of unconditional attention layers (default is
        0)."""
        return 0

    def step_callback(self, x_t):
        """A placeholder for a step callback.

        Returns the input unchanged.
        """
        return x_t


class UNetWrapper(nn.Module):
    """A wrapper for UNet with optional attention mechanisms.

    Args:
        unet (nn.Module): The UNet model to wrap
        use_attn (bool): Whether to use attention. Defaults to True
        base_size (int): Base size for the attention store. Defaults to 512
        max_attn_size (int, optional): Maximum size for the attention store.
            Defaults to None
        attn_selector (str): The types of attention to use.
            Defaults to 'up_cross+down_cross'
    """

    def __init__(self,
                 unet,
                 use_attn=True,
                 base_size=512,
                 max_attn_size=None,
                 attn_selector='up_cross+down_cross'):
        super().__init__()

        assert has_ldm, 'To use UNetWrapper, please install required ' \
            'packages via `pip install -r requirements/optional.txt`.'

        self.unet = unet
        self.attention_store = AttentionStore(
            base_size=base_size // 8, max_size=max_attn_size)
        self.attn_selector = attn_selector.split('+')
        self.use_attn = use_attn
        self.init_sizes(base_size)
        if self.use_attn:
            register_attention_control(unet, self.attention_store)

    def init_sizes(self, base_size):
        """Initialize sizes based on the base size."""
        self.size16 = base_size // 32
        self.size32 = base_size // 16
        self.size64 = base_size // 8

    def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
        """Forward pass through the model."""
        diffusion_model = self.unet.diffusion_model
        if self.use_attn:
            self.attention_store.reset()
        hs, emb, out_list = self._unet_forward(x, timesteps, context, y,
                                               diffusion_model)
        if self.use_attn:
            self._append_attn_to_output(out_list)
        return out_list[::-1]

    def _unet_forward(self, x, timesteps, context, y, diffusion_model):
        hs = []
        t_emb = timestep_embedding(
            timesteps, diffusion_model.model_channels, repeat_only=False)
        emb = diffusion_model.time_embed(t_emb)
        h = x.type(diffusion_model.dtype)
        for module in diffusion_model.input_blocks:
            h = module(h, emb, context)
            hs.append(h)
        h = diffusion_model.middle_block(h, emb, context)
        out_list = []
        for i_out, module in enumerate(diffusion_model.output_blocks):
            h = torch.cat([h, hs.pop()], dim=1)
            h = module(h, emb, context)
            if i_out in [1, 4, 7]:
                out_list.append(h)
        h = h.type(x.dtype)
        out_list.append(h)
        return hs, emb, out_list

    def _append_attn_to_output(self, out_list):
        avg_attn = self.attention_store.get_average_attention()
        attns = {self.size16: [], self.size32: [], self.size64: []}
        for k in self.attn_selector:
            for up_attn in avg_attn[k]:
                size = int(math.sqrt(up_attn.shape[1]))
                up_attn = up_attn.transpose(-1, -2).reshape(
                    *up_attn.shape[:2], size, -1)
                attns[size].append(up_attn)
        attn16 = torch.stack(attns[self.size16]).mean(0)
        attn32 = torch.stack(attns[self.size32]).mean(0)
        attn64 = torch.stack(attns[self.size64]).mean(0) if len(
            attns[self.size64]) > 0 else None
        out_list[1] = torch.cat([out_list[1], attn16], dim=1)
        out_list[2] = torch.cat([out_list[2], attn32], dim=1)
        if attn64 is not None:
            out_list[3] = torch.cat([out_list[3], attn64], dim=1)


class TextAdapter(nn.Module):
    """A PyTorch Module that serves as a text adapter.

    This module takes text embeddings and adjusts them based on a scaling
    factor gamma.
    """

    def __init__(self, text_dim=768):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(text_dim, text_dim), nn.GELU(),
            nn.Linear(text_dim, text_dim))

    def forward(self, texts, gamma):
        texts_after = self.fc(texts)
        texts = texts + gamma * texts_after
        return texts


@MODELS.register_module()
class VPD(BaseModule):
    """VPD (Visual Perception Diffusion) model.

    .. _`VPD`: https://arxiv.org/abs/2303.02153

    Args:
        diffusion_cfg (dict): Configuration for diffusion model.
        class_embed_path (str): Path for class embeddings.
        unet_cfg (dict, optional): Configuration for U-Net.
        gamma (float, optional): Gamma for text adaptation. Defaults to 1e-4.
        class_embed_select (bool, optional): If True, enables class embedding
            selection. Defaults to False.
        pad_shape (Optional[Union[int, List[int]]], optional): Padding shape.
            Defaults to None.
        pad_val (Union[int, List[int]], optional): Padding value.
            Defaults to 0.
        init_cfg (dict, optional): Configuration for network initialization.
    """

    def __init__(self,
                 diffusion_cfg: ConfigType,
                 class_embed_path: str,
                 unet_cfg: OptConfigType = dict(),
                 gamma: float = 1e-4,
                 class_embed_select=False,
                 pad_shape: Optional[Union[int, List[int]]] = None,
                 pad_val: Union[int, List[int]] = 0,
                 init_cfg: OptConfigType = None):

        super().__init__(init_cfg=init_cfg)

        assert has_ldm, 'To use VPD model, please install required packages' \
            ' via `pip install -r requirements/optional.txt`.'

        if pad_shape is not None:
            if not isinstance(pad_shape, (list, tuple)):
                pad_shape = (pad_shape, pad_shape)

        self.pad_shape = pad_shape
        self.pad_val = pad_val

        # diffusion model
        diffusion_checkpoint = diffusion_cfg.pop('checkpoint', None)
        sd_model = instantiate_from_config(diffusion_cfg)
        if diffusion_checkpoint is not None:
            load_checkpoint(sd_model, diffusion_checkpoint, strict=False)

        self.encoder_vq = sd_model.first_stage_model
        self.unet = UNetWrapper(sd_model.model, **unet_cfg)

        # class embeddings & text adapter
        class_embeddings = CheckpointLoader.load_checkpoint(class_embed_path)
        text_dim = class_embeddings.size(-1)
        self.text_adapter = TextAdapter(text_dim=text_dim)
        self.class_embed_select = class_embed_select
        if class_embed_select:
            class_embeddings = torch.cat(
                (class_embeddings, class_embeddings.mean(dim=0,
                                                         keepdims=True)),
                dim=0)
        self.register_buffer('class_embeddings', class_embeddings)
        self.gamma = nn.Parameter(torch.ones(text_dim) * gamma)

    def forward(self, x):
        """Extract features from images."""

        # calculate cross-attn map
        if self.class_embed_select:
            if isinstance(x, (tuple, list)):
                x, class_ids = x[:2]
                class_ids = class_ids.tolist()
            else:
                class_ids = [-1] * x.size(0)
            class_embeddings = self.class_embeddings[class_ids]
            c_crossattn = self.text_adapter(class_embeddings, self.gamma)
            c_crossattn = c_crossattn.unsqueeze(1)
        else:
            class_embeddings = self.class_embeddings
            c_crossattn = self.text_adapter(class_embeddings, self.gamma)
            c_crossattn = c_crossattn.unsqueeze(0).repeat(x.size(0), 1, 1)

        # pad to required input shape for pretrained diffusion model
        if self.pad_shape is not None:
            pad_width = max(0, self.pad_shape[1] - x.shape[-1])
            pad_height = max(0, self.pad_shape[0] - x.shape[-2])
            x = F.pad(x, (0, pad_width, 0, pad_height), value=self.pad_val)

        # forward the denoising model
        with torch.no_grad():
            latents = self.encoder_vq.encode(x).mode().detach()
        t = torch.ones((x.shape[0], ), device=x.device).long()
        outs = self.unet(latents, t, context=c_crossattn)

        return outs