import itertools
from functools import partial
from typing import Any, Dict, Tuple, Callable
from typing import Union, Optional, List

import numpy as np
import torch
from diffusers import DPMSolverMultistepScheduler
from diffusers import StableDiffusionPipeline, AutoencoderKL
from diffusers import Transformer2DModel, ModelMixin, ConfigMixin, SchedulerMixin
from diffusers import UNet2DConditionModel
from diffusers.configuration_utils import register_to_config
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.resnet import ResnetBlock2D, Downsample2D, Upsample2D
from diffusers.models.transformer_2d import Transformer2DModelOutput
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import replace_example_docstring
from torch import nn
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor


def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
    """
    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
    Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
    """
    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
    # rescale the results from guidance (fixes overexposure)
    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
    # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
    return noise_cfg


def custom_sort_order(obj):
    """
    Key function for sorting order of execution in forward methods
    """
    return {ResnetBlock2D: 0, Transformer2DModel: 1, FlexibleTransformer2DModel: 1}.get(obj.__class__)


def squeeze_to_len_n_starting_from_index_i(n, i, timestep_spacing):
    """
    :param timestep_spacing: the timestep_spacing array we want to squeeze
    :param n: the size of the squeezed array
    :param i: the index we start squeezing from
    :return: squeezed timestep_spacing
    Example:
    timesteps = np.array([967, 907, 846, 786, 725, 665, 604, 544, 484, 423, 363, 302, 242, 181, 121, 60]) (len=16)
    n = 10, i = 6
    Expected:
    [967, 907, 846, 786, 725, 665, 4k, 3k, 2k, k], and if we define 665=5k => k = 133
    """
    assert i < n
    squeezed = np.flip(np.arange(n)) + 1  # [n, n-1, ..., 2, 1]
    squeezed[:i] = timestep_spacing[:i]
    k = squeezed[i - 1] // (n - i + 1)
    squeezed[i:] *= k

    return squeezed


PREDEFINED_TIMESTEP_SQUEEZERS = {
    # Tested with DPM 16-steps (reduced 16 -> 10 or 11 steps)
    "10,6": partial(squeeze_to_len_n_starting_from_index_i, 10, 6),
    "11,7": partial(squeeze_to_len_n_starting_from_index_i, 11, 7),
}

FlexibleUnetConfigurations = {
    # General parameters for all blocks
    "sample_size": 64,
    "temb_dim": 320 * 4,
    "resnet_eps": 1e-5,
    "resnet_act_fn": "silu",
    "num_attention_heads": 8,
    "cross_attention_dim": 768,
    # Controls modules execute order in unet's forward
    "mix_block_in_forward": True,
    # Down blocks parameters
    "down_blocks_in_channels": [320, 320, 640],
    "down_blocks_out_channels": [320, 640, 1280],
    "down_blocks_num_attentions": [0, 1, 3],
    "down_blocks_num_resnets": [2, 2, 1],
    "add_downsample": [True, True, False],
    # Middle block parameters
    "add_upsample_mid_block": None,
    "mid_num_resnets": 0,
    "mid_num_attentions": 0,
    # Up block parameters
    "prev_output_channels": [1280, 1280, 640],
    "up_blocks_num_attentions": [5, 3, 0],
    "up_blocks_num_resnets": [2, 3, 3],
    "add_upsample": [True, True, False],
}


