# Based on https://raw.githubusercontent.com/okotaku/diffusers/feature/reference_only_control/examples/community/stable_diffusion_reference.py
# Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236
import torch.fft as fft
from typing import Any, Callable, Dict, List, Optional, Union, Tuple

import numpy as np
import PIL.Image
import torch

from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.unet_2d_blocks import (
    CrossAttnDownBlock2D,
    CrossAttnUpBlock2D,
    DownBlock2D,
    UpBlock2D,
)
from diffusers.utils import PIL_INTERPOLATION, logging
import torch.nn.functional as F

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

EXAMPLE_DOC_STRING = """
    Examples:
        ```py
        >>> import torch
        >>> from diffusers import UniPCMultistepScheduler
        >>> from diffusers.utils import load_image

        >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")

        >>> pipe = StableDiffusionReferencePipeline.from_pretrained(
                "runwayml/stable-diffusion-v1-5",
                safety_checker=None,
                torch_dtype=torch.float16
                ).to('cuda:0')

        >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config)

        >>> result_img = pipe(ref_image=input_image,
                        prompt="1girl",
                        num_inference_steps=20,
                        reference_attn=True,
                        reference_adain=True).images[0]

        >>> result_img.show()
        ```
"""


def torch_dfs(model: torch.nn.Module):
    result = [model]
    for child in model.children():
        result += torch_dfs(child)
    return result


@torch.no_grad()
def add_freq_feature(feature1, feature2, ref_ratio):
    """
    feature1: reference feature
    feature2: target feature
    ref_ratio: larger ratio means larger reference frequency
    """
    # Convert features to float32 (if not already) for compatibility with fft operations
    data_type = feature2.dtype
    feature1 = feature1.to(torch.float32)
    feature2 = feature2.to(torch.float32)

    # Compute the Fourier transforms of both features
    spectrum1 = fft.fftn(feature1, dim=(-2, -1))
    spectrum2 = fft.fftn(feature2, dim=(-2, -1))

    # Extract high-frequency magnitude and phase from feature1
    magnitude1 = torch.abs(spectrum1)
    # phase1 = torch.angle(spectrum1)

    # Extract magnitude and phase from feature2
    magnitude2 = torch.abs(spectrum2)
    phase2 = torch.angle(spectrum2)

    magnitude2.mul_((1-ref_ratio)).add_(magnitude1 * ref_ratio)
    # phase2.mul_(1.0).add_(phase1 * 0.0)

    # Combine magnitude and phase information
    mixed_spectrum = torch.polar(magnitude2, phase2)

    # Compute the inverse Fourier transform to get the mixed feature
    mixed_feature = fft.ifftn(mixed_spectrum, dim=(-2, -1))

    del feature1, feature2, spectrum1, spectrum2, magnitude1, magnitude2, phase2, mixed_spectrum

    # Convert back to the original data type and return the result
    return mixed_feature.to(data_type)


@torch.no_grad()
def save_ref_feature(feature, mask):
    """
    feature: n,c,h,w
    mask: n,1,h,w

    return n,c,h,w
    """
    return feature * mask


@torch.no_grad()
def mix_ref_feature(feature, ref_fea_bank, cfg=True, ref_scale=0.0, dim3=False):
    """
    feature: n,l,c or n,c,h,w
    ref_fea_bank: [(n,c,h,w)]
    cfg: True/False

    return n,l,c or n,c,h,w
    """
    if cfg:
        ref_fea = torch.cat(
            (ref_fea_bank+ref_fea_bank), dim=0)
    else:
        ref_fea = ref_fea_bank

    if dim3:
        feature = feature.permute(0, 2, 1).view(ref_fea.shape)

    mixed_feature = add_freq_feature(ref_fea, feature, ref_scale)

    if dim3:
        mixed_feature = mixed_feature.view(
            ref_fea.shape[0], ref_fea.shape[1], -1).permute(0, 2, 1)

    del ref_fea
    del feature
    return mixed_feature


