# some modules/classes are copied and modified from https://github.com/mcmonkey4eva/sd3-ref
# the original code is licensed under the MIT License

# and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution!

from ast import Tuple
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from functools import partial
import math
from types import SimpleNamespace
from typing import Dict, List, Optional, Union
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from transformers import CLIPTokenizer, T5TokenizerFast

from library import custom_offloading_utils
from library.device_utils import clean_memory_on_device

from .utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)


memory_efficient_attention = None
try:
    import xformers
except:
    pass

try:
    from xformers.ops import memory_efficient_attention
except:
    memory_efficient_attention = None


# region mmdit


@dataclass
class SD3Params:
    patch_size: int
    depth: int
    num_patches: int
    pos_embed_max_size: int
    adm_in_channels: int
    qk_norm: Optional[str]
    x_block_self_attn_layers: list[int]
    context_embedder_in_features: int
    context_embedder_out_features: int
    model_type: str


def get_2d_sincos_pos_embed(
    embed_dim,
    grid_size,
    scaling_factor=None,
    offset=None,
):
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)
    if scaling_factor is not None:
        grid = grid / scaling_factor
    if offset is not None:
        grid = grid - offset

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
    return emb


def get_scaled_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, sample_size=64, base_size=16):
    """
    This function is contributed by KohakuBlueleaf. Thanks for the contribution!

    Creates scaled 2D sinusoidal positional embeddings that maintain consistent relative positions
    when the resolution differs from the training resolution.

    Args:
        embed_dim (int): Dimension of the positional embedding.
        grid_size (int or tuple): Size of the position grid (H, W). If int, assumes square grid.
        cls_token (bool): Whether to include class token. Defaults to False.
        extra_tokens (int): Number of extra tokens (e.g., cls_token). Defaults to 0.
        sample_size (int): Reference resolution (typically training resolution). Defaults to 64.
        base_size (int): Base grid size used during training. Defaults to 16.

    Returns:
        numpy.ndarray: Positional embeddings of shape (H*W, embed_dim) or
                      (H*W + extra_tokens, embed_dim) if cls_token is True.
    """
    # Convert grid_size to tuple if it's an integer
    if isinstance(grid_size, int):
        grid_size = (grid_size, grid_size)

    # Create normalized grid coordinates (0 to 1)
    grid_h = np.arange(grid_size[0], dtype=np.float32) / grid_size[0]
    grid_w = np.arange(grid_size[1], dtype=np.float32) / grid_size[1]

    # Calculate scaling factors for height and width
    # This ensures that the central region matches the original resolution's embeddings
    scale_h = base_size * grid_size[0] / (sample_size)
    scale_w = base_size * grid_size[1] / (sample_size)

    # Calculate shift values to center the original resolution's embedding region
    # This ensures that the central sample_size x sample_size region has similar
    # positional embeddings to the original resolution
    shift_h = 1 * scale_h * (grid_size[0] - sample_size) / (2 * grid_size[0])
    shift_w = 1 * scale_w * (grid_size[1] - sample_size) / (2 * grid_size[1])

    # Apply scaling and shifting to create the final grid coordinates
    grid_h = grid_h * scale_h - shift_h
    grid_w = grid_w * scale_w - shift_w

    # Create 2D grid using meshgrid (note: w goes first)
    grid = np.meshgrid(grid_w, grid_h)
    grid = np.stack(grid, axis=0)

    # # Calculate the starting indices for the central region
    # # This is used for debugging/visualization of the central region
    # st_h = (grid_size[0] - sample_size) // 2
    # st_w = (grid_size[1] - sample_size) // 2
    # print(grid[:, st_h : st_h + sample_size, st_w : st_w + sample_size])

    # Reshape grid for positional embedding calculation
    grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])

    # Generate the sinusoidal positional embeddings
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)

    # Add zeros for extra tokens (e.g., [CLS] token) if required
    if cls_token and extra_tokens > 0:
        pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)

    return pos_embed


# if __name__ == "__main__":
#     # This is what you get when you load SD3.5 state dict
#     pos_emb = torch.from_numpy(get_scaled_2d_sincos_pos_embed(
#         1536, [384, 384], sample_size=64, base_size=16
#     )).float().unsqueeze(0)


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float64)
    omega /= embed_dim / 2.0
    omega = 1.0 / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum("m,d->md", pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out)  # (M, D/2)
    emb_cos = np.cos(out)  # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


def get_1d_sincos_pos_embed_from_grid_torch(
    embed_dim,
    pos,
    device=None,
    dtype=torch.float32,
):
    omega = torch.arange(embed_dim // 2, device=device, dtype=dtype)
    omega *= 2.0 / embed_dim
    omega = 1.0 / 10000**omega
    out = torch.outer(pos.reshape(-1), omega)
    emb = torch.cat([out.sin(), out.cos()], dim=1)
    return emb


def get_2d_sincos_pos_embed_torch(
    embed_dim,
    w,
    h,
    val_center=7.5,
    val_magnitude=7.5,
    device=None,
    dtype=torch.float32,
):
    small = min(h, w)
    val_h = (h / small) * val_magnitude
    val_w = (w / small) * val_magnitude
    grid_h, grid_w = torch.meshgrid(
        torch.linspace(-val_h + val_center, val_h + val_center, h, device=device, dtype=dtype),
        torch.linspace(-val_w + val_center, val_w + val_center, w, device=device, dtype=dtype),
        indexing="ij",
    )
    emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
    emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
    emb = torch.cat([emb_w, emb_h], dim=1)  # (H*W, D)
    return emb


def modulate(x, shift, scale):
    if shift is None:
        shift = torch.zeros_like(scale)
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)