class SqueezedDPMSolverMultistepScheduler(DPMSolverMultistepScheduler, SchedulerMixin):
    """
    This is a copy-paste from Diffuser's `DPMSolverMultistepScheduler`, with minor differences:
    * Defaults are modified to accommodate DeciDiffusion
    * It supports a squeezer to squeeze the number of inference steps to a smaller number
    //!\\ IMPORTANT: the actual number of inference steps is deduced by the squeezer, and not the pipeline!
    """

    @register_to_config
    def __init__(
        self,
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "squaredcos_cap_v2",  # NOTE THIS DEFAULT VALUE
        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
        solver_order: int = 2,
        prediction_type: str = "v_prediction",  # NOTE THIS DEFAULT VALUE
        thresholding: bool = False,
        dynamic_thresholding_ratio: float = 0.995,
        sample_max_value: float = 1.0,
        algorithm_type: str = "dpmsolver++",
        solver_type: str = "heun",  # NOTE THIS DEFAULT VALUE
        lower_order_final: bool = True,
        use_karras_sigmas: Optional[bool] = False,
        lambda_min_clipped: float = -7.5,  # NOTE THIS DEFAULT VALUE
        variance_type: Optional[str] = None,
        timestep_spacing: str = "linspace",
        steps_offset: int = 1,
        squeeze_mode: Optional[str] = None,  # NOTE THIS ADDITION. Supports keys from `PREDEFINED_TIMESTEP_SQUEEZERS` defined above
    ):
        self._squeezer = PREDEFINED_TIMESTEP_SQUEEZERS.get(squeeze_mode)

        if use_karras_sigmas:
            raise NotImplementedError("Squeezing isn't tested with `use_karras_sigmas`. Please provide `use_karras_sigmas=False`")

        super().__init__(
            num_train_timesteps=num_train_timesteps,
            beta_start=beta_start,
            beta_end=beta_end,
            beta_schedule=beta_schedule,
            trained_betas=trained_betas,
            solver_order=solver_order,
            prediction_type=prediction_type,
            thresholding=thresholding,
            dynamic_thresholding_ratio=dynamic_thresholding_ratio,
            sample_max_value=sample_max_value,
            algorithm_type=algorithm_type,
            solver_type=solver_type,
            lower_order_final=lower_order_final,
            use_karras_sigmas=False,
            lambda_min_clipped=lambda_min_clipped,
            variance_type=variance_type,
            timestep_spacing=timestep_spacing,
            steps_offset=steps_offset,
        )

    def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
        """
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).

        Args:
            num_inference_steps (`int`):
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        """
        super().set_timesteps(num_inference_steps=num_inference_steps, device=device)
        if self._squeezer is not None:
            timesteps = self._squeezer(self.timesteps.cpu())
            sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
            sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
            sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
            sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
            self.sigmas = torch.from_numpy(sigmas)
            self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
            self.num_inference_steps = len(timesteps)


class FlexibleIdentityBlock(nn.Module):
    def forward(
        self,
        hidden_states: torch.FloatTensor,
        temb: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
    ):
        return hidden_states