def mix_norm_feature(x, inpaint_mask, mean_bank, var_bank, do_classifier_free_guidance, style_fidelity, uc_mask, eps=1e-6):
    """
    x: input feature n,c,h,w
    inpaint_mask: mask region to inpain
    """

    # get the inpainting region and only mix this region.
    scale_ratio = inpaint_mask.shape[2] / x.shape[2]
    this_inpaint_mask = F.interpolate(
        inpaint_mask.to(x.device), scale_factor=1 / scale_ratio
    )
    this_inpaint_mask = this_inpaint_mask.repeat(
        x.shape[0], x.shape[1], 1, 1
    ).bool()
    masked_x = (
        x[this_inpaint_mask]
        .detach()
        .clone()
        .view(x.shape[0], x.shape[1], -1, 1)
    )
    var, mean = torch.var_mean(
        masked_x, dim=(2, 3), keepdim=True, correction=0
    )
    std = torch.maximum(
        var, torch.zeros_like(var) + eps) ** 0.5
    mean_acc = sum(mean_bank) / float(len(mean_bank))
    var_acc = sum(var_bank) / float(len(var_bank))
    std_acc = (
        torch.maximum(var_acc, torch.zeros_like(
            var_acc) + eps) ** 0.5
    )

    x_uc = (((masked_x - mean) / std) * std_acc) + mean_acc
    x_c = x_uc.clone()
    if do_classifier_free_guidance and style_fidelity > 0:
        x_c[uc_mask] = masked_x[uc_mask]
    masked_x = style_fidelity * x_c + \
        (1.0 - style_fidelity) * x_uc
    x[this_inpaint_mask] = masked_x.view(-1)
    return x


