import itertools
from typing import Any, Optional, Dict, Tuple

import torch
from diffusers import StableDiffusionPipeline, AutoencoderKL
from diffusers import Transformer2DModel, ModelMixin, ConfigMixin
from diffusers import UNet2DConditionModel
from diffusers.configuration_utils import register_to_config
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.resnet import ResnetBlock2D, Downsample2D, Upsample2D
from diffusers.models.transformer_2d import Transformer2DModelOutput
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from torch import nn
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor

FlexibleUnetConfigurations = {
    # General parameters for all blocks
    'sample_size': 64,
    'temb_dim': 320 * 4,
    'resnet_eps': 1e-5,
    'resnet_act_fn': 'silu',
    'num_attention_heads': 8,
    'cross_attention_dim': 768,

    # Controls modules execute order in unet's forward
    'mix_block_in_forward': True,
    # Down blocks parameters
    'down_blocks_in_channels': [320, 320, 640],
    'down_blocks_out_channels': [320, 640, 1280],
    'down_blocks_num_attentions': [0, 1, 3],
    'down_blocks_num_resnets': [2, 2, 1],
    'add_downsample': [True, True, True],

    # Middle block parameters
    'add_upsample_mid_block': True,
    'mid_num_resnets': 4,
    'mid_num_attentions': 2,

    # Up block parameters
    'prev_output_channels': [1280, 1280, 640],
    'up_blocks_num_attentions': [6, 3, 0],
    'up_blocks_num_resnets': [2, 3, 3],
    'add_upsample': [True, True, False],
}


def custom_sort_order(obj):
    """
    Key function for sorting order of execution in forward methods
    """
    return {ResnetBlock2D: 0, Transformer2DModel: 1, FlexibleTransformer2DModel: 1}.get(obj.__class__)


class FlexibleIdentityBlock(nn.Module):
    def forward(
            self,
            hidden_states: torch.FloatTensor,
            temb: Optional[torch.FloatTensor] = None,
            encoder_hidden_states: Optional[torch.FloatTensor] = None,
            attention_mask: Optional[torch.FloatTensor] = None,
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            encoder_attention_mask: Optional[torch.FloatTensor] = None,
    ):
        return hidden_states


class FlexibleUNet2DConditionModel(UNet2DConditionModel, ModelMixin):
    configurations = FlexibleUnetConfigurations

    @register_to_config
    def __init__(self):
        super().__init__(sample_size=self.configurations.get('sample_size',
                                                             FlexibleUnetConfigurations['sample_size']),
                         cross_attention_dim=self.configurations.get("cross_attention_dim",
                                                                     FlexibleUnetConfigurations['cross_attention_dim']))

        num_attention_heads = self.configurations.get("num_attention_heads")
        cross_attention_dim = self.configurations.get("cross_attention_dim")
        mix_block_in_forward = self.configurations.get("mix_block_in_forward")
        resnet_act_fn = self.configurations.get("resnet_act_fn")
        resnet_eps = self.configurations.get("resnet_eps")
        temb_dim = self.configurations.get("temb_dim")

        ###############
        # Down blocks #
        ###############
        down_blocks_num_attentions = self.configurations.get("down_blocks_num_attentions")
        down_blocks_out_channels = self.configurations.get("down_blocks_out_channels")
        down_blocks_in_channels = self.configurations.get("down_blocks_in_channels")
        down_blocks_num_resnets = self.configurations.get("down_blocks_num_resnets")
        add_downsample = self.configurations.get("add_downsample")

        self.down_blocks = nn.ModuleList()

        for i, (in_c, out_c, n_res, n_att, add_down) in enumerate(zip(down_blocks_in_channels, down_blocks_out_channels,
                                                                      down_blocks_num_resnets,
                                                                      down_blocks_num_attentions,
                                                                      add_downsample)):
            last_block = i == len(down_blocks_in_channels) - 1
            self.down_blocks.append(FlexibleCrossAttnDownBlock2D(in_channels=in_c,
                                                                 out_channels=out_c,
                                                                 temb_channels=temb_dim,
                                                                 num_resnets=n_res,
                                                                 num_attentions=n_att,
                                                                 resnet_eps=resnet_eps,
                                                                 resnet_act_fn=resnet_act_fn,
                                                                 num_attention_heads=num_attention_heads,
                                                                 cross_attention_dim=cross_attention_dim,
                                                                 add_downsample=add_down,
                                                                 last_block=last_block,
                                                                 mix_block_in_forward=mix_block_in_forward))

        ###############
        # Mid blocks  #
        ###############

        mid_block_add_upsample = self.configurations.get("add_upsample_mid_block")
        mid_num_attentions = self.configurations.get("mid_num_attentions")
        mid_num_resnets = self.configurations.get("mid_num_resnets")
        
        if mid_num_resnets == mid_num_attentions == 0:
            self.mid_block = FlexibleIdentityBlock()
        else:
            self.mid_block = FlexibleUNetMidBlock2DCrossAttn(in_channels=down_blocks_out_channels[-1],
                                                             temb_channels=temb_dim,
                                                             resnet_act_fn=resnet_act_fn,
                                                             resnet_eps=resnet_eps,
                                                             cross_attention_dim=cross_attention_dim,
                                                             num_attention_heads=num_attention_heads,
                                                             num_resnets=mid_num_resnets,
                                                             num_attentions=mid_num_attentions,
                                                             mix_block_in_forward=mix_block_in_forward,
                                                             add_upsample=mid_block_add_upsample
                                                             )

        ###############
        #  Up blocks  #
        ###############

        up_blocks_num_attentions = self.configurations.get("up_blocks_num_attentions")
        up_blocks_num_resnets = self.configurations.get("up_blocks_num_resnets")
        prev_output_channels = self.configurations.get("prev_output_channels")
        up_upsample = self.configurations.get("add_upsample")

        self.up_blocks = nn.ModuleList()
        for in_c, out_c, prev_out, n_res, n_att, add_up in zip(reversed(down_blocks_in_channels),
                                                               reversed(down_blocks_out_channels),
                                                               prev_output_channels,
                                                               up_blocks_num_resnets, up_blocks_num_attentions,
                                                               up_upsample):
            self.up_blocks.append(FlexibleCrossAttnUpBlock2D(in_channels=in_c,
                                                             out_channels=out_c,
                                                             prev_output_channel=prev_out,
                                                             temb_channels=temb_dim,
                                                             num_resnets=n_res,
                                                             num_attentions=n_att,
                                                             resnet_eps=resnet_eps,
                                                             resnet_act_fn=resnet_act_fn,
                                                             num_attention_heads=num_attention_heads,
                                                             cross_attention_dim=cross_attention_dim,
                                                             add_upsample=add_up,
                                                             mix_block_in_forward=mix_block_in_forward
                                                             ))


