import einops
from collections import OrderedDict
from functools import partial
from typing import Callable

import torch
import torch.nn as nn
import torchvision
from torch.utils.checkpoint import checkpoint

from accelerate.utils import set_module_tensor_to_device
from diffusers.models.embeddings import apply_rotary_emb, FluxPosEmbed
from diffusers.models.modeling_utils import ModelMixin
from diffusers.configuration_utils import ConfigMixin
from diffusers.loaders import FromOriginalModelMixin


class MLPBlock(torchvision.ops.misc.MLP):
    """Transformer MLP block."""

    _version = 2

    def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
        super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6)

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        version = local_metadata.get("version", None)

        if version is None or version < 2:
            # Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
            for i in range(2):
                for type in ["weight", "bias"]:
                    old_key = f"{prefix}linear_{i+1}.{type}"
                    new_key = f"{prefix}{3*i}.{type}"
                    if old_key in state_dict:
                        state_dict[new_key] = state_dict.pop(old_key)

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )


class EncoderBlock(nn.Module):
    """Transformer encoder block."""

    def __init__(
        self,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads

        # Attention block
        self.ln_1 = norm_layer(hidden_dim)
        self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)

        # MLP block
        self.ln_2 = norm_layer(hidden_dim)
        self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)

    def forward(self, input: torch.Tensor, freqs_cis):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        B, L, C = input.shape
        x = self.ln_1(input)
        if freqs_cis is not None:
            query = x.view(B, L, self.num_heads, self.hidden_dim // self.num_heads).transpose(1, 2)
            query = apply_rotary_emb(query, freqs_cis)
            query = query.transpose(1, 2).reshape(B, L, self.hidden_dim)
        x, _ = self.self_attention(query, query, x, need_weights=False)
        x = self.dropout(x)
        x = x + input

        y = self.ln_2(x)
        y = self.mlp(y)
        return x + y


class Encoder(nn.Module):
    """Transformer Model Encoder for sequence to sequence translation."""

    def __init__(
        self,
        seq_length: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        # Note that batch_size is on the first dim because
        # we have batch_first=True in nn.MultiAttention() by default
        # self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02))  # from BERT
        self.dropout = nn.Dropout(dropout)
        layers: OrderedDict[str, nn.Module] = OrderedDict()
        for i in range(num_layers):
            layers[f"encoder_layer_{i}"] = EncoderBlock(
                num_heads,
                hidden_dim,
                mlp_dim,
                dropout,
                attention_dropout,
                norm_layer,
            )
        self.layers = nn.Sequential(layers)
        self.ln = norm_layer(hidden_dim)

    def forward(self, input: torch.Tensor, freqs_cis):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        input = input # + self.pos_embedding
        x = self.dropout(input)
        for l in self.layers:
            x = checkpoint(l, x, freqs_cis)
        x = self.ln(x)
        return x


class ViTEncoder(nn.Module):
    def __init__(self, arch='vit-b/32'):
        super().__init__()
        self.arch = arch

        if self.arch == 'vit-b/32':
            ch = 768
            layers = 12
            heads = 12
        elif self.arch == 'vit-h/14':
            ch = 1280
            layers = 32
            heads = 16
        
        self.encoder = Encoder(
            seq_length=-1,
            num_layers=layers,
            num_heads=heads,
            hidden_dim=ch,
            mlp_dim=ch*4,
            dropout=0.0,
            attention_dropout=0.0,
        )
        self.fc_in = nn.Linear(16, ch)
        self.fc_out = nn.Linear(ch, 256)

        if self.arch == 'vit-b/32':
            from torchvision.models.vision_transformer import vit_b_32, ViT_B_32_Weights
            vit = vit_b_32(weights=ViT_B_32_Weights.DEFAULT)
        elif self.arch == 'vit-h/14':
            from torchvision.models.vision_transformer import vit_h_14, ViT_H_14_Weights
            vit = vit_h_14(weights=ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1)

        missing_keys, unexpected_keys = self.encoder.load_state_dict(vit.encoder.state_dict(), strict=False)
        if len(missing_keys) > 0 or len(unexpected_keys) > 0:
            print(f"ViT Encoder Missing keys: {missing_keys}")
            print(f"ViT Encoder Unexpected keys: {unexpected_keys}")
        del vit
    
    def forward(self, x, freqs_cis):
        out = self.fc_in(x)
        out = self.encoder(out, freqs_cis)
        out = checkpoint(self.fc_out, out)
        return out


def patchify(x, patch_size=8):
    if len(x.shape) == 4:
        bs, c, h, w = x.shape
        x = einops.rearrange(x, "b c (h p1) (w p2) -> b (c p1 p2) h w", p1=patch_size, p2=patch_size)
    elif len(x.shape) == 3:
        c, h, w = x.shape
        x = einops.rearrange(x, "c (h p1) (w p2) -> (c p1 p2) h w", p1=patch_size, p2=patch_size)
    return x


def unpatchify(x, patch_size=8):
    if len(x.shape) == 4:
        bs, c, h, w = x.shape
        x = einops.rearrange(x, "b (c p1 p2) h w -> b c (h p1) (w p2)", p1=patch_size, p2=patch_size)
    elif len(x.shape) == 3:
        c, h, w = x.shape
        x = einops.rearrange(x, "(c p1 p2) h w -> c (h p1) (w p2)", p1=patch_size, p2=patch_size)
    return x


def crop_each_layer(hidden_states, use_layers, list_layer_box, H, W, pos_embedding):
    token_list = []
    cos_list, sin_list = [], []
    for layer_idx in range(hidden_states.shape[1]):
        if list_layer_box[layer_idx] is None:
            continue
        else:
            x1, y1, x2, y2 = list_layer_box[layer_idx]
            x1, y1, x2, y2 = x1 // 8, y1 // 8, x2 // 8, y2 // 8
            layer_token = hidden_states[:, layer_idx, y1:y2, x1:x2]
            c, h, w = layer_token.shape
            layer_token = layer_token.reshape(c, -1)
            token_list.append(layer_token)
            ids = prepare_latent_image_ids(-1, H * 2, W * 2, hidden_states.device, hidden_states.dtype)
            ids[:, 0] = use_layers[layer_idx]
            image_rotary_emb = pos_embedding(ids)
            pos_cos, pos_sin = image_rotary_emb[0].reshape(H, W, -1), image_rotary_emb[1].reshape(H, W, -1)
            cos_list.append(pos_cos[y1:y2, x1:x2].reshape(-1, 64))
            sin_list.append(pos_sin[y1:y2, x1:x2].reshape(-1, 64))
    token_list = torch.cat(token_list, dim=1).permute(1, 0)
    cos_list = torch.cat(cos_list, dim=0)
    sin_list = torch.cat(sin_list, dim=0)
    return token_list, (cos_list, sin_list)


def prepare_latent_image_ids(batch_size, height, width, device, dtype):
    latent_image_ids = torch.zeros(height // 2, width // 2, 3)
    latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
    latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]

    latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape

    latent_image_ids = latent_image_ids.reshape(
        latent_image_id_height * latent_image_id_width, latent_image_id_channels
    )

    return latent_image_ids.to(device=device, dtype=dtype)


class AutoencoderKLTransformerTraining(ModelMixin, ConfigMixin, FromOriginalModelMixin):
    def __init__(self):
        super().__init__()

        self.decoder_arch = 'vit'
        self.layer_embedding = 'rope'

        self.decoder = ViTEncoder()
        self.pos_embedding = FluxPosEmbed(theta=10000, axes_dim=(8, 28, 28))
        if 'rel' in self.layer_embedding or 'abs' in self.layer_embedding:
            self.layer_embedding = nn.Parameter(torch.empty(16, 2 + self.max_layers, 1, 1).normal_(std=0.02), requires_grad=True)

        def zero_module(module):
            """
            Zero out the parameters of a module and return it.
            """
            for p in module.parameters():
                p.detach().zero_()
            return module

    def encode(self, z_2d, box, use_layers):
        B, C, T, H, W = z_2d.shape

        z, freqs_cis = [], []
        for b in range(B):
            _z = z_2d[b]
            if 'vit' in self.decoder_arch:
                _use_layers = torch.tensor(use_layers[b], device=z_2d.device)
                if 'rel' in self.layer_embedding:
                    _use_layers[_use_layers > 2] = 2
                if 'rel' in self.layer_embedding or 'abs' in self.layer_embedding:
                    _z = _z + self.layer_embedding[:, _use_layers] # + self.pos_embedding
            if 'rope' not in self.layer_embedding:
                use_layers[b] = [0] * len(use_layers[b])
            _z, cis = crop_each_layer(_z, use_layers[b], box[b], H, W, self.pos_embedding) ### modified
            z.append(_z)
            freqs_cis.append(cis)

        return z, freqs_cis

    def decode(self, z, freqs_cis, box, H, W):
        B = len(z)
        pad = torch.zeros(4, H, W, device=z[0].device, dtype=z[0].dtype)
        pad[3, :, :] = -1
        x = []
        for b in range(B):
            _x = []
            _z = self.decoder(z[b].unsqueeze(0), freqs_cis[b]).squeeze(0)
            current_index = 0
            for layer_idx in range(len(box[b])):
                if box[b][layer_idx] == None:
                    _x.append(pad.clone())
                else:
                    x1, y1, x2, y2 = box[b][layer_idx]
                    x1_tok, y1_tok, x2_tok, y2_tok = x1 // 8, y1 // 8, x2 // 8, y2 // 8
                    token_length = (x2_tok - x1_tok) * (y2_tok - y1_tok)
                    tokens = _z[current_index:current_index + token_length]
                    pixels = einops.rearrange(tokens, "(h w) c -> c h w", h=y2_tok - y1_tok, w=x2_tok - x1_tok)
                    unpatched = unpatchify(pixels)
                    pixels = pad.clone()
                    pixels[:, y1:y2, x1:x2] = unpatched
                    _x.append(pixels)
                    current_index += token_length
            _x = torch.stack(_x, dim=1)
            x.append(_x)
        x = torch.stack(x, dim=0)
        return x
    
    def forward(self, z_2d, box, use_layers=None):
        z_2d = z_2d.transpose(0, 1).unsqueeze(0)
        use_layers = use_layers or [list(range(z_2d.shape[2]))]
        z, freqs_cis = self.encode(z_2d, box, use_layers)
        H, W = z_2d.shape[-2:]
        x_hat = self.decode(z, freqs_cis, box, H * 8, W * 8)
        assert x_hat.shape[0] == 1, x_hat.shape
        x_hat = einops.rearrange(x_hat[0], "c t h w -> t c h w")
        x_hat_rgb, x_hat_alpha = x_hat[:, :3], x_hat[:, 3:]
        return x_hat_rgb, x_hat_alpha