pandaphd commited on
Commit
1201269
·
1 Parent(s): e6809a0

fix diffusers

Browse files
diffuserss/models/autoencoder_kl.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict, Optional, Tuple, Union
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from ..configuration_utils import ConfigMixin, register_to_config
20
+ from ..loaders import FromOriginalVAEMixin
21
+ from ..utils.accelerate_utils import apply_forward_hook
22
+ from .attention_processor import (
23
+ ADDED_KV_ATTENTION_PROCESSORS,
24
+ CROSS_ATTENTION_PROCESSORS,
25
+ AttentionProcessor,
26
+ AttnAddedKVProcessor,
27
+ AttnProcessor,
28
+ )
29
+ from .modeling_outputs import AutoencoderKLOutput
30
+ from .modeling_utils import ModelMixin
31
+ from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
32
+
33
+
34
+ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
35
+ r"""
36
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
37
+
38
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
39
+ for all models (such as downloading or saving).
40
+
41
+ Parameters:
42
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
43
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
44
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
45
+ Tuple of downsample block types.
46
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
47
+ Tuple of upsample block types.
48
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
49
+ Tuple of block output channels.
50
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
51
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
52
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
53
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
54
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
55
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
56
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
57
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
58
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
59
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
60
+ force_upcast (`bool`, *optional*, default to `True`):
61
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
62
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
63
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
64
+ """
65
+
66
+ _supports_gradient_checkpointing = True
67
+
68
+ @register_to_config
69
+ def __init__(
70
+ self,
71
+ in_channels: int = 3,
72
+ out_channels: int = 3,
73
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
74
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
75
+ block_out_channels: Tuple[int] = (64,),
76
+ layers_per_block: int = 1,
77
+ act_fn: str = "silu",
78
+ latent_channels: int = 4,
79
+ norm_num_groups: int = 32,
80
+ sample_size: int = 32,
81
+ scaling_factor: float = 0.18215,
82
+ force_upcast: float = True,
83
+ ):
84
+ super().__init__()
85
+
86
+ # pass init params to Encoder
87
+ self.encoder = Encoder(
88
+ in_channels=in_channels,
89
+ out_channels=latent_channels,
90
+ down_block_types=down_block_types,
91
+ block_out_channels=block_out_channels,
92
+ layers_per_block=layers_per_block,
93
+ act_fn=act_fn,
94
+ norm_num_groups=norm_num_groups,
95
+ double_z=True,
96
+ )
97
+
98
+ # pass init params to Decoder
99
+ self.decoder = Decoder(
100
+ in_channels=latent_channels,
101
+ out_channels=out_channels,
102
+ up_block_types=up_block_types,
103
+ block_out_channels=block_out_channels,
104
+ layers_per_block=layers_per_block,
105
+ norm_num_groups=norm_num_groups,
106
+ act_fn=act_fn,
107
+ )
108
+
109
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
110
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
111
+
112
+ self.use_slicing = False
113
+ self.use_tiling = False
114
+
115
+ # only relevant if vae tiling is enabled
116
+ self.tile_sample_min_size = self.config.sample_size
117
+ sample_size = (
118
+ self.config.sample_size[0]
119
+ if isinstance(self.config.sample_size, (list, tuple))
120
+ else self.config.sample_size
121
+ )
122
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
123
+ self.tile_overlap_factor = 0.25
124
+
125
+ def _set_gradient_checkpointing(self, module, value=False):
126
+ if isinstance(module, (Encoder, Decoder)):
127
+ module.gradient_checkpointing = value
128
+
129
+ def enable_tiling(self, use_tiling: bool = True):
130
+ r"""
131
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
132
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
133
+ processing larger images.
134
+ """
135
+ self.use_tiling = use_tiling
136
+
137
+ def disable_tiling(self):
138
+ r"""
139
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
140
+ decoding in one step.
141
+ """
142
+ self.enable_tiling(False)
143
+
144
+ def enable_slicing(self):
145
+ r"""
146
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
147
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
148
+ """
149
+ self.use_slicing = True
150
+
151
+ def disable_slicing(self):
152
+ r"""
153
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
154
+ decoding in one step.
155
+ """
156
+ self.use_slicing = False
157
+
158
+ @property
159
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
160
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
161
+ r"""
162
+ Returns:
163
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
164
+ indexed by its weight name.
165
+ """
166
+ # set recursively
167
+ processors = {}
168
+
169
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
170
+ if hasattr(module, "get_processor"):
171
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
172
+
173
+ for sub_name, child in module.named_children():
174
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
175
+
176
+ return processors
177
+
178
+ for name, module in self.named_children():
179
+ fn_recursive_add_processors(name, module, processors)
180
+
181
+ return processors
182
+
183
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
184
+ def set_attn_processor(
185
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
186
+ ):
187
+ r"""
188
+ Sets the attention processor to use to compute attention.
189
+
190
+ Parameters:
191
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
192
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
193
+ for **all** `Attention` layers.
194
+
195
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
196
+ processor. This is strongly recommended when setting trainable attention processors.
197
+
198
+ """
199
+ count = len(self.attn_processors.keys())
200
+
201
+ if isinstance(processor, dict) and len(processor) != count:
202
+ raise ValueError(
203
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
204
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
205
+ )
206
+
207
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
208
+ if hasattr(module, "set_processor"):
209
+ if not isinstance(processor, dict):
210
+ module.set_processor(processor, _remove_lora=_remove_lora)
211
+ else:
212
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
213
+
214
+ for sub_name, child in module.named_children():
215
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
216
+
217
+ for name, module in self.named_children():
218
+ fn_recursive_attn_processor(name, module, processor)
219
+
220
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
221
+ def set_default_attn_processor(self):
222
+ """
223
+ Disables custom attention processors and sets the default attention implementation.
224
+ """
225
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
226
+ processor = AttnAddedKVProcessor()
227
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
228
+ processor = AttnProcessor()
229
+ else:
230
+ raise ValueError(
231
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
232
+ )
233
+
234
+ self.set_attn_processor(processor, _remove_lora=True)
235
+
236
+ @apply_forward_hook
237
+ def encode(
238
+ self, x: torch.FloatTensor, return_dict: bool = True
239
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
240
+ """
241
+ Encode a batch of images into latents.
242
+
243
+ Args:
244
+ x (`torch.FloatTensor`): Input batch of images.
245
+ return_dict (`bool`, *optional*, defaults to `True`):
246
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
247
+
248
+ Returns:
249
+ The latent representations of the encoded images. If `return_dict` is True, a
250
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
251
+ """
252
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
253
+ return self.tiled_encode(x, return_dict=return_dict)
254
+
255
+ if self.use_slicing and x.shape[0] > 1:
256
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
257
+ h = torch.cat(encoded_slices)
258
+ else:
259
+ h = self.encoder(x)
260
+
261
+ moments = self.quant_conv(h)
262
+ posterior = DiagonalGaussianDistribution(moments)
263
+
264
+ if not return_dict:
265
+ return (posterior,)
266
+
267
+ return AutoencoderKLOutput(latent_dist=posterior)
268
+
269
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
270
+ if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
271
+ return self.tiled_decode(z, return_dict=return_dict)
272
+
273
+ z = self.post_quant_conv(z)
274
+ dec = self.decoder(z)
275
+
276
+ if not return_dict:
277
+ return (dec,)
278
+
279
+ return DecoderOutput(sample=dec)
280
+
281
+ @apply_forward_hook
282
+ def decode(
283
+ self, z: torch.FloatTensor, return_dict: bool = True, generator=None
284
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
285
+ """
286
+ Decode a batch of images.
287
+
288
+ Args:
289
+ z (`torch.FloatTensor`): Input batch of latent vectors.
290
+ return_dict (`bool`, *optional*, defaults to `True`):
291
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
292
+
293
+ Returns:
294
+ [`~models.vae.DecoderOutput`] or `tuple`:
295
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
296
+ returned.
297
+
298
+ """
299
+ if self.use_slicing and z.shape[0] > 1:
300
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
301
+ decoded = torch.cat(decoded_slices)
302
+ else:
303
+ decoded = self._decode(z).sample
304
+
305
+ if not return_dict:
306
+ return (decoded,)
307
+
308
+ return DecoderOutput(sample=decoded)
309
+
310
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
311
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
312
+ for y in range(blend_extent):
313
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
314
+ return b
315
+
316
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
317
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
318
+ for x in range(blend_extent):
319
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
320
+ return b
321
+
322
+ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
323
+ r"""Encode a batch of images using a tiled encoder.
324
+
325
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
326
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
327
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
328
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
329
+ output, but they should be much less noticeable.
330
+
331
+ Args:
332
+ x (`torch.FloatTensor`): Input batch of images.
333
+ return_dict (`bool`, *optional*, defaults to `True`):
334
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
335
+
336
+ Returns:
337
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
338
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
339
+ `tuple` is returned.
340
+ """
341
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
342
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
343
+ row_limit = self.tile_latent_min_size - blend_extent
344
+
345
+ # Split the image into 512x512 tiles and encode them separately.
346
+ rows = []
347
+ for i in range(0, x.shape[2], overlap_size):
348
+ row = []
349
+ for j in range(0, x.shape[3], overlap_size):
350
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
351
+ tile = self.encoder(tile)
352
+ tile = self.quant_conv(tile)
353
+ row.append(tile)
354
+ rows.append(row)
355
+ result_rows = []
356
+ for i, row in enumerate(rows):
357
+ result_row = []
358
+ for j, tile in enumerate(row):
359
+ # blend the above tile and the left tile
360
+ # to the current tile and add the current tile to the result row
361
+ if i > 0:
362
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
363
+ if j > 0:
364
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
365
+ result_row.append(tile[:, :, :row_limit, :row_limit])
366
+ result_rows.append(torch.cat(result_row, dim=3))
367
+
368
+ moments = torch.cat(result_rows, dim=2)
369
+ posterior = DiagonalGaussianDistribution(moments)
370
+
371
+ if not return_dict:
372
+ return (posterior,)
373
+
374
+ return AutoencoderKLOutput(latent_dist=posterior)
375
+
376
+ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
377
+ r"""
378
+ Decode a batch of images using a tiled decoder.
379
+
380
+ Args:
381
+ z (`torch.FloatTensor`): Input batch of latent vectors.
382
+ return_dict (`bool`, *optional*, defaults to `True`):
383
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
384
+
385
+ Returns:
386
+ [`~models.vae.DecoderOutput`] or `tuple`:
387
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
388
+ returned.
389
+ """
390
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
391
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
392
+ row_limit = self.tile_sample_min_size - blend_extent
393
+
394
+ # Split z into overlapping 64x64 tiles and decode them separately.
395
+ # The tiles have an overlap to avoid seams between tiles.
396
+ rows = []
397
+ for i in range(0, z.shape[2], overlap_size):
398
+ row = []
399
+ for j in range(0, z.shape[3], overlap_size):
400
+ tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
401
+ tile = self.post_quant_conv(tile)
402
+ decoded = self.decoder(tile)
403
+ row.append(decoded)
404
+ rows.append(row)
405
+ result_rows = []
406
+ for i, row in enumerate(rows):
407
+ result_row = []
408
+ for j, tile in enumerate(row):
409
+ # blend the above tile and the left tile
410
+ # to the current tile and add the current tile to the result row
411
+ if i > 0:
412
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
413
+ if j > 0:
414
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
415
+ result_row.append(tile[:, :, :row_limit, :row_limit])
416
+ result_rows.append(torch.cat(result_row, dim=3))
417
+
418
+ dec = torch.cat(result_rows, dim=2)
419
+ if not return_dict:
420
+ return (dec,)
421
+
422
+ return DecoderOutput(sample=dec)
423
+
424
+ def forward(
425
+ self,
426
+ sample: torch.FloatTensor,
427
+ sample_posterior: bool = False,
428
+ return_dict: bool = True,
429
+ generator: Optional[torch.Generator] = None,
430
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
431
+ r"""
432
+ Args:
433
+ sample (`torch.FloatTensor`): Input sample.
434
+ sample_posterior (`bool`, *optional*, defaults to `False`):
435
+ Whether to sample from the posterior.
436
+ return_dict (`bool`, *optional*, defaults to `True`):
437
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
438
+ """
439
+ x = sample
440
+ posterior = self.encode(x).latent_dist
441
+ if sample_posterior:
442
+ z = posterior.sample(generator=generator)
443
+ else:
444
+ z = posterior.mode()
445
+ dec = self.decode(z).sample
446
+
447
+ if not return_dict:
448
+ return (dec,)
449
+
450
+ return DecoderOutput(sample=dec)
diffuserss/schedulers/scheduling_ddim.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
17
+ # and https://github.com/hojonathanho/diffusion
18
+
19
+ import math
20
+ from dataclasses import dataclass
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+
26
+ from ..configuration_utils import ConfigMixin, register_to_config
27
+ from ..utils import BaseOutput
28
+ from ..utils.torch_utils import randn_tensor
29
+ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
30
+
31
+
32
+ @dataclass
33
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
34
+ class DDIMSchedulerOutput(BaseOutput):
35
+ """
36
+ Output class for the scheduler's `step` function output.
37
+
38
+ Args:
39
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
40
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
41
+ denoising loop.
42
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
43
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
44
+ `pred_original_sample` can be used to preview progress or for guidance.
45
+ """
46
+
47
+ prev_sample: torch.FloatTensor
48
+ pred_original_sample: Optional[torch.FloatTensor] = None
49
+
50
+
51
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
52
+ def betas_for_alpha_bar(
53
+ num_diffusion_timesteps,
54
+ max_beta=0.999,
55
+ alpha_transform_type="cosine",
56
+ ):
57
+ """
58
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
59
+ (1-beta) over time from t = [0,1].
60
+
61
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
62
+ to that part of the diffusion process.
63
+
64
+
65
+ Args:
66
+ num_diffusion_timesteps (`int`): the number of betas to produce.
67
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
68
+ prevent singularities.
69
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
70
+ Choose from `cosine` or `exp`
71
+
72
+ Returns:
73
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
74
+ """
75
+ if alpha_transform_type == "cosine":
76
+
77
+ def alpha_bar_fn(t):
78
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
79
+
80
+ elif alpha_transform_type == "exp":
81
+
82
+ def alpha_bar_fn(t):
83
+ return math.exp(t * -12.0)
84
+
85
+ else:
86
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
87
+
88
+ betas = []
89
+ for i in range(num_diffusion_timesteps):
90
+ t1 = i / num_diffusion_timesteps
91
+ t2 = (i + 1) / num_diffusion_timesteps
92
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
93
+ return torch.tensor(betas, dtype=torch.float32)
94
+
95
+
96
+ def rescale_zero_terminal_snr(betas):
97
+ """
98
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
99
+
100
+
101
+ Args:
102
+ betas (`torch.FloatTensor`):
103
+ the betas that the scheduler is being initialized with.
104
+
105
+ Returns:
106
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
107
+ """
108
+ # Convert betas to alphas_bar_sqrt
109
+ alphas = 1.0 - betas
110
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
111
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
112
+
113
+ # Store old values.
114
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
115
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
116
+
117
+ # Shift so the last timestep is zero.
118
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
119
+
120
+ # Scale so the first timestep is back to the old value.
121
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
122
+
123
+ # Convert alphas_bar_sqrt to betas
124
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
125
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
126
+ alphas = torch.cat([alphas_bar[0:1], alphas])
127
+ betas = 1 - alphas
128
+
129
+ return betas
130
+
131
+
132
+ class DDIMScheduler(SchedulerMixin, ConfigMixin):
133
+ """
134
+ `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
135
+ non-Markovian guidance.
136
+
137
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
138
+ methods the library implements for all schedulers such as loading and saving.
139
+
140
+ Args:
141
+ num_train_timesteps (`int`, defaults to 1000):
142
+ The number of diffusion steps to train the model.
143
+ beta_start (`float`, defaults to 0.0001):
144
+ The starting `beta` value of inference.
145
+ beta_end (`float`, defaults to 0.02):
146
+ The final `beta` value.
147
+ beta_schedule (`str`, defaults to `"linear"`):
148
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
149
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
150
+ trained_betas (`np.ndarray`, *optional*):
151
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
152
+ clip_sample (`bool`, defaults to `True`):
153
+ Clip the predicted sample for numerical stability.
154
+ clip_sample_range (`float`, defaults to 1.0):
155
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
156
+ set_alpha_to_one (`bool`, defaults to `True`):
157
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
158
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
159
+ otherwise it uses the alpha value at step 0.
160
+ steps_offset (`int`, defaults to 0):
161
+ An offset added to the inference steps. You can use a combination of `offset=1` and
162
+ `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
163
+ Diffusion.
164
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
165
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
166
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
167
+ Video](https://imagen.research.google/video/paper.pdf) paper).
168
+ thresholding (`bool`, defaults to `False`):
169
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
170
+ as Stable Diffusion.
171
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
172
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
173
+ sample_max_value (`float`, defaults to 1.0):
174
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
175
+ timestep_spacing (`str`, defaults to `"leading"`):
176
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
177
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
178
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
179
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
180
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
181
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
182
+ """
183
+
184
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
185
+ order = 1
186
+
187
+ @register_to_config
188
+ def __init__(
189
+ self,
190
+ num_train_timesteps: int = 1000,
191
+ beta_start: float = 0.0001,
192
+ beta_end: float = 0.02,
193
+ beta_schedule: str = "linear",
194
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
195
+ clip_sample: bool = True,
196
+ set_alpha_to_one: bool = True,
197
+ steps_offset: int = 0,
198
+ prediction_type: str = "epsilon",
199
+ thresholding: bool = False,
200
+ dynamic_thresholding_ratio: float = 0.995,
201
+ clip_sample_range: float = 1.0,
202
+ sample_max_value: float = 1.0,
203
+ timestep_spacing: str = "leading",
204
+ rescale_betas_zero_snr: bool = False,
205
+ ):
206
+ if trained_betas is not None:
207
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
208
+ elif beta_schedule == "linear":
209
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
210
+ elif beta_schedule == "scaled_linear":
211
+ # this schedule is very specific to the latent diffusion model.
212
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
213
+ elif beta_schedule == "squaredcos_cap_v2":
214
+ # Glide cosine schedule
215
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
216
+ else:
217
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
218
+
219
+ # Rescale for zero SNR
220
+ if rescale_betas_zero_snr:
221
+ self.betas = rescale_zero_terminal_snr(self.betas)
222
+
223
+ self.alphas = 1.0 - self.betas
224
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
225
+
226
+ # At every step in ddim, we are looking into the previous alphas_cumprod
227
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
228
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
229
+ # whether we use the final alpha of the "non-previous" one.
230
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
231
+
232
+ # standard deviation of the initial noise distribution
233
+ self.init_noise_sigma = 1.0
234
+
235
+ # setable values
236
+ self.num_inference_steps = None
237
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
238
+
239
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
240
+ """
241
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
242
+ current timestep.
243
+
244
+ Args:
245
+ sample (`torch.FloatTensor`):
246
+ The input sample.
247
+ timestep (`int`, *optional*):
248
+ The current timestep in the diffusion chain.
249
+
250
+ Returns:
251
+ `torch.FloatTensor`:
252
+ A scaled input sample.
253
+ """
254
+ return sample
255
+
256
+ def _get_variance(self, timestep, prev_timestep):
257
+ alpha_prod_t = self.alphas_cumprod[timestep]
258
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
259
+ beta_prod_t = 1 - alpha_prod_t
260
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
261
+
262
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
263
+
264
+ return variance
265
+
266
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
267
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
268
+ """
269
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
270
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
271
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
272
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
273
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
274
+
275
+ https://arxiv.org/abs/2205.11487
276
+ """
277
+ dtype = sample.dtype
278
+ batch_size, channels, *remaining_dims = sample.shape
279
+
280
+ if dtype not in (torch.float32, torch.float64):
281
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
282
+
283
+ # Flatten sample for doing quantile calculation along each image
284
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
285
+
286
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
287
+
288
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
289
+ s = torch.clamp(
290
+ s, min=1, max=self.config.sample_max_value
291
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
292
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
293
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
294
+
295
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
296
+ sample = sample.to(dtype)
297
+
298
+ return sample
299
+
300
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
301
+ """
302
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
303
+
304
+ Args:
305
+ num_inference_steps (`int`):
306
+ The number of diffusion steps used when generating samples with a pre-trained model.
307
+ """
308
+
309
+ if num_inference_steps > self.config.num_train_timesteps:
310
+ raise ValueError(
311
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
312
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
313
+ f" maximal {self.config.num_train_timesteps} timesteps."
314
+ )
315
+
316
+ self.num_inference_steps = num_inference_steps
317
+
318
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
319
+ if self.config.timestep_spacing == "linspace":
320
+ timesteps = (
321
+ np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
322
+ .round()[::-1]
323
+ .copy()
324
+ .astype(np.int64)
325
+ )
326
+ elif self.config.timestep_spacing == "leading":
327
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
328
+ # creates integer timesteps by multiplying by ratio
329
+ # casting to int to avoid issues when num_inference_step is power of 3
330
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
331
+ timesteps += self.config.steps_offset
332
+ elif self.config.timestep_spacing == "trailing":
333
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
334
+ # creates integer timesteps by multiplying by ratio
335
+ # casting to int to avoid issues when num_inference_step is power of 3
336
+ timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
337
+ timesteps -= 1
338
+ else:
339
+ raise ValueError(
340
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
341
+ )
342
+
343
+ self.timesteps = torch.from_numpy(timesteps).to(device)
344
+
345
+ def step(
346
+ self,
347
+ model_output: torch.FloatTensor,
348
+ timestep: int,
349
+ sample: torch.FloatTensor,
350
+ eta: float = 0.0,
351
+ use_clipped_model_output: bool = False,
352
+ generator=None,
353
+ variance_noise: Optional[torch.FloatTensor] = None,
354
+ return_dict: bool = True,
355
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
356
+ """
357
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
358
+ process from the learned model outputs (most often the predicted noise).
359
+
360
+ Args:
361
+ model_output (`torch.FloatTensor`):
362
+ The direct output from learned diffusion model.
363
+ timestep (`float`):
364
+ The current discrete timestep in the diffusion chain.
365
+ sample (`torch.FloatTensor`):
366
+ A current instance of a sample created by the diffusion process.
367
+ eta (`float`):
368
+ The weight of noise for added noise in diffusion step.
369
+ use_clipped_model_output (`bool`, defaults to `False`):
370
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
371
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
372
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
373
+ `use_clipped_model_output` has no effect.
374
+ generator (`torch.Generator`, *optional*):
375
+ A random number generator.
376
+ variance_noise (`torch.FloatTensor`):
377
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
378
+ itself. Useful for methods such as [`CycleDiffusion`].
379
+ return_dict (`bool`, *optional*, defaults to `True`):
380
+ Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
381
+
382
+ Returns:
383
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
384
+ If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
385
+ tuple is returned where the first element is the sample tensor.
386
+
387
+ """
388
+ if self.num_inference_steps is None:
389
+ raise ValueError(
390
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
391
+ )
392
+
393
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
394
+ # Ideally, read DDIM paper in-detail understanding
395
+
396
+ # Notation (<variable name> -> <name in paper>
397
+ # - pred_noise_t -> e_theta(x_t, t)
398
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
399
+ # - std_dev_t -> sigma_t
400
+ # - eta -> η
401
+ # - pred_sample_direction -> "direction pointing to x_t"
402
+ # - pred_prev_sample -> "x_t-1"
403
+
404
+ # 1. get previous step value (=t-1)
405
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
406
+
407
+ # 2. compute alphas, betas
408
+ alpha_prod_t = self.alphas_cumprod[timestep]
409
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
410
+
411
+ beta_prod_t = 1 - alpha_prod_t
412
+
413
+ # 3. compute predicted original sample from predicted noise also called
414
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
415
+ if self.config.prediction_type == "epsilon":
416
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
417
+ pred_epsilon = model_output
418
+ elif self.config.prediction_type == "sample":
419
+ pred_original_sample = model_output
420
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
421
+ elif self.config.prediction_type == "v_prediction":
422
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
423
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
424
+ else:
425
+ raise ValueError(
426
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
427
+ " `v_prediction`"
428
+ )
429
+
430
+ # 4. Clip or threshold "predicted x_0"
431
+ if self.config.thresholding:
432
+ pred_original_sample = self._threshold_sample(pred_original_sample)
433
+ elif self.config.clip_sample:
434
+ pred_original_sample = pred_original_sample.clamp(
435
+ -self.config.clip_sample_range, self.config.clip_sample_range
436
+ )
437
+
438
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
439
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
440
+ variance = self._get_variance(timestep, prev_timestep)
441
+ std_dev_t = eta * variance ** (0.5)
442
+
443
+ if use_clipped_model_output:
444
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
445
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
446
+
447
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
448
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
449
+
450
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
451
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
452
+
453
+ if eta > 0:
454
+ if variance_noise is not None and generator is not None:
455
+ raise ValueError(
456
+ "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
457
+ " `variance_noise` stays `None`."
458
+ )
459
+
460
+ if variance_noise is None:
461
+ variance_noise = randn_tensor(
462
+ model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
463
+ )
464
+ variance = std_dev_t * variance_noise
465
+
466
+ prev_sample = prev_sample + variance
467
+
468
+ if not return_dict:
469
+ return (prev_sample,)
470
+
471
+ return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
472
+
473
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
474
+ def add_noise(
475
+ self,
476
+ original_samples: torch.FloatTensor,
477
+ noise: torch.FloatTensor,
478
+ timesteps: torch.IntTensor,
479
+ ) -> torch.FloatTensor:
480
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
481
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
482
+ timesteps = timesteps.to(original_samples.device)
483
+
484
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
485
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
486
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
487
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
488
+
489
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
490
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
491
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
492
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
493
+
494
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
495
+ return noisy_samples
496
+
497
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
498
+ def get_velocity(
499
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
500
+ ) -> torch.FloatTensor:
501
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
502
+ alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
503
+ timesteps = timesteps.to(sample.device)
504
+
505
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
506
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
507
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
508
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
509
+
510
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
511
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
512
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
513
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
514
+
515
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
516
+ return velocity
517
+
518
+ def __len__(self):
519
+ return self.config.num_train_timesteps
inference_bokehK.py CHANGED
@@ -11,7 +11,12 @@ from pathlib import Path
11
  from omegaconf import OmegaConf
12
  from torch.utils.data import Dataset
13
  from transformers import CLIPTextModel, CLIPTokenizer
14
- from diffusers import AutoencoderKL, DDIMScheduler
 
 
 
 
 
15
  from einops import rearrange
16
 
17
  from genphoto.pipelines.pipeline_animation import GenPhotoPipeline
 
11
  from omegaconf import OmegaConf
12
  from torch.utils.data import Dataset
13
  from transformers import CLIPTextModel, CLIPTokenizer
14
+
15
+
16
+ from diffuserss.models.autoencoder_kl import AutoencoderKL
17
+ from diffuserss.schedulers.scheduling_ddim import DDIMScheduler
18
+
19
+
20
  from einops import rearrange
21
 
22
  from genphoto.pipelines.pipeline_animation import GenPhotoPipeline
inference_color_temperature.py CHANGED
@@ -11,7 +11,13 @@ from pathlib import Path
11
  from omegaconf import OmegaConf
12
  from torch.utils.data import Dataset
13
  from transformers import CLIPTextModel, CLIPTokenizer
14
- from diffusers import AutoencoderKL, DDIMScheduler
 
 
 
 
 
 
15
  from einops import rearrange
16
 
17
  from genphoto.pipelines.pipeline_animation import GenPhotoPipeline
 
11
  from omegaconf import OmegaConf
12
  from torch.utils.data import Dataset
13
  from transformers import CLIPTextModel, CLIPTokenizer
14
+
15
+
16
+ from diffuserss.models.autoencoder_kl import AutoencoderKL
17
+ from diffuserss.schedulers.scheduling_ddim import DDIMScheduler
18
+
19
+
20
+
21
  from einops import rearrange
22
 
23
  from genphoto.pipelines.pipeline_animation import GenPhotoPipeline
inference_focal_length.py CHANGED
@@ -11,7 +11,14 @@ from pathlib import Path
11
  from omegaconf import OmegaConf
12
  from torch.utils.data import Dataset
13
  from transformers import CLIPTextModel, CLIPTokenizer
14
- from diffusers import AutoencoderKL, DDIMScheduler
 
 
 
 
 
 
 
15
  from einops import rearrange
16
 
17
  from genphoto.pipelines.pipeline_animation import GenPhotoPipeline
 
11
  from omegaconf import OmegaConf
12
  from torch.utils.data import Dataset
13
  from transformers import CLIPTextModel, CLIPTokenizer
14
+
15
+
16
+
17
+ from diffuserss.models.autoencoder_kl import AutoencoderKL
18
+ from diffuserss.schedulers.scheduling_ddim import DDIMScheduler
19
+
20
+
21
+
22
  from einops import rearrange
23
 
24
  from genphoto.pipelines.pipeline_animation import GenPhotoPipeline
inference_shutter_speed.py CHANGED
@@ -11,7 +11,11 @@ from pathlib import Path
11
  from omegaconf import OmegaConf
12
  from torch.utils.data import Dataset
13
  from transformers import CLIPTextModel, CLIPTokenizer
14
- from diffusers import AutoencoderKL, DDIMScheduler
 
 
 
 
15
  from einops import rearrange
16
 
17
  from genphoto.pipelines.pipeline_animation import GenPhotoPipeline
 
11
  from omegaconf import OmegaConf
12
  from torch.utils.data import Dataset
13
  from transformers import CLIPTextModel, CLIPTokenizer
14
+
15
+ from diffuserss.models.autoencoder_kl import AutoencoderKL
16
+ from diffuserss.schedulers.scheduling_ddim import DDIMScheduler
17
+
18
+
19
  from einops import rearrange
20
 
21
  from genphoto.pipelines.pipeline_animation import GenPhotoPipeline