class FlexibleCrossAttnDownBlock2D(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            temb_channels: int,
            dropout: float = 0.0,
            num_resnets: int = 1,
            num_attentions: int = 1,
            transformer_layers_per_block: int = 1,
            resnet_eps: float = 1e-6,
            resnet_time_scale_shift: str = "default",
            resnet_act_fn: str = "swish",
            resnet_groups: int = 32,
            resnet_pre_norm: bool = True,
            num_attention_heads: int = 1,
            cross_attention_dim: int = 1280,
            output_scale_factor: float = 1.0,
            downsample_padding: int = 1,
            add_downsample: bool = True,
            use_linear_projection: bool = False,
            only_cross_attention: bool = False,
            upcast_attention: bool = False,
            last_block: bool = False,
            mix_block_in_forward: bool = True,
    ):
        super().__init__()

        self.last_block = last_block
        self.mix_block_in_forward = mix_block_in_forward
        self.has_cross_attention = True
        self.num_attention_heads = num_attention_heads

        modules = []

        add_resnets = [True] * num_resnets
        add_cross_attentions = [True] * num_attentions
        for i, (add_resnet, add_cross_attention) in enumerate(
                itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)):
            in_channels = in_channels if i == 0 else out_channels
            if add_resnet:
                modules.append(
                    ResnetBlock2D(
                        in_channels=in_channels,
                        out_channels=out_channels,
                        temb_channels=temb_channels,
                        eps=resnet_eps,
                        groups=resnet_groups,
                        dropout=dropout,
                        time_embedding_norm=resnet_time_scale_shift,
                        non_linearity=resnet_act_fn,
                        output_scale_factor=output_scale_factor,
                        pre_norm=resnet_pre_norm,
                    )
                )
            if add_cross_attention:
                modules.append(
                    FlexibleTransformer2DModel(
                        num_attention_heads=num_attention_heads,
                        attention_head_dim=out_channels // num_attention_heads,
                        in_channels=out_channels,
                        num_layers=transformer_layers_per_block,
                        cross_attention_dim=cross_attention_dim,
                        norm_num_groups=resnet_groups,
                        use_linear_projection=use_linear_projection,
                        only_cross_attention=only_cross_attention,
                        upcast_attention=upcast_attention,
                    )
                )

        if not mix_block_in_forward:
            modules = sorted(modules, key=custom_sort_order)

        self.modules_list = nn.ModuleList(modules)

        if add_downsample:
            self.downsamplers = nn.ModuleList(
                [
                    Downsample2D(
                        out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
                    )
                ]
            )
        else:
            self.downsamplers = None

        self.gradient_checkpointing = False

    def forward(
            self,
            hidden_states: torch.FloatTensor,
            temb: Optional[torch.FloatTensor] = None,
            encoder_hidden_states: Optional[torch.FloatTensor] = None,
            attention_mask: Optional[torch.FloatTensor] = None,
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            encoder_attention_mask: Optional[torch.FloatTensor] = None,
    ):
        output_states = ()

        for module in self.modules_list:
            if isinstance(module, ResnetBlock2D):
                hidden_states = module(hidden_states, temb)
            elif isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)):
                hidden_states = module(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    cross_attention_kwargs=cross_attention_kwargs,
                    attention_mask=attention_mask,
                    encoder_attention_mask=encoder_attention_mask,
                    return_dict=False,
                )[0]
            else:
                raise ValueError(f'Got an unexpected module in modules list! {type(module)}')
            if isinstance(module, ResnetBlock2D):
                output_states = output_states + (hidden_states,)

        if self.downsamplers is not None:
            for downsampler in self.downsamplers:
                hidden_states = downsampler(hidden_states)

            if not self.last_block:
                output_states = output_states + (hidden_states,)

        return hidden_states, output_states