def default(x, default_value):
    if x is None:
        return default_value
    return x


def timestep_embedding(t, dim, max_period=10000):
    half = dim // 2
    # freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
    #     device=t.device, dtype=t.dtype
    # )
    freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
    args = t[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    if torch.is_floating_point(t):
        embedding = embedding.to(dtype=t.dtype)
    return embedding


class PatchEmbed(nn.Module):
    def __init__(
        self,
        img_size=256,
        patch_size=4,
        in_channels=3,
        embed_dim=512,
        norm_layer=None,
        flatten=True,
        bias=True,
        strict_img_size=True,
        dynamic_img_pad=False,
    ):
        # dynamic_img_pad and norm is omitted in SD3.5
        super().__init__()
        self.patch_size = patch_size
        self.flatten = flatten
        self.strict_img_size = strict_img_size
        self.dynamic_img_pad = dynamic_img_pad
        if img_size is not None:
            self.img_size = img_size
            self.grid_size = img_size // patch_size
            self.num_patches = self.grid_size**2
        else:
            self.img_size = None
            self.grid_size = None
            self.num_patches = None

        self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size, bias=bias)
        self.norm = nn.Identity() if norm_layer is None else norm_layer(embed_dim)

    def forward(self, x):
        B, C, H, W = x.shape

        if self.dynamic_img_pad:
            # Pad input so we won't have partial patch
            pad_h = (self.patch_size - H % self.patch_size) % self.patch_size
            pad_w = (self.patch_size - W % self.patch_size) % self.patch_size
            x = nn.functional.pad(x, (0, pad_w, 0, pad_h), mode="reflect")
        x = self.proj(x)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x


# FinalLayer in mmdit.py
class UnPatch(nn.Module):
    def __init__(self, hidden_size=512, patch_size=4, out_channels=3):
        super().__init__()
        self.patch_size = patch_size
        self.c = out_channels

        # eps is default in mmdit.py
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(hidden_size, patch_size**2 * out_channels)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 2 * hidden_size),
        )

    def forward(self, x: torch.Tensor, cmod, H=None, W=None):
        b, n, _ = x.shape
        p = self.patch_size
        c = self.c
        if H is None and W is None:
            w = h = int(n**0.5)
            assert h * w == n
        else:
            h = H // p if H else n // (W // p)
            w = W // p if W else n // h
            assert h * w == n

        shift, scale = self.adaLN_modulation(cmod).chunk(2, dim=-1)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)

        x = x.view(b, h, w, p, p, c)
        x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
        x = x.view(b, c, h * p, w * p)
        return x


class MLP(nn.Module):
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=lambda: nn.GELU(),
        norm_layer=None,
        bias=True,
        use_conv=False,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.use_conv = use_conv

        layer = partial(nn.Conv1d, kernel_size=1) if use_conv else nn.Linear

        self.fc1 = layer(in_features, hidden_features, bias=bias)
        self.fc2 = layer(hidden_features, out_features, bias=bias)
        self.act = act_layer()
        self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.norm(x)
        x = self.fc2(x)
        return x


class TimestepEmbedding(nn.Module):
    def __init__(self, hidden_size, freq_embed_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(freq_embed_size, hidden_size),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size),
        )
        self.freq_embed_size = freq_embed_size

    def forward(self, t, dtype=None, **kwargs):
        t_freq = timestep_embedding(t, self.freq_embed_size).to(dtype)
        t_emb = self.mlp(t_freq)
        return t_emb


class Embedder(nn.Module):
    def __init__(self, input_dim, hidden_size):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_size),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size),
        )

    def forward(self, x):
        return self.mlp(x)


def rmsnorm(x, eps=1e-6):
    return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)


class RMSNorm(torch.nn.Module):
    def __init__(
        self,
        dim: int,
        elementwise_affine: bool = False,
        eps: float = 1e-6,
        device=None,
        dtype=None,
    ):
        """
        Initialize the RMSNorm normalization layer.
        Args:
            dim (int): The dimension of the input tensor.
            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
        Attributes:
            eps (float): A small value added to the denominator for numerical stability.
            weight (nn.Parameter): Learnable scaling parameter.
        """
        super().__init__()
        self.eps = eps
        self.learnable_scale = elementwise_affine
        if self.learnable_scale:
            self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
        else:
            self.register_parameter("weight", None)

    def forward(self, x):
        """
        Forward pass through the RMSNorm layer.
        Args:
            x (torch.Tensor): The input tensor.
        Returns:
            torch.Tensor: The output tensor after applying RMSNorm.
        """
        x = rmsnorm(x, eps=self.eps)
        if self.learnable_scale:
            return x * self.weight.to(device=x.device, dtype=x.dtype)
        else:
            return x


class SwiGLUFeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
        ffn_dim_multiplier: float = None,
    ):
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))


# Linears for SelfAttention in mmdit.py
class AttentionLinears(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        qkv_bias: bool = False,
        pre_only: bool = False,
        qk_norm: Optional[str] = None,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        if not pre_only:
            self.proj = nn.Linear(dim, dim)
        self.pre_only = pre_only

        if qk_norm == "rms":
            self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6)
            self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6)
        elif qk_norm == "ln":
            self.ln_q = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6)
            self.ln_k = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6)
        elif qk_norm is None:
            self.ln_q = nn.Identity()
            self.ln_k = nn.Identity()
        else:
            raise ValueError(qk_norm)

    def pre_attention(self, x: torch.Tensor) -> torch.Tensor:
        """
        output:
            q, k, v: [B, L, D]
        """
        B, L, C = x.shape
        qkv: torch.Tensor = self.qkv(x)
        q, k, v = qkv.reshape(B, L, -1, self.head_dim).chunk(3, dim=2)
        q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1)
        k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1)
        return (q, k, v)

    def post_attention(self, x: torch.Tensor) -> torch.Tensor:
        assert not self.pre_only
        x = self.proj(x)
        return x


MEMORY_LAYOUTS = {
    "torch": (
        lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2),
        lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1),
        lambda x: (1, x, 1, 1),
    ),
    "xformers": (
        lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim),
        lambda x: x.reshape(x.shape[0], x.shape[1], -1),
        lambda x: (1, 1, x, 1),
    ),
    "math": (
        lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2),
        lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1),
        lambda x: (1, x, 1, 1),
    ),
}
# ATTN_FUNCTION = {
#     "torch": F.scaled_dot_product_attention,
#     "xformers": memory_efficient_attention,
# }


def vanilla_attention(q, k, v, mask, scale=None):
    if scale is None:
        scale = math.sqrt(q.size(-1))
    scores = torch.bmm(q, k.transpose(-1, -2)) / scale
    if mask is not None:
        mask = einops.rearrange(mask, "b ... -> b (...)")
        max_neg_value = -torch.finfo(scores.dtype).max
        mask = einops.repeat(mask, "b j -> (b h) j", h=q.size(-3))
        scores = scores.masked_fill(~mask, max_neg_value)
    p_attn = F.softmax(scores, dim=-1)
    return torch.bmm(p_attn, v)


def attention(q, k, v, head_dim, mask=None, scale=None, mode="xformers"):
    """
    q, k, v: [B, L, D]
    """
    pre_attn_layout = MEMORY_LAYOUTS[mode][0]
    post_attn_layout = MEMORY_LAYOUTS[mode][1]
    q = pre_attn_layout(q, head_dim)
    k = pre_attn_layout(k, head_dim)
    v = pre_attn_layout(v, head_dim)

    # scores = ATTN_FUNCTION[mode](q, k.to(q), v.to(q), mask, scale=scale)
    if mode == "torch":
        assert scale is None
        scores = F.scaled_dot_product_attention(q, k.to(q), v.to(q), mask)  # , scale=scale)
    elif mode == "xformers":
        scores = memory_efficient_attention(q, k.to(q), v.to(q), mask, scale=scale)
    else:
        scores = vanilla_attention(q, k.to(q), v.to(q), mask, scale=scale)

    scores = post_attn_layout(scores)
    return scores