class FlexibleUNet2DConditionModel(UNet2DConditionModel, ModelMixin):
    configurations = FlexibleUnetConfigurations

    @register_to_config
    def __init__(self):
        super().__init__(
            sample_size=self.configurations.get("sample_size", FlexibleUnetConfigurations["sample_size"]),
            cross_attention_dim=self.configurations.get("cross_attention_dim", FlexibleUnetConfigurations["cross_attention_dim"]),
        )

        num_attention_heads = self.configurations.get("num_attention_heads")
        cross_attention_dim = self.configurations.get("cross_attention_dim")
        mix_block_in_forward = self.configurations.get("mix_block_in_forward")
        resnet_act_fn = self.configurations.get("resnet_act_fn")
        resnet_eps = self.configurations.get("resnet_eps")
        temb_dim = self.configurations.get("temb_dim")

        ###############
        # Down blocks #
        ###############
        down_blocks_num_attentions = self.configurations.get("down_blocks_num_attentions")
        down_blocks_out_channels = self.configurations.get("down_blocks_out_channels")
        down_blocks_in_channels = self.configurations.get("down_blocks_in_channels")
        down_blocks_num_resnets = self.configurations.get("down_blocks_num_resnets")
        add_downsample = self.configurations.get("add_downsample")

        self.down_blocks = nn.ModuleList()

        for i, (in_c, out_c, n_res, n_att, add_down) in enumerate(
            zip(down_blocks_in_channels, down_blocks_out_channels, down_blocks_num_resnets, down_blocks_num_attentions, add_downsample)
        ):
            last_block = i == len(down_blocks_in_channels) - 1
            self.down_blocks.append(
                FlexibleCrossAttnDownBlock2D(
                    in_channels=in_c,
                    out_channels=out_c,
                    temb_channels=temb_dim,
                    num_resnets=n_res,
                    num_attentions=n_att,
                    resnet_eps=resnet_eps,
                    resnet_act_fn=resnet_act_fn,
                    num_attention_heads=num_attention_heads,
                    cross_attention_dim=cross_attention_dim,
                    add_downsample=add_down,
                    last_block=last_block,
                    mix_block_in_forward=mix_block_in_forward,
                )
            )

        ###############
        # Mid blocks  #
        ###############

        mid_block_add_upsample = self.configurations.get("add_upsample_mid_block")
        mid_num_attentions = self.configurations.get("mid_num_attentions")
        mid_num_resnets = self.configurations.get("mid_num_resnets")

        if mid_num_resnets == mid_num_attentions == 0:
            self.mid_block = FlexibleIdentityBlock()
        else:
            self.mid_block = FlexibleUNetMidBlock2DCrossAttn(
                in_channels=down_blocks_out_channels[-1],
                temb_channels=temb_dim,
                resnet_act_fn=resnet_act_fn,
                resnet_eps=resnet_eps,
                cross_attention_dim=cross_attention_dim,
                num_attention_heads=num_attention_heads,
                num_resnets=mid_num_resnets,
                num_attentions=mid_num_attentions,
                mix_block_in_forward=mix_block_in_forward,
                add_upsample=mid_block_add_upsample,
            )

        ###############
        #  Up blocks  #
        ###############

        up_blocks_num_attentions = self.configurations.get("up_blocks_num_attentions")
        up_blocks_num_resnets = self.configurations.get("up_blocks_num_resnets")
        prev_output_channels = self.configurations.get("prev_output_channels")
        up_upsample = self.configurations.get("add_upsample")

        self.up_blocks = nn.ModuleList()
        for in_c, out_c, prev_out, n_res, n_att, add_up in zip(
            reversed(down_blocks_in_channels),
            reversed(down_blocks_out_channels),
            prev_output_channels,
            up_blocks_num_resnets,
            up_blocks_num_attentions,
            up_upsample,
        ):
            self.up_blocks.append(
                FlexibleCrossAttnUpBlock2D(
                    in_channels=in_c,
                    out_channels=out_c,
                    prev_output_channel=prev_out,
                    temb_channels=temb_dim,
                    num_resnets=n_res,
                    num_attentions=n_att,
                    resnet_eps=resnet_eps,
                    resnet_act_fn=resnet_act_fn,
                    num_attention_heads=num_attention_heads,
                    cross_attention_dim=cross_attention_dim,
                    add_upsample=add_up,
                    mix_block_in_forward=mix_block_in_forward,
                )
            )