class StableDiffusionReferencePipeline:
    def prepare_ref_image(
        self,
        image,
        width,
        height,
        batch_size,
        num_images_per_prompt,
        device,
        dtype,
        do_classifier_free_guidance=False,
        guess_mode=False,
    ):
        if not isinstance(image, torch.Tensor):
            if isinstance(image, PIL.Image.Image):
                image = [image]

            if isinstance(image[0], PIL.Image.Image):
                images = []

                for image_ in image:
                    image_ = image_.convert("RGB")
                    image_ = image_.resize(
                        (width, height), resample=PIL_INTERPOLATION["lanczos"]
                    )
                    image_ = np.array(image_)
                    image_ = image_[None, :]
                    images.append(image_)

                image = images

                image = np.concatenate(image, axis=0)
                image = np.array(image).astype(np.float32) / 255.0
                image = (image - 0.5) / 0.5
                image = image.transpose(0, 3, 1, 2)
                image = torch.from_numpy(image)
            elif isinstance(image[0], torch.Tensor):
                image = torch.cat(image, dim=0)

        image_batch_size = image.shape[0]

        if image_batch_size == 1:
            repeat_by = batch_size
        else:
            # image batch size is the same as prompt batch size
            repeat_by = num_images_per_prompt

        image = image.repeat_interleave(repeat_by, dim=0)

        image = image.to(device=device, dtype=dtype)

        if do_classifier_free_guidance and not guess_mode:
            image = torch.cat([image] * 2)

        return image

    def prepare_ref_latents(
        self,
        refimage,
        batch_size,
        dtype,
        device,
        generator,
        do_classifier_free_guidance,
    ):
        refimage = refimage.to(device=device, dtype=dtype)

        # encode the mask image into latents space so we can concatenate it to the latents
        if isinstance(generator, list):
            ref_image_latents = [
                self.vae.encode(refimage[i: i + 1]).latent_dist.sample(
                    generator=generator[i]
                )
                for i in range(batch_size)
            ]
            ref_image_latents = torch.cat(ref_image_latents, dim=0)
        else:
            ref_image_latents = self.vae.encode(refimage).latent_dist.sample(
                generator=generator
            )
        ref_image_latents = self.vae.config.scaling_factor * ref_image_latents

        # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method
        if ref_image_latents.shape[0] < batch_size:
            if not batch_size % ref_image_latents.shape[0] == 0:
                raise ValueError(
                    "The passed images and the required batch size don't match. Images are supposed to be duplicated"
                    f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed."
                    " Make sure the number of images that you pass is divisible by the total requested batch size."
                )
            ref_image_latents = ref_image_latents.repeat(
                batch_size // ref_image_latents.shape[0], 1, 1, 1
            )

        ref_image_latents = (
            torch.cat([ref_image_latents] * 2)
            if do_classifier_free_guidance
            else ref_image_latents
        )

        # aligning device to prevent device errors when concating it with the latent model input
        ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
        return ref_image_latents

    def check_ref_input(self, reference_attn, reference_adain):
        assert (
            reference_attn or reference_adain
        ), "`reference_attn` or `reference_adain` must be True."

    def redefine_ref_model(
        self, model, reference_attn, reference_adain, model_type="unet"
    ):
        def hacked_basic_transformer_inner_forward(
            self,
            hidden_states: torch.FloatTensor,
            attention_mask: Optional[torch.FloatTensor] = None,
            encoder_hidden_states: Optional[torch.FloatTensor] = None,
            encoder_attention_mask: Optional[torch.FloatTensor] = None,
            timestep: Optional[torch.LongTensor] = None,
            cross_attention_kwargs: Dict[str, Any] = None,
            class_labels: Optional[torch.LongTensor] = None,
        ):
            if self.use_ada_layer_norm:
                norm_hidden_states = self.norm1(hidden_states, timestep)
            elif self.use_ada_layer_norm_zero:
                (
                    norm_hidden_states,
                    gate_msa,
                    shift_mlp,
                    scale_mlp,
                    gate_mlp,
                ) = self.norm1(
                    hidden_states,
                    timestep,
                    class_labels,
                    hidden_dtype=hidden_states.dtype,
                )
            else:
                norm_hidden_states = self.norm1(hidden_states)

            # 1. Self-Attention
            cross_attention_kwargs = (
                cross_attention_kwargs if cross_attention_kwargs is not None else {}
            )
            if self.only_cross_attention:
                attn_output = self.attn1(
                    norm_hidden_states,
                    encoder_hidden_states=encoder_hidden_states
                    if self.only_cross_attention
                    else None,
                    attention_mask=attention_mask,
                    **cross_attention_kwargs,
                )
            else:
                if self.MODE == "write":
                    if self.attention_auto_machine_weight > self.attn_weight:
                        # print("hacked_basic_transformer_inner_forward")
                        scale_ratio = (
                            (self.ref_mask.shape[2] * self.ref_mask.shape[3])
                            / norm_hidden_states.shape[1]
                        ) ** 0.5
                        this_ref_mask = F.interpolate(
                            self.ref_mask.to(norm_hidden_states.device),
                            scale_factor=1 / scale_ratio,
                        )
                        resize_norm_hidden_states = norm_hidden_states.view(
                            norm_hidden_states.shape[0],
                            this_ref_mask.shape[2],
                            this_ref_mask.shape[3],
                            -1,
                        ).permute(0, 3, 1, 2)

                        ref_scale = 1.0
                        resize_norm_hidden_states = F.interpolate(
                            resize_norm_hidden_states,
                            scale_factor=ref_scale,
                            mode="bilinear",
                        )
                        this_ref_mask = F.interpolate(
                            this_ref_mask, scale_factor=ref_scale
                        )
                        self.fea_bank.append(save_ref_feature(
                            resize_norm_hidden_states, this_ref_mask))
                        this_ref_mask = this_ref_mask.repeat(
                            resize_norm_hidden_states.shape[0],
                            resize_norm_hidden_states.shape[1],
                            1,
                            1,
                        ).bool()
                        masked_norm_hidden_states = (
                            resize_norm_hidden_states[this_ref_mask]
                            .detach()
                            .clone()
                            .view(
                                resize_norm_hidden_states.shape[0],
                                resize_norm_hidden_states.shape[1],
                                -1,
                            )
                        )

                        masked_norm_hidden_states = masked_norm_hidden_states.permute(
                            0, 2, 1
                        )
                        self.bank.append(masked_norm_hidden_states)
                        del masked_norm_hidden_states
                        del this_ref_mask
                        del resize_norm_hidden_states
                    attn_output = self.attn1(
                        norm_hidden_states,
                        encoder_hidden_states=encoder_hidden_states
                        if self.only_cross_attention
                        else None,
                        attention_mask=attention_mask,
                        **cross_attention_kwargs,
                    )
                if self.MODE == "read":
                    if self.attention_auto_machine_weight > self.attn_weight:
                        freq_norm_hidden_states = mix_ref_feature(
                            norm_hidden_states,
                            self.fea_bank,
                            cfg=self.do_classifier_free_guidance,
                            ref_scale=self.ref_scale,
                            dim3=True)
                        self.fea_bank.clear()

                        this_bank = torch.cat(self.bank+self.bank, dim=0)
                        ref_hidden_states = torch.cat(
                            (freq_norm_hidden_states, this_bank), dim=1
                        )
                        del this_bank
                        self.bank.clear()

                        attn_output_uc = self.attn1(
                            freq_norm_hidden_states,
                            encoder_hidden_states=ref_hidden_states,
                            **cross_attention_kwargs,
                        )
                        del ref_hidden_states
                        attn_output_c = attn_output_uc.clone()
                        if self.do_classifier_free_guidance and self.style_fidelity > 0:
                            attn_output_c[self.uc_mask] = self.attn1(
                                norm_hidden_states[self.uc_mask],
                                encoder_hidden_states=norm_hidden_states[self.uc_mask],
                                **cross_attention_kwargs,
                            )
                        attn_output = (
                            self.style_fidelity * attn_output_c
                            + (1.0 - self.style_fidelity) * attn_output_uc
                        )
                        self.bank.clear()
                        self.fea_bank.clear()
                        del attn_output_c
                        del attn_output_uc
                    else:
                        attn_output = self.attn1(
                            norm_hidden_states,
                            encoder_hidden_states=encoder_hidden_states
                            if self.only_cross_attention
                            else None,
                            attention_mask=attention_mask,
                            **cross_attention_kwargs,
                        )
                    self.bank.clear()
                    self.fea_bank.clear()

            if self.use_ada_layer_norm_zero:
                attn_output = gate_msa.unsqueeze(1) * attn_output
            hidden_states = attn_output + hidden_states

            if self.attn2 is not None:
                norm_hidden_states = (
                    self.norm2(hidden_states, timestep)
                    if self.use_ada_layer_norm
                    else self.norm2(hidden_states)
                )

                # 2. Cross-Attention
                attn_output = self.attn2(
                    norm_hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=encoder_attention_mask,
                    **cross_attention_kwargs,
                )
                hidden_states = attn_output + hidden_states

            # 3. Feed-forward
            norm_hidden_states = self.norm3(hidden_states)

            if self.use_ada_layer_norm_zero:
                norm_hidden_states = (
                    norm_hidden_states *
                    (1 + scale_mlp[:, None]) + shift_mlp[:, None]
                )

            ff_output = self.ff(norm_hidden_states)

            if self.use_ada_layer_norm_zero:
                ff_output = gate_mlp.unsqueeze(1) * ff_output

            hidden_states = ff_output + hidden_states

            return hidden_states

        def hacked_mid_forward(self, *args, **kwargs):
            eps = 1e-6
            x = self.original_forward(*args, **kwargs)
            if self.MODE == "write":
                if self.gn_auto_machine_weight >= self.gn_weight:
                    # mask var mean
                    scale_ratio = self.ref_mask.shape[2] / x.shape[2]
                    this_ref_mask = F.interpolate(
                        self.ref_mask.to(x.device), scale_factor=1 / scale_ratio
                    )

                    self.fea_bank.append(save_ref_feature(
                        x, this_ref_mask))

                    this_ref_mask = this_ref_mask.repeat(
                        x.shape[0], x.shape[1], 1, 1
                    ).bool()
                    masked_x = (
                        x[this_ref_mask]
                        .detach()
                        .clone()
                        .view(x.shape[0], x.shape[1], -1, 1)
                    )
                    var, mean = torch.var_mean(
                        masked_x, dim=(2, 3), keepdim=True, correction=0
                    )

                    self.mean_bank.append(torch.cat([mean]*2, dim=0))
                    self.var_bank.append(torch.cat([var]*2, dim=0))
            if self.MODE == "read":
                if (
                    self.gn_auto_machine_weight >= self.gn_weight
                    and len(self.mean_bank) > 0
                    and len(self.var_bank) > 0
                ):
                    # print("hacked_mid_forward")
                    x = mix_ref_feature(
                        x, self.fea_bank, cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale)
                    self.fea_bank = []
                    x = mix_norm_feature(x, self.inpaint_mask, self.mean_bank, self.var_bank,
                                         self.do_classifier_free_guidance,
                                         self.style_fidelity, self.uc_mask)
                self.mean_bank = []
                self.var_bank = []
            return x

        def hack_CrossAttnDownBlock2D_forward(
            self,
            hidden_states: torch.FloatTensor,
            temb: Optional[torch.FloatTensor] = None,
            encoder_hidden_states: Optional[torch.FloatTensor] = None,
            attention_mask: Optional[torch.FloatTensor] = None,
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            encoder_attention_mask: Optional[torch.FloatTensor] = None,
        ):
            eps = 1e-6
            # TODO(Patrick, William) - attention mask is not used
            output_states = ()

            for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
                hidden_states = resnet(hidden_states, temb)

                if self.MODE == "write":
                    if self.gn_auto_machine_weight >= self.gn_weight:
                        # var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
                        # mask var mean
                        scale_ratio = self.ref_mask.shape[2] / \
                            hidden_states.shape[2]
                        this_ref_mask = F.interpolate(
                            self.ref_mask.to(hidden_states.device),
                            scale_factor=1 / scale_ratio,
                        )
                        self.fea_bank0.append(save_ref_feature(
                            hidden_states, this_ref_mask))
                        this_ref_mask = this_ref_mask.repeat(
                            hidden_states.shape[0], hidden_states.shape[1], 1, 1
                        ).bool()
                        masked_hidden_states = (
                            hidden_states[this_ref_mask]
                            .detach()
                            .clone()
                            .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
                        )
                        var, mean = torch.var_mean(
                            masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
                        )
                        self.mean_bank0.append(torch.cat([mean]*2, dim=0))
                        self.var_bank0.append(torch.cat([var]*2, dim=0))
                if self.MODE == "read":
                    if (
                        self.gn_auto_machine_weight >= self.gn_weight
                        and len(self.mean_bank0) > 0
                        and len(self.var_bank0) > 0
                    ):
                        # print("hacked_CrossAttnDownBlock2D_forward0")
                        hidden_states = mix_ref_feature(
                            hidden_states, [self.fea_bank0[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale)

                        hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank0[i], self.var_bank0[i],
                                                         self.do_classifier_free_guidance,
                                                         self.style_fidelity, self.uc_mask)

                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    cross_attention_kwargs=cross_attention_kwargs,
                    return_dict=False,
                )[0]
                if self.MODE == "write":
                    if self.gn_auto_machine_weight >= self.gn_weight:
                        # mask var mean
                        scale_ratio = self.ref_mask.shape[2] / \
                            hidden_states.shape[2]
                        this_ref_mask = F.interpolate(
                            self.ref_mask.to(hidden_states.device),
                            scale_factor=1 / scale_ratio,
                        )
                        self.fea_bank.append(save_ref_feature(
                            hidden_states, this_ref_mask))
                        this_ref_mask = this_ref_mask.repeat(
                            hidden_states.shape[0], hidden_states.shape[1], 1, 1
                        ).bool()
                        masked_hidden_states = (
                            hidden_states[this_ref_mask]
                            .detach()
                            .clone()
                            .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
                        )
                        var, mean = torch.var_mean(
                            masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
                        )
                        self.mean_bank.append(torch.cat([mean]*2, dim=0))
                        self.var_bank.append(torch.cat([var]*2, dim=0))
                if self.MODE == "read":
                    if (
                        self.gn_auto_machine_weight >= self.gn_weight
                        and len(self.mean_bank) > 0
                        and len(self.var_bank) > 0
                    ):
                        # print("hack_CrossAttnDownBlock2D_forward")
                        hidden_states = mix_ref_feature(
                            hidden_states, [self.fea_bank[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale)

                        hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank[i], self.var_bank[i],
                                                         self.do_classifier_free_guidance,
                                                         self.style_fidelity, self.uc_mask)
                output_states = output_states + (hidden_states,)

            if self.MODE == "read":
                self.mean_bank0 = []
                self.var_bank0 = []
                self.mean_bank = []
                self.var_bank = []
                self.fea_bank0 = []
                self.fea_bank = []

            if self.downsamplers is not None:
                for downsampler in self.downsamplers:
                    hidden_states = downsampler(hidden_states)

                output_states = output_states + (hidden_states,)

            return hidden_states, output_states

        def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
            eps = 1e-6

            output_states = ()

            for i, resnet in enumerate(self.resnets):
                hidden_states = resnet(hidden_states, temb)

                if self.MODE == "write":
                    if self.gn_auto_machine_weight >= self.gn_weight:
                        # var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
                        # mask var mean
                        scale_ratio = self.ref_mask.shape[2] / \
                            hidden_states.shape[2]
                        this_ref_mask = F.interpolate(
                            self.ref_mask.to(hidden_states.device),
                            scale_factor=1 / scale_ratio,
                        )
                        self.fea_bank.append(save_ref_feature(
                            hidden_states, this_ref_mask))
                        this_ref_mask = this_ref_mask.repeat(
                            hidden_states.shape[0], hidden_states.shape[1], 1, 1
                        ).bool()
                        masked_hidden_states = (
                            hidden_states[this_ref_mask]
                            .detach()
                            .clone()
                            .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
                        )
                        var, mean = torch.var_mean(
                            masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
                        )
                        self.mean_bank.append(torch.cat([mean]*2, dim=0))
                        self.var_bank.append(torch.cat([var]*2, dim=0))
                if self.MODE == "read":
                    if (
                        self.gn_auto_machine_weight >= self.gn_weight
                        and len(self.mean_bank) > 0
                        and len(self.var_bank) > 0
                    ):
                        # print("hacked_DownBlock2D_forward")
                        hidden_states = mix_ref_feature(
                            hidden_states, [self.fea_bank[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale)

                        hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank[i], self.var_bank[i],
                                                         self.do_classifier_free_guidance,
                                                         self.style_fidelity, self.uc_mask)

                output_states = output_states + (hidden_states,)

            if self.MODE == "read":
                self.mean_bank = []
                self.var_bank = []
                self.fea_bank = []

            if self.downsamplers is not None:
                for downsampler in self.downsamplers:
                    hidden_states = downsampler(hidden_states)

                output_states = output_states + (hidden_states,)

            return hidden_states, output_states

        def hacked_CrossAttnUpBlock2D_forward(
            self,
            hidden_states: torch.FloatTensor,
            res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
            temb: Optional[torch.FloatTensor] = None,
            encoder_hidden_states: Optional[torch.FloatTensor] = None,
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            upsample_size: Optional[int] = None,
            attention_mask: Optional[torch.FloatTensor] = None,
            encoder_attention_mask: Optional[torch.FloatTensor] = None,
        ):
            eps = 1e-6
            # TODO(Patrick, William) - attention mask is not used
            for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
                # pop res hidden states
                res_hidden_states = res_hidden_states_tuple[-1]
                res_hidden_states_tuple = res_hidden_states_tuple[:-1]
                hidden_states = torch.cat(
                    [hidden_states, res_hidden_states], dim=1)
                hidden_states = resnet(hidden_states, temb)

                if self.MODE == "write":
                    if self.gn_auto_machine_weight >= self.gn_weight:
                        # var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
                        # mask var mean
                        scale_ratio = self.ref_mask.shape[2] / \
                            hidden_states.shape[2]
                        this_ref_mask = F.interpolate(
                            self.ref_mask.to(hidden_states.device),
                            scale_factor=1 / scale_ratio,
                        )
                        self.fea_bank0.append(save_ref_feature(
                            hidden_states, this_ref_mask))
                        this_ref_mask = this_ref_mask.repeat(
                            hidden_states.shape[0], hidden_states.shape[1], 1, 1
                        ).bool()
                        masked_hidden_states = (
                            hidden_states[this_ref_mask]
                            .detach()
                            .clone()
                            .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
                        )
                        var, mean = torch.var_mean(
                            masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
                        )
                        self.mean_bank0.append(torch.cat([mean]*2, dim=0))
                        self.var_bank0.append(torch.cat([var]*2, dim=0))
                if self.MODE == "read":
                    if (
                        self.gn_auto_machine_weight >= self.gn_weight
                        and len(self.mean_bank0) > 0
                        and len(self.var_bank0) > 0
                    ):
                        # print("hacked_CrossAttnUpBlock2D_forward1")
                        hidden_states = mix_ref_feature(
                            hidden_states, [self.fea_bank0[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale)

                        hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank0[i], self.var_bank0[i],
                                                         self.do_classifier_free_guidance,
                                                         self.style_fidelity, self.uc_mask)

                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    cross_attention_kwargs=cross_attention_kwargs,
                    # attention_mask=attention_mask,
                    # encoder_attention_mask=encoder_attention_mask,
                    return_dict=False,
                )[0]

                if self.MODE == "write":
                    if self.gn_auto_machine_weight >= self.gn_weight:
                        # var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
                        # mask var mean
                        scale_ratio = self.ref_mask.shape[2] / \
                            hidden_states.shape[2]
                        this_ref_mask = F.interpolate(
                            self.ref_mask.to(hidden_states.device),
                            scale_factor=1 / scale_ratio,
                        )
                        self.fea_bank.append(save_ref_feature(
                            hidden_states, this_ref_mask))
                        this_ref_mask = this_ref_mask.repeat(
                            hidden_states.shape[0], hidden_states.shape[1], 1, 1
                        ).bool()
                        masked_hidden_states = (
                            hidden_states[this_ref_mask]
                            .detach()
                            .clone()
                            .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
                        )
                        var, mean = torch.var_mean(
                            masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
                        )
                        self.mean_bank.append(torch.cat([mean]*2, dim=0))
                        self.var_bank.append(torch.cat([var]*2, dim=0))
                if self.MODE == "read":
                    if (
                        self.gn_auto_machine_weight >= self.gn_weight
                        and len(self.mean_bank) > 0
                        and len(self.var_bank) > 0
                    ):
                        # print("hacked_CrossAttnUpBlock2D_forward")
                        hidden_states = mix_ref_feature(
                            hidden_states, [self.fea_bank[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale)

                        hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank[i], self.var_bank[i],
                                                         self.do_classifier_free_guidance,
                                                         self.style_fidelity, self.uc_mask)

            if self.MODE == "read":
                self.mean_bank0 = []
                self.var_bank0 = []
                self.mean_bank = []
                self.var_bank = []
                self.fea_bank = []
                self.fea_bank0 = []

            if self.upsamplers is not None:
                for upsampler in self.upsamplers:
                    hidden_states = upsampler(hidden_states, upsample_size)

            return hidden_states

        def hacked_UpBlock2D_forward(
            self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None
        ):
            eps = 1e-6
            for i, resnet in enumerate(self.resnets):
                # pop res hidden states
                res_hidden_states = res_hidden_states_tuple[-1]
                res_hidden_states_tuple = res_hidden_states_tuple[:-1]
                hidden_states = torch.cat(
                    [hidden_states, res_hidden_states], dim=1)
                hidden_states = resnet(hidden_states, temb)

                if self.MODE == "write":
                    if self.gn_auto_machine_weight >= self.gn_weight:
                        # var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
                        # mask var mean
                        scale_ratio = self.ref_mask.shape[2] / \
                            hidden_states.shape[2]
                        this_ref_mask = F.interpolate(
                            self.ref_mask.to(hidden_states.device),
                            scale_factor=1 / scale_ratio,
                        )
                        self.fea_bank.append(save_ref_feature(
                            hidden_states, this_ref_mask))
                        this_ref_mask = this_ref_mask.repeat(
                            hidden_states.shape[0], hidden_states.shape[1], 1, 1
                        ).bool()
                        masked_hidden_states = (
                            hidden_states[this_ref_mask]
                            .detach()
                            .clone()
                            .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
                        )
                        var, mean = torch.var_mean(
                            masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
                        )
                        self.mean_bank.append(torch.cat([mean]*2, dim=0))
                        self.var_bank.append(torch.cat([var]*2, dim=0))
                if self.MODE == "read":
                    if (
                        self.gn_auto_machine_weight >= self.gn_weight
                        and len(self.mean_bank) > 0
                        and len(self.var_bank) > 0
                    ):
                        # print("hacked_UpBlock2D_forward")
                        hidden_states = mix_ref_feature(
                            hidden_states, [self.fea_bank[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale)

                        hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank[i], self.var_bank[i],
                                                         self.do_classifier_free_guidance,
                                                         self.style_fidelity, self.uc_mask)

            if self.MODE == "read":
                self.mean_bank = []
                self.var_bank = []
                self.fea_bank = []

            if self.upsamplers is not None:
                for upsampler in self.upsamplers:
                    hidden_states = upsampler(hidden_states, upsample_size)

            return hidden_states

        if model_type == "unet":

            if reference_attn:
                attn_modules = [
                    module
                    for module in torch_dfs(model)
                    if isinstance(module, BasicTransformerBlock)
                ]
                attn_modules = sorted(
                    attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
                )

                for i, module in enumerate(attn_modules):
                    module._original_inner_forward = module.forward
                    module.forward = hacked_basic_transformer_inner_forward.__get__(
                        module, BasicTransformerBlock
                    )
                    module.bank = []
                    module.fea_bank = []
                    module.attn_weight = float(i) / float(len(attn_modules))
                    module.attention_auto_machine_weight = (
                        self.attention_auto_machine_weight
                    )
                    module.gn_auto_machine_weight = self.gn_auto_machine_weight
                    module.do_classifier_free_guidance = (
                        self.do_classifier_free_guidance
                    )
                    module.do_classifier_free_guidance = (
                        self.do_classifier_free_guidance
                    )
                    module.uc_mask = self.uc_mask
                    module.style_fidelity = self.style_fidelity
                    module.ref_mask = self.ref_mask
                    module.ref_scale = self.ref_scale
            else:
                attn_modules = None
            if reference_adain:
                gn_modules = [model.mid_block]
                model.mid_block.gn_weight = 0

                down_blocks = model.down_blocks
                for w, module in enumerate(down_blocks):
                    module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
                    gn_modules.append(module)
                    # print(module.__class__.__name__,module.gn_weight)

                up_blocks = model.up_blocks
                for w, module in enumerate(up_blocks):
                    module.gn_weight = float(w) / float(len(up_blocks))
                    gn_modules.append(module)
                    # print(module.__class__.__name__,module.gn_weight)

                for i, module in enumerate(gn_modules):
                    if getattr(module, "original_forward", None) is None:
                        module.original_forward = module.forward
                    if i == 0:
                        # mid_block
                        module.forward = hacked_mid_forward.__get__(
                            module, torch.nn.Module
                        )
                    # elif isinstance(module, CrossAttnDownBlock2D):
                    #     module.forward = hack_CrossAttnDownBlock2D_forward.__get__(
                    #         module, CrossAttnDownBlock2D
                    #     )
                    #     module.mean_bank0 = []
                    #     module.var_bank0 = []
                    #     module.fea_bank0 = []
                        
                    elif isinstance(module, DownBlock2D):
                        module.forward = hacked_DownBlock2D_forward.__get__(
                            module, DownBlock2D
                        )
                    # elif isinstance(module, CrossAttnUpBlock2D):
                    #     module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
                    #     module.mean_bank0 = []
                    #     module.var_bank0 = []
                    #     module.fea_bank0 = []
                    elif isinstance(module, UpBlock2D):
                        module.forward = hacked_UpBlock2D_forward.__get__(
                            module, UpBlock2D
                        )
                        module.mean_bank0 = []
                        module.var_bank0 = []
                        module.fea_bank0 = []
                    module.mean_bank = []
                    module.var_bank = []
                    module.fea_bank = []
                    module.attention_auto_machine_weight = (
                        self.attention_auto_machine_weight
                    )
                    module.gn_auto_machine_weight = self.gn_auto_machine_weight
                    module.do_classifier_free_guidance = (
                        self.do_classifier_free_guidance
                    )
                    module.do_classifier_free_guidance = (
                        self.do_classifier_free_guidance
                    )
                    module.uc_mask = self.uc_mask
                    module.style_fidelity = self.style_fidelity
                    module.ref_mask = self.ref_mask
                    module.inpaint_mask = self.inpaint_mask
                    module.ref_scale = self.ref_scale
            else:
                gn_modules = None
        elif model_type == "controlnet":
            model = model.nets[-1]  # only hack the inpainting controlnet
            if reference_attn:
                attn_modules = [
                    module
                    for module in torch_dfs(model)
                    if isinstance(module, BasicTransformerBlock)
                ]
                attn_modules = sorted(
                    attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
                )
                for i, module in enumerate(attn_modules):
                    module._original_inner_forward = module.forward
                    module.forward = hacked_basic_transformer_inner_forward.__get__(
                        module, BasicTransformerBlock
                    )
                    module.bank = []
                    module.fea_bank = []
                    # float(i) / float(len(attn_modules))
                    module.attn_weight = 0.0
                    module.attention_auto_machine_weight = (
                        self.attention_auto_machine_weight
                    )
                    module.gn_auto_machine_weight = self.gn_auto_machine_weight
                    module.do_classifier_free_guidance = (
                        self.do_classifier_free_guidance
                    )
                    module.do_classifier_free_guidance = (
                        self.do_classifier_free_guidance
                    )
                    module.uc_mask = self.uc_mask
                    module.style_fidelity = self.style_fidelity
                    module.ref_mask = self.ref_mask
                    module.ref_scale = self.ref_scale
            else:
                attn_modules = None
            # gn_modules = None
            if reference_adain:
                gn_modules = [model.mid_block]
                model.mid_block.gn_weight = 0

                down_blocks = model.down_blocks
                for w, module in enumerate(down_blocks):
                    module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
                    gn_modules.append(module)
                    # print(module.__class__.__name__,module.gn_weight)


                for i, module in enumerate(gn_modules):
                    if getattr(module, "original_forward", None) is None:
                        module.original_forward = module.forward
                    if i == 0:
                        # mid_block
                        module.forward = hacked_mid_forward.__get__(
                            module, torch.nn.Module
                        )
                    # elif isinstance(module, CrossAttnDownBlock2D):
                    #     module.forward = hack_CrossAttnDownBlock2D_forward.__get__(
                    #         module, CrossAttnDownBlock2D
                    #     )
                    #     module.mean_bank0 = []
                    #     module.var_bank0 = []
                    #     module.fea_bank0 = []
                        
                    elif isinstance(module, DownBlock2D):
                        module.forward = hacked_DownBlock2D_forward.__get__(
                            module, DownBlock2D
                        )
                    module.mean_bank = []
                    module.var_bank = []
                    module.fea_bank = []
                    module.attention_auto_machine_weight = (
                        self.attention_auto_machine_weight
                    )
                    module.gn_auto_machine_weight = self.gn_auto_machine_weight
                    module.do_classifier_free_guidance = (
                        self.do_classifier_free_guidance
                    )
                    module.do_classifier_free_guidance = (
                        self.do_classifier_free_guidance
                    )
                    module.uc_mask = self.uc_mask
                    module.style_fidelity = self.style_fidelity
                    module.ref_mask = self.ref_mask
                    module.inpaint_mask = self.inpaint_mask
                    module.ref_scale = self.ref_scale
            else:
                gn_modules = None

        return attn_modules, gn_modules

    def change_module_mode(self, mode, attn_modules, gn_modules):
        if attn_modules is not None:
            for i, module in enumerate(attn_modules):
                module.MODE = mode

        if gn_modules is not None:
            for i, module in enumerate(gn_modules):
                module.MODE = mode