class FlexibleCrossAttnUpBlock2D(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            prev_output_channel: int,
            temb_channels: int,
            dropout: float = 0.0,
            num_resnets: int = 1,
            num_attentions: int = 1,
            transformer_layers_per_block: int = 1,
            resnet_eps: float = 1e-6,
            resnet_time_scale_shift: str = "default",
            resnet_act_fn: str = "swish",
            resnet_groups: int = 32,
            resnet_pre_norm: bool = True,
            num_attention_heads: int = 1,
            cross_attention_dim: int = 1280,
            output_scale_factor: float = 1.0,
            add_upsample: bool = True,
            use_linear_projection: bool = False,
            only_cross_attention: bool = False,
            upcast_attention: bool = False,
            mix_block_in_forward: bool = True
    ):
        super().__init__()
        modules = []

        # WARNING: This parameter is filled with number of resnets and used within StableDiffusionPipeline
        self.resnets = []

        self.has_cross_attention = True
        self.num_attention_heads = num_attention_heads

        add_resnets = [True] * num_resnets
        add_cross_attentions = [True] * num_attentions
        for i, (add_resnet, add_cross_attention) in enumerate(
                itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)):
            res_skip_channels = in_channels if (i == len(add_resnets) - 1) else out_channels
            resnet_in_channels = prev_output_channel if i == 0 else out_channels

            if add_resnet:
                self.resnets += [True]
                modules.append(
                    ResnetBlock2D(
                        in_channels=resnet_in_channels + res_skip_channels,
                        out_channels=out_channels,
                        temb_channels=temb_channels,
                        eps=resnet_eps,
                        groups=resnet_groups,
                        dropout=dropout,
                        time_embedding_norm=resnet_time_scale_shift,
                        non_linearity=resnet_act_fn,
                        output_scale_factor=output_scale_factor,
                        pre_norm=resnet_pre_norm,
                    )
                )
            if add_cross_attention:
                modules.append(
                    FlexibleTransformer2DModel(
                        num_attention_heads,
                        out_channels // num_attention_heads,
                        in_channels=out_channels,
                        num_layers=transformer_layers_per_block,
                        cross_attention_dim=cross_attention_dim,
                        norm_num_groups=resnet_groups,
                        use_linear_projection=use_linear_projection,
                        only_cross_attention=only_cross_attention,
                        upcast_attention=upcast_attention,
                    )
                )

        if not mix_block_in_forward:
            modules = sorted(modules, key=custom_sort_order)

        self.modules_list = nn.ModuleList(modules)

        self.upsamplers = None
        if add_upsample:
            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])

        self.gradient_checkpointing = False

    def forward(
            self,
            hidden_states: torch.FloatTensor,
            res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
            temb: Optional[torch.FloatTensor] = None,
            encoder_hidden_states: Optional[torch.FloatTensor] = None,
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            upsample_size: Optional[int] = None,
            attention_mask: Optional[torch.FloatTensor] = None,
            encoder_attention_mask: Optional[torch.FloatTensor] = None,
    ):

        for module in self.modules_list:
            if isinstance(module, ResnetBlock2D):
                res_hidden_states = res_hidden_states_tuple[-1]
                res_hidden_states_tuple = res_hidden_states_tuple[:-1]
                hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
                hidden_states = module(hidden_states, temb)
            if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)):
                hidden_states = module(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    cross_attention_kwargs=cross_attention_kwargs,
                    attention_mask=attention_mask,
                    encoder_attention_mask=encoder_attention_mask,
                    return_dict=False,
                )[0]

        if self.upsamplers is not None:
            for upsampler in self.upsamplers:
                hidden_states = upsampler(hidden_states, upsample_size)

        return hidden_states