class FlexibleCrossAttnDownBlock2D(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        temb_channels: int,
        dropout: float = 0.0,
        num_resnets: int = 1,
        num_attentions: int = 1,
        transformer_layers_per_block: int = 1,
        resnet_eps: float = 1e-6,
        resnet_time_scale_shift: str = "default",
        resnet_act_fn: str = "swish",
        resnet_groups: int = 32,
        resnet_pre_norm: bool = True,
        num_attention_heads: int = 1,
        cross_attention_dim: int = 1280,
        output_scale_factor: float = 1.0,
        downsample_padding: int = 1,
        add_downsample: bool = True,
        use_linear_projection: bool = False,
        only_cross_attention: bool = False,
        upcast_attention: bool = False,
        last_block: bool = False,
        mix_block_in_forward: bool = True,
    ):
        super().__init__()

        self.last_block = last_block
        self.mix_block_in_forward = mix_block_in_forward
        self.has_cross_attention = True
        self.num_attention_heads = num_attention_heads

        modules = []

        add_resnets = [True] * num_resnets
        add_cross_attentions = [True] * num_attentions
        for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)):
            in_channels = in_channels if i == 0 else out_channels
            if add_resnet:
                modules.append(
                    ResnetBlock2D(
                        in_channels=in_channels,
                        out_channels=out_channels,
                        temb_channels=temb_channels,
                        eps=resnet_eps,
                        groups=resnet_groups,
                        dropout=dropout,
                        time_embedding_norm=resnet_time_scale_shift,
                        non_linearity=resnet_act_fn,
                        output_scale_factor=output_scale_factor,
                        pre_norm=resnet_pre_norm,
                    )
                )
            if add_cross_attention:
                modules.append(
                    FlexibleTransformer2DModel(
                        num_attention_heads=num_attention_heads,
                        attention_head_dim=out_channels // num_attention_heads,
                        in_channels=out_channels,
                        num_layers=transformer_layers_per_block,
                        cross_attention_dim=cross_attention_dim,
                        norm_num_groups=resnet_groups,
                        use_linear_projection=use_linear_projection,
                        only_cross_attention=only_cross_attention,
                        upcast_attention=upcast_attention,
                    )
                )

        if not mix_block_in_forward:
            modules = sorted(modules, key=custom_sort_order)

        self.modules_list = nn.ModuleList(modules)

        if add_downsample:
            self.downsamplers = nn.ModuleList([Downsample2D(out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op")])
        else:
            self.downsamplers = None

        self.gradient_checkpointing = False

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        temb: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
    ):
        output_states = ()

        for module in self.modules_list:
            if isinstance(module, ResnetBlock2D):
                hidden_states = module(hidden_states, temb)
            elif isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)):
                hidden_states = module(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    cross_attention_kwargs=cross_attention_kwargs,
                    attention_mask=attention_mask,
                    encoder_attention_mask=encoder_attention_mask,
                    return_dict=False,
                )[0]
            else:
                raise ValueError(f"Got an unexpected module in modules list! {type(module)}")
            if isinstance(module, ResnetBlock2D):
                output_states = output_states + (hidden_states,)

        if self.downsamplers is not None:
            for downsampler in self.downsamplers:
                hidden_states = downsampler(hidden_states)

            if not self.last_block:
                output_states = output_states + (hidden_states,)

        return hidden_states, output_states


class FlexibleCrossAttnUpBlock2D(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        prev_output_channel: int,
        temb_channels: int,
        dropout: float = 0.0,
        num_resnets: int = 1,
        num_attentions: int = 1,
        transformer_layers_per_block: int = 1,
        resnet_eps: float = 1e-6,
        resnet_time_scale_shift: str = "default",
        resnet_act_fn: str = "swish",
        resnet_groups: int = 32,
        resnet_pre_norm: bool = True,
        num_attention_heads: int = 1,
        cross_attention_dim: int = 1280,
        output_scale_factor: float = 1.0,
        add_upsample: bool = True,
        use_linear_projection: bool = False,
        only_cross_attention: bool = False,
        upcast_attention: bool = False,
        mix_block_in_forward: bool = True,
    ):
        super().__init__()
        modules = []

        # WARNING: This parameter is filled with number of resnets and used within StableDiffusionPipeline
        self.resnets = []

        self.has_cross_attention = True
        self.num_attention_heads = num_attention_heads

        add_resnets = [True] * num_resnets
        add_cross_attentions = [True] * num_attentions
        for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)):
            res_skip_channels = in_channels if (i == len(add_resnets) - 1) else out_channels
            resnet_in_channels = prev_output_channel if i == 0 else out_channels

            if add_resnet:
                self.resnets += [True]
                modules.append(
                    ResnetBlock2D(
                        in_channels=resnet_in_channels + res_skip_channels,
                        out_channels=out_channels,
                        temb_channels=temb_channels,
                        eps=resnet_eps,
                        groups=resnet_groups,
                        dropout=dropout,
                        time_embedding_norm=resnet_time_scale_shift,
                        non_linearity=resnet_act_fn,
                        output_scale_factor=output_scale_factor,
                        pre_norm=resnet_pre_norm,
                    )
                )
            if add_cross_attention:
                modules.append(
                    FlexibleTransformer2DModel(
                        num_attention_heads,
                        out_channels // num_attention_heads,
                        in_channels=out_channels,
                        num_layers=transformer_layers_per_block,
                        cross_attention_dim=cross_attention_dim,
                        norm_num_groups=resnet_groups,
                        use_linear_projection=use_linear_projection,
                        only_cross_attention=only_cross_attention,
                        upcast_attention=upcast_attention,
                    )
                )

        if not mix_block_in_forward:
            modules = sorted(modules, key=custom_sort_order)

        self.modules_list = nn.ModuleList(modules)

        self.upsamplers = None
        if add_upsample:
            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])

        self.gradient_checkpointing = False

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
        temb: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        upsample_size: Optional[int] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
    ):

        for module in self.modules_list:
            if isinstance(module, ResnetBlock2D):
                res_hidden_states = res_hidden_states_tuple[-1]
                res_hidden_states_tuple = res_hidden_states_tuple[:-1]
                hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
                hidden_states = module(hidden_states, temb)
            if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)):
                hidden_states = module(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    cross_attention_kwargs=cross_attention_kwargs,
                    attention_mask=attention_mask,
                    encoder_attention_mask=encoder_attention_mask,
                    return_dict=False,
                )[0]

        if self.upsamplers is not None:
            for upsampler in self.upsamplers:
                hidden_states = upsampler(hidden_states, upsample_size)

        return hidden_states


class FlexibleUNetMidBlock2DCrossAttn(nn.Module):
    def __init__(
        self,
        in_channels: int,
        temb_channels: int,
        dropout: float = 0.0,
        num_resnets: int = 1,
        num_attentions: int = 1,
        transformer_layers_per_block: int = 1,
        resnet_eps: float = 1e-6,
        resnet_time_scale_shift: str = "default",
        resnet_act_fn: str = "swish",
        resnet_groups: int = 32,
        resnet_pre_norm: bool = True,
        num_attention_heads: int = 1,
        output_scale_factor: float = 1.0,
        cross_attention_dim: int = 1280,
        use_linear_projection: bool = False,
        upcast_attention: bool = False,
        mix_block_in_forward: bool = True,
        add_upsample: bool = True,
    ):
        super().__init__()

        self.has_cross_attention = True
        self.num_attention_heads = num_attention_heads
        resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
        # There is always at least one resnet
        modules = [
            ResnetBlock2D(
                in_channels=in_channels,
                out_channels=in_channels,
                temb_channels=temb_channels,
                eps=resnet_eps,
                groups=resnet_groups,
                dropout=dropout,
                time_embedding_norm=resnet_time_scale_shift,
                non_linearity=resnet_act_fn,
                output_scale_factor=output_scale_factor,
                pre_norm=resnet_pre_norm,
            )
        ]

        add_resnets = [True] * num_resnets
        add_cross_attentions = [True] * num_attentions
        for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)):
            if add_cross_attention:
                modules.append(
                    FlexibleTransformer2DModel(
                        num_attention_heads,
                        in_channels // num_attention_heads,
                        in_channels=in_channels,
                        num_layers=transformer_layers_per_block,
                        cross_attention_dim=cross_attention_dim,
                        norm_num_groups=resnet_groups,
                        use_linear_projection=use_linear_projection,
                        upcast_attention=upcast_attention,
                    )
                )

            if add_resnet:
                modules.append(
                    ResnetBlock2D(
                        in_channels=in_channels,
                        out_channels=in_channels,
                        temb_channels=temb_channels,
                        eps=resnet_eps,
                        groups=resnet_groups,
                        dropout=dropout,
                        time_embedding_norm=resnet_time_scale_shift,
                        non_linearity=resnet_act_fn,
                        output_scale_factor=output_scale_factor,
                        pre_norm=resnet_pre_norm,
                    )
                )
        if not mix_block_in_forward:
            modules = sorted(modules, key=custom_sort_order)

        self.modules_list = nn.ModuleList(modules)

        self.upsamplers = nn.ModuleList([nn.Identity()])
        if add_upsample:
            self.upsamplers = nn.ModuleList([Upsample2D(in_channels, use_conv=True, out_channels=in_channels)])

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        temb: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
    ) -> torch.FloatTensor:
        hidden_states = self.modules_list[0](hidden_states, temb)

        for module in self.modules_list:
            if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)):
                hidden_states = module(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    cross_attention_kwargs=cross_attention_kwargs,
                    attention_mask=attention_mask,
                    encoder_attention_mask=encoder_attention_mask,
                    return_dict=False,
                )[0]
            elif isinstance(module, ResnetBlock2D):
                hidden_states = module(hidden_states, temb)

        for upsampler in self.upsamplers:
            hidden_states = upsampler(hidden_states)

        return hidden_states