# DismantledBlock in mmdit.py
class SingleDiTBlock(nn.Module):
    """
    A DiT block with gated adaptive layer norm (adaLN) conditioning.
    """

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        attn_mode: str = "xformers",
        qkv_bias: bool = False,
        pre_only: bool = False,
        rmsnorm: bool = False,
        scale_mod_only: bool = False,
        swiglu: bool = False,
        qk_norm: Optional[str] = None,
        x_block_self_attn: bool = False,
        **block_kwargs,
    ):
        super().__init__()
        assert attn_mode in MEMORY_LAYOUTS
        self.attn_mode = attn_mode
        if not rmsnorm:
            self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        else:
            self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.attn = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=pre_only, qk_norm=qk_norm)

        self.x_block_self_attn = x_block_self_attn
        if self.x_block_self_attn:
            assert not pre_only
            assert not scale_mod_only
            self.attn2 = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=False, qk_norm=qk_norm)

        if not pre_only:
            if not rmsnorm:
                self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
            else:
                self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        if not pre_only:
            if not swiglu:
                self.mlp = MLP(
                    in_features=hidden_size,
                    hidden_features=mlp_hidden_dim,
                    act_layer=lambda: nn.GELU(approximate="tanh"),
                )
            else:
                self.mlp = SwiGLUFeedForward(
                    dim=hidden_size,
                    hidden_dim=mlp_hidden_dim,
                    multiple_of=256,
                )
        self.scale_mod_only = scale_mod_only
        if self.x_block_self_attn:
            n_mods = 9
        elif not scale_mod_only:
            n_mods = 6 if not pre_only else 2
        else:
            n_mods = 4 if not pre_only else 1
        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, n_mods * hidden_size))
        self.pre_only = pre_only

    def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
        if not self.pre_only:
            if not self.scale_mod_only:
                (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(6, dim=-1)
            else:
                shift_msa = None
                shift_mlp = None
                (scale_msa, gate_msa, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(4, dim=-1)
            qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
            return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp)
        else:
            if not self.scale_mod_only:
                (shift_msa, scale_msa) = self.adaLN_modulation(c).chunk(2, dim=-1)
            else:
                shift_msa = None
                scale_msa = self.adaLN_modulation(c)
            qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
            return qkv, None

    def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
        assert self.x_block_self_attn
        (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2) = self.adaLN_modulation(
            c
        ).chunk(9, dim=1)
        x_norm = self.norm1(x)
        qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa))
        qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2))
        return qkv, qkv2, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2)

    def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
        assert not self.pre_only
        x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
        return x

    def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2, attn1_dropout: float = 0.0):
        assert not self.pre_only
        if attn1_dropout > 0.0:
            # Use torch.bernoulli to implement dropout, only dropout the batch dimension
            attn1_dropout = torch.bernoulli(torch.full((attn.size(0), 1, 1), 1 - attn1_dropout, device=attn.device))
            attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) * attn1_dropout
        else:
            attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
        x = x + attn_
        attn2_ = gate_msa2.unsqueeze(1) * self.attn2.post_attention(attn2)
        x = x + attn2_
        mlp_ = gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
        x = x + mlp_
        return x