class FlexibleUNetMidBlock2DCrossAttn(nn.Module):
    def __init__(
            self,
            in_channels: int,
            temb_channels: int,
            dropout: float = 0.0,
            num_resnets: int = 1,
            num_attentions: int = 1,
            transformer_layers_per_block: int = 1,
            resnet_eps: float = 1e-6,
            resnet_time_scale_shift: str = "default",
            resnet_act_fn: str = "swish",
            resnet_groups: int = 32,
            resnet_pre_norm: bool = True,
            num_attention_heads: int = 1,
            output_scale_factor: float = 1.0,
            cross_attention_dim: int = 1280,
            use_linear_projection: bool = False,
            upcast_attention: bool = False,
            mix_block_in_forward: bool = True,
            add_upsample: bool = True,
    ):
        super().__init__()

        self.has_cross_attention = True
        self.num_attention_heads = num_attention_heads
        resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
        # There is always at least one resnet
        modules = [ResnetBlock2D(
            in_channels=in_channels,
            out_channels=in_channels,
            temb_channels=temb_channels,
            eps=resnet_eps,
            groups=resnet_groups,
            dropout=dropout,
            time_embedding_norm=resnet_time_scale_shift,
            non_linearity=resnet_act_fn,
            output_scale_factor=output_scale_factor,
            pre_norm=resnet_pre_norm,
        )]

        add_resnets = [True] * num_resnets
        add_cross_attentions = [True] * num_attentions
        for i, (add_resnet, add_cross_attention) in enumerate(
                itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)):
            if add_cross_attention:
                modules.append(
                    FlexibleTransformer2DModel(
                        num_attention_heads,
                        in_channels // num_attention_heads,
                        in_channels=in_channels,
                        num_layers=transformer_layers_per_block,
                        cross_attention_dim=cross_attention_dim,
                        norm_num_groups=resnet_groups,
                        use_linear_projection=use_linear_projection,
                        upcast_attention=upcast_attention,
                    )
                )

            if add_resnet:
                modules.append(
                    ResnetBlock2D(
                        in_channels=in_channels,
                        out_channels=in_channels,
                        temb_channels=temb_channels,
                        eps=resnet_eps,
                        groups=resnet_groups,
                        dropout=dropout,
                        time_embedding_norm=resnet_time_scale_shift,
                        non_linearity=resnet_act_fn,
                        output_scale_factor=output_scale_factor,
                        pre_norm=resnet_pre_norm,
                    )
                )
        if not mix_block_in_forward:
            modules = sorted(modules, key=custom_sort_order)

        self.modules_list = nn.ModuleList(modules)

        self.upsamplers = nn.ModuleList([nn.Identity()])
        if add_upsample:
            self.upsamplers = nn.ModuleList([Upsample2D(in_channels, use_conv=True, out_channels=in_channels)])

    def forward(
            self,
            hidden_states: torch.FloatTensor,
            temb: Optional[torch.FloatTensor] = None,
            encoder_hidden_states: Optional[torch.FloatTensor] = None,
            attention_mask: Optional[torch.FloatTensor] = None,
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            encoder_attention_mask: Optional[torch.FloatTensor] = None,
    ) -> torch.FloatTensor:
        hidden_states = self.modules_list[0](hidden_states, temb)

        for module in self.modules_list:
            if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)):
                hidden_states = module(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    cross_attention_kwargs=cross_attention_kwargs,
                    attention_mask=attention_mask,
                    encoder_attention_mask=encoder_attention_mask,
                    return_dict=False,
                )[0]
            elif isinstance(module, ResnetBlock2D):
                hidden_states = module(hidden_states, temb)

        for upsampler in self.upsamplers:
            hidden_states = upsampler(hidden_states)

        return hidden_states