class FlexibleTransformer2DModel(ModelMixin, ConfigMixin):
    @register_to_config
    def __init__(
        self,
        num_attention_heads: int = 16,
        attention_head_dim: int = 88,
        in_channels: Optional[int] = None,
        out_channels: Optional[int] = None,
        num_layers: int = 1,
        dropout: float = 0.0,
        norm_num_groups: int = 32,
        cross_attention_dim: Optional[int] = None,
        attention_bias: bool = False,
        activation_fn: str = "geglu",
        num_embeds_ada_norm: Optional[int] = None,
        only_cross_attention: bool = False,
        use_linear_projection: bool = False,
        upcast_attention: bool = False,
        norm_type: str = "layer_norm",
        norm_elementwise_affine: bool = True,
    ):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_dim = attention_head_dim
        self.in_channels = in_channels
        inner_dim = num_attention_heads * attention_head_dim

        # Define input layers
        self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
        self.use_linear_projection = use_linear_projection
        if self.use_linear_projection:
            self.proj_in = nn.Linear(in_channels, inner_dim)
        else:
            self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)

        # Define transformers blocks
        self.transformer_blocks = nn.ModuleList(
            [
                BasicTransformerBlock(
                    inner_dim,
                    num_attention_heads,
                    attention_head_dim,
                    dropout=dropout,
                    cross_attention_dim=cross_attention_dim,
                    activation_fn=activation_fn,
                    num_embeds_ada_norm=num_embeds_ada_norm,
                    attention_bias=attention_bias,
                    only_cross_attention=only_cross_attention,
                    upcast_attention=upcast_attention,
                    norm_type=norm_type,
                    norm_elementwise_affine=norm_elementwise_affine,
                )
                for _ in range(num_layers)
            ]
        )

        # Define output layers
        self.out_channels = in_channels if out_channels is None else out_channels
        if self.use_linear_projection:
            self.proj_out = nn.Linear(inner_dim, in_channels)
        else:
            self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)

    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        timestep: Optional[torch.LongTensor] = None,
        class_labels: Optional[torch.LongTensor] = None,
        cross_attention_kwargs: Dict[str, Any] = None,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        return_dict: bool = False,
    ):
        # 1. Input
        batch, _, height, width = hidden_states.shape
        residual = hidden_states

        hidden_states = self.norm(hidden_states)
        if not self.use_linear_projection:
            hidden_states = self.proj_in(hidden_states)
            inner_dim = hidden_states.shape[1]
            hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
        else:
            inner_dim = hidden_states.shape[1]
            hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
            hidden_states = self.proj_in(hidden_states)

        # 2. Blocks
        for block in self.transformer_blocks:
            hidden_states = block(
                hidden_states,
                attention_mask=attention_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                timestep=timestep,
                cross_attention_kwargs=cross_attention_kwargs,
                class_labels=class_labels,
            )

        # 3. Output
        if not self.use_linear_projection:
            hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
            hidden_states = self.proj_out(hidden_states)
        else:
            hidden_states = self.proj_out(hidden_states)
            hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()

        output = hidden_states + residual
        if return_dict:
            return (output,)
        return Transformer2DModelOutput(sample=output)