# JointBlock + block_mixing in mmdit.py
class MMDiTBlock(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        pre_only = kwargs.pop("pre_only")
        x_block_self_attn = kwargs.pop("x_block_self_attn")

        self.context_block = SingleDiTBlock(*args, pre_only=pre_only, **kwargs)
        self.x_block = SingleDiTBlock(*args, pre_only=False, x_block_self_attn=x_block_self_attn, **kwargs)

        self.head_dim = self.x_block.attn.head_dim
        self.mode = self.x_block.attn_mode
        self.gradient_checkpointing = False

    def enable_gradient_checkpointing(self):
        self.gradient_checkpointing = True

    def _forward(self, context, x, c):
        ctx_qkv, ctx_intermediate = self.context_block.pre_attention(context, c)

        if self.x_block.x_block_self_attn:
            x_qkv, x_qkv2, x_intermediates = self.x_block.pre_attention_x(x, c)
        else:
            x_qkv, x_intermediates = self.x_block.pre_attention(x, c)

        ctx_len = ctx_qkv[0].size(1)

        q = torch.concat((ctx_qkv[0], x_qkv[0]), dim=1)
        k = torch.concat((ctx_qkv[1], x_qkv[1]), dim=1)
        v = torch.concat((ctx_qkv[2], x_qkv[2]), dim=1)

        attn = attention(q, k, v, head_dim=self.head_dim, mode=self.mode)
        ctx_attn_out = attn[:, :ctx_len]
        x_attn_out = attn[:, ctx_len:]

        if self.x_block.x_block_self_attn:
            x_q2, x_k2, x_v2 = x_qkv2
            attn2 = attention(x_q2, x_k2, x_v2, self.x_block.attn2.num_heads, mode=self.mode)
            x = self.x_block.post_attention_x(x_attn_out, attn2, *x_intermediates)
        else:
            x = self.x_block.post_attention(x_attn_out, *x_intermediates)

        if not self.context_block.pre_only:
            context = self.context_block.post_attention(ctx_attn_out, *ctx_intermediate)
        else:
            context = None

        return context, x

    def forward(self, *args, **kwargs):
        if self.training and self.gradient_checkpointing:
            return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
        else:
            return self._forward(*args, **kwargs)


class MMDiT(nn.Module):
    """
    Diffusion model with a Transformer backbone.
    """

    # prepare pos_embed for latent size * 2
    POS_EMBED_MAX_RATIO = 1.5

    def __init__(
        self,
        input_size: int = 32,
        patch_size: int = 2,
        in_channels: int = 4,
        depth: int = 28,
        # hidden_size: Optional[int] = None,
        # num_heads: Optional[int] = None,
        mlp_ratio: float = 4.0,
        learn_sigma: bool = False,
        adm_in_channels: Optional[int] = None,
        context_embedder_in_features: Optional[int] = None,
        context_embedder_out_features: Optional[int] = None,
        use_checkpoint: bool = False,
        register_length: int = 0,
        attn_mode: str = "torch",
        rmsnorm: bool = False,
        scale_mod_only: bool = False,
        swiglu: bool = False,
        out_channels: Optional[int] = None,
        pos_embed_scaling_factor: Optional[float] = None,
        pos_embed_offset: Optional[float] = None,
        pos_embed_max_size: Optional[int] = None,
        num_patches=None,
        qk_norm: Optional[str] = None,
        x_block_self_attn_layers: Optional[list[int]] = [],
        qkv_bias: bool = True,
        pos_emb_random_crop_rate: float = 0.0,
        use_scaled_pos_embed: bool = False,
        pos_embed_latent_sizes: Optional[list[int]] = None,
        model_type: str = "sd3m",
    ):
        super().__init__()
        self._model_type = model_type
        self.learn_sigma = learn_sigma
        self.in_channels = in_channels
        default_out_channels = in_channels * 2 if learn_sigma else in_channels
        self.out_channels = default(out_channels, default_out_channels)
        self.patch_size = patch_size
        self.pos_embed_scaling_factor = pos_embed_scaling_factor
        self.pos_embed_offset = pos_embed_offset
        self.pos_embed_max_size = pos_embed_max_size
        self.x_block_self_attn_layers = x_block_self_attn_layers
        self.pos_emb_random_crop_rate = pos_emb_random_crop_rate
        self.gradient_checkpointing = use_checkpoint

        # hidden_size = default(hidden_size, 64 * depth)
        # num_heads = default(num_heads, hidden_size // 64)

        # apply magic --> this defines a head_size of 64
        self.hidden_size = 64 * depth
        num_heads = depth

        self.num_heads = num_heads

        self.enable_scaled_pos_embed(use_scaled_pos_embed, pos_embed_latent_sizes)

        self.x_embedder = PatchEmbed(
            input_size,
            patch_size,
            in_channels,
            self.hidden_size,
            bias=True,
            strict_img_size=self.pos_embed_max_size is None,
        )
        self.t_embedder = TimestepEmbedding(self.hidden_size)

        self.y_embedder = None
        if adm_in_channels is not None:
            assert isinstance(adm_in_channels, int)
            self.y_embedder = Embedder(adm_in_channels, self.hidden_size)

        if context_embedder_in_features is not None:
            self.context_embedder = nn.Linear(context_embedder_in_features, context_embedder_out_features)
        else:
            self.context_embedder = nn.Identity()

        self.register_length = register_length
        if self.register_length > 0:
            self.register = nn.Parameter(torch.randn(1, register_length, self.hidden_size))

        # num_patches = self.x_embedder.num_patches
        # Will use fixed sin-cos embedding:
        # just use a buffer already
        if num_patches is not None:
            self.register_buffer(
                "pos_embed",
                torch.empty(1, num_patches, self.hidden_size),
            )
        else:
            self.pos_embed = None

        self.use_checkpoint = use_checkpoint
        self.joint_blocks = nn.ModuleList(
            [
                MMDiTBlock(
                    self.hidden_size,
                    num_heads,
                    mlp_ratio=mlp_ratio,
                    attn_mode=attn_mode,
                    qkv_bias=qkv_bias,
                    pre_only=i == depth - 1,
                    rmsnorm=rmsnorm,
                    scale_mod_only=scale_mod_only,
                    swiglu=swiglu,
                    qk_norm=qk_norm,
                    x_block_self_attn=(i in self.x_block_self_attn_layers),
                )
                for i in range(depth)
            ]
        )
        for block in self.joint_blocks:
            block.gradient_checkpointing = use_checkpoint

        self.final_layer = UnPatch(self.hidden_size, patch_size, self.out_channels)
        # self.initialize_weights()

        self.blocks_to_swap = None
        self.offloader = None
        self.num_blocks = len(self.joint_blocks)

    def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Optional[list[int]]):
        self.use_scaled_pos_embed = use_scaled_pos_embed

        if self.use_scaled_pos_embed:
            # remove pos_embed to free up memory up to 0.4 GB
            self.pos_embed = None

            # remove duplicates and sort latent sizes in ascending order
            latent_sizes = list(set(latent_sizes))
            latent_sizes = sorted(latent_sizes)

            patched_sizes = [latent_size // self.patch_size for latent_size in latent_sizes]

            # calculate value range for each latent area: this is used to determine the pos_emb size from the latent shape
            max_areas = []
            for i in range(1, len(patched_sizes)):
                prev_area = patched_sizes[i - 1] ** 2
                area = patched_sizes[i] ** 2
                max_areas.append((prev_area + area) // 2)

            # area of the last latent size, if the latent size exceeds this, error will be raised
            max_areas.append(int((patched_sizes[-1] * MMDiT.POS_EMBED_MAX_RATIO) ** 2))
            # print("max_areas", max_areas)

            self.resolution_area_to_latent_size = [(area, latent_size) for area, latent_size in zip(max_areas, patched_sizes)]

            self.resolution_pos_embeds = {}
            for patched_size in patched_sizes:
                grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO)
                pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size)
                pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
                self.resolution_pos_embeds[patched_size] = pos_embed
                # print(f"pos_embed for {patched_size}x{patched_size} latent size: {pos_embed.shape}")

        else:
            self.resolution_area_to_latent_size = None
            self.resolution_pos_embeds = None

    @property
    def model_type(self):
        return self._model_type

    @property
    def device(self):
        return next(self.parameters()).device

    @property
    def dtype(self):
        return next(self.parameters()).dtype

    def enable_gradient_checkpointing(self):
        self.gradient_checkpointing = True
        for block in self.joint_blocks:
            block.enable_gradient_checkpointing()

    def disable_gradient_checkpointing(self):
        self.gradient_checkpointing = False
        for block in self.joint_blocks:
            block.disable_gradient_checkpointing()

    def initialize_weights(self):
        # TODO: Init context_embedder?
        # Initialize transformer layers:
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

        self.apply(_basic_init)

        # Initialize (and freeze) pos_embed by sin-cos embedding
        if self.pos_embed is not None:
            pos_embed = get_2d_sincos_pos_embed(
                self.pos_embed.shape[-1],
                int(self.pos_embed.shape[-2] ** 0.5),
                scaling_factor=self.pos_embed_scaling_factor,
            )
            self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        # Initialize patch_embed like nn.Linear (instead of nn.Conv2d)
        w = self.x_embedder.proj.weight.data
        nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
        nn.init.constant_(self.x_embedder.proj.bias, 0)

        if getattr(self, "y_embedder", None) is not None:
            nn.init.normal_(self.y_embedder.mlp[0].weight, std=0.02)
            nn.init.normal_(self.y_embedder.mlp[2].weight, std=0.02)

        # Initialize timestep embedding MLP:
        nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
        nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)

        # Zero-out adaLN modulation layers in DiT blocks:
        for block in self.joint_blocks:
            nn.init.constant_(block.x_block.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.x_block.adaLN_modulation[-1].bias, 0)
            nn.init.constant_(block.context_block.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.context_block.adaLN_modulation[-1].bias, 0)

        # Zero-out output layers:
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.linear.weight, 0)
        nn.init.constant_(self.final_layer.linear.bias, 0)

    def set_pos_emb_random_crop_rate(self, rate: float):
        self.pos_emb_random_crop_rate = rate

    def cropped_pos_embed(self, h, w, device=None, random_crop: bool = False):
        p = self.x_embedder.patch_size
        # patched size
        h = (h + 1) // p
        w = (w + 1) // p
        if self.pos_embed is None:  # should not happen
            return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device)
        assert self.pos_embed_max_size is not None
        assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
        assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)

        if not random_crop:
            top = (self.pos_embed_max_size - h) // 2
            left = (self.pos_embed_max_size - w) // 2
        else:
            top = torch.randint(0, self.pos_embed_max_size - h + 1, (1,)).item()
            left = torch.randint(0, self.pos_embed_max_size - w + 1, (1,)).item()

        spatial_pos_embed = self.pos_embed.reshape(
            1,
            self.pos_embed_max_size,
            self.pos_embed_max_size,
            self.pos_embed.shape[-1],
        )
        spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
        spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
        return spatial_pos_embed

    def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: bool = False):
        p = self.x_embedder.patch_size
        # patched size
        h = (h + 1) // p
        w = (w + 1) // p

        # select pos_embed size based on area
        area = h * w
        patched_size = None
        for area_, patched_size_ in self.resolution_area_to_latent_size:
            if area <= area_:
                patched_size = patched_size_
                break
        if patched_size is None:
            raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.")

        pos_embed = self.resolution_pos_embeds[patched_size]
        pos_embed_size = round(math.sqrt(pos_embed.shape[1]))
        if h > pos_embed_size or w > pos_embed_size:
            # # fallback to normal pos_embed
            # return self.cropped_pos_embed(h * p, w * p, device=device, random_crop=random_crop)
            # extend pos_embed size
            logger.warning(
                f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide."
            )
            pos_embed_size = max(h, w)
            pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, pos_embed_size, sample_size=patched_size)
            pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
            self.resolution_pos_embeds[patched_size] = pos_embed
            logger.info(f"Updated pos_embed for size {pos_embed_size}x{pos_embed_size}")

        if not random_crop:
            top = (pos_embed_size - h) // 2
            left = (pos_embed_size - w) // 2
        else:
            top = torch.randint(0, pos_embed_size - h + 1, (1,)).item()
            left = torch.randint(0, pos_embed_size - w + 1, (1,)).item()

        if pos_embed.device != device:
            pos_embed = pos_embed.to(device)
            # which is better to update device, or transfer every time to device? -> 64x64 emb is 96*96*1536*4=56MB. It's okay to update device.
            self.resolution_pos_embeds[patched_size] = pos_embed  # update device
        if pos_embed.dtype != dtype:
            pos_embed = pos_embed.to(dtype)
            self.resolution_pos_embeds[patched_size] = pos_embed  # update dtype

        spatial_pos_embed = pos_embed.reshape(1, pos_embed_size, pos_embed_size, pos_embed.shape[-1])
        spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
        spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
        # print(
        #     f"patched size: {h}x{w}, pos_embed size: {pos_embed_size}, pos_embed shape: {pos_embed.shape}, top: {top}, left: {left}"
        # )
        return spatial_pos_embed

    def enable_block_swap(self, num_blocks: int, device: torch.device):
        self.blocks_to_swap = num_blocks

        assert (
            self.blocks_to_swap <= self.num_blocks - 2
        ), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks."

        self.offloader = custom_offloading_utils.ModelOffloader(
            self.joint_blocks, self.num_blocks, self.blocks_to_swap, device  # , debug=True
        )
        print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.")

    def move_to_device_except_swap_blocks(self, device: torch.device):
        # assume model is on cpu. do not move blocks to device to reduce temporary memory usage
        if self.blocks_to_swap:
            save_blocks = self.joint_blocks
            self.joint_blocks = None

        self.to(device)

        if self.blocks_to_swap:
            self.joint_blocks = save_blocks

    def prepare_block_swap_before_forward(self):
        if self.blocks_to_swap is None or self.blocks_to_swap == 0:
            return
        self.offloader.prepare_block_devices_before_forward(self.joint_blocks)

    def forward(
        self,
        x: torch.Tensor,
        t: torch.Tensor,
        y: Optional[torch.Tensor] = None,
        context: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Forward pass of DiT.
        x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
        t: (N,) tensor of diffusion timesteps
        y: (N, D) tensor of class labels
        """
        pos_emb_random_crop = (
            False if self.pos_emb_random_crop_rate == 0.0 else torch.rand(1).item() < self.pos_emb_random_crop_rate
        )

        B, C, H, W = x.shape

        # x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype)
        if not self.use_scaled_pos_embed:
            pos_embed = self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype)
        else:
            # print(f"Using scaled pos_embed for size {H}x{W}")
            pos_embed = self.cropped_scaled_pos_embed(H, W, device=x.device, dtype=x.dtype, random_crop=pos_emb_random_crop)
        x = self.x_embedder(x) + pos_embed
        del pos_embed

        c = self.t_embedder(t, dtype=x.dtype)  # (N, D)
        if y is not None and self.y_embedder is not None:
            y = self.y_embedder(y)  # (N, D)
            c = c + y  # (N, D)

        if context is not None:
            context = self.context_embedder(context)

        if self.register_length > 0:
            context = torch.cat(
                (einops.repeat(self.register, "1 ... -> b ...", b=x.shape[0]), default(context, torch.Tensor([]).type_as(x))), 1
            )

        if not self.blocks_to_swap:
            for block in self.joint_blocks:
                context, x = block(context, x, c)
        else:
            for block_idx, block in enumerate(self.joint_blocks):
                self.offloader.wait_for_block(block_idx)

                context, x = block(context, x, c)

                self.offloader.submit_move_blocks(self.joint_blocks, block_idx)

        x = self.final_layer(x, c, H, W)  # Our final layer combined UnPatchify
        return x[:, :, :H, :W]


def create_sd3_mmdit(params: SD3Params, attn_mode: str = "torch") -> MMDiT:
    mmdit = MMDiT(
        input_size=None,
        pos_embed_max_size=params.pos_embed_max_size,
        patch_size=params.patch_size,
        in_channels=16,
        adm_in_channels=params.adm_in_channels,
        context_embedder_in_features=params.context_embedder_in_features,
        context_embedder_out_features=params.context_embedder_out_features,
        depth=params.depth,
        mlp_ratio=4,
        qk_norm=params.qk_norm,
        x_block_self_attn_layers=params.x_block_self_attn_layers,
        num_patches=params.num_patches,
        attn_mode=attn_mode,
        model_type=params.model_type,
    )
    return mmdit


# endregion

# region VAE

VAE_SCALE_FACTOR = 1.5305
VAE_SHIFT_FACTOR = 0.0609


def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None):
    return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)


class ResnetBlock(torch.nn.Module):
    def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels

        self.norm1 = Normalize(in_channels, dtype=dtype, device=device)
        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
        self.norm2 = Normalize(out_channels, dtype=dtype, device=device)
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
        if self.in_channels != self.out_channels:
            self.nin_shortcut = torch.nn.Conv2d(
                in_channels, out_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device
            )
        else:
            self.nin_shortcut = None
        self.swish = torch.nn.SiLU(inplace=True)

    def forward(self, x):
        hidden = x
        hidden = self.norm1(hidden)
        hidden = self.swish(hidden)
        hidden = self.conv1(hidden)
        hidden = self.norm2(hidden)
        hidden = self.swish(hidden)
        hidden = self.conv2(hidden)
        if self.in_channels != self.out_channels:
            x = self.nin_shortcut(x)
        return x + hidden


class AttnBlock(torch.nn.Module):
    def __init__(self, in_channels, dtype=torch.float32, device=None):
        super().__init__()
        self.norm = Normalize(in_channels, dtype=dtype, device=device)
        self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
        self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
        self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
        self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)

    def forward(self, x):
        hidden = self.norm(x)
        q = self.q(hidden)
        k = self.k(hidden)
        v = self.v(hidden)
        b, c, h, w = q.shape
        q, k, v = map(lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v))
        hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v)  # scale is dim ** -0.5 per default
        hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
        hidden = self.proj_out(hidden)
        return x + hidden


class Downsample(torch.nn.Module):
    def __init__(self, in_channels, dtype=torch.float32, device=None):
        super().__init__()
        self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0, dtype=dtype, device=device)

    def forward(self, x):
        pad = (0, 1, 0, 1)
        x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
        x = self.conv(x)
        return x


class Upsample(torch.nn.Module):
    def __init__(self, in_channels, dtype=torch.float32, device=None):
        super().__init__()
        self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)

    def forward(self, x):
        org_dtype = x.dtype
        if x.dtype == torch.bfloat16:
            x = x.to(torch.float32)
        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
        if x.dtype != org_dtype:
            x = x.to(org_dtype)
        x = self.conv(x)
        return x


class VAEEncoder(torch.nn.Module):
    def __init__(
        self, ch=128, ch_mult=(1, 2, 4, 4), num_res_blocks=2, in_channels=3, z_channels=16, dtype=torch.float32, device=None
    ):
        super().__init__()
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        # downsampling
        self.conv_in = torch.nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
        in_ch_mult = (1,) + tuple(ch_mult)
        self.in_ch_mult = in_ch_mult
        self.down = torch.nn.ModuleList()
        for i_level in range(self.num_resolutions):
            block = torch.nn.ModuleList()
            attn = torch.nn.ModuleList()
            block_in = ch * in_ch_mult[i_level]
            block_out = ch * ch_mult[i_level]
            for i_block in range(num_res_blocks):
                block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
                block_in = block_out
            down = torch.nn.Module()
            down.block = block
            down.attn = attn
            if i_level != self.num_resolutions - 1:
                down.downsample = Downsample(block_in, dtype=dtype, device=device)
            self.down.append(down)
        # middle
        self.mid = torch.nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
        self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
        self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
        # end
        self.norm_out = Normalize(block_in, dtype=dtype, device=device)
        self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
        self.swish = torch.nn.SiLU(inplace=True)

    def forward(self, x):
        # downsampling
        hs = [self.conv_in(x)]
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = self.down[i_level].block[i_block](hs[-1])
                hs.append(h)
            if i_level != self.num_resolutions - 1:
                hs.append(self.down[i_level].downsample(hs[-1]))
        # middle
        h = hs[-1]
        h = self.mid.block_1(h)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h)
        # end
        h = self.norm_out(h)
        h = self.swish(h)
        h = self.conv_out(h)
        return h


class VAEDecoder(torch.nn.Module):
    def __init__(
        self,
        ch=128,
        out_ch=3,
        ch_mult=(1, 2, 4, 4),
        num_res_blocks=2,
        resolution=256,
        z_channels=16,
        dtype=torch.float32,
        device=None,
    ):
        super().__init__()
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        block_in = ch * ch_mult[self.num_resolutions - 1]
        curr_res = resolution // 2 ** (self.num_resolutions - 1)
        # z to block_in
        self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
        # middle
        self.mid = torch.nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
        self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
        self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
        # upsampling
        self.up = torch.nn.ModuleList()
        for i_level in reversed(range(self.num_resolutions)):
            block = torch.nn.ModuleList()
            block_out = ch * ch_mult[i_level]
            for i_block in range(self.num_res_blocks + 1):
                block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
                block_in = block_out
            up = torch.nn.Module()
            up.block = block
            if i_level != 0:
                up.upsample = Upsample(block_in, dtype=dtype, device=device)
                curr_res = curr_res * 2
            self.up.insert(0, up)  # prepend to get consistent order
        # end
        self.norm_out = Normalize(block_in, dtype=dtype, device=device)
        self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
        self.swish = torch.nn.SiLU(inplace=True)

    def forward(self, z):
        # z to block_in
        hidden = self.conv_in(z)
        # middle
        hidden = self.mid.block_1(hidden)
        hidden = self.mid.attn_1(hidden)
        hidden = self.mid.block_2(hidden)
        # upsampling
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks + 1):
                hidden = self.up[i_level].block[i_block](hidden)
            if i_level != 0:
                hidden = self.up[i_level].upsample(hidden)
        # end
        hidden = self.norm_out(hidden)
        hidden = self.swish(hidden)
        hidden = self.conv_out(hidden)
        return hidden


class SDVAE(torch.nn.Module):
    def __init__(self, dtype=torch.float32, device=None):
        super().__init__()
        self.encoder = VAEEncoder(dtype=dtype, device=device)
        self.decoder = VAEDecoder(dtype=dtype, device=device)

    @property
    def device(self):
        return next(self.parameters()).device

    @property
    def dtype(self):
        return next(self.parameters()).dtype

    # @torch.autocast("cuda", dtype=torch.float16)
    def decode(self, latent):
        return self.decoder(latent)

    # @torch.autocast("cuda", dtype=torch.float16)
    def encode(self, image):
        hidden = self.encoder(image)
        mean, logvar = torch.chunk(hidden, 2, dim=1)
        logvar = torch.clamp(logvar, -30.0, 20.0)
        std = torch.exp(0.5 * logvar)
        return mean + std * torch.randn_like(mean)

    @staticmethod
    def process_in(latent):
        return (latent - VAE_SHIFT_FACTOR) * VAE_SCALE_FACTOR

    @staticmethod
    def process_out(latent):
        return (latent / VAE_SCALE_FACTOR) + VAE_SHIFT_FACTOR


# endregion