class FlexibleTransformer2DModel(ModelMixin, ConfigMixin):
    @register_to_config
    def __init__(
            self,
            num_attention_heads: int = 16,
            attention_head_dim: int = 88,
            in_channels: Optional[int] = None,
            out_channels: Optional[int] = None,
            num_layers: int = 1,
            dropout: float = 0.0,
            norm_num_groups: int = 32,
            cross_attention_dim: Optional[int] = None,
            attention_bias: bool = False,
            activation_fn: str = "geglu",
            num_embeds_ada_norm: Optional[int] = None,
            only_cross_attention: bool = False,
            use_linear_projection: bool = False,
            upcast_attention: bool = False,
            norm_type: str = "layer_norm",
            norm_elementwise_affine: bool = True,
    ):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_dim = attention_head_dim
        self.in_channels = in_channels
        inner_dim = num_attention_heads * attention_head_dim

        # Define input layers
        self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
        self.use_linear_projection = use_linear_projection
        if self.use_linear_projection:
            self.proj_in = nn.Linear(in_channels, inner_dim)
        else:
            self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)

        # Define transformers blocks
        self.transformer_blocks = nn.ModuleList(
            [
                BasicTransformerBlock(
                    inner_dim,
                    num_attention_heads,
                    attention_head_dim,
                    dropout=dropout,
                    cross_attention_dim=cross_attention_dim,
                    activation_fn=activation_fn,
                    num_embeds_ada_norm=num_embeds_ada_norm,
                    attention_bias=attention_bias,
                    only_cross_attention=only_cross_attention,
                    upcast_attention=upcast_attention,
                    norm_type=norm_type,
                    norm_elementwise_affine=norm_elementwise_affine,
                )
                for _ in range(num_layers)
            ]
        )

        # Define output layers
        self.out_channels = in_channels if out_channels is None else out_channels
        if self.use_linear_projection:
            self.proj_out = nn.Linear(inner_dim, in_channels)
        else:
            self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)

    def forward(
            self,
            hidden_states: torch.Tensor,
            encoder_hidden_states: Optional[torch.Tensor] = None,
            timestep: Optional[torch.LongTensor] = None,
            class_labels: Optional[torch.LongTensor] = None,
            cross_attention_kwargs: Dict[str, Any] = None,
            attention_mask: Optional[torch.Tensor] = None,
            encoder_attention_mask: Optional[torch.Tensor] = None,
            return_dict: bool = False
    ):
        # 1. Input
        batch, _, height, width = hidden_states.shape
        residual = hidden_states

        hidden_states = self.norm(hidden_states)
        if not self.use_linear_projection:
            hidden_states = self.proj_in(hidden_states)
            inner_dim = hidden_states.shape[1]
            hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
        else:
            inner_dim = hidden_states.shape[1]
            hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
            hidden_states = self.proj_in(hidden_states)

        # 2. Blocks
        for block in self.transformer_blocks:
            hidden_states = block(
                hidden_states,
                attention_mask=attention_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                timestep=timestep,
                cross_attention_kwargs=cross_attention_kwargs,
                class_labels=class_labels,
            )

        # 3. Output
        if not self.use_linear_projection:
            hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
            hidden_states = self.proj_out(hidden_states)
        else:
            hidden_states = self.proj_out(hidden_states)
            hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()

        output = hidden_states + residual
        if return_dict:
            return (output,)
        return Transformer2DModelOutput(sample=output)


class DeciDiffusionPipeline(StableDiffusionPipeline):
    deci_default_number_of_iterations = 30
    deci_default_guidance_rescale = 0.7

    def __init__(self,
                 vae: AutoencoderKL,
                 text_encoder: CLIPTextModel,
                 tokenizer: CLIPTokenizer,
                 unet: UNet2DConditionModel,
                 scheduler: KarrasDiffusionSchedulers,
                 safety_checker: StableDiffusionSafetyChecker,
                 feature_extractor: CLIPImageProcessor,
                 requires_safety_checker: bool = True
                 ):
        # Replace UNet with Deci`s unet
        del unet
        unet = FlexibleUNet2DConditionModel()

        super().__init__(vae=vae,
                         text_encoder=text_encoder,
                         tokenizer=tokenizer,
                         unet=unet,
                         scheduler=scheduler,
                         safety_checker=safety_checker,
                         feature_extractor=feature_extractor,
                         requires_safety_checker=requires_safety_checker
                         )

        self.register_modules(vae=vae,
                              text_encoder=text_encoder,
                              tokenizer=tokenizer,
                              unet=unet,
                              scheduler=scheduler,
                              safety_checker=safety_checker,
                              feature_extractor=feature_extractor)

    def __call__(self, *args, **kwargs):
        # Set up default training parameters (if not given by user specifically)
        if "guidance_rescale" not in kwargs:
            kwargs.update({'guidance_rescale': self.deci_default_guidance_rescale})
        if "num_inference_steps" not in kwargs:
            kwargs.update({'num_inference_steps': self.deci_default_number_of_iterations})
        return super().__call__(*args, **kwargs)