class DeciDiffusionPipeline(StableDiffusionPipeline):
    deci_default_squeeze_mode = "10,6"
    deci_default_number_of_iterations = 16
    deci_default_guidance_rescale = 0.8

    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: UNet2DConditionModel,
        scheduler: KarrasDiffusionSchedulers,
        safety_checker: StableDiffusionSafetyChecker,
        feature_extractor: CLIPImageProcessor,
        requires_safety_checker: bool = True,
    ):
        # Replace UNet with Deci`s unet
        del unet
        unet = FlexibleUNet2DConditionModel()

        # Replace with custom scheduler
        del scheduler
        scheduler = SqueezedDPMSolverMultistepScheduler(squeeze_mode=self.deci_default_squeeze_mode)

        super().__init__(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
            requires_safety_checker=requires_safety_checker,
        )

        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
        )

    @torch.no_grad()
    def __call__(
        self,
        prompt: Union[str, List[str]] = None,
        height: Optional[int] = None,
        width: Optional[int] = None,
        num_inference_steps: int = 16,
        guidance_scale: float = 7.5,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: float = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
        callback_steps: int = 1,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        guidance_rescale: float = 0.8,
    ):
        r"""
        The call function to the pipeline for generation.

        Args:
            prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
                The height in pixels of the generated image.
            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
                The width in pixels of the generated image.
            num_inference_steps (`int`, *optional*, defaults to 50):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            guidance_scale (`float`, *optional*, defaults to 7.5):
                A higher guidance scale value encourages the model to generate images closely linked to the text
                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
            negative_prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts to guide what to not include in image generation. If not defined, you need to
                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            eta (`float`, *optional*, defaults to 0.0):
                Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
                generation deterministic.
            latents (`torch.FloatTensor`, *optional*):
                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor is generated by sampling using the supplied random `generator`.
            prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
                provided, text embeddings are generated from the `prompt` input argument.
            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generated image. Choose between `PIL.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
                plain tuple.
            callback (`Callable`, *optional*):
                A function that calls every `callback_steps` steps during inference. The function is called with the
                following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
            callback_steps (`int`, *optional*, defaults to 1):
                The frequency at which the `callback` function is called. If not specified, the callback is called at
                every step.
            cross_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            guidance_rescale (`float`, *optional*, defaults to 0.7):
                Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
                Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
                using zero terminal SNR.

        Examples:

        Returns:
            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
                otherwise a `tuple` is returned where the first element is a list with the generated images and the
                second element is a list of `bool`s indicating whether the corresponding generated image contains
                "not-safe-for-work" (nsfw) content.
        """
        # 0. Default height and width to unet
        height = height or self.unet.config.sample_size * self.vae_scale_factor
        width = width or self.unet.config.sample_size * self.vae_scale_factor

        # 1. Check inputs. Raise error if not correct
        self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)

        # 2. Define call parameters
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        device = self._execution_device
        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0

        # 3. Encode input prompt
        text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
        prompt_embeds, negative_prompt_embeds = self.encode_prompt(
            prompt,
            device,
            num_images_per_prompt,
            do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            lora_scale=text_encoder_lora_scale,
        )
        # For classifier free guidance, we need to do two forward passes.
        # Here we concatenate the unconditional and text embeddings into a single batch
        # to avoid doing two forward passes
        if do_classifier_free_guidance:
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

        # 4. Prepare timesteps
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps

        # 5. Prepare latent variables
        num_channels_latents = self.unet.config.in_channels
        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )

        # 6. Prepare extra step kwargs.
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        # 7. Denoising loop
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        with self.progress_bar(total=len(timesteps)) as progress_bar:
            for i, t in enumerate(timesteps):
                # expand the latents if we are doing classifier free guidance
                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                # predict the noise residual
                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    cross_attention_kwargs=cross_attention_kwargs,
                    return_dict=False,
                )[0]

                # perform guidance
                if do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                if do_classifier_free_guidance and guidance_rescale > 0.0:
                    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)

                # compute the previous noisy sample x_t -> x_t-1
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

                # call the callback, if provided
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        callback(i, t, latents)

        if not output_type == "latent":
            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
        else:
            image = latents
            has_nsfw_concept = None

        if has_nsfw_concept is None:
            do_denormalize = [True] * image.shape[0]
        else:
            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

        # Offload all models
        self.maybe_free_model_hooks()

        if not return_dict:
            return (image, has_nsfw_concept)

        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)