ouclxy commited on
Commit
f5d6fe2
·
verified ·
1 Parent(s): 1cb8a65

Upload 321 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. diffusers/__init__.py +734 -0
  2. diffusers/commands/__init__.py +27 -0
  3. diffusers/commands/diffusers_cli.py +43 -0
  4. diffusers/commands/env.py +84 -0
  5. diffusers/commands/fp16_safetensors.py +133 -0
  6. diffusers/configuration_utils.py +694 -0
  7. diffusers/dependency_versions_check.py +35 -0
  8. diffusers/dependency_versions_table.py +46 -0
  9. diffusers/experimental/README.md +5 -0
  10. diffusers/experimental/__init__.py +1 -0
  11. diffusers/experimental/rl/__init__.py +1 -0
  12. diffusers/experimental/rl/value_guided_sampling.py +154 -0
  13. diffusers/image_processor.py +476 -0
  14. diffusers/loaders.py +0 -0
  15. diffusers/models/README.md +3 -0
  16. diffusers/models/__init__.py +77 -0
  17. diffusers/models/activations.py +120 -0
  18. diffusers/models/adapter.py +584 -0
  19. diffusers/models/attention.py +398 -0
  20. diffusers/models/attention_flax.py +486 -0
  21. diffusers/models/attention_processor.py +2020 -0
  22. diffusers/models/autoencoder_asym_kl.py +181 -0
  23. diffusers/models/autoencoder_kl.py +465 -0
  24. diffusers/models/autoencoder_tiny.py +349 -0
  25. diffusers/models/consistency_decoder_vae.py +430 -0
  26. diffusers/models/controlnet.py +844 -0
  27. diffusers/models/controlnet_flax.py +394 -0
  28. diffusers/models/dual_transformer_2d.py +155 -0
  29. diffusers/models/embeddings.py +792 -0
  30. diffusers/models/embeddings_flax.py +95 -0
  31. diffusers/models/lora.py +304 -0
  32. diffusers/models/modeling_flax_pytorch_utils.py +134 -0
  33. diffusers/models/modeling_flax_utils.py +560 -0
  34. diffusers/models/modeling_pytorch_flax_utils.py +161 -0
  35. diffusers/models/modeling_utils.py +1158 -0
  36. diffusers/models/normalization.py +148 -0
  37. diffusers/models/prior_transformer.py +382 -0
  38. diffusers/models/resnet.py +1037 -0
  39. diffusers/models/resnet_flax.py +124 -0
  40. diffusers/models/t5_film_transformer.py +438 -0
  41. diffusers/models/transformer_2d.py +442 -0
  42. diffusers/models/transformer_temporal.py +197 -0
  43. diffusers/models/unet_1d.py +255 -0
  44. diffusers/models/unet_1d_blocks.py +702 -0
  45. diffusers/models/unet_2d.py +346 -0
  46. diffusers/models/unet_2d_blocks.py +0 -0
  47. diffusers/models/unet_2d_blocks_flax.py +395 -0
  48. diffusers/models/unet_2d_condition.py +1163 -0
  49. diffusers/models/unet_2d_condition_flax.py +444 -0
  50. diffusers/models/unet_3d_blocks.py +1611 -0
diffusers/__init__.py ADDED
@@ -0,0 +1,734 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "0.23.1"
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from .utils import (
6
+ DIFFUSERS_SLOW_IMPORT,
7
+ OptionalDependencyNotAvailable,
8
+ _LazyModule,
9
+ is_flax_available,
10
+ is_k_diffusion_available,
11
+ is_librosa_available,
12
+ is_note_seq_available,
13
+ is_onnx_available,
14
+ is_scipy_available,
15
+ is_torch_available,
16
+ is_torchsde_available,
17
+ is_transformers_available,
18
+ )
19
+
20
+
21
+ # Lazy Import based on
22
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/__init__.py
23
+
24
+ # When adding a new object to this init, please add it to `_import_structure`. The `_import_structure` is a dictionary submodule to list of object names,
25
+ # and is used to defer the actual importing for when the objects are requested.
26
+ # This way `import diffusers` provides the names in the namespace without actually importing anything (and especially none of the backends).
27
+
28
+ _import_structure = {
29
+ "configuration_utils": ["ConfigMixin"],
30
+ "models": [],
31
+ "pipelines": [],
32
+ "schedulers": [],
33
+ "utils": [
34
+ "OptionalDependencyNotAvailable",
35
+ "is_flax_available",
36
+ "is_inflect_available",
37
+ "is_invisible_watermark_available",
38
+ "is_k_diffusion_available",
39
+ "is_k_diffusion_version",
40
+ "is_librosa_available",
41
+ "is_note_seq_available",
42
+ "is_onnx_available",
43
+ "is_scipy_available",
44
+ "is_torch_available",
45
+ "is_torchsde_available",
46
+ "is_transformers_available",
47
+ "is_transformers_version",
48
+ "is_unidecode_available",
49
+ "logging",
50
+ ],
51
+ }
52
+
53
+ try:
54
+ if not is_onnx_available():
55
+ raise OptionalDependencyNotAvailable()
56
+ except OptionalDependencyNotAvailable:
57
+ from .utils import dummy_onnx_objects # noqa F403
58
+
59
+ _import_structure["utils.dummy_onnx_objects"] = [
60
+ name for name in dir(dummy_onnx_objects) if not name.startswith("_")
61
+ ]
62
+
63
+ else:
64
+ _import_structure["pipelines"].extend(["OnnxRuntimeModel"])
65
+
66
+ try:
67
+ if not is_torch_available():
68
+ raise OptionalDependencyNotAvailable()
69
+ except OptionalDependencyNotAvailable:
70
+ from .utils import dummy_pt_objects # noqa F403
71
+
72
+ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
73
+
74
+ else:
75
+ _import_structure["models"].extend(
76
+ [
77
+ "AsymmetricAutoencoderKL",
78
+ "AutoencoderKL",
79
+ "AutoencoderTiny",
80
+ "ConsistencyDecoderVAE",
81
+ "ControlNetModel",
82
+ "ModelMixin",
83
+ "MotionAdapter",
84
+ "MultiAdapter",
85
+ "PriorTransformer",
86
+ "T2IAdapter",
87
+ "T5FilmDecoder",
88
+ "Transformer2DModel",
89
+ "UNet1DModel",
90
+ "UNet2DConditionModel",
91
+ "UNet2DModel",
92
+ "UNet3DConditionModel",
93
+ "UNetMotionModel",
94
+ "VQModel",
95
+ ]
96
+ )
97
+ _import_structure["optimization"] = [
98
+ "get_constant_schedule",
99
+ "get_constant_schedule_with_warmup",
100
+ "get_cosine_schedule_with_warmup",
101
+ "get_cosine_with_hard_restarts_schedule_with_warmup",
102
+ "get_linear_schedule_with_warmup",
103
+ "get_polynomial_decay_schedule_with_warmup",
104
+ "get_scheduler",
105
+ ]
106
+
107
+ _import_structure["pipelines"].extend(
108
+ [
109
+ "AudioPipelineOutput",
110
+ "AutoPipelineForImage2Image",
111
+ "AutoPipelineForInpainting",
112
+ "AutoPipelineForText2Image",
113
+ "ConsistencyModelPipeline",
114
+ "DanceDiffusionPipeline",
115
+ "DDIMPipeline",
116
+ "DDPMPipeline",
117
+ "DiffusionPipeline",
118
+ "DiTPipeline",
119
+ "ImagePipelineOutput",
120
+ "KarrasVePipeline",
121
+ "LDMPipeline",
122
+ "LDMSuperResolutionPipeline",
123
+ "PNDMPipeline",
124
+ "RePaintPipeline",
125
+ "ScoreSdeVePipeline",
126
+ ]
127
+ )
128
+ _import_structure["schedulers"].extend(
129
+ [
130
+ "CMStochasticIterativeScheduler",
131
+ "DDIMInverseScheduler",
132
+ "DDIMParallelScheduler",
133
+ "DDIMScheduler",
134
+ "DDPMParallelScheduler",
135
+ "DDPMScheduler",
136
+ "DDPMWuerstchenScheduler",
137
+ "DEISMultistepScheduler",
138
+ "DPMSolverMultistepInverseScheduler",
139
+ "DPMSolverMultistepScheduler",
140
+ "DPMSolverSinglestepScheduler",
141
+ "EulerAncestralDiscreteScheduler",
142
+ "EulerDiscreteScheduler",
143
+ "HeunDiscreteScheduler",
144
+ "IPNDMScheduler",
145
+ "KarrasVeScheduler",
146
+ "KDPM2AncestralDiscreteScheduler",
147
+ "KDPM2DiscreteScheduler",
148
+ "LCMScheduler",
149
+ "PNDMScheduler",
150
+ "RePaintScheduler",
151
+ "SchedulerMixin",
152
+ "ScoreSdeVeScheduler",
153
+ "UnCLIPScheduler",
154
+ "UniPCMultistepScheduler",
155
+ "VQDiffusionScheduler",
156
+ ]
157
+ )
158
+ _import_structure["training_utils"] = ["EMAModel"]
159
+
160
+ try:
161
+ if not (is_torch_available() and is_scipy_available()):
162
+ raise OptionalDependencyNotAvailable()
163
+ except OptionalDependencyNotAvailable:
164
+ from .utils import dummy_torch_and_scipy_objects # noqa F403
165
+
166
+ _import_structure["utils.dummy_torch_and_scipy_objects"] = [
167
+ name for name in dir(dummy_torch_and_scipy_objects) if not name.startswith("_")
168
+ ]
169
+
170
+ else:
171
+ _import_structure["schedulers"].extend(["LMSDiscreteScheduler"])
172
+
173
+ try:
174
+ if not (is_torch_available() and is_torchsde_available()):
175
+ raise OptionalDependencyNotAvailable()
176
+ except OptionalDependencyNotAvailable:
177
+ from .utils import dummy_torch_and_torchsde_objects # noqa F403
178
+
179
+ _import_structure["utils.dummy_torch_and_torchsde_objects"] = [
180
+ name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_")
181
+ ]
182
+
183
+ else:
184
+ _import_structure["schedulers"].extend(["DPMSolverSDEScheduler"])
185
+
186
+ try:
187
+ if not (is_torch_available() and is_transformers_available()):
188
+ raise OptionalDependencyNotAvailable()
189
+ except OptionalDependencyNotAvailable:
190
+ from .utils import dummy_torch_and_transformers_objects # noqa F403
191
+
192
+ _import_structure["utils.dummy_torch_and_transformers_objects"] = [
193
+ name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_")
194
+ ]
195
+
196
+ else:
197
+ _import_structure["pipelines"].extend(
198
+ [
199
+ "AltDiffusionImg2ImgPipeline",
200
+ "AltDiffusionPipeline",
201
+ "AnimateDiffPipeline",
202
+ "AudioLDM2Pipeline",
203
+ "AudioLDM2ProjectionModel",
204
+ "AudioLDM2UNet2DConditionModel",
205
+ "AudioLDMPipeline",
206
+ "BlipDiffusionControlNetPipeline",
207
+ "BlipDiffusionPipeline",
208
+ "CLIPImageProjection",
209
+ "CycleDiffusionPipeline",
210
+ "IFImg2ImgPipeline",
211
+ "IFImg2ImgSuperResolutionPipeline",
212
+ "IFInpaintingPipeline",
213
+ "IFInpaintingSuperResolutionPipeline",
214
+ "IFPipeline",
215
+ "IFSuperResolutionPipeline",
216
+ "ImageTextPipelineOutput",
217
+ "KandinskyCombinedPipeline",
218
+ "KandinskyImg2ImgCombinedPipeline",
219
+ "KandinskyImg2ImgPipeline",
220
+ "KandinskyInpaintCombinedPipeline",
221
+ "KandinskyInpaintPipeline",
222
+ "KandinskyPipeline",
223
+ "KandinskyPriorPipeline",
224
+ "KandinskyV22CombinedPipeline",
225
+ "KandinskyV22ControlnetImg2ImgPipeline",
226
+ "KandinskyV22ControlnetPipeline",
227
+ "KandinskyV22Img2ImgCombinedPipeline",
228
+ "KandinskyV22Img2ImgPipeline",
229
+ "KandinskyV22InpaintCombinedPipeline",
230
+ "KandinskyV22InpaintPipeline",
231
+ "KandinskyV22Pipeline",
232
+ "KandinskyV22PriorEmb2EmbPipeline",
233
+ "KandinskyV22PriorPipeline",
234
+ "LatentConsistencyModelImg2ImgPipeline",
235
+ "LatentConsistencyModelPipeline",
236
+ "LDMTextToImagePipeline",
237
+ "MusicLDMPipeline",
238
+ "PaintByExamplePipeline",
239
+ "PixArtAlphaPipeline",
240
+ "SemanticStableDiffusionPipeline",
241
+ "ShapEImg2ImgPipeline",
242
+ "ShapEPipeline",
243
+ "StableDiffusionAdapterPipeline",
244
+ "StableDiffusionAttendAndExcitePipeline",
245
+ "StableDiffusionControlNetImg2ImgPipeline",
246
+ "StableDiffusionControlNetInpaintPipeline",
247
+ "StableDiffusionControlNetPipeline",
248
+ "StableDiffusionDepth2ImgPipeline",
249
+ "StableDiffusionDiffEditPipeline",
250
+ "StableDiffusionGLIGENPipeline",
251
+ "StableDiffusionGLIGENTextImagePipeline",
252
+ "StableDiffusionImageVariationPipeline",
253
+ "StableDiffusionImg2ImgPipeline",
254
+ "StableDiffusionInpaintPipeline",
255
+ "StableDiffusionInpaintPipelineLegacy",
256
+ "StableDiffusionInstructPix2PixPipeline",
257
+ "StableDiffusionLatentUpscalePipeline",
258
+ "StableDiffusionLDM3DPipeline",
259
+ "StableDiffusionModelEditingPipeline",
260
+ "StableDiffusionPanoramaPipeline",
261
+ "StableDiffusionParadigmsPipeline",
262
+ "StableDiffusionPipeline",
263
+ "StableDiffusionPipelineSafe",
264
+ "StableDiffusionPix2PixZeroPipeline",
265
+ "StableDiffusionSAGPipeline",
266
+ "StableDiffusionUpscalePipeline",
267
+ "StableDiffusionXLAdapterPipeline",
268
+ "StableDiffusionXLControlNetImg2ImgPipeline",
269
+ "StableDiffusionXLControlNetInpaintPipeline",
270
+ "StableDiffusionXLControlNetPipeline",
271
+ "StableDiffusionXLImg2ImgPipeline",
272
+ "StableDiffusionXLInpaintPipeline",
273
+ "StableDiffusionXLInstructPix2PixPipeline",
274
+ "StableDiffusionXLPipeline",
275
+ "StableUnCLIPImg2ImgPipeline",
276
+ "StableUnCLIPPipeline",
277
+ "TextToVideoSDPipeline",
278
+ "TextToVideoZeroPipeline",
279
+ "UnCLIPImageVariationPipeline",
280
+ "UnCLIPPipeline",
281
+ "UniDiffuserModel",
282
+ "UniDiffuserPipeline",
283
+ "UniDiffuserTextDecoder",
284
+ "VersatileDiffusionDualGuidedPipeline",
285
+ "VersatileDiffusionImageVariationPipeline",
286
+ "VersatileDiffusionPipeline",
287
+ "VersatileDiffusionTextToImagePipeline",
288
+ "VideoToVideoSDPipeline",
289
+ "VQDiffusionPipeline",
290
+ "WuerstchenCombinedPipeline",
291
+ "WuerstchenDecoderPipeline",
292
+ "WuerstchenPriorPipeline",
293
+ ]
294
+ )
295
+
296
+ try:
297
+ if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
298
+ raise OptionalDependencyNotAvailable()
299
+ except OptionalDependencyNotAvailable:
300
+ from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
301
+
302
+ _import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [
303
+ name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_")
304
+ ]
305
+
306
+ else:
307
+ _import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline"])
308
+
309
+ try:
310
+ if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
311
+ raise OptionalDependencyNotAvailable()
312
+ except OptionalDependencyNotAvailable:
313
+ from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403
314
+
315
+ _import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [
316
+ name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_")
317
+ ]
318
+
319
+ else:
320
+ _import_structure["pipelines"].extend(
321
+ [
322
+ "OnnxStableDiffusionImg2ImgPipeline",
323
+ "OnnxStableDiffusionInpaintPipeline",
324
+ "OnnxStableDiffusionInpaintPipelineLegacy",
325
+ "OnnxStableDiffusionPipeline",
326
+ "OnnxStableDiffusionUpscalePipeline",
327
+ "StableDiffusionOnnxPipeline",
328
+ ]
329
+ )
330
+
331
+ try:
332
+ if not (is_torch_available() and is_librosa_available()):
333
+ raise OptionalDependencyNotAvailable()
334
+ except OptionalDependencyNotAvailable:
335
+ from .utils import dummy_torch_and_librosa_objects # noqa F403
336
+
337
+ _import_structure["utils.dummy_torch_and_librosa_objects"] = [
338
+ name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_")
339
+ ]
340
+
341
+ else:
342
+ _import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"])
343
+
344
+ try:
345
+ if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
346
+ raise OptionalDependencyNotAvailable()
347
+ except OptionalDependencyNotAvailable:
348
+ from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403
349
+
350
+ _import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [
351
+ name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_")
352
+ ]
353
+
354
+
355
+ else:
356
+ _import_structure["pipelines"].extend(["SpectrogramDiffusionPipeline"])
357
+
358
+ try:
359
+ if not is_flax_available():
360
+ raise OptionalDependencyNotAvailable()
361
+ except OptionalDependencyNotAvailable:
362
+ from .utils import dummy_flax_objects # noqa F403
363
+
364
+ _import_structure["utils.dummy_flax_objects"] = [
365
+ name for name in dir(dummy_flax_objects) if not name.startswith("_")
366
+ ]
367
+
368
+
369
+ else:
370
+ _import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"]
371
+ _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
372
+ _import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
373
+ _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
374
+ _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
375
+ _import_structure["schedulers"].extend(
376
+ [
377
+ "FlaxDDIMScheduler",
378
+ "FlaxDDPMScheduler",
379
+ "FlaxDPMSolverMultistepScheduler",
380
+ "FlaxEulerDiscreteScheduler",
381
+ "FlaxKarrasVeScheduler",
382
+ "FlaxLMSDiscreteScheduler",
383
+ "FlaxPNDMScheduler",
384
+ "FlaxSchedulerMixin",
385
+ "FlaxScoreSdeVeScheduler",
386
+ ]
387
+ )
388
+
389
+
390
+ try:
391
+ if not (is_flax_available() and is_transformers_available()):
392
+ raise OptionalDependencyNotAvailable()
393
+ except OptionalDependencyNotAvailable:
394
+ from .utils import dummy_flax_and_transformers_objects # noqa F403
395
+
396
+ _import_structure["utils.dummy_flax_and_transformers_objects"] = [
397
+ name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_")
398
+ ]
399
+
400
+
401
+ else:
402
+ _import_structure["pipelines"].extend(
403
+ [
404
+ "FlaxStableDiffusionControlNetPipeline",
405
+ "FlaxStableDiffusionImg2ImgPipeline",
406
+ "FlaxStableDiffusionInpaintPipeline",
407
+ "FlaxStableDiffusionPipeline",
408
+ "FlaxStableDiffusionXLPipeline",
409
+ ]
410
+ )
411
+
412
+ try:
413
+ if not (is_note_seq_available()):
414
+ raise OptionalDependencyNotAvailable()
415
+ except OptionalDependencyNotAvailable:
416
+ from .utils import dummy_note_seq_objects # noqa F403
417
+
418
+ _import_structure["utils.dummy_note_seq_objects"] = [
419
+ name for name in dir(dummy_note_seq_objects) if not name.startswith("_")
420
+ ]
421
+
422
+
423
+ else:
424
+ _import_structure["pipelines"].extend(["MidiProcessor"])
425
+
426
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
427
+ from .configuration_utils import ConfigMixin
428
+
429
+ try:
430
+ if not is_onnx_available():
431
+ raise OptionalDependencyNotAvailable()
432
+ except OptionalDependencyNotAvailable:
433
+ from .utils.dummy_onnx_objects import * # noqa F403
434
+ else:
435
+ from .pipelines import OnnxRuntimeModel
436
+
437
+ try:
438
+ if not is_torch_available():
439
+ raise OptionalDependencyNotAvailable()
440
+ except OptionalDependencyNotAvailable:
441
+ from .utils.dummy_pt_objects import * # noqa F403
442
+ else:
443
+ from .models import (
444
+ AsymmetricAutoencoderKL,
445
+ AutoencoderKL,
446
+ AutoencoderTiny,
447
+ ConsistencyDecoderVAE,
448
+ ControlNetModel,
449
+ ModelMixin,
450
+ MotionAdapter,
451
+ MultiAdapter,
452
+ PriorTransformer,
453
+ T2IAdapter,
454
+ T5FilmDecoder,
455
+ Transformer2DModel,
456
+ UNet1DModel,
457
+ UNet2DConditionModel,
458
+ UNet2DModel,
459
+ UNet3DConditionModel,
460
+ UNetMotionModel,
461
+ VQModel,
462
+ )
463
+ from .optimization import (
464
+ get_constant_schedule,
465
+ get_constant_schedule_with_warmup,
466
+ get_cosine_schedule_with_warmup,
467
+ get_cosine_with_hard_restarts_schedule_with_warmup,
468
+ get_linear_schedule_with_warmup,
469
+ get_polynomial_decay_schedule_with_warmup,
470
+ get_scheduler,
471
+ )
472
+ from .pipelines import (
473
+ AudioPipelineOutput,
474
+ AutoPipelineForImage2Image,
475
+ AutoPipelineForInpainting,
476
+ AutoPipelineForText2Image,
477
+ BlipDiffusionControlNetPipeline,
478
+ BlipDiffusionPipeline,
479
+ CLIPImageProjection,
480
+ ConsistencyModelPipeline,
481
+ DanceDiffusionPipeline,
482
+ DDIMPipeline,
483
+ DDPMPipeline,
484
+ DiffusionPipeline,
485
+ DiTPipeline,
486
+ ImagePipelineOutput,
487
+ KarrasVePipeline,
488
+ LDMPipeline,
489
+ LDMSuperResolutionPipeline,
490
+ PNDMPipeline,
491
+ RePaintPipeline,
492
+ ScoreSdeVePipeline,
493
+ )
494
+ from .schedulers import (
495
+ CMStochasticIterativeScheduler,
496
+ DDIMInverseScheduler,
497
+ DDIMParallelScheduler,
498
+ DDIMScheduler,
499
+ DDPMParallelScheduler,
500
+ DDPMScheduler,
501
+ DDPMWuerstchenScheduler,
502
+ DEISMultistepScheduler,
503
+ DPMSolverMultistepInverseScheduler,
504
+ DPMSolverMultistepScheduler,
505
+ DPMSolverSinglestepScheduler,
506
+ EulerAncestralDiscreteScheduler,
507
+ EulerDiscreteScheduler,
508
+ HeunDiscreteScheduler,
509
+ IPNDMScheduler,
510
+ KarrasVeScheduler,
511
+ KDPM2AncestralDiscreteScheduler,
512
+ KDPM2DiscreteScheduler,
513
+ LCMScheduler,
514
+ PNDMScheduler,
515
+ RePaintScheduler,
516
+ SchedulerMixin,
517
+ ScoreSdeVeScheduler,
518
+ UnCLIPScheduler,
519
+ UniPCMultistepScheduler,
520
+ VQDiffusionScheduler,
521
+ )
522
+ from .training_utils import EMAModel
523
+
524
+ try:
525
+ if not (is_torch_available() and is_scipy_available()):
526
+ raise OptionalDependencyNotAvailable()
527
+ except OptionalDependencyNotAvailable:
528
+ from .utils.dummy_torch_and_scipy_objects import * # noqa F403
529
+ else:
530
+ from .schedulers import LMSDiscreteScheduler
531
+
532
+ try:
533
+ if not (is_torch_available() and is_torchsde_available()):
534
+ raise OptionalDependencyNotAvailable()
535
+ except OptionalDependencyNotAvailable:
536
+ from .utils.dummy_torch_and_torchsde_objects import * # noqa F403
537
+ else:
538
+ from .schedulers import DPMSolverSDEScheduler
539
+
540
+ try:
541
+ if not (is_torch_available() and is_transformers_available()):
542
+ raise OptionalDependencyNotAvailable()
543
+ except OptionalDependencyNotAvailable:
544
+ from .utils.dummy_torch_and_transformers_objects import * # noqa F403
545
+ else:
546
+ from .pipelines import (
547
+ AltDiffusionImg2ImgPipeline,
548
+ AltDiffusionPipeline,
549
+ AnimateDiffPipeline,
550
+ AudioLDM2Pipeline,
551
+ AudioLDM2ProjectionModel,
552
+ AudioLDM2UNet2DConditionModel,
553
+ AudioLDMPipeline,
554
+ CLIPImageProjection,
555
+ CycleDiffusionPipeline,
556
+ IFImg2ImgPipeline,
557
+ IFImg2ImgSuperResolutionPipeline,
558
+ IFInpaintingPipeline,
559
+ IFInpaintingSuperResolutionPipeline,
560
+ IFPipeline,
561
+ IFSuperResolutionPipeline,
562
+ ImageTextPipelineOutput,
563
+ KandinskyCombinedPipeline,
564
+ KandinskyImg2ImgCombinedPipeline,
565
+ KandinskyImg2ImgPipeline,
566
+ KandinskyInpaintCombinedPipeline,
567
+ KandinskyInpaintPipeline,
568
+ KandinskyPipeline,
569
+ KandinskyPriorPipeline,
570
+ KandinskyV22CombinedPipeline,
571
+ KandinskyV22ControlnetImg2ImgPipeline,
572
+ KandinskyV22ControlnetPipeline,
573
+ KandinskyV22Img2ImgCombinedPipeline,
574
+ KandinskyV22Img2ImgPipeline,
575
+ KandinskyV22InpaintCombinedPipeline,
576
+ KandinskyV22InpaintPipeline,
577
+ KandinskyV22Pipeline,
578
+ KandinskyV22PriorEmb2EmbPipeline,
579
+ KandinskyV22PriorPipeline,
580
+ LatentConsistencyModelImg2ImgPipeline,
581
+ LatentConsistencyModelPipeline,
582
+ LDMTextToImagePipeline,
583
+ MusicLDMPipeline,
584
+ PaintByExamplePipeline,
585
+ PixArtAlphaPipeline,
586
+ SemanticStableDiffusionPipeline,
587
+ ShapEImg2ImgPipeline,
588
+ ShapEPipeline,
589
+ StableDiffusionAdapterPipeline,
590
+ StableDiffusionAttendAndExcitePipeline,
591
+ StableDiffusionControlNetImg2ImgPipeline,
592
+ StableDiffusionControlNetInpaintPipeline,
593
+ StableDiffusionControlNetPipeline,
594
+ StableDiffusionDepth2ImgPipeline,
595
+ StableDiffusionDiffEditPipeline,
596
+ StableDiffusionGLIGENPipeline,
597
+ StableDiffusionGLIGENTextImagePipeline,
598
+ StableDiffusionImageVariationPipeline,
599
+ StableDiffusionImg2ImgPipeline,
600
+ StableDiffusionInpaintPipeline,
601
+ StableDiffusionInpaintPipelineLegacy,
602
+ StableDiffusionInstructPix2PixPipeline,
603
+ StableDiffusionLatentUpscalePipeline,
604
+ StableDiffusionLDM3DPipeline,
605
+ StableDiffusionModelEditingPipeline,
606
+ StableDiffusionPanoramaPipeline,
607
+ StableDiffusionParadigmsPipeline,
608
+ StableDiffusionPipeline,
609
+ StableDiffusionPipelineSafe,
610
+ StableDiffusionPix2PixZeroPipeline,
611
+ StableDiffusionSAGPipeline,
612
+ StableDiffusionUpscalePipeline,
613
+ StableDiffusionXLAdapterPipeline,
614
+ StableDiffusionXLControlNetImg2ImgPipeline,
615
+ StableDiffusionXLControlNetInpaintPipeline,
616
+ StableDiffusionXLControlNetPipeline,
617
+ StableDiffusionXLImg2ImgPipeline,
618
+ StableDiffusionXLInpaintPipeline,
619
+ StableDiffusionXLInstructPix2PixPipeline,
620
+ StableDiffusionXLPipeline,
621
+ StableUnCLIPImg2ImgPipeline,
622
+ StableUnCLIPPipeline,
623
+ TextToVideoSDPipeline,
624
+ TextToVideoZeroPipeline,
625
+ UnCLIPImageVariationPipeline,
626
+ UnCLIPPipeline,
627
+ UniDiffuserModel,
628
+ UniDiffuserPipeline,
629
+ UniDiffuserTextDecoder,
630
+ VersatileDiffusionDualGuidedPipeline,
631
+ VersatileDiffusionImageVariationPipeline,
632
+ VersatileDiffusionPipeline,
633
+ VersatileDiffusionTextToImagePipeline,
634
+ VideoToVideoSDPipeline,
635
+ VQDiffusionPipeline,
636
+ WuerstchenCombinedPipeline,
637
+ WuerstchenDecoderPipeline,
638
+ WuerstchenPriorPipeline,
639
+ )
640
+
641
+ try:
642
+ if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
643
+ raise OptionalDependencyNotAvailable()
644
+ except OptionalDependencyNotAvailable:
645
+ from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
646
+ else:
647
+ from .pipelines import StableDiffusionKDiffusionPipeline
648
+
649
+ try:
650
+ if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
651
+ raise OptionalDependencyNotAvailable()
652
+ except OptionalDependencyNotAvailable:
653
+ from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
654
+ else:
655
+ from .pipelines import (
656
+ OnnxStableDiffusionImg2ImgPipeline,
657
+ OnnxStableDiffusionInpaintPipeline,
658
+ OnnxStableDiffusionInpaintPipelineLegacy,
659
+ OnnxStableDiffusionPipeline,
660
+ OnnxStableDiffusionUpscalePipeline,
661
+ StableDiffusionOnnxPipeline,
662
+ )
663
+
664
+ try:
665
+ if not (is_torch_available() and is_librosa_available()):
666
+ raise OptionalDependencyNotAvailable()
667
+ except OptionalDependencyNotAvailable:
668
+ from .utils.dummy_torch_and_librosa_objects import * # noqa F403
669
+ else:
670
+ from .pipelines import AudioDiffusionPipeline, Mel
671
+
672
+ try:
673
+ if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
674
+ raise OptionalDependencyNotAvailable()
675
+ except OptionalDependencyNotAvailable:
676
+ from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
677
+ else:
678
+ from .pipelines import SpectrogramDiffusionPipeline
679
+
680
+ try:
681
+ if not is_flax_available():
682
+ raise OptionalDependencyNotAvailable()
683
+ except OptionalDependencyNotAvailable:
684
+ from .utils.dummy_flax_objects import * # noqa F403
685
+ else:
686
+ from .models.controlnet_flax import FlaxControlNetModel
687
+ from .models.modeling_flax_utils import FlaxModelMixin
688
+ from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
689
+ from .models.vae_flax import FlaxAutoencoderKL
690
+ from .pipelines import FlaxDiffusionPipeline
691
+ from .schedulers import (
692
+ FlaxDDIMScheduler,
693
+ FlaxDDPMScheduler,
694
+ FlaxDPMSolverMultistepScheduler,
695
+ FlaxEulerDiscreteScheduler,
696
+ FlaxKarrasVeScheduler,
697
+ FlaxLMSDiscreteScheduler,
698
+ FlaxPNDMScheduler,
699
+ FlaxSchedulerMixin,
700
+ FlaxScoreSdeVeScheduler,
701
+ )
702
+
703
+ try:
704
+ if not (is_flax_available() and is_transformers_available()):
705
+ raise OptionalDependencyNotAvailable()
706
+ except OptionalDependencyNotAvailable:
707
+ from .utils.dummy_flax_and_transformers_objects import * # noqa F403
708
+ else:
709
+ from .pipelines import (
710
+ FlaxStableDiffusionControlNetPipeline,
711
+ FlaxStableDiffusionImg2ImgPipeline,
712
+ FlaxStableDiffusionInpaintPipeline,
713
+ FlaxStableDiffusionPipeline,
714
+ FlaxStableDiffusionXLPipeline,
715
+ )
716
+
717
+ try:
718
+ if not (is_note_seq_available()):
719
+ raise OptionalDependencyNotAvailable()
720
+ except OptionalDependencyNotAvailable:
721
+ from .utils.dummy_note_seq_objects import * # noqa F403
722
+ else:
723
+ from .pipelines import MidiProcessor
724
+
725
+ else:
726
+ import sys
727
+
728
+ sys.modules[__name__] = _LazyModule(
729
+ __name__,
730
+ globals()["__file__"],
731
+ _import_structure,
732
+ module_spec=__spec__,
733
+ extra_objects={"__version__": __version__},
734
+ )
diffusers/commands/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from argparse import ArgumentParser
17
+
18
+
19
+ class BaseDiffusersCLICommand(ABC):
20
+ @staticmethod
21
+ @abstractmethod
22
+ def register_subcommand(parser: ArgumentParser):
23
+ raise NotImplementedError()
24
+
25
+ @abstractmethod
26
+ def run(self):
27
+ raise NotImplementedError()
diffusers/commands/diffusers_cli.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ from .env import EnvironmentCommand
19
+ from .fp16_safetensors import FP16SafetensorsCommand
20
+
21
+
22
+ def main():
23
+ parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
24
+ commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
25
+
26
+ # Register commands
27
+ EnvironmentCommand.register_subcommand(commands_parser)
28
+ FP16SafetensorsCommand.register_subcommand(commands_parser)
29
+
30
+ # Let's go
31
+ args = parser.parse_args()
32
+
33
+ if not hasattr(args, "func"):
34
+ parser.print_help()
35
+ exit(1)
36
+
37
+ # Run
38
+ service = args.func(args)
39
+ service.run()
40
+
41
+
42
+ if __name__ == "__main__":
43
+ main()
diffusers/commands/env.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import platform
16
+ from argparse import ArgumentParser
17
+
18
+ import huggingface_hub
19
+
20
+ from .. import __version__ as version
21
+ from ..utils import is_accelerate_available, is_torch_available, is_transformers_available, is_xformers_available
22
+ from . import BaseDiffusersCLICommand
23
+
24
+
25
+ def info_command_factory(_):
26
+ return EnvironmentCommand()
27
+
28
+
29
+ class EnvironmentCommand(BaseDiffusersCLICommand):
30
+ @staticmethod
31
+ def register_subcommand(parser: ArgumentParser):
32
+ download_parser = parser.add_parser("env")
33
+ download_parser.set_defaults(func=info_command_factory)
34
+
35
+ def run(self):
36
+ hub_version = huggingface_hub.__version__
37
+
38
+ pt_version = "not installed"
39
+ pt_cuda_available = "NA"
40
+ if is_torch_available():
41
+ import torch
42
+
43
+ pt_version = torch.__version__
44
+ pt_cuda_available = torch.cuda.is_available()
45
+
46
+ transformers_version = "not installed"
47
+ if is_transformers_available():
48
+ import transformers
49
+
50
+ transformers_version = transformers.__version__
51
+
52
+ accelerate_version = "not installed"
53
+ if is_accelerate_available():
54
+ import accelerate
55
+
56
+ accelerate_version = accelerate.__version__
57
+
58
+ xformers_version = "not installed"
59
+ if is_xformers_available():
60
+ import xformers
61
+
62
+ xformers_version = xformers.__version__
63
+
64
+ info = {
65
+ "`diffusers` version": version,
66
+ "Platform": platform.platform(),
67
+ "Python version": platform.python_version(),
68
+ "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
69
+ "Huggingface_hub version": hub_version,
70
+ "Transformers version": transformers_version,
71
+ "Accelerate version": accelerate_version,
72
+ "xFormers version": xformers_version,
73
+ "Using GPU in script?": "<fill in>",
74
+ "Using distributed or parallel set-up in script?": "<fill in>",
75
+ }
76
+
77
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
78
+ print(self.format_dict(info))
79
+
80
+ return info
81
+
82
+ @staticmethod
83
+ def format_dict(d):
84
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
diffusers/commands/fp16_safetensors.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Usage example:
17
+ diffusers-cli fp16_safetensors --ckpt_id=openai/shap-e --fp16 --use_safetensors
18
+ """
19
+
20
+ import glob
21
+ import json
22
+ from argparse import ArgumentParser, Namespace
23
+ from importlib import import_module
24
+
25
+ import huggingface_hub
26
+ import torch
27
+ from huggingface_hub import hf_hub_download
28
+ from packaging import version
29
+
30
+ from ..utils import logging
31
+ from . import BaseDiffusersCLICommand
32
+
33
+
34
+ def conversion_command_factory(args: Namespace):
35
+ return FP16SafetensorsCommand(
36
+ args.ckpt_id,
37
+ args.fp16,
38
+ args.use_safetensors,
39
+ args.use_auth_token,
40
+ )
41
+
42
+
43
+ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
44
+ @staticmethod
45
+ def register_subcommand(parser: ArgumentParser):
46
+ conversion_parser = parser.add_parser("fp16_safetensors")
47
+ conversion_parser.add_argument(
48
+ "--ckpt_id",
49
+ type=str,
50
+ help="Repo id of the checkpoints on which to run the conversion. Example: 'openai/shap-e'.",
51
+ )
52
+ conversion_parser.add_argument(
53
+ "--fp16", action="store_true", help="If serializing the variables in FP16 precision."
54
+ )
55
+ conversion_parser.add_argument(
56
+ "--use_safetensors", action="store_true", help="If serializing in the safetensors format."
57
+ )
58
+ conversion_parser.add_argument(
59
+ "--use_auth_token",
60
+ action="store_true",
61
+ help="When working with checkpoints having private visibility. When used `huggingface-cli login` needs to be run beforehand.",
62
+ )
63
+ conversion_parser.set_defaults(func=conversion_command_factory)
64
+
65
+ def __init__(self, ckpt_id: str, fp16: bool, use_safetensors: bool, use_auth_token: bool):
66
+ self.logger = logging.get_logger("diffusers-cli/fp16_safetensors")
67
+ self.ckpt_id = ckpt_id
68
+ self.local_ckpt_dir = f"/tmp/{ckpt_id}"
69
+ self.fp16 = fp16
70
+
71
+ self.use_safetensors = use_safetensors
72
+
73
+ if not self.use_safetensors and not self.fp16:
74
+ raise NotImplementedError(
75
+ "When `use_safetensors` and `fp16` both are False, then this command is of no use."
76
+ )
77
+
78
+ self.use_auth_token = use_auth_token
79
+
80
+ def run(self):
81
+ if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"):
82
+ raise ImportError(
83
+ "The huggingface_hub version must be >= 0.9.0 to use this command. Please update your huggingface_hub"
84
+ " installation."
85
+ )
86
+ else:
87
+ from huggingface_hub import create_commit
88
+ from huggingface_hub._commit_api import CommitOperationAdd
89
+
90
+ model_index = hf_hub_download(repo_id=self.ckpt_id, filename="model_index.json", token=self.use_auth_token)
91
+ with open(model_index, "r") as f:
92
+ pipeline_class_name = json.load(f)["_class_name"]
93
+ pipeline_class = getattr(import_module("diffusers"), pipeline_class_name)
94
+ self.logger.info(f"Pipeline class imported: {pipeline_class_name}.")
95
+
96
+ # Load the appropriate pipeline. We could have use `DiffusionPipeline`
97
+ # here, but just to avoid any rough edge cases.
98
+ pipeline = pipeline_class.from_pretrained(
99
+ self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32, use_auth_token=self.use_auth_token
100
+ )
101
+ pipeline.save_pretrained(
102
+ self.local_ckpt_dir,
103
+ safe_serialization=True if self.use_safetensors else False,
104
+ variant="fp16" if self.fp16 else None,
105
+ )
106
+ self.logger.info(f"Pipeline locally saved to {self.local_ckpt_dir}.")
107
+
108
+ # Fetch all the paths.
109
+ if self.fp16:
110
+ modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.fp16.*")
111
+ elif self.use_safetensors:
112
+ modified_paths = glob.glob(f"{self.local_ckpt_dir}/*/*.safetensors")
113
+
114
+ # Prepare for the PR.
115
+ commit_message = f"Serialize variables with FP16: {self.fp16} and safetensors: {self.use_safetensors}."
116
+ operations = []
117
+ for path in modified_paths:
118
+ operations.append(CommitOperationAdd(path_in_repo="/".join(path.split("/")[4:]), path_or_fileobj=path))
119
+
120
+ # Open the PR.
121
+ commit_description = (
122
+ "Variables converted by the [`diffusers`' `fp16_safetensors`"
123
+ " CLI](https://github.com/huggingface/diffusers/blob/main/src/diffusers/commands/fp16_safetensors.py)."
124
+ )
125
+ hub_pr_url = create_commit(
126
+ repo_id=self.ckpt_id,
127
+ operations=operations,
128
+ commit_message=commit_message,
129
+ commit_description=commit_description,
130
+ repo_type="model",
131
+ create_pr=True,
132
+ ).pr_url
133
+ self.logger.info(f"PR created here: {hub_pr_url}.")
diffusers/configuration_utils.py ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ ConfigMixin base class and utilities."""
17
+ import dataclasses
18
+ import functools
19
+ import importlib
20
+ import inspect
21
+ import json
22
+ import os
23
+ import re
24
+ from collections import OrderedDict
25
+ from pathlib import PosixPath
26
+ from typing import Any, Dict, Tuple, Union
27
+
28
+ import numpy as np
29
+ from huggingface_hub import create_repo, hf_hub_download
30
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
31
+ from requests import HTTPError
32
+
33
+ from . import __version__
34
+ from .utils import (
35
+ DIFFUSERS_CACHE,
36
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
37
+ DummyObject,
38
+ deprecate,
39
+ extract_commit_hash,
40
+ http_user_agent,
41
+ logging,
42
+ )
43
+
44
+
45
+ logger = logging.get_logger(__name__)
46
+
47
+ _re_configuration_file = re.compile(r"config\.(.*)\.json")
48
+
49
+
50
+ class FrozenDict(OrderedDict):
51
+ def __init__(self, *args, **kwargs):
52
+ super().__init__(*args, **kwargs)
53
+
54
+ for key, value in self.items():
55
+ setattr(self, key, value)
56
+
57
+ self.__frozen = True
58
+
59
+ def __delitem__(self, *args, **kwargs):
60
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
61
+
62
+ def setdefault(self, *args, **kwargs):
63
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
64
+
65
+ def pop(self, *args, **kwargs):
66
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
67
+
68
+ def update(self, *args, **kwargs):
69
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
70
+
71
+ def __setattr__(self, name, value):
72
+ if hasattr(self, "__frozen") and self.__frozen:
73
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
74
+ super().__setattr__(name, value)
75
+
76
+ def __setitem__(self, name, value):
77
+ if hasattr(self, "__frozen") and self.__frozen:
78
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
79
+ super().__setitem__(name, value)
80
+
81
+
82
+ class ConfigMixin:
83
+ r"""
84
+ Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
85
+ provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
86
+ saving classes that inherit from [`ConfigMixin`].
87
+
88
+ Class attributes:
89
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
90
+ [`~ConfigMixin.save_config`] (should be overridden by parent class).
91
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
92
+ overridden by subclass).
93
+ - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
94
+ - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
95
+ should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
96
+ subclass).
97
+ """
98
+ config_name = None
99
+ ignore_for_config = []
100
+ has_compatibles = False
101
+
102
+ _deprecated_kwargs = []
103
+
104
+ def register_to_config(self, **kwargs):
105
+ if self.config_name is None:
106
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
107
+ # Special case for `kwargs` used in deprecation warning added to schedulers
108
+ # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
109
+ # or solve in a more general way.
110
+ kwargs.pop("kwargs", None)
111
+
112
+ if not hasattr(self, "_internal_dict"):
113
+ internal_dict = kwargs
114
+ else:
115
+ previous_dict = dict(self._internal_dict)
116
+ internal_dict = {**self._internal_dict, **kwargs}
117
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
118
+
119
+ self._internal_dict = FrozenDict(internal_dict)
120
+
121
+ def __getattr__(self, name: str) -> Any:
122
+ """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
123
+ config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
124
+
125
+ Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
126
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
127
+ """
128
+
129
+ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
130
+ is_attribute = name in self.__dict__
131
+
132
+ if is_in_config and not is_attribute:
133
+ deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
134
+ deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
135
+ return self._internal_dict[name]
136
+
137
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
138
+
139
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
140
+ """
141
+ Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
142
+ [`~ConfigMixin.from_config`] class method.
143
+
144
+ Args:
145
+ save_directory (`str` or `os.PathLike`):
146
+ Directory where the configuration JSON file is saved (will be created if it does not exist).
147
+ push_to_hub (`bool`, *optional*, defaults to `False`):
148
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
149
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
150
+ namespace).
151
+ kwargs (`Dict[str, Any]`, *optional*):
152
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
153
+ """
154
+ if os.path.isfile(save_directory):
155
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
156
+
157
+ os.makedirs(save_directory, exist_ok=True)
158
+
159
+ # If we save using the predefined names, we can load using `from_config`
160
+ output_config_file = os.path.join(save_directory, self.config_name)
161
+
162
+ self.to_json_file(output_config_file)
163
+ logger.info(f"Configuration saved in {output_config_file}")
164
+
165
+ if push_to_hub:
166
+ commit_message = kwargs.pop("commit_message", None)
167
+ private = kwargs.pop("private", False)
168
+ create_pr = kwargs.pop("create_pr", False)
169
+ token = kwargs.pop("token", None)
170
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
171
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
172
+
173
+ self._upload_folder(
174
+ save_directory,
175
+ repo_id,
176
+ token=token,
177
+ commit_message=commit_message,
178
+ create_pr=create_pr,
179
+ )
180
+
181
+ @classmethod
182
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
183
+ r"""
184
+ Instantiate a Python class from a config dictionary.
185
+
186
+ Parameters:
187
+ config (`Dict[str, Any]`):
188
+ A config dictionary from which the Python class is instantiated. Make sure to only load configuration
189
+ files of compatible classes.
190
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
191
+ Whether kwargs that are not consumed by the Python class should be returned or not.
192
+ kwargs (remaining dictionary of keyword arguments, *optional*):
193
+ Can be used to update the configuration object (after it is loaded) and initiate the Python class.
194
+ `**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
195
+ overwrite the same named arguments in `config`.
196
+
197
+ Returns:
198
+ [`ModelMixin`] or [`SchedulerMixin`]:
199
+ A model or scheduler object instantiated from a config dictionary.
200
+
201
+ Examples:
202
+
203
+ ```python
204
+ >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
205
+
206
+ >>> # Download scheduler from huggingface.co and cache.
207
+ >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
208
+
209
+ >>> # Instantiate DDIM scheduler class with same config as DDPM
210
+ >>> scheduler = DDIMScheduler.from_config(scheduler.config)
211
+
212
+ >>> # Instantiate PNDM scheduler class with same config as DDPM
213
+ >>> scheduler = PNDMScheduler.from_config(scheduler.config)
214
+ ```
215
+ """
216
+ # <===== TO BE REMOVED WITH DEPRECATION
217
+ # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
218
+ if "pretrained_model_name_or_path" in kwargs:
219
+ config = kwargs.pop("pretrained_model_name_or_path")
220
+
221
+ if config is None:
222
+ raise ValueError("Please make sure to provide a config as the first positional argument.")
223
+ # ======>
224
+
225
+ if not isinstance(config, dict):
226
+ deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
227
+ if "Scheduler" in cls.__name__:
228
+ deprecation_message += (
229
+ f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
230
+ " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
231
+ " be removed in v1.0.0."
232
+ )
233
+ elif "Model" in cls.__name__:
234
+ deprecation_message += (
235
+ f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
236
+ f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
237
+ " instead. This functionality will be removed in v1.0.0."
238
+ )
239
+ deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
240
+ config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
241
+
242
+ init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
243
+
244
+ # Allow dtype to be specified on initialization
245
+ if "dtype" in unused_kwargs:
246
+ init_dict["dtype"] = unused_kwargs.pop("dtype")
247
+
248
+ # add possible deprecated kwargs
249
+ for deprecated_kwarg in cls._deprecated_kwargs:
250
+ if deprecated_kwarg in unused_kwargs:
251
+ init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
252
+
253
+ # Return model and optionally state and/or unused_kwargs
254
+ model = cls(**init_dict)
255
+
256
+ # make sure to also save config parameters that might be used for compatible classes
257
+ model.register_to_config(**hidden_dict)
258
+
259
+ # add hidden kwargs of compatible classes to unused_kwargs
260
+ unused_kwargs = {**unused_kwargs, **hidden_dict}
261
+
262
+ if return_unused_kwargs:
263
+ return (model, unused_kwargs)
264
+ else:
265
+ return model
266
+
267
+ @classmethod
268
+ def get_config_dict(cls, *args, **kwargs):
269
+ deprecation_message = (
270
+ f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
271
+ " removed in version v1.0.0"
272
+ )
273
+ deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
274
+ return cls.load_config(*args, **kwargs)
275
+
276
+ @classmethod
277
+ def load_config(
278
+ cls,
279
+ pretrained_model_name_or_path: Union[str, os.PathLike],
280
+ return_unused_kwargs=False,
281
+ return_commit_hash=False,
282
+ **kwargs,
283
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
284
+ r"""
285
+ Load a model or scheduler configuration.
286
+
287
+ Parameters:
288
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
289
+ Can be either:
290
+
291
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
292
+ the Hub.
293
+ - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
294
+ [`~ConfigMixin.save_config`].
295
+
296
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
297
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
298
+ is not used.
299
+ force_download (`bool`, *optional*, defaults to `False`):
300
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
301
+ cached versions if they exist.
302
+ resume_download (`bool`, *optional*, defaults to `False`):
303
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
304
+ incompletely downloaded files are deleted.
305
+ proxies (`Dict[str, str]`, *optional*):
306
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
307
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
308
+ output_loading_info(`bool`, *optional*, defaults to `False`):
309
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
310
+ local_files_only (`bool`, *optional*, defaults to `False`):
311
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
312
+ won't be downloaded from the Hub.
313
+ use_auth_token (`str` or *bool*, *optional*):
314
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
315
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
316
+ revision (`str`, *optional*, defaults to `"main"`):
317
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
318
+ allowed by Git.
319
+ subfolder (`str`, *optional*, defaults to `""`):
320
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
321
+ return_unused_kwargs (`bool`, *optional*, defaults to `False):
322
+ Whether unused keyword arguments of the config are returned.
323
+ return_commit_hash (`bool`, *optional*, defaults to `False):
324
+ Whether the `commit_hash` of the loaded configuration are returned.
325
+
326
+ Returns:
327
+ `dict`:
328
+ A dictionary of all the parameters stored in a JSON configuration file.
329
+
330
+ """
331
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
332
+ force_download = kwargs.pop("force_download", False)
333
+ resume_download = kwargs.pop("resume_download", False)
334
+ proxies = kwargs.pop("proxies", None)
335
+ use_auth_token = kwargs.pop("use_auth_token", None)
336
+ local_files_only = kwargs.pop("local_files_only", False)
337
+ revision = kwargs.pop("revision", None)
338
+ _ = kwargs.pop("mirror", None)
339
+ subfolder = kwargs.pop("subfolder", None)
340
+ user_agent = kwargs.pop("user_agent", {})
341
+
342
+ user_agent = {**user_agent, "file_type": "config"}
343
+ user_agent = http_user_agent(user_agent)
344
+
345
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
346
+
347
+ if cls.config_name is None:
348
+ raise ValueError(
349
+ "`self.config_name` is not defined. Note that one should not load a config from "
350
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
351
+ )
352
+
353
+ if os.path.isfile(pretrained_model_name_or_path):
354
+ config_file = pretrained_model_name_or_path
355
+ elif os.path.isdir(pretrained_model_name_or_path):
356
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
357
+ # Load from a PyTorch checkpoint
358
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
359
+ elif subfolder is not None and os.path.isfile(
360
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
361
+ ):
362
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
363
+ else:
364
+ raise EnvironmentError(
365
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
366
+ )
367
+ else:
368
+ try:
369
+ # Load from URL or cache if already cached
370
+ config_file = hf_hub_download(
371
+ pretrained_model_name_or_path,
372
+ filename=cls.config_name,
373
+ cache_dir=cache_dir,
374
+ force_download=force_download,
375
+ proxies=proxies,
376
+ resume_download=resume_download,
377
+ local_files_only=local_files_only,
378
+ use_auth_token=use_auth_token,
379
+ user_agent=user_agent,
380
+ subfolder=subfolder,
381
+ revision=revision,
382
+ )
383
+ except RepositoryNotFoundError:
384
+ raise EnvironmentError(
385
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
386
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
387
+ " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
388
+ " login`."
389
+ )
390
+ except RevisionNotFoundError:
391
+ raise EnvironmentError(
392
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
393
+ " this model name. Check the model page at"
394
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
395
+ )
396
+ except EntryNotFoundError:
397
+ raise EnvironmentError(
398
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
399
+ )
400
+ except HTTPError as err:
401
+ raise EnvironmentError(
402
+ "There was a specific connection error when trying to load"
403
+ f" {pretrained_model_name_or_path}:\n{err}"
404
+ )
405
+ except ValueError:
406
+ raise EnvironmentError(
407
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
408
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
409
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
410
+ " run the library in offline mode at"
411
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
412
+ )
413
+ except EnvironmentError:
414
+ raise EnvironmentError(
415
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
416
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
417
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
418
+ f"containing a {cls.config_name} file"
419
+ )
420
+
421
+ try:
422
+ # Load config dict
423
+ config_dict = cls._dict_from_json_file(config_file)
424
+
425
+ commit_hash = extract_commit_hash(config_file)
426
+ except (json.JSONDecodeError, UnicodeDecodeError):
427
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
428
+
429
+ if not (return_unused_kwargs or return_commit_hash):
430
+ return config_dict
431
+
432
+ outputs = (config_dict,)
433
+
434
+ if return_unused_kwargs:
435
+ outputs += (kwargs,)
436
+
437
+ if return_commit_hash:
438
+ outputs += (commit_hash,)
439
+
440
+ return outputs
441
+
442
+ @staticmethod
443
+ def _get_init_keys(cls):
444
+ return set(dict(inspect.signature(cls.__init__).parameters).keys())
445
+
446
+ @classmethod
447
+ def extract_init_dict(cls, config_dict, **kwargs):
448
+ # Skip keys that were not present in the original config, so default __init__ values were used
449
+ used_defaults = config_dict.get("_use_default_values", [])
450
+ config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
451
+
452
+ # 0. Copy origin config dict
453
+ original_dict = dict(config_dict.items())
454
+
455
+ # 1. Retrieve expected config attributes from __init__ signature
456
+ expected_keys = cls._get_init_keys(cls)
457
+ expected_keys.remove("self")
458
+ # remove general kwargs if present in dict
459
+ if "kwargs" in expected_keys:
460
+ expected_keys.remove("kwargs")
461
+ # remove flax internal keys
462
+ if hasattr(cls, "_flax_internal_args"):
463
+ for arg in cls._flax_internal_args:
464
+ expected_keys.remove(arg)
465
+
466
+ # 2. Remove attributes that cannot be expected from expected config attributes
467
+ # remove keys to be ignored
468
+ if len(cls.ignore_for_config) > 0:
469
+ expected_keys = expected_keys - set(cls.ignore_for_config)
470
+
471
+ # load diffusers library to import compatible and original scheduler
472
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
473
+
474
+ if cls.has_compatibles:
475
+ compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
476
+ else:
477
+ compatible_classes = []
478
+
479
+ expected_keys_comp_cls = set()
480
+ for c in compatible_classes:
481
+ expected_keys_c = cls._get_init_keys(c)
482
+ expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
483
+ expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
484
+ config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
485
+
486
+ # remove attributes from orig class that cannot be expected
487
+ orig_cls_name = config_dict.pop("_class_name", cls.__name__)
488
+ if (
489
+ isinstance(orig_cls_name, str)
490
+ and orig_cls_name != cls.__name__
491
+ and hasattr(diffusers_library, orig_cls_name)
492
+ ):
493
+ orig_cls = getattr(diffusers_library, orig_cls_name)
494
+ unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
495
+ config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
496
+ elif not isinstance(orig_cls_name, str) and not isinstance(orig_cls_name, (list, tuple)):
497
+ raise ValueError(
498
+ "Make sure that the `_class_name` is of type string or list of string (for custom pipelines)."
499
+ )
500
+
501
+ # remove private attributes
502
+ config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
503
+
504
+ # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
505
+ init_dict = {}
506
+ for key in expected_keys:
507
+ # if config param is passed to kwarg and is present in config dict
508
+ # it should overwrite existing config dict key
509
+ if key in kwargs and key in config_dict:
510
+ config_dict[key] = kwargs.pop(key)
511
+
512
+ if key in kwargs:
513
+ # overwrite key
514
+ init_dict[key] = kwargs.pop(key)
515
+ elif key in config_dict:
516
+ # use value from config dict
517
+ init_dict[key] = config_dict.pop(key)
518
+
519
+ # 4. Give nice warning if unexpected values have been passed
520
+ if len(config_dict) > 0:
521
+ logger.warning(
522
+ f"The config attributes {config_dict} were passed to {cls.__name__}, "
523
+ "but are not expected and will be ignored. Please verify your "
524
+ f"{cls.config_name} configuration file."
525
+ )
526
+
527
+ # 5. Give nice info if config attributes are initiliazed to default because they have not been passed
528
+ passed_keys = set(init_dict.keys())
529
+ if len(expected_keys - passed_keys) > 0:
530
+ logger.info(
531
+ f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
532
+ )
533
+
534
+ # 6. Define unused keyword arguments
535
+ unused_kwargs = {**config_dict, **kwargs}
536
+
537
+ # 7. Define "hidden" config parameters that were saved for compatible classes
538
+ hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
539
+
540
+ return init_dict, unused_kwargs, hidden_config_dict
541
+
542
+ @classmethod
543
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
544
+ with open(json_file, "r", encoding="utf-8") as reader:
545
+ text = reader.read()
546
+ return json.loads(text)
547
+
548
+ def __repr__(self):
549
+ return f"{self.__class__.__name__} {self.to_json_string()}"
550
+
551
+ @property
552
+ def config(self) -> Dict[str, Any]:
553
+ """
554
+ Returns the config of the class as a frozen dictionary
555
+
556
+ Returns:
557
+ `Dict[str, Any]`: Config of the class.
558
+ """
559
+ return self._internal_dict
560
+
561
+ def to_json_string(self) -> str:
562
+ """
563
+ Serializes the configuration instance to a JSON string.
564
+
565
+ Returns:
566
+ `str`:
567
+ String containing all the attributes that make up the configuration instance in JSON format.
568
+ """
569
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
570
+ config_dict["_class_name"] = self.__class__.__name__
571
+ config_dict["_diffusers_version"] = __version__
572
+
573
+ def to_json_saveable(value):
574
+ if isinstance(value, np.ndarray):
575
+ value = value.tolist()
576
+ elif isinstance(value, PosixPath):
577
+ value = str(value)
578
+ return value
579
+
580
+ config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
581
+ # Don't save "_ignore_files" or "_use_default_values"
582
+ config_dict.pop("_ignore_files", None)
583
+ config_dict.pop("_use_default_values", None)
584
+
585
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
586
+
587
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
588
+ """
589
+ Save the configuration instance's parameters to a JSON file.
590
+
591
+ Args:
592
+ json_file_path (`str` or `os.PathLike`):
593
+ Path to the JSON file to save a configuration instance's parameters.
594
+ """
595
+ with open(json_file_path, "w", encoding="utf-8") as writer:
596
+ writer.write(self.to_json_string())
597
+
598
+
599
+ def register_to_config(init):
600
+ r"""
601
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
602
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
603
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
604
+
605
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
606
+ """
607
+
608
+ @functools.wraps(init)
609
+ def inner_init(self, *args, **kwargs):
610
+ # Ignore private kwargs in the init.
611
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
612
+ config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
613
+ if not isinstance(self, ConfigMixin):
614
+ raise RuntimeError(
615
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
616
+ "not inherit from `ConfigMixin`."
617
+ )
618
+
619
+ ignore = getattr(self, "ignore_for_config", [])
620
+ # Get positional arguments aligned with kwargs
621
+ new_kwargs = {}
622
+ signature = inspect.signature(init)
623
+ parameters = {
624
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
625
+ }
626
+ for arg, name in zip(args, parameters.keys()):
627
+ new_kwargs[name] = arg
628
+
629
+ # Then add all kwargs
630
+ new_kwargs.update(
631
+ {
632
+ k: init_kwargs.get(k, default)
633
+ for k, default in parameters.items()
634
+ if k not in ignore and k not in new_kwargs
635
+ }
636
+ )
637
+
638
+ # Take note of the parameters that were not present in the loaded config
639
+ if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
640
+ new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
641
+
642
+ new_kwargs = {**config_init_kwargs, **new_kwargs}
643
+ getattr(self, "register_to_config")(**new_kwargs)
644
+ init(self, *args, **init_kwargs)
645
+
646
+ return inner_init
647
+
648
+
649
+ def flax_register_to_config(cls):
650
+ original_init = cls.__init__
651
+
652
+ @functools.wraps(original_init)
653
+ def init(self, *args, **kwargs):
654
+ if not isinstance(self, ConfigMixin):
655
+ raise RuntimeError(
656
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
657
+ "not inherit from `ConfigMixin`."
658
+ )
659
+
660
+ # Ignore private kwargs in the init. Retrieve all passed attributes
661
+ init_kwargs = dict(kwargs.items())
662
+
663
+ # Retrieve default values
664
+ fields = dataclasses.fields(self)
665
+ default_kwargs = {}
666
+ for field in fields:
667
+ # ignore flax specific attributes
668
+ if field.name in self._flax_internal_args:
669
+ continue
670
+ if type(field.default) == dataclasses._MISSING_TYPE:
671
+ default_kwargs[field.name] = None
672
+ else:
673
+ default_kwargs[field.name] = getattr(self, field.name)
674
+
675
+ # Make sure init_kwargs override default kwargs
676
+ new_kwargs = {**default_kwargs, **init_kwargs}
677
+ # dtype should be part of `init_kwargs`, but not `new_kwargs`
678
+ if "dtype" in new_kwargs:
679
+ new_kwargs.pop("dtype")
680
+
681
+ # Get positional arguments aligned with kwargs
682
+ for i, arg in enumerate(args):
683
+ name = fields[i].name
684
+ new_kwargs[name] = arg
685
+
686
+ # Take note of the parameters that were not present in the loaded config
687
+ if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
688
+ new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
689
+
690
+ getattr(self, "register_to_config")(**new_kwargs)
691
+ original_init(self, *args, **kwargs)
692
+
693
+ cls.__init__ = init
694
+ return cls
diffusers/dependency_versions_check.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import sys
15
+
16
+ from .dependency_versions_table import deps
17
+ from .utils.versions import require_version, require_version_core
18
+
19
+
20
+ # define which module versions we always want to check at run time
21
+ # (usually the ones defined in `install_requires` in setup.py)
22
+ #
23
+ # order specific notes:
24
+ # - tqdm must be checked before tokenizers
25
+
26
+ pkgs_to_check_at_runtime = "python requests filelock numpy".split()
27
+ for pkg in pkgs_to_check_at_runtime:
28
+ if pkg in deps:
29
+ require_version_core(deps[pkg])
30
+ else:
31
+ raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
32
+
33
+
34
+ def dep_version_check(pkg, hint=None):
35
+ require_version(deps[pkg], hint)
diffusers/dependency_versions_table.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # THIS FILE HAS BEEN AUTOGENERATED. To update:
2
+ # 1. modify the `_deps` dict in setup.py
3
+ # 2. run `make deps_table_update``
4
+ deps = {
5
+ "Pillow": "Pillow",
6
+ "accelerate": "accelerate>=0.11.0",
7
+ "compel": "compel==0.1.8",
8
+ "black": "black~=23.1",
9
+ "datasets": "datasets",
10
+ "filelock": "filelock",
11
+ "flax": "flax>=0.4.1",
12
+ "hf-doc-builder": "hf-doc-builder>=0.3.0",
13
+ "huggingface-hub": "huggingface-hub>=0.13.2",
14
+ "requests-mock": "requests-mock==1.10.0",
15
+ "importlib_metadata": "importlib_metadata",
16
+ "invisible-watermark": "invisible-watermark>=0.2.0",
17
+ "isort": "isort>=5.5.4",
18
+ "jax": "jax>=0.4.1",
19
+ "jaxlib": "jaxlib>=0.4.1",
20
+ "Jinja2": "Jinja2",
21
+ "k-diffusion": "k-diffusion>=0.0.12",
22
+ "torchsde": "torchsde",
23
+ "note_seq": "note_seq",
24
+ "librosa": "librosa",
25
+ "numpy": "numpy",
26
+ "omegaconf": "omegaconf",
27
+ "parameterized": "parameterized",
28
+ "peft": "peft<=0.6.2",
29
+ "protobuf": "protobuf>=3.20.3,<4",
30
+ "pytest": "pytest",
31
+ "pytest-timeout": "pytest-timeout",
32
+ "pytest-xdist": "pytest-xdist",
33
+ "python": "python>=3.8.0",
34
+ "ruff": "ruff==0.0.280",
35
+ "safetensors": "safetensors>=0.3.1",
36
+ "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
37
+ "scipy": "scipy",
38
+ "onnx": "onnx",
39
+ "regex": "regex!=2019.12.17",
40
+ "requests": "requests",
41
+ "tensorboard": "tensorboard",
42
+ "torch": "torch>=1.4",
43
+ "torchvision": "torchvision",
44
+ "transformers": "transformers>=4.25.1",
45
+ "urllib3": "urllib3<=2.0.0",
46
+ }
diffusers/experimental/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # 🧨 Diffusers Experimental
2
+
3
+ We are adding experimental code to support novel applications and usages of the Diffusers library.
4
+ Currently, the following experiments are supported:
5
+ * Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model.
diffusers/experimental/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .rl import ValueGuidedRLPipeline
diffusers/experimental/rl/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .value_guided_sampling import ValueGuidedRLPipeline
diffusers/experimental/rl/value_guided_sampling.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import torch
17
+ import tqdm
18
+
19
+ from ...models.unet_1d import UNet1DModel
20
+ from ...pipelines import DiffusionPipeline
21
+ from ...utils.dummy_pt_objects import DDPMScheduler
22
+ from ...utils.torch_utils import randn_tensor
23
+
24
+
25
+ class ValueGuidedRLPipeline(DiffusionPipeline):
26
+ r"""
27
+ Pipeline for value-guided sampling from a diffusion model trained to predict sequences of states.
28
+
29
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
30
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
31
+
32
+ Parameters:
33
+ value_function ([`UNet1DModel`]):
34
+ A specialized UNet for fine-tuning trajectories base on reward.
35
+ unet ([`UNet1DModel`]):
36
+ UNet architecture to denoise the encoded trajectories.
37
+ scheduler ([`SchedulerMixin`]):
38
+ A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this
39
+ application is [`DDPMScheduler`].
40
+ env ():
41
+ An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ value_function: UNet1DModel,
47
+ unet: UNet1DModel,
48
+ scheduler: DDPMScheduler,
49
+ env,
50
+ ):
51
+ super().__init__()
52
+ self.value_function = value_function
53
+ self.unet = unet
54
+ self.scheduler = scheduler
55
+ self.env = env
56
+ self.data = env.get_dataset()
57
+ self.means = {}
58
+ for key in self.data.keys():
59
+ try:
60
+ self.means[key] = self.data[key].mean()
61
+ except: # noqa: E722
62
+ pass
63
+ self.stds = {}
64
+ for key in self.data.keys():
65
+ try:
66
+ self.stds[key] = self.data[key].std()
67
+ except: # noqa: E722
68
+ pass
69
+ self.state_dim = env.observation_space.shape[0]
70
+ self.action_dim = env.action_space.shape[0]
71
+
72
+ def normalize(self, x_in, key):
73
+ return (x_in - self.means[key]) / self.stds[key]
74
+
75
+ def de_normalize(self, x_in, key):
76
+ return x_in * self.stds[key] + self.means[key]
77
+
78
+ def to_torch(self, x_in):
79
+ if isinstance(x_in, dict):
80
+ return {k: self.to_torch(v) for k, v in x_in.items()}
81
+ elif torch.is_tensor(x_in):
82
+ return x_in.to(self.unet.device)
83
+ return torch.tensor(x_in, device=self.unet.device)
84
+
85
+ def reset_x0(self, x_in, cond, act_dim):
86
+ for key, val in cond.items():
87
+ x_in[:, key, act_dim:] = val.clone()
88
+ return x_in
89
+
90
+ def run_diffusion(self, x, conditions, n_guide_steps, scale):
91
+ batch_size = x.shape[0]
92
+ y = None
93
+ for i in tqdm.tqdm(self.scheduler.timesteps):
94
+ # create batch of timesteps to pass into model
95
+ timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
96
+ for _ in range(n_guide_steps):
97
+ with torch.enable_grad():
98
+ x.requires_grad_()
99
+
100
+ # permute to match dimension for pre-trained models
101
+ y = self.value_function(x.permute(0, 2, 1), timesteps).sample
102
+ grad = torch.autograd.grad([y.sum()], [x])[0]
103
+
104
+ posterior_variance = self.scheduler._get_variance(i)
105
+ model_std = torch.exp(0.5 * posterior_variance)
106
+ grad = model_std * grad
107
+
108
+ grad[timesteps < 2] = 0
109
+ x = x.detach()
110
+ x = x + scale * grad
111
+ x = self.reset_x0(x, conditions, self.action_dim)
112
+
113
+ prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
114
+
115
+ # TODO: verify deprecation of this kwarg
116
+ x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
117
+
118
+ # apply conditions to the trajectory (set the initial state)
119
+ x = self.reset_x0(x, conditions, self.action_dim)
120
+ x = self.to_torch(x)
121
+ return x, y
122
+
123
+ def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
124
+ # normalize the observations and create batch dimension
125
+ obs = self.normalize(obs, "observations")
126
+ obs = obs[None].repeat(batch_size, axis=0)
127
+
128
+ conditions = {0: self.to_torch(obs)}
129
+ shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
130
+
131
+ # generate initial noise and apply our conditions (to make the trajectories start at current state)
132
+ x1 = randn_tensor(shape, device=self.unet.device)
133
+ x = self.reset_x0(x1, conditions, self.action_dim)
134
+ x = self.to_torch(x)
135
+
136
+ # run the diffusion process
137
+ x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
138
+
139
+ # sort output trajectories by value
140
+ sorted_idx = y.argsort(0, descending=True).squeeze()
141
+ sorted_values = x[sorted_idx]
142
+ actions = sorted_values[:, :, : self.action_dim]
143
+ actions = actions.detach().cpu().numpy()
144
+ denorm_actions = self.de_normalize(actions, key="actions")
145
+
146
+ # select the action with the highest value
147
+ if y is not None:
148
+ selected_index = 0
149
+ else:
150
+ # if we didn't run value guiding, select a random action
151
+ selected_index = np.random.randint(0, batch_size)
152
+
153
+ denorm_actions = denorm_actions[selected_index, 0]
154
+ return denorm_actions
diffusers/image_processor.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import warnings
16
+ from typing import List, Optional, Union
17
+
18
+ import numpy as np
19
+ import PIL.Image
20
+ import torch
21
+ from PIL import Image
22
+
23
+ from .configuration_utils import ConfigMixin, register_to_config
24
+ from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
25
+
26
+
27
+ PipelineImageInput = Union[
28
+ PIL.Image.Image,
29
+ np.ndarray,
30
+ torch.FloatTensor,
31
+ List[PIL.Image.Image],
32
+ List[np.ndarray],
33
+ List[torch.FloatTensor],
34
+ ]
35
+
36
+
37
+ class VaeImageProcessor(ConfigMixin):
38
+ """
39
+ Image processor for VAE.
40
+
41
+ Args:
42
+ do_resize (`bool`, *optional*, defaults to `True`):
43
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
44
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
45
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
46
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
47
+ resample (`str`, *optional*, defaults to `lanczos`):
48
+ Resampling filter to use when resizing the image.
49
+ do_normalize (`bool`, *optional*, defaults to `True`):
50
+ Whether to normalize the image to [-1,1].
51
+ do_binarize (`bool`, *optional*, defaults to `False`):
52
+ Whether to binarize the image to 0/1.
53
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
54
+ Whether to convert the images to RGB format.
55
+ do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
56
+ Whether to convert the images to grayscale format.
57
+ """
58
+
59
+ config_name = CONFIG_NAME
60
+
61
+ @register_to_config
62
+ def __init__(
63
+ self,
64
+ do_resize: bool = True,
65
+ vae_scale_factor: int = 8,
66
+ resample: str = "lanczos",
67
+ do_normalize: bool = True,
68
+ do_binarize: bool = False,
69
+ do_convert_rgb: bool = False,
70
+ do_convert_grayscale: bool = False,
71
+ ):
72
+ super().__init__()
73
+ if do_convert_rgb and do_convert_grayscale:
74
+ raise ValueError(
75
+ "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
76
+ " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
77
+ " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
78
+ )
79
+ self.config.do_convert_rgb = False
80
+
81
+ @staticmethod
82
+ def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
83
+ """
84
+ Convert a numpy image or a batch of images to a PIL image.
85
+ """
86
+ if images.ndim == 3:
87
+ images = images[None, ...]
88
+ images = (images * 255).round().astype("uint8")
89
+ if images.shape[-1] == 1:
90
+ # special case for grayscale (single channel) images
91
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
92
+ else:
93
+ pil_images = [Image.fromarray(image) for image in images]
94
+
95
+ return pil_images
96
+
97
+ @staticmethod
98
+ def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
99
+ """
100
+ Convert a PIL image or a list of PIL images to NumPy arrays.
101
+ """
102
+ if not isinstance(images, list):
103
+ images = [images]
104
+ images = [np.array(image).astype(np.float32) / 255.0 for image in images]
105
+ images = np.stack(images, axis=0)
106
+
107
+ return images
108
+
109
+ @staticmethod
110
+ def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
111
+ """
112
+ Convert a NumPy image to a PyTorch tensor.
113
+ """
114
+ if images.ndim == 3:
115
+ images = images[..., None]
116
+
117
+ images = torch.from_numpy(images.transpose(0, 3, 1, 2))
118
+ return images
119
+
120
+ @staticmethod
121
+ def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
122
+ """
123
+ Convert a PyTorch tensor to a NumPy image.
124
+ """
125
+ images = images.cpu().permute(0, 2, 3, 1).float().numpy()
126
+ return images
127
+
128
+ @staticmethod
129
+ def normalize(images):
130
+ """
131
+ Normalize an image array to [-1,1].
132
+ """
133
+ return 2.0 * images - 1.0
134
+
135
+ @staticmethod
136
+ def denormalize(images):
137
+ """
138
+ Denormalize an image array to [0,1].
139
+ """
140
+ return (images / 2 + 0.5).clamp(0, 1)
141
+
142
+ @staticmethod
143
+ def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
144
+ """
145
+ Converts a PIL image to RGB format.
146
+ """
147
+ image = image.convert("RGB")
148
+
149
+ return image
150
+
151
+ @staticmethod
152
+ def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
153
+ """
154
+ Converts a PIL image to grayscale format.
155
+ """
156
+ image = image.convert("L")
157
+
158
+ return image
159
+
160
+ def get_default_height_width(
161
+ self,
162
+ image: [PIL.Image.Image, np.ndarray, torch.Tensor],
163
+ height: Optional[int] = None,
164
+ width: Optional[int] = None,
165
+ ):
166
+ """
167
+ This function return the height and width that are downscaled to the next integer multiple of
168
+ `vae_scale_factor`.
169
+
170
+ Args:
171
+ image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
172
+ The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
173
+ shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
174
+ have shape `[batch, channel, height, width]`.
175
+ height (`int`, *optional*, defaults to `None`):
176
+ The height in preprocessed image. If `None`, will use the height of `image` input.
177
+ width (`int`, *optional*`, defaults to `None`):
178
+ The width in preprocessed. If `None`, will use the width of the `image` input.
179
+ """
180
+
181
+ if height is None:
182
+ if isinstance(image, PIL.Image.Image):
183
+ height = image.height
184
+ elif isinstance(image, torch.Tensor):
185
+ height = image.shape[2]
186
+ else:
187
+ height = image.shape[1]
188
+
189
+ if width is None:
190
+ if isinstance(image, PIL.Image.Image):
191
+ width = image.width
192
+ elif isinstance(image, torch.Tensor):
193
+ width = image.shape[3]
194
+ else:
195
+ width = image.shape[2]
196
+
197
+ width, height = (
198
+ x - x % self.config.vae_scale_factor for x in (width, height)
199
+ ) # resize to integer multiple of vae_scale_factor
200
+
201
+ return height, width
202
+
203
+ def resize(
204
+ self,
205
+ image: [PIL.Image.Image, np.ndarray, torch.Tensor],
206
+ height: Optional[int] = None,
207
+ width: Optional[int] = None,
208
+ ) -> [PIL.Image.Image, np.ndarray, torch.Tensor]:
209
+ """
210
+ Resize image.
211
+ """
212
+ if isinstance(image, PIL.Image.Image):
213
+ image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
214
+ elif isinstance(image, torch.Tensor):
215
+ image = torch.nn.functional.interpolate(
216
+ image,
217
+ size=(height, width),
218
+ )
219
+ elif isinstance(image, np.ndarray):
220
+ image = self.numpy_to_pt(image)
221
+ image = torch.nn.functional.interpolate(
222
+ image,
223
+ size=(height, width),
224
+ )
225
+ image = self.pt_to_numpy(image)
226
+ return image
227
+
228
+ def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
229
+ """
230
+ create a face_hair_mask
231
+ """
232
+ image[image < 0.5] = 0
233
+ image[image >= 0.5] = 1
234
+ return image
235
+
236
+ def preprocess(
237
+ self,
238
+ image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
239
+ height: Optional[int] = None,
240
+ width: Optional[int] = None,
241
+ ) -> torch.Tensor:
242
+ """
243
+ Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
244
+ """
245
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
246
+
247
+ # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
248
+ if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
249
+ if isinstance(image, torch.Tensor):
250
+ # if image is a pytorch tensor could have 2 possible shapes:
251
+ # 1. batch x height x width: we should insert the channel dimension at position 1
252
+ # 2. channnel x height x width: we should insert batch dimension at position 0,
253
+ # however, since both channel and batch dimension has same size 1, it is same to insert at position 1
254
+ # for simplicity, we insert a dimension of size 1 at position 1 for both cases
255
+ image = image.unsqueeze(1)
256
+ else:
257
+ # if it is a numpy array, it could have 2 possible shapes:
258
+ # 1. batch x height x width: insert channel dimension on last position
259
+ # 2. height x width x channel: insert batch dimension on first position
260
+ if image.shape[-1] == 1:
261
+ image = np.expand_dims(image, axis=0)
262
+ else:
263
+ image = np.expand_dims(image, axis=-1)
264
+
265
+ if isinstance(image, supported_formats):
266
+ image = [image]
267
+ elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
268
+ raise ValueError(
269
+ f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
270
+ )
271
+
272
+ if isinstance(image[0], PIL.Image.Image):
273
+ if self.config.do_convert_rgb:
274
+ image = [self.convert_to_rgb(i) for i in image]
275
+ elif self.config.do_convert_grayscale:
276
+ image = [self.convert_to_grayscale(i) for i in image]
277
+ if self.config.do_resize:
278
+ height, width = self.get_default_height_width(image[0], height, width)
279
+ image = [self.resize(i, height, width) for i in image]
280
+ image = self.pil_to_numpy(image) # to np
281
+ image = self.numpy_to_pt(image) # to pt
282
+
283
+ elif isinstance(image[0], np.ndarray):
284
+ image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
285
+
286
+ image = self.numpy_to_pt(image)
287
+
288
+ height, width = self.get_default_height_width(image, height, width)
289
+ if self.config.do_resize:
290
+ image = self.resize(image, height, width)
291
+
292
+ elif isinstance(image[0], torch.Tensor):
293
+ image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
294
+
295
+ if self.config.do_convert_grayscale and image.ndim == 3:
296
+ image = image.unsqueeze(1)
297
+
298
+ channel = image.shape[1]
299
+ # don't need any preprocess if the image is latents
300
+ if channel == 4:
301
+ return image
302
+
303
+ height, width = self.get_default_height_width(image, height, width)
304
+ if self.config.do_resize:
305
+ image = self.resize(image, height, width)
306
+
307
+ # expected range [0,1], normalize to [-1,1]
308
+ do_normalize = self.config.do_normalize
309
+ if image.min() < 0 and do_normalize:
310
+ warnings.warn(
311
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
312
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
313
+ FutureWarning,
314
+ )
315
+ do_normalize = False
316
+
317
+ if do_normalize:
318
+ image = self.normalize(image)
319
+
320
+ if self.config.do_binarize:
321
+ image = self.binarize(image)
322
+
323
+ return image
324
+
325
+ def postprocess(
326
+ self,
327
+ image: torch.FloatTensor,
328
+ output_type: str = "pil",
329
+ do_denormalize: Optional[List[bool]] = None,
330
+ ):
331
+ if not isinstance(image, torch.Tensor):
332
+ raise ValueError(
333
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
334
+ )
335
+ if output_type not in ["latent", "pt", "np", "pil"]:
336
+ deprecation_message = (
337
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
338
+ "`pil`, `np`, `pt`, `latent`"
339
+ )
340
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
341
+ output_type = "np"
342
+
343
+ if output_type == "latent":
344
+ return image
345
+
346
+ if do_denormalize is None:
347
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
348
+
349
+ image = torch.stack(
350
+ [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
351
+ )
352
+
353
+ if output_type == "pt":
354
+ return image
355
+
356
+ image = self.pt_to_numpy(image)
357
+
358
+ if output_type == "np":
359
+ return image
360
+
361
+ if output_type == "pil":
362
+ return self.numpy_to_pil(image)
363
+
364
+
365
+ class VaeImageProcessorLDM3D(VaeImageProcessor):
366
+ """
367
+ Image processor for VAE LDM3D.
368
+
369
+ Args:
370
+ do_resize (`bool`, *optional*, defaults to `True`):
371
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
372
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
373
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
374
+ resample (`str`, *optional*, defaults to `lanczos`):
375
+ Resampling filter to use when resizing the image.
376
+ do_normalize (`bool`, *optional*, defaults to `True`):
377
+ Whether to normalize the image to [-1,1].
378
+ """
379
+
380
+ config_name = CONFIG_NAME
381
+
382
+ @register_to_config
383
+ def __init__(
384
+ self,
385
+ do_resize: bool = True,
386
+ vae_scale_factor: int = 8,
387
+ resample: str = "lanczos",
388
+ do_normalize: bool = True,
389
+ ):
390
+ super().__init__()
391
+
392
+ @staticmethod
393
+ def numpy_to_pil(images):
394
+ """
395
+ Convert a NumPy image or a batch of images to a PIL image.
396
+ """
397
+ if images.ndim == 3:
398
+ images = images[None, ...]
399
+ images = (images * 255).round().astype("uint8")
400
+ if images.shape[-1] == 1:
401
+ # special case for grayscale (single channel) images
402
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
403
+ else:
404
+ pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
405
+
406
+ return pil_images
407
+
408
+ @staticmethod
409
+ def rgblike_to_depthmap(image):
410
+ """
411
+ Args:
412
+ image: RGB-like depth image
413
+
414
+ Returns: depth map
415
+
416
+ """
417
+ return image[:, :, 1] * 2**8 + image[:, :, 2]
418
+
419
+ def numpy_to_depth(self, images):
420
+ """
421
+ Convert a NumPy depth image or a batch of images to a PIL image.
422
+ """
423
+ if images.ndim == 3:
424
+ images = images[None, ...]
425
+ images_depth = images[:, :, :, 3:]
426
+ if images.shape[-1] == 6:
427
+ images_depth = (images_depth * 255).round().astype("uint8")
428
+ pil_images = [
429
+ Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
430
+ ]
431
+ elif images.shape[-1] == 4:
432
+ images_depth = (images_depth * 65535.0).astype(np.uint16)
433
+ pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
434
+ else:
435
+ raise Exception("Not supported")
436
+
437
+ return pil_images
438
+
439
+ def postprocess(
440
+ self,
441
+ image: torch.FloatTensor,
442
+ output_type: str = "pil",
443
+ do_denormalize: Optional[List[bool]] = None,
444
+ ):
445
+ if not isinstance(image, torch.Tensor):
446
+ raise ValueError(
447
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
448
+ )
449
+ if output_type not in ["latent", "pt", "np", "pil"]:
450
+ deprecation_message = (
451
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
452
+ "`pil`, `np`, `pt`, `latent`"
453
+ )
454
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
455
+ output_type = "np"
456
+
457
+ if do_denormalize is None:
458
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
459
+
460
+ image = torch.stack(
461
+ [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
462
+ )
463
+
464
+ image = self.pt_to_numpy(image)
465
+
466
+ if output_type == "np":
467
+ if image.shape[-1] == 6:
468
+ image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
469
+ else:
470
+ image_depth = image[:, :, :, 3:]
471
+ return image[:, :, :, :3], image_depth
472
+
473
+ if output_type == "pil":
474
+ return self.numpy_to_pil(image), self.numpy_to_depth(image)
475
+ else:
476
+ raise Exception(f"This type {output_type} is not supported")
diffusers/loaders.py ADDED
The diff for this file is too large to render. See raw diff
 
diffusers/models/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Models
2
+
3
+ For more detail on the models, please refer to the [docs](https://huggingface.co/docs/diffusers/api/models/overview).
diffusers/models/__init__.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available
18
+
19
+
20
+ _import_structure = {}
21
+
22
+ if is_torch_available():
23
+ _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
24
+ _import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
25
+ _import_structure["autoencoder_kl"] = ["AutoencoderKL"]
26
+ _import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
27
+ _import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
28
+ _import_structure["controlnet"] = ["ControlNetModel"]
29
+ _import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
30
+ _import_structure["modeling_utils"] = ["ModelMixin"]
31
+ _import_structure["prior_transformer"] = ["PriorTransformer"]
32
+ _import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
33
+ _import_structure["transformer_2d"] = ["Transformer2DModel"]
34
+ _import_structure["transformer_temporal"] = ["TransformerTemporalModel"]
35
+ _import_structure["unet_1d"] = ["UNet1DModel"]
36
+ _import_structure["unet_2d"] = ["UNet2DModel"]
37
+ _import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
38
+ _import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
39
+ _import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
40
+ _import_structure["vq_model"] = ["VQModel"]
41
+
42
+ if is_flax_available():
43
+ _import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
44
+ _import_structure["unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
45
+ _import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
46
+
47
+
48
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
49
+ if is_torch_available():
50
+ from .adapter import MultiAdapter, T2IAdapter
51
+ from .autoencoder_asym_kl import AsymmetricAutoencoderKL
52
+ from .autoencoder_kl import AutoencoderKL
53
+ from .autoencoder_tiny import AutoencoderTiny
54
+ from .consistency_decoder_vae import ConsistencyDecoderVAE
55
+ from .controlnet import ControlNetModel
56
+ from .dual_transformer_2d import DualTransformer2DModel
57
+ from .modeling_utils import ModelMixin
58
+ from .prior_transformer import PriorTransformer
59
+ from .t5_film_transformer import T5FilmDecoder
60
+ from .transformer_2d import Transformer2DModel
61
+ from .transformer_temporal import TransformerTemporalModel
62
+ from .unet_1d import UNet1DModel
63
+ from .unet_2d import UNet2DModel
64
+ from .unet_2d_condition import UNet2DConditionModel
65
+ from .unet_3d_condition import UNet3DConditionModel
66
+ from .unet_motion_model import MotionAdapter, UNetMotionModel
67
+ from .vq_model import VQModel
68
+
69
+ if is_flax_available():
70
+ from .controlnet_flax import FlaxControlNetModel
71
+ from .unet_2d_condition_flax import FlaxUNet2DConditionModel
72
+ from .vae_flax import FlaxAutoencoderKL
73
+
74
+ else:
75
+ import sys
76
+
77
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diffusers/models/activations.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from ..utils import USE_PEFT_BACKEND
21
+ from .lora import LoRACompatibleLinear
22
+
23
+
24
+ ACTIVATION_FUNCTIONS = {
25
+ "swish": nn.SiLU(),
26
+ "silu": nn.SiLU(),
27
+ "mish": nn.Mish(),
28
+ "gelu": nn.GELU(),
29
+ "relu": nn.ReLU(),
30
+ }
31
+
32
+
33
+ def get_activation(act_fn: str) -> nn.Module:
34
+ """Helper function to get activation function from string.
35
+
36
+ Args:
37
+ act_fn (str): Name of activation function.
38
+
39
+ Returns:
40
+ nn.Module: Activation function.
41
+ """
42
+
43
+ act_fn = act_fn.lower()
44
+ if act_fn in ACTIVATION_FUNCTIONS:
45
+ return ACTIVATION_FUNCTIONS[act_fn]
46
+ else:
47
+ raise ValueError(f"Unsupported activation function: {act_fn}")
48
+
49
+
50
+ class GELU(nn.Module):
51
+ r"""
52
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
53
+
54
+ Parameters:
55
+ dim_in (`int`): The number of channels in the input.
56
+ dim_out (`int`): The number of channels in the output.
57
+ approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
58
+ """
59
+
60
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
61
+ super().__init__()
62
+ self.proj = nn.Linear(dim_in, dim_out)
63
+ self.approximate = approximate
64
+
65
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
66
+ if gate.device.type != "mps":
67
+ return F.gelu(gate, approximate=self.approximate)
68
+ # mps: gelu is not implemented for float16
69
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
70
+
71
+ def forward(self, hidden_states):
72
+ hidden_states = self.proj(hidden_states)
73
+ hidden_states = self.gelu(hidden_states)
74
+ return hidden_states
75
+
76
+
77
+ class GEGLU(nn.Module):
78
+ r"""
79
+ A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function.
80
+
81
+ Parameters:
82
+ dim_in (`int`): The number of channels in the input.
83
+ dim_out (`int`): The number of channels in the output.
84
+ """
85
+
86
+ def __init__(self, dim_in: int, dim_out: int):
87
+ super().__init__()
88
+ linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
89
+
90
+ self.proj = linear_cls(dim_in, dim_out * 2)
91
+
92
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
93
+ if gate.device.type != "mps":
94
+ return F.gelu(gate)
95
+ # mps: gelu is not implemented for float16
96
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
97
+
98
+ def forward(self, hidden_states, scale: float = 1.0):
99
+ args = () if USE_PEFT_BACKEND else (scale,)
100
+ hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
101
+ return hidden_states * self.gelu(gate)
102
+
103
+
104
+ class ApproximateGELU(nn.Module):
105
+ r"""
106
+ The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this
107
+ [paper](https://arxiv.org/abs/1606.08415).
108
+
109
+ Parameters:
110
+ dim_in (`int`): The number of channels in the input.
111
+ dim_out (`int`): The number of channels in the output.
112
+ """
113
+
114
+ def __init__(self, dim_in: int, dim_out: int):
115
+ super().__init__()
116
+ self.proj = nn.Linear(dim_in, dim_out)
117
+
118
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
119
+ x = self.proj(x)
120
+ return x * torch.sigmoid(1.702 * x)
diffusers/models/adapter.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ from typing import Callable, List, Optional, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from ..configuration_utils import ConfigMixin, register_to_config
21
+ from ..utils import logging
22
+ from .modeling_utils import ModelMixin
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class MultiAdapter(ModelMixin):
29
+ r"""
30
+ MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to
31
+ user-assigned weighting.
32
+
33
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
34
+ implements for all the model (such as downloading or saving, etc.)
35
+
36
+ Parameters:
37
+ adapters (`List[T2IAdapter]`, *optional*, defaults to None):
38
+ A list of `T2IAdapter` model instances.
39
+ """
40
+
41
+ def __init__(self, adapters: List["T2IAdapter"]):
42
+ super(MultiAdapter, self).__init__()
43
+
44
+ self.num_adapter = len(adapters)
45
+ self.adapters = nn.ModuleList(adapters)
46
+
47
+ if len(adapters) == 0:
48
+ raise ValueError("Expecting at least one adapter")
49
+
50
+ if len(adapters) == 1:
51
+ raise ValueError("For a single adapter, please use the `T2IAdapter` class instead of `MultiAdapter`")
52
+
53
+ # The outputs from each adapter are added together with a weight.
54
+ # This means that the change in dimensions from downsampling must
55
+ # be the same for all adapters. Inductively, it also means the
56
+ # downscale_factor and total_downscale_factor must be the same for all
57
+ # adapters.
58
+ first_adapter_total_downscale_factor = adapters[0].total_downscale_factor
59
+ first_adapter_downscale_factor = adapters[0].downscale_factor
60
+ for idx in range(1, len(adapters)):
61
+ if (
62
+ adapters[idx].total_downscale_factor != first_adapter_total_downscale_factor
63
+ or adapters[idx].downscale_factor != first_adapter_downscale_factor
64
+ ):
65
+ raise ValueError(
66
+ f"Expecting all adapters to have the same downscaling behavior, but got:\n"
67
+ f"adapters[0].total_downscale_factor={first_adapter_total_downscale_factor}\n"
68
+ f"adapters[0].downscale_factor={first_adapter_downscale_factor}\n"
69
+ f"adapter[`{idx}`].total_downscale_factor={adapters[idx].total_downscale_factor}\n"
70
+ f"adapter[`{idx}`].downscale_factor={adapters[idx].downscale_factor}"
71
+ )
72
+
73
+ self.total_downscale_factor = first_adapter_total_downscale_factor
74
+ self.downscale_factor = first_adapter_downscale_factor
75
+
76
+ def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
77
+ r"""
78
+ Args:
79
+ xs (`torch.Tensor`):
80
+ (batch, channel, height, width) input images for multiple adapter models concated along dimension 1,
81
+ `channel` should equal to `num_adapter` * "number of channel of image".
82
+ adapter_weights (`List[float]`, *optional*, defaults to None):
83
+ List of floats representing the weight which will be multiply to each adapter's output before adding
84
+ them together.
85
+ """
86
+ if adapter_weights is None:
87
+ adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter)
88
+ else:
89
+ adapter_weights = torch.tensor(adapter_weights)
90
+
91
+ accume_state = None
92
+ for x, w, adapter in zip(xs, adapter_weights, self.adapters):
93
+ features = adapter(x)
94
+ if accume_state is None:
95
+ accume_state = features
96
+ for i in range(len(accume_state)):
97
+ accume_state[i] = w * accume_state[i]
98
+ else:
99
+ for i in range(len(features)):
100
+ accume_state[i] += w * features[i]
101
+ return accume_state
102
+
103
+ def save_pretrained(
104
+ self,
105
+ save_directory: Union[str, os.PathLike],
106
+ is_main_process: bool = True,
107
+ save_function: Callable = None,
108
+ safe_serialization: bool = True,
109
+ variant: Optional[str] = None,
110
+ ):
111
+ """
112
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
113
+ `[`~models.adapter.MultiAdapter.from_pretrained`]` class method.
114
+
115
+ Arguments:
116
+ save_directory (`str` or `os.PathLike`):
117
+ Directory to which to save. Will be created if it doesn't exist.
118
+ is_main_process (`bool`, *optional*, defaults to `True`):
119
+ Whether the process calling this is the main process or not. Useful when in distributed training like
120
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
121
+ the main process to avoid race conditions.
122
+ save_function (`Callable`):
123
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
124
+ need to replace `torch.save` by another method. Can be configured with the environment variable
125
+ `DIFFUSERS_SAVE_MODE`.
126
+ safe_serialization (`bool`, *optional*, defaults to `True`):
127
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
128
+ variant (`str`, *optional*):
129
+ If specified, weights are saved in the format pytorch_model.<variant>.bin.
130
+ """
131
+ idx = 0
132
+ model_path_to_save = save_directory
133
+ for adapter in self.adapters:
134
+ adapter.save_pretrained(
135
+ model_path_to_save,
136
+ is_main_process=is_main_process,
137
+ save_function=save_function,
138
+ safe_serialization=safe_serialization,
139
+ variant=variant,
140
+ )
141
+
142
+ idx += 1
143
+ model_path_to_save = model_path_to_save + f"_{idx}"
144
+
145
+ @classmethod
146
+ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
147
+ r"""
148
+ Instantiate a pretrained MultiAdapter model from multiple pre-trained adapter models.
149
+
150
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
151
+ the model, you should first set it back in training mode with `model.train()`.
152
+
153
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
154
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
155
+ task.
156
+
157
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
158
+ weights are discarded.
159
+
160
+ Parameters:
161
+ pretrained_model_path (`os.PathLike`):
162
+ A path to a *directory* containing model weights saved using
163
+ [`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g., `./my_model_directory/adapter`.
164
+ torch_dtype (`str` or `torch.dtype`, *optional*):
165
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
166
+ will be automatically derived from the model's weights.
167
+ output_loading_info(`bool`, *optional*, defaults to `False`):
168
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
169
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
170
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
171
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
172
+ same device.
173
+
174
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
175
+ more information about each option see [designing a device
176
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
177
+ max_memory (`Dict`, *optional*):
178
+ A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
179
+ GPU and the available CPU RAM if unset.
180
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
181
+ Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
182
+ also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
183
+ model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
184
+ setting this argument to `True` will raise an error.
185
+ variant (`str`, *optional*):
186
+ If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
187
+ ignored when using `from_flax`.
188
+ use_safetensors (`bool`, *optional*, defaults to `None`):
189
+ If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
190
+ `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
191
+ `safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
192
+ """
193
+ idx = 0
194
+ adapters = []
195
+
196
+ # load adapter and append to list until no adapter directory exists anymore
197
+ # first adapter has to be saved under `./mydirectory/adapter` to be compliant with `DiffusionPipeline.from_pretrained`
198
+ # second, third, ... adapters have to be saved under `./mydirectory/adapter_1`, `./mydirectory/adapter_2`, ...
199
+ model_path_to_load = pretrained_model_path
200
+ while os.path.isdir(model_path_to_load):
201
+ adapter = T2IAdapter.from_pretrained(model_path_to_load, **kwargs)
202
+ adapters.append(adapter)
203
+
204
+ idx += 1
205
+ model_path_to_load = pretrained_model_path + f"_{idx}"
206
+
207
+ logger.info(f"{len(adapters)} adapters loaded from {pretrained_model_path}.")
208
+
209
+ if len(adapters) == 0:
210
+ raise ValueError(
211
+ f"No T2IAdapters found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
212
+ )
213
+
214
+ return cls(adapters)
215
+
216
+
217
+ class T2IAdapter(ModelMixin, ConfigMixin):
218
+ r"""
219
+ A simple ResNet-like model that accepts images containing control signals such as keyposes and depth. The model
220
+ generates multiple feature maps that are used as additional conditioning in [`UNet2DConditionModel`]. The model's
221
+ architecture follows the original implementation of
222
+ [Adapter](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L97)
223
+ and
224
+ [AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235).
225
+
226
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
227
+ implements for all the model (such as downloading or saving, etc.)
228
+
229
+ Parameters:
230
+ in_channels (`int`, *optional*, defaults to 3):
231
+ Number of channels of Aapter's input(*control image*). Set this parameter to 1 if you're using gray scale
232
+ image as *control image*.
233
+ channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
234
+ The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will
235
+ also determine the number of downsample blocks in the Adapter.
236
+ num_res_blocks (`int`, *optional*, defaults to 2):
237
+ Number of ResNet blocks in each downsample block.
238
+ downscale_factor (`int`, *optional*, defaults to 8):
239
+ A factor that determines the total downscale factor of the Adapter.
240
+ adapter_type (`str`, *optional*, defaults to `full_adapter`):
241
+ The type of Adapter to use. Choose either `full_adapter` or `full_adapter_xl` or `light_adapter`.
242
+ """
243
+
244
+ @register_to_config
245
+ def __init__(
246
+ self,
247
+ in_channels: int = 3,
248
+ channels: List[int] = [320, 640, 1280, 1280],
249
+ num_res_blocks: int = 2,
250
+ downscale_factor: int = 8,
251
+ adapter_type: str = "full_adapter",
252
+ ):
253
+ super().__init__()
254
+
255
+ if adapter_type == "full_adapter":
256
+ self.adapter = FullAdapter(in_channels, channels, num_res_blocks, downscale_factor)
257
+ elif adapter_type == "full_adapter_xl":
258
+ self.adapter = FullAdapterXL(in_channels, channels, num_res_blocks, downscale_factor)
259
+ elif adapter_type == "light_adapter":
260
+ self.adapter = LightAdapter(in_channels, channels, num_res_blocks, downscale_factor)
261
+ else:
262
+ raise ValueError(
263
+ f"Unsupported adapter_type: '{adapter_type}'. Choose either 'full_adapter' or "
264
+ "'full_adapter_xl' or 'light_adapter'."
265
+ )
266
+
267
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
268
+ r"""
269
+ This function processes the input tensor `x` through the adapter model and returns a list of feature tensors,
270
+ each representing information extracted at a different scale from the input. The length of the list is
271
+ determined by the number of downsample blocks in the Adapter, as specified by the `channels` and
272
+ `num_res_blocks` parameters during initialization.
273
+ """
274
+ return self.adapter(x)
275
+
276
+ @property
277
+ def total_downscale_factor(self):
278
+ return self.adapter.total_downscale_factor
279
+
280
+ @property
281
+ def downscale_factor(self):
282
+ """The downscale factor applied in the T2I-Adapter's initial pixel unshuffle operation. If an input image's dimensions are
283
+ not evenly divisible by the downscale_factor then an exception will be raised.
284
+ """
285
+ return self.adapter.unshuffle.downscale_factor
286
+
287
+
288
+ # full adapter
289
+
290
+
291
+ class FullAdapter(nn.Module):
292
+ r"""
293
+ See [`T2IAdapter`] for more information.
294
+ """
295
+
296
+ def __init__(
297
+ self,
298
+ in_channels: int = 3,
299
+ channels: List[int] = [320, 640, 1280, 1280],
300
+ num_res_blocks: int = 2,
301
+ downscale_factor: int = 8,
302
+ ):
303
+ super().__init__()
304
+
305
+ in_channels = in_channels * downscale_factor**2
306
+
307
+ self.unshuffle = nn.PixelUnshuffle(downscale_factor)
308
+ self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
309
+
310
+ self.body = nn.ModuleList(
311
+ [
312
+ AdapterBlock(channels[0], channels[0], num_res_blocks),
313
+ *[
314
+ AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True)
315
+ for i in range(1, len(channels))
316
+ ],
317
+ ]
318
+ )
319
+
320
+ self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1)
321
+
322
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
323
+ r"""
324
+ This method processes the input tensor `x` through the FullAdapter model and performs operations including
325
+ pixel unshuffling, convolution, and a stack of AdapterBlocks. It returns a list of feature tensors, each
326
+ capturing information at a different stage of processing within the FullAdapter model. The number of feature
327
+ tensors in the list is determined by the number of downsample blocks specified during initialization.
328
+ """
329
+ x = self.unshuffle(x)
330
+ x = self.conv_in(x)
331
+
332
+ features = []
333
+
334
+ for block in self.body:
335
+ x = block(x)
336
+ features.append(x)
337
+
338
+ return features
339
+
340
+
341
+ class FullAdapterXL(nn.Module):
342
+ r"""
343
+ See [`T2IAdapter`] for more information.
344
+ """
345
+
346
+ def __init__(
347
+ self,
348
+ in_channels: int = 3,
349
+ channels: List[int] = [320, 640, 1280, 1280],
350
+ num_res_blocks: int = 2,
351
+ downscale_factor: int = 16,
352
+ ):
353
+ super().__init__()
354
+
355
+ in_channels = in_channels * downscale_factor**2
356
+
357
+ self.unshuffle = nn.PixelUnshuffle(downscale_factor)
358
+ self.conv_in = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
359
+
360
+ self.body = []
361
+ # blocks to extract XL features with dimensions of [320, 64, 64], [640, 64, 64], [1280, 32, 32], [1280, 32, 32]
362
+ for i in range(len(channels)):
363
+ if i == 1:
364
+ self.body.append(AdapterBlock(channels[i - 1], channels[i], num_res_blocks))
365
+ elif i == 2:
366
+ self.body.append(AdapterBlock(channels[i - 1], channels[i], num_res_blocks, down=True))
367
+ else:
368
+ self.body.append(AdapterBlock(channels[i], channels[i], num_res_blocks))
369
+
370
+ self.body = nn.ModuleList(self.body)
371
+ # XL has only one downsampling AdapterBlock.
372
+ self.total_downscale_factor = downscale_factor * 2
373
+
374
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
375
+ r"""
376
+ This method takes the tensor x as input and processes it through FullAdapterXL model. It consists of operations
377
+ including unshuffling pixels, applying convolution layer and appending each block into list of feature tensors.
378
+ """
379
+ x = self.unshuffle(x)
380
+ x = self.conv_in(x)
381
+
382
+ features = []
383
+
384
+ for block in self.body:
385
+ x = block(x)
386
+ features.append(x)
387
+
388
+ return features
389
+
390
+
391
+ class AdapterBlock(nn.Module):
392
+ r"""
393
+ An AdapterBlock is a helper model that contains multiple ResNet-like blocks. It is used in the `FullAdapter` and
394
+ `FullAdapterXL` models.
395
+
396
+ Parameters:
397
+ in_channels (`int`):
398
+ Number of channels of AdapterBlock's input.
399
+ out_channels (`int`):
400
+ Number of channels of AdapterBlock's output.
401
+ num_res_blocks (`int`):
402
+ Number of ResNet blocks in the AdapterBlock.
403
+ down (`bool`, *optional*, defaults to `False`):
404
+ Whether to perform downsampling on AdapterBlock's input.
405
+ """
406
+
407
+ def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
408
+ super().__init__()
409
+
410
+ self.downsample = None
411
+ if down:
412
+ self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)
413
+
414
+ self.in_conv = None
415
+ if in_channels != out_channels:
416
+ self.in_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
417
+
418
+ self.resnets = nn.Sequential(
419
+ *[AdapterResnetBlock(out_channels) for _ in range(num_res_blocks)],
420
+ )
421
+
422
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
423
+ r"""
424
+ This method takes tensor x as input and performs operations downsampling and convolutional layers if the
425
+ self.downsample and self.in_conv properties of AdapterBlock model are specified. Then it applies a series of
426
+ residual blocks to the input tensor.
427
+ """
428
+ if self.downsample is not None:
429
+ x = self.downsample(x)
430
+
431
+ if self.in_conv is not None:
432
+ x = self.in_conv(x)
433
+
434
+ x = self.resnets(x)
435
+
436
+ return x
437
+
438
+
439
+ class AdapterResnetBlock(nn.Module):
440
+ r"""
441
+ An `AdapterResnetBlock` is a helper model that implements a ResNet-like block.
442
+
443
+ Parameters:
444
+ channels (`int`):
445
+ Number of channels of AdapterResnetBlock's input and output.
446
+ """
447
+
448
+ def __init__(self, channels: int):
449
+ super().__init__()
450
+ self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
451
+ self.act = nn.ReLU()
452
+ self.block2 = nn.Conv2d(channels, channels, kernel_size=1)
453
+
454
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
455
+ r"""
456
+ This method takes input tensor x and applies a convolutional layer, ReLU activation, and another convolutional
457
+ layer on the input tensor. It returns addition with the input tensor.
458
+ """
459
+
460
+ h = self.act(self.block1(x))
461
+ h = self.block2(h)
462
+
463
+ return h + x
464
+
465
+
466
+ # light adapter
467
+
468
+
469
+ class LightAdapter(nn.Module):
470
+ r"""
471
+ See [`T2IAdapter`] for more information.
472
+ """
473
+
474
+ def __init__(
475
+ self,
476
+ in_channels: int = 3,
477
+ channels: List[int] = [320, 640, 1280],
478
+ num_res_blocks: int = 4,
479
+ downscale_factor: int = 8,
480
+ ):
481
+ super().__init__()
482
+
483
+ in_channels = in_channels * downscale_factor**2
484
+
485
+ self.unshuffle = nn.PixelUnshuffle(downscale_factor)
486
+
487
+ self.body = nn.ModuleList(
488
+ [
489
+ LightAdapterBlock(in_channels, channels[0], num_res_blocks),
490
+ *[
491
+ LightAdapterBlock(channels[i], channels[i + 1], num_res_blocks, down=True)
492
+ for i in range(len(channels) - 1)
493
+ ],
494
+ LightAdapterBlock(channels[-1], channels[-1], num_res_blocks, down=True),
495
+ ]
496
+ )
497
+
498
+ self.total_downscale_factor = downscale_factor * (2 ** len(channels))
499
+
500
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
501
+ r"""
502
+ This method takes the input tensor x and performs downscaling and appends it in list of feature tensors. Each
503
+ feature tensor corresponds to a different level of processing within the LightAdapter.
504
+ """
505
+ x = self.unshuffle(x)
506
+
507
+ features = []
508
+
509
+ for block in self.body:
510
+ x = block(x)
511
+ features.append(x)
512
+
513
+ return features
514
+
515
+
516
+ class LightAdapterBlock(nn.Module):
517
+ r"""
518
+ A `LightAdapterBlock` is a helper model that contains multiple `LightAdapterResnetBlocks`. It is used in the
519
+ `LightAdapter` model.
520
+
521
+ Parameters:
522
+ in_channels (`int`):
523
+ Number of channels of LightAdapterBlock's input.
524
+ out_channels (`int`):
525
+ Number of channels of LightAdapterBlock's output.
526
+ num_res_blocks (`int`):
527
+ Number of LightAdapterResnetBlocks in the LightAdapterBlock.
528
+ down (`bool`, *optional*, defaults to `False`):
529
+ Whether to perform downsampling on LightAdapterBlock's input.
530
+ """
531
+
532
+ def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
533
+ super().__init__()
534
+ mid_channels = out_channels // 4
535
+
536
+ self.downsample = None
537
+ if down:
538
+ self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)
539
+
540
+ self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
541
+ self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)])
542
+ self.out_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1)
543
+
544
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
545
+ r"""
546
+ This method takes tensor x as input and performs downsampling if required. Then it applies in convolution
547
+ layer, a sequence of residual blocks, and out convolutional layer.
548
+ """
549
+ if self.downsample is not None:
550
+ x = self.downsample(x)
551
+
552
+ x = self.in_conv(x)
553
+ x = self.resnets(x)
554
+ x = self.out_conv(x)
555
+
556
+ return x
557
+
558
+
559
+ class LightAdapterResnetBlock(nn.Module):
560
+ """
561
+ A `LightAdapterResnetBlock` is a helper model that implements a ResNet-like block with a slightly different
562
+ architecture than `AdapterResnetBlock`.
563
+
564
+ Parameters:
565
+ channels (`int`):
566
+ Number of channels of LightAdapterResnetBlock's input and output.
567
+ """
568
+
569
+ def __init__(self, channels: int):
570
+ super().__init__()
571
+ self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
572
+ self.act = nn.ReLU()
573
+ self.block2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
574
+
575
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
576
+ r"""
577
+ This function takes input tensor x and processes it through one convolutional layer, ReLU activation, and
578
+ another convolutional layer and adds it to input tensor.
579
+ """
580
+
581
+ h = self.act(self.block1(x))
582
+ h = self.block2(h)
583
+
584
+ return h + x
diffusers/models/attention.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+ from ..utils import USE_PEFT_BACKEND
20
+ from ..utils.torch_utils import maybe_allow_in_graph
21
+ from .activations import GEGLU, GELU, ApproximateGELU
22
+ from .attention_processor import Attention
23
+ from .embeddings import SinusoidalPositionalEmbedding
24
+ from .lora import LoRACompatibleLinear
25
+ from .normalization import AdaLayerNorm, AdaLayerNormZero
26
+
27
+
28
+ @maybe_allow_in_graph
29
+ class GatedSelfAttentionDense(nn.Module):
30
+ r"""
31
+ A gated self-attention dense layer that combines visual features and object features.
32
+
33
+ Parameters:
34
+ query_dim (`int`): The number of channels in the query.
35
+ context_dim (`int`): The number of channels in the context.
36
+ n_heads (`int`): The number of heads to use for attention.
37
+ d_head (`int`): The number of channels in each head.
38
+ """
39
+
40
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
41
+ super().__init__()
42
+
43
+ # we need a linear projection since we need cat visual feature and obj feature
44
+ self.linear = nn.Linear(context_dim, query_dim)
45
+
46
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
47
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
48
+
49
+ self.norm1 = nn.LayerNorm(query_dim)
50
+ self.norm2 = nn.LayerNorm(query_dim)
51
+
52
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
53
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
54
+
55
+ self.enabled = True
56
+
57
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
58
+ if not self.enabled:
59
+ return x
60
+
61
+ n_visual = x.shape[1]
62
+ objs = self.linear(objs)
63
+
64
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
65
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
66
+
67
+ return x
68
+
69
+
70
+ @maybe_allow_in_graph
71
+ class BasicTransformerBlock(nn.Module):
72
+ r"""
73
+ A basic Transformer block.
74
+
75
+ Parameters:
76
+ dim (`int`): The number of channels in the input and output.
77
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
78
+ attention_head_dim (`int`): The number of channels in each head.
79
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
80
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
81
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
82
+ num_embeds_ada_norm (:
83
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
84
+ attention_bias (:
85
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
86
+ only_cross_attention (`bool`, *optional*):
87
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
88
+ double_self_attention (`bool`, *optional*):
89
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
90
+ upcast_attention (`bool`, *optional*):
91
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
92
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
93
+ Whether to use learnable elementwise affine parameters for normalization.
94
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
95
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
96
+ final_dropout (`bool` *optional*, defaults to False):
97
+ Whether to apply a final dropout after the last feed-forward layer.
98
+ attention_type (`str`, *optional*, defaults to `"default"`):
99
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
100
+ positional_embeddings (`str`, *optional*, defaults to `None`):
101
+ The type of positional embeddings to apply to.
102
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
103
+ The maximum number of positional embeddings to apply.
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ dim: int,
109
+ num_attention_heads: int,
110
+ attention_head_dim: int,
111
+ dropout=0.0,
112
+ cross_attention_dim: Optional[int] = None,
113
+ activation_fn: str = "geglu",
114
+ num_embeds_ada_norm: Optional[int] = None,
115
+ attention_bias: bool = False,
116
+ only_cross_attention: bool = False,
117
+ double_self_attention: bool = False,
118
+ upcast_attention: bool = False,
119
+ norm_elementwise_affine: bool = True,
120
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
121
+ norm_eps: float = 1e-5,
122
+ final_dropout: bool = False,
123
+ attention_type: str = "default",
124
+ positional_embeddings: Optional[str] = None,
125
+ num_positional_embeddings: Optional[int] = None,
126
+ ):
127
+ super().__init__()
128
+ self.only_cross_attention = only_cross_attention
129
+
130
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
131
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
132
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
133
+ self.use_layer_norm = norm_type == "layer_norm"
134
+
135
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
136
+ raise ValueError(
137
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
138
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
139
+ )
140
+
141
+ if positional_embeddings and (num_positional_embeddings is None):
142
+ raise ValueError(
143
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
144
+ )
145
+
146
+ if positional_embeddings == "sinusoidal":
147
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
148
+ else:
149
+ self.pos_embed = None
150
+
151
+ # Define 3 blocks. Each block has its own normalization layer.
152
+ # 1. Self-Attn
153
+ if self.use_ada_layer_norm:
154
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
155
+ elif self.use_ada_layer_norm_zero:
156
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
157
+ else:
158
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
159
+
160
+ self.attn1 = Attention(
161
+ query_dim=dim,
162
+ heads=num_attention_heads,
163
+ dim_head=attention_head_dim,
164
+ dropout=dropout,
165
+ bias=attention_bias,
166
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
167
+ upcast_attention=upcast_attention,
168
+ )
169
+
170
+ # 2. Cross-Attn
171
+ if cross_attention_dim is not None or double_self_attention:
172
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
173
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
174
+ # the second cross attention block.
175
+ self.norm2 = (
176
+ AdaLayerNorm(dim, num_embeds_ada_norm)
177
+ if self.use_ada_layer_norm
178
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
179
+ )
180
+ self.attn2 = Attention(
181
+ query_dim=dim,
182
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
183
+ heads=num_attention_heads,
184
+ dim_head=attention_head_dim,
185
+ dropout=dropout,
186
+ bias=attention_bias,
187
+ upcast_attention=upcast_attention,
188
+ ) # is self-attn if encoder_hidden_states is none
189
+ else:
190
+ self.norm2 = None
191
+ self.attn2 = None
192
+
193
+ # 3. Feed-forward
194
+ if not self.use_ada_layer_norm_single:
195
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
196
+
197
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
198
+
199
+ # 4. Fuser
200
+ if attention_type == "gated" or attention_type == "gated-text-image":
201
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
202
+
203
+ # 5. Scale-shift for PixArt-Alpha.
204
+ if self.use_ada_layer_norm_single:
205
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
206
+
207
+ # let chunk size default to None
208
+ self._chunk_size = None
209
+ self._chunk_dim = 0
210
+
211
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
212
+ # Sets chunk feed-forward
213
+ self._chunk_size = chunk_size
214
+ self._chunk_dim = dim
215
+
216
+ def forward(
217
+ self,
218
+ hidden_states: torch.FloatTensor,
219
+ attention_mask: Optional[torch.FloatTensor] = None,
220
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
221
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
222
+ timestep: Optional[torch.LongTensor] = None,
223
+ cross_attention_kwargs: Dict[str, Any] = None,
224
+ class_labels: Optional[torch.LongTensor] = None,
225
+ ) -> torch.FloatTensor:
226
+ # Notice that normalization is always applied before the real computation in the following blocks.
227
+ # 0. Self-Attention
228
+ batch_size = hidden_states.shape[0]
229
+
230
+ if self.use_ada_layer_norm:
231
+ norm_hidden_states = self.norm1(hidden_states, timestep)
232
+ elif self.use_ada_layer_norm_zero:
233
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
234
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
235
+ )
236
+ elif self.use_layer_norm:
237
+ norm_hidden_states = self.norm1(hidden_states)
238
+ elif self.use_ada_layer_norm_single:
239
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
240
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
241
+ ).chunk(6, dim=1)
242
+ norm_hidden_states = self.norm1(hidden_states)
243
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
244
+ norm_hidden_states = norm_hidden_states.squeeze(1)
245
+ else:
246
+ raise ValueError("Incorrect norm used")
247
+
248
+ if self.pos_embed is not None:
249
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
250
+
251
+ # 1. Retrieve lora scale.
252
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
253
+
254
+ # 2. Prepare GLIGEN inputs
255
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
256
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
257
+
258
+ attn_output = self.attn1(
259
+ norm_hidden_states, # 32 4096 320
260
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, # 32 77 768
261
+ attention_mask=attention_mask,
262
+ **cross_attention_kwargs,
263
+ )
264
+ if self.use_ada_layer_norm_zero:
265
+ attn_output = gate_msa.unsqueeze(1) * attn_output
266
+ elif self.use_ada_layer_norm_single:
267
+ attn_output = gate_msa * attn_output
268
+
269
+ hidden_states = attn_output + hidden_states
270
+ if hidden_states.ndim == 4:
271
+ hidden_states = hidden_states.squeeze(1)
272
+
273
+ # 2.5 GLIGEN Control
274
+ if gligen_kwargs is not None:
275
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
276
+
277
+ # 3. Cross-Attention
278
+ if self.attn2 is not None:
279
+ if self.use_ada_layer_norm:
280
+ norm_hidden_states = self.norm2(hidden_states, timestep)
281
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
282
+ norm_hidden_states = self.norm2(hidden_states)
283
+ elif self.use_ada_layer_norm_single:
284
+ # For PixArt norm2 isn't applied here:
285
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
286
+ norm_hidden_states = hidden_states
287
+ else:
288
+ raise ValueError("Incorrect norm")
289
+
290
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
291
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
292
+
293
+ attn_output = self.attn2(
294
+ norm_hidden_states,
295
+ encoder_hidden_states=encoder_hidden_states,
296
+ attention_mask=encoder_attention_mask,
297
+ **cross_attention_kwargs,
298
+ )
299
+ # print(attn_output.shape)
300
+ # print(hidden_states.shape)
301
+ hidden_states = attn_output + hidden_states
302
+
303
+ # 4. Feed-forward
304
+ if not self.use_ada_layer_norm_single:
305
+ norm_hidden_states = self.norm3(hidden_states)
306
+
307
+ if self.use_ada_layer_norm_zero:
308
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
309
+
310
+ if self.use_ada_layer_norm_single:
311
+ norm_hidden_states = self.norm2(hidden_states)
312
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
313
+
314
+ if self._chunk_size is not None:
315
+ # "feed_forward_chunk_size" can be used to save memory
316
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
317
+ raise ValueError(
318
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
319
+ )
320
+
321
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
322
+ ff_output = torch.cat(
323
+ [
324
+ self.ff(hid_slice, scale=lora_scale)
325
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
326
+ ],
327
+ dim=self._chunk_dim,
328
+ )
329
+ else:
330
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
331
+
332
+ if self.use_ada_layer_norm_zero:
333
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
334
+ elif self.use_ada_layer_norm_single:
335
+ ff_output = gate_mlp * ff_output
336
+
337
+ hidden_states = ff_output + hidden_states
338
+ if hidden_states.ndim == 4:
339
+ hidden_states = hidden_states.squeeze(1)
340
+
341
+ return hidden_states
342
+
343
+
344
+ class FeedForward(nn.Module):
345
+ r"""
346
+ A feed-forward layer.
347
+
348
+ Parameters:
349
+ dim (`int`): The number of channels in the input.
350
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
351
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
352
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
353
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
354
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
355
+ """
356
+
357
+ def __init__(
358
+ self,
359
+ dim: int,
360
+ dim_out: Optional[int] = None,
361
+ mult: int = 4,
362
+ dropout: float = 0.0,
363
+ activation_fn: str = "geglu",
364
+ final_dropout: bool = False,
365
+ ):
366
+ super().__init__()
367
+ inner_dim = int(dim * mult)
368
+ dim_out = dim_out if dim_out is not None else dim
369
+ linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
370
+
371
+ if activation_fn == "gelu":
372
+ act_fn = GELU(dim, inner_dim)
373
+ if activation_fn == "gelu-approximate":
374
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
375
+ elif activation_fn == "geglu":
376
+ act_fn = GEGLU(dim, inner_dim)
377
+ elif activation_fn == "geglu-approximate":
378
+ act_fn = ApproximateGELU(dim, inner_dim)
379
+
380
+ self.net = nn.ModuleList([])
381
+ # project in
382
+ self.net.append(act_fn)
383
+ # project dropout
384
+ self.net.append(nn.Dropout(dropout))
385
+ # project out
386
+ self.net.append(linear_cls(inner_dim, dim_out))
387
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
388
+ if final_dropout:
389
+ self.net.append(nn.Dropout(dropout))
390
+
391
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
392
+ compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
393
+ for module in self.net:
394
+ if isinstance(module, compatible_cls):
395
+ hidden_states = module(hidden_states, scale)
396
+ else:
397
+ hidden_states = module(hidden_states)
398
+ return hidden_states
diffusers/models/attention_flax.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ import math
17
+
18
+ import flax.linen as nn
19
+ import jax
20
+ import jax.numpy as jnp
21
+
22
+
23
+ def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
24
+ """Multi-head dot product attention with a limited number of queries."""
25
+ num_kv, num_heads, k_features = key.shape[-3:]
26
+ v_features = value.shape[-1]
27
+ key_chunk_size = min(key_chunk_size, num_kv)
28
+ query = query / jnp.sqrt(k_features)
29
+
30
+ @functools.partial(jax.checkpoint, prevent_cse=False)
31
+ def summarize_chunk(query, key, value):
32
+ attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)
33
+
34
+ max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
35
+ max_score = jax.lax.stop_gradient(max_score)
36
+ exp_weights = jnp.exp(attn_weights - max_score)
37
+
38
+ exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
39
+ max_score = jnp.einsum("...qhk->...qh", max_score)
40
+
41
+ return (exp_values, exp_weights.sum(axis=-1), max_score)
42
+
43
+ def chunk_scanner(chunk_idx):
44
+ # julienne key array
45
+ key_chunk = jax.lax.dynamic_slice(
46
+ operand=key,
47
+ start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d]
48
+ slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d]
49
+ )
50
+
51
+ # julienne value array
52
+ value_chunk = jax.lax.dynamic_slice(
53
+ operand=value,
54
+ start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d]
55
+ slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d]
56
+ )
57
+
58
+ return summarize_chunk(query, key_chunk, value_chunk)
59
+
60
+ chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
61
+
62
+ global_max = jnp.max(chunk_max, axis=0, keepdims=True)
63
+ max_diffs = jnp.exp(chunk_max - global_max)
64
+
65
+ chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
66
+ chunk_weights *= max_diffs
67
+
68
+ all_values = chunk_values.sum(axis=0)
69
+ all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
70
+
71
+ return all_values / all_weights
72
+
73
+
74
+ def jax_memory_efficient_attention(
75
+ query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
76
+ ):
77
+ r"""
78
+ Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
79
+ https://github.com/AminRezaei0x443/memory-efficient-attention
80
+
81
+ Args:
82
+ query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
83
+ key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
84
+ value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
85
+ precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
86
+ numerical precision for computation
87
+ query_chunk_size (`int`, *optional*, defaults to 1024):
88
+ chunk size to divide query array value must divide query_length equally without remainder
89
+ key_chunk_size (`int`, *optional*, defaults to 4096):
90
+ chunk size to divide key and value array value must divide key_value_length equally without remainder
91
+
92
+ Returns:
93
+ (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
94
+ """
95
+ num_q, num_heads, q_features = query.shape[-3:]
96
+
97
+ def chunk_scanner(chunk_idx, _):
98
+ # julienne query array
99
+ query_chunk = jax.lax.dynamic_slice(
100
+ operand=query,
101
+ start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d]
102
+ slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d]
103
+ )
104
+
105
+ return (
106
+ chunk_idx + query_chunk_size, # unused ignore it
107
+ _query_chunk_attention(
108
+ query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
109
+ ),
110
+ )
111
+
112
+ _, res = jax.lax.scan(
113
+ f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter
114
+ )
115
+
116
+ return jnp.concatenate(res, axis=-3) # fuse the chunked result back
117
+
118
+
119
+ class FlaxAttention(nn.Module):
120
+ r"""
121
+ A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
122
+
123
+ Parameters:
124
+ query_dim (:obj:`int`):
125
+ Input hidden states dimension
126
+ heads (:obj:`int`, *optional*, defaults to 8):
127
+ Number of heads
128
+ dim_head (:obj:`int`, *optional*, defaults to 64):
129
+ Hidden states dimension inside each head
130
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
131
+ Dropout rate
132
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
133
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
134
+ split_head_dim (`bool`, *optional*, defaults to `False`):
135
+ Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
136
+ enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
137
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
138
+ Parameters `dtype`
139
+
140
+ """
141
+ query_dim: int
142
+ heads: int = 8
143
+ dim_head: int = 64
144
+ dropout: float = 0.0
145
+ use_memory_efficient_attention: bool = False
146
+ split_head_dim: bool = False
147
+ dtype: jnp.dtype = jnp.float32
148
+
149
+ def setup(self):
150
+ inner_dim = self.dim_head * self.heads
151
+ self.scale = self.dim_head**-0.5
152
+
153
+ # Weights were exported with old names {to_q, to_k, to_v, to_out}
154
+ self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
155
+ self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
156
+ self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
157
+
158
+ self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
159
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
160
+
161
+ def reshape_heads_to_batch_dim(self, tensor):
162
+ batch_size, seq_len, dim = tensor.shape
163
+ head_size = self.heads
164
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
165
+ tensor = jnp.transpose(tensor, (0, 2, 1, 3))
166
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
167
+ return tensor
168
+
169
+ def reshape_batch_dim_to_heads(self, tensor):
170
+ batch_size, seq_len, dim = tensor.shape
171
+ head_size = self.heads
172
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
173
+ tensor = jnp.transpose(tensor, (0, 2, 1, 3))
174
+ tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
175
+ return tensor
176
+
177
+ def __call__(self, hidden_states, context=None, deterministic=True):
178
+ context = hidden_states if context is None else context
179
+
180
+ query_proj = self.query(hidden_states)
181
+ key_proj = self.key(context)
182
+ value_proj = self.value(context)
183
+
184
+ if self.split_head_dim:
185
+ b = hidden_states.shape[0]
186
+ query_states = jnp.reshape(query_proj, (b, -1, self.heads, self.dim_head))
187
+ key_states = jnp.reshape(key_proj, (b, -1, self.heads, self.dim_head))
188
+ value_states = jnp.reshape(value_proj, (b, -1, self.heads, self.dim_head))
189
+ else:
190
+ query_states = self.reshape_heads_to_batch_dim(query_proj)
191
+ key_states = self.reshape_heads_to_batch_dim(key_proj)
192
+ value_states = self.reshape_heads_to_batch_dim(value_proj)
193
+
194
+ if self.use_memory_efficient_attention:
195
+ query_states = query_states.transpose(1, 0, 2)
196
+ key_states = key_states.transpose(1, 0, 2)
197
+ value_states = value_states.transpose(1, 0, 2)
198
+
199
+ # this if statement create a chunk size for each layer of the unet
200
+ # the chunk size is equal to the query_length dimension of the deepest layer of the unet
201
+
202
+ flatten_latent_dim = query_states.shape[-3]
203
+ if flatten_latent_dim % 64 == 0:
204
+ query_chunk_size = int(flatten_latent_dim / 64)
205
+ elif flatten_latent_dim % 16 == 0:
206
+ query_chunk_size = int(flatten_latent_dim / 16)
207
+ elif flatten_latent_dim % 4 == 0:
208
+ query_chunk_size = int(flatten_latent_dim / 4)
209
+ else:
210
+ query_chunk_size = int(flatten_latent_dim)
211
+
212
+ hidden_states = jax_memory_efficient_attention(
213
+ query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
214
+ )
215
+
216
+ hidden_states = hidden_states.transpose(1, 0, 2)
217
+ else:
218
+ # compute attentions
219
+ if self.split_head_dim:
220
+ attention_scores = jnp.einsum("b t n h, b f n h -> b n f t", key_states, query_states)
221
+ else:
222
+ attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
223
+
224
+ attention_scores = attention_scores * self.scale
225
+ attention_probs = nn.softmax(attention_scores, axis=-1 if self.split_head_dim else 2)
226
+
227
+ # attend to values
228
+ if self.split_head_dim:
229
+ hidden_states = jnp.einsum("b n f t, b t n h -> b f n h", attention_probs, value_states)
230
+ b = hidden_states.shape[0]
231
+ hidden_states = jnp.reshape(hidden_states, (b, -1, self.heads * self.dim_head))
232
+ else:
233
+ hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
234
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
235
+
236
+ hidden_states = self.proj_attn(hidden_states)
237
+ return self.dropout_layer(hidden_states, deterministic=deterministic)
238
+
239
+
240
+ class FlaxBasicTransformerBlock(nn.Module):
241
+ r"""
242
+ A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:
243
+ https://arxiv.org/abs/1706.03762
244
+
245
+
246
+ Parameters:
247
+ dim (:obj:`int`):
248
+ Inner hidden states dimension
249
+ n_heads (:obj:`int`):
250
+ Number of heads
251
+ d_head (:obj:`int`):
252
+ Hidden states dimension inside each head
253
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
254
+ Dropout rate
255
+ only_cross_attention (`bool`, defaults to `False`):
256
+ Whether to only apply cross attention.
257
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
258
+ Parameters `dtype`
259
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
260
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
261
+ split_head_dim (`bool`, *optional*, defaults to `False`):
262
+ Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
263
+ enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
264
+ """
265
+ dim: int
266
+ n_heads: int
267
+ d_head: int
268
+ dropout: float = 0.0
269
+ only_cross_attention: bool = False
270
+ dtype: jnp.dtype = jnp.float32
271
+ use_memory_efficient_attention: bool = False
272
+ split_head_dim: bool = False
273
+
274
+ def setup(self):
275
+ # self attention (or cross_attention if only_cross_attention is True)
276
+ self.attn1 = FlaxAttention(
277
+ self.dim,
278
+ self.n_heads,
279
+ self.d_head,
280
+ self.dropout,
281
+ self.use_memory_efficient_attention,
282
+ self.split_head_dim,
283
+ dtype=self.dtype,
284
+ )
285
+ # cross attention
286
+ self.attn2 = FlaxAttention(
287
+ self.dim,
288
+ self.n_heads,
289
+ self.d_head,
290
+ self.dropout,
291
+ self.use_memory_efficient_attention,
292
+ self.split_head_dim,
293
+ dtype=self.dtype,
294
+ )
295
+ self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
296
+ self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
297
+ self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
298
+ self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
299
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
300
+
301
+ def __call__(self, hidden_states, context, deterministic=True):
302
+ # self attention
303
+ residual = hidden_states
304
+ if self.only_cross_attention:
305
+ hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
306
+ else:
307
+ hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
308
+ hidden_states = hidden_states + residual
309
+
310
+ # cross attention
311
+ residual = hidden_states
312
+ hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
313
+ hidden_states = hidden_states + residual
314
+
315
+ # feed forward
316
+ residual = hidden_states
317
+ hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
318
+ hidden_states = hidden_states + residual
319
+
320
+ return self.dropout_layer(hidden_states, deterministic=deterministic)
321
+
322
+
323
+ class FlaxTransformer2DModel(nn.Module):
324
+ r"""
325
+ A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
326
+ https://arxiv.org/pdf/1506.02025.pdf
327
+
328
+
329
+ Parameters:
330
+ in_channels (:obj:`int`):
331
+ Input number of channels
332
+ n_heads (:obj:`int`):
333
+ Number of heads
334
+ d_head (:obj:`int`):
335
+ Hidden states dimension inside each head
336
+ depth (:obj:`int`, *optional*, defaults to 1):
337
+ Number of transformers block
338
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
339
+ Dropout rate
340
+ use_linear_projection (`bool`, defaults to `False`): tbd
341
+ only_cross_attention (`bool`, defaults to `False`): tbd
342
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
343
+ Parameters `dtype`
344
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
345
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
346
+ split_head_dim (`bool`, *optional*, defaults to `False`):
347
+ Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
348
+ enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
349
+ """
350
+ in_channels: int
351
+ n_heads: int
352
+ d_head: int
353
+ depth: int = 1
354
+ dropout: float = 0.0
355
+ use_linear_projection: bool = False
356
+ only_cross_attention: bool = False
357
+ dtype: jnp.dtype = jnp.float32
358
+ use_memory_efficient_attention: bool = False
359
+ split_head_dim: bool = False
360
+
361
+ def setup(self):
362
+ self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
363
+
364
+ inner_dim = self.n_heads * self.d_head
365
+ if self.use_linear_projection:
366
+ self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
367
+ else:
368
+ self.proj_in = nn.Conv(
369
+ inner_dim,
370
+ kernel_size=(1, 1),
371
+ strides=(1, 1),
372
+ padding="VALID",
373
+ dtype=self.dtype,
374
+ )
375
+
376
+ self.transformer_blocks = [
377
+ FlaxBasicTransformerBlock(
378
+ inner_dim,
379
+ self.n_heads,
380
+ self.d_head,
381
+ dropout=self.dropout,
382
+ only_cross_attention=self.only_cross_attention,
383
+ dtype=self.dtype,
384
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
385
+ split_head_dim=self.split_head_dim,
386
+ )
387
+ for _ in range(self.depth)
388
+ ]
389
+
390
+ if self.use_linear_projection:
391
+ self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
392
+ else:
393
+ self.proj_out = nn.Conv(
394
+ inner_dim,
395
+ kernel_size=(1, 1),
396
+ strides=(1, 1),
397
+ padding="VALID",
398
+ dtype=self.dtype,
399
+ )
400
+
401
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
402
+
403
+ def __call__(self, hidden_states, context, deterministic=True):
404
+ batch, height, width, channels = hidden_states.shape
405
+ residual = hidden_states
406
+ hidden_states = self.norm(hidden_states)
407
+ if self.use_linear_projection:
408
+ hidden_states = hidden_states.reshape(batch, height * width, channels)
409
+ hidden_states = self.proj_in(hidden_states)
410
+ else:
411
+ hidden_states = self.proj_in(hidden_states)
412
+ hidden_states = hidden_states.reshape(batch, height * width, channels)
413
+
414
+ for transformer_block in self.transformer_blocks:
415
+ hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
416
+
417
+ if self.use_linear_projection:
418
+ hidden_states = self.proj_out(hidden_states)
419
+ hidden_states = hidden_states.reshape(batch, height, width, channels)
420
+ else:
421
+ hidden_states = hidden_states.reshape(batch, height, width, channels)
422
+ hidden_states = self.proj_out(hidden_states)
423
+
424
+ hidden_states = hidden_states + residual
425
+ return self.dropout_layer(hidden_states, deterministic=deterministic)
426
+
427
+
428
+ class FlaxFeedForward(nn.Module):
429
+ r"""
430
+ Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
431
+ [`FeedForward`] class, with the following simplifications:
432
+ - The activation function is currently hardcoded to a gated linear unit from:
433
+ https://arxiv.org/abs/2002.05202
434
+ - `dim_out` is equal to `dim`.
435
+ - The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
436
+
437
+ Parameters:
438
+ dim (:obj:`int`):
439
+ Inner hidden states dimension
440
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
441
+ Dropout rate
442
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
443
+ Parameters `dtype`
444
+ """
445
+ dim: int
446
+ dropout: float = 0.0
447
+ dtype: jnp.dtype = jnp.float32
448
+
449
+ def setup(self):
450
+ # The second linear layer needs to be called
451
+ # net_2 for now to match the index of the Sequential layer
452
+ self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
453
+ self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
454
+
455
+ def __call__(self, hidden_states, deterministic=True):
456
+ hidden_states = self.net_0(hidden_states, deterministic=deterministic)
457
+ hidden_states = self.net_2(hidden_states)
458
+ return hidden_states
459
+
460
+
461
+ class FlaxGEGLU(nn.Module):
462
+ r"""
463
+ Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
464
+ https://arxiv.org/abs/2002.05202.
465
+
466
+ Parameters:
467
+ dim (:obj:`int`):
468
+ Input hidden states dimension
469
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
470
+ Dropout rate
471
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
472
+ Parameters `dtype`
473
+ """
474
+ dim: int
475
+ dropout: float = 0.0
476
+ dtype: jnp.dtype = jnp.float32
477
+
478
+ def setup(self):
479
+ inner_dim = self.dim * 4
480
+ self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
481
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
482
+
483
+ def __call__(self, hidden_states, deterministic=True):
484
+ hidden_states = self.proj(hidden_states)
485
+ hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
486
+ return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
diffusers/models/attention_processor.py ADDED
@@ -0,0 +1,2020 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from importlib import import_module
15
+ from typing import Callable, Optional, Union
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from ..utils import USE_PEFT_BACKEND, deprecate, logging
22
+ from ..utils.import_utils import is_xformers_available
23
+ from ..utils.torch_utils import maybe_allow_in_graph
24
+ from .lora import LoRACompatibleLinear, LoRALinearLayer
25
+
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ if is_xformers_available():
31
+ import xformers
32
+ import xformers.ops
33
+ else:
34
+ xformers = None
35
+
36
+
37
+ @maybe_allow_in_graph
38
+ class Attention(nn.Module):
39
+ r"""
40
+ A cross attention layer.
41
+
42
+ Parameters:
43
+ query_dim (`int`):
44
+ The number of channels in the query.
45
+ cross_attention_dim (`int`, *optional*):
46
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
47
+ heads (`int`, *optional*, defaults to 8):
48
+ The number of heads to use for multi-head attention.
49
+ dim_head (`int`, *optional*, defaults to 64):
50
+ The number of channels in each head.
51
+ dropout (`float`, *optional*, defaults to 0.0):
52
+ The dropout probability to use.
53
+ bias (`bool`, *optional*, defaults to False):
54
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
55
+ upcast_attention (`bool`, *optional*, defaults to False):
56
+ Set to `True` to upcast the attention computation to `float32`.
57
+ upcast_softmax (`bool`, *optional*, defaults to False):
58
+ Set to `True` to upcast the softmax computation to `float32`.
59
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
60
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
61
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
62
+ The number of groups to use for the group norm in the cross attention.
63
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
64
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
65
+ norm_num_groups (`int`, *optional*, defaults to `None`):
66
+ The number of groups to use for the group norm in the attention.
67
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
68
+ The number of channels to use for the spatial normalization.
69
+ out_bias (`bool`, *optional*, defaults to `True`):
70
+ Set to `True` to use a bias in the output linear layer.
71
+ scale_qk (`bool`, *optional*, defaults to `True`):
72
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
73
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
74
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
75
+ `added_kv_proj_dim` is not `None`.
76
+ eps (`float`, *optional*, defaults to 1e-5):
77
+ An additional value added to the denominator in group normalization that is used for numerical stability.
78
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
79
+ A factor to rescale the output by dividing it with this value.
80
+ residual_connection (`bool`, *optional*, defaults to `False`):
81
+ Set to `True` to add the residual connection to the output.
82
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
83
+ Set to `True` if the attention block is loaded from a deprecated state dict.
84
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
85
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
86
+ `AttnProcessor` otherwise.
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ query_dim: int,
92
+ cross_attention_dim: Optional[int] = None,
93
+ heads: int = 8,
94
+ dim_head: int = 64,
95
+ dropout: float = 0.0,
96
+ bias: bool = False,
97
+ upcast_attention: bool = False,
98
+ upcast_softmax: bool = False,
99
+ cross_attention_norm: Optional[str] = None,
100
+ cross_attention_norm_num_groups: int = 32,
101
+ added_kv_proj_dim: Optional[int] = None,
102
+ norm_num_groups: Optional[int] = None,
103
+ spatial_norm_dim: Optional[int] = None,
104
+ out_bias: bool = True,
105
+ scale_qk: bool = True,
106
+ only_cross_attention: bool = False,
107
+ eps: float = 1e-5,
108
+ rescale_output_factor: float = 1.0,
109
+ residual_connection: bool = False,
110
+ _from_deprecated_attn_block: bool = False,
111
+ processor: Optional["AttnProcessor"] = None,
112
+ ):
113
+ super().__init__()
114
+ self.inner_dim = dim_head * heads
115
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
116
+ self.upcast_attention = upcast_attention
117
+ self.upcast_softmax = upcast_softmax
118
+ self.rescale_output_factor = rescale_output_factor
119
+ self.residual_connection = residual_connection
120
+ self.dropout = dropout
121
+
122
+ # we make use of this private variable to know whether this class is loaded
123
+ # with an deprecated state dict so that we can convert it on the fly
124
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
125
+
126
+ self.scale_qk = scale_qk
127
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
128
+
129
+ self.heads = heads
130
+ # for slice_size > 0 the attention score computation
131
+ # is split across the batch axis to save memory
132
+ # You can set slice_size with `set_attention_slice`
133
+ self.sliceable_head_dim = heads
134
+
135
+ self.added_kv_proj_dim = added_kv_proj_dim
136
+ self.only_cross_attention = only_cross_attention
137
+
138
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
139
+ raise ValueError(
140
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
141
+ )
142
+
143
+ if norm_num_groups is not None:
144
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
145
+ else:
146
+ self.group_norm = None
147
+
148
+ if spatial_norm_dim is not None:
149
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
150
+ else:
151
+ self.spatial_norm = None
152
+
153
+ if cross_attention_norm is None:
154
+ self.norm_cross = None
155
+ elif cross_attention_norm == "layer_norm":
156
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
157
+ elif cross_attention_norm == "group_norm":
158
+ if self.added_kv_proj_dim is not None:
159
+ # The given `encoder_hidden_states` are initially of shape
160
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
161
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
162
+ # before the projection, so we need to use `added_kv_proj_dim` as
163
+ # the number of channels for the group norm.
164
+ norm_cross_num_channels = added_kv_proj_dim
165
+ else:
166
+ norm_cross_num_channels = self.cross_attention_dim
167
+
168
+ self.norm_cross = nn.GroupNorm(
169
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
170
+ )
171
+ else:
172
+ raise ValueError(
173
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
174
+ )
175
+
176
+ if USE_PEFT_BACKEND:
177
+ linear_cls = nn.Linear
178
+ else:
179
+ linear_cls = LoRACompatibleLinear
180
+
181
+ self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
182
+
183
+ if not self.only_cross_attention:
184
+ # only relevant for the `AddedKVProcessor` classes
185
+ self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
186
+ self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
187
+ else:
188
+ self.to_k = None
189
+ self.to_v = None
190
+
191
+ if self.added_kv_proj_dim is not None:
192
+ self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
193
+ self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
194
+
195
+ self.to_out = nn.ModuleList([])
196
+ self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias))
197
+ self.to_out.append(nn.Dropout(dropout))
198
+
199
+ # set attention processor
200
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
201
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
202
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
203
+ if processor is None:
204
+ processor = (
205
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
206
+ )
207
+ self.set_processor(processor)
208
+
209
+ def set_use_memory_efficient_attention_xformers(
210
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
211
+ ) -> None:
212
+ r"""
213
+ Set whether to use memory efficient attention from `xformers` or not.
214
+
215
+ Args:
216
+ use_memory_efficient_attention_xformers (`bool`):
217
+ Whether to use memory efficient attention from `xformers` or not.
218
+ attention_op (`Callable`, *optional*):
219
+ The attention operation to use. Defaults to `None` which uses the default attention operation from
220
+ `xformers`.
221
+ """
222
+ is_lora = hasattr(self, "processor") and isinstance(
223
+ self.processor,
224
+ LORA_ATTENTION_PROCESSORS,
225
+ )
226
+ is_custom_diffusion = hasattr(self, "processor") and isinstance(
227
+ self.processor,
228
+ (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
229
+ )
230
+ is_added_kv_processor = hasattr(self, "processor") and isinstance(
231
+ self.processor,
232
+ (
233
+ AttnAddedKVProcessor,
234
+ AttnAddedKVProcessor2_0,
235
+ SlicedAttnAddedKVProcessor,
236
+ XFormersAttnAddedKVProcessor,
237
+ LoRAAttnAddedKVProcessor,
238
+ ),
239
+ )
240
+
241
+ if use_memory_efficient_attention_xformers:
242
+ if is_added_kv_processor and (is_lora or is_custom_diffusion):
243
+ raise NotImplementedError(
244
+ f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}"
245
+ )
246
+ if not is_xformers_available():
247
+ raise ModuleNotFoundError(
248
+ (
249
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
250
+ " xformers"
251
+ ),
252
+ name="xformers",
253
+ )
254
+ elif not torch.cuda.is_available():
255
+ raise ValueError(
256
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
257
+ " only available for GPU "
258
+ )
259
+ else:
260
+ try:
261
+ # Make sure we can run the memory efficient attention
262
+ _ = xformers.ops.memory_efficient_attention(
263
+ torch.randn((1, 2, 40), device="cuda"),
264
+ torch.randn((1, 2, 40), device="cuda"),
265
+ torch.randn((1, 2, 40), device="cuda"),
266
+ )
267
+ except Exception as e:
268
+ raise e
269
+
270
+ if is_lora:
271
+ # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
272
+ # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
273
+ processor = LoRAXFormersAttnProcessor(
274
+ hidden_size=self.processor.hidden_size,
275
+ cross_attention_dim=self.processor.cross_attention_dim,
276
+ rank=self.processor.rank,
277
+ attention_op=attention_op,
278
+ )
279
+ processor.load_state_dict(self.processor.state_dict())
280
+ processor.to(self.processor.to_q_lora.up.weight.device)
281
+ elif is_custom_diffusion:
282
+ processor = CustomDiffusionXFormersAttnProcessor(
283
+ train_kv=self.processor.train_kv,
284
+ train_q_out=self.processor.train_q_out,
285
+ hidden_size=self.processor.hidden_size,
286
+ cross_attention_dim=self.processor.cross_attention_dim,
287
+ attention_op=attention_op,
288
+ )
289
+ processor.load_state_dict(self.processor.state_dict())
290
+ if hasattr(self.processor, "to_k_custom_diffusion"):
291
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
292
+ elif is_added_kv_processor:
293
+ # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
294
+ # which uses this type of cross attention ONLY because the attention face_hair_mask of format
295
+ # [0, ..., -10.000, ..., 0, ...,] is not supported
296
+ # throw warning
297
+ logger.info(
298
+ "Memory efficient attention with `xformers` might currently not work correctly if an attention face_hair_mask is required for the attention operation."
299
+ )
300
+ processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
301
+ else:
302
+ processor = XFormersAttnProcessor(attention_op=attention_op)
303
+ else:
304
+ if is_lora:
305
+ attn_processor_class = (
306
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
307
+ )
308
+ processor = attn_processor_class(
309
+ hidden_size=self.processor.hidden_size,
310
+ cross_attention_dim=self.processor.cross_attention_dim,
311
+ rank=self.processor.rank,
312
+ )
313
+ processor.load_state_dict(self.processor.state_dict())
314
+ processor.to(self.processor.to_q_lora.up.weight.device)
315
+ elif is_custom_diffusion:
316
+ attn_processor_class = (
317
+ CustomDiffusionAttnProcessor2_0
318
+ if hasattr(F, "scaled_dot_product_attention")
319
+ else CustomDiffusionAttnProcessor
320
+ )
321
+ processor = attn_processor_class(
322
+ train_kv=self.processor.train_kv,
323
+ train_q_out=self.processor.train_q_out,
324
+ hidden_size=self.processor.hidden_size,
325
+ cross_attention_dim=self.processor.cross_attention_dim,
326
+ )
327
+ processor.load_state_dict(self.processor.state_dict())
328
+ if hasattr(self.processor, "to_k_custom_diffusion"):
329
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
330
+ else:
331
+ # set attention processor
332
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
333
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
334
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
335
+ processor = (
336
+ AttnProcessor2_0()
337
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
338
+ else AttnProcessor()
339
+ )
340
+
341
+ self.set_processor(processor)
342
+
343
+ def set_attention_slice(self, slice_size: int) -> None:
344
+ r"""
345
+ Set the slice size for attention computation.
346
+
347
+ Args:
348
+ slice_size (`int`):
349
+ The slice size for attention computation.
350
+ """
351
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
352
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
353
+
354
+ if slice_size is not None and self.added_kv_proj_dim is not None:
355
+ processor = SlicedAttnAddedKVProcessor(slice_size)
356
+ elif slice_size is not None:
357
+ processor = SlicedAttnProcessor(slice_size)
358
+ elif self.added_kv_proj_dim is not None:
359
+ processor = AttnAddedKVProcessor()
360
+ else:
361
+ # set attention processor
362
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
363
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
364
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
365
+ processor = (
366
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
367
+ )
368
+
369
+ self.set_processor(processor)
370
+
371
+ def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None:
372
+ r"""
373
+ Set the attention processor to use.
374
+
375
+ Args:
376
+ processor (`AttnProcessor`):
377
+ The attention processor to use.
378
+ _remove_lora (`bool`, *optional*, defaults to `False`):
379
+ Set to `True` to remove LoRA layers from the model.
380
+ """
381
+ if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
382
+ deprecate(
383
+ "set_processor to offload LoRA",
384
+ "0.26.0",
385
+ "In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
386
+ )
387
+ # TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
388
+ # We need to remove all LoRA layers
389
+ # Don't forget to remove ALL `_remove_lora` from the codebase
390
+ for module in self.modules():
391
+ if hasattr(module, "set_lora_layer"):
392
+ module.set_lora_layer(None)
393
+
394
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
395
+ # pop `processor` from `self._modules`
396
+ if (
397
+ hasattr(self, "processor")
398
+ and isinstance(self.processor, torch.nn.Module)
399
+ and not isinstance(processor, torch.nn.Module)
400
+ ):
401
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
402
+ self._modules.pop("processor")
403
+
404
+ self.processor = processor
405
+
406
+ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
407
+ r"""
408
+ Get the attention processor in use.
409
+
410
+ Args:
411
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
412
+ Set to `True` to return the deprecated LoRA attention processor.
413
+
414
+ Returns:
415
+ "AttentionProcessor": The attention processor in use.
416
+ """
417
+ if not return_deprecated_lora:
418
+ return self.processor
419
+
420
+ # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
421
+ # serialization format for LoRA Attention Processors. It should be deleted once the integration
422
+ # with PEFT is completed.
423
+ is_lora_activated = {
424
+ name: module.lora_layer is not None
425
+ for name, module in self.named_modules()
426
+ if hasattr(module, "lora_layer")
427
+ }
428
+
429
+ # 1. if no layer has a LoRA activated we can return the processor as usual
430
+ if not any(is_lora_activated.values()):
431
+ return self.processor
432
+
433
+ # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
434
+ is_lora_activated.pop("add_k_proj", None)
435
+ is_lora_activated.pop("add_v_proj", None)
436
+ # 2. else it is not posssible that only some layers have LoRA activated
437
+ if not all(is_lora_activated.values()):
438
+ raise ValueError(
439
+ f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
440
+ )
441
+
442
+ # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
443
+ non_lora_processor_cls_name = self.processor.__class__.__name__
444
+ lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
445
+
446
+ hidden_size = self.inner_dim
447
+
448
+ # now create a LoRA attention processor from the LoRA layers
449
+ if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
450
+ kwargs = {
451
+ "cross_attention_dim": self.cross_attention_dim,
452
+ "rank": self.to_q.lora_layer.rank,
453
+ "network_alpha": self.to_q.lora_layer.network_alpha,
454
+ "q_rank": self.to_q.lora_layer.rank,
455
+ "q_hidden_size": self.to_q.lora_layer.out_features,
456
+ "k_rank": self.to_k.lora_layer.rank,
457
+ "k_hidden_size": self.to_k.lora_layer.out_features,
458
+ "v_rank": self.to_v.lora_layer.rank,
459
+ "v_hidden_size": self.to_v.lora_layer.out_features,
460
+ "out_rank": self.to_out[0].lora_layer.rank,
461
+ "out_hidden_size": self.to_out[0].lora_layer.out_features,
462
+ }
463
+
464
+ if hasattr(self.processor, "attention_op"):
465
+ kwargs["attention_op"] = self.processor.attention_op
466
+
467
+ lora_processor = lora_processor_cls(hidden_size, **kwargs)
468
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
469
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
470
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
471
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
472
+ elif lora_processor_cls == LoRAAttnAddedKVProcessor:
473
+ lora_processor = lora_processor_cls(
474
+ hidden_size,
475
+ cross_attention_dim=self.add_k_proj.weight.shape[0],
476
+ rank=self.to_q.lora_layer.rank,
477
+ network_alpha=self.to_q.lora_layer.network_alpha,
478
+ )
479
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
480
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
481
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
482
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
483
+
484
+ # only save if used
485
+ if self.add_k_proj.lora_layer is not None:
486
+ lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
487
+ lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
488
+ else:
489
+ lora_processor.add_k_proj_lora = None
490
+ lora_processor.add_v_proj_lora = None
491
+ else:
492
+ raise ValueError(f"{lora_processor_cls} does not exist.")
493
+
494
+ return lora_processor
495
+
496
+ def forward(
497
+ self,
498
+ hidden_states: torch.FloatTensor,
499
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
500
+ attention_mask: Optional[torch.FloatTensor] = None,
501
+ **cross_attention_kwargs,
502
+ ) -> torch.Tensor:
503
+ r"""
504
+ The forward method of the `Attention` class.
505
+
506
+ Args:
507
+ hidden_states (`torch.Tensor`):
508
+ The hidden states of the query.
509
+ encoder_hidden_states (`torch.Tensor`, *optional*):
510
+ The hidden states of the encoder.
511
+ attention_mask (`torch.Tensor`, *optional*):
512
+ The attention face_hair_mask to use. If `None`, no face_hair_mask is applied.
513
+ **cross_attention_kwargs:
514
+ Additional keyword arguments to pass along to the cross attention.
515
+
516
+ Returns:
517
+ `torch.Tensor`: The output of the attention layer.
518
+ """
519
+ # The `Attention` class can call different attention processors / attention functions
520
+ # here we simply pass along all tensors to the selected processor class
521
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
522
+ return self.processor(
523
+ self,
524
+ hidden_states,
525
+ encoder_hidden_states=encoder_hidden_states,
526
+ attention_mask=attention_mask,
527
+ **cross_attention_kwargs,
528
+ )
529
+
530
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
531
+ r"""
532
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
533
+ is the number of heads initialized while constructing the `Attention` class.
534
+
535
+ Args:
536
+ tensor (`torch.Tensor`): The tensor to reshape.
537
+
538
+ Returns:
539
+ `torch.Tensor`: The reshaped tensor.
540
+ """
541
+ head_size = self.heads
542
+ batch_size, seq_len, dim = tensor.shape
543
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
544
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
545
+ return tensor
546
+
547
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
548
+ r"""
549
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
550
+ the number of heads initialized while constructing the `Attention` class.
551
+
552
+ Args:
553
+ tensor (`torch.Tensor`): The tensor to reshape.
554
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
555
+ reshaped to `[batch_size * heads, seq_len, dim // heads]`.
556
+
557
+ Returns:
558
+ `torch.Tensor`: The reshaped tensor.
559
+ """
560
+ head_size = self.heads
561
+ batch_size, seq_len, dim = tensor.shape
562
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
563
+ tensor = tensor.permute(0, 2, 1, 3)
564
+
565
+ if out_dim == 3:
566
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
567
+
568
+ return tensor
569
+
570
+ def get_attention_scores(
571
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
572
+ ) -> torch.Tensor:
573
+ r"""
574
+ Compute the attention scores.
575
+
576
+ Args:
577
+ query (`torch.Tensor`): The query tensor.
578
+ key (`torch.Tensor`): The key tensor.
579
+ attention_mask (`torch.Tensor`, *optional*): The attention face_hair_mask to use. If `None`, no face_hair_mask is applied.
580
+
581
+ Returns:
582
+ `torch.Tensor`: The attention probabilities/scores.
583
+ """
584
+ dtype = query.dtype
585
+ if self.upcast_attention:
586
+ query = query.float()
587
+ key = key.float()
588
+
589
+ if attention_mask is None:
590
+ baddbmm_input = torch.empty(
591
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
592
+ )
593
+ beta = 0
594
+ else:
595
+ baddbmm_input = attention_mask
596
+ beta = 1
597
+
598
+ attention_scores = torch.baddbmm(
599
+ baddbmm_input,
600
+ query,
601
+ key.transpose(-1, -2),
602
+ beta=beta,
603
+ alpha=self.scale,
604
+ )
605
+ del baddbmm_input
606
+
607
+ if self.upcast_softmax:
608
+ attention_scores = attention_scores.float()
609
+
610
+ attention_probs = attention_scores.softmax(dim=-1)
611
+ del attention_scores
612
+
613
+ attention_probs = attention_probs.to(dtype)
614
+
615
+ return attention_probs
616
+
617
+ def prepare_attention_mask(
618
+ self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
619
+ ) -> torch.Tensor:
620
+ r"""
621
+ Prepare the attention face_hair_mask for the attention computation.
622
+
623
+ Args:
624
+ attention_mask (`torch.Tensor`):
625
+ The attention face_hair_mask to prepare.
626
+ target_length (`int`):
627
+ The target length of the attention face_hair_mask. This is the length of the attention face_hair_mask after padding.
628
+ batch_size (`int`):
629
+ The batch size, which is used to repeat the attention face_hair_mask.
630
+ out_dim (`int`, *optional*, defaults to `3`):
631
+ The output dimension of the attention face_hair_mask. Can be either `3` or `4`.
632
+
633
+ Returns:
634
+ `torch.Tensor`: The prepared attention face_hair_mask.
635
+ """
636
+ head_size = self.heads
637
+ if attention_mask is None:
638
+ return attention_mask
639
+
640
+ current_length: int = attention_mask.shape[-1]
641
+ if current_length != target_length:
642
+ if attention_mask.device.type == "mps":
643
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
644
+ # Instead, we can manually construct the padding tensor.
645
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
646
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
647
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
648
+ else:
649
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn face_hair_mask:
650
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
651
+ # remaining_length: int = target_length - current_length
652
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
653
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
654
+
655
+ if out_dim == 3:
656
+ if attention_mask.shape[0] < batch_size * head_size:
657
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
658
+ elif out_dim == 4:
659
+ attention_mask = attention_mask.unsqueeze(1)
660
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
661
+
662
+ return attention_mask
663
+
664
+ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
665
+ r"""
666
+ Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
667
+ `Attention` class.
668
+
669
+ Args:
670
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
671
+
672
+ Returns:
673
+ `torch.Tensor`: The normalized encoder hidden states.
674
+ """
675
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
676
+
677
+ if isinstance(self.norm_cross, nn.LayerNorm):
678
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
679
+ elif isinstance(self.norm_cross, nn.GroupNorm):
680
+ # Group norm norms along the channels dimension and expects
681
+ # input to be in the shape of (N, C, *). In this case, we want
682
+ # to norm along the hidden dimension, so we need to move
683
+ # (batch_size, sequence_length, hidden_size) ->
684
+ # (batch_size, hidden_size, sequence_length)
685
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
686
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
687
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
688
+ else:
689
+ assert False
690
+
691
+ return encoder_hidden_states
692
+
693
+
694
+ class AttnProcessor:
695
+ r"""
696
+ Default processor for performing attention-related computations.
697
+ """
698
+
699
+ def __call__(
700
+ self,
701
+ attn: Attention,
702
+ hidden_states: torch.FloatTensor,
703
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
704
+ attention_mask: Optional[torch.FloatTensor] = None,
705
+ temb: Optional[torch.FloatTensor] = None,
706
+ scale: float = 1.0,
707
+ ) -> torch.Tensor:
708
+ residual = hidden_states
709
+
710
+ args = () if USE_PEFT_BACKEND else (scale,)
711
+
712
+ if attn.spatial_norm is not None:
713
+ hidden_states = attn.spatial_norm(hidden_states, temb)
714
+
715
+ input_ndim = hidden_states.ndim
716
+
717
+ if input_ndim == 4:
718
+ batch_size, channel, height, width = hidden_states.shape
719
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
720
+
721
+ batch_size, sequence_length, _ = (
722
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
723
+ )
724
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
725
+
726
+ if attn.group_norm is not None:
727
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
728
+
729
+ query = attn.to_q(hidden_states, *args)
730
+
731
+ if encoder_hidden_states is None:
732
+ encoder_hidden_states = hidden_states
733
+ elif attn.norm_cross:
734
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
735
+
736
+ key = attn.to_k(encoder_hidden_states, *args)
737
+ value = attn.to_v(encoder_hidden_states, *args)
738
+
739
+ query = attn.head_to_batch_dim(query)
740
+ key = attn.head_to_batch_dim(key)
741
+ value = attn.head_to_batch_dim(value)
742
+
743
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
744
+ hidden_states = torch.bmm(attention_probs, value)
745
+ hidden_states = attn.batch_to_head_dim(hidden_states)
746
+
747
+ # linear proj
748
+ hidden_states = attn.to_out[0](hidden_states, *args)
749
+ # dropout
750
+ hidden_states = attn.to_out[1](hidden_states)
751
+
752
+ if input_ndim == 4:
753
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
754
+
755
+ if attn.residual_connection:
756
+ hidden_states = hidden_states + residual
757
+
758
+ hidden_states = hidden_states / attn.rescale_output_factor
759
+
760
+ return hidden_states
761
+
762
+
763
+ class CustomDiffusionAttnProcessor(nn.Module):
764
+ r"""
765
+ Processor for implementing attention for the Custom Diffusion method.
766
+
767
+ Args:
768
+ train_kv (`bool`, defaults to `True`):
769
+ Whether to newly train the key and value matrices corresponding to the text features.
770
+ train_q_out (`bool`, defaults to `True`):
771
+ Whether to newly train query matrices corresponding to the latent image features.
772
+ hidden_size (`int`, *optional*, defaults to `None`):
773
+ The hidden size of the attention layer.
774
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
775
+ The number of channels in the `encoder_hidden_states`.
776
+ out_bias (`bool`, defaults to `True`):
777
+ Whether to include the bias parameter in `train_q_out`.
778
+ dropout (`float`, *optional*, defaults to 0.0):
779
+ The dropout probability to use.
780
+ """
781
+
782
+ def __init__(
783
+ self,
784
+ train_kv: bool = True,
785
+ train_q_out: bool = True,
786
+ hidden_size: Optional[int] = None,
787
+ cross_attention_dim: Optional[int] = None,
788
+ out_bias: bool = True,
789
+ dropout: float = 0.0,
790
+ ):
791
+ super().__init__()
792
+ self.train_kv = train_kv
793
+ self.train_q_out = train_q_out
794
+
795
+ self.hidden_size = hidden_size
796
+ self.cross_attention_dim = cross_attention_dim
797
+
798
+ # `_custom_diffusion` id for easy serialization and loading.
799
+ if self.train_kv:
800
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
801
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
802
+ if self.train_q_out:
803
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
804
+ self.to_out_custom_diffusion = nn.ModuleList([])
805
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
806
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
807
+
808
+ def __call__(
809
+ self,
810
+ attn: Attention,
811
+ hidden_states: torch.FloatTensor,
812
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
813
+ attention_mask: Optional[torch.FloatTensor] = None,
814
+ ) -> torch.Tensor:
815
+ batch_size, sequence_length, _ = hidden_states.shape
816
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
817
+ if self.train_q_out:
818
+ query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
819
+ else:
820
+ query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
821
+
822
+ if encoder_hidden_states is None:
823
+ crossattn = False
824
+ encoder_hidden_states = hidden_states
825
+ else:
826
+ crossattn = True
827
+ if attn.norm_cross:
828
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
829
+
830
+ if self.train_kv:
831
+ key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
832
+ value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
833
+ key = key.to(attn.to_q.weight.dtype)
834
+ value = value.to(attn.to_q.weight.dtype)
835
+ else:
836
+ key = attn.to_k(encoder_hidden_states)
837
+ value = attn.to_v(encoder_hidden_states)
838
+
839
+ if crossattn:
840
+ detach = torch.ones_like(key)
841
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
842
+ key = detach * key + (1 - detach) * key.detach()
843
+ value = detach * value + (1 - detach) * value.detach()
844
+
845
+ query = attn.head_to_batch_dim(query)
846
+ key = attn.head_to_batch_dim(key)
847
+ value = attn.head_to_batch_dim(value)
848
+
849
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
850
+ hidden_states = torch.bmm(attention_probs, value)
851
+ hidden_states = attn.batch_to_head_dim(hidden_states)
852
+
853
+ if self.train_q_out:
854
+ # linear proj
855
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
856
+ # dropout
857
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
858
+ else:
859
+ # linear proj
860
+ hidden_states = attn.to_out[0](hidden_states)
861
+ # dropout
862
+ hidden_states = attn.to_out[1](hidden_states)
863
+
864
+ return hidden_states
865
+
866
+
867
+ class AttnAddedKVProcessor:
868
+ r"""
869
+ Processor for performing attention-related computations with extra learnable key and value matrices for the text
870
+ encoder.
871
+ """
872
+
873
+ def __call__(
874
+ self,
875
+ attn: Attention,
876
+ hidden_states: torch.FloatTensor,
877
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
878
+ attention_mask: Optional[torch.FloatTensor] = None,
879
+ scale: float = 1.0,
880
+ ) -> torch.Tensor:
881
+ residual = hidden_states
882
+
883
+ args = () if USE_PEFT_BACKEND else (scale,)
884
+
885
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
886
+ batch_size, sequence_length, _ = hidden_states.shape
887
+
888
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
889
+
890
+ if encoder_hidden_states is None:
891
+ encoder_hidden_states = hidden_states
892
+ elif attn.norm_cross:
893
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
894
+
895
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
896
+
897
+ query = attn.to_q(hidden_states, *args)
898
+ query = attn.head_to_batch_dim(query)
899
+
900
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, *args)
901
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, *args)
902
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
903
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
904
+
905
+ if not attn.only_cross_attention:
906
+ key = attn.to_k(hidden_states, *args)
907
+ value = attn.to_v(hidden_states, *args)
908
+ key = attn.head_to_batch_dim(key)
909
+ value = attn.head_to_batch_dim(value)
910
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
911
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
912
+ else:
913
+ key = encoder_hidden_states_key_proj
914
+ value = encoder_hidden_states_value_proj
915
+
916
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
917
+ hidden_states = torch.bmm(attention_probs, value)
918
+ hidden_states = attn.batch_to_head_dim(hidden_states)
919
+
920
+ # linear proj
921
+ hidden_states = attn.to_out[0](hidden_states, *args)
922
+ # dropout
923
+ hidden_states = attn.to_out[1](hidden_states)
924
+
925
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
926
+ hidden_states = hidden_states + residual
927
+
928
+ return hidden_states
929
+
930
+
931
+ class AttnAddedKVProcessor2_0:
932
+ r"""
933
+ Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
934
+ learnable key and value matrices for the text encoder.
935
+ """
936
+
937
+ def __init__(self):
938
+ if not hasattr(F, "scaled_dot_product_attention"):
939
+ raise ImportError(
940
+ "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
941
+ )
942
+
943
+ def __call__(
944
+ self,
945
+ attn: Attention,
946
+ hidden_states: torch.FloatTensor,
947
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
948
+ attention_mask: Optional[torch.FloatTensor] = None,
949
+ scale: float = 1.0,
950
+ ) -> torch.Tensor:
951
+ residual = hidden_states
952
+
953
+ args = () if USE_PEFT_BACKEND else (scale,)
954
+
955
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
956
+ batch_size, sequence_length, _ = hidden_states.shape
957
+
958
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
959
+
960
+ if encoder_hidden_states is None:
961
+ encoder_hidden_states = hidden_states
962
+ elif attn.norm_cross:
963
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
964
+
965
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
966
+
967
+ query = attn.to_q(hidden_states, *args)
968
+ query = attn.head_to_batch_dim(query, out_dim=4)
969
+
970
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
971
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
972
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
973
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
974
+
975
+ if not attn.only_cross_attention:
976
+ key = attn.to_k(hidden_states, *args)
977
+ value = attn.to_v(hidden_states, *args)
978
+ key = attn.head_to_batch_dim(key, out_dim=4)
979
+ value = attn.head_to_batch_dim(value, out_dim=4)
980
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
981
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
982
+ else:
983
+ key = encoder_hidden_states_key_proj
984
+ value = encoder_hidden_states_value_proj
985
+
986
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
987
+ # TODO: add support for attn.scale when we move to Torch 2.1
988
+ hidden_states = F.scaled_dot_product_attention(
989
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
990
+ )
991
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
992
+
993
+ # linear proj
994
+ hidden_states = attn.to_out[0](hidden_states, *args)
995
+ # dropout
996
+ hidden_states = attn.to_out[1](hidden_states)
997
+
998
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
999
+ hidden_states = hidden_states + residual
1000
+
1001
+ return hidden_states
1002
+
1003
+
1004
+ class XFormersAttnAddedKVProcessor:
1005
+ r"""
1006
+ Processor for implementing memory efficient attention using xFormers.
1007
+
1008
+ Args:
1009
+ attention_op (`Callable`, *optional*, defaults to `None`):
1010
+ The base
1011
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1012
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1013
+ operator.
1014
+ """
1015
+
1016
+ def __init__(self, attention_op: Optional[Callable] = None):
1017
+ self.attention_op = attention_op
1018
+
1019
+ def __call__(
1020
+ self,
1021
+ attn: Attention,
1022
+ hidden_states: torch.FloatTensor,
1023
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1024
+ attention_mask: Optional[torch.FloatTensor] = None,
1025
+ ) -> torch.Tensor:
1026
+ residual = hidden_states
1027
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
1028
+ batch_size, sequence_length, _ = hidden_states.shape
1029
+
1030
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1031
+
1032
+ if encoder_hidden_states is None:
1033
+ encoder_hidden_states = hidden_states
1034
+ elif attn.norm_cross:
1035
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1036
+
1037
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1038
+
1039
+ query = attn.to_q(hidden_states)
1040
+ query = attn.head_to_batch_dim(query)
1041
+
1042
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1043
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1044
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
1045
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
1046
+
1047
+ if not attn.only_cross_attention:
1048
+ key = attn.to_k(hidden_states)
1049
+ value = attn.to_v(hidden_states)
1050
+ key = attn.head_to_batch_dim(key)
1051
+ value = attn.head_to_batch_dim(value)
1052
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1053
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1054
+ else:
1055
+ key = encoder_hidden_states_key_proj
1056
+ value = encoder_hidden_states_value_proj
1057
+
1058
+ hidden_states = xformers.ops.memory_efficient_attention(
1059
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1060
+ )
1061
+ hidden_states = hidden_states.to(query.dtype)
1062
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1063
+
1064
+ # linear proj
1065
+ hidden_states = attn.to_out[0](hidden_states)
1066
+ # dropout
1067
+ hidden_states = attn.to_out[1](hidden_states)
1068
+
1069
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1070
+ hidden_states = hidden_states + residual
1071
+
1072
+ return hidden_states
1073
+
1074
+
1075
+ class XFormersAttnProcessor:
1076
+ r"""
1077
+ Processor for implementing memory efficient attention using xFormers.
1078
+
1079
+ Args:
1080
+ attention_op (`Callable`, *optional*, defaults to `None`):
1081
+ The base
1082
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1083
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1084
+ operator.
1085
+ """
1086
+
1087
+ def __init__(self, attention_op: Optional[Callable] = None):
1088
+ self.attention_op = attention_op
1089
+
1090
+ def __call__(
1091
+ self,
1092
+ attn: Attention,
1093
+ hidden_states: torch.FloatTensor,
1094
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1095
+ attention_mask: Optional[torch.FloatTensor] = None,
1096
+ temb: Optional[torch.FloatTensor] = None,
1097
+ scale: float = 1.0,
1098
+ ) -> torch.FloatTensor:
1099
+ residual = hidden_states
1100
+
1101
+ args = () if USE_PEFT_BACKEND else (scale,)
1102
+
1103
+ if attn.spatial_norm is not None:
1104
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1105
+
1106
+ input_ndim = hidden_states.ndim
1107
+
1108
+ if input_ndim == 4:
1109
+ batch_size, channel, height, width = hidden_states.shape
1110
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1111
+
1112
+ batch_size, key_tokens, _ = (
1113
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1114
+ )
1115
+
1116
+ attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
1117
+ if attention_mask is not None:
1118
+ # expand our face_hair_mask's singleton query_tokens dimension:
1119
+ # [batch*heads, 1, key_tokens] ->
1120
+ # [batch*heads, query_tokens, key_tokens]
1121
+ # so that it can be added as a bias onto the attention scores that xformers computes:
1122
+ # [batch*heads, query_tokens, key_tokens]
1123
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
1124
+ _, query_tokens, _ = hidden_states.shape
1125
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
1126
+
1127
+ if attn.group_norm is not None:
1128
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1129
+
1130
+ query = attn.to_q(hidden_states, *args)
1131
+
1132
+ if encoder_hidden_states is None:
1133
+ encoder_hidden_states = hidden_states
1134
+ elif attn.norm_cross:
1135
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1136
+
1137
+ key = attn.to_k(encoder_hidden_states, *args)
1138
+ value = attn.to_v(encoder_hidden_states, *args)
1139
+
1140
+ query = attn.head_to_batch_dim(query).contiguous()
1141
+ key = attn.head_to_batch_dim(key).contiguous()
1142
+ value = attn.head_to_batch_dim(value).contiguous()
1143
+
1144
+ hidden_states = xformers.ops.memory_efficient_attention(
1145
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1146
+ )
1147
+ hidden_states = hidden_states.to(query.dtype)
1148
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1149
+
1150
+ # linear proj
1151
+ hidden_states = attn.to_out[0](hidden_states, *args)
1152
+ # dropout
1153
+ hidden_states = attn.to_out[1](hidden_states)
1154
+
1155
+ if input_ndim == 4:
1156
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1157
+
1158
+ if attn.residual_connection:
1159
+ hidden_states = hidden_states + residual
1160
+
1161
+ hidden_states = hidden_states / attn.rescale_output_factor
1162
+
1163
+ return hidden_states
1164
+
1165
+
1166
+ class AttnProcessor2_0:
1167
+ r"""
1168
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
1169
+ """
1170
+
1171
+ def __init__(self):
1172
+ if not hasattr(F, "scaled_dot_product_attention"):
1173
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1174
+
1175
+ def __call__(
1176
+ self,
1177
+ attn: Attention,
1178
+ hidden_states: torch.FloatTensor,
1179
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1180
+ attention_mask: Optional[torch.FloatTensor] = None,
1181
+ temb: Optional[torch.FloatTensor] = None,
1182
+ scale: float = 1.0,
1183
+ ) -> torch.FloatTensor:
1184
+ residual = hidden_states
1185
+
1186
+ args = () if USE_PEFT_BACKEND else (scale,)
1187
+
1188
+ if attn.spatial_norm is not None:
1189
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1190
+
1191
+ input_ndim = hidden_states.ndim
1192
+
1193
+ if input_ndim == 4:
1194
+ batch_size, channel, height, width = hidden_states.shape
1195
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1196
+
1197
+ batch_size, sequence_length, _ = (
1198
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1199
+ )
1200
+
1201
+ if attention_mask is not None:
1202
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1203
+ # scaled_dot_product_attention expects attention_mask shape to be
1204
+ # (batch, heads, source_length, target_length)
1205
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1206
+
1207
+ if attn.group_norm is not None:
1208
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1209
+
1210
+ args = () if USE_PEFT_BACKEND else (scale,)
1211
+ query = attn.to_q(hidden_states, *args)
1212
+
1213
+ if encoder_hidden_states is None:
1214
+ encoder_hidden_states = hidden_states
1215
+ elif attn.norm_cross:
1216
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1217
+
1218
+ key = attn.to_k(encoder_hidden_states, *args)
1219
+ value = attn.to_v(encoder_hidden_states, *args)
1220
+
1221
+ inner_dim = key.shape[-1]
1222
+ head_dim = inner_dim // attn.heads
1223
+
1224
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1225
+
1226
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1227
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1228
+
1229
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1230
+ # TODO: add support for attn.scale when we move to Torch 2.1
1231
+ hidden_states = F.scaled_dot_product_attention(
1232
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1233
+ )
1234
+
1235
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1236
+ hidden_states = hidden_states.to(query.dtype)
1237
+
1238
+ # linear proj
1239
+ hidden_states = attn.to_out[0](hidden_states, *args)
1240
+ # dropout
1241
+ hidden_states = attn.to_out[1](hidden_states)
1242
+
1243
+ if input_ndim == 4:
1244
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1245
+
1246
+ if attn.residual_connection:
1247
+ hidden_states = hidden_states + residual
1248
+
1249
+ hidden_states = hidden_states / attn.rescale_output_factor
1250
+
1251
+ return hidden_states
1252
+
1253
+
1254
+ class CustomDiffusionXFormersAttnProcessor(nn.Module):
1255
+ r"""
1256
+ Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
1257
+
1258
+ Args:
1259
+ train_kv (`bool`, defaults to `True`):
1260
+ Whether to newly train the key and value matrices corresponding to the text features.
1261
+ train_q_out (`bool`, defaults to `True`):
1262
+ Whether to newly train query matrices corresponding to the latent image features.
1263
+ hidden_size (`int`, *optional*, defaults to `None`):
1264
+ The hidden size of the attention layer.
1265
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
1266
+ The number of channels in the `encoder_hidden_states`.
1267
+ out_bias (`bool`, defaults to `True`):
1268
+ Whether to include the bias parameter in `train_q_out`.
1269
+ dropout (`float`, *optional*, defaults to 0.0):
1270
+ The dropout probability to use.
1271
+ attention_op (`Callable`, *optional*, defaults to `None`):
1272
+ The base
1273
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
1274
+ as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
1275
+ """
1276
+
1277
+ def __init__(
1278
+ self,
1279
+ train_kv: bool = True,
1280
+ train_q_out: bool = False,
1281
+ hidden_size: Optional[int] = None,
1282
+ cross_attention_dim: Optional[int] = None,
1283
+ out_bias: bool = True,
1284
+ dropout: float = 0.0,
1285
+ attention_op: Optional[Callable] = None,
1286
+ ):
1287
+ super().__init__()
1288
+ self.train_kv = train_kv
1289
+ self.train_q_out = train_q_out
1290
+
1291
+ self.hidden_size = hidden_size
1292
+ self.cross_attention_dim = cross_attention_dim
1293
+ self.attention_op = attention_op
1294
+
1295
+ # `_custom_diffusion` id for easy serialization and loading.
1296
+ if self.train_kv:
1297
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1298
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1299
+ if self.train_q_out:
1300
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
1301
+ self.to_out_custom_diffusion = nn.ModuleList([])
1302
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
1303
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
1304
+
1305
+ def __call__(
1306
+ self,
1307
+ attn: Attention,
1308
+ hidden_states: torch.FloatTensor,
1309
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1310
+ attention_mask: Optional[torch.FloatTensor] = None,
1311
+ ) -> torch.FloatTensor:
1312
+ batch_size, sequence_length, _ = (
1313
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1314
+ )
1315
+
1316
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1317
+
1318
+ if self.train_q_out:
1319
+ query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
1320
+ else:
1321
+ query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
1322
+
1323
+ if encoder_hidden_states is None:
1324
+ crossattn = False
1325
+ encoder_hidden_states = hidden_states
1326
+ else:
1327
+ crossattn = True
1328
+ if attn.norm_cross:
1329
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1330
+
1331
+ if self.train_kv:
1332
+ key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
1333
+ value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
1334
+ key = key.to(attn.to_q.weight.dtype)
1335
+ value = value.to(attn.to_q.weight.dtype)
1336
+ else:
1337
+ key = attn.to_k(encoder_hidden_states)
1338
+ value = attn.to_v(encoder_hidden_states)
1339
+
1340
+ if crossattn:
1341
+ detach = torch.ones_like(key)
1342
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
1343
+ key = detach * key + (1 - detach) * key.detach()
1344
+ value = detach * value + (1 - detach) * value.detach()
1345
+
1346
+ query = attn.head_to_batch_dim(query).contiguous()
1347
+ key = attn.head_to_batch_dim(key).contiguous()
1348
+ value = attn.head_to_batch_dim(value).contiguous()
1349
+
1350
+ hidden_states = xformers.ops.memory_efficient_attention(
1351
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1352
+ )
1353
+ hidden_states = hidden_states.to(query.dtype)
1354
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1355
+
1356
+ if self.train_q_out:
1357
+ # linear proj
1358
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
1359
+ # dropout
1360
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
1361
+ else:
1362
+ # linear proj
1363
+ hidden_states = attn.to_out[0](hidden_states)
1364
+ # dropout
1365
+ hidden_states = attn.to_out[1](hidden_states)
1366
+
1367
+ return hidden_states
1368
+
1369
+
1370
+ class CustomDiffusionAttnProcessor2_0(nn.Module):
1371
+ r"""
1372
+ Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0’s memory-efficient scaled
1373
+ dot-product attention.
1374
+
1375
+ Args:
1376
+ train_kv (`bool`, defaults to `True`):
1377
+ Whether to newly train the key and value matrices corresponding to the text features.
1378
+ train_q_out (`bool`, defaults to `True`):
1379
+ Whether to newly train query matrices corresponding to the latent image features.
1380
+ hidden_size (`int`, *optional*, defaults to `None`):
1381
+ The hidden size of the attention layer.
1382
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
1383
+ The number of channels in the `encoder_hidden_states`.
1384
+ out_bias (`bool`, defaults to `True`):
1385
+ Whether to include the bias parameter in `train_q_out`.
1386
+ dropout (`float`, *optional*, defaults to 0.0):
1387
+ The dropout probability to use.
1388
+ """
1389
+
1390
+ def __init__(
1391
+ self,
1392
+ train_kv: bool = True,
1393
+ train_q_out: bool = True,
1394
+ hidden_size: Optional[int] = None,
1395
+ cross_attention_dim: Optional[int] = None,
1396
+ out_bias: bool = True,
1397
+ dropout: float = 0.0,
1398
+ ):
1399
+ super().__init__()
1400
+ self.train_kv = train_kv
1401
+ self.train_q_out = train_q_out
1402
+
1403
+ self.hidden_size = hidden_size
1404
+ self.cross_attention_dim = cross_attention_dim
1405
+
1406
+ # `_custom_diffusion` id for easy serialization and loading.
1407
+ if self.train_kv:
1408
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1409
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1410
+ if self.train_q_out:
1411
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
1412
+ self.to_out_custom_diffusion = nn.ModuleList([])
1413
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
1414
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
1415
+
1416
+ def __call__(
1417
+ self,
1418
+ attn: Attention,
1419
+ hidden_states: torch.FloatTensor,
1420
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1421
+ attention_mask: Optional[torch.FloatTensor] = None,
1422
+ ) -> torch.FloatTensor:
1423
+ batch_size, sequence_length, _ = hidden_states.shape
1424
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1425
+ if self.train_q_out:
1426
+ query = self.to_q_custom_diffusion(hidden_states)
1427
+ else:
1428
+ query = attn.to_q(hidden_states)
1429
+
1430
+ if encoder_hidden_states is None:
1431
+ crossattn = False
1432
+ encoder_hidden_states = hidden_states
1433
+ else:
1434
+ crossattn = True
1435
+ if attn.norm_cross:
1436
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1437
+
1438
+ if self.train_kv:
1439
+ key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
1440
+ value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
1441
+ key = key.to(attn.to_q.weight.dtype)
1442
+ value = value.to(attn.to_q.weight.dtype)
1443
+
1444
+ else:
1445
+ key = attn.to_k(encoder_hidden_states)
1446
+ value = attn.to_v(encoder_hidden_states)
1447
+
1448
+ if crossattn:
1449
+ detach = torch.ones_like(key)
1450
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
1451
+ key = detach * key + (1 - detach) * key.detach()
1452
+ value = detach * value + (1 - detach) * value.detach()
1453
+
1454
+ inner_dim = hidden_states.shape[-1]
1455
+
1456
+ head_dim = inner_dim // attn.heads
1457
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1458
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1459
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1460
+
1461
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1462
+ # TODO: add support for attn.scale when we move to Torch 2.1
1463
+ hidden_states = F.scaled_dot_product_attention(
1464
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1465
+ )
1466
+
1467
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1468
+ hidden_states = hidden_states.to(query.dtype)
1469
+
1470
+ if self.train_q_out:
1471
+ # linear proj
1472
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
1473
+ # dropout
1474
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
1475
+ else:
1476
+ # linear proj
1477
+ hidden_states = attn.to_out[0](hidden_states)
1478
+ # dropout
1479
+ hidden_states = attn.to_out[1](hidden_states)
1480
+
1481
+ return hidden_states
1482
+
1483
+
1484
+ class SlicedAttnProcessor:
1485
+ r"""
1486
+ Processor for implementing sliced attention.
1487
+
1488
+ Args:
1489
+ slice_size (`int`, *optional*):
1490
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1491
+ `attention_head_dim` must be a multiple of the `slice_size`.
1492
+ """
1493
+
1494
+ def __init__(self, slice_size: int):
1495
+ self.slice_size = slice_size
1496
+
1497
+ def __call__(
1498
+ self,
1499
+ attn: Attention,
1500
+ hidden_states: torch.FloatTensor,
1501
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1502
+ attention_mask: Optional[torch.FloatTensor] = None,
1503
+ ) -> torch.FloatTensor:
1504
+ residual = hidden_states
1505
+
1506
+ input_ndim = hidden_states.ndim
1507
+
1508
+ if input_ndim == 4:
1509
+ batch_size, channel, height, width = hidden_states.shape
1510
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1511
+
1512
+ batch_size, sequence_length, _ = (
1513
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1514
+ )
1515
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1516
+
1517
+ if attn.group_norm is not None:
1518
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1519
+
1520
+ query = attn.to_q(hidden_states)
1521
+ dim = query.shape[-1]
1522
+ query = attn.head_to_batch_dim(query)
1523
+
1524
+ if encoder_hidden_states is None:
1525
+ encoder_hidden_states = hidden_states
1526
+ elif attn.norm_cross:
1527
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1528
+
1529
+ key = attn.to_k(encoder_hidden_states)
1530
+ value = attn.to_v(encoder_hidden_states)
1531
+ key = attn.head_to_batch_dim(key)
1532
+ value = attn.head_to_batch_dim(value)
1533
+
1534
+ batch_size_attention, query_tokens, _ = query.shape
1535
+ hidden_states = torch.zeros(
1536
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
1537
+ )
1538
+
1539
+ for i in range(batch_size_attention // self.slice_size):
1540
+ start_idx = i * self.slice_size
1541
+ end_idx = (i + 1) * self.slice_size
1542
+
1543
+ query_slice = query[start_idx:end_idx]
1544
+ key_slice = key[start_idx:end_idx]
1545
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
1546
+
1547
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
1548
+
1549
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
1550
+
1551
+ hidden_states[start_idx:end_idx] = attn_slice
1552
+
1553
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1554
+
1555
+ # linear proj
1556
+ hidden_states = attn.to_out[0](hidden_states)
1557
+ # dropout
1558
+ hidden_states = attn.to_out[1](hidden_states)
1559
+
1560
+ if input_ndim == 4:
1561
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1562
+
1563
+ if attn.residual_connection:
1564
+ hidden_states = hidden_states + residual
1565
+
1566
+ hidden_states = hidden_states / attn.rescale_output_factor
1567
+
1568
+ return hidden_states
1569
+
1570
+
1571
+ class SlicedAttnAddedKVProcessor:
1572
+ r"""
1573
+ Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
1574
+
1575
+ Args:
1576
+ slice_size (`int`, *optional*):
1577
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1578
+ `attention_head_dim` must be a multiple of the `slice_size`.
1579
+ """
1580
+
1581
+ def __init__(self, slice_size):
1582
+ self.slice_size = slice_size
1583
+
1584
+ def __call__(
1585
+ self,
1586
+ attn: "Attention",
1587
+ hidden_states: torch.FloatTensor,
1588
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1589
+ attention_mask: Optional[torch.FloatTensor] = None,
1590
+ temb: Optional[torch.FloatTensor] = None,
1591
+ ) -> torch.FloatTensor:
1592
+ residual = hidden_states
1593
+
1594
+ if attn.spatial_norm is not None:
1595
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1596
+
1597
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
1598
+
1599
+ batch_size, sequence_length, _ = hidden_states.shape
1600
+
1601
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1602
+
1603
+ if encoder_hidden_states is None:
1604
+ encoder_hidden_states = hidden_states
1605
+ elif attn.norm_cross:
1606
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1607
+
1608
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1609
+
1610
+ query = attn.to_q(hidden_states)
1611
+ dim = query.shape[-1]
1612
+ query = attn.head_to_batch_dim(query)
1613
+
1614
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1615
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1616
+
1617
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
1618
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
1619
+
1620
+ if not attn.only_cross_attention:
1621
+ key = attn.to_k(hidden_states)
1622
+ value = attn.to_v(hidden_states)
1623
+ key = attn.head_to_batch_dim(key)
1624
+ value = attn.head_to_batch_dim(value)
1625
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1626
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1627
+ else:
1628
+ key = encoder_hidden_states_key_proj
1629
+ value = encoder_hidden_states_value_proj
1630
+
1631
+ batch_size_attention, query_tokens, _ = query.shape
1632
+ hidden_states = torch.zeros(
1633
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
1634
+ )
1635
+
1636
+ for i in range(batch_size_attention // self.slice_size):
1637
+ start_idx = i * self.slice_size
1638
+ end_idx = (i + 1) * self.slice_size
1639
+
1640
+ query_slice = query[start_idx:end_idx]
1641
+ key_slice = key[start_idx:end_idx]
1642
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
1643
+
1644
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
1645
+
1646
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
1647
+
1648
+ hidden_states[start_idx:end_idx] = attn_slice
1649
+
1650
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1651
+
1652
+ # linear proj
1653
+ hidden_states = attn.to_out[0](hidden_states)
1654
+ # dropout
1655
+ hidden_states = attn.to_out[1](hidden_states)
1656
+
1657
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1658
+ hidden_states = hidden_states + residual
1659
+
1660
+ return hidden_states
1661
+
1662
+
1663
+ class SpatialNorm(nn.Module):
1664
+ """
1665
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
1666
+
1667
+ Args:
1668
+ f_channels (`int`):
1669
+ The number of channels for input to group normalization layer, and output of the spatial norm layer.
1670
+ zq_channels (`int`):
1671
+ The number of channels for the quantized vector as described in the paper.
1672
+ """
1673
+
1674
+ def __init__(
1675
+ self,
1676
+ f_channels: int,
1677
+ zq_channels: int,
1678
+ ):
1679
+ super().__init__()
1680
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
1681
+ self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
1682
+ self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
1683
+
1684
+ def forward(self, f: torch.FloatTensor, zq: torch.FloatTensor) -> torch.FloatTensor:
1685
+ f_size = f.shape[-2:]
1686
+ zq = F.interpolate(zq, size=f_size, mode="nearest")
1687
+ norm_f = self.norm_layer(f)
1688
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
1689
+ return new_f
1690
+
1691
+
1692
+ ## Deprecated
1693
+ class LoRAAttnProcessor(nn.Module):
1694
+ r"""
1695
+ Processor for implementing the LoRA attention mechanism.
1696
+
1697
+ Args:
1698
+ hidden_size (`int`, *optional*):
1699
+ The hidden size of the attention layer.
1700
+ cross_attention_dim (`int`, *optional*):
1701
+ The number of channels in the `encoder_hidden_states`.
1702
+ rank (`int`, defaults to 4):
1703
+ The dimension of the LoRA update matrices.
1704
+ network_alpha (`int`, *optional*):
1705
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1706
+ kwargs (`dict`):
1707
+ Additional keyword arguments to pass to the `LoRALinearLayer` layers.
1708
+ """
1709
+
1710
+ def __init__(
1711
+ self,
1712
+ hidden_size: int,
1713
+ cross_attention_dim: Optional[int] = None,
1714
+ rank: int = 4,
1715
+ network_alpha: Optional[int] = None,
1716
+ **kwargs,
1717
+ ):
1718
+ super().__init__()
1719
+
1720
+ self.hidden_size = hidden_size
1721
+ self.cross_attention_dim = cross_attention_dim
1722
+ self.rank = rank
1723
+
1724
+ q_rank = kwargs.pop("q_rank", None)
1725
+ q_hidden_size = kwargs.pop("q_hidden_size", None)
1726
+ q_rank = q_rank if q_rank is not None else rank
1727
+ q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
1728
+
1729
+ v_rank = kwargs.pop("v_rank", None)
1730
+ v_hidden_size = kwargs.pop("v_hidden_size", None)
1731
+ v_rank = v_rank if v_rank is not None else rank
1732
+ v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
1733
+
1734
+ out_rank = kwargs.pop("out_rank", None)
1735
+ out_hidden_size = kwargs.pop("out_hidden_size", None)
1736
+ out_rank = out_rank if out_rank is not None else rank
1737
+ out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
1738
+
1739
+ self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
1740
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1741
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
1742
+ self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
1743
+
1744
+ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
1745
+ self_cls_name = self.__class__.__name__
1746
+ deprecate(
1747
+ self_cls_name,
1748
+ "0.26.0",
1749
+ (
1750
+ f"Make sure use {self_cls_name[4:]} instead by setting"
1751
+ "LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
1752
+ " `LoraLoaderMixin.load_lora_weights`"
1753
+ ),
1754
+ )
1755
+ attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
1756
+ attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
1757
+ attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
1758
+ attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
1759
+
1760
+ attn._modules.pop("processor")
1761
+ attn.processor = AttnProcessor()
1762
+ return attn.processor(attn, hidden_states, *args, **kwargs)
1763
+
1764
+
1765
+ class LoRAAttnProcessor2_0(nn.Module):
1766
+ r"""
1767
+ Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
1768
+ attention.
1769
+
1770
+ Args:
1771
+ hidden_size (`int`):
1772
+ The hidden size of the attention layer.
1773
+ cross_attention_dim (`int`, *optional*):
1774
+ The number of channels in the `encoder_hidden_states`.
1775
+ rank (`int`, defaults to 4):
1776
+ The dimension of the LoRA update matrices.
1777
+ network_alpha (`int`, *optional*):
1778
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1779
+ kwargs (`dict`):
1780
+ Additional keyword arguments to pass to the `LoRALinearLayer` layers.
1781
+ """
1782
+
1783
+ def __init__(
1784
+ self,
1785
+ hidden_size: int,
1786
+ cross_attention_dim: Optional[int] = None,
1787
+ rank: int = 4,
1788
+ network_alpha: Optional[int] = None,
1789
+ **kwargs,
1790
+ ):
1791
+ super().__init__()
1792
+ if not hasattr(F, "scaled_dot_product_attention"):
1793
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1794
+
1795
+ self.hidden_size = hidden_size
1796
+ self.cross_attention_dim = cross_attention_dim
1797
+ self.rank = rank
1798
+
1799
+ q_rank = kwargs.pop("q_rank", None)
1800
+ q_hidden_size = kwargs.pop("q_hidden_size", None)
1801
+ q_rank = q_rank if q_rank is not None else rank
1802
+ q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
1803
+
1804
+ v_rank = kwargs.pop("v_rank", None)
1805
+ v_hidden_size = kwargs.pop("v_hidden_size", None)
1806
+ v_rank = v_rank if v_rank is not None else rank
1807
+ v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
1808
+
1809
+ out_rank = kwargs.pop("out_rank", None)
1810
+ out_hidden_size = kwargs.pop("out_hidden_size", None)
1811
+ out_rank = out_rank if out_rank is not None else rank
1812
+ out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
1813
+
1814
+ self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
1815
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1816
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
1817
+ self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
1818
+
1819
+ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
1820
+ self_cls_name = self.__class__.__name__
1821
+ deprecate(
1822
+ self_cls_name,
1823
+ "0.26.0",
1824
+ (
1825
+ f"Make sure use {self_cls_name[4:]} instead by setting"
1826
+ "LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
1827
+ " `LoraLoaderMixin.load_lora_weights`"
1828
+ ),
1829
+ )
1830
+ attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
1831
+ attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
1832
+ attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
1833
+ attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
1834
+
1835
+ attn._modules.pop("processor")
1836
+ attn.processor = AttnProcessor2_0()
1837
+ return attn.processor(attn, hidden_states, *args, **kwargs)
1838
+
1839
+
1840
+ class LoRAXFormersAttnProcessor(nn.Module):
1841
+ r"""
1842
+ Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
1843
+
1844
+ Args:
1845
+ hidden_size (`int`, *optional*):
1846
+ The hidden size of the attention layer.
1847
+ cross_attention_dim (`int`, *optional*):
1848
+ The number of channels in the `encoder_hidden_states`.
1849
+ rank (`int`, defaults to 4):
1850
+ The dimension of the LoRA update matrices.
1851
+ attention_op (`Callable`, *optional*, defaults to `None`):
1852
+ The base
1853
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1854
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1855
+ operator.
1856
+ network_alpha (`int`, *optional*):
1857
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1858
+ kwargs (`dict`):
1859
+ Additional keyword arguments to pass to the `LoRALinearLayer` layers.
1860
+ """
1861
+
1862
+ def __init__(
1863
+ self,
1864
+ hidden_size: int,
1865
+ cross_attention_dim: int,
1866
+ rank: int = 4,
1867
+ attention_op: Optional[Callable] = None,
1868
+ network_alpha: Optional[int] = None,
1869
+ **kwargs,
1870
+ ):
1871
+ super().__init__()
1872
+
1873
+ self.hidden_size = hidden_size
1874
+ self.cross_attention_dim = cross_attention_dim
1875
+ self.rank = rank
1876
+ self.attention_op = attention_op
1877
+
1878
+ q_rank = kwargs.pop("q_rank", None)
1879
+ q_hidden_size = kwargs.pop("q_hidden_size", None)
1880
+ q_rank = q_rank if q_rank is not None else rank
1881
+ q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
1882
+
1883
+ v_rank = kwargs.pop("v_rank", None)
1884
+ v_hidden_size = kwargs.pop("v_hidden_size", None)
1885
+ v_rank = v_rank if v_rank is not None else rank
1886
+ v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
1887
+
1888
+ out_rank = kwargs.pop("out_rank", None)
1889
+ out_hidden_size = kwargs.pop("out_hidden_size", None)
1890
+ out_rank = out_rank if out_rank is not None else rank
1891
+ out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
1892
+
1893
+ self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
1894
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1895
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
1896
+ self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
1897
+
1898
+ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
1899
+ self_cls_name = self.__class__.__name__
1900
+ deprecate(
1901
+ self_cls_name,
1902
+ "0.26.0",
1903
+ (
1904
+ f"Make sure use {self_cls_name[4:]} instead by setting"
1905
+ "LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
1906
+ " `LoraLoaderMixin.load_lora_weights`"
1907
+ ),
1908
+ )
1909
+ attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
1910
+ attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
1911
+ attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
1912
+ attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
1913
+
1914
+ attn._modules.pop("processor")
1915
+ attn.processor = XFormersAttnProcessor()
1916
+ return attn.processor(attn, hidden_states, *args, **kwargs)
1917
+
1918
+
1919
+ class LoRAAttnAddedKVProcessor(nn.Module):
1920
+ r"""
1921
+ Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
1922
+ encoder.
1923
+
1924
+ Args:
1925
+ hidden_size (`int`, *optional*):
1926
+ The hidden size of the attention layer.
1927
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
1928
+ The number of channels in the `encoder_hidden_states`.
1929
+ rank (`int`, defaults to 4):
1930
+ The dimension of the LoRA update matrices.
1931
+ network_alpha (`int`, *optional*):
1932
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1933
+ kwargs (`dict`):
1934
+ Additional keyword arguments to pass to the `LoRALinearLayer` layers.
1935
+ """
1936
+
1937
+ def __init__(
1938
+ self,
1939
+ hidden_size: int,
1940
+ cross_attention_dim: Optional[int] = None,
1941
+ rank: int = 4,
1942
+ network_alpha: Optional[int] = None,
1943
+ ):
1944
+ super().__init__()
1945
+
1946
+ self.hidden_size = hidden_size
1947
+ self.cross_attention_dim = cross_attention_dim
1948
+ self.rank = rank
1949
+
1950
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1951
+ self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1952
+ self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1953
+ self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1954
+ self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1955
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1956
+
1957
+ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
1958
+ self_cls_name = self.__class__.__name__
1959
+ deprecate(
1960
+ self_cls_name,
1961
+ "0.26.0",
1962
+ (
1963
+ f"Make sure use {self_cls_name[4:]} instead by setting"
1964
+ "LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
1965
+ " `LoraLoaderMixin.load_lora_weights`"
1966
+ ),
1967
+ )
1968
+ attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
1969
+ attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
1970
+ attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
1971
+ attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
1972
+
1973
+ attn._modules.pop("processor")
1974
+ attn.processor = AttnAddedKVProcessor()
1975
+ return attn.processor(attn, hidden_states, *args, **kwargs)
1976
+
1977
+
1978
+ LORA_ATTENTION_PROCESSORS = (
1979
+ LoRAAttnProcessor,
1980
+ LoRAAttnProcessor2_0,
1981
+ LoRAXFormersAttnProcessor,
1982
+ LoRAAttnAddedKVProcessor,
1983
+ )
1984
+
1985
+ ADDED_KV_ATTENTION_PROCESSORS = (
1986
+ AttnAddedKVProcessor,
1987
+ SlicedAttnAddedKVProcessor,
1988
+ AttnAddedKVProcessor2_0,
1989
+ XFormersAttnAddedKVProcessor,
1990
+ LoRAAttnAddedKVProcessor,
1991
+ )
1992
+
1993
+ CROSS_ATTENTION_PROCESSORS = (
1994
+ AttnProcessor,
1995
+ AttnProcessor2_0,
1996
+ XFormersAttnProcessor,
1997
+ SlicedAttnProcessor,
1998
+ LoRAAttnProcessor,
1999
+ LoRAAttnProcessor2_0,
2000
+ LoRAXFormersAttnProcessor,
2001
+ )
2002
+
2003
+ AttentionProcessor = Union[
2004
+ AttnProcessor,
2005
+ AttnProcessor2_0,
2006
+ XFormersAttnProcessor,
2007
+ SlicedAttnProcessor,
2008
+ AttnAddedKVProcessor,
2009
+ SlicedAttnAddedKVProcessor,
2010
+ AttnAddedKVProcessor2_0,
2011
+ XFormersAttnAddedKVProcessor,
2012
+ CustomDiffusionAttnProcessor,
2013
+ CustomDiffusionXFormersAttnProcessor,
2014
+ CustomDiffusionAttnProcessor2_0,
2015
+ # deprecated
2016
+ LoRAAttnProcessor,
2017
+ LoRAAttnProcessor2_0,
2018
+ LoRAXFormersAttnProcessor,
2019
+ LoRAAttnAddedKVProcessor,
2020
+ ]
diffusers/models/autoencoder_asym_kl.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional, Tuple, Union
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from ..configuration_utils import ConfigMixin, register_to_config
20
+ from ..utils.accelerate_utils import apply_forward_hook
21
+ from .autoencoder_kl import AutoencoderKLOutput
22
+ from .modeling_utils import ModelMixin
23
+ from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
24
+
25
+
26
+ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
27
+ r"""
28
+ Designing a Better Asymmetric VQGAN for StableDiffusion https://arxiv.org/abs/2306.04632 . A VAE model with KL loss
29
+ for encoding images into latents and decoding latent representations into images.
30
+
31
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
32
+ for all models (such as downloading or saving).
33
+
34
+ Parameters:
35
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
36
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
37
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
38
+ Tuple of downsample block types.
39
+ down_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
40
+ Tuple of down block output channels.
41
+ layers_per_down_block (`int`, *optional*, defaults to `1`):
42
+ Number layers for down block.
43
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
44
+ Tuple of upsample block types.
45
+ up_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
46
+ Tuple of up block output channels.
47
+ layers_per_up_block (`int`, *optional*, defaults to `1`):
48
+ Number layers for up block.
49
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
50
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
51
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
52
+ norm_num_groups (`int`, *optional*, defaults to `32`):
53
+ Number of groups to use for the first normalization layer in ResNet blocks.
54
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
55
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
56
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
57
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
58
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
59
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
60
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
61
+ """
62
+
63
+ @register_to_config
64
+ def __init__(
65
+ self,
66
+ in_channels: int = 3,
67
+ out_channels: int = 3,
68
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
69
+ down_block_out_channels: Tuple[int] = (64,),
70
+ layers_per_down_block: int = 1,
71
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
72
+ up_block_out_channels: Tuple[int] = (64,),
73
+ layers_per_up_block: int = 1,
74
+ act_fn: str = "silu",
75
+ latent_channels: int = 4,
76
+ norm_num_groups: int = 32,
77
+ sample_size: int = 32,
78
+ scaling_factor: float = 0.18215,
79
+ ) -> None:
80
+ super().__init__()
81
+
82
+ # pass init params to Encoder
83
+ self.encoder = Encoder(
84
+ in_channels=in_channels,
85
+ out_channels=latent_channels,
86
+ down_block_types=down_block_types,
87
+ block_out_channels=down_block_out_channels,
88
+ layers_per_block=layers_per_down_block,
89
+ act_fn=act_fn,
90
+ norm_num_groups=norm_num_groups,
91
+ double_z=True,
92
+ )
93
+
94
+ # pass init params to Decoder
95
+ self.decoder = MaskConditionDecoder(
96
+ in_channels=latent_channels,
97
+ out_channels=out_channels,
98
+ up_block_types=up_block_types,
99
+ block_out_channels=up_block_out_channels,
100
+ layers_per_block=layers_per_up_block,
101
+ act_fn=act_fn,
102
+ norm_num_groups=norm_num_groups,
103
+ )
104
+
105
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
106
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
107
+
108
+ self.use_slicing = False
109
+ self.use_tiling = False
110
+
111
+ @apply_forward_hook
112
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
113
+ h = self.encoder(x)
114
+ moments = self.quant_conv(h)
115
+ posterior = DiagonalGaussianDistribution(moments)
116
+
117
+ if not return_dict:
118
+ return (posterior,)
119
+
120
+ return AutoencoderKLOutput(latent_dist=posterior)
121
+
122
+ def _decode(
123
+ self,
124
+ z: torch.FloatTensor,
125
+ image: Optional[torch.FloatTensor] = None,
126
+ mask: Optional[torch.FloatTensor] = None,
127
+ return_dict: bool = True,
128
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
129
+ z = self.post_quant_conv(z)
130
+ dec = self.decoder(z, image, mask)
131
+
132
+ if not return_dict:
133
+ return (dec,)
134
+
135
+ return DecoderOutput(sample=dec)
136
+
137
+ @apply_forward_hook
138
+ def decode(
139
+ self,
140
+ z: torch.FloatTensor,
141
+ generator: Optional[torch.Generator] = None,
142
+ image: Optional[torch.FloatTensor] = None,
143
+ mask: Optional[torch.FloatTensor] = None,
144
+ return_dict: bool = True,
145
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
146
+ decoded = self._decode(z, image, mask).sample
147
+
148
+ if not return_dict:
149
+ return (decoded,)
150
+
151
+ return DecoderOutput(sample=decoded)
152
+
153
+ def forward(
154
+ self,
155
+ sample: torch.FloatTensor,
156
+ mask: Optional[torch.FloatTensor] = None,
157
+ sample_posterior: bool = False,
158
+ return_dict: bool = True,
159
+ generator: Optional[torch.Generator] = None,
160
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
161
+ r"""
162
+ Args:
163
+ sample (`torch.FloatTensor`): Input sample.
164
+ mask (`torch.FloatTensor`, *optional*, defaults to `None`): Optional inpainting face_hair_mask.
165
+ sample_posterior (`bool`, *optional*, defaults to `False`):
166
+ Whether to sample from the posterior.
167
+ return_dict (`bool`, *optional*, defaults to `True`):
168
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
169
+ """
170
+ x = sample
171
+ posterior = self.encode(x).latent_dist
172
+ if sample_posterior:
173
+ z = posterior.sample(generator=generator)
174
+ else:
175
+ z = posterior.mode()
176
+ dec = self.decode(z, sample, mask).sample
177
+
178
+ if not return_dict:
179
+ return (dec,)
180
+
181
+ return DecoderOutput(sample=dec)
diffusers/models/autoencoder_kl.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Dict, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from ..configuration_utils import ConfigMixin, register_to_config
21
+ from ..loaders import FromOriginalVAEMixin
22
+ from ..utils import BaseOutput
23
+ from ..utils.accelerate_utils import apply_forward_hook
24
+ from .attention_processor import (
25
+ ADDED_KV_ATTENTION_PROCESSORS,
26
+ CROSS_ATTENTION_PROCESSORS,
27
+ AttentionProcessor,
28
+ AttnAddedKVProcessor,
29
+ AttnProcessor,
30
+ )
31
+ from .modeling_utils import ModelMixin
32
+ from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
33
+
34
+
35
+ @dataclass
36
+ class AutoencoderKLOutput(BaseOutput):
37
+ """
38
+ Output of AutoencoderKL encoding method.
39
+
40
+ Args:
41
+ latent_dist (`DiagonalGaussianDistribution`):
42
+ Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
43
+ `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
44
+ """
45
+
46
+ latent_dist: "DiagonalGaussianDistribution"
47
+
48
+
49
+ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
50
+ r"""
51
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
52
+
53
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
54
+ for all models (such as downloading or saving).
55
+
56
+ Parameters:
57
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
58
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
59
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
60
+ Tuple of downsample block types.
61
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
62
+ Tuple of upsample block types.
63
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
64
+ Tuple of block output channels.
65
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
66
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
67
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
68
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
69
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
70
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
71
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
72
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
73
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
74
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
75
+ force_upcast (`bool`, *optional*, default to `True`):
76
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
77
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
78
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
79
+ """
80
+
81
+ _supports_gradient_checkpointing = True
82
+
83
+ @register_to_config
84
+ def __init__(
85
+ self,
86
+ in_channels: int = 3,
87
+ out_channels: int = 3,
88
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
89
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
90
+ block_out_channels: Tuple[int] = (64,),
91
+ layers_per_block: int = 1,
92
+ act_fn: str = "silu",
93
+ latent_channels: int = 4,
94
+ norm_num_groups: int = 32,
95
+ sample_size: int = 32,
96
+ scaling_factor: float = 0.18215,
97
+ force_upcast: float = True,
98
+ ):
99
+ super().__init__()
100
+
101
+ # pass init params to Encoder
102
+ self.encoder = Encoder(
103
+ in_channels=in_channels,
104
+ out_channels=latent_channels,
105
+ down_block_types=down_block_types,
106
+ block_out_channels=block_out_channels,
107
+ layers_per_block=layers_per_block,
108
+ act_fn=act_fn,
109
+ norm_num_groups=norm_num_groups,
110
+ double_z=True,
111
+ )
112
+
113
+ # pass init params to Decoder
114
+ self.decoder = Decoder(
115
+ in_channels=latent_channels,
116
+ out_channels=out_channels,
117
+ up_block_types=up_block_types,
118
+ block_out_channels=block_out_channels,
119
+ layers_per_block=layers_per_block,
120
+ norm_num_groups=norm_num_groups,
121
+ act_fn=act_fn,
122
+ )
123
+
124
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
125
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
126
+
127
+ self.use_slicing = False
128
+ self.use_tiling = False
129
+
130
+ # only relevant if vae tiling is enabled
131
+ self.tile_sample_min_size = self.config.sample_size
132
+ sample_size = (
133
+ self.config.sample_size[0]
134
+ if isinstance(self.config.sample_size, (list, tuple))
135
+ else self.config.sample_size
136
+ )
137
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
138
+ self.tile_overlap_factor = 0.25
139
+
140
+ def _set_gradient_checkpointing(self, module, value=False):
141
+ if isinstance(module, (Encoder, Decoder)):
142
+ module.gradient_checkpointing = value
143
+
144
+ def enable_tiling(self, use_tiling: bool = True):
145
+ r"""
146
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
147
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
148
+ processing larger images.
149
+ """
150
+ self.use_tiling = use_tiling
151
+
152
+ def disable_tiling(self):
153
+ r"""
154
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
155
+ decoding in one step.
156
+ """
157
+ self.enable_tiling(False)
158
+
159
+ def enable_slicing(self):
160
+ r"""
161
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
162
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
163
+ """
164
+ self.use_slicing = True
165
+
166
+ def disable_slicing(self):
167
+ r"""
168
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
169
+ decoding in one step.
170
+ """
171
+ self.use_slicing = False
172
+
173
+ @property
174
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
175
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
176
+ r"""
177
+ Returns:
178
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
179
+ indexed by its weight name.
180
+ """
181
+ # set recursively
182
+ processors = {}
183
+
184
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
185
+ if hasattr(module, "get_processor"):
186
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
187
+
188
+ for sub_name, child in module.named_children():
189
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
190
+
191
+ return processors
192
+
193
+ for name, module in self.named_children():
194
+ fn_recursive_add_processors(name, module, processors)
195
+
196
+ return processors
197
+
198
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
199
+ def set_attn_processor(
200
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
201
+ ):
202
+ r"""
203
+ Sets the attention processor to use to compute attention.
204
+
205
+ Parameters:
206
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
207
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
208
+ for **all** `Attention` layers.
209
+
210
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
211
+ processor. This is strongly recommended when setting trainable attention processors.
212
+
213
+ """
214
+ count = len(self.attn_processors.keys())
215
+
216
+ if isinstance(processor, dict) and len(processor) != count:
217
+ raise ValueError(
218
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
219
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
220
+ )
221
+
222
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
223
+ if hasattr(module, "set_processor"):
224
+ if not isinstance(processor, dict):
225
+ module.set_processor(processor, _remove_lora=_remove_lora)
226
+ else:
227
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
228
+
229
+ for sub_name, child in module.named_children():
230
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
231
+
232
+ for name, module in self.named_children():
233
+ fn_recursive_attn_processor(name, module, processor)
234
+
235
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
236
+ def set_default_attn_processor(self):
237
+ """
238
+ Disables custom attention processors and sets the default attention implementation.
239
+ """
240
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
241
+ processor = AttnAddedKVProcessor()
242
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
243
+ processor = AttnProcessor()
244
+ else:
245
+ raise ValueError(
246
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
247
+ )
248
+
249
+ self.set_attn_processor(processor, _remove_lora=True)
250
+
251
+ @apply_forward_hook
252
+ def encode(
253
+ self, x: torch.FloatTensor, return_dict: bool = True
254
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
255
+ """
256
+ Encode a batch of images into latents.
257
+
258
+ Args:
259
+ x (`torch.FloatTensor`): Input batch of images.
260
+ return_dict (`bool`, *optional*, defaults to `True`):
261
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
262
+
263
+ Returns:
264
+ The latent representations of the encoded images. If `return_dict` is True, a
265
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
266
+ """
267
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
268
+ return self.tiled_encode(x, return_dict=return_dict)
269
+
270
+ if self.use_slicing and x.shape[0] > 1:
271
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
272
+ h = torch.cat(encoded_slices)
273
+ else:
274
+ h = self.encoder(x)
275
+
276
+ moments = self.quant_conv(h)
277
+ posterior = DiagonalGaussianDistribution(moments)
278
+
279
+ if not return_dict:
280
+ return (posterior,)
281
+
282
+ return AutoencoderKLOutput(latent_dist=posterior)
283
+
284
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
285
+ if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
286
+ return self.tiled_decode(z, return_dict=return_dict)
287
+
288
+ z = self.post_quant_conv(z)
289
+ dec = self.decoder(z)
290
+
291
+ if not return_dict:
292
+ return (dec,)
293
+
294
+ return DecoderOutput(sample=dec)
295
+
296
+ @apply_forward_hook
297
+ def decode(
298
+ self, z: torch.FloatTensor, return_dict: bool = True, generator=None
299
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
300
+ """
301
+ Decode a batch of images.
302
+
303
+ Args:
304
+ z (`torch.FloatTensor`): Input batch of latent vectors.
305
+ return_dict (`bool`, *optional*, defaults to `True`):
306
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
307
+
308
+ Returns:
309
+ [`~models.vae.DecoderOutput`] or `tuple`:
310
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
311
+ returned.
312
+
313
+ """
314
+ if self.use_slicing and z.shape[0] > 1:
315
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
316
+ decoded = torch.cat(decoded_slices)
317
+ else:
318
+ decoded = self._decode(z).sample
319
+
320
+ if not return_dict:
321
+ return (decoded,)
322
+
323
+ return DecoderOutput(sample=decoded)
324
+
325
+ def blend_v(self, a, b, blend_extent):
326
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
327
+ for y in range(blend_extent):
328
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
329
+ return b
330
+
331
+ def blend_h(self, a, b, blend_extent):
332
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
333
+ for x in range(blend_extent):
334
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
335
+ return b
336
+
337
+ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
338
+ r"""Encode a batch of images using a tiled encoder.
339
+
340
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
341
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
342
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
343
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
344
+ output, but they should be much less noticeable.
345
+
346
+ Args:
347
+ x (`torch.FloatTensor`): Input batch of images.
348
+ return_dict (`bool`, *optional*, defaults to `True`):
349
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
350
+
351
+ Returns:
352
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
353
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
354
+ `tuple` is returned.
355
+ """
356
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
357
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
358
+ row_limit = self.tile_latent_min_size - blend_extent
359
+
360
+ # Split the image into 512x512 tiles and encode them separately.
361
+ rows = []
362
+ for i in range(0, x.shape[2], overlap_size):
363
+ row = []
364
+ for j in range(0, x.shape[3], overlap_size):
365
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
366
+ tile = self.encoder(tile)
367
+ tile = self.quant_conv(tile)
368
+ row.append(tile)
369
+ rows.append(row)
370
+ result_rows = []
371
+ for i, row in enumerate(rows):
372
+ result_row = []
373
+ for j, tile in enumerate(row):
374
+ # blend the above tile and the left tile
375
+ # to the current tile and add the current tile to the result row
376
+ if i > 0:
377
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
378
+ if j > 0:
379
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
380
+ result_row.append(tile[:, :, :row_limit, :row_limit])
381
+ result_rows.append(torch.cat(result_row, dim=3))
382
+
383
+ moments = torch.cat(result_rows, dim=2)
384
+ posterior = DiagonalGaussianDistribution(moments)
385
+
386
+ if not return_dict:
387
+ return (posterior,)
388
+
389
+ return AutoencoderKLOutput(latent_dist=posterior)
390
+
391
+ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
392
+ r"""
393
+ Decode a batch of images using a tiled decoder.
394
+
395
+ Args:
396
+ z (`torch.FloatTensor`): Input batch of latent vectors.
397
+ return_dict (`bool`, *optional*, defaults to `True`):
398
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
399
+
400
+ Returns:
401
+ [`~models.vae.DecoderOutput`] or `tuple`:
402
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
403
+ returned.
404
+ """
405
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
406
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
407
+ row_limit = self.tile_sample_min_size - blend_extent
408
+
409
+ # Split z into overlapping 64x64 tiles and decode them separately.
410
+ # The tiles have an overlap to avoid seams between tiles.
411
+ rows = []
412
+ for i in range(0, z.shape[2], overlap_size):
413
+ row = []
414
+ for j in range(0, z.shape[3], overlap_size):
415
+ tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
416
+ tile = self.post_quant_conv(tile)
417
+ decoded = self.decoder(tile)
418
+ row.append(decoded)
419
+ rows.append(row)
420
+ result_rows = []
421
+ for i, row in enumerate(rows):
422
+ result_row = []
423
+ for j, tile in enumerate(row):
424
+ # blend the above tile and the left tile
425
+ # to the current tile and add the current tile to the result row
426
+ if i > 0:
427
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
428
+ if j > 0:
429
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
430
+ result_row.append(tile[:, :, :row_limit, :row_limit])
431
+ result_rows.append(torch.cat(result_row, dim=3))
432
+
433
+ dec = torch.cat(result_rows, dim=2)
434
+ if not return_dict:
435
+ return (dec,)
436
+
437
+ return DecoderOutput(sample=dec)
438
+
439
+ def forward(
440
+ self,
441
+ sample: torch.FloatTensor,
442
+ sample_posterior: bool = False,
443
+ return_dict: bool = True,
444
+ generator: Optional[torch.Generator] = None,
445
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
446
+ r"""
447
+ Args:
448
+ sample (`torch.FloatTensor`): Input sample.
449
+ sample_posterior (`bool`, *optional*, defaults to `False`):
450
+ Whether to sample from the posterior.
451
+ return_dict (`bool`, *optional*, defaults to `True`):
452
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
453
+ """
454
+ x = sample
455
+ posterior = self.encode(x).latent_dist
456
+ if sample_posterior:
457
+ z = posterior.sample(generator=generator)
458
+ else:
459
+ z = posterior.mode()
460
+ dec = self.decode(z).sample
461
+
462
+ if not return_dict:
463
+ return (dec,)
464
+
465
+ return DecoderOutput(sample=dec)
diffusers/models/autoencoder_tiny.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Ollin Boer Bohan and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Optional, Tuple, Union
18
+
19
+ import torch
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..utils import BaseOutput
23
+ from ..utils.accelerate_utils import apply_forward_hook
24
+ from .modeling_utils import ModelMixin
25
+ from .vae import DecoderOutput, DecoderTiny, EncoderTiny
26
+
27
+
28
+ @dataclass
29
+ class AutoencoderTinyOutput(BaseOutput):
30
+ """
31
+ Output of AutoencoderTiny encoding method.
32
+
33
+ Args:
34
+ latents (`torch.Tensor`): Encoded outputs of the `Encoder`.
35
+
36
+ """
37
+
38
+ latents: torch.Tensor
39
+
40
+
41
+ class AutoencoderTiny(ModelMixin, ConfigMixin):
42
+ r"""
43
+ A tiny distilled VAE model for encoding images into latents and decoding latent representations into images.
44
+
45
+ [`AutoencoderTiny`] is a wrapper around the original implementation of `TAESD`.
46
+
47
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
48
+ all models (such as downloading or saving).
49
+
50
+ Parameters:
51
+ in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
52
+ out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
53
+ encoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
54
+ Tuple of integers representing the number of output channels for each encoder block. The length of the
55
+ tuple should be equal to the number of encoder blocks.
56
+ decoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
57
+ Tuple of integers representing the number of output channels for each decoder block. The length of the
58
+ tuple should be equal to the number of decoder blocks.
59
+ act_fn (`str`, *optional*, defaults to `"relu"`):
60
+ Activation function to be used throughout the model.
61
+ latent_channels (`int`, *optional*, defaults to 4):
62
+ Number of channels in the latent representation. The latent space acts as a compressed representation of
63
+ the input image.
64
+ upsampling_scaling_factor (`int`, *optional*, defaults to 2):
65
+ Scaling factor for upsampling in the decoder. It determines the size of the output image during the
66
+ upsampling process.
67
+ num_encoder_blocks (`Tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`):
68
+ Tuple of integers representing the number of encoder blocks at each stage of the encoding process. The
69
+ length of the tuple should be equal to the number of stages in the encoder. Each stage has a different
70
+ number of encoder blocks.
71
+ num_decoder_blocks (`Tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`):
72
+ Tuple of integers representing the number of decoder blocks at each stage of the decoding process. The
73
+ length of the tuple should be equal to the number of stages in the decoder. Each stage has a different
74
+ number of decoder blocks.
75
+ latent_magnitude (`float`, *optional*, defaults to 3.0):
76
+ Magnitude of the latent representation. This parameter scales the latent representation values to control
77
+ the extent of information preservation.
78
+ latent_shift (float, *optional*, defaults to 0.5):
79
+ Shift applied to the latent representation. This parameter controls the center of the latent space.
80
+ scaling_factor (`float`, *optional*, defaults to 1.0):
81
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
82
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
83
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
84
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
85
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
86
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. For this Autoencoder,
87
+ however, no such scaling factor was used, hence the value of 1.0 as the default.
88
+ force_upcast (`bool`, *optional*, default to `False`):
89
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
90
+ can be fine-tuned / trained to a lower range without losing too much precision, in which case
91
+ `force_upcast` can be set to `False` (see this fp16-friendly
92
+ [AutoEncoder](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
93
+ """
94
+ _supports_gradient_checkpointing = True
95
+
96
+ @register_to_config
97
+ def __init__(
98
+ self,
99
+ in_channels=3,
100
+ out_channels=3,
101
+ encoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
102
+ decoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
103
+ act_fn: str = "relu",
104
+ latent_channels: int = 4,
105
+ upsampling_scaling_factor: int = 2,
106
+ num_encoder_blocks: Tuple[int] = (1, 3, 3, 3),
107
+ num_decoder_blocks: Tuple[int] = (3, 3, 3, 1),
108
+ latent_magnitude: int = 3,
109
+ latent_shift: float = 0.5,
110
+ force_upcast: float = False,
111
+ scaling_factor: float = 1.0,
112
+ ):
113
+ super().__init__()
114
+
115
+ if len(encoder_block_out_channels) != len(num_encoder_blocks):
116
+ raise ValueError("`encoder_block_out_channels` should have the same length as `num_encoder_blocks`.")
117
+ if len(decoder_block_out_channels) != len(num_decoder_blocks):
118
+ raise ValueError("`decoder_block_out_channels` should have the same length as `num_decoder_blocks`.")
119
+
120
+ self.encoder = EncoderTiny(
121
+ in_channels=in_channels,
122
+ out_channels=latent_channels,
123
+ num_blocks=num_encoder_blocks,
124
+ block_out_channels=encoder_block_out_channels,
125
+ act_fn=act_fn,
126
+ )
127
+
128
+ self.decoder = DecoderTiny(
129
+ in_channels=latent_channels,
130
+ out_channels=out_channels,
131
+ num_blocks=num_decoder_blocks,
132
+ block_out_channels=decoder_block_out_channels,
133
+ upsampling_scaling_factor=upsampling_scaling_factor,
134
+ act_fn=act_fn,
135
+ )
136
+
137
+ self.latent_magnitude = latent_magnitude
138
+ self.latent_shift = latent_shift
139
+ self.scaling_factor = scaling_factor
140
+
141
+ self.use_slicing = False
142
+ self.use_tiling = False
143
+
144
+ # only relevant if vae tiling is enabled
145
+ self.spatial_scale_factor = 2**out_channels
146
+ self.tile_overlap_factor = 0.125
147
+ self.tile_sample_min_size = 512
148
+ self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor
149
+
150
+ def _set_gradient_checkpointing(self, module, value=False):
151
+ if isinstance(module, (EncoderTiny, DecoderTiny)):
152
+ module.gradient_checkpointing = value
153
+
154
+ def scale_latents(self, x):
155
+ """raw latents -> [0, 1]"""
156
+ return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
157
+
158
+ def unscale_latents(self, x):
159
+ """[0, 1] -> raw latents"""
160
+ return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
161
+
162
+ def enable_slicing(self):
163
+ r"""
164
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
165
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
166
+ """
167
+ self.use_slicing = True
168
+
169
+ def disable_slicing(self):
170
+ r"""
171
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
172
+ decoding in one step.
173
+ """
174
+ self.use_slicing = False
175
+
176
+ def enable_tiling(self, use_tiling: bool = True):
177
+ r"""
178
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
179
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
180
+ processing larger images.
181
+ """
182
+ self.use_tiling = use_tiling
183
+
184
+ def disable_tiling(self):
185
+ r"""
186
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
187
+ decoding in one step.
188
+ """
189
+ self.enable_tiling(False)
190
+
191
+ def _tiled_encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
192
+ r"""Encode a batch of images using a tiled encoder.
193
+
194
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
195
+ steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
196
+ tiles overlap and are blended together to form a smooth output.
197
+
198
+ Args:
199
+ x (`torch.FloatTensor`): Input batch of images.
200
+ return_dict (`bool`, *optional*, defaults to `True`):
201
+ Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
202
+
203
+ Returns:
204
+ [`~models.autoencoder_tiny.AutoencoderTinyOutput`] or `tuple`:
205
+ If return_dict is True, a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] is returned, otherwise a
206
+ plain `tuple` is returned.
207
+ """
208
+ # scale of encoder output relative to input
209
+ sf = self.spatial_scale_factor
210
+ tile_size = self.tile_sample_min_size
211
+
212
+ # number of pixels to blend and to traverse between tile
213
+ blend_size = int(tile_size * self.tile_overlap_factor)
214
+ traverse_size = tile_size - blend_size
215
+
216
+ # tiles index (up/left)
217
+ ti = range(0, x.shape[-2], traverse_size)
218
+ tj = range(0, x.shape[-1], traverse_size)
219
+
220
+ # face_hair_mask for blending
221
+ blend_masks = torch.stack(
222
+ torch.meshgrid([torch.arange(tile_size / sf) / (blend_size / sf - 1)] * 2, indexing="ij")
223
+ )
224
+ blend_masks = blend_masks.clamp(0, 1).to(x.device)
225
+
226
+ # output array
227
+ out = torch.zeros(x.shape[0], 4, x.shape[-2] // sf, x.shape[-1] // sf, device=x.device)
228
+ for i in ti:
229
+ for j in tj:
230
+ tile_in = x[..., i : i + tile_size, j : j + tile_size]
231
+ # tile result
232
+ tile_out = out[..., i // sf : (i + tile_size) // sf, j // sf : (j + tile_size) // sf]
233
+ tile = self.encoder(tile_in)
234
+ h, w = tile.shape[-2], tile.shape[-1]
235
+ # blend tile result into output
236
+ blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
237
+ blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
238
+ blend_mask = blend_mask_i * blend_mask_j
239
+ tile, blend_mask = tile[..., :h, :w], blend_mask[..., :h, :w]
240
+ tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
241
+ return out
242
+
243
+ def _tiled_decode(self, x: torch.FloatTensor) -> torch.FloatTensor:
244
+ r"""Encode a batch of images using a tiled encoder.
245
+
246
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
247
+ steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
248
+ tiles overlap and are blended together to form a smooth output.
249
+
250
+ Args:
251
+ x (`torch.FloatTensor`): Input batch of images.
252
+ return_dict (`bool`, *optional*, defaults to `True`):
253
+ Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
254
+
255
+ Returns:
256
+ [`~models.vae.DecoderOutput`] or `tuple`:
257
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
258
+ returned.
259
+ """
260
+ # scale of decoder output relative to input
261
+ sf = self.spatial_scale_factor
262
+ tile_size = self.tile_latent_min_size
263
+
264
+ # number of pixels to blend and to traverse between tiles
265
+ blend_size = int(tile_size * self.tile_overlap_factor)
266
+ traverse_size = tile_size - blend_size
267
+
268
+ # tiles index (up/left)
269
+ ti = range(0, x.shape[-2], traverse_size)
270
+ tj = range(0, x.shape[-1], traverse_size)
271
+
272
+ # face_hair_mask for blending
273
+ blend_masks = torch.stack(
274
+ torch.meshgrid([torch.arange(tile_size * sf) / (blend_size * sf - 1)] * 2, indexing="ij")
275
+ )
276
+ blend_masks = blend_masks.clamp(0, 1).to(x.device)
277
+
278
+ # output array
279
+ out = torch.zeros(x.shape[0], 3, x.shape[-2] * sf, x.shape[-1] * sf, device=x.device)
280
+ for i in ti:
281
+ for j in tj:
282
+ tile_in = x[..., i : i + tile_size, j : j + tile_size]
283
+ # tile result
284
+ tile_out = out[..., i * sf : (i + tile_size) * sf, j * sf : (j + tile_size) * sf]
285
+ tile = self.decoder(tile_in)
286
+ h, w = tile.shape[-2], tile.shape[-1]
287
+ # blend tile result into output
288
+ blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
289
+ blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
290
+ blend_mask = (blend_mask_i * blend_mask_j)[..., :h, :w]
291
+ tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
292
+ return out
293
+
294
+ @apply_forward_hook
295
+ def encode(
296
+ self, x: torch.FloatTensor, return_dict: bool = True
297
+ ) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]:
298
+ if self.use_slicing and x.shape[0] > 1:
299
+ output = [self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x) for x_slice in x.split(1)]
300
+ output = torch.cat(output)
301
+ else:
302
+ output = self._tiled_encode(x) if self.use_tiling else self.encoder(x)
303
+
304
+ if not return_dict:
305
+ return (output,)
306
+
307
+ return AutoencoderTinyOutput(latents=output)
308
+
309
+ @apply_forward_hook
310
+ def decode(
311
+ self, x: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
312
+ ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
313
+ if self.use_slicing and x.shape[0] > 1:
314
+ output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
315
+ output = torch.cat(output)
316
+ else:
317
+ output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
318
+
319
+ if not return_dict:
320
+ return (output,)
321
+
322
+ return DecoderOutput(sample=output)
323
+
324
+ def forward(
325
+ self,
326
+ sample: torch.FloatTensor,
327
+ return_dict: bool = True,
328
+ ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
329
+ r"""
330
+ Args:
331
+ sample (`torch.FloatTensor`): Input sample.
332
+ return_dict (`bool`, *optional*, defaults to `True`):
333
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
334
+ """
335
+ enc = self.encode(sample).latents
336
+
337
+ # scale latents to be in [0, 1], then quantize latents to a byte tensor,
338
+ # as if we were storing the latents in an RGBA uint8 image.
339
+ scaled_enc = self.scale_latents(enc).mul_(255).round_().byte()
340
+
341
+ # unquantize latents back into [0, 1], then unscale latents back to their original range,
342
+ # as if we were loading the latents from an RGBA uint8 image.
343
+ unscaled_enc = self.unscale_latents(scaled_enc / 255.0)
344
+
345
+ dec = self.decode(unscaled_enc)
346
+
347
+ if not return_dict:
348
+ return (dec,)
349
+ return DecoderOutput(sample=dec)
diffusers/models/consistency_decoder_vae.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Dict, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..schedulers import ConsistencyDecoderScheduler
23
+ from ..utils import BaseOutput
24
+ from ..utils.accelerate_utils import apply_forward_hook
25
+ from ..utils.torch_utils import randn_tensor
26
+ from .attention_processor import (
27
+ ADDED_KV_ATTENTION_PROCESSORS,
28
+ CROSS_ATTENTION_PROCESSORS,
29
+ AttentionProcessor,
30
+ AttnAddedKVProcessor,
31
+ AttnProcessor,
32
+ )
33
+ from .modeling_utils import ModelMixin
34
+ from .unet_2d import UNet2DModel
35
+ from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
36
+
37
+
38
+ @dataclass
39
+ class ConsistencyDecoderVAEOutput(BaseOutput):
40
+ """
41
+ Output of encoding method.
42
+
43
+ Args:
44
+ latent_dist (`DiagonalGaussianDistribution`):
45
+ Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
46
+ `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
47
+ """
48
+
49
+ latent_dist: "DiagonalGaussianDistribution"
50
+
51
+
52
+ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
53
+ r"""
54
+ The consistency decoder used with DALL-E 3.
55
+
56
+ Examples:
57
+ ```py
58
+ >>> import torch
59
+ >>> from diffusers import DiffusionPipeline, ConsistencyDecoderVAE
60
+
61
+ >>> vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=pipe.torch_dtype)
62
+ >>> pipe = StableDiffusionPipeline.from_pretrained(
63
+ ... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
64
+ ... ).to("cuda")
65
+
66
+ >>> pipe("horse", generator=torch.manual_seed(0)).images
67
+ ```
68
+ """
69
+
70
+ @register_to_config
71
+ def __init__(
72
+ self,
73
+ scaling_factor=0.18215,
74
+ latent_channels=4,
75
+ encoder_act_fn="silu",
76
+ encoder_block_out_channels=(128, 256, 512, 512),
77
+ encoder_double_z=True,
78
+ encoder_down_block_types=(
79
+ "DownEncoderBlock2D",
80
+ "DownEncoderBlock2D",
81
+ "DownEncoderBlock2D",
82
+ "DownEncoderBlock2D",
83
+ ),
84
+ encoder_in_channels=3,
85
+ encoder_layers_per_block=2,
86
+ encoder_norm_num_groups=32,
87
+ encoder_out_channels=4,
88
+ decoder_add_attention=False,
89
+ decoder_block_out_channels=(320, 640, 1024, 1024),
90
+ decoder_down_block_types=(
91
+ "ResnetDownsampleBlock2D",
92
+ "ResnetDownsampleBlock2D",
93
+ "ResnetDownsampleBlock2D",
94
+ "ResnetDownsampleBlock2D",
95
+ ),
96
+ decoder_downsample_padding=1,
97
+ decoder_in_channels=7,
98
+ decoder_layers_per_block=3,
99
+ decoder_norm_eps=1e-05,
100
+ decoder_norm_num_groups=32,
101
+ decoder_num_train_timesteps=1024,
102
+ decoder_out_channels=6,
103
+ decoder_resnet_time_scale_shift="scale_shift",
104
+ decoder_time_embedding_type="learned",
105
+ decoder_up_block_types=(
106
+ "ResnetUpsampleBlock2D",
107
+ "ResnetUpsampleBlock2D",
108
+ "ResnetUpsampleBlock2D",
109
+ "ResnetUpsampleBlock2D",
110
+ ),
111
+ ):
112
+ super().__init__()
113
+ self.encoder = Encoder(
114
+ act_fn=encoder_act_fn,
115
+ block_out_channels=encoder_block_out_channels,
116
+ double_z=encoder_double_z,
117
+ down_block_types=encoder_down_block_types,
118
+ in_channels=encoder_in_channels,
119
+ layers_per_block=encoder_layers_per_block,
120
+ norm_num_groups=encoder_norm_num_groups,
121
+ out_channels=encoder_out_channels,
122
+ )
123
+
124
+ self.decoder_unet = UNet2DModel(
125
+ add_attention=decoder_add_attention,
126
+ block_out_channels=decoder_block_out_channels,
127
+ down_block_types=decoder_down_block_types,
128
+ downsample_padding=decoder_downsample_padding,
129
+ in_channels=decoder_in_channels,
130
+ layers_per_block=decoder_layers_per_block,
131
+ norm_eps=decoder_norm_eps,
132
+ norm_num_groups=decoder_norm_num_groups,
133
+ num_train_timesteps=decoder_num_train_timesteps,
134
+ out_channels=decoder_out_channels,
135
+ resnet_time_scale_shift=decoder_resnet_time_scale_shift,
136
+ time_embedding_type=decoder_time_embedding_type,
137
+ up_block_types=decoder_up_block_types,
138
+ )
139
+ self.decoder_scheduler = ConsistencyDecoderScheduler()
140
+ self.register_to_config(block_out_channels=encoder_block_out_channels)
141
+ self.register_buffer(
142
+ "means",
143
+ torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None],
144
+ persistent=False,
145
+ )
146
+ self.register_buffer(
147
+ "stds", torch.tensor([0.9654121, 1.0440036, 0.76147926, 0.77022034])[None, :, None, None], persistent=False
148
+ )
149
+
150
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
151
+
152
+ self.use_slicing = False
153
+ self.use_tiling = False
154
+
155
+ # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_tiling
156
+ def enable_tiling(self, use_tiling: bool = True):
157
+ r"""
158
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
159
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
160
+ processing larger images.
161
+ """
162
+ self.use_tiling = use_tiling
163
+
164
+ # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_tiling
165
+ def disable_tiling(self):
166
+ r"""
167
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
168
+ decoding in one step.
169
+ """
170
+ self.enable_tiling(False)
171
+
172
+ # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_slicing
173
+ def enable_slicing(self):
174
+ r"""
175
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
176
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
177
+ """
178
+ self.use_slicing = True
179
+
180
+ # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_slicing
181
+ def disable_slicing(self):
182
+ r"""
183
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
184
+ decoding in one step.
185
+ """
186
+ self.use_slicing = False
187
+
188
+ @property
189
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
190
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
191
+ r"""
192
+ Returns:
193
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
194
+ indexed by its weight name.
195
+ """
196
+ # set recursively
197
+ processors = {}
198
+
199
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
200
+ if hasattr(module, "get_processor"):
201
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
202
+
203
+ for sub_name, child in module.named_children():
204
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
205
+
206
+ return processors
207
+
208
+ for name, module in self.named_children():
209
+ fn_recursive_add_processors(name, module, processors)
210
+
211
+ return processors
212
+
213
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
214
+ def set_attn_processor(
215
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
216
+ ):
217
+ r"""
218
+ Sets the attention processor to use to compute attention.
219
+
220
+ Parameters:
221
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
222
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
223
+ for **all** `Attention` layers.
224
+
225
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
226
+ processor. This is strongly recommended when setting trainable attention processors.
227
+
228
+ """
229
+ count = len(self.attn_processors.keys())
230
+
231
+ if isinstance(processor, dict) and len(processor) != count:
232
+ raise ValueError(
233
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
234
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
235
+ )
236
+
237
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
238
+ if hasattr(module, "set_processor"):
239
+ if not isinstance(processor, dict):
240
+ module.set_processor(processor, _remove_lora=_remove_lora)
241
+ else:
242
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
243
+
244
+ for sub_name, child in module.named_children():
245
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
246
+
247
+ for name, module in self.named_children():
248
+ fn_recursive_attn_processor(name, module, processor)
249
+
250
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
251
+ def set_default_attn_processor(self):
252
+ """
253
+ Disables custom attention processors and sets the default attention implementation.
254
+ """
255
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
256
+ processor = AttnAddedKVProcessor()
257
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
258
+ processor = AttnProcessor()
259
+ else:
260
+ raise ValueError(
261
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
262
+ )
263
+
264
+ self.set_attn_processor(processor, _remove_lora=True)
265
+
266
+ @apply_forward_hook
267
+ def encode(
268
+ self, x: torch.FloatTensor, return_dict: bool = True
269
+ ) -> Union[ConsistencyDecoderVAEOutput, Tuple[DiagonalGaussianDistribution]]:
270
+ """
271
+ Encode a batch of images into latents.
272
+
273
+ Args:
274
+ x (`torch.FloatTensor`): Input batch of images.
275
+ return_dict (`bool`, *optional*, defaults to `True`):
276
+ Whether to return a [`~models.consistecy_decoder_vae.ConsistencyDecoderOoutput`] instead of a plain
277
+ tuple.
278
+
279
+ Returns:
280
+ The latent representations of the encoded images. If `return_dict` is True, a
281
+ [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned, otherwise a plain `tuple`
282
+ is returned.
283
+ """
284
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
285
+ return self.tiled_encode(x, return_dict=return_dict)
286
+
287
+ if self.use_slicing and x.shape[0] > 1:
288
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
289
+ h = torch.cat(encoded_slices)
290
+ else:
291
+ h = self.encoder(x)
292
+
293
+ moments = self.quant_conv(h)
294
+ posterior = DiagonalGaussianDistribution(moments)
295
+
296
+ if not return_dict:
297
+ return (posterior,)
298
+
299
+ return ConsistencyDecoderVAEOutput(latent_dist=posterior)
300
+
301
+ @apply_forward_hook
302
+ def decode(
303
+ self,
304
+ z: torch.FloatTensor,
305
+ generator: Optional[torch.Generator] = None,
306
+ return_dict: bool = True,
307
+ num_inference_steps=2,
308
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
309
+ z = (z * self.config.scaling_factor - self.means) / self.stds
310
+
311
+ scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
312
+ z = F.interpolate(z, mode="nearest", scale_factor=scale_factor)
313
+
314
+ batch_size, _, height, width = z.shape
315
+
316
+ self.decoder_scheduler.set_timesteps(num_inference_steps, device=self.device)
317
+
318
+ x_t = self.decoder_scheduler.init_noise_sigma * randn_tensor(
319
+ (batch_size, 3, height, width), generator=generator, dtype=z.dtype, device=z.device
320
+ )
321
+
322
+ for t in self.decoder_scheduler.timesteps:
323
+ model_input = torch.concat([self.decoder_scheduler.scale_model_input(x_t, t), z], dim=1)
324
+ model_output = self.decoder_unet(model_input, t).sample[:, :3, :, :]
325
+ prev_sample = self.decoder_scheduler.step(model_output, t, x_t, generator).prev_sample
326
+ x_t = prev_sample
327
+
328
+ x_0 = x_t
329
+
330
+ if not return_dict:
331
+ return (x_0,)
332
+
333
+ return DecoderOutput(sample=x_0)
334
+
335
+ # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_v
336
+ def blend_v(self, a, b, blend_extent):
337
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
338
+ for y in range(blend_extent):
339
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
340
+ return b
341
+
342
+ # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_h
343
+ def blend_h(self, a, b, blend_extent):
344
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
345
+ for x in range(blend_extent):
346
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
347
+ return b
348
+
349
+ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> ConsistencyDecoderVAEOutput:
350
+ r"""Encode a batch of images using a tiled encoder.
351
+
352
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
353
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
354
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
355
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
356
+ output, but they should be much less noticeable.
357
+
358
+ Args:
359
+ x (`torch.FloatTensor`): Input batch of images.
360
+ return_dict (`bool`, *optional*, defaults to `True`):
361
+ Whether or not to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a
362
+ plain tuple.
363
+
364
+ Returns:
365
+ [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] or `tuple`:
366
+ If return_dict is True, a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned,
367
+ otherwise a plain `tuple` is returned.
368
+ """
369
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
370
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
371
+ row_limit = self.tile_latent_min_size - blend_extent
372
+
373
+ # Split the image into 512x512 tiles and encode them separately.
374
+ rows = []
375
+ for i in range(0, x.shape[2], overlap_size):
376
+ row = []
377
+ for j in range(0, x.shape[3], overlap_size):
378
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
379
+ tile = self.encoder(tile)
380
+ tile = self.quant_conv(tile)
381
+ row.append(tile)
382
+ rows.append(row)
383
+ result_rows = []
384
+ for i, row in enumerate(rows):
385
+ result_row = []
386
+ for j, tile in enumerate(row):
387
+ # blend the above tile and the left tile
388
+ # to the current tile and add the current tile to the result row
389
+ if i > 0:
390
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
391
+ if j > 0:
392
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
393
+ result_row.append(tile[:, :, :row_limit, :row_limit])
394
+ result_rows.append(torch.cat(result_row, dim=3))
395
+
396
+ moments = torch.cat(result_rows, dim=2)
397
+ posterior = DiagonalGaussianDistribution(moments)
398
+
399
+ if not return_dict:
400
+ return (posterior,)
401
+
402
+ return ConsistencyDecoderVAEOutput(latent_dist=posterior)
403
+
404
+ def forward(
405
+ self,
406
+ sample: torch.FloatTensor,
407
+ sample_posterior: bool = False,
408
+ return_dict: bool = True,
409
+ generator: Optional[torch.Generator] = None,
410
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
411
+ r"""
412
+ Args:
413
+ sample (`torch.FloatTensor`): Input sample.
414
+ sample_posterior (`bool`, *optional*, defaults to `False`):
415
+ Whether to sample from the posterior.
416
+ return_dict (`bool`, *optional*, defaults to `True`):
417
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
418
+ """
419
+ x = sample
420
+ posterior = self.encode(x).latent_dist
421
+ if sample_posterior:
422
+ z = posterior.sample(generator=generator)
423
+ else:
424
+ z = posterior.mode()
425
+ dec = self.decode(z, generator=generator).sample
426
+
427
+ if not return_dict:
428
+ return (dec,)
429
+
430
+ return DecoderOutput(sample=dec)
diffusers/models/controlnet.py ADDED
@@ -0,0 +1,844 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+ from torch.nn import functional as F
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..loaders import FromOriginalControlnetMixin
23
+ from ..utils import BaseOutput, logging
24
+ from .attention_processor import (
25
+ ADDED_KV_ATTENTION_PROCESSORS,
26
+ CROSS_ATTENTION_PROCESSORS,
27
+ AttentionProcessor,
28
+ AttnAddedKVProcessor,
29
+ AttnProcessor,
30
+ )
31
+ from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
32
+ from .modeling_utils import ModelMixin
33
+ from .unet_2d_blocks import (
34
+ CrossAttnDownBlock2D,
35
+ DownBlock2D,
36
+ UNetMidBlock2DCrossAttn,
37
+ get_down_block,
38
+ )
39
+ from .unet_2d_condition import UNet2DConditionModel
40
+
41
+
42
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
+
44
+
45
+ @dataclass
46
+ class ControlNetOutput(BaseOutput):
47
+ """
48
+ The output of [`ControlNetModel`].
49
+
50
+ Args:
51
+ down_block_res_samples (`tuple[torch.Tensor]`):
52
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
53
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
54
+ used to condition the original UNet's downsampling activations.
55
+ mid_down_block_re_sample (`torch.Tensor`):
56
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
57
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
58
+ Output can be used to condition the original UNet's middle block activation.
59
+ """
60
+
61
+ down_block_res_samples: Tuple[torch.Tensor]
62
+ mid_block_res_sample: torch.Tensor
63
+
64
+
65
+ class ControlNetConditioningEmbedding(nn.Module):
66
+ """
67
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
68
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
69
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
70
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
71
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
72
+ model) to encode image-space conditions ... into feature maps ..."
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ conditioning_embedding_channels: int,
78
+ conditioning_channels: int = 3,
79
+ block_out_channels: Tuple[int] = (16, 32, 96, 256),
80
+ ):
81
+ super().__init__()
82
+
83
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
84
+
85
+ self.blocks = nn.ModuleList([])
86
+
87
+ for i in range(len(block_out_channels) - 1):
88
+ channel_in = block_out_channels[i]
89
+ channel_out = block_out_channels[i + 1]
90
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
91
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
92
+
93
+ self.conv_out = zero_module(
94
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
95
+ )
96
+
97
+ def forward(self, conditioning):
98
+ embedding = self.conv_in(conditioning)
99
+ embedding = F.silu(embedding)
100
+
101
+ for block in self.blocks:
102
+ embedding = block(embedding)
103
+ embedding = F.silu(embedding)
104
+
105
+ embedding = self.conv_out(embedding)
106
+
107
+ return embedding
108
+
109
+
110
+ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
111
+ """
112
+ A ControlNet model.
113
+
114
+ Args:
115
+ in_channels (`int`, defaults to 4):
116
+ The number of channels in the input sample.
117
+ flip_sin_to_cos (`bool`, defaults to `True`):
118
+ Whether to flip the sin to cos in the time embedding.
119
+ freq_shift (`int`, defaults to 0):
120
+ The frequency shift to apply to the time embedding.
121
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
122
+ The tuple of downsample blocks to use.
123
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
124
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
125
+ The tuple of output channels for each block.
126
+ layers_per_block (`int`, defaults to 2):
127
+ The number of layers per block.
128
+ downsample_padding (`int`, defaults to 1):
129
+ The padding to use for the downsampling convolution.
130
+ mid_block_scale_factor (`float`, defaults to 1):
131
+ The scale factor to use for the mid block.
132
+ act_fn (`str`, defaults to "silu"):
133
+ The activation function to use.
134
+ norm_num_groups (`int`, *optional*, defaults to 32):
135
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
136
+ in post-processing.
137
+ norm_eps (`float`, defaults to 1e-5):
138
+ The epsilon to use for the normalization.
139
+ cross_attention_dim (`int`, defaults to 1280):
140
+ The dimension of the cross attention features.
141
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
142
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
143
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
144
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
145
+ encoder_hid_dim (`int`, *optional*, defaults to None):
146
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
147
+ dimension to `cross_attention_dim`.
148
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
149
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
150
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
151
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
152
+ The dimension of the attention heads.
153
+ use_linear_projection (`bool`, defaults to `False`):
154
+ class_embed_type (`str`, *optional*, defaults to `None`):
155
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
156
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
157
+ addition_embed_type (`str`, *optional*, defaults to `None`):
158
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
159
+ "text". "text" will use the `TextTimeEmbedding` layer.
160
+ num_class_embeds (`int`, *optional*, defaults to 0):
161
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
162
+ class conditioning with `class_embed_type` equal to `None`.
163
+ upcast_attention (`bool`, defaults to `False`):
164
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
165
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
166
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
167
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
168
+ `class_embed_type="projection"`.
169
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
170
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
171
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
172
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
173
+ global_pool_conditions (`bool`, defaults to `False`):
174
+ """
175
+
176
+ _supports_gradient_checkpointing = True
177
+
178
+ @register_to_config
179
+ def __init__(
180
+ self,
181
+ in_channels: int = 4,
182
+ conditioning_channels: int = 3,
183
+ flip_sin_to_cos: bool = True,
184
+ freq_shift: int = 0,
185
+ down_block_types: Tuple[str] = (
186
+ "CrossAttnDownBlock2D",
187
+ "CrossAttnDownBlock2D",
188
+ "CrossAttnDownBlock2D",
189
+ "DownBlock2D",
190
+ ),
191
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
192
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
193
+ layers_per_block: int = 2,
194
+ downsample_padding: int = 1,
195
+ mid_block_scale_factor: float = 1,
196
+ act_fn: str = "silu",
197
+ norm_num_groups: Optional[int] = 32,
198
+ norm_eps: float = 1e-5,
199
+ cross_attention_dim: int = 1280,
200
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
201
+ encoder_hid_dim: Optional[int] = None,
202
+ encoder_hid_dim_type: Optional[str] = None,
203
+ attention_head_dim: Union[int, Tuple[int]] = 8,
204
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
205
+ use_linear_projection: bool = False,
206
+ class_embed_type: Optional[str] = None,
207
+ addition_embed_type: Optional[str] = None,
208
+ addition_time_embed_dim: Optional[int] = None,
209
+ num_class_embeds: Optional[int] = None,
210
+ upcast_attention: bool = False,
211
+ resnet_time_scale_shift: str = "default",
212
+ projection_class_embeddings_input_dim: Optional[int] = None,
213
+ controlnet_conditioning_channel_order: str = "rgb",
214
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
215
+ global_pool_conditions: bool = False,
216
+ addition_embed_type_num_heads=64,
217
+ ):
218
+ super().__init__()
219
+
220
+ # If `num_attention_heads` is not defined (which is the case for most models)
221
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
222
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
223
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
224
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
225
+ # which is why we correct for the naming here.
226
+ num_attention_heads = num_attention_heads or attention_head_dim
227
+
228
+ # Check inputs
229
+ if len(block_out_channels) != len(down_block_types):
230
+ raise ValueError(
231
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
232
+ )
233
+
234
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
235
+ raise ValueError(
236
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
237
+ )
238
+
239
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
240
+ raise ValueError(
241
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
242
+ )
243
+
244
+ if isinstance(transformer_layers_per_block, int):
245
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
246
+
247
+ # input
248
+ conv_in_kernel = 3
249
+ conv_in_padding = (conv_in_kernel - 1) // 2
250
+ self.conv_in = nn.Conv2d(
251
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
252
+ )
253
+
254
+ # time
255
+ time_embed_dim = block_out_channels[0] * 4
256
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
257
+ timestep_input_dim = block_out_channels[0]
258
+ self.time_embedding = TimestepEmbedding(
259
+ timestep_input_dim,
260
+ time_embed_dim,
261
+ act_fn=act_fn,
262
+ )
263
+
264
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
265
+ encoder_hid_dim_type = "text_proj"
266
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
267
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
268
+
269
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
270
+ raise ValueError(
271
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
272
+ )
273
+
274
+ if encoder_hid_dim_type == "text_proj":
275
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
276
+ elif encoder_hid_dim_type == "text_image_proj":
277
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
278
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
279
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
280
+ self.encoder_hid_proj = TextImageProjection(
281
+ text_embed_dim=encoder_hid_dim,
282
+ image_embed_dim=cross_attention_dim,
283
+ cross_attention_dim=cross_attention_dim,
284
+ )
285
+
286
+ elif encoder_hid_dim_type is not None:
287
+ raise ValueError(
288
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
289
+ )
290
+ else:
291
+ self.encoder_hid_proj = None
292
+
293
+ # class embedding
294
+ if class_embed_type is None and num_class_embeds is not None:
295
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
296
+ elif class_embed_type == "timestep":
297
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
298
+ elif class_embed_type == "identity":
299
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
300
+ elif class_embed_type == "projection":
301
+ if projection_class_embeddings_input_dim is None:
302
+ raise ValueError(
303
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
304
+ )
305
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
306
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
307
+ # 2. it projects from an arbitrary input dimension.
308
+ #
309
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
310
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
311
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
312
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
313
+ else:
314
+ self.class_embedding = None
315
+
316
+ if addition_embed_type == "text":
317
+ if encoder_hid_dim is not None:
318
+ text_time_embedding_from_dim = encoder_hid_dim
319
+ else:
320
+ text_time_embedding_from_dim = cross_attention_dim
321
+
322
+ self.add_embedding = TextTimeEmbedding(
323
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
324
+ )
325
+ elif addition_embed_type == "text_image":
326
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
327
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
328
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
329
+ self.add_embedding = TextImageTimeEmbedding(
330
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
331
+ )
332
+ elif addition_embed_type == "text_time":
333
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
334
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
335
+
336
+ elif addition_embed_type is not None:
337
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
338
+
339
+ # control net conditioning embedding
340
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
341
+ conditioning_embedding_channels=block_out_channels[0],
342
+ block_out_channels=conditioning_embedding_out_channels,
343
+ conditioning_channels=conditioning_channels,
344
+ )
345
+
346
+ self.down_blocks = nn.ModuleList([])
347
+ self.controlnet_down_blocks = nn.ModuleList([])
348
+
349
+ if isinstance(only_cross_attention, bool):
350
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
351
+
352
+ if isinstance(attention_head_dim, int):
353
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
354
+
355
+ if isinstance(num_attention_heads, int):
356
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
357
+
358
+ # down
359
+ output_channel = block_out_channels[0]
360
+
361
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
362
+ controlnet_block = zero_module(controlnet_block)
363
+ self.controlnet_down_blocks.append(controlnet_block)
364
+
365
+ for i, down_block_type in enumerate(down_block_types):
366
+ input_channel = output_channel
367
+ output_channel = block_out_channels[i]
368
+ is_final_block = i == len(block_out_channels) - 1
369
+
370
+ down_block = get_down_block(
371
+ down_block_type,
372
+ num_layers=layers_per_block,
373
+ transformer_layers_per_block=transformer_layers_per_block[i],
374
+ in_channels=input_channel,
375
+ out_channels=output_channel,
376
+ temb_channels=time_embed_dim,
377
+ add_downsample=not is_final_block,
378
+ resnet_eps=norm_eps,
379
+ resnet_act_fn=act_fn,
380
+ resnet_groups=norm_num_groups,
381
+ cross_attention_dim=cross_attention_dim,
382
+ num_attention_heads=num_attention_heads[i],
383
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
384
+ downsample_padding=downsample_padding,
385
+ use_linear_projection=use_linear_projection,
386
+ only_cross_attention=only_cross_attention[i],
387
+ upcast_attention=upcast_attention,
388
+ resnet_time_scale_shift=resnet_time_scale_shift,
389
+ )
390
+ self.down_blocks.append(down_block)
391
+
392
+ for _ in range(layers_per_block):
393
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
394
+ controlnet_block = zero_module(controlnet_block)
395
+ self.controlnet_down_blocks.append(controlnet_block)
396
+
397
+ if not is_final_block:
398
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
399
+ controlnet_block = zero_module(controlnet_block)
400
+ self.controlnet_down_blocks.append(controlnet_block)
401
+
402
+ # mid
403
+ mid_block_channel = block_out_channels[-1]
404
+
405
+ controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
406
+ controlnet_block = zero_module(controlnet_block)
407
+ self.controlnet_mid_block = controlnet_block
408
+
409
+ self.mid_block = UNetMidBlock2DCrossAttn(
410
+ transformer_layers_per_block=transformer_layers_per_block[-1],
411
+ in_channels=mid_block_channel,
412
+ temb_channels=time_embed_dim,
413
+ resnet_eps=norm_eps,
414
+ resnet_act_fn=act_fn,
415
+ output_scale_factor=mid_block_scale_factor,
416
+ resnet_time_scale_shift=resnet_time_scale_shift,
417
+ cross_attention_dim=cross_attention_dim,
418
+ num_attention_heads=num_attention_heads[-1],
419
+ resnet_groups=norm_num_groups,
420
+ use_linear_projection=use_linear_projection,
421
+ upcast_attention=upcast_attention,
422
+ )
423
+
424
+ @classmethod
425
+ def from_unet(
426
+ cls,
427
+ unet: UNet2DConditionModel,
428
+ controlnet_conditioning_channel_order: str = "rgb",
429
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
430
+ load_weights_from_unet: bool = True,
431
+ ):
432
+ r"""
433
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
434
+
435
+ Parameters:
436
+ unet (`UNet2DConditionModel`):
437
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
438
+ where applicable.
439
+ """
440
+ transformer_layers_per_block = (
441
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
442
+ )
443
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
444
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
445
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
446
+ addition_time_embed_dim = (
447
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
448
+ )
449
+
450
+ controlnet = cls(
451
+ encoder_hid_dim=encoder_hid_dim,
452
+ encoder_hid_dim_type=encoder_hid_dim_type,
453
+ addition_embed_type=addition_embed_type,
454
+ addition_time_embed_dim=addition_time_embed_dim,
455
+ transformer_layers_per_block=transformer_layers_per_block,
456
+ in_channels=unet.config.in_channels,
457
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
458
+ freq_shift=unet.config.freq_shift,
459
+ down_block_types=unet.config.down_block_types,
460
+ only_cross_attention=unet.config.only_cross_attention,
461
+ block_out_channels=unet.config.block_out_channels,
462
+ layers_per_block=unet.config.layers_per_block,
463
+ downsample_padding=unet.config.downsample_padding,
464
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
465
+ act_fn=unet.config.act_fn,
466
+ norm_num_groups=unet.config.norm_num_groups,
467
+ norm_eps=unet.config.norm_eps,
468
+ cross_attention_dim=unet.config.cross_attention_dim,
469
+ attention_head_dim=unet.config.attention_head_dim,
470
+ num_attention_heads=unet.config.num_attention_heads,
471
+ use_linear_projection=unet.config.use_linear_projection,
472
+ class_embed_type=unet.config.class_embed_type,
473
+ num_class_embeds=unet.config.num_class_embeds,
474
+ upcast_attention=unet.config.upcast_attention,
475
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
476
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
477
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
478
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
479
+ )
480
+
481
+ if load_weights_from_unet:
482
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
483
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
484
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
485
+
486
+ if controlnet.class_embedding:
487
+ controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
488
+
489
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
490
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
491
+
492
+ return controlnet
493
+
494
+ @property
495
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
496
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
497
+ r"""
498
+ Returns:
499
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
500
+ indexed by its weight name.
501
+ """
502
+ # set recursively
503
+ processors = {}
504
+
505
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
506
+ if hasattr(module, "get_processor"):
507
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
508
+
509
+ for sub_name, child in module.named_children():
510
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
511
+
512
+ return processors
513
+
514
+ for name, module in self.named_children():
515
+ fn_recursive_add_processors(name, module, processors)
516
+
517
+ return processors
518
+
519
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
520
+ def set_attn_processor(
521
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
522
+ ):
523
+ r"""
524
+ Sets the attention processor to use to compute attention.
525
+
526
+ Parameters:
527
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
528
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
529
+ for **all** `Attention` layers.
530
+
531
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
532
+ processor. This is strongly recommended when setting trainable attention processors.
533
+
534
+ """
535
+ count = len(self.attn_processors.keys())
536
+
537
+ if isinstance(processor, dict) and len(processor) != count:
538
+ raise ValueError(
539
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
540
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
541
+ )
542
+
543
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
544
+ if hasattr(module, "set_processor"):
545
+ if not isinstance(processor, dict):
546
+ module.set_processor(processor, _remove_lora=_remove_lora)
547
+ else:
548
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
549
+
550
+ for sub_name, child in module.named_children():
551
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
552
+
553
+ for name, module in self.named_children():
554
+ fn_recursive_attn_processor(name, module, processor)
555
+
556
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
557
+ def set_default_attn_processor(self):
558
+ """
559
+ Disables custom attention processors and sets the default attention implementation.
560
+ """
561
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
562
+ processor = AttnAddedKVProcessor()
563
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
564
+ processor = AttnProcessor()
565
+ else:
566
+ raise ValueError(
567
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
568
+ )
569
+
570
+ self.set_attn_processor(processor, _remove_lora=True)
571
+
572
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
573
+ def set_attention_slice(self, slice_size):
574
+ r"""
575
+ Enable sliced attention computation.
576
+
577
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
578
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
579
+
580
+ Args:
581
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
582
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
583
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
584
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
585
+ must be a multiple of `slice_size`.
586
+ """
587
+ sliceable_head_dims = []
588
+
589
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
590
+ if hasattr(module, "set_attention_slice"):
591
+ sliceable_head_dims.append(module.sliceable_head_dim)
592
+
593
+ for child in module.children():
594
+ fn_recursive_retrieve_sliceable_dims(child)
595
+
596
+ # retrieve number of attention layers
597
+ for module in self.children():
598
+ fn_recursive_retrieve_sliceable_dims(module)
599
+
600
+ num_sliceable_layers = len(sliceable_head_dims)
601
+
602
+ if slice_size == "auto":
603
+ # half the attention head size is usually a good trade-off between
604
+ # speed and memory
605
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
606
+ elif slice_size == "max":
607
+ # make smallest slice possible
608
+ slice_size = num_sliceable_layers * [1]
609
+
610
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
611
+
612
+ if len(slice_size) != len(sliceable_head_dims):
613
+ raise ValueError(
614
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
615
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
616
+ )
617
+
618
+ for i in range(len(slice_size)):
619
+ size = slice_size[i]
620
+ dim = sliceable_head_dims[i]
621
+ if size is not None and size > dim:
622
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
623
+
624
+ # Recursively walk through all the children.
625
+ # Any children which exposes the set_attention_slice method
626
+ # gets the message
627
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
628
+ if hasattr(module, "set_attention_slice"):
629
+ module.set_attention_slice(slice_size.pop())
630
+
631
+ for child in module.children():
632
+ fn_recursive_set_attention_slice(child, slice_size)
633
+
634
+ reversed_slice_size = list(reversed(slice_size))
635
+ for module in self.children():
636
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
637
+
638
+ def _set_gradient_checkpointing(self, module, value=False):
639
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
640
+ module.gradient_checkpointing = value
641
+
642
+ def forward(
643
+ self,
644
+ sample: torch.FloatTensor,
645
+ timestep: Union[torch.Tensor, float, int],
646
+ encoder_hidden_states: torch.Tensor,
647
+ controlnet_cond: torch.FloatTensor,
648
+ conditioning_scale: float = 1.0,
649
+ class_labels: Optional[torch.Tensor] = None,
650
+ timestep_cond: Optional[torch.Tensor] = None,
651
+ attention_mask: Optional[torch.Tensor] = None,
652
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
653
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
654
+ guess_mode: bool = False,
655
+ return_dict: bool = True,
656
+ ) -> Union[ControlNetOutput, Tuple]:
657
+ """
658
+ The [`ControlNetModel`] forward method.
659
+
660
+ Args:
661
+ sample (`torch.FloatTensor`):
662
+ The noisy input tensor.
663
+ timestep (`Union[torch.Tensor, float, int]`):
664
+ The number of timesteps to denoise an input.
665
+ encoder_hidden_states (`torch.Tensor`):
666
+ The encoder hidden states.
667
+ controlnet_cond (`torch.FloatTensor`):
668
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
669
+ conditioning_scale (`float`, defaults to `1.0`):
670
+ The scale factor for ControlNet outputs.
671
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
672
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
673
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
674
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
675
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
676
+ embeddings.
677
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
678
+ An attention face_hair_mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the face_hair_mask
679
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
680
+ negative values to the attention scores corresponding to "discard" tokens.
681
+ added_cond_kwargs (`dict`):
682
+ Additional conditions for the Stable Diffusion XL UNet.
683
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
684
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
685
+ guess_mode (`bool`, defaults to `False`):
686
+ In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
687
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
688
+ return_dict (`bool`, defaults to `True`):
689
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
690
+
691
+ Returns:
692
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
693
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
694
+ returned where the first element is the sample tensor.
695
+ """
696
+ # check channel order
697
+ channel_order = self.config.controlnet_conditioning_channel_order
698
+
699
+ if channel_order == "rgb":
700
+ # in rgb order by default
701
+ ...
702
+ elif channel_order == "bgr":
703
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
704
+ else:
705
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
706
+
707
+ # prepare attention_mask
708
+ if attention_mask is not None:
709
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
710
+ attention_mask = attention_mask.unsqueeze(1)
711
+
712
+ # 1. time
713
+ timesteps = timestep
714
+ if not torch.is_tensor(timesteps):
715
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
716
+ # This would be a good case for the `match` statement (Python 3.10+)
717
+ is_mps = sample.device.type == "mps"
718
+ if isinstance(timestep, float):
719
+ dtype = torch.float32 if is_mps else torch.float64
720
+ else:
721
+ dtype = torch.int32 if is_mps else torch.int64
722
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
723
+ elif len(timesteps.shape) == 0:
724
+ timesteps = timesteps[None].to(sample.device)
725
+
726
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
727
+ timesteps = timesteps.expand(sample.shape[0])
728
+
729
+ t_emb = self.time_proj(timesteps)
730
+
731
+ # timesteps does not contain any weights and will always return f32 tensors
732
+ # but time_embedding might actually be running in fp16. so we need to cast here.
733
+ # there might be better ways to encapsulate this.
734
+ t_emb = t_emb.to(dtype=sample.dtype)
735
+
736
+ emb = self.time_embedding(t_emb, timestep_cond)
737
+ aug_emb = None
738
+
739
+ if self.class_embedding is not None:
740
+ if class_labels is None:
741
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
742
+
743
+ if self.config.class_embed_type == "timestep":
744
+ class_labels = self.time_proj(class_labels)
745
+
746
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
747
+ emb = emb + class_emb
748
+
749
+ if self.config.addition_embed_type is not None:
750
+ if self.config.addition_embed_type == "text":
751
+ aug_emb = self.add_embedding(encoder_hidden_states)
752
+
753
+ elif self.config.addition_embed_type == "text_time":
754
+ if "text_embeds" not in added_cond_kwargs:
755
+ raise ValueError(
756
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
757
+ )
758
+ text_embeds = added_cond_kwargs.get("text_embeds")
759
+ if "time_ids" not in added_cond_kwargs:
760
+ raise ValueError(
761
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
762
+ )
763
+ time_ids = added_cond_kwargs.get("time_ids")
764
+ time_embeds = self.add_time_proj(time_ids.flatten())
765
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
766
+
767
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
768
+ add_embeds = add_embeds.to(emb.dtype)
769
+ aug_emb = self.add_embedding(add_embeds)
770
+
771
+ emb = emb + aug_emb if aug_emb is not None else emb
772
+
773
+ # 2. pre-process
774
+ sample = self.conv_in(sample)
775
+
776
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
777
+ sample = sample + controlnet_cond
778
+
779
+ # 3. down
780
+ down_block_res_samples = (sample,)
781
+ for downsample_block in self.down_blocks:
782
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
783
+ sample, res_samples = downsample_block(
784
+ hidden_states=sample,
785
+ temb=emb,
786
+ encoder_hidden_states=encoder_hidden_states,
787
+ attention_mask=attention_mask,
788
+ cross_attention_kwargs=cross_attention_kwargs,
789
+ )
790
+ else:
791
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
792
+
793
+ down_block_res_samples += res_samples
794
+
795
+ # 4. mid
796
+ if self.mid_block is not None:
797
+ sample = self.mid_block(
798
+ sample,
799
+ emb,
800
+ encoder_hidden_states=encoder_hidden_states,
801
+ attention_mask=attention_mask,
802
+ cross_attention_kwargs=cross_attention_kwargs,
803
+ )
804
+
805
+ # 5. Control net blocks
806
+
807
+ controlnet_down_block_res_samples = ()
808
+
809
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
810
+ down_block_res_sample = controlnet_block(down_block_res_sample)
811
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
812
+
813
+ down_block_res_samples = controlnet_down_block_res_samples
814
+
815
+ mid_block_res_sample = self.controlnet_mid_block(sample)
816
+
817
+ # 6. scaling
818
+ if guess_mode and not self.config.global_pool_conditions:
819
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
820
+ scales = scales * conditioning_scale
821
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
822
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
823
+ else:
824
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
825
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
826
+
827
+ if self.config.global_pool_conditions:
828
+ down_block_res_samples = [
829
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
830
+ ]
831
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
832
+
833
+ if not return_dict:
834
+ return (down_block_res_samples, mid_block_res_sample)
835
+
836
+ return ControlNetOutput(
837
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
838
+ )
839
+
840
+
841
+ def zero_module(module):
842
+ for p in module.parameters():
843
+ nn.init.zeros_(p)
844
+ return module
diffusers/models/controlnet_flax.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional, Tuple, Union
15
+
16
+ import flax
17
+ import flax.linen as nn
18
+ import jax
19
+ import jax.numpy as jnp
20
+ from flax.core.frozen_dict import FrozenDict
21
+
22
+ from ..configuration_utils import ConfigMixin, flax_register_to_config
23
+ from ..utils import BaseOutput
24
+ from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
25
+ from .modeling_flax_utils import FlaxModelMixin
26
+ from .unet_2d_blocks_flax import (
27
+ FlaxCrossAttnDownBlock2D,
28
+ FlaxDownBlock2D,
29
+ FlaxUNetMidBlock2DCrossAttn,
30
+ )
31
+
32
+
33
+ @flax.struct.dataclass
34
+ class FlaxControlNetOutput(BaseOutput):
35
+ """
36
+ The output of [`FlaxControlNetModel`].
37
+
38
+ Args:
39
+ down_block_res_samples (`jnp.ndarray`):
40
+ mid_block_res_sample (`jnp.ndarray`):
41
+ """
42
+
43
+ down_block_res_samples: jnp.ndarray
44
+ mid_block_res_sample: jnp.ndarray
45
+
46
+
47
+ class FlaxControlNetConditioningEmbedding(nn.Module):
48
+ conditioning_embedding_channels: int
49
+ block_out_channels: Tuple[int] = (16, 32, 96, 256)
50
+ dtype: jnp.dtype = jnp.float32
51
+
52
+ def setup(self):
53
+ self.conv_in = nn.Conv(
54
+ self.block_out_channels[0],
55
+ kernel_size=(3, 3),
56
+ padding=((1, 1), (1, 1)),
57
+ dtype=self.dtype,
58
+ )
59
+
60
+ blocks = []
61
+ for i in range(len(self.block_out_channels) - 1):
62
+ channel_in = self.block_out_channels[i]
63
+ channel_out = self.block_out_channels[i + 1]
64
+ conv1 = nn.Conv(
65
+ channel_in,
66
+ kernel_size=(3, 3),
67
+ padding=((1, 1), (1, 1)),
68
+ dtype=self.dtype,
69
+ )
70
+ blocks.append(conv1)
71
+ conv2 = nn.Conv(
72
+ channel_out,
73
+ kernel_size=(3, 3),
74
+ strides=(2, 2),
75
+ padding=((1, 1), (1, 1)),
76
+ dtype=self.dtype,
77
+ )
78
+ blocks.append(conv2)
79
+ self.blocks = blocks
80
+
81
+ self.conv_out = nn.Conv(
82
+ self.conditioning_embedding_channels,
83
+ kernel_size=(3, 3),
84
+ padding=((1, 1), (1, 1)),
85
+ kernel_init=nn.initializers.zeros_init(),
86
+ bias_init=nn.initializers.zeros_init(),
87
+ dtype=self.dtype,
88
+ )
89
+
90
+ def __call__(self, conditioning):
91
+ embedding = self.conv_in(conditioning)
92
+ embedding = nn.silu(embedding)
93
+
94
+ for block in self.blocks:
95
+ embedding = block(embedding)
96
+ embedding = nn.silu(embedding)
97
+
98
+ embedding = self.conv_out(embedding)
99
+
100
+ return embedding
101
+
102
+
103
+ @flax_register_to_config
104
+ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
105
+ r"""
106
+ A ControlNet model.
107
+
108
+ This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it’s generic methods
109
+ implemented for all models (such as downloading or saving).
110
+
111
+ This model is also a Flax Linen [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
112
+ subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
113
+ general usage and behavior.
114
+
115
+ Inherent JAX features such as the following are supported:
116
+
117
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
118
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
119
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
120
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
121
+
122
+ Parameters:
123
+ sample_size (`int`, *optional*):
124
+ The size of the input sample.
125
+ in_channels (`int`, *optional*, defaults to 4):
126
+ The number of channels in the input sample.
127
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
128
+ The tuple of downsample blocks to use.
129
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
130
+ The tuple of output channels for each block.
131
+ layers_per_block (`int`, *optional*, defaults to 2):
132
+ The number of layers per block.
133
+ attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
134
+ The dimension of the attention heads.
135
+ num_attention_heads (`int` or `Tuple[int]`, *optional*):
136
+ The number of attention heads.
137
+ cross_attention_dim (`int`, *optional*, defaults to 768):
138
+ The dimension of the cross attention features.
139
+ dropout (`float`, *optional*, defaults to 0):
140
+ Dropout probability for down, up and bottleneck blocks.
141
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
142
+ Whether to flip the sin to cos in the time embedding.
143
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
144
+ controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`):
145
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
146
+ conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`):
147
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
148
+ """
149
+ sample_size: int = 32
150
+ in_channels: int = 4
151
+ down_block_types: Tuple[str] = (
152
+ "CrossAttnDownBlock2D",
153
+ "CrossAttnDownBlock2D",
154
+ "CrossAttnDownBlock2D",
155
+ "DownBlock2D",
156
+ )
157
+ only_cross_attention: Union[bool, Tuple[bool]] = False
158
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
159
+ layers_per_block: int = 2
160
+ attention_head_dim: Union[int, Tuple[int]] = 8
161
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None
162
+ cross_attention_dim: int = 1280
163
+ dropout: float = 0.0
164
+ use_linear_projection: bool = False
165
+ dtype: jnp.dtype = jnp.float32
166
+ flip_sin_to_cos: bool = True
167
+ freq_shift: int = 0
168
+ controlnet_conditioning_channel_order: str = "rgb"
169
+ conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256)
170
+
171
+ def init_weights(self, rng: jax.Array) -> FrozenDict:
172
+ # init input tensors
173
+ sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
174
+ sample = jnp.zeros(sample_shape, dtype=jnp.float32)
175
+ timesteps = jnp.ones((1,), dtype=jnp.int32)
176
+ encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
177
+ controlnet_cond_shape = (1, 3, self.sample_size * 8, self.sample_size * 8)
178
+ controlnet_cond = jnp.zeros(controlnet_cond_shape, dtype=jnp.float32)
179
+
180
+ params_rng, dropout_rng = jax.random.split(rng)
181
+ rngs = {"params": params_rng, "dropout": dropout_rng}
182
+
183
+ return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
184
+
185
+ def setup(self):
186
+ block_out_channels = self.block_out_channels
187
+ time_embed_dim = block_out_channels[0] * 4
188
+
189
+ # If `num_attention_heads` is not defined (which is the case for most models)
190
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
191
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
192
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
193
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
194
+ # which is why we correct for the naming here.
195
+ num_attention_heads = self.num_attention_heads or self.attention_head_dim
196
+
197
+ # input
198
+ self.conv_in = nn.Conv(
199
+ block_out_channels[0],
200
+ kernel_size=(3, 3),
201
+ strides=(1, 1),
202
+ padding=((1, 1), (1, 1)),
203
+ dtype=self.dtype,
204
+ )
205
+
206
+ # time
207
+ self.time_proj = FlaxTimesteps(
208
+ block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift
209
+ )
210
+ self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
211
+
212
+ self.controlnet_cond_embedding = FlaxControlNetConditioningEmbedding(
213
+ conditioning_embedding_channels=block_out_channels[0],
214
+ block_out_channels=self.conditioning_embedding_out_channels,
215
+ )
216
+
217
+ only_cross_attention = self.only_cross_attention
218
+ if isinstance(only_cross_attention, bool):
219
+ only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
220
+
221
+ if isinstance(num_attention_heads, int):
222
+ num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
223
+
224
+ # down
225
+ down_blocks = []
226
+ controlnet_down_blocks = []
227
+
228
+ output_channel = block_out_channels[0]
229
+
230
+ controlnet_block = nn.Conv(
231
+ output_channel,
232
+ kernel_size=(1, 1),
233
+ padding="VALID",
234
+ kernel_init=nn.initializers.zeros_init(),
235
+ bias_init=nn.initializers.zeros_init(),
236
+ dtype=self.dtype,
237
+ )
238
+ controlnet_down_blocks.append(controlnet_block)
239
+
240
+ for i, down_block_type in enumerate(self.down_block_types):
241
+ input_channel = output_channel
242
+ output_channel = block_out_channels[i]
243
+ is_final_block = i == len(block_out_channels) - 1
244
+
245
+ if down_block_type == "CrossAttnDownBlock2D":
246
+ down_block = FlaxCrossAttnDownBlock2D(
247
+ in_channels=input_channel,
248
+ out_channels=output_channel,
249
+ dropout=self.dropout,
250
+ num_layers=self.layers_per_block,
251
+ num_attention_heads=num_attention_heads[i],
252
+ add_downsample=not is_final_block,
253
+ use_linear_projection=self.use_linear_projection,
254
+ only_cross_attention=only_cross_attention[i],
255
+ dtype=self.dtype,
256
+ )
257
+ else:
258
+ down_block = FlaxDownBlock2D(
259
+ in_channels=input_channel,
260
+ out_channels=output_channel,
261
+ dropout=self.dropout,
262
+ num_layers=self.layers_per_block,
263
+ add_downsample=not is_final_block,
264
+ dtype=self.dtype,
265
+ )
266
+
267
+ down_blocks.append(down_block)
268
+
269
+ for _ in range(self.layers_per_block):
270
+ controlnet_block = nn.Conv(
271
+ output_channel,
272
+ kernel_size=(1, 1),
273
+ padding="VALID",
274
+ kernel_init=nn.initializers.zeros_init(),
275
+ bias_init=nn.initializers.zeros_init(),
276
+ dtype=self.dtype,
277
+ )
278
+ controlnet_down_blocks.append(controlnet_block)
279
+
280
+ if not is_final_block:
281
+ controlnet_block = nn.Conv(
282
+ output_channel,
283
+ kernel_size=(1, 1),
284
+ padding="VALID",
285
+ kernel_init=nn.initializers.zeros_init(),
286
+ bias_init=nn.initializers.zeros_init(),
287
+ dtype=self.dtype,
288
+ )
289
+ controlnet_down_blocks.append(controlnet_block)
290
+
291
+ self.down_blocks = down_blocks
292
+ self.controlnet_down_blocks = controlnet_down_blocks
293
+
294
+ # mid
295
+ mid_block_channel = block_out_channels[-1]
296
+ self.mid_block = FlaxUNetMidBlock2DCrossAttn(
297
+ in_channels=mid_block_channel,
298
+ dropout=self.dropout,
299
+ num_attention_heads=num_attention_heads[-1],
300
+ use_linear_projection=self.use_linear_projection,
301
+ dtype=self.dtype,
302
+ )
303
+
304
+ self.controlnet_mid_block = nn.Conv(
305
+ mid_block_channel,
306
+ kernel_size=(1, 1),
307
+ padding="VALID",
308
+ kernel_init=nn.initializers.zeros_init(),
309
+ bias_init=nn.initializers.zeros_init(),
310
+ dtype=self.dtype,
311
+ )
312
+
313
+ def __call__(
314
+ self,
315
+ sample,
316
+ timesteps,
317
+ encoder_hidden_states,
318
+ controlnet_cond,
319
+ conditioning_scale: float = 1.0,
320
+ return_dict: bool = True,
321
+ train: bool = False,
322
+ ) -> Union[FlaxControlNetOutput, Tuple]:
323
+ r"""
324
+ Args:
325
+ sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
326
+ timestep (`jnp.ndarray` or `float` or `int`): timesteps
327
+ encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
328
+ controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
329
+ conditioning_scale: (`float`) the scale factor for controlnet outputs
330
+ return_dict (`bool`, *optional*, defaults to `True`):
331
+ Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
332
+ plain tuple.
333
+ train (`bool`, *optional*, defaults to `False`):
334
+ Use deterministic functions and disable dropout when not training.
335
+
336
+ Returns:
337
+ [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
338
+ [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
339
+ When returning a tuple, the first element is the sample tensor.
340
+ """
341
+ channel_order = self.controlnet_conditioning_channel_order
342
+ if channel_order == "bgr":
343
+ controlnet_cond = jnp.flip(controlnet_cond, axis=1)
344
+
345
+ # 1. time
346
+ if not isinstance(timesteps, jnp.ndarray):
347
+ timesteps = jnp.array([timesteps], dtype=jnp.int32)
348
+ elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
349
+ timesteps = timesteps.astype(dtype=jnp.float32)
350
+ timesteps = jnp.expand_dims(timesteps, 0)
351
+
352
+ t_emb = self.time_proj(timesteps)
353
+ t_emb = self.time_embedding(t_emb)
354
+
355
+ # 2. pre-process
356
+ sample = jnp.transpose(sample, (0, 2, 3, 1))
357
+ sample = self.conv_in(sample)
358
+
359
+ controlnet_cond = jnp.transpose(controlnet_cond, (0, 2, 3, 1))
360
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
361
+ sample += controlnet_cond
362
+
363
+ # 3. down
364
+ down_block_res_samples = (sample,)
365
+ for down_block in self.down_blocks:
366
+ if isinstance(down_block, FlaxCrossAttnDownBlock2D):
367
+ sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
368
+ else:
369
+ sample, res_samples = down_block(sample, t_emb, deterministic=not train)
370
+ down_block_res_samples += res_samples
371
+
372
+ # 4. mid
373
+ sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
374
+
375
+ # 5. contronet blocks
376
+ controlnet_down_block_res_samples = ()
377
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
378
+ down_block_res_sample = controlnet_block(down_block_res_sample)
379
+ controlnet_down_block_res_samples += (down_block_res_sample,)
380
+
381
+ down_block_res_samples = controlnet_down_block_res_samples
382
+
383
+ mid_block_res_sample = self.controlnet_mid_block(sample)
384
+
385
+ # 6. scaling
386
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
387
+ mid_block_res_sample *= conditioning_scale
388
+
389
+ if not return_dict:
390
+ return (down_block_res_samples, mid_block_res_sample)
391
+
392
+ return FlaxControlNetOutput(
393
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
394
+ )
diffusers/models/dual_transformer_2d.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional
15
+
16
+ from torch import nn
17
+
18
+ from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
19
+
20
+
21
+ class DualTransformer2DModel(nn.Module):
22
+ """
23
+ Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
24
+
25
+ Parameters:
26
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
27
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
28
+ in_channels (`int`, *optional*):
29
+ Pass if the input is continuous. The number of channels in the input and output.
30
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
31
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
32
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
33
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
34
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
35
+ `ImagePositionalEmbeddings`.
36
+ num_vector_embeds (`int`, *optional*):
37
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
38
+ Includes the class for the masked latent pixel.
39
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
40
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
41
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
42
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
43
+ up to but not more than steps than `num_embeds_ada_norm`.
44
+ attention_bias (`bool`, *optional*):
45
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ num_attention_heads: int = 16,
51
+ attention_head_dim: int = 88,
52
+ in_channels: Optional[int] = None,
53
+ num_layers: int = 1,
54
+ dropout: float = 0.0,
55
+ norm_num_groups: int = 32,
56
+ cross_attention_dim: Optional[int] = None,
57
+ attention_bias: bool = False,
58
+ sample_size: Optional[int] = None,
59
+ num_vector_embeds: Optional[int] = None,
60
+ activation_fn: str = "geglu",
61
+ num_embeds_ada_norm: Optional[int] = None,
62
+ ):
63
+ super().__init__()
64
+ self.transformers = nn.ModuleList(
65
+ [
66
+ Transformer2DModel(
67
+ num_attention_heads=num_attention_heads,
68
+ attention_head_dim=attention_head_dim,
69
+ in_channels=in_channels,
70
+ num_layers=num_layers,
71
+ dropout=dropout,
72
+ norm_num_groups=norm_num_groups,
73
+ cross_attention_dim=cross_attention_dim,
74
+ attention_bias=attention_bias,
75
+ sample_size=sample_size,
76
+ num_vector_embeds=num_vector_embeds,
77
+ activation_fn=activation_fn,
78
+ num_embeds_ada_norm=num_embeds_ada_norm,
79
+ )
80
+ for _ in range(2)
81
+ ]
82
+ )
83
+
84
+ # Variables that can be set by a pipeline:
85
+
86
+ # The ratio of transformer1 to transformer2's output states to be combined during inference
87
+ self.mix_ratio = 0.5
88
+
89
+ # The shape of `encoder_hidden_states` is expected to be
90
+ # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
91
+ self.condition_lengths = [77, 257]
92
+
93
+ # Which transformer to use to encode which condition.
94
+ # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
95
+ self.transformer_index_for_condition = [1, 0]
96
+
97
+ def forward(
98
+ self,
99
+ hidden_states,
100
+ encoder_hidden_states,
101
+ timestep=None,
102
+ attention_mask=None,
103
+ cross_attention_kwargs=None,
104
+ return_dict: bool = True,
105
+ ):
106
+ """
107
+ Args:
108
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
109
+ When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
110
+ hidden_states.
111
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
112
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
113
+ self-attention.
114
+ timestep ( `torch.long`, *optional*):
115
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
116
+ attention_mask (`torch.FloatTensor`, *optional*):
117
+ Optional attention face_hair_mask to be applied in Attention.
118
+ cross_attention_kwargs (`dict`, *optional*):
119
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
120
+ `self.processor` in
121
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
122
+ return_dict (`bool`, *optional*, defaults to `True`):
123
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
124
+
125
+ Returns:
126
+ [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
127
+ [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
128
+ returning a tuple, the first element is the sample tensor.
129
+ """
130
+ input_states = hidden_states
131
+
132
+ encoded_states = []
133
+ tokens_start = 0
134
+ # attention_mask is not used yet
135
+ for i in range(2):
136
+ # for each of the two transformers, pass the corresponding condition tokens
137
+ condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
138
+ transformer_index = self.transformer_index_for_condition[i]
139
+ encoded_state = self.transformers[transformer_index](
140
+ input_states,
141
+ encoder_hidden_states=condition_state,
142
+ timestep=timestep,
143
+ cross_attention_kwargs=cross_attention_kwargs,
144
+ return_dict=False,
145
+ )[0]
146
+ encoded_states.append(encoded_state - input_states)
147
+ tokens_start += self.condition_lengths[i]
148
+
149
+ output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
150
+ output_states = output_states + input_states
151
+
152
+ if not return_dict:
153
+ return (output_states,)
154
+
155
+ return Transformer2DModelOutput(sample=output_states)
diffusers/models/embeddings.py ADDED
@@ -0,0 +1,792 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from typing import Optional
16
+
17
+ import numpy as np
18
+ import torch
19
+ from torch import nn
20
+
21
+ from ..utils import USE_PEFT_BACKEND
22
+ from .activations import get_activation
23
+ from .lora import LoRACompatibleLinear
24
+
25
+
26
+ def get_timestep_embedding(
27
+ timesteps: torch.Tensor,
28
+ embedding_dim: int,
29
+ flip_sin_to_cos: bool = False,
30
+ downscale_freq_shift: float = 1,
31
+ scale: float = 1,
32
+ max_period: int = 10000,
33
+ ):
34
+ """
35
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
36
+
37
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
38
+ These may be fractional.
39
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
40
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
41
+ """
42
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
43
+
44
+ half_dim = embedding_dim // 2
45
+ exponent = -math.log(max_period) * torch.arange(
46
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
47
+ )
48
+ exponent = exponent / (half_dim - downscale_freq_shift)
49
+
50
+ emb = torch.exp(exponent)
51
+ emb = timesteps[:, None].float() * emb[None, :]
52
+
53
+ # scale embeddings
54
+ emb = scale * emb
55
+
56
+ # concat sine and cosine embeddings
57
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
58
+
59
+ # flip sine and cosine embeddings
60
+ if flip_sin_to_cos:
61
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
62
+
63
+ # zero pad
64
+ if embedding_dim % 2 == 1:
65
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
66
+ return emb
67
+
68
+
69
+ def get_2d_sincos_pos_embed(
70
+ embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
71
+ ):
72
+ """
73
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
74
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
75
+ """
76
+ if isinstance(grid_size, int):
77
+ grid_size = (grid_size, grid_size)
78
+
79
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
80
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
81
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
82
+ grid = np.stack(grid, axis=0)
83
+
84
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
85
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
86
+ if cls_token and extra_tokens > 0:
87
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
88
+ return pos_embed
89
+
90
+
91
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
92
+ if embed_dim % 2 != 0:
93
+ raise ValueError("embed_dim must be divisible by 2")
94
+
95
+ # use half of dimensions to encode grid_h
96
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
97
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
98
+
99
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
100
+ return emb
101
+
102
+
103
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
104
+ """
105
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
106
+ """
107
+ if embed_dim % 2 != 0:
108
+ raise ValueError("embed_dim must be divisible by 2")
109
+
110
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
111
+ omega /= embed_dim / 2.0
112
+ omega = 1.0 / 10000**omega # (D/2,)
113
+
114
+ pos = pos.reshape(-1) # (M,)
115
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
116
+
117
+ emb_sin = np.sin(out) # (M, D/2)
118
+ emb_cos = np.cos(out) # (M, D/2)
119
+
120
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
121
+ return emb
122
+
123
+
124
+ class PatchEmbed(nn.Module):
125
+ """2D Image to Patch Embedding"""
126
+
127
+ def __init__(
128
+ self,
129
+ height=224,
130
+ width=224,
131
+ patch_size=16,
132
+ in_channels=3,
133
+ embed_dim=768,
134
+ layer_norm=False,
135
+ flatten=True,
136
+ bias=True,
137
+ interpolation_scale=1,
138
+ ):
139
+ super().__init__()
140
+
141
+ num_patches = (height // patch_size) * (width // patch_size)
142
+ self.flatten = flatten
143
+ self.layer_norm = layer_norm
144
+
145
+ self.proj = nn.Conv2d(
146
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
147
+ )
148
+ if layer_norm:
149
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
150
+ else:
151
+ self.norm = None
152
+
153
+ self.patch_size = patch_size
154
+ # See:
155
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
156
+ self.height, self.width = height // patch_size, width // patch_size
157
+ self.base_size = height // patch_size
158
+ self.interpolation_scale = interpolation_scale
159
+ pos_embed = get_2d_sincos_pos_embed(
160
+ embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
161
+ )
162
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
163
+
164
+ def forward(self, latent):
165
+ height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
166
+
167
+ latent = self.proj(latent)
168
+ if self.flatten:
169
+ latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
170
+ if self.layer_norm:
171
+ latent = self.norm(latent)
172
+
173
+ # Interpolate positional embeddings if needed.
174
+ # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
175
+ if self.height != height or self.width != width:
176
+ pos_embed = get_2d_sincos_pos_embed(
177
+ embed_dim=self.pos_embed.shape[-1],
178
+ grid_size=(height, width),
179
+ base_size=self.base_size,
180
+ interpolation_scale=self.interpolation_scale,
181
+ )
182
+ pos_embed = torch.from_numpy(pos_embed)
183
+ pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
184
+ else:
185
+ pos_embed = self.pos_embed
186
+
187
+ return (latent + pos_embed).to(latent.dtype)
188
+
189
+
190
+ class TimestepEmbedding(nn.Module):
191
+ def __init__(
192
+ self,
193
+ in_channels: int,
194
+ time_embed_dim: int,
195
+ act_fn: str = "silu",
196
+ out_dim: int = None,
197
+ post_act_fn: Optional[str] = None,
198
+ cond_proj_dim=None,
199
+ ):
200
+ super().__init__()
201
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
202
+
203
+ self.linear_1 = linear_cls(in_channels, time_embed_dim)
204
+
205
+ if cond_proj_dim is not None:
206
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
207
+ else:
208
+ self.cond_proj = None
209
+
210
+ self.act = get_activation(act_fn)
211
+
212
+ if out_dim is not None:
213
+ time_embed_dim_out = out_dim
214
+ else:
215
+ time_embed_dim_out = time_embed_dim
216
+ self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out)
217
+
218
+ if post_act_fn is None:
219
+ self.post_act = None
220
+ else:
221
+ self.post_act = get_activation(post_act_fn)
222
+
223
+ def forward(self, sample, condition=None):
224
+ if condition is not None:
225
+ sample = sample + self.cond_proj(condition)
226
+ sample = self.linear_1(sample)
227
+
228
+ if self.act is not None:
229
+ sample = self.act(sample)
230
+
231
+ sample = self.linear_2(sample)
232
+
233
+ if self.post_act is not None:
234
+ sample = self.post_act(sample)
235
+ return sample
236
+
237
+
238
+ class Timesteps(nn.Module):
239
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
240
+ super().__init__()
241
+ self.num_channels = num_channels
242
+ self.flip_sin_to_cos = flip_sin_to_cos
243
+ self.downscale_freq_shift = downscale_freq_shift
244
+
245
+ def forward(self, timesteps):
246
+ t_emb = get_timestep_embedding(
247
+ timesteps,
248
+ self.num_channels,
249
+ flip_sin_to_cos=self.flip_sin_to_cos,
250
+ downscale_freq_shift=self.downscale_freq_shift,
251
+ )
252
+ return t_emb
253
+
254
+
255
+ class GaussianFourierProjection(nn.Module):
256
+ """Gaussian Fourier embeddings for noise levels."""
257
+
258
+ def __init__(
259
+ self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
260
+ ):
261
+ super().__init__()
262
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
263
+ self.log = log
264
+ self.flip_sin_to_cos = flip_sin_to_cos
265
+
266
+ if set_W_to_weight:
267
+ # to delete later
268
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
269
+
270
+ self.weight = self.W
271
+
272
+ def forward(self, x):
273
+ if self.log:
274
+ x = torch.log(x)
275
+
276
+ x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
277
+
278
+ if self.flip_sin_to_cos:
279
+ out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
280
+ else:
281
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
282
+ return out
283
+
284
+
285
+ class SinusoidalPositionalEmbedding(nn.Module):
286
+ """Apply positional information to a sequence of embeddings.
287
+
288
+ Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
289
+ them
290
+
291
+ Args:
292
+ embed_dim: (int): Dimension of the positional embedding.
293
+ max_seq_length: Maximum sequence length to apply positional embeddings
294
+
295
+ """
296
+
297
+ def __init__(self, embed_dim: int, max_seq_length: int = 32):
298
+ super().__init__()
299
+ position = torch.arange(max_seq_length).unsqueeze(1)
300
+ div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
301
+ pe = torch.zeros(1, max_seq_length, embed_dim)
302
+ pe[0, :, 0::2] = torch.sin(position * div_term)
303
+ pe[0, :, 1::2] = torch.cos(position * div_term)
304
+ self.register_buffer("pe", pe)
305
+
306
+ def forward(self, x):
307
+ _, seq_length, _ = x.shape
308
+ x = x + self.pe[:, :seq_length]
309
+ return x
310
+
311
+
312
+ class ImagePositionalEmbeddings(nn.Module):
313
+ """
314
+ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
315
+ height and width of the latent space.
316
+
317
+ For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
318
+
319
+ For VQ-diffusion:
320
+
321
+ Output vector embeddings are used as input for the transformer.
322
+
323
+ Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
324
+
325
+ Args:
326
+ num_embed (`int`):
327
+ Number of embeddings for the latent pixels embeddings.
328
+ height (`int`):
329
+ Height of the latent image i.e. the number of height embeddings.
330
+ width (`int`):
331
+ Width of the latent image i.e. the number of width embeddings.
332
+ embed_dim (`int`):
333
+ Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
334
+ """
335
+
336
+ def __init__(
337
+ self,
338
+ num_embed: int,
339
+ height: int,
340
+ width: int,
341
+ embed_dim: int,
342
+ ):
343
+ super().__init__()
344
+
345
+ self.height = height
346
+ self.width = width
347
+ self.num_embed = num_embed
348
+ self.embed_dim = embed_dim
349
+
350
+ self.emb = nn.Embedding(self.num_embed, embed_dim)
351
+ self.height_emb = nn.Embedding(self.height, embed_dim)
352
+ self.width_emb = nn.Embedding(self.width, embed_dim)
353
+
354
+ def forward(self, index):
355
+ emb = self.emb(index)
356
+
357
+ height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
358
+
359
+ # 1 x H x D -> 1 x H x 1 x D
360
+ height_emb = height_emb.unsqueeze(2)
361
+
362
+ width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
363
+
364
+ # 1 x W x D -> 1 x 1 x W x D
365
+ width_emb = width_emb.unsqueeze(1)
366
+
367
+ pos_emb = height_emb + width_emb
368
+
369
+ # 1 x H x W x D -> 1 x L xD
370
+ pos_emb = pos_emb.view(1, self.height * self.width, -1)
371
+
372
+ emb = emb + pos_emb[:, : emb.shape[1], :]
373
+
374
+ return emb
375
+
376
+
377
+ class LabelEmbedding(nn.Module):
378
+ """
379
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
380
+
381
+ Args:
382
+ num_classes (`int`): The number of classes.
383
+ hidden_size (`int`): The size of the vector embeddings.
384
+ dropout_prob (`float`): The probability of dropping a label.
385
+ """
386
+
387
+ def __init__(self, num_classes, hidden_size, dropout_prob):
388
+ super().__init__()
389
+ use_cfg_embedding = dropout_prob > 0
390
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
391
+ self.num_classes = num_classes
392
+ self.dropout_prob = dropout_prob
393
+
394
+ def token_drop(self, labels, force_drop_ids=None):
395
+ """
396
+ Drops labels to enable classifier-free guidance.
397
+ """
398
+ if force_drop_ids is None:
399
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
400
+ else:
401
+ drop_ids = torch.tensor(force_drop_ids == 1)
402
+ labels = torch.where(drop_ids, self.num_classes, labels)
403
+ return labels
404
+
405
+ def forward(self, labels: torch.LongTensor, force_drop_ids=None):
406
+ use_dropout = self.dropout_prob > 0
407
+ if (self.training and use_dropout) or (force_drop_ids is not None):
408
+ labels = self.token_drop(labels, force_drop_ids)
409
+ embeddings = self.embedding_table(labels)
410
+ return embeddings
411
+
412
+
413
+ class TextImageProjection(nn.Module):
414
+ def __init__(
415
+ self,
416
+ text_embed_dim: int = 1024,
417
+ image_embed_dim: int = 768,
418
+ cross_attention_dim: int = 768,
419
+ num_image_text_embeds: int = 10,
420
+ ):
421
+ super().__init__()
422
+
423
+ self.num_image_text_embeds = num_image_text_embeds
424
+ self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
425
+ self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
426
+
427
+ def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
428
+ batch_size = text_embeds.shape[0]
429
+
430
+ # image
431
+ image_text_embeds = self.image_embeds(image_embeds)
432
+ image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
433
+
434
+ # text
435
+ text_embeds = self.text_proj(text_embeds)
436
+
437
+ return torch.cat([image_text_embeds, text_embeds], dim=1)
438
+
439
+
440
+ class ImageProjection(nn.Module):
441
+ def __init__(
442
+ self,
443
+ image_embed_dim: int = 768,
444
+ cross_attention_dim: int = 768,
445
+ num_image_text_embeds: int = 32,
446
+ ):
447
+ super().__init__()
448
+
449
+ self.num_image_text_embeds = num_image_text_embeds
450
+ self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
451
+ self.norm = nn.LayerNorm(cross_attention_dim)
452
+
453
+ def forward(self, image_embeds: torch.FloatTensor):
454
+ batch_size = image_embeds.shape[0]
455
+
456
+ # image
457
+ image_embeds = self.image_embeds(image_embeds)
458
+ image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
459
+ image_embeds = self.norm(image_embeds)
460
+ return image_embeds
461
+
462
+
463
+ class CombinedTimestepLabelEmbeddings(nn.Module):
464
+ def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
465
+ super().__init__()
466
+
467
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
468
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
469
+ self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
470
+
471
+ def forward(self, timestep, class_labels, hidden_dtype=None):
472
+ timesteps_proj = self.time_proj(timestep)
473
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
474
+
475
+ class_labels = self.class_embedder(class_labels) # (N, D)
476
+
477
+ conditioning = timesteps_emb + class_labels # (N, D)
478
+
479
+ return conditioning
480
+
481
+
482
+ class TextTimeEmbedding(nn.Module):
483
+ def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
484
+ super().__init__()
485
+ self.norm1 = nn.LayerNorm(encoder_dim)
486
+ self.pool = AttentionPooling(num_heads, encoder_dim)
487
+ self.proj = nn.Linear(encoder_dim, time_embed_dim)
488
+ self.norm2 = nn.LayerNorm(time_embed_dim)
489
+
490
+ def forward(self, hidden_states):
491
+ hidden_states = self.norm1(hidden_states)
492
+ hidden_states = self.pool(hidden_states)
493
+ hidden_states = self.proj(hidden_states)
494
+ hidden_states = self.norm2(hidden_states)
495
+ return hidden_states
496
+
497
+
498
+ class TextImageTimeEmbedding(nn.Module):
499
+ def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
500
+ super().__init__()
501
+ self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
502
+ self.text_norm = nn.LayerNorm(time_embed_dim)
503
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
504
+
505
+ def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
506
+ # text
507
+ time_text_embeds = self.text_proj(text_embeds)
508
+ time_text_embeds = self.text_norm(time_text_embeds)
509
+
510
+ # image
511
+ time_image_embeds = self.image_proj(image_embeds)
512
+
513
+ return time_image_embeds + time_text_embeds
514
+
515
+
516
+ class ImageTimeEmbedding(nn.Module):
517
+ def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
518
+ super().__init__()
519
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
520
+ self.image_norm = nn.LayerNorm(time_embed_dim)
521
+
522
+ def forward(self, image_embeds: torch.FloatTensor):
523
+ # image
524
+ time_image_embeds = self.image_proj(image_embeds)
525
+ time_image_embeds = self.image_norm(time_image_embeds)
526
+ return time_image_embeds
527
+
528
+
529
+ class ImageHintTimeEmbedding(nn.Module):
530
+ def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
531
+ super().__init__()
532
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
533
+ self.image_norm = nn.LayerNorm(time_embed_dim)
534
+ self.input_hint_block = nn.Sequential(
535
+ nn.Conv2d(3, 16, 3, padding=1),
536
+ nn.SiLU(),
537
+ nn.Conv2d(16, 16, 3, padding=1),
538
+ nn.SiLU(),
539
+ nn.Conv2d(16, 32, 3, padding=1, stride=2),
540
+ nn.SiLU(),
541
+ nn.Conv2d(32, 32, 3, padding=1),
542
+ nn.SiLU(),
543
+ nn.Conv2d(32, 96, 3, padding=1, stride=2),
544
+ nn.SiLU(),
545
+ nn.Conv2d(96, 96, 3, padding=1),
546
+ nn.SiLU(),
547
+ nn.Conv2d(96, 256, 3, padding=1, stride=2),
548
+ nn.SiLU(),
549
+ nn.Conv2d(256, 4, 3, padding=1),
550
+ )
551
+
552
+ def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor):
553
+ # image
554
+ time_image_embeds = self.image_proj(image_embeds)
555
+ time_image_embeds = self.image_norm(time_image_embeds)
556
+ hint = self.input_hint_block(hint)
557
+ return time_image_embeds, hint
558
+
559
+
560
+ class AttentionPooling(nn.Module):
561
+ # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
562
+
563
+ def __init__(self, num_heads, embed_dim, dtype=None):
564
+ super().__init__()
565
+ self.dtype = dtype
566
+ self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
567
+ self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
568
+ self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
569
+ self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
570
+ self.num_heads = num_heads
571
+ self.dim_per_head = embed_dim // self.num_heads
572
+
573
+ def forward(self, x):
574
+ bs, length, width = x.size()
575
+
576
+ def shape(x):
577
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
578
+ x = x.view(bs, -1, self.num_heads, self.dim_per_head)
579
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
580
+ x = x.transpose(1, 2)
581
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
582
+ x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
583
+ # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
584
+ x = x.transpose(1, 2)
585
+ return x
586
+
587
+ class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
588
+ x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
589
+
590
+ # (bs*n_heads, class_token_length, dim_per_head)
591
+ q = shape(self.q_proj(class_token))
592
+ # (bs*n_heads, length+class_token_length, dim_per_head)
593
+ k = shape(self.k_proj(x))
594
+ v = shape(self.v_proj(x))
595
+
596
+ # (bs*n_heads, class_token_length, length+class_token_length):
597
+ scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
598
+ weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
599
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
600
+
601
+ # (bs*n_heads, dim_per_head, class_token_length)
602
+ a = torch.einsum("bts,bcs->bct", weight, v)
603
+
604
+ # (bs, length+1, width)
605
+ a = a.reshape(bs, -1, 1).transpose(1, 2)
606
+
607
+ return a[:, 0, :] # cls_token
608
+
609
+
610
+ class FourierEmbedder(nn.Module):
611
+ def __init__(self, num_freqs=64, temperature=100):
612
+ super().__init__()
613
+
614
+ self.num_freqs = num_freqs
615
+ self.temperature = temperature
616
+
617
+ freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
618
+ freq_bands = freq_bands[None, None, None]
619
+ self.register_buffer("freq_bands", freq_bands, persistent=False)
620
+
621
+ def __call__(self, x):
622
+ x = self.freq_bands * x.unsqueeze(-1)
623
+ return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
624
+
625
+
626
+ class PositionNet(nn.Module):
627
+ def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8):
628
+ super().__init__()
629
+ self.positive_len = positive_len
630
+ self.out_dim = out_dim
631
+
632
+ self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
633
+ self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
634
+
635
+ if isinstance(out_dim, tuple):
636
+ out_dim = out_dim[0]
637
+
638
+ if feature_type == "text-only":
639
+ self.linears = nn.Sequential(
640
+ nn.Linear(self.positive_len + self.position_dim, 512),
641
+ nn.SiLU(),
642
+ nn.Linear(512, 512),
643
+ nn.SiLU(),
644
+ nn.Linear(512, out_dim),
645
+ )
646
+ self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
647
+
648
+ elif feature_type == "text-image":
649
+ self.linears_text = nn.Sequential(
650
+ nn.Linear(self.positive_len + self.position_dim, 512),
651
+ nn.SiLU(),
652
+ nn.Linear(512, 512),
653
+ nn.SiLU(),
654
+ nn.Linear(512, out_dim),
655
+ )
656
+ self.linears_image = nn.Sequential(
657
+ nn.Linear(self.positive_len + self.position_dim, 512),
658
+ nn.SiLU(),
659
+ nn.Linear(512, 512),
660
+ nn.SiLU(),
661
+ nn.Linear(512, out_dim),
662
+ )
663
+ self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
664
+ self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
665
+
666
+ self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
667
+
668
+ def forward(
669
+ self,
670
+ boxes,
671
+ masks,
672
+ positive_embeddings=None,
673
+ phrases_masks=None,
674
+ image_masks=None,
675
+ phrases_embeddings=None,
676
+ image_embeddings=None,
677
+ ):
678
+ masks = masks.unsqueeze(-1)
679
+
680
+ # embedding position (it may includes padding as placeholder)
681
+ xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C
682
+
683
+ # learnable null embedding
684
+ xyxy_null = self.null_position_feature.view(1, 1, -1)
685
+
686
+ # replace padding with learnable null embedding
687
+ xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
688
+
689
+ # positionet with text only information
690
+ if positive_embeddings is not None:
691
+ # learnable null embedding
692
+ positive_null = self.null_positive_feature.view(1, 1, -1)
693
+
694
+ # replace padding with learnable null embedding
695
+ positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
696
+
697
+ objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
698
+
699
+ # positionet with text and image infomation
700
+ else:
701
+ phrases_masks = phrases_masks.unsqueeze(-1)
702
+ image_masks = image_masks.unsqueeze(-1)
703
+
704
+ # learnable null embedding
705
+ text_null = self.null_text_feature.view(1, 1, -1)
706
+ image_null = self.null_image_feature.view(1, 1, -1)
707
+
708
+ # replace padding with learnable null embedding
709
+ phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null
710
+ image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null
711
+
712
+ objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1))
713
+ objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1))
714
+ objs = torch.cat([objs_text, objs_image], dim=1)
715
+
716
+ return objs
717
+
718
+
719
+ class CombinedTimestepSizeEmbeddings(nn.Module):
720
+ """
721
+ For PixArt-Alpha.
722
+
723
+ Reference:
724
+ https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
725
+ """
726
+
727
+ def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
728
+ super().__init__()
729
+
730
+ self.outdim = size_emb_dim
731
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
732
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
733
+
734
+ self.use_additional_conditions = use_additional_conditions
735
+ if use_additional_conditions:
736
+ self.use_additional_conditions = True
737
+ self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
738
+ self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
739
+ self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
740
+
741
+ def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module):
742
+ if size.ndim == 1:
743
+ size = size[:, None]
744
+
745
+ if size.shape[0] != batch_size:
746
+ size = size.repeat(batch_size // size.shape[0], 1)
747
+ if size.shape[0] != batch_size:
748
+ raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.")
749
+
750
+ current_batch_size, dims = size.shape[0], size.shape[1]
751
+ size = size.reshape(-1)
752
+ size_freq = self.additional_condition_proj(size).to(size.dtype)
753
+
754
+ size_emb = embedder(size_freq)
755
+ size_emb = size_emb.reshape(current_batch_size, dims * self.outdim)
756
+ return size_emb
757
+
758
+ def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
759
+ timesteps_proj = self.time_proj(timestep)
760
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
761
+
762
+ if self.use_additional_conditions:
763
+ resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder)
764
+ aspect_ratio = self.apply_condition(
765
+ aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder
766
+ )
767
+ conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1)
768
+ else:
769
+ conditioning = timesteps_emb
770
+
771
+ return conditioning
772
+
773
+
774
+ class CaptionProjection(nn.Module):
775
+ """
776
+ Projects caption embeddings. Also handles dropout for classifier-free guidance.
777
+
778
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
779
+ """
780
+
781
+ def __init__(self, in_features, hidden_size, num_tokens=120):
782
+ super().__init__()
783
+ self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
784
+ self.act_1 = nn.GELU(approximate="tanh")
785
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
786
+ self.register_buffer("y_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features**0.5))
787
+
788
+ def forward(self, caption, force_drop_ids=None):
789
+ hidden_states = self.linear_1(caption)
790
+ hidden_states = self.act_1(hidden_states)
791
+ hidden_states = self.linear_2(hidden_states)
792
+ return hidden_states
diffusers/models/embeddings_flax.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+
16
+ import flax.linen as nn
17
+ import jax.numpy as jnp
18
+
19
+
20
+ def get_sinusoidal_embeddings(
21
+ timesteps: jnp.ndarray,
22
+ embedding_dim: int,
23
+ freq_shift: float = 1,
24
+ min_timescale: float = 1,
25
+ max_timescale: float = 1.0e4,
26
+ flip_sin_to_cos: bool = False,
27
+ scale: float = 1.0,
28
+ ) -> jnp.ndarray:
29
+ """Returns the positional encoding (same as Tensor2Tensor).
30
+
31
+ Args:
32
+ timesteps: a 1-D Tensor of N indices, one per batch element.
33
+ These may be fractional.
34
+ embedding_dim: The number of output channels.
35
+ min_timescale: The smallest time unit (should probably be 0.0).
36
+ max_timescale: The largest time unit.
37
+ Returns:
38
+ a Tensor of timing signals [N, num_channels]
39
+ """
40
+ assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
41
+ assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
42
+ num_timescales = float(embedding_dim // 2)
43
+ log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
44
+ inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
45
+ emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
46
+
47
+ # scale embeddings
48
+ scaled_time = scale * emb
49
+
50
+ if flip_sin_to_cos:
51
+ signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
52
+ else:
53
+ signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
54
+ signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
55
+ return signal
56
+
57
+
58
+ class FlaxTimestepEmbedding(nn.Module):
59
+ r"""
60
+ Time step Embedding Module. Learns embeddings for input time steps.
61
+
62
+ Args:
63
+ time_embed_dim (`int`, *optional*, defaults to `32`):
64
+ Time step embedding dimension
65
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
66
+ Parameters `dtype`
67
+ """
68
+ time_embed_dim: int = 32
69
+ dtype: jnp.dtype = jnp.float32
70
+
71
+ @nn.compact
72
+ def __call__(self, temb):
73
+ temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb)
74
+ temb = nn.silu(temb)
75
+ temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb)
76
+ return temb
77
+
78
+
79
+ class FlaxTimesteps(nn.Module):
80
+ r"""
81
+ Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239
82
+
83
+ Args:
84
+ dim (`int`, *optional*, defaults to `32`):
85
+ Time step embedding dimension
86
+ """
87
+ dim: int = 32
88
+ flip_sin_to_cos: bool = False
89
+ freq_shift: float = 1
90
+
91
+ @nn.compact
92
+ def __call__(self, timesteps):
93
+ return get_sinusoidal_embeddings(
94
+ timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift
95
+ )
diffusers/models/lora.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
22
+ from ..utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
26
+
27
+
28
+ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
29
+ for _, attn_module in text_encoder_attn_modules(text_encoder):
30
+ if isinstance(attn_module.q_proj, PatchedLoraProjection):
31
+ attn_module.q_proj.lora_scale = lora_scale
32
+ attn_module.k_proj.lora_scale = lora_scale
33
+ attn_module.v_proj.lora_scale = lora_scale
34
+ attn_module.out_proj.lora_scale = lora_scale
35
+
36
+ for _, mlp_module in text_encoder_mlp_modules(text_encoder):
37
+ if isinstance(mlp_module.fc1, PatchedLoraProjection):
38
+ mlp_module.fc1.lora_scale = lora_scale
39
+ mlp_module.fc2.lora_scale = lora_scale
40
+
41
+
42
+ class LoRALinearLayer(nn.Module):
43
+ r"""
44
+ A linear layer that is used with LoRA.
45
+
46
+ Parameters:
47
+ in_features (`int`):
48
+ Number of input features.
49
+ out_features (`int`):
50
+ Number of output features.
51
+ rank (`int`, `optional`, defaults to 4):
52
+ The rank of the LoRA layer.
53
+ network_alpha (`float`, `optional`, defaults to `None`):
54
+ The value of the network alpha used for stable learning and preventing underflow. This value has the same
55
+ meaning as the `--network_alpha` option in the kohya-ss trainer script. See
56
+ https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
57
+ device (`torch.device`, `optional`, defaults to `None`):
58
+ The device to use for the layer's weights.
59
+ dtype (`torch.dtype`, `optional`, defaults to `None`):
60
+ The dtype to use for the layer's weights.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ in_features: int,
66
+ out_features: int,
67
+ rank: int = 4,
68
+ network_alpha: Optional[float] = None,
69
+ device: Optional[Union[torch.device, str]] = None,
70
+ dtype: Optional[torch.dtype] = None,
71
+ ):
72
+ super().__init__()
73
+
74
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
75
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
76
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
77
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
78
+ self.network_alpha = network_alpha
79
+ self.rank = rank
80
+ self.out_features = out_features
81
+ self.in_features = in_features
82
+
83
+ nn.init.normal_(self.down.weight, std=1 / rank)
84
+ nn.init.zeros_(self.up.weight)
85
+
86
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
87
+ orig_dtype = hidden_states.dtype
88
+ dtype = self.down.weight.dtype
89
+
90
+ down_hidden_states = self.down(hidden_states.to(dtype))
91
+ up_hidden_states = self.up(down_hidden_states)
92
+
93
+ if self.network_alpha is not None:
94
+ up_hidden_states *= self.network_alpha / self.rank
95
+
96
+ return up_hidden_states.to(orig_dtype)
97
+
98
+
99
+ class LoRAConv2dLayer(nn.Module):
100
+ r"""
101
+ A convolutional layer that is used with LoRA.
102
+
103
+ Parameters:
104
+ in_features (`int`):
105
+ Number of input features.
106
+ out_features (`int`):
107
+ Number of output features.
108
+ rank (`int`, `optional`, defaults to 4):
109
+ The rank of the LoRA layer.
110
+ kernel_size (`int` or `tuple` of two `int`, `optional`, defaults to 1):
111
+ The kernel size of the convolution.
112
+ stride (`int` or `tuple` of two `int`, `optional`, defaults to 1):
113
+ The stride of the convolution.
114
+ padding (`int` or `tuple` of two `int` or `str`, `optional`, defaults to 0):
115
+ The padding of the convolution.
116
+ network_alpha (`float`, `optional`, defaults to `None`):
117
+ The value of the network alpha used for stable learning and preventing underflow. This value has the same
118
+ meaning as the `--network_alpha` option in the kohya-ss trainer script. See
119
+ https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ in_features: int,
125
+ out_features: int,
126
+ rank: int = 4,
127
+ kernel_size: Union[int, Tuple[int, int]] = (1, 1),
128
+ stride: Union[int, Tuple[int, int]] = (1, 1),
129
+ padding: Union[int, Tuple[int, int], str] = 0,
130
+ network_alpha: Optional[float] = None,
131
+ ):
132
+ super().__init__()
133
+
134
+ self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
135
+ # according to the official kohya_ss trainer kernel_size are always fixed for the up layer
136
+ # # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
137
+ self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False)
138
+
139
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
140
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
141
+ self.network_alpha = network_alpha
142
+ self.rank = rank
143
+
144
+ nn.init.normal_(self.down.weight, std=1 / rank)
145
+ nn.init.zeros_(self.up.weight)
146
+
147
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
148
+ orig_dtype = hidden_states.dtype
149
+ dtype = self.down.weight.dtype
150
+
151
+ down_hidden_states = self.down(hidden_states.to(dtype))
152
+ up_hidden_states = self.up(down_hidden_states)
153
+
154
+ if self.network_alpha is not None:
155
+ up_hidden_states *= self.network_alpha / self.rank
156
+
157
+ return up_hidden_states.to(orig_dtype)
158
+
159
+
160
+ class LoRACompatibleConv(nn.Conv2d):
161
+ """
162
+ A convolutional layer that can be used with LoRA.
163
+ """
164
+
165
+ def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs):
166
+ super().__init__(*args, **kwargs)
167
+ self.lora_layer = lora_layer
168
+
169
+ def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
170
+ self.lora_layer = lora_layer
171
+
172
+ def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
173
+ if self.lora_layer is None:
174
+ return
175
+
176
+ dtype, device = self.weight.data.dtype, self.weight.data.device
177
+
178
+ w_orig = self.weight.data.float()
179
+ w_up = self.lora_layer.up.weight.data.float()
180
+ w_down = self.lora_layer.down.weight.data.float()
181
+
182
+ if self.lora_layer.network_alpha is not None:
183
+ w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
184
+
185
+ fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1))
186
+ fusion = fusion.reshape((w_orig.shape))
187
+ fused_weight = w_orig + (lora_scale * fusion)
188
+
189
+ if safe_fusing and torch.isnan(fused_weight).any().item():
190
+ raise ValueError(
191
+ "This LoRA weight seems to be broken. "
192
+ f"Encountered NaN values when trying to fuse LoRA weights for {self}."
193
+ "LoRA weights will not be fused."
194
+ )
195
+
196
+ self.weight.data = fused_weight.to(device=device, dtype=dtype)
197
+
198
+ # we can drop the lora layer now
199
+ self.lora_layer = None
200
+
201
+ # offload the up and down matrices to CPU to not blow the memory
202
+ self.w_up = w_up.cpu()
203
+ self.w_down = w_down.cpu()
204
+ self._lora_scale = lora_scale
205
+
206
+ def _unfuse_lora(self):
207
+ if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
208
+ return
209
+
210
+ fused_weight = self.weight.data
211
+ dtype, device = fused_weight.data.dtype, fused_weight.data.device
212
+
213
+ self.w_up = self.w_up.to(device=device).float()
214
+ self.w_down = self.w_down.to(device).float()
215
+
216
+ fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1))
217
+ fusion = fusion.reshape((fused_weight.shape))
218
+ unfused_weight = fused_weight.float() - (self._lora_scale * fusion)
219
+ self.weight.data = unfused_weight.to(device=device, dtype=dtype)
220
+
221
+ self.w_up = None
222
+ self.w_down = None
223
+
224
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
225
+ if self.lora_layer is None:
226
+ # make sure to the functional Conv2D function as otherwise torch.compile's graph will break
227
+ # see: https://github.com/huggingface/diffusers/pull/4315
228
+ return F.conv2d(
229
+ hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
230
+ )
231
+ else:
232
+ original_outputs = F.conv2d(
233
+ hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
234
+ )
235
+ return original_outputs + (scale * self.lora_layer(hidden_states))
236
+
237
+
238
+ class LoRACompatibleLinear(nn.Linear):
239
+ """
240
+ A Linear layer that can be used with LoRA.
241
+ """
242
+
243
+ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
244
+ super().__init__(*args, **kwargs)
245
+ self.lora_layer = lora_layer
246
+
247
+ def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
248
+ self.lora_layer = lora_layer
249
+
250
+ def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
251
+ if self.lora_layer is None:
252
+ return
253
+
254
+ dtype, device = self.weight.data.dtype, self.weight.data.device
255
+
256
+ w_orig = self.weight.data.float()
257
+ w_up = self.lora_layer.up.weight.data.float()
258
+ w_down = self.lora_layer.down.weight.data.float()
259
+
260
+ if self.lora_layer.network_alpha is not None:
261
+ w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
262
+
263
+ fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
264
+
265
+ if safe_fusing and torch.isnan(fused_weight).any().item():
266
+ raise ValueError(
267
+ "This LoRA weight seems to be broken. "
268
+ f"Encountered NaN values when trying to fuse LoRA weights for {self}."
269
+ "LoRA weights will not be fused."
270
+ )
271
+
272
+ self.weight.data = fused_weight.to(device=device, dtype=dtype)
273
+
274
+ # we can drop the lora layer now
275
+ self.lora_layer = None
276
+
277
+ # offload the up and down matrices to CPU to not blow the memory
278
+ self.w_up = w_up.cpu()
279
+ self.w_down = w_down.cpu()
280
+ self._lora_scale = lora_scale
281
+
282
+ def _unfuse_lora(self):
283
+ if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
284
+ return
285
+
286
+ fused_weight = self.weight.data
287
+ dtype, device = fused_weight.dtype, fused_weight.device
288
+
289
+ w_up = self.w_up.to(device=device).float()
290
+ w_down = self.w_down.to(device).float()
291
+
292
+ unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
293
+ self.weight.data = unfused_weight.to(device=device, dtype=dtype)
294
+
295
+ self.w_up = None
296
+ self.w_down = None
297
+
298
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
299
+ if self.lora_layer is None:
300
+ out = super().forward(hidden_states)
301
+ return out
302
+ else:
303
+ out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
304
+ return out
diffusers/models/modeling_flax_pytorch_utils.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch - Flax general utilities."""
16
+ import re
17
+
18
+ import jax.numpy as jnp
19
+ from flax.traverse_util import flatten_dict, unflatten_dict
20
+ from jax.random import PRNGKey
21
+
22
+ from ..utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ def rename_key(key):
29
+ regex = r"\w+[.]\d+"
30
+ pats = re.findall(regex, key)
31
+ for pat in pats:
32
+ key = key.replace(pat, "_".join(pat.split(".")))
33
+ return key
34
+
35
+
36
+ #####################
37
+ # PyTorch => Flax #
38
+ #####################
39
+
40
+
41
+ # Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
42
+ # and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
43
+ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
44
+ """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
45
+ # conv norm or layer norm
46
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
47
+
48
+ # rename attention layers
49
+ if len(pt_tuple_key) > 1:
50
+ for rename_from, rename_to in (
51
+ ("to_out_0", "proj_attn"),
52
+ ("to_k", "key"),
53
+ ("to_v", "value"),
54
+ ("to_q", "query"),
55
+ ):
56
+ if pt_tuple_key[-2] == rename_from:
57
+ weight_name = pt_tuple_key[-1]
58
+ weight_name = "kernel" if weight_name == "weight" else weight_name
59
+ renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name)
60
+ if renamed_pt_tuple_key in random_flax_state_dict:
61
+ assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape
62
+ return renamed_pt_tuple_key, pt_tensor.T
63
+
64
+ if (
65
+ any("norm" in str_ for str_ in pt_tuple_key)
66
+ and (pt_tuple_key[-1] == "bias")
67
+ and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
68
+ and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
69
+ ):
70
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
71
+ return renamed_pt_tuple_key, pt_tensor
72
+ elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
73
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
74
+ return renamed_pt_tuple_key, pt_tensor
75
+
76
+ # embedding
77
+ if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
78
+ pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
79
+ return renamed_pt_tuple_key, pt_tensor
80
+
81
+ # conv layer
82
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
83
+ if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
84
+ pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
85
+ return renamed_pt_tuple_key, pt_tensor
86
+
87
+ # linear layer
88
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
89
+ if pt_tuple_key[-1] == "weight":
90
+ pt_tensor = pt_tensor.T
91
+ return renamed_pt_tuple_key, pt_tensor
92
+
93
+ # old PyTorch layer norm weight
94
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
95
+ if pt_tuple_key[-1] == "gamma":
96
+ return renamed_pt_tuple_key, pt_tensor
97
+
98
+ # old PyTorch layer norm bias
99
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
100
+ if pt_tuple_key[-1] == "beta":
101
+ return renamed_pt_tuple_key, pt_tensor
102
+
103
+ return pt_tuple_key, pt_tensor
104
+
105
+
106
+ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
107
+ # Step 1: Convert pytorch tensor to numpy
108
+ pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
109
+
110
+ # Step 2: Since the model is stateless, get random Flax params
111
+ random_flax_params = flax_model.init_weights(PRNGKey(init_key))
112
+
113
+ random_flax_state_dict = flatten_dict(random_flax_params)
114
+ flax_state_dict = {}
115
+
116
+ # Need to change some parameters name to match Flax names
117
+ for pt_key, pt_tensor in pt_state_dict.items():
118
+ renamed_pt_key = rename_key(pt_key)
119
+ pt_tuple_key = tuple(renamed_pt_key.split("."))
120
+
121
+ # Correctly rename weight parameters
122
+ flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)
123
+
124
+ if flax_key in random_flax_state_dict:
125
+ if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
126
+ raise ValueError(
127
+ f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
128
+ f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
129
+ )
130
+
131
+ # also add unexpected weight so that warning is thrown
132
+ flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
133
+
134
+ return unflatten_dict(flax_state_dict)
diffusers/models/modeling_flax_utils.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from pickle import UnpicklingError
18
+ from typing import Any, Dict, Union
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import msgpack.exceptions
23
+ from flax.core.frozen_dict import FrozenDict, unfreeze
24
+ from flax.serialization import from_bytes, to_bytes
25
+ from flax.traverse_util import flatten_dict, unflatten_dict
26
+ from huggingface_hub import create_repo, hf_hub_download
27
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
28
+ from requests import HTTPError
29
+
30
+ from .. import __version__, is_torch_available
31
+ from ..utils import (
32
+ CONFIG_NAME,
33
+ DIFFUSERS_CACHE,
34
+ FLAX_WEIGHTS_NAME,
35
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
36
+ WEIGHTS_NAME,
37
+ PushToHubMixin,
38
+ logging,
39
+ )
40
+ from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+
46
+ class FlaxModelMixin(PushToHubMixin):
47
+ r"""
48
+ Base class for all Flax models.
49
+
50
+ [`FlaxModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
51
+ saving models.
52
+
53
+ - **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`].
54
+ """
55
+ config_name = CONFIG_NAME
56
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
57
+ _flax_internal_args = ["name", "parent", "dtype"]
58
+
59
+ @classmethod
60
+ def _from_config(cls, config, **kwargs):
61
+ """
62
+ All context managers that the model should be initialized under go here.
63
+ """
64
+ return cls(config, **kwargs)
65
+
66
+ def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
67
+ """
68
+ Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
69
+ """
70
+
71
+ # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
72
+ def conditional_cast(param):
73
+ if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
74
+ param = param.astype(dtype)
75
+ return param
76
+
77
+ if mask is None:
78
+ return jax.tree_map(conditional_cast, params)
79
+
80
+ flat_params = flatten_dict(params)
81
+ flat_mask, _ = jax.tree_flatten(mask)
82
+
83
+ for masked, key in zip(flat_mask, flat_params.keys()):
84
+ if masked:
85
+ param = flat_params[key]
86
+ flat_params[key] = conditional_cast(param)
87
+
88
+ return unflatten_dict(flat_params)
89
+
90
+ def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
91
+ r"""
92
+ Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
93
+ the `params` in place.
94
+
95
+ This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full
96
+ half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
97
+
98
+ Arguments:
99
+ params (`Union[Dict, FrozenDict]`):
100
+ A `PyTree` of model parameters.
101
+ mask (`Union[Dict, FrozenDict]`):
102
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
103
+ for params you want to cast, and `False` for those you want to skip.
104
+
105
+ Examples:
106
+
107
+ ```python
108
+ >>> from diffusers import FlaxUNet2DConditionModel
109
+
110
+ >>> # load model
111
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
112
+ >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
113
+ >>> params = model.to_bf16(params)
114
+ >>> # If you don't want to cast certain parameters (for example layer norm bias and scale)
115
+ >>> # then pass the face_hair_mask as follows
116
+ >>> from flax import traverse_util
117
+
118
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
119
+ >>> flat_params = traverse_util.flatten_dict(params)
120
+ >>> face_hair_mask = {
121
+ ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
122
+ ... for path in flat_params
123
+ ... }
124
+ >>> face_hair_mask = traverse_util.unflatten_dict(face_hair_mask)
125
+ >>> params = model.to_bf16(params, face_hair_mask)
126
+ ```"""
127
+ return self._cast_floating_to(params, jnp.bfloat16, mask)
128
+
129
+ def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
130
+ r"""
131
+ Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the
132
+ model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
133
+
134
+ Arguments:
135
+ params (`Union[Dict, FrozenDict]`):
136
+ A `PyTree` of model parameters.
137
+ mask (`Union[Dict, FrozenDict]`):
138
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
139
+ for params you want to cast, and `False` for those you want to skip.
140
+
141
+ Examples:
142
+
143
+ ```python
144
+ >>> from diffusers import FlaxUNet2DConditionModel
145
+
146
+ >>> # Download model and configuration from huggingface.co
147
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
148
+ >>> # By default, the model params will be in fp32, to illustrate the use of this method,
149
+ >>> # we'll first cast to fp16 and back to fp32
150
+ >>> params = model.to_f16(params)
151
+ >>> # now cast back to fp32
152
+ >>> params = model.to_fp32(params)
153
+ ```"""
154
+ return self._cast_floating_to(params, jnp.float32, mask)
155
+
156
+ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
157
+ r"""
158
+ Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
159
+ `params` in place.
160
+
161
+ This method can be used on a GPU to explicitly convert the model parameters to float16 precision to do full
162
+ half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
163
+
164
+ Arguments:
165
+ params (`Union[Dict, FrozenDict]`):
166
+ A `PyTree` of model parameters.
167
+ mask (`Union[Dict, FrozenDict]`):
168
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True`
169
+ for params you want to cast, and `False` for those you want to skip.
170
+
171
+ Examples:
172
+
173
+ ```python
174
+ >>> from diffusers import FlaxUNet2DConditionModel
175
+
176
+ >>> # load model
177
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
178
+ >>> # By default, the model params will be in fp32, to cast these to float16
179
+ >>> params = model.to_fp16(params)
180
+ >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
181
+ >>> # then pass the face_hair_mask as follows
182
+ >>> from flax import traverse_util
183
+
184
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
185
+ >>> flat_params = traverse_util.flatten_dict(params)
186
+ >>> face_hair_mask = {
187
+ ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
188
+ ... for path in flat_params
189
+ ... }
190
+ >>> face_hair_mask = traverse_util.unflatten_dict(face_hair_mask)
191
+ >>> params = model.to_fp16(params, face_hair_mask)
192
+ ```"""
193
+ return self._cast_floating_to(params, jnp.float16, mask)
194
+
195
+ def init_weights(self, rng: jax.Array) -> Dict:
196
+ raise NotImplementedError(f"init_weights method has to be implemented for {self}")
197
+
198
+ @classmethod
199
+ def from_pretrained(
200
+ cls,
201
+ pretrained_model_name_or_path: Union[str, os.PathLike],
202
+ dtype: jnp.dtype = jnp.float32,
203
+ *model_args,
204
+ **kwargs,
205
+ ):
206
+ r"""
207
+ Instantiate a pretrained Flax model from a pretrained model configuration.
208
+
209
+ Parameters:
210
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
211
+ Can be either:
212
+
213
+ - A string, the *model id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained model
214
+ hosted on the Hub.
215
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
216
+ using [`~FlaxModelMixin.save_pretrained`].
217
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
218
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
219
+ `jax.numpy.bfloat16` (on TPUs).
220
+
221
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
222
+ specified, all the computation will be performed with the given `dtype`.
223
+
224
+ <Tip>
225
+
226
+ This only specifies the dtype of the *computation* and does not influence the dtype of model
227
+ parameters.
228
+
229
+ If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and
230
+ [`~FlaxModelMixin.to_bf16`].
231
+
232
+ </Tip>
233
+
234
+ model_args (sequence of positional arguments, *optional*):
235
+ All remaining positional arguments are passed to the underlying model's `__init__` method.
236
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
237
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
238
+ is not used.
239
+ force_download (`bool`, *optional*, defaults to `False`):
240
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
241
+ cached versions if they exist.
242
+ resume_download (`bool`, *optional*, defaults to `False`):
243
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
244
+ incompletely downloaded files are deleted.
245
+ proxies (`Dict[str, str]`, *optional*):
246
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
247
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
248
+ local_files_only(`bool`, *optional*, defaults to `False`):
249
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
250
+ won't be downloaded from the Hub.
251
+ revision (`str`, *optional*, defaults to `"main"`):
252
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
253
+ allowed by Git.
254
+ from_pt (`bool`, *optional*, defaults to `False`):
255
+ Load the model weights from a PyTorch checkpoint save file.
256
+ kwargs (remaining dictionary of keyword arguments, *optional*):
257
+ Can be used to update the configuration object (after it is loaded) and initiate the model (for
258
+ example, `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
259
+ automatically loaded:
260
+
261
+ - If a configuration is provided with `config`, `kwargs` are directly passed to the underlying
262
+ model's `__init__` method (we assume all relevant updates to the configuration have already been
263
+ done).
264
+ - If a configuration is not provided, `kwargs` are first passed to the configuration class
265
+ initialization function [`~ConfigMixin.from_config`]. Each key of the `kwargs` that corresponds
266
+ to a configuration attribute is used to override said attribute with the supplied `kwargs` value.
267
+ Remaining keys that do not correspond to any configuration attribute are passed to the underlying
268
+ model's `__init__` function.
269
+
270
+ Examples:
271
+
272
+ ```python
273
+ >>> from diffusers import FlaxUNet2DConditionModel
274
+
275
+ >>> # Download model and configuration from huggingface.co and cache.
276
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
277
+ >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
278
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/")
279
+ ```
280
+
281
+ If you get the error message below, you need to finetune the weights for your downstream task:
282
+
283
+ ```bash
284
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
285
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
286
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
287
+ ```
288
+ """
289
+ config = kwargs.pop("config", None)
290
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
291
+ force_download = kwargs.pop("force_download", False)
292
+ from_pt = kwargs.pop("from_pt", False)
293
+ resume_download = kwargs.pop("resume_download", False)
294
+ proxies = kwargs.pop("proxies", None)
295
+ local_files_only = kwargs.pop("local_files_only", False)
296
+ use_auth_token = kwargs.pop("use_auth_token", None)
297
+ revision = kwargs.pop("revision", None)
298
+ subfolder = kwargs.pop("subfolder", None)
299
+
300
+ user_agent = {
301
+ "diffusers": __version__,
302
+ "file_type": "model",
303
+ "framework": "flax",
304
+ }
305
+
306
+ # Load config if we don't provide one
307
+ if config is None:
308
+ config, unused_kwargs = cls.load_config(
309
+ pretrained_model_name_or_path,
310
+ cache_dir=cache_dir,
311
+ return_unused_kwargs=True,
312
+ force_download=force_download,
313
+ resume_download=resume_download,
314
+ proxies=proxies,
315
+ local_files_only=local_files_only,
316
+ use_auth_token=use_auth_token,
317
+ revision=revision,
318
+ subfolder=subfolder,
319
+ **kwargs,
320
+ )
321
+
322
+ model, model_kwargs = cls.from_config(config, dtype=dtype, return_unused_kwargs=True, **unused_kwargs)
323
+
324
+ # Load model
325
+ pretrained_path_with_subfolder = (
326
+ pretrained_model_name_or_path
327
+ if subfolder is None
328
+ else os.path.join(pretrained_model_name_or_path, subfolder)
329
+ )
330
+ if os.path.isdir(pretrained_path_with_subfolder):
331
+ if from_pt:
332
+ if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
333
+ raise EnvironmentError(
334
+ f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} "
335
+ )
336
+ model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)
337
+ elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)):
338
+ # Load from a Flax checkpoint
339
+ model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)
340
+ # Check if pytorch weights exist instead
341
+ elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
342
+ raise EnvironmentError(
343
+ f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
344
+ " using `from_pt=True`."
345
+ )
346
+ else:
347
+ raise EnvironmentError(
348
+ f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
349
+ f"{pretrained_path_with_subfolder}."
350
+ )
351
+ else:
352
+ try:
353
+ model_file = hf_hub_download(
354
+ pretrained_model_name_or_path,
355
+ filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME,
356
+ cache_dir=cache_dir,
357
+ force_download=force_download,
358
+ proxies=proxies,
359
+ resume_download=resume_download,
360
+ local_files_only=local_files_only,
361
+ use_auth_token=use_auth_token,
362
+ user_agent=user_agent,
363
+ subfolder=subfolder,
364
+ revision=revision,
365
+ )
366
+
367
+ except RepositoryNotFoundError:
368
+ raise EnvironmentError(
369
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
370
+ "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
371
+ "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
372
+ "login`."
373
+ )
374
+ except RevisionNotFoundError:
375
+ raise EnvironmentError(
376
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
377
+ "this model name. Check the model page at "
378
+ f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
379
+ )
380
+ except EntryNotFoundError:
381
+ raise EnvironmentError(
382
+ f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}."
383
+ )
384
+ except HTTPError as err:
385
+ raise EnvironmentError(
386
+ f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
387
+ f"{err}"
388
+ )
389
+ except ValueError:
390
+ raise EnvironmentError(
391
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
392
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
393
+ f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your"
394
+ " internet connection or see how to run the library in offline mode at"
395
+ " 'https://huggingface.co/docs/transformers/installation#offline-mode'."
396
+ )
397
+ except EnvironmentError:
398
+ raise EnvironmentError(
399
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
400
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
401
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
402
+ f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
403
+ )
404
+
405
+ if from_pt:
406
+ if is_torch_available():
407
+ from .modeling_utils import load_state_dict
408
+ else:
409
+ raise EnvironmentError(
410
+ "Can't load the model in PyTorch format because PyTorch is not installed. "
411
+ "Please, install PyTorch or use native Flax weights."
412
+ )
413
+
414
+ # Step 1: Get the pytorch file
415
+ pytorch_model_file = load_state_dict(model_file)
416
+
417
+ # Step 2: Convert the weights
418
+ state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)
419
+ else:
420
+ try:
421
+ with open(model_file, "rb") as state_f:
422
+ state = from_bytes(cls, state_f.read())
423
+ except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
424
+ try:
425
+ with open(model_file) as f:
426
+ if f.read().startswith("version"):
427
+ raise OSError(
428
+ "You seem to have cloned a repository without having git-lfs installed. Please"
429
+ " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
430
+ " folder you cloned."
431
+ )
432
+ else:
433
+ raise ValueError from e
434
+ except (UnicodeDecodeError, ValueError):
435
+ raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
436
+ # make sure all arrays are stored as jnp.ndarray
437
+ # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
438
+ # https://github.com/google/flax/issues/1261
439
+ state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
440
+
441
+ # flatten dicts
442
+ state = flatten_dict(state)
443
+
444
+ params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0))
445
+ required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
446
+
447
+ shape_state = flatten_dict(unfreeze(params_shape_tree))
448
+
449
+ missing_keys = required_params - set(state.keys())
450
+ unexpected_keys = set(state.keys()) - required_params
451
+
452
+ if missing_keys:
453
+ logger.warning(
454
+ f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
455
+ "Make sure to call model.init_weights to initialize the missing weights."
456
+ )
457
+ cls._missing_keys = missing_keys
458
+
459
+ for key in state.keys():
460
+ if key in shape_state and state[key].shape != shape_state[key].shape:
461
+ raise ValueError(
462
+ f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
463
+ f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. "
464
+ )
465
+
466
+ # remove unexpected keys to not be saved again
467
+ for unexpected_key in unexpected_keys:
468
+ del state[unexpected_key]
469
+
470
+ if len(unexpected_keys) > 0:
471
+ logger.warning(
472
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
473
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
474
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
475
+ " with another architecture."
476
+ )
477
+ else:
478
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
479
+
480
+ if len(missing_keys) > 0:
481
+ logger.warning(
482
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
483
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
484
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
485
+ )
486
+ else:
487
+ logger.info(
488
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
489
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
490
+ f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
491
+ " training."
492
+ )
493
+
494
+ return model, unflatten_dict(state)
495
+
496
+ def save_pretrained(
497
+ self,
498
+ save_directory: Union[str, os.PathLike],
499
+ params: Union[Dict, FrozenDict],
500
+ is_main_process: bool = True,
501
+ push_to_hub: bool = False,
502
+ **kwargs,
503
+ ):
504
+ """
505
+ Save a model and its configuration file to a directory so that it can be reloaded using the
506
+ [`~FlaxModelMixin.from_pretrained`] class method.
507
+
508
+ Arguments:
509
+ save_directory (`str` or `os.PathLike`):
510
+ Directory to save a model and its configuration file to. Will be created if it doesn't exist.
511
+ params (`Union[Dict, FrozenDict]`):
512
+ A `PyTree` of model parameters.
513
+ is_main_process (`bool`, *optional*, defaults to `True`):
514
+ Whether the process calling this is the main process or not. Useful during distributed training and you
515
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
516
+ process to avoid race conditions.
517
+ push_to_hub (`bool`, *optional*, defaults to `False`):
518
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
519
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
520
+ namespace).
521
+ kwargs (`Dict[str, Any]`, *optional*):
522
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
523
+ """
524
+ if os.path.isfile(save_directory):
525
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
526
+ return
527
+
528
+ os.makedirs(save_directory, exist_ok=True)
529
+
530
+ if push_to_hub:
531
+ commit_message = kwargs.pop("commit_message", None)
532
+ private = kwargs.pop("private", False)
533
+ create_pr = kwargs.pop("create_pr", False)
534
+ token = kwargs.pop("token", None)
535
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
536
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
537
+
538
+ model_to_save = self
539
+
540
+ # Attach architecture to the config
541
+ # Save the config
542
+ if is_main_process:
543
+ model_to_save.save_config(save_directory)
544
+
545
+ # save model
546
+ output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
547
+ with open(output_model_file, "wb") as f:
548
+ model_bytes = to_bytes(params)
549
+ f.write(model_bytes)
550
+
551
+ logger.info(f"Model weights saved in {output_model_file}")
552
+
553
+ if push_to_hub:
554
+ self._upload_folder(
555
+ save_directory,
556
+ repo_id,
557
+ token=token,
558
+ commit_message=commit_message,
559
+ create_pr=create_pr,
560
+ )
diffusers/models/modeling_pytorch_flax_utils.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch - Flax general utilities."""
16
+
17
+ from pickle import UnpicklingError
18
+
19
+ import jax
20
+ import jax.numpy as jnp
21
+ import numpy as np
22
+ from flax.serialization import from_bytes
23
+ from flax.traverse_util import flatten_dict
24
+
25
+ from ..utils import logging
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ #####################
32
+ # Flax => PyTorch #
33
+ #####################
34
+
35
+
36
+ # from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352
37
+ def load_flax_checkpoint_in_pytorch_model(pt_model, model_file):
38
+ try:
39
+ with open(model_file, "rb") as flax_state_f:
40
+ flax_state = from_bytes(None, flax_state_f.read())
41
+ except UnpicklingError as e:
42
+ try:
43
+ with open(model_file) as f:
44
+ if f.read().startswith("version"):
45
+ raise OSError(
46
+ "You seem to have cloned a repository without having git-lfs installed. Please"
47
+ " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
48
+ " folder you cloned."
49
+ )
50
+ else:
51
+ raise ValueError from e
52
+ except (UnicodeDecodeError, ValueError):
53
+ raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
54
+
55
+ return load_flax_weights_in_pytorch_model(pt_model, flax_state)
56
+
57
+
58
+ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
59
+ """Load flax checkpoints in a PyTorch model"""
60
+
61
+ try:
62
+ import torch # noqa: F401
63
+ except ImportError:
64
+ logger.error(
65
+ "Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see"
66
+ " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
67
+ " instructions."
68
+ )
69
+ raise
70
+
71
+ # check if we have bf16 weights
72
+ is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
73
+ if any(is_type_bf16):
74
+ # convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16
75
+
76
+ # and bf16 is not fully supported in PT yet.
77
+ logger.warning(
78
+ "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
79
+ "before loading those in PyTorch model."
80
+ )
81
+ flax_state = jax.tree_util.tree_map(
82
+ lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
83
+ )
84
+
85
+ pt_model.base_model_prefix = ""
86
+
87
+ flax_state_dict = flatten_dict(flax_state, sep=".")
88
+ pt_model_dict = pt_model.state_dict()
89
+
90
+ # keep track of unexpected & missing keys
91
+ unexpected_keys = []
92
+ missing_keys = set(pt_model_dict.keys())
93
+
94
+ for flax_key_tuple, flax_tensor in flax_state_dict.items():
95
+ flax_key_tuple_array = flax_key_tuple.split(".")
96
+
97
+ if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4:
98
+ flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
99
+ flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
100
+ elif flax_key_tuple_array[-1] == "kernel":
101
+ flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
102
+ flax_tensor = flax_tensor.T
103
+ elif flax_key_tuple_array[-1] == "scale":
104
+ flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
105
+
106
+ if "time_embedding" not in flax_key_tuple_array:
107
+ for i, flax_key_tuple_string in enumerate(flax_key_tuple_array):
108
+ flax_key_tuple_array[i] = (
109
+ flax_key_tuple_string.replace("_0", ".0")
110
+ .replace("_1", ".1")
111
+ .replace("_2", ".2")
112
+ .replace("_3", ".3")
113
+ .replace("_4", ".4")
114
+ .replace("_5", ".5")
115
+ .replace("_6", ".6")
116
+ .replace("_7", ".7")
117
+ .replace("_8", ".8")
118
+ .replace("_9", ".9")
119
+ )
120
+
121
+ flax_key = ".".join(flax_key_tuple_array)
122
+
123
+ if flax_key in pt_model_dict:
124
+ if flax_tensor.shape != pt_model_dict[flax_key].shape:
125
+ raise ValueError(
126
+ f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected "
127
+ f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
128
+ )
129
+ else:
130
+ # add weight to pytorch dict
131
+ flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
132
+ pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
133
+ # remove from missing keys
134
+ missing_keys.remove(flax_key)
135
+ else:
136
+ # weight is not expected by PyTorch model
137
+ unexpected_keys.append(flax_key)
138
+
139
+ pt_model.load_state_dict(pt_model_dict)
140
+
141
+ # re-transform missing_keys to list
142
+ missing_keys = list(missing_keys)
143
+
144
+ if len(unexpected_keys) > 0:
145
+ logger.warning(
146
+ "Some weights of the Flax model were not used when initializing the PyTorch model"
147
+ f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
148
+ f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
149
+ " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
150
+ f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
151
+ " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
152
+ " FlaxBertForSequenceClassification model)."
153
+ )
154
+ if len(missing_keys) > 0:
155
+ logger.warning(
156
+ f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
157
+ f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
158
+ " use it for predictions and inference."
159
+ )
160
+
161
+ return pt_model
diffusers/models/modeling_utils.py ADDED
@@ -0,0 +1,1158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import inspect
18
+ import itertools
19
+ import os
20
+ import re
21
+ from functools import partial
22
+ from typing import Any, Callable, List, Optional, Tuple, Union
23
+
24
+ import safetensors
25
+ import torch
26
+ from huggingface_hub import create_repo
27
+ from torch import Tensor, device, nn
28
+
29
+ from .. import __version__
30
+ from ..utils import (
31
+ CONFIG_NAME,
32
+ DIFFUSERS_CACHE,
33
+ FLAX_WEIGHTS_NAME,
34
+ HF_HUB_OFFLINE,
35
+ MIN_PEFT_VERSION,
36
+ SAFETENSORS_WEIGHTS_NAME,
37
+ WEIGHTS_NAME,
38
+ _add_variant,
39
+ _get_model_file,
40
+ check_peft_version,
41
+ deprecate,
42
+ is_accelerate_available,
43
+ is_torch_version,
44
+ logging,
45
+ )
46
+ from ..utils.hub_utils import PushToHubMixin
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ if is_torch_version(">=", "1.9.0"):
53
+ _LOW_CPU_MEM_USAGE_DEFAULT = True
54
+ else:
55
+ _LOW_CPU_MEM_USAGE_DEFAULT = False
56
+
57
+
58
+ if is_accelerate_available():
59
+ import accelerate
60
+ from accelerate.utils import set_module_tensor_to_device
61
+ from accelerate.utils.versions import is_torch_version
62
+
63
+
64
+ def get_parameter_device(parameter: torch.nn.Module):
65
+ try:
66
+ parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
67
+ return next(parameters_and_buffers).device
68
+ except StopIteration:
69
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
70
+
71
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
72
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
73
+ return tuples
74
+
75
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
76
+ first_tuple = next(gen)
77
+ return first_tuple[1].device
78
+
79
+
80
+ def get_parameter_dtype(parameter: torch.nn.Module):
81
+ try:
82
+ params = tuple(parameter.parameters())
83
+ if len(params) > 0:
84
+ return params[0].dtype
85
+
86
+ buffers = tuple(parameter.buffers())
87
+ if len(buffers) > 0:
88
+ return buffers[0].dtype
89
+
90
+ except StopIteration:
91
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
92
+
93
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
94
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
95
+ return tuples
96
+
97
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
98
+ first_tuple = next(gen)
99
+ return first_tuple[1].dtype
100
+
101
+
102
+ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
103
+ """
104
+ Reads a checkpoint file, returning properly formatted errors if they arise.
105
+ """
106
+ try:
107
+ if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
108
+ return torch.load(checkpoint_file, map_location="cpu")
109
+ else:
110
+ return safetensors.torch.load_file(checkpoint_file, device="cpu")
111
+ except Exception as e:
112
+ try:
113
+ with open(checkpoint_file) as f:
114
+ if f.read().startswith("version"):
115
+ raise OSError(
116
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
117
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
118
+ "you cloned."
119
+ )
120
+ else:
121
+ raise ValueError(
122
+ f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
123
+ "model. Make sure you have saved the model properly."
124
+ ) from e
125
+ except (UnicodeDecodeError, ValueError):
126
+ raise OSError(
127
+ f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
128
+ f"at '{checkpoint_file}'. "
129
+ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
130
+ )
131
+
132
+
133
+ def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_name_or_path=None):
134
+ device = device or torch.device("cpu")
135
+ dtype = dtype or torch.float32
136
+
137
+ accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
138
+
139
+ unexpected_keys = []
140
+ empty_state_dict = model.state_dict()
141
+ for param_name, param in state_dict.items():
142
+ if param_name not in empty_state_dict:
143
+ unexpected_keys.append(param_name)
144
+ continue
145
+
146
+ if empty_state_dict[param_name].shape != param.shape:
147
+ model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
148
+ raise ValueError(
149
+ f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
150
+ )
151
+
152
+ if accepts_dtype:
153
+ set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
154
+ else:
155
+ set_module_tensor_to_device(model, param_name, device, value=param)
156
+ return unexpected_keys
157
+
158
+
159
+ def _load_state_dict_into_model(model_to_load, state_dict):
160
+ # Convert old format to new format if needed from a PyTorch state_dict
161
+ # copy state_dict so _load_from_state_dict can modify it
162
+ state_dict = state_dict.copy()
163
+ error_msgs = []
164
+
165
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
166
+ # so we need to apply the function recursively.
167
+ def load(module: torch.nn.Module, prefix=""):
168
+ args = (state_dict, prefix, {}, True, [], [], error_msgs)
169
+ module._load_from_state_dict(*args)
170
+
171
+ for name, child in module._modules.items():
172
+ if child is not None:
173
+ load(child, prefix + name + ".")
174
+
175
+ load(model_to_load)
176
+
177
+ return error_msgs
178
+
179
+
180
+ class ModelMixin(torch.nn.Module, PushToHubMixin):
181
+ r"""
182
+ Base class for all models.
183
+
184
+ [`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
185
+ saving models.
186
+
187
+ - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
188
+ """
189
+ config_name = CONFIG_NAME
190
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
191
+ _supports_gradient_checkpointing = False
192
+ _keys_to_ignore_on_load_unexpected = None
193
+ _hf_peft_config_loaded = False
194
+
195
+ def __init__(self):
196
+ super().__init__()
197
+
198
+ def __getattr__(self, name: str) -> Any:
199
+ """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
200
+ config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
201
+ __getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
202
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
203
+ """
204
+
205
+ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
206
+ is_attribute = name in self.__dict__
207
+
208
+ if is_in_config and not is_attribute:
209
+ deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
210
+ deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
211
+ return self._internal_dict[name]
212
+
213
+ # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
214
+ return super().__getattr__(name)
215
+
216
+ @property
217
+ def is_gradient_checkpointing(self) -> bool:
218
+ """
219
+ Whether gradient checkpointing is activated for this model or not.
220
+ """
221
+ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
222
+
223
+ def enable_gradient_checkpointing(self):
224
+ """
225
+ Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
226
+ *checkpoint activations* in other frameworks).
227
+ """
228
+ if not self._supports_gradient_checkpointing:
229
+ raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
230
+ self.apply(partial(self._set_gradient_checkpointing, value=True))
231
+
232
+ def disable_gradient_checkpointing(self):
233
+ """
234
+ Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
235
+ *checkpoint activations* in other frameworks).
236
+ """
237
+ if self._supports_gradient_checkpointing:
238
+ self.apply(partial(self._set_gradient_checkpointing, value=False))
239
+
240
+ def set_use_memory_efficient_attention_xformers(
241
+ self, valid: bool, attention_op: Optional[Callable] = None
242
+ ) -> None:
243
+ # Recursively walk through all the children.
244
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
245
+ # gets the message
246
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
247
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
248
+ module.set_use_memory_efficient_attention_xformers(valid, attention_op)
249
+
250
+ for child in module.children():
251
+ fn_recursive_set_mem_eff(child)
252
+
253
+ for module in self.children():
254
+ if isinstance(module, torch.nn.Module):
255
+ fn_recursive_set_mem_eff(module)
256
+
257
+ def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
258
+ r"""
259
+ Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
260
+
261
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
262
+ inference. Speed up during training is not guaranteed.
263
+
264
+ <Tip warning={true}>
265
+
266
+ ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
267
+ precedent.
268
+
269
+ </Tip>
270
+
271
+ Parameters:
272
+ attention_op (`Callable`, *optional*):
273
+ Override the default `None` operator for use as `op` argument to the
274
+ [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
275
+ function of xFormers.
276
+
277
+ Examples:
278
+
279
+ ```py
280
+ >>> import torch
281
+ >>> from diffusers import UNet2DConditionModel
282
+ >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
283
+
284
+ >>> model = UNet2DConditionModel.from_pretrained(
285
+ ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
286
+ ... )
287
+ >>> model = model.to("cuda")
288
+ >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
289
+ ```
290
+ """
291
+ self.set_use_memory_efficient_attention_xformers(True, attention_op)
292
+
293
+ def disable_xformers_memory_efficient_attention(self):
294
+ r"""
295
+ Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
296
+ """
297
+ self.set_use_memory_efficient_attention_xformers(False)
298
+
299
+ def add_adapter(self, adapter_config, adapter_name: str = "default") -> None:
300
+ r"""
301
+ Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned
302
+ to the adapter to follow the convention of the PEFT library.
303
+
304
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT
305
+ [documentation](https://huggingface.co/docs/peft).
306
+
307
+ Args:
308
+ adapter_config (`[~peft.PeftConfig]`):
309
+ The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt
310
+ methods.
311
+ adapter_name (`str`, *optional*, defaults to `"default"`):
312
+ The name of the adapter to add. If no name is passed, a default name is assigned to the adapter.
313
+ """
314
+ check_peft_version(min_version=MIN_PEFT_VERSION)
315
+
316
+ from peft import PeftConfig, inject_adapter_in_model
317
+
318
+ if not self._hf_peft_config_loaded:
319
+ self._hf_peft_config_loaded = True
320
+ elif adapter_name in self.peft_config:
321
+ raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
322
+
323
+ if not isinstance(adapter_config, PeftConfig):
324
+ raise ValueError(
325
+ f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
326
+ )
327
+
328
+ # Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is
329
+ # handled by the `load_lora_layers` or `LoraLoaderMixin`. Therefore we set it to `None` here.
330
+ adapter_config.base_model_name_or_path = None
331
+ inject_adapter_in_model(adapter_config, self, adapter_name)
332
+ self.set_adapter(adapter_name)
333
+
334
+ def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
335
+ """
336
+ Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters.
337
+
338
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
339
+ official documentation: https://huggingface.co/docs/peft
340
+
341
+ Args:
342
+ adapter_name (Union[str, List[str]])):
343
+ The list of adapters to set or the adapter name in case of single adapter.
344
+ """
345
+ check_peft_version(min_version=MIN_PEFT_VERSION)
346
+
347
+ if not self._hf_peft_config_loaded:
348
+ raise ValueError("No adapter loaded. Please load an adapter first.")
349
+
350
+ if isinstance(adapter_name, str):
351
+ adapter_name = [adapter_name]
352
+
353
+ missing = set(adapter_name) - set(self.peft_config)
354
+ if len(missing) > 0:
355
+ raise ValueError(
356
+ f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
357
+ f" current loaded adapters are: {list(self.peft_config.keys())}"
358
+ )
359
+
360
+ from peft.tuners.tuners_utils import BaseTunerLayer
361
+
362
+ _adapters_has_been_set = False
363
+
364
+ for _, module in self.named_modules():
365
+ if isinstance(module, BaseTunerLayer):
366
+ if hasattr(module, "set_adapter"):
367
+ module.set_adapter(adapter_name)
368
+ # Previous versions of PEFT does not support multi-adapter inference
369
+ elif not hasattr(module, "set_adapter") and len(adapter_name) != 1:
370
+ raise ValueError(
371
+ "You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT."
372
+ " `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`"
373
+ )
374
+ else:
375
+ module.active_adapter = adapter_name
376
+ _adapters_has_been_set = True
377
+
378
+ if not _adapters_has_been_set:
379
+ raise ValueError(
380
+ "Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters."
381
+ )
382
+
383
+ def disable_adapters(self) -> None:
384
+ r"""
385
+ Disable all adapters attached to the model and fallback to inference with the base model only.
386
+
387
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
388
+ official documentation: https://huggingface.co/docs/peft
389
+ """
390
+ check_peft_version(min_version=MIN_PEFT_VERSION)
391
+
392
+ if not self._hf_peft_config_loaded:
393
+ raise ValueError("No adapter loaded. Please load an adapter first.")
394
+
395
+ from peft.tuners.tuners_utils import BaseTunerLayer
396
+
397
+ for _, module in self.named_modules():
398
+ if isinstance(module, BaseTunerLayer):
399
+ if hasattr(module, "enable_adapters"):
400
+ module.enable_adapters(enabled=False)
401
+ else:
402
+ # support for older PEFT versions
403
+ module.disable_adapters = True
404
+
405
+ def enable_adapters(self) -> None:
406
+ """
407
+ Enable adapters that are attached to the model. The model will use `self.active_adapters()` to retrieve the
408
+ list of adapters to enable.
409
+
410
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
411
+ official documentation: https://huggingface.co/docs/peft
412
+ """
413
+ check_peft_version(min_version=MIN_PEFT_VERSION)
414
+
415
+ if not self._hf_peft_config_loaded:
416
+ raise ValueError("No adapter loaded. Please load an adapter first.")
417
+
418
+ from peft.tuners.tuners_utils import BaseTunerLayer
419
+
420
+ for _, module in self.named_modules():
421
+ if isinstance(module, BaseTunerLayer):
422
+ if hasattr(module, "enable_adapters"):
423
+ module.enable_adapters(enabled=True)
424
+ else:
425
+ # support for older PEFT versions
426
+ module.disable_adapters = False
427
+
428
+ def active_adapters(self) -> List[str]:
429
+ """
430
+ Gets the current list of active adapters of the model.
431
+
432
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
433
+ official documentation: https://huggingface.co/docs/peft
434
+ """
435
+ check_peft_version(min_version=MIN_PEFT_VERSION)
436
+
437
+ if not self._hf_peft_config_loaded:
438
+ raise ValueError("No adapter loaded. Please load an adapter first.")
439
+
440
+ from peft.tuners.tuners_utils import BaseTunerLayer
441
+
442
+ for _, module in self.named_modules():
443
+ if isinstance(module, BaseTunerLayer):
444
+ return module.active_adapter
445
+
446
+ def save_pretrained(
447
+ self,
448
+ save_directory: Union[str, os.PathLike],
449
+ is_main_process: bool = True,
450
+ save_function: Callable = None,
451
+ safe_serialization: bool = True,
452
+ variant: Optional[str] = None,
453
+ push_to_hub: bool = False,
454
+ **kwargs,
455
+ ):
456
+ """
457
+ Save a model and its configuration file to a directory so that it can be reloaded using the
458
+ [`~models.ModelMixin.from_pretrained`] class method.
459
+
460
+ Arguments:
461
+ save_directory (`str` or `os.PathLike`):
462
+ Directory to save a model and its configuration file to. Will be created if it doesn't exist.
463
+ is_main_process (`bool`, *optional*, defaults to `True`):
464
+ Whether the process calling this is the main process or not. Useful during distributed training and you
465
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
466
+ process to avoid race conditions.
467
+ save_function (`Callable`):
468
+ The function to use to save the state dictionary. Useful during distributed training when you need to
469
+ replace `torch.save` with another method. Can be configured with the environment variable
470
+ `DIFFUSERS_SAVE_MODE`.
471
+ safe_serialization (`bool`, *optional*, defaults to `True`):
472
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
473
+ variant (`str`, *optional*):
474
+ If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
475
+ push_to_hub (`bool`, *optional*, defaults to `False`):
476
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
477
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
478
+ namespace).
479
+ kwargs (`Dict[str, Any]`, *optional*):
480
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
481
+ """
482
+ if os.path.isfile(save_directory):
483
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
484
+ return
485
+
486
+ os.makedirs(save_directory, exist_ok=True)
487
+
488
+ if push_to_hub:
489
+ commit_message = kwargs.pop("commit_message", None)
490
+ private = kwargs.pop("private", False)
491
+ create_pr = kwargs.pop("create_pr", False)
492
+ token = kwargs.pop("token", None)
493
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
494
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
495
+
496
+ # Only save the model itself if we are using distributed training
497
+ model_to_save = self
498
+
499
+ # Attach architecture to the config
500
+ # Save the config
501
+ if is_main_process:
502
+ model_to_save.save_config(save_directory)
503
+
504
+ # Save the model
505
+ state_dict = model_to_save.state_dict()
506
+
507
+ weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
508
+ weights_name = _add_variant(weights_name, variant)
509
+
510
+ # Save the model
511
+ if safe_serialization:
512
+ safetensors.torch.save_file(
513
+ state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
514
+ )
515
+ else:
516
+ torch.save(state_dict, os.path.join(save_directory, weights_name))
517
+
518
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
519
+
520
+ if push_to_hub:
521
+ self._upload_folder(
522
+ save_directory,
523
+ repo_id,
524
+ token=token,
525
+ commit_message=commit_message,
526
+ create_pr=create_pr,
527
+ )
528
+
529
+ @classmethod
530
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
531
+ r"""
532
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
533
+
534
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
535
+ train the model, set it back in training mode with `model.train()`.
536
+
537
+ Parameters:
538
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
539
+ Can be either:
540
+
541
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
542
+ the Hub.
543
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
544
+ with [`~ModelMixin.save_pretrained`].
545
+
546
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
547
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
548
+ is not used.
549
+ torch_dtype (`str` or `torch.dtype`, *optional*):
550
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
551
+ dtype is automatically derived from the model's weights.
552
+ force_download (`bool`, *optional*, defaults to `False`):
553
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
554
+ cached versions if they exist.
555
+ resume_download (`bool`, *optional*, defaults to `False`):
556
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
557
+ incompletely downloaded files are deleted.
558
+ proxies (`Dict[str, str]`, *optional*):
559
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
560
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
561
+ output_loading_info (`bool`, *optional*, defaults to `False`):
562
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
563
+ local_files_only(`bool`, *optional*, defaults to `False`):
564
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
565
+ won't be downloaded from the Hub.
566
+ use_auth_token (`str` or *bool*, *optional*):
567
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
568
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
569
+ revision (`str`, *optional*, defaults to `"main"`):
570
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
571
+ allowed by Git.
572
+ from_flax (`bool`, *optional*, defaults to `False`):
573
+ Load the model weights from a Flax checkpoint save file.
574
+ subfolder (`str`, *optional*, defaults to `""`):
575
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
576
+ mirror (`str`, *optional*):
577
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
578
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
579
+ information.
580
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
581
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
582
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
583
+ same device.
584
+
585
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
586
+ more information about each option see [designing a device
587
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
588
+ max_memory (`Dict`, *optional*):
589
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
590
+ each GPU and the available CPU RAM if unset.
591
+ offload_folder (`str` or `os.PathLike`, *optional*):
592
+ The path to offload weights if `device_map` contains the value `"disk"`.
593
+ offload_state_dict (`bool`, *optional*):
594
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
595
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
596
+ when there is some disk offload.
597
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
598
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
599
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
600
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
601
+ argument to `True` will raise an error.
602
+ variant (`str`, *optional*):
603
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
604
+ loading `from_flax`.
605
+ use_safetensors (`bool`, *optional*, defaults to `None`):
606
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
607
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
608
+ weights. If set to `False`, `safetensors` weights are not loaded.
609
+
610
+ <Tip>
611
+
612
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
613
+ `huggingface-cli login`. You can also activate the special
614
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
615
+ firewalled environment.
616
+
617
+ </Tip>
618
+
619
+ Example:
620
+
621
+ ```py
622
+ from diffusers import UNet2DConditionModel
623
+
624
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
625
+ ```
626
+
627
+ If you get the error message below, you need to finetune the weights for your downstream task:
628
+
629
+ ```bash
630
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
631
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
632
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
633
+ ```
634
+ """
635
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
636
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
637
+ force_download = kwargs.pop("force_download", False)
638
+ from_flax = kwargs.pop("from_flax", False)
639
+ resume_download = kwargs.pop("resume_download", False)
640
+ proxies = kwargs.pop("proxies", None)
641
+ output_loading_info = kwargs.pop("output_loading_info", False)
642
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
643
+ use_auth_token = kwargs.pop("use_auth_token", None)
644
+ revision = kwargs.pop("revision", None)
645
+ torch_dtype = kwargs.pop("torch_dtype", None)
646
+ subfolder = kwargs.pop("subfolder", None)
647
+ device_map = kwargs.pop("device_map", None)
648
+ max_memory = kwargs.pop("max_memory", None)
649
+ offload_folder = kwargs.pop("offload_folder", None)
650
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
651
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
652
+ variant = kwargs.pop("variant", None)
653
+ use_safetensors = kwargs.pop("use_safetensors", None)
654
+
655
+ allow_pickle = False
656
+ if use_safetensors is None:
657
+ use_safetensors = True
658
+ allow_pickle = True
659
+
660
+ if low_cpu_mem_usage and not is_accelerate_available():
661
+ low_cpu_mem_usage = False
662
+ logger.warning(
663
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
664
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
665
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
666
+ " install accelerate\n```\n."
667
+ )
668
+
669
+ if device_map is not None and not is_accelerate_available():
670
+ raise NotImplementedError(
671
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
672
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
673
+ )
674
+
675
+ # Check if we can handle device_map and dispatching the weights
676
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
677
+ raise NotImplementedError(
678
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
679
+ " `device_map=None`."
680
+ )
681
+
682
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
683
+ raise NotImplementedError(
684
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
685
+ " `low_cpu_mem_usage=False`."
686
+ )
687
+
688
+ if low_cpu_mem_usage is False and device_map is not None:
689
+ raise ValueError(
690
+ f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
691
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
692
+ )
693
+
694
+ # Load config if we don't provide a configuration
695
+ config_path = pretrained_model_name_or_path
696
+
697
+ user_agent = {
698
+ "diffusers": __version__,
699
+ "file_type": "model",
700
+ "framework": "pytorch",
701
+ }
702
+
703
+ # load config
704
+ config, unused_kwargs, commit_hash = cls.load_config(
705
+ config_path,
706
+ cache_dir=cache_dir,
707
+ return_unused_kwargs=True,
708
+ return_commit_hash=True,
709
+ force_download=force_download,
710
+ resume_download=resume_download,
711
+ proxies=proxies,
712
+ local_files_only=local_files_only,
713
+ use_auth_token=use_auth_token,
714
+ revision=revision,
715
+ subfolder=subfolder,
716
+ device_map=device_map,
717
+ max_memory=max_memory,
718
+ offload_folder=offload_folder,
719
+ offload_state_dict=offload_state_dict,
720
+ user_agent=user_agent,
721
+ **kwargs,
722
+ )
723
+
724
+ # load model
725
+ model_file = None
726
+ if from_flax:
727
+ model_file = _get_model_file(
728
+ pretrained_model_name_or_path,
729
+ weights_name=FLAX_WEIGHTS_NAME,
730
+ cache_dir=cache_dir,
731
+ force_download=force_download,
732
+ resume_download=resume_download,
733
+ proxies=proxies,
734
+ local_files_only=local_files_only,
735
+ use_auth_token=use_auth_token,
736
+ revision=revision,
737
+ subfolder=subfolder,
738
+ user_agent=user_agent,
739
+ commit_hash=commit_hash,
740
+ )
741
+ model = cls.from_config(config, **unused_kwargs)
742
+
743
+ # Convert the weights
744
+ from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
745
+
746
+ model = load_flax_checkpoint_in_pytorch_model(model, model_file)
747
+ else:
748
+ if use_safetensors:
749
+ try:
750
+ model_file = _get_model_file(
751
+ pretrained_model_name_or_path,
752
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
753
+ cache_dir=cache_dir,
754
+ force_download=force_download,
755
+ resume_download=resume_download,
756
+ proxies=proxies,
757
+ local_files_only=local_files_only,
758
+ use_auth_token=use_auth_token,
759
+ revision=revision,
760
+ subfolder=subfolder,
761
+ user_agent=user_agent,
762
+ commit_hash=commit_hash,
763
+ )
764
+ except IOError as e:
765
+ if not allow_pickle:
766
+ raise e
767
+ pass
768
+ if model_file is None:
769
+ model_file = _get_model_file(
770
+ pretrained_model_name_or_path,
771
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
772
+ cache_dir=cache_dir,
773
+ force_download=force_download,
774
+ resume_download=resume_download,
775
+ proxies=proxies,
776
+ local_files_only=local_files_only,
777
+ use_auth_token=use_auth_token,
778
+ revision=revision,
779
+ subfolder=subfolder,
780
+ user_agent=user_agent,
781
+ commit_hash=commit_hash,
782
+ )
783
+
784
+ if low_cpu_mem_usage:
785
+ # Instantiate model with empty weights
786
+ with accelerate.init_empty_weights():
787
+ model = cls.from_config(config, **unused_kwargs)
788
+
789
+ # if device_map is None, load the state dict and move the params from meta device to the cpu
790
+ if device_map is None:
791
+ param_device = "cpu"
792
+ state_dict = load_state_dict(model_file, variant=variant)
793
+ model._convert_deprecated_attention_blocks(state_dict)
794
+ # move the params from meta device to cpu
795
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
796
+ # if len(missing_keys) > 0:
797
+ # raise ValueError(
798
+ # f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
799
+ # f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
800
+ # " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
801
+ # " those weights or else make sure your checkpoint file is correct."
802
+ # )
803
+
804
+ unexpected_keys = load_model_dict_into_meta(
805
+ model,
806
+ state_dict,
807
+ device=param_device,
808
+ dtype=torch_dtype,
809
+ model_name_or_path=pretrained_model_name_or_path,
810
+ )
811
+
812
+ if cls._keys_to_ignore_on_load_unexpected is not None:
813
+ for pat in cls._keys_to_ignore_on_load_unexpected:
814
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
815
+
816
+ if len(unexpected_keys) > 0:
817
+ logger.warn(
818
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
819
+ )
820
+
821
+ else: # else let accelerate handle loading and dispatching.
822
+ # Load weights and dispatch according to the device_map
823
+ # by default the device_map is None and the weights are loaded on the CPU
824
+ try:
825
+ accelerate.load_checkpoint_and_dispatch(
826
+ model,
827
+ model_file,
828
+ device_map,
829
+ max_memory=max_memory,
830
+ offload_folder=offload_folder,
831
+ offload_state_dict=offload_state_dict,
832
+ dtype=torch_dtype,
833
+ )
834
+ except AttributeError as e:
835
+ # When using accelerate loading, we do not have the ability to load the state
836
+ # dict and rename the weight names manually. Additionally, accelerate skips
837
+ # torch loading conventions and directly writes into `module.{_buffers, _parameters}`
838
+ # (which look like they should be private variables?), so we can't use the standard hooks
839
+ # to rename parameters on load. We need to mimic the original weight names so the correct
840
+ # attributes are available. After we have loaded the weights, we convert the deprecated
841
+ # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
842
+ # the weights so we don't have to do this again.
843
+
844
+ if "'Attention' object has no attribute" in str(e):
845
+ logger.warn(
846
+ f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
847
+ " was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
848
+ " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
849
+ " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
850
+ " please also re-upload it or open a PR on the original repository."
851
+ )
852
+ model._temp_convert_self_to_deprecated_attention_blocks()
853
+ accelerate.load_checkpoint_and_dispatch(
854
+ model,
855
+ model_file,
856
+ device_map,
857
+ max_memory=max_memory,
858
+ offload_folder=offload_folder,
859
+ offload_state_dict=offload_state_dict,
860
+ dtype=torch_dtype,
861
+ )
862
+ model._undo_temp_convert_self_to_deprecated_attention_blocks()
863
+ else:
864
+ raise e
865
+
866
+ loading_info = {
867
+ "missing_keys": [],
868
+ "unexpected_keys": [],
869
+ "mismatched_keys": [],
870
+ "error_msgs": [],
871
+ }
872
+ else:
873
+ model = cls.from_config(config, **unused_kwargs)
874
+
875
+ state_dict = load_state_dict(model_file, variant=variant)
876
+ model._convert_deprecated_attention_blocks(state_dict)
877
+
878
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
879
+ model,
880
+ state_dict,
881
+ model_file,
882
+ pretrained_model_name_or_path,
883
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
884
+ )
885
+
886
+ loading_info = {
887
+ "missing_keys": missing_keys,
888
+ "unexpected_keys": unexpected_keys,
889
+ "mismatched_keys": mismatched_keys,
890
+ "error_msgs": error_msgs,
891
+ }
892
+
893
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
894
+ raise ValueError(
895
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
896
+ )
897
+ elif torch_dtype is not None:
898
+ model = model.to(torch_dtype)
899
+
900
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
901
+
902
+ # Set model in evaluation mode to deactivate DropOut modules by default
903
+ model.eval()
904
+ if output_loading_info:
905
+ return model, loading_info
906
+
907
+ return model
908
+
909
+ @classmethod
910
+ def _load_pretrained_model(
911
+ cls,
912
+ model,
913
+ state_dict,
914
+ resolved_archive_file,
915
+ pretrained_model_name_or_path,
916
+ ignore_mismatched_sizes=False,
917
+ ):
918
+ # Retrieve missing & unexpected_keys
919
+ model_state_dict = model.state_dict()
920
+ loaded_keys = list(state_dict.keys())
921
+
922
+ expected_keys = list(model_state_dict.keys())
923
+
924
+ original_loaded_keys = loaded_keys
925
+
926
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
927
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
928
+
929
+ # Make sure we are able to load base models as well as derived models (with heads)
930
+ model_to_load = model
931
+
932
+ def _find_mismatched_keys(
933
+ state_dict,
934
+ model_state_dict,
935
+ loaded_keys,
936
+ ignore_mismatched_sizes,
937
+ ):
938
+ mismatched_keys = []
939
+ if ignore_mismatched_sizes:
940
+ for checkpoint_key in loaded_keys:
941
+ model_key = checkpoint_key
942
+
943
+ if (
944
+ model_key in model_state_dict
945
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
946
+ ):
947
+ mismatched_keys.append(
948
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
949
+ )
950
+ del state_dict[checkpoint_key]
951
+ return mismatched_keys
952
+
953
+ if state_dict is not None:
954
+ # Whole checkpoint
955
+ mismatched_keys = _find_mismatched_keys(
956
+ state_dict,
957
+ model_state_dict,
958
+ original_loaded_keys,
959
+ ignore_mismatched_sizes,
960
+ )
961
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
962
+
963
+ if len(error_msgs) > 0:
964
+ error_msg = "\n\t".join(error_msgs)
965
+ if "size mismatch" in error_msg:
966
+ error_msg += (
967
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
968
+ )
969
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
970
+
971
+ if len(unexpected_keys) > 0:
972
+ logger.warning(
973
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
974
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
975
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
976
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
977
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
978
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
979
+ " identical (initializing a BertForSequenceClassification model from a"
980
+ " BertForSequenceClassification model)."
981
+ )
982
+ else:
983
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
984
+ if len(missing_keys) > 0:
985
+ logger.warning(
986
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
987
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
988
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
989
+ )
990
+ elif len(mismatched_keys) == 0:
991
+ logger.info(
992
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
993
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
994
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
995
+ " without further training."
996
+ )
997
+ if len(mismatched_keys) > 0:
998
+ mismatched_warning = "\n".join(
999
+ [
1000
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
1001
+ for key, shape1, shape2 in mismatched_keys
1002
+ ]
1003
+ )
1004
+ logger.warning(
1005
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1006
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
1007
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
1008
+ " able to use it for predictions and inference."
1009
+ )
1010
+
1011
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
1012
+
1013
+ @property
1014
+ def device(self) -> device:
1015
+ """
1016
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
1017
+ device).
1018
+ """
1019
+ return get_parameter_device(self)
1020
+
1021
+ @property
1022
+ def dtype(self) -> torch.dtype:
1023
+ """
1024
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
1025
+ """
1026
+ return get_parameter_dtype(self)
1027
+
1028
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
1029
+ """
1030
+ Get number of (trainable or non-embedding) parameters in the module.
1031
+
1032
+ Args:
1033
+ only_trainable (`bool`, *optional*, defaults to `False`):
1034
+ Whether or not to return only the number of trainable parameters.
1035
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
1036
+ Whether or not to return only the number of non-embedding parameters.
1037
+
1038
+ Returns:
1039
+ `int`: The number of parameters.
1040
+
1041
+ Example:
1042
+
1043
+ ```py
1044
+ from diffusers import UNet2DConditionModel
1045
+
1046
+ model_id = "runwayml/stable-diffusion-v1-5"
1047
+ unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
1048
+ unet.num_parameters(only_trainable=True)
1049
+ 859520964
1050
+ ```
1051
+ """
1052
+
1053
+ if exclude_embeddings:
1054
+ embedding_param_names = [
1055
+ f"{name}.weight"
1056
+ for name, module_type in self.named_modules()
1057
+ if isinstance(module_type, torch.nn.Embedding)
1058
+ ]
1059
+ non_embedding_parameters = [
1060
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
1061
+ ]
1062
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
1063
+ else:
1064
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
1065
+
1066
+ def _convert_deprecated_attention_blocks(self, state_dict):
1067
+ deprecated_attention_block_paths = []
1068
+
1069
+ def recursive_find_attn_block(name, module):
1070
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
1071
+ deprecated_attention_block_paths.append(name)
1072
+
1073
+ for sub_name, sub_module in module.named_children():
1074
+ sub_name = sub_name if name == "" else f"{name}.{sub_name}"
1075
+ recursive_find_attn_block(sub_name, sub_module)
1076
+
1077
+ recursive_find_attn_block("", self)
1078
+
1079
+ # NOTE: we have to check if the deprecated parameters are in the state dict
1080
+ # because it is possible we are loading from a state dict that was already
1081
+ # converted
1082
+
1083
+ for path in deprecated_attention_block_paths:
1084
+ # group_norm path stays the same
1085
+
1086
+ # query -> to_q
1087
+ if f"{path}.query.weight" in state_dict:
1088
+ state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
1089
+ if f"{path}.query.bias" in state_dict:
1090
+ state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
1091
+
1092
+ # key -> to_k
1093
+ if f"{path}.key.weight" in state_dict:
1094
+ state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
1095
+ if f"{path}.key.bias" in state_dict:
1096
+ state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
1097
+
1098
+ # value -> to_v
1099
+ if f"{path}.value.weight" in state_dict:
1100
+ state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
1101
+ if f"{path}.value.bias" in state_dict:
1102
+ state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
1103
+
1104
+ # proj_attn -> to_out.0
1105
+ if f"{path}.proj_attn.weight" in state_dict:
1106
+ state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
1107
+ if f"{path}.proj_attn.bias" in state_dict:
1108
+ state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
1109
+
1110
+ def _temp_convert_self_to_deprecated_attention_blocks(self):
1111
+ deprecated_attention_block_modules = []
1112
+
1113
+ def recursive_find_attn_block(module):
1114
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
1115
+ deprecated_attention_block_modules.append(module)
1116
+
1117
+ for sub_module in module.children():
1118
+ recursive_find_attn_block(sub_module)
1119
+
1120
+ recursive_find_attn_block(self)
1121
+
1122
+ for module in deprecated_attention_block_modules:
1123
+ module.query = module.to_q
1124
+ module.key = module.to_k
1125
+ module.value = module.to_v
1126
+ module.proj_attn = module.to_out[0]
1127
+
1128
+ # We don't _have_ to delete the old attributes, but it's helpful to ensure
1129
+ # that _all_ the weights are loaded into the new attributes and we're not
1130
+ # making an incorrect assumption that this model should be converted when
1131
+ # it really shouldn't be.
1132
+ del module.to_q
1133
+ del module.to_k
1134
+ del module.to_v
1135
+ del module.to_out
1136
+
1137
+ def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
1138
+ deprecated_attention_block_modules = []
1139
+
1140
+ def recursive_find_attn_block(module):
1141
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
1142
+ deprecated_attention_block_modules.append(module)
1143
+
1144
+ for sub_module in module.children():
1145
+ recursive_find_attn_block(sub_module)
1146
+
1147
+ recursive_find_attn_block(self)
1148
+
1149
+ for module in deprecated_attention_block_modules:
1150
+ module.to_q = module.query
1151
+ module.to_k = module.key
1152
+ module.to_v = module.value
1153
+ module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
1154
+
1155
+ del module.query
1156
+ del module.key
1157
+ del module.value
1158
+ del module.proj_attn
diffusers/models/normalization.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Dict, Optional, Tuple
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from .activations import get_activation
23
+ from .embeddings import CombinedTimestepLabelEmbeddings, CombinedTimestepSizeEmbeddings
24
+
25
+
26
+ class AdaLayerNorm(nn.Module):
27
+ r"""
28
+ Norm layer modified to incorporate timestep embeddings.
29
+
30
+ Parameters:
31
+ embedding_dim (`int`): The size of each embedding vector.
32
+ num_embeddings (`int`): The size of the embeddings dictionary.
33
+ """
34
+
35
+ def __init__(self, embedding_dim: int, num_embeddings: int):
36
+ super().__init__()
37
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
38
+ self.silu = nn.SiLU()
39
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
40
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
41
+
42
+ def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
43
+ emb = self.linear(self.silu(self.emb(timestep)))
44
+ scale, shift = torch.chunk(emb, 2)
45
+ x = self.norm(x) * (1 + scale) + shift
46
+ return x
47
+
48
+
49
+ class AdaLayerNormZero(nn.Module):
50
+ r"""
51
+ Norm layer adaptive layer norm zero (adaLN-Zero).
52
+
53
+ Parameters:
54
+ embedding_dim (`int`): The size of each embedding vector.
55
+ num_embeddings (`int`): The size of the embeddings dictionary.
56
+ """
57
+
58
+ def __init__(self, embedding_dim: int, num_embeddings: int):
59
+ super().__init__()
60
+
61
+ self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
62
+
63
+ self.silu = nn.SiLU()
64
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
65
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
66
+
67
+ def forward(
68
+ self,
69
+ x: torch.Tensor,
70
+ timestep: torch.Tensor,
71
+ class_labels: torch.LongTensor,
72
+ hidden_dtype: Optional[torch.dtype] = None,
73
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
74
+ emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
75
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
76
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
77
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
78
+
79
+
80
+ class AdaLayerNormSingle(nn.Module):
81
+ r"""
82
+ Norm layer adaptive layer norm single (adaLN-single).
83
+
84
+ As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
85
+
86
+ Parameters:
87
+ embedding_dim (`int`): The size of each embedding vector.
88
+ use_additional_conditions (`bool`): To use additional conditions for normalization or not.
89
+ """
90
+
91
+ def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
92
+ super().__init__()
93
+
94
+ self.emb = CombinedTimestepSizeEmbeddings(
95
+ embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
96
+ )
97
+
98
+ self.silu = nn.SiLU()
99
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
100
+
101
+ def forward(
102
+ self,
103
+ timestep: torch.Tensor,
104
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
105
+ batch_size: int = None,
106
+ hidden_dtype: Optional[torch.dtype] = None,
107
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
108
+ # No modulation happening here.
109
+ embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
110
+ return self.linear(self.silu(embedded_timestep)), embedded_timestep
111
+
112
+
113
+ class AdaGroupNorm(nn.Module):
114
+ r"""
115
+ GroupNorm layer modified to incorporate timestep embeddings.
116
+
117
+ Parameters:
118
+ embedding_dim (`int`): The size of each embedding vector.
119
+ num_embeddings (`int`): The size of the embeddings dictionary.
120
+ num_groups (`int`): The number of groups to separate the channels into.
121
+ act_fn (`str`, *optional*, defaults to `None`): The activation function to use.
122
+ eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability.
123
+ """
124
+
125
+ def __init__(
126
+ self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
127
+ ):
128
+ super().__init__()
129
+ self.num_groups = num_groups
130
+ self.eps = eps
131
+
132
+ if act_fn is None:
133
+ self.act = None
134
+ else:
135
+ self.act = get_activation(act_fn)
136
+
137
+ self.linear = nn.Linear(embedding_dim, out_dim * 2)
138
+
139
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
140
+ if self.act:
141
+ emb = self.act(emb)
142
+ emb = self.linear(emb)
143
+ emb = emb[:, :, None, None]
144
+ scale, shift = emb.chunk(2, dim=1)
145
+
146
+ x = F.group_norm(x, self.num_groups, eps=self.eps)
147
+ x = x * (1 + scale) + shift
148
+ return x
diffusers/models/prior_transformer.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Optional, Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ from ..configuration_utils import ConfigMixin, register_to_config
9
+ from ..loaders import UNet2DConditionLoadersMixin
10
+ from ..utils import BaseOutput
11
+ from .attention import BasicTransformerBlock
12
+ from .attention_processor import (
13
+ ADDED_KV_ATTENTION_PROCESSORS,
14
+ CROSS_ATTENTION_PROCESSORS,
15
+ AttentionProcessor,
16
+ AttnAddedKVProcessor,
17
+ AttnProcessor,
18
+ )
19
+ from .embeddings import TimestepEmbedding, Timesteps
20
+ from .modeling_utils import ModelMixin
21
+
22
+
23
+ @dataclass
24
+ class PriorTransformerOutput(BaseOutput):
25
+ """
26
+ The output of [`PriorTransformer`].
27
+
28
+ Args:
29
+ predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
30
+ The predicted CLIP image embedding conditioned on the CLIP text embedding input.
31
+ """
32
+
33
+ predicted_image_embedding: torch.FloatTensor
34
+
35
+
36
+ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
37
+ """
38
+ A Prior Transformer model.
39
+
40
+ Parameters:
41
+ num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
42
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
43
+ num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
44
+ embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
45
+ num_embeddings (`int`, *optional*, defaults to 77):
46
+ The number of embeddings of the model input `hidden_states`
47
+ additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
48
+ projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
49
+ additional_embeddings`.
50
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
51
+ time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
52
+ The activation function to use to create timestep embeddings.
53
+ norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
54
+ passing to Transformer blocks. Set it to `None` if normalization is not needed.
55
+ embedding_proj_norm_type (`str`, *optional*, defaults to None):
56
+ The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
57
+ needed.
58
+ encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
59
+ The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
60
+ `encoder_hidden_states` is `None`.
61
+ added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
62
+ Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
63
+ product between the text embedding and image embedding as proposed in the unclip paper
64
+ https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
65
+ time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
66
+ If None, will be set to `num_attention_heads * attention_head_dim`
67
+ embedding_proj_dim (`int`, *optional*, default to None):
68
+ The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
69
+ clip_embed_dim (`int`, *optional*, default to None):
70
+ The dimension of the output. If None, will be set to `embedding_dim`.
71
+ """
72
+
73
+ @register_to_config
74
+ def __init__(
75
+ self,
76
+ num_attention_heads: int = 32,
77
+ attention_head_dim: int = 64,
78
+ num_layers: int = 20,
79
+ embedding_dim: int = 768,
80
+ num_embeddings=77,
81
+ additional_embeddings=4,
82
+ dropout: float = 0.0,
83
+ time_embed_act_fn: str = "silu",
84
+ norm_in_type: Optional[str] = None, # layer
85
+ embedding_proj_norm_type: Optional[str] = None, # layer
86
+ encoder_hid_proj_type: Optional[str] = "linear", # linear
87
+ added_emb_type: Optional[str] = "prd", # prd
88
+ time_embed_dim: Optional[int] = None,
89
+ embedding_proj_dim: Optional[int] = None,
90
+ clip_embed_dim: Optional[int] = None,
91
+ ):
92
+ super().__init__()
93
+ self.num_attention_heads = num_attention_heads
94
+ self.attention_head_dim = attention_head_dim
95
+ inner_dim = num_attention_heads * attention_head_dim
96
+ self.additional_embeddings = additional_embeddings
97
+
98
+ time_embed_dim = time_embed_dim or inner_dim
99
+ embedding_proj_dim = embedding_proj_dim or embedding_dim
100
+ clip_embed_dim = clip_embed_dim or embedding_dim
101
+
102
+ self.time_proj = Timesteps(inner_dim, True, 0)
103
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
104
+
105
+ self.proj_in = nn.Linear(embedding_dim, inner_dim)
106
+
107
+ if embedding_proj_norm_type is None:
108
+ self.embedding_proj_norm = None
109
+ elif embedding_proj_norm_type == "layer":
110
+ self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
111
+ else:
112
+ raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
113
+
114
+ self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
115
+
116
+ if encoder_hid_proj_type is None:
117
+ self.encoder_hidden_states_proj = None
118
+ elif encoder_hid_proj_type == "linear":
119
+ self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
120
+ else:
121
+ raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
122
+
123
+ self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
124
+
125
+ if added_emb_type == "prd":
126
+ self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
127
+ elif added_emb_type is None:
128
+ self.prd_embedding = None
129
+ else:
130
+ raise ValueError(
131
+ f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
132
+ )
133
+
134
+ self.transformer_blocks = nn.ModuleList(
135
+ [
136
+ BasicTransformerBlock(
137
+ inner_dim,
138
+ num_attention_heads,
139
+ attention_head_dim,
140
+ dropout=dropout,
141
+ activation_fn="gelu",
142
+ attention_bias=True,
143
+ )
144
+ for d in range(num_layers)
145
+ ]
146
+ )
147
+
148
+ if norm_in_type == "layer":
149
+ self.norm_in = nn.LayerNorm(inner_dim)
150
+ elif norm_in_type is None:
151
+ self.norm_in = None
152
+ else:
153
+ raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
154
+
155
+ self.norm_out = nn.LayerNorm(inner_dim)
156
+
157
+ self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
158
+
159
+ causal_attention_mask = torch.full(
160
+ [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
161
+ )
162
+ causal_attention_mask.triu_(1)
163
+ causal_attention_mask = causal_attention_mask[None, ...]
164
+ self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
165
+
166
+ self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
167
+ self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
168
+
169
+ @property
170
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
171
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
172
+ r"""
173
+ Returns:
174
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
175
+ indexed by its weight name.
176
+ """
177
+ # set recursively
178
+ processors = {}
179
+
180
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
181
+ if hasattr(module, "get_processor"):
182
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
183
+
184
+ for sub_name, child in module.named_children():
185
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
186
+
187
+ return processors
188
+
189
+ for name, module in self.named_children():
190
+ fn_recursive_add_processors(name, module, processors)
191
+
192
+ return processors
193
+
194
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
195
+ def set_attn_processor(
196
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
197
+ ):
198
+ r"""
199
+ Sets the attention processor to use to compute attention.
200
+
201
+ Parameters:
202
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
203
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
204
+ for **all** `Attention` layers.
205
+
206
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
207
+ processor. This is strongly recommended when setting trainable attention processors.
208
+
209
+ """
210
+ count = len(self.attn_processors.keys())
211
+
212
+ if isinstance(processor, dict) and len(processor) != count:
213
+ raise ValueError(
214
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
215
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
216
+ )
217
+
218
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
219
+ if hasattr(module, "set_processor"):
220
+ if not isinstance(processor, dict):
221
+ module.set_processor(processor, _remove_lora=_remove_lora)
222
+ else:
223
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
224
+
225
+ for sub_name, child in module.named_children():
226
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
227
+
228
+ for name, module in self.named_children():
229
+ fn_recursive_attn_processor(name, module, processor)
230
+
231
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
232
+ def set_default_attn_processor(self):
233
+ """
234
+ Disables custom attention processors and sets the default attention implementation.
235
+ """
236
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
237
+ processor = AttnAddedKVProcessor()
238
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
239
+ processor = AttnProcessor()
240
+ else:
241
+ raise ValueError(
242
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
243
+ )
244
+
245
+ self.set_attn_processor(processor, _remove_lora=True)
246
+
247
+ def forward(
248
+ self,
249
+ hidden_states,
250
+ timestep: Union[torch.Tensor, float, int],
251
+ proj_embedding: torch.FloatTensor,
252
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
253
+ attention_mask: Optional[torch.BoolTensor] = None,
254
+ return_dict: bool = True,
255
+ ):
256
+ """
257
+ The [`PriorTransformer`] forward method.
258
+
259
+ Args:
260
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
261
+ The currently predicted image embeddings.
262
+ timestep (`torch.LongTensor`):
263
+ Current denoising step.
264
+ proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
265
+ Projected embedding vector the denoising process is conditioned on.
266
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
267
+ Hidden states of the text embeddings the denoising process is conditioned on.
268
+ attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
269
+ Text face_hair_mask for the text embeddings.
270
+ return_dict (`bool`, *optional*, defaults to `True`):
271
+ Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
272
+ tuple.
273
+
274
+ Returns:
275
+ [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
276
+ If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
277
+ tuple is returned where the first element is the sample tensor.
278
+ """
279
+ batch_size = hidden_states.shape[0]
280
+
281
+ timesteps = timestep
282
+ if not torch.is_tensor(timesteps):
283
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
284
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
285
+ timesteps = timesteps[None].to(hidden_states.device)
286
+
287
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
288
+ timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
289
+
290
+ timesteps_projected = self.time_proj(timesteps)
291
+
292
+ # timesteps does not contain any weights and will always return f32 tensors
293
+ # but time_embedding might be fp16, so we need to cast here.
294
+ timesteps_projected = timesteps_projected.to(dtype=self.dtype)
295
+ time_embeddings = self.time_embedding(timesteps_projected)
296
+
297
+ if self.embedding_proj_norm is not None:
298
+ proj_embedding = self.embedding_proj_norm(proj_embedding)
299
+
300
+ proj_embeddings = self.embedding_proj(proj_embedding)
301
+ if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
302
+ encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
303
+ elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
304
+ raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
305
+
306
+ hidden_states = self.proj_in(hidden_states)
307
+
308
+ positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
309
+
310
+ additional_embeds = []
311
+ additional_embeddings_len = 0
312
+
313
+ if encoder_hidden_states is not None:
314
+ additional_embeds.append(encoder_hidden_states)
315
+ additional_embeddings_len += encoder_hidden_states.shape[1]
316
+
317
+ if len(proj_embeddings.shape) == 2:
318
+ proj_embeddings = proj_embeddings[:, None, :]
319
+
320
+ if len(hidden_states.shape) == 2:
321
+ hidden_states = hidden_states[:, None, :]
322
+
323
+ additional_embeds = additional_embeds + [
324
+ proj_embeddings,
325
+ time_embeddings[:, None, :],
326
+ hidden_states,
327
+ ]
328
+
329
+ if self.prd_embedding is not None:
330
+ prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
331
+ additional_embeds.append(prd_embedding)
332
+
333
+ hidden_states = torch.cat(
334
+ additional_embeds,
335
+ dim=1,
336
+ )
337
+
338
+ # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
339
+ additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
340
+ if positional_embeddings.shape[1] < hidden_states.shape[1]:
341
+ positional_embeddings = F.pad(
342
+ positional_embeddings,
343
+ (
344
+ 0,
345
+ 0,
346
+ additional_embeddings_len,
347
+ self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
348
+ ),
349
+ value=0.0,
350
+ )
351
+
352
+ hidden_states = hidden_states + positional_embeddings
353
+
354
+ if attention_mask is not None:
355
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
356
+ attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
357
+ attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
358
+ attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
359
+
360
+ if self.norm_in is not None:
361
+ hidden_states = self.norm_in(hidden_states)
362
+
363
+ for block in self.transformer_blocks:
364
+ hidden_states = block(hidden_states, attention_mask=attention_mask)
365
+
366
+ hidden_states = self.norm_out(hidden_states)
367
+
368
+ if self.prd_embedding is not None:
369
+ hidden_states = hidden_states[:, -1]
370
+ else:
371
+ hidden_states = hidden_states[:, additional_embeddings_len:]
372
+
373
+ predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
374
+
375
+ if not return_dict:
376
+ return (predicted_image_embedding,)
377
+
378
+ return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
379
+
380
+ def post_process_latents(self, prior_latents):
381
+ prior_latents = (prior_latents * self.clip_std) + self.clip_mean
382
+ return prior_latents
diffusers/models/resnet.py ADDED
@@ -0,0 +1,1037 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ # `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partial
17
+ from typing import Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from ..utils import USE_PEFT_BACKEND
24
+ from .activations import get_activation
25
+ from .attention_processor import SpatialNorm
26
+ from .lora import LoRACompatibleConv, LoRACompatibleLinear
27
+ from .normalization import AdaGroupNorm
28
+
29
+
30
+ class Upsample1D(nn.Module):
31
+ """A 1D upsampling layer with an optional convolution.
32
+
33
+ Parameters:
34
+ channels (`int`):
35
+ number of channels in the inputs and outputs.
36
+ use_conv (`bool`, default `False`):
37
+ option to use a convolution.
38
+ use_conv_transpose (`bool`, default `False`):
39
+ option to use a convolution transpose.
40
+ out_channels (`int`, optional):
41
+ number of output channels. Defaults to `channels`.
42
+ name (`str`, default `conv`):
43
+ name of the upsampling 1D layer.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ channels: int,
49
+ use_conv: bool = False,
50
+ use_conv_transpose: bool = False,
51
+ out_channels: Optional[int] = None,
52
+ name: str = "conv",
53
+ ):
54
+ super().__init__()
55
+ self.channels = channels
56
+ self.out_channels = out_channels or channels
57
+ self.use_conv = use_conv
58
+ self.use_conv_transpose = use_conv_transpose
59
+ self.name = name
60
+
61
+ self.conv = None
62
+ if use_conv_transpose:
63
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
64
+ elif use_conv:
65
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
66
+
67
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
68
+ assert inputs.shape[1] == self.channels
69
+ if self.use_conv_transpose:
70
+ return self.conv(inputs)
71
+
72
+ outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
73
+
74
+ if self.use_conv:
75
+ outputs = self.conv(outputs)
76
+
77
+ return outputs
78
+
79
+
80
+ class Downsample1D(nn.Module):
81
+ """A 1D downsampling layer with an optional convolution.
82
+
83
+ Parameters:
84
+ channels (`int`):
85
+ number of channels in the inputs and outputs.
86
+ use_conv (`bool`, default `False`):
87
+ option to use a convolution.
88
+ out_channels (`int`, optional):
89
+ number of output channels. Defaults to `channels`.
90
+ padding (`int`, default `1`):
91
+ padding for the convolution.
92
+ name (`str`, default `conv`):
93
+ name of the downsampling 1D layer.
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ channels: int,
99
+ use_conv: bool = False,
100
+ out_channels: Optional[int] = None,
101
+ padding: int = 1,
102
+ name: str = "conv",
103
+ ):
104
+ super().__init__()
105
+ self.channels = channels
106
+ self.out_channels = out_channels or channels
107
+ self.use_conv = use_conv
108
+ self.padding = padding
109
+ stride = 2
110
+ self.name = name
111
+
112
+ if use_conv:
113
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
114
+ else:
115
+ assert self.channels == self.out_channels
116
+ self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
117
+
118
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
119
+ assert inputs.shape[1] == self.channels
120
+ return self.conv(inputs)
121
+
122
+
123
+ class Upsample2D(nn.Module):
124
+ """A 2D upsampling layer with an optional convolution.
125
+
126
+ Parameters:
127
+ channels (`int`):
128
+ number of channels in the inputs and outputs.
129
+ use_conv (`bool`, default `False`):
130
+ option to use a convolution.
131
+ use_conv_transpose (`bool`, default `False`):
132
+ option to use a convolution transpose.
133
+ out_channels (`int`, optional):
134
+ number of output channels. Defaults to `channels`.
135
+ name (`str`, default `conv`):
136
+ name of the upsampling 2D layer.
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ channels: int,
142
+ use_conv: bool = False,
143
+ use_conv_transpose: bool = False,
144
+ out_channels: Optional[int] = None,
145
+ name: str = "conv",
146
+ ):
147
+ super().__init__()
148
+ self.channels = channels
149
+ self.out_channels = out_channels or channels
150
+ self.use_conv = use_conv
151
+ self.use_conv_transpose = use_conv_transpose
152
+ self.name = name
153
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
154
+
155
+ conv = None
156
+ if use_conv_transpose:
157
+ conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
158
+ elif use_conv:
159
+ conv = conv_cls(self.channels, self.out_channels, 3, padding=1)
160
+
161
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
162
+ if name == "conv":
163
+ self.conv = conv
164
+ else:
165
+ self.Conv2d_0 = conv
166
+
167
+ def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, scale: float = 1.0):
168
+ assert hidden_states.shape[1] == self.channels
169
+
170
+ if self.use_conv_transpose:
171
+ return self.conv(hidden_states)
172
+
173
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
174
+ # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
175
+ # https://github.com/pytorch/pytorch/issues/86679
176
+ dtype = hidden_states.dtype
177
+ if dtype == torch.bfloat16:
178
+ hidden_states = hidden_states.to(torch.float32)
179
+
180
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
181
+ if hidden_states.shape[0] >= 64:
182
+ hidden_states = hidden_states.contiguous()
183
+
184
+ # if `output_size` is passed we force the interpolation output
185
+ # size and do not make use of `scale_factor=2`
186
+ if output_size is None:
187
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
188
+ else:
189
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
190
+
191
+ # If the input is bfloat16, we cast back to bfloat16
192
+ if dtype == torch.bfloat16:
193
+ hidden_states = hidden_states.to(dtype)
194
+
195
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
196
+ if self.use_conv:
197
+ if self.name == "conv":
198
+ if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND:
199
+ hidden_states = self.conv(hidden_states, scale)
200
+ else:
201
+ hidden_states = self.conv(hidden_states)
202
+ else:
203
+ if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND:
204
+ hidden_states = self.Conv2d_0(hidden_states, scale)
205
+ else:
206
+ hidden_states = self.Conv2d_0(hidden_states)
207
+
208
+ return hidden_states
209
+
210
+
211
+ class Downsample2D(nn.Module):
212
+ """A 2D downsampling layer with an optional convolution.
213
+
214
+ Parameters:
215
+ channels (`int`):
216
+ number of channels in the inputs and outputs.
217
+ use_conv (`bool`, default `False`):
218
+ option to use a convolution.
219
+ out_channels (`int`, optional):
220
+ number of output channels. Defaults to `channels`.
221
+ padding (`int`, default `1`):
222
+ padding for the convolution.
223
+ name (`str`, default `conv`):
224
+ name of the downsampling 2D layer.
225
+ """
226
+
227
+ def __init__(
228
+ self,
229
+ channels: int,
230
+ use_conv: bool = False,
231
+ out_channels: Optional[int] = None,
232
+ padding: int = 1,
233
+ name: str = "conv",
234
+ ):
235
+ super().__init__()
236
+ self.channels = channels
237
+ self.out_channels = out_channels or channels
238
+ self.use_conv = use_conv
239
+ self.padding = padding
240
+ stride = 2
241
+ self.name = name
242
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
243
+
244
+ if use_conv:
245
+ conv = conv_cls(self.channels, self.out_channels, 3, stride=stride, padding=padding)
246
+ else:
247
+ assert self.channels == self.out_channels
248
+ conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
249
+
250
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
251
+ if name == "conv":
252
+ self.Conv2d_0 = conv
253
+ self.conv = conv
254
+ elif name == "Conv2d_0":
255
+ self.conv = conv
256
+ else:
257
+ self.conv = conv
258
+
259
+ def forward(self, hidden_states, scale: float = 1.0):
260
+ assert hidden_states.shape[1] == self.channels
261
+
262
+ if self.use_conv and self.padding == 0:
263
+ pad = (0, 1, 0, 1)
264
+ hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
265
+
266
+ assert hidden_states.shape[1] == self.channels
267
+
268
+ if not USE_PEFT_BACKEND:
269
+ if isinstance(self.conv, LoRACompatibleConv):
270
+ hidden_states = self.conv(hidden_states, scale)
271
+ else:
272
+ hidden_states = self.conv(hidden_states)
273
+ else:
274
+ hidden_states = self.conv(hidden_states)
275
+
276
+ return hidden_states
277
+
278
+
279
+ class FirUpsample2D(nn.Module):
280
+ """A 2D FIR upsampling layer with an optional convolution.
281
+
282
+ Parameters:
283
+ channels (`int`):
284
+ number of channels in the inputs and outputs.
285
+ use_conv (`bool`, default `False`):
286
+ option to use a convolution.
287
+ out_channels (`int`, optional):
288
+ number of output channels. Defaults to `channels`.
289
+ fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
290
+ kernel for the FIR filter.
291
+ """
292
+
293
+ def __init__(
294
+ self,
295
+ channels: int = None,
296
+ out_channels: Optional[int] = None,
297
+ use_conv: bool = False,
298
+ fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
299
+ ):
300
+ super().__init__()
301
+ out_channels = out_channels if out_channels else channels
302
+ if use_conv:
303
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
304
+ self.use_conv = use_conv
305
+ self.fir_kernel = fir_kernel
306
+ self.out_channels = out_channels
307
+
308
+ def _upsample_2d(
309
+ self,
310
+ hidden_states: torch.Tensor,
311
+ weight: Optional[torch.Tensor] = None,
312
+ kernel: Optional[torch.FloatTensor] = None,
313
+ factor: int = 2,
314
+ gain: float = 1,
315
+ ) -> torch.Tensor:
316
+ """Fused `upsample_2d()` followed by `Conv2d()`.
317
+
318
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
319
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
320
+ arbitrary order.
321
+
322
+ Args:
323
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
324
+ weight: Weight tensor of the shape `[filterH, filterW, inChannels,
325
+ outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
326
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
327
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
328
+ factor: Integer upsampling factor (default: 2).
329
+ gain: Scaling factor for signal magnitude (default: 1.0).
330
+
331
+ Returns:
332
+ output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
333
+ datatype as `hidden_states`.
334
+ """
335
+
336
+ assert isinstance(factor, int) and factor >= 1
337
+
338
+ # Setup filter kernel.
339
+ if kernel is None:
340
+ kernel = [1] * factor
341
+
342
+ # setup kernel
343
+ kernel = torch.tensor(kernel, dtype=torch.float32)
344
+ if kernel.ndim == 1:
345
+ kernel = torch.outer(kernel, kernel)
346
+ kernel /= torch.sum(kernel)
347
+
348
+ kernel = kernel * (gain * (factor**2))
349
+
350
+ if self.use_conv:
351
+ convH = weight.shape[2]
352
+ convW = weight.shape[3]
353
+ inC = weight.shape[1]
354
+
355
+ pad_value = (kernel.shape[0] - factor) - (convW - 1)
356
+
357
+ stride = (factor, factor)
358
+ # Determine data dimensions.
359
+ output_shape = (
360
+ (hidden_states.shape[2] - 1) * factor + convH,
361
+ (hidden_states.shape[3] - 1) * factor + convW,
362
+ )
363
+ output_padding = (
364
+ output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
365
+ output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
366
+ )
367
+ assert output_padding[0] >= 0 and output_padding[1] >= 0
368
+ num_groups = hidden_states.shape[1] // inC
369
+
370
+ # Transpose weights.
371
+ weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
372
+ weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
373
+ weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
374
+
375
+ inverse_conv = F.conv_transpose2d(
376
+ hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
377
+ )
378
+
379
+ output = upfirdn2d_native(
380
+ inverse_conv,
381
+ torch.tensor(kernel, device=inverse_conv.device),
382
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
383
+ )
384
+ else:
385
+ pad_value = kernel.shape[0] - factor
386
+ output = upfirdn2d_native(
387
+ hidden_states,
388
+ torch.tensor(kernel, device=hidden_states.device),
389
+ up=factor,
390
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
391
+ )
392
+
393
+ return output
394
+
395
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
396
+ if self.use_conv:
397
+ height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
398
+ height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
399
+ else:
400
+ height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
401
+
402
+ return height
403
+
404
+
405
+ class FirDownsample2D(nn.Module):
406
+ """A 2D FIR downsampling layer with an optional convolution.
407
+
408
+ Parameters:
409
+ channels (`int`):
410
+ number of channels in the inputs and outputs.
411
+ use_conv (`bool`, default `False`):
412
+ option to use a convolution.
413
+ out_channels (`int`, optional):
414
+ number of output channels. Defaults to `channels`.
415
+ fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
416
+ kernel for the FIR filter.
417
+ """
418
+
419
+ def __init__(
420
+ self,
421
+ channels: int = None,
422
+ out_channels: Optional[int] = None,
423
+ use_conv: bool = False,
424
+ fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
425
+ ):
426
+ super().__init__()
427
+ out_channels = out_channels if out_channels else channels
428
+ if use_conv:
429
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
430
+ self.fir_kernel = fir_kernel
431
+ self.use_conv = use_conv
432
+ self.out_channels = out_channels
433
+
434
+ def _downsample_2d(
435
+ self,
436
+ hidden_states: torch.Tensor,
437
+ weight: Optional[torch.Tensor] = None,
438
+ kernel: Optional[torch.FloatTensor] = None,
439
+ factor: int = 2,
440
+ gain: float = 1,
441
+ ) -> torch.Tensor:
442
+ """Fused `Conv2d()` followed by `downsample_2d()`.
443
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
444
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
445
+ arbitrary order.
446
+
447
+ Args:
448
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
449
+ weight:
450
+ Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
451
+ performed by `inChannels = x.shape[0] // numGroups`.
452
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
453
+ factor`, which corresponds to average pooling.
454
+ factor: Integer downsampling factor (default: 2).
455
+ gain: Scaling factor for signal magnitude (default: 1.0).
456
+
457
+ Returns:
458
+ output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
459
+ same datatype as `x`.
460
+ """
461
+
462
+ assert isinstance(factor, int) and factor >= 1
463
+ if kernel is None:
464
+ kernel = [1] * factor
465
+
466
+ # setup kernel
467
+ kernel = torch.tensor(kernel, dtype=torch.float32)
468
+ if kernel.ndim == 1:
469
+ kernel = torch.outer(kernel, kernel)
470
+ kernel /= torch.sum(kernel)
471
+
472
+ kernel = kernel * gain
473
+
474
+ if self.use_conv:
475
+ _, _, convH, convW = weight.shape
476
+ pad_value = (kernel.shape[0] - factor) + (convW - 1)
477
+ stride_value = [factor, factor]
478
+ upfirdn_input = upfirdn2d_native(
479
+ hidden_states,
480
+ torch.tensor(kernel, device=hidden_states.device),
481
+ pad=((pad_value + 1) // 2, pad_value // 2),
482
+ )
483
+ output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
484
+ else:
485
+ pad_value = kernel.shape[0] - factor
486
+ output = upfirdn2d_native(
487
+ hidden_states,
488
+ torch.tensor(kernel, device=hidden_states.device),
489
+ down=factor,
490
+ pad=((pad_value + 1) // 2, pad_value // 2),
491
+ )
492
+
493
+ return output
494
+
495
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
496
+ if self.use_conv:
497
+ downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
498
+ hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
499
+ else:
500
+ hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
501
+
502
+ return hidden_states
503
+
504
+
505
+ # downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
506
+ class KDownsample2D(nn.Module):
507
+ r"""A 2D K-downsampling layer.
508
+
509
+ Parameters:
510
+ pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
511
+ """
512
+
513
+ def __init__(self, pad_mode: str = "reflect"):
514
+ super().__init__()
515
+ self.pad_mode = pad_mode
516
+ kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
517
+ self.pad = kernel_1d.shape[1] // 2 - 1
518
+ self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
519
+
520
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
521
+ inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
522
+ weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
523
+ indices = torch.arange(inputs.shape[1], device=inputs.device)
524
+ kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
525
+ weight[indices, indices] = kernel
526
+ return F.conv2d(inputs, weight, stride=2)
527
+
528
+
529
+ class KUpsample2D(nn.Module):
530
+ r"""A 2D K-upsampling layer.
531
+
532
+ Parameters:
533
+ pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
534
+ """
535
+
536
+ def __init__(self, pad_mode: str = "reflect"):
537
+ super().__init__()
538
+ self.pad_mode = pad_mode
539
+ kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
540
+ self.pad = kernel_1d.shape[1] // 2 - 1
541
+ self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
542
+
543
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
544
+ inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
545
+ weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
546
+ indices = torch.arange(inputs.shape[1], device=inputs.device)
547
+ kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
548
+ weight[indices, indices] = kernel
549
+ return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
550
+
551
+
552
+ class ResnetBlock2D(nn.Module):
553
+ r"""
554
+ A Resnet block.
555
+
556
+ Parameters:
557
+ in_channels (`int`): The number of channels in the input.
558
+ out_channels (`int`, *optional*, default to be `None`):
559
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
560
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
561
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
562
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
563
+ groups_out (`int`, *optional*, default to None):
564
+ The number of groups to use for the second normalization layer. if set to None, same as `groups`.
565
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
566
+ non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
567
+ time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
568
+ By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
569
+ "ada_group" for a stronger conditioning with scale and shift.
570
+ kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
571
+ [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
572
+ output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
573
+ use_in_shortcut (`bool`, *optional*, default to `True`):
574
+ If `True`, add a 1x1 nn.conv2d layer for skip-connection.
575
+ up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
576
+ down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
577
+ conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
578
+ `conv_shortcut` output.
579
+ conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
580
+ If None, same as `out_channels`.
581
+ """
582
+
583
+ def __init__(
584
+ self,
585
+ *,
586
+ in_channels: int,
587
+ out_channels: Optional[int] = None,
588
+ conv_shortcut: bool = False,
589
+ dropout: float = 0.0,
590
+ temb_channels: int = 512,
591
+ groups: int = 32,
592
+ groups_out: Optional[int] = None,
593
+ pre_norm: bool = True,
594
+ eps: float = 1e-6,
595
+ non_linearity: str = "swish",
596
+ skip_time_act: bool = False,
597
+ time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial
598
+ kernel: Optional[torch.FloatTensor] = None,
599
+ output_scale_factor: float = 1.0,
600
+ use_in_shortcut: Optional[bool] = None,
601
+ up: bool = False,
602
+ down: bool = False,
603
+ conv_shortcut_bias: bool = True,
604
+ conv_2d_out_channels: Optional[int] = None,
605
+ ):
606
+ super().__init__()
607
+ self.pre_norm = pre_norm
608
+ self.pre_norm = True
609
+ self.in_channels = in_channels
610
+ out_channels = in_channels if out_channels is None else out_channels
611
+ self.out_channels = out_channels
612
+ self.use_conv_shortcut = conv_shortcut
613
+ self.up = up
614
+ self.down = down
615
+ self.output_scale_factor = output_scale_factor
616
+ self.time_embedding_norm = time_embedding_norm
617
+ self.skip_time_act = skip_time_act
618
+
619
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
620
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
621
+
622
+ if groups_out is None:
623
+ groups_out = groups
624
+
625
+ if self.time_embedding_norm == "ada_group":
626
+ self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
627
+ elif self.time_embedding_norm == "spatial":
628
+ self.norm1 = SpatialNorm(in_channels, temb_channels)
629
+ else:
630
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
631
+
632
+ self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
633
+
634
+ if temb_channels is not None:
635
+ if self.time_embedding_norm == "default":
636
+ self.time_emb_proj = linear_cls(temb_channels, out_channels)
637
+ elif self.time_embedding_norm == "scale_shift":
638
+ self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
639
+ elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
640
+ self.time_emb_proj = None
641
+ else:
642
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
643
+ else:
644
+ self.time_emb_proj = None
645
+
646
+ if self.time_embedding_norm == "ada_group":
647
+ self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
648
+ elif self.time_embedding_norm == "spatial":
649
+ self.norm2 = SpatialNorm(out_channels, temb_channels)
650
+ else:
651
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
652
+
653
+ self.dropout = torch.nn.Dropout(dropout)
654
+ conv_2d_out_channels = conv_2d_out_channels or out_channels
655
+ self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
656
+
657
+ self.nonlinearity = get_activation(non_linearity)
658
+
659
+ self.upsample = self.downsample = None
660
+ if self.up:
661
+ if kernel == "fir":
662
+ fir_kernel = (1, 3, 3, 1)
663
+ self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
664
+ elif kernel == "sde_vp":
665
+ self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
666
+ else:
667
+ self.upsample = Upsample2D(in_channels, use_conv=False)
668
+ elif self.down:
669
+ if kernel == "fir":
670
+ fir_kernel = (1, 3, 3, 1)
671
+ self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
672
+ elif kernel == "sde_vp":
673
+ self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
674
+ else:
675
+ self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
676
+
677
+ self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
678
+
679
+ self.conv_shortcut = None
680
+ if self.use_in_shortcut:
681
+ self.conv_shortcut = conv_cls(
682
+ in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
683
+ )
684
+
685
+ def forward(self, input_tensor, temb, scale: float = 1.0):
686
+ hidden_states = input_tensor
687
+
688
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
689
+ hidden_states = self.norm1(hidden_states, temb)
690
+ else:
691
+ hidden_states = self.norm1(hidden_states)
692
+
693
+ hidden_states = self.nonlinearity(hidden_states)
694
+
695
+ if self.upsample is not None:
696
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
697
+ if hidden_states.shape[0] >= 64:
698
+ input_tensor = input_tensor.contiguous()
699
+ hidden_states = hidden_states.contiguous()
700
+ input_tensor = (
701
+ self.upsample(input_tensor, scale=scale)
702
+ if isinstance(self.upsample, Upsample2D)
703
+ else self.upsample(input_tensor)
704
+ )
705
+ hidden_states = (
706
+ self.upsample(hidden_states, scale=scale)
707
+ if isinstance(self.upsample, Upsample2D)
708
+ else self.upsample(hidden_states)
709
+ )
710
+ elif self.downsample is not None:
711
+ input_tensor = (
712
+ self.downsample(input_tensor, scale=scale)
713
+ if isinstance(self.downsample, Downsample2D)
714
+ else self.downsample(input_tensor)
715
+ )
716
+ hidden_states = (
717
+ self.downsample(hidden_states, scale=scale)
718
+ if isinstance(self.downsample, Downsample2D)
719
+ else self.downsample(hidden_states)
720
+ )
721
+
722
+ hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states)
723
+
724
+ if self.time_emb_proj is not None:
725
+ if not self.skip_time_act:
726
+ temb = self.nonlinearity(temb)
727
+ temb = (
728
+ self.time_emb_proj(temb, scale)[:, :, None, None]
729
+ if not USE_PEFT_BACKEND
730
+ else self.time_emb_proj(temb)[:, :, None, None]
731
+ )
732
+
733
+ if temb is not None and self.time_embedding_norm == "default":
734
+ hidden_states = hidden_states + temb
735
+
736
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
737
+ hidden_states = self.norm2(hidden_states, temb)
738
+ else:
739
+ hidden_states = self.norm2(hidden_states)
740
+
741
+ if temb is not None and self.time_embedding_norm == "scale_shift":
742
+ scale, shift = torch.chunk(temb, 2, dim=1)
743
+ hidden_states = hidden_states * (1 + scale) + shift
744
+
745
+ hidden_states = self.nonlinearity(hidden_states)
746
+
747
+ hidden_states = self.dropout(hidden_states)
748
+ hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states)
749
+
750
+ if self.conv_shortcut is not None:
751
+ input_tensor = (
752
+ self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
753
+ )
754
+
755
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
756
+
757
+ return output_tensor
758
+
759
+
760
+ # unet_rl.py
761
+ def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor:
762
+ if len(tensor.shape) == 2:
763
+ return tensor[:, :, None]
764
+ if len(tensor.shape) == 3:
765
+ return tensor[:, :, None, :]
766
+ elif len(tensor.shape) == 4:
767
+ return tensor[:, :, 0, :]
768
+ else:
769
+ raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
770
+
771
+
772
+ class Conv1dBlock(nn.Module):
773
+ """
774
+ Conv1d --> GroupNorm --> Mish
775
+
776
+ Parameters:
777
+ inp_channels (`int`): Number of input channels.
778
+ out_channels (`int`): Number of output channels.
779
+ kernel_size (`int` or `tuple`): Size of the convolving kernel.
780
+ n_groups (`int`, default `8`): Number of groups to separate the channels into.
781
+ activation (`str`, defaults `mish`): Name of the activation function.
782
+ """
783
+
784
+ def __init__(
785
+ self,
786
+ inp_channels: int,
787
+ out_channels: int,
788
+ kernel_size: Union[int, Tuple[int, int]],
789
+ n_groups: int = 8,
790
+ activation: str = "mish",
791
+ ):
792
+ super().__init__()
793
+
794
+ self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
795
+ self.group_norm = nn.GroupNorm(n_groups, out_channels)
796
+ self.mish = get_activation(activation)
797
+
798
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
799
+ intermediate_repr = self.conv1d(inputs)
800
+ intermediate_repr = rearrange_dims(intermediate_repr)
801
+ intermediate_repr = self.group_norm(intermediate_repr)
802
+ intermediate_repr = rearrange_dims(intermediate_repr)
803
+ output = self.mish(intermediate_repr)
804
+ return output
805
+
806
+
807
+ # unet_rl.py
808
+ class ResidualTemporalBlock1D(nn.Module):
809
+ """
810
+ Residual 1D block with temporal convolutions.
811
+
812
+ Parameters:
813
+ inp_channels (`int`): Number of input channels.
814
+ out_channels (`int`): Number of output channels.
815
+ embed_dim (`int`): Embedding dimension.
816
+ kernel_size (`int` or `tuple`): Size of the convolving kernel.
817
+ activation (`str`, defaults `mish`): It is possible to choose the right activation function.
818
+ """
819
+
820
+ def __init__(
821
+ self,
822
+ inp_channels: int,
823
+ out_channels: int,
824
+ embed_dim: int,
825
+ kernel_size: Union[int, Tuple[int, int]] = 5,
826
+ activation: str = "mish",
827
+ ):
828
+ super().__init__()
829
+ self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
830
+ self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
831
+
832
+ self.time_emb_act = get_activation(activation)
833
+ self.time_emb = nn.Linear(embed_dim, out_channels)
834
+
835
+ self.residual_conv = (
836
+ nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
837
+ )
838
+
839
+ def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
840
+ """
841
+ Args:
842
+ inputs : [ batch_size x inp_channels x horizon ]
843
+ t : [ batch_size x embed_dim ]
844
+
845
+ returns:
846
+ out : [ batch_size x out_channels x horizon ]
847
+ """
848
+ t = self.time_emb_act(t)
849
+ t = self.time_emb(t)
850
+ out = self.conv_in(inputs) + rearrange_dims(t)
851
+ out = self.conv_out(out)
852
+ return out + self.residual_conv(inputs)
853
+
854
+
855
+ def upsample_2d(
856
+ hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
857
+ ) -> torch.Tensor:
858
+ r"""Upsample2D a batch of 2D images with the given filter.
859
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
860
+ filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
861
+ `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
862
+ a: multiple of the upsampling factor.
863
+
864
+ Args:
865
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
866
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
867
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
868
+ factor: Integer upsampling factor (default: 2).
869
+ gain: Scaling factor for signal magnitude (default: 1.0).
870
+
871
+ Returns:
872
+ output: Tensor of the shape `[N, C, H * factor, W * factor]`
873
+ """
874
+ assert isinstance(factor, int) and factor >= 1
875
+ if kernel is None:
876
+ kernel = [1] * factor
877
+
878
+ kernel = torch.tensor(kernel, dtype=torch.float32)
879
+ if kernel.ndim == 1:
880
+ kernel = torch.outer(kernel, kernel)
881
+ kernel /= torch.sum(kernel)
882
+
883
+ kernel = kernel * (gain * (factor**2))
884
+ pad_value = kernel.shape[0] - factor
885
+ output = upfirdn2d_native(
886
+ hidden_states,
887
+ kernel.to(device=hidden_states.device),
888
+ up=factor,
889
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
890
+ )
891
+ return output
892
+
893
+
894
+ def downsample_2d(
895
+ hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
896
+ ) -> torch.Tensor:
897
+ r"""Downsample2D a batch of 2D images with the given filter.
898
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
899
+ given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
900
+ specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
901
+ shape is a multiple of the downsampling factor.
902
+
903
+ Args:
904
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
905
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
906
+ (separable). The default is `[1] * factor`, which corresponds to average pooling.
907
+ factor: Integer downsampling factor (default: 2).
908
+ gain: Scaling factor for signal magnitude (default: 1.0).
909
+
910
+ Returns:
911
+ output: Tensor of the shape `[N, C, H // factor, W // factor]`
912
+ """
913
+
914
+ assert isinstance(factor, int) and factor >= 1
915
+ if kernel is None:
916
+ kernel = [1] * factor
917
+
918
+ kernel = torch.tensor(kernel, dtype=torch.float32)
919
+ if kernel.ndim == 1:
920
+ kernel = torch.outer(kernel, kernel)
921
+ kernel /= torch.sum(kernel)
922
+
923
+ kernel = kernel * gain
924
+ pad_value = kernel.shape[0] - factor
925
+ output = upfirdn2d_native(
926
+ hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
927
+ )
928
+ return output
929
+
930
+
931
+ def upfirdn2d_native(
932
+ tensor: torch.Tensor, kernel: torch.Tensor, up: int = 1, down: int = 1, pad: Tuple[int, int] = (0, 0)
933
+ ) -> torch.Tensor:
934
+ up_x = up_y = up
935
+ down_x = down_y = down
936
+ pad_x0 = pad_y0 = pad[0]
937
+ pad_x1 = pad_y1 = pad[1]
938
+
939
+ _, channel, in_h, in_w = tensor.shape
940
+ tensor = tensor.reshape(-1, in_h, in_w, 1)
941
+
942
+ _, in_h, in_w, minor = tensor.shape
943
+ kernel_h, kernel_w = kernel.shape
944
+
945
+ out = tensor.view(-1, in_h, 1, in_w, 1, minor)
946
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
947
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
948
+
949
+ out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
950
+ out = out.to(tensor.device) # Move back to mps if necessary
951
+ out = out[
952
+ :,
953
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
954
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
955
+ :,
956
+ ]
957
+
958
+ out = out.permute(0, 3, 1, 2)
959
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
960
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
961
+ out = F.conv2d(out, w)
962
+ out = out.reshape(
963
+ -1,
964
+ minor,
965
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
966
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
967
+ )
968
+ out = out.permute(0, 2, 3, 1)
969
+ out = out[:, ::down_y, ::down_x, :]
970
+
971
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
972
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
973
+
974
+ return out.view(-1, channel, out_h, out_w)
975
+
976
+
977
+ class TemporalConvLayer(nn.Module):
978
+ """
979
+ Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
980
+ https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
981
+
982
+ Parameters:
983
+ in_dim (`int`): Number of input channels.
984
+ out_dim (`int`): Number of output channels.
985
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
986
+ """
987
+
988
+ def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0):
989
+ super().__init__()
990
+ out_dim = out_dim or in_dim
991
+ self.in_dim = in_dim
992
+ self.out_dim = out_dim
993
+
994
+ # conv layers
995
+ self.conv1 = nn.Sequential(
996
+ nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
997
+ )
998
+ self.conv2 = nn.Sequential(
999
+ nn.GroupNorm(32, out_dim),
1000
+ nn.SiLU(),
1001
+ nn.Dropout(dropout),
1002
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
1003
+ )
1004
+ self.conv3 = nn.Sequential(
1005
+ nn.GroupNorm(32, out_dim),
1006
+ nn.SiLU(),
1007
+ nn.Dropout(dropout),
1008
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
1009
+ )
1010
+ self.conv4 = nn.Sequential(
1011
+ nn.GroupNorm(32, out_dim),
1012
+ nn.SiLU(),
1013
+ nn.Dropout(dropout),
1014
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
1015
+ )
1016
+
1017
+ # zero out the last layer params,so the conv block is identity
1018
+ nn.init.zeros_(self.conv4[-1].weight)
1019
+ nn.init.zeros_(self.conv4[-1].bias)
1020
+
1021
+ def forward(self, hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Tensor:
1022
+ hidden_states = (
1023
+ hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
1024
+ )
1025
+
1026
+ identity = hidden_states
1027
+ hidden_states = self.conv1(hidden_states)
1028
+ hidden_states = self.conv2(hidden_states)
1029
+ hidden_states = self.conv3(hidden_states)
1030
+ hidden_states = self.conv4(hidden_states)
1031
+
1032
+ hidden_states = identity + hidden_states
1033
+
1034
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
1035
+ (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
1036
+ )
1037
+ return hidden_states
diffusers/models/resnet_flax.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import flax.linen as nn
15
+ import jax
16
+ import jax.numpy as jnp
17
+
18
+
19
+ class FlaxUpsample2D(nn.Module):
20
+ out_channels: int
21
+ dtype: jnp.dtype = jnp.float32
22
+
23
+ def setup(self):
24
+ self.conv = nn.Conv(
25
+ self.out_channels,
26
+ kernel_size=(3, 3),
27
+ strides=(1, 1),
28
+ padding=((1, 1), (1, 1)),
29
+ dtype=self.dtype,
30
+ )
31
+
32
+ def __call__(self, hidden_states):
33
+ batch, height, width, channels = hidden_states.shape
34
+ hidden_states = jax.image.resize(
35
+ hidden_states,
36
+ shape=(batch, height * 2, width * 2, channels),
37
+ method="nearest",
38
+ )
39
+ hidden_states = self.conv(hidden_states)
40
+ return hidden_states
41
+
42
+
43
+ class FlaxDownsample2D(nn.Module):
44
+ out_channels: int
45
+ dtype: jnp.dtype = jnp.float32
46
+
47
+ def setup(self):
48
+ self.conv = nn.Conv(
49
+ self.out_channels,
50
+ kernel_size=(3, 3),
51
+ strides=(2, 2),
52
+ padding=((1, 1), (1, 1)), # padding="VALID",
53
+ dtype=self.dtype,
54
+ )
55
+
56
+ def __call__(self, hidden_states):
57
+ # pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
58
+ # hidden_states = jnp.pad(hidden_states, pad_width=pad)
59
+ hidden_states = self.conv(hidden_states)
60
+ return hidden_states
61
+
62
+
63
+ class FlaxResnetBlock2D(nn.Module):
64
+ in_channels: int
65
+ out_channels: int = None
66
+ dropout_prob: float = 0.0
67
+ use_nin_shortcut: bool = None
68
+ dtype: jnp.dtype = jnp.float32
69
+
70
+ def setup(self):
71
+ out_channels = self.in_channels if self.out_channels is None else self.out_channels
72
+
73
+ self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
74
+ self.conv1 = nn.Conv(
75
+ out_channels,
76
+ kernel_size=(3, 3),
77
+ strides=(1, 1),
78
+ padding=((1, 1), (1, 1)),
79
+ dtype=self.dtype,
80
+ )
81
+
82
+ self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype)
83
+
84
+ self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
85
+ self.dropout = nn.Dropout(self.dropout_prob)
86
+ self.conv2 = nn.Conv(
87
+ out_channels,
88
+ kernel_size=(3, 3),
89
+ strides=(1, 1),
90
+ padding=((1, 1), (1, 1)),
91
+ dtype=self.dtype,
92
+ )
93
+
94
+ use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
95
+
96
+ self.conv_shortcut = None
97
+ if use_nin_shortcut:
98
+ self.conv_shortcut = nn.Conv(
99
+ out_channels,
100
+ kernel_size=(1, 1),
101
+ strides=(1, 1),
102
+ padding="VALID",
103
+ dtype=self.dtype,
104
+ )
105
+
106
+ def __call__(self, hidden_states, temb, deterministic=True):
107
+ residual = hidden_states
108
+ hidden_states = self.norm1(hidden_states)
109
+ hidden_states = nn.swish(hidden_states)
110
+ hidden_states = self.conv1(hidden_states)
111
+
112
+ temb = self.time_emb_proj(nn.swish(temb))
113
+ temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
114
+ hidden_states = hidden_states + temb
115
+
116
+ hidden_states = self.norm2(hidden_states)
117
+ hidden_states = nn.swish(hidden_states)
118
+ hidden_states = self.dropout(hidden_states, deterministic)
119
+ hidden_states = self.conv2(hidden_states)
120
+
121
+ if self.conv_shortcut is not None:
122
+ residual = self.conv_shortcut(residual)
123
+
124
+ return hidden_states + residual
diffusers/models/t5_film_transformer.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from typing import Optional, Tuple
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from ..configuration_utils import ConfigMixin, register_to_config
21
+ from .attention_processor import Attention
22
+ from .embeddings import get_timestep_embedding
23
+ from .modeling_utils import ModelMixin
24
+
25
+
26
+ class T5FilmDecoder(ModelMixin, ConfigMixin):
27
+ r"""
28
+ T5 style decoder with FiLM conditioning.
29
+
30
+ Args:
31
+ input_dims (`int`, *optional*, defaults to `128`):
32
+ The number of input dimensions.
33
+ targets_length (`int`, *optional*, defaults to `256`):
34
+ The length of the targets.
35
+ d_model (`int`, *optional*, defaults to `768`):
36
+ Size of the input hidden states.
37
+ num_layers (`int`, *optional*, defaults to `12`):
38
+ The number of `DecoderLayer`'s to use.
39
+ num_heads (`int`, *optional*, defaults to `12`):
40
+ The number of attention heads to use.
41
+ d_kv (`int`, *optional*, defaults to `64`):
42
+ Size of the key-value projection vectors.
43
+ d_ff (`int`, *optional*, defaults to `2048`):
44
+ The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s.
45
+ dropout_rate (`float`, *optional*, defaults to `0.1`):
46
+ Dropout probability.
47
+ """
48
+
49
+ @register_to_config
50
+ def __init__(
51
+ self,
52
+ input_dims: int = 128,
53
+ targets_length: int = 256,
54
+ max_decoder_noise_time: float = 2000.0,
55
+ d_model: int = 768,
56
+ num_layers: int = 12,
57
+ num_heads: int = 12,
58
+ d_kv: int = 64,
59
+ d_ff: int = 2048,
60
+ dropout_rate: float = 0.1,
61
+ ):
62
+ super().__init__()
63
+
64
+ self.conditioning_emb = nn.Sequential(
65
+ nn.Linear(d_model, d_model * 4, bias=False),
66
+ nn.SiLU(),
67
+ nn.Linear(d_model * 4, d_model * 4, bias=False),
68
+ nn.SiLU(),
69
+ )
70
+
71
+ self.position_encoding = nn.Embedding(targets_length, d_model)
72
+ self.position_encoding.weight.requires_grad = False
73
+
74
+ self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False)
75
+
76
+ self.dropout = nn.Dropout(p=dropout_rate)
77
+
78
+ self.decoders = nn.ModuleList()
79
+ for lyr_num in range(num_layers):
80
+ # FiLM conditional T5 decoder
81
+ lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate)
82
+ self.decoders.append(lyr)
83
+
84
+ self.decoder_norm = T5LayerNorm(d_model)
85
+
86
+ self.post_dropout = nn.Dropout(p=dropout_rate)
87
+ self.spec_out = nn.Linear(d_model, input_dims, bias=False)
88
+
89
+ def encoder_decoder_mask(self, query_input: torch.FloatTensor, key_input: torch.FloatTensor) -> torch.FloatTensor:
90
+ mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
91
+ return mask.unsqueeze(-3)
92
+
93
+ def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time):
94
+ batch, _, _ = decoder_input_tokens.shape
95
+ assert decoder_noise_time.shape == (batch,)
96
+
97
+ # decoder_noise_time is in [0, 1), so rescale to expected timing range.
98
+ time_steps = get_timestep_embedding(
99
+ decoder_noise_time * self.config.max_decoder_noise_time,
100
+ embedding_dim=self.config.d_model,
101
+ max_period=self.config.max_decoder_noise_time,
102
+ ).to(dtype=self.dtype)
103
+
104
+ conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1)
105
+
106
+ assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4)
107
+
108
+ seq_length = decoder_input_tokens.shape[1]
109
+
110
+ # If we want to use relative positions for audio context, we can just offset
111
+ # this sequence by the length of encodings_and_masks.
112
+ decoder_positions = torch.broadcast_to(
113
+ torch.arange(seq_length, device=decoder_input_tokens.device),
114
+ (batch, seq_length),
115
+ )
116
+
117
+ position_encodings = self.position_encoding(decoder_positions)
118
+
119
+ inputs = self.continuous_inputs_projection(decoder_input_tokens)
120
+ inputs += position_encodings
121
+ y = self.dropout(inputs)
122
+
123
+ # decoder: No padding present.
124
+ decoder_mask = torch.ones(
125
+ decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype
126
+ )
127
+
128
+ # Translate encoding masks to encoder-decoder masks.
129
+ encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks]
130
+
131
+ # cross attend style: concat encodings
132
+ encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1)
133
+ encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1)
134
+
135
+ for lyr in self.decoders:
136
+ y = lyr(
137
+ y,
138
+ conditioning_emb=conditioning_emb,
139
+ encoder_hidden_states=encoded,
140
+ encoder_attention_mask=encoder_decoder_mask,
141
+ )[0]
142
+
143
+ y = self.decoder_norm(y)
144
+ y = self.post_dropout(y)
145
+
146
+ spec_out = self.spec_out(y)
147
+ return spec_out
148
+
149
+
150
+ class DecoderLayer(nn.Module):
151
+ r"""
152
+ T5 decoder layer.
153
+
154
+ Args:
155
+ d_model (`int`):
156
+ Size of the input hidden states.
157
+ d_kv (`int`):
158
+ Size of the key-value projection vectors.
159
+ num_heads (`int`):
160
+ Number of attention heads.
161
+ d_ff (`int`):
162
+ Size of the intermediate feed-forward layer.
163
+ dropout_rate (`float`):
164
+ Dropout probability.
165
+ layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`):
166
+ A small value used for numerical stability to avoid dividing by zero.
167
+ """
168
+
169
+ def __init__(
170
+ self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6
171
+ ):
172
+ super().__init__()
173
+ self.layer = nn.ModuleList()
174
+
175
+ # cond self attention: layer 0
176
+ self.layer.append(
177
+ T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate)
178
+ )
179
+
180
+ # cross attention: layer 1
181
+ self.layer.append(
182
+ T5LayerCrossAttention(
183
+ d_model=d_model,
184
+ d_kv=d_kv,
185
+ num_heads=num_heads,
186
+ dropout_rate=dropout_rate,
187
+ layer_norm_epsilon=layer_norm_epsilon,
188
+ )
189
+ )
190
+
191
+ # Film Cond MLP + dropout: last layer
192
+ self.layer.append(
193
+ T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon)
194
+ )
195
+
196
+ def forward(
197
+ self,
198
+ hidden_states: torch.FloatTensor,
199
+ conditioning_emb: Optional[torch.FloatTensor] = None,
200
+ attention_mask: Optional[torch.FloatTensor] = None,
201
+ encoder_hidden_states: Optional[torch.Tensor] = None,
202
+ encoder_attention_mask: Optional[torch.Tensor] = None,
203
+ encoder_decoder_position_bias=None,
204
+ ) -> Tuple[torch.FloatTensor]:
205
+ hidden_states = self.layer[0](
206
+ hidden_states,
207
+ conditioning_emb=conditioning_emb,
208
+ attention_mask=attention_mask,
209
+ )
210
+
211
+ if encoder_hidden_states is not None:
212
+ encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to(
213
+ encoder_hidden_states.dtype
214
+ )
215
+
216
+ hidden_states = self.layer[1](
217
+ hidden_states,
218
+ key_value_states=encoder_hidden_states,
219
+ attention_mask=encoder_extended_attention_mask,
220
+ )
221
+
222
+ # Apply Film Conditional Feed Forward layer
223
+ hidden_states = self.layer[-1](hidden_states, conditioning_emb)
224
+
225
+ return (hidden_states,)
226
+
227
+
228
+ class T5LayerSelfAttentionCond(nn.Module):
229
+ r"""
230
+ T5 style self-attention layer with conditioning.
231
+
232
+ Args:
233
+ d_model (`int`):
234
+ Size of the input hidden states.
235
+ d_kv (`int`):
236
+ Size of the key-value projection vectors.
237
+ num_heads (`int`):
238
+ Number of attention heads.
239
+ dropout_rate (`float`):
240
+ Dropout probability.
241
+ """
242
+
243
+ def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float):
244
+ super().__init__()
245
+ self.layer_norm = T5LayerNorm(d_model)
246
+ self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
247
+ self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
248
+ self.dropout = nn.Dropout(dropout_rate)
249
+
250
+ def forward(
251
+ self,
252
+ hidden_states: torch.FloatTensor,
253
+ conditioning_emb: Optional[torch.FloatTensor] = None,
254
+ attention_mask: Optional[torch.FloatTensor] = None,
255
+ ) -> torch.FloatTensor:
256
+ # pre_self_attention_layer_norm
257
+ normed_hidden_states = self.layer_norm(hidden_states)
258
+
259
+ if conditioning_emb is not None:
260
+ normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb)
261
+
262
+ # Self-attention block
263
+ attention_output = self.attention(normed_hidden_states)
264
+
265
+ hidden_states = hidden_states + self.dropout(attention_output)
266
+
267
+ return hidden_states
268
+
269
+
270
+ class T5LayerCrossAttention(nn.Module):
271
+ r"""
272
+ T5 style cross-attention layer.
273
+
274
+ Args:
275
+ d_model (`int`):
276
+ Size of the input hidden states.
277
+ d_kv (`int`):
278
+ Size of the key-value projection vectors.
279
+ num_heads (`int`):
280
+ Number of attention heads.
281
+ dropout_rate (`float`):
282
+ Dropout probability.
283
+ layer_norm_epsilon (`float`):
284
+ A small value used for numerical stability to avoid dividing by zero.
285
+ """
286
+
287
+ def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float):
288
+ super().__init__()
289
+ self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
290
+ self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
291
+ self.dropout = nn.Dropout(dropout_rate)
292
+
293
+ def forward(
294
+ self,
295
+ hidden_states: torch.FloatTensor,
296
+ key_value_states: Optional[torch.FloatTensor] = None,
297
+ attention_mask: Optional[torch.FloatTensor] = None,
298
+ ) -> torch.FloatTensor:
299
+ normed_hidden_states = self.layer_norm(hidden_states)
300
+ attention_output = self.attention(
301
+ normed_hidden_states,
302
+ encoder_hidden_states=key_value_states,
303
+ attention_mask=attention_mask.squeeze(1),
304
+ )
305
+ layer_output = hidden_states + self.dropout(attention_output)
306
+ return layer_output
307
+
308
+
309
+ class T5LayerFFCond(nn.Module):
310
+ r"""
311
+ T5 style feed-forward conditional layer.
312
+
313
+ Args:
314
+ d_model (`int`):
315
+ Size of the input hidden states.
316
+ d_ff (`int`):
317
+ Size of the intermediate feed-forward layer.
318
+ dropout_rate (`float`):
319
+ Dropout probability.
320
+ layer_norm_epsilon (`float`):
321
+ A small value used for numerical stability to avoid dividing by zero.
322
+ """
323
+
324
+ def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float):
325
+ super().__init__()
326
+ self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
327
+ self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
328
+ self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
329
+ self.dropout = nn.Dropout(dropout_rate)
330
+
331
+ def forward(
332
+ self, hidden_states: torch.FloatTensor, conditioning_emb: Optional[torch.FloatTensor] = None
333
+ ) -> torch.FloatTensor:
334
+ forwarded_states = self.layer_norm(hidden_states)
335
+ if conditioning_emb is not None:
336
+ forwarded_states = self.film(forwarded_states, conditioning_emb)
337
+
338
+ forwarded_states = self.DenseReluDense(forwarded_states)
339
+ hidden_states = hidden_states + self.dropout(forwarded_states)
340
+ return hidden_states
341
+
342
+
343
+ class T5DenseGatedActDense(nn.Module):
344
+ r"""
345
+ T5 style feed-forward layer with gated activations and dropout.
346
+
347
+ Args:
348
+ d_model (`int`):
349
+ Size of the input hidden states.
350
+ d_ff (`int`):
351
+ Size of the intermediate feed-forward layer.
352
+ dropout_rate (`float`):
353
+ Dropout probability.
354
+ """
355
+
356
+ def __init__(self, d_model: int, d_ff: int, dropout_rate: float):
357
+ super().__init__()
358
+ self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
359
+ self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
360
+ self.wo = nn.Linear(d_ff, d_model, bias=False)
361
+ self.dropout = nn.Dropout(dropout_rate)
362
+ self.act = NewGELUActivation()
363
+
364
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
365
+ hidden_gelu = self.act(self.wi_0(hidden_states))
366
+ hidden_linear = self.wi_1(hidden_states)
367
+ hidden_states = hidden_gelu * hidden_linear
368
+ hidden_states = self.dropout(hidden_states)
369
+
370
+ hidden_states = self.wo(hidden_states)
371
+ return hidden_states
372
+
373
+
374
+ class T5LayerNorm(nn.Module):
375
+ r"""
376
+ T5 style layer normalization module.
377
+
378
+ Args:
379
+ hidden_size (`int`):
380
+ Size of the input hidden states.
381
+ eps (`float`, `optional`, defaults to `1e-6`):
382
+ A small value used for numerical stability to avoid dividing by zero.
383
+ """
384
+
385
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
386
+ """
387
+ Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
388
+ """
389
+ super().__init__()
390
+ self.weight = nn.Parameter(torch.ones(hidden_size))
391
+ self.variance_epsilon = eps
392
+
393
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
394
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
395
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
396
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
397
+ # half-precision inputs is done in fp32
398
+
399
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
400
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
401
+
402
+ # convert into half-precision if necessary
403
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
404
+ hidden_states = hidden_states.to(self.weight.dtype)
405
+
406
+ return self.weight * hidden_states
407
+
408
+
409
+ class NewGELUActivation(nn.Module):
410
+ """
411
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
412
+ the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
413
+ """
414
+
415
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
416
+ return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
417
+
418
+
419
+ class T5FiLMLayer(nn.Module):
420
+ """
421
+ T5 style FiLM Layer.
422
+
423
+ Args:
424
+ in_features (`int`):
425
+ Number of input features.
426
+ out_features (`int`):
427
+ Number of output features.
428
+ """
429
+
430
+ def __init__(self, in_features: int, out_features: int):
431
+ super().__init__()
432
+ self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
433
+
434
+ def forward(self, x: torch.FloatTensor, conditioning_emb: torch.FloatTensor) -> torch.FloatTensor:
435
+ emb = self.scale_bias(conditioning_emb)
436
+ scale, shift = torch.chunk(emb, 2, -1)
437
+ x = x * (1 + scale) + shift
438
+ return x
diffusers/models/transformer_2d.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..models.embeddings import ImagePositionalEmbeddings
23
+ from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate
24
+ from .attention import BasicTransformerBlock
25
+ from .embeddings import CaptionProjection, PatchEmbed
26
+ from .lora import LoRACompatibleConv, LoRACompatibleLinear
27
+ from .modeling_utils import ModelMixin
28
+ from .normalization import AdaLayerNormSingle
29
+
30
+
31
+ @dataclass
32
+ class Transformer2DModelOutput(BaseOutput):
33
+ """
34
+ The output of [`Transformer2DModel`].
35
+
36
+ Args:
37
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
38
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
39
+ distributions for the unnoised latent pixels.
40
+ """
41
+
42
+ sample: torch.FloatTensor
43
+
44
+
45
+ class Transformer2DModel(ModelMixin, ConfigMixin):
46
+ """
47
+ A 2D Transformer model for image-like data.
48
+
49
+ Parameters:
50
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
51
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
52
+ in_channels (`int`, *optional*):
53
+ The number of channels in the input and output (specify if the input is **continuous**).
54
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
55
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
56
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
57
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
58
+ This is fixed during training since it is used to learn a number of position embeddings.
59
+ num_vector_embeds (`int`, *optional*):
60
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
61
+ Includes the class for the masked latent pixel.
62
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
63
+ num_embeds_ada_norm ( `int`, *optional*):
64
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
65
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
66
+ added to the hidden states.
67
+
68
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
69
+ attention_bias (`bool`, *optional*):
70
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
71
+ """
72
+
73
+ @register_to_config
74
+ def __init__(
75
+ self,
76
+ num_attention_heads: int = 16,
77
+ attention_head_dim: int = 88,
78
+ in_channels: Optional[int] = None,
79
+ out_channels: Optional[int] = None,
80
+ num_layers: int = 1,
81
+ dropout: float = 0.0,
82
+ norm_num_groups: int = 32,
83
+ cross_attention_dim: Optional[int] = None,
84
+ attention_bias: bool = False,
85
+ sample_size: Optional[int] = None,
86
+ num_vector_embeds: Optional[int] = None,
87
+ patch_size: Optional[int] = None,
88
+ activation_fn: str = "geglu",
89
+ num_embeds_ada_norm: Optional[int] = None,
90
+ use_linear_projection: bool = False,
91
+ only_cross_attention: bool = False,
92
+ double_self_attention: bool = False,
93
+ upcast_attention: bool = False,
94
+ norm_type: str = "layer_norm",
95
+ norm_elementwise_affine: bool = True,
96
+ norm_eps: float = 1e-5,
97
+ attention_type: str = "default",
98
+ caption_channels: int = None,
99
+ ):
100
+ super().__init__()
101
+ self.use_linear_projection = use_linear_projection
102
+ self.num_attention_heads = num_attention_heads
103
+ self.attention_head_dim = attention_head_dim
104
+ inner_dim = num_attention_heads * attention_head_dim
105
+
106
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
107
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
108
+
109
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
110
+ # Define whether input is continuous or discrete depending on configuration
111
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
112
+ self.is_input_vectorized = num_vector_embeds is not None
113
+ self.is_input_patches = in_channels is not None and patch_size is not None
114
+
115
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
116
+ deprecation_message = (
117
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
118
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
119
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
120
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
121
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
122
+ )
123
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
124
+ norm_type = "ada_norm"
125
+
126
+ if self.is_input_continuous and self.is_input_vectorized:
127
+ raise ValueError(
128
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
129
+ " sure that either `in_channels` or `num_vector_embeds` is None."
130
+ )
131
+ elif self.is_input_vectorized and self.is_input_patches:
132
+ raise ValueError(
133
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
134
+ " sure that either `num_vector_embeds` or `num_patches` is None."
135
+ )
136
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
137
+ raise ValueError(
138
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
139
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
140
+ )
141
+
142
+ # 2. Define input layers
143
+ if self.is_input_continuous:
144
+ self.in_channels = in_channels
145
+
146
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
147
+ if use_linear_projection:
148
+ self.proj_in = linear_cls(in_channels, inner_dim)
149
+ else:
150
+ self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
151
+ elif self.is_input_vectorized:
152
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
153
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
154
+
155
+ self.height = sample_size
156
+ self.width = sample_size
157
+ self.num_vector_embeds = num_vector_embeds
158
+ self.num_latent_pixels = self.height * self.width
159
+
160
+ self.latent_image_embedding = ImagePositionalEmbeddings(
161
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
162
+ )
163
+ elif self.is_input_patches:
164
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
165
+
166
+ self.height = sample_size
167
+ self.width = sample_size
168
+
169
+ self.patch_size = patch_size
170
+ interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
171
+ interpolation_scale = max(interpolation_scale, 1)
172
+ self.pos_embed = PatchEmbed(
173
+ height=sample_size,
174
+ width=sample_size,
175
+ patch_size=patch_size,
176
+ in_channels=in_channels,
177
+ embed_dim=inner_dim,
178
+ interpolation_scale=interpolation_scale,
179
+ )
180
+
181
+ # 3. Define transformers blocks
182
+ self.transformer_blocks = nn.ModuleList(
183
+ [
184
+ BasicTransformerBlock(
185
+ inner_dim,
186
+ num_attention_heads,
187
+ attention_head_dim,
188
+ dropout=dropout,
189
+ cross_attention_dim=cross_attention_dim,
190
+ activation_fn=activation_fn,
191
+ num_embeds_ada_norm=num_embeds_ada_norm,
192
+ attention_bias=attention_bias,
193
+ only_cross_attention=only_cross_attention,
194
+ double_self_attention=double_self_attention,
195
+ upcast_attention=upcast_attention,
196
+ norm_type=norm_type,
197
+ norm_elementwise_affine=norm_elementwise_affine,
198
+ norm_eps=norm_eps,
199
+ attention_type=attention_type,
200
+ )
201
+ for d in range(num_layers)
202
+ ]
203
+ )
204
+
205
+ # 4. Define output layers
206
+ self.out_channels = in_channels if out_channels is None else out_channels
207
+ if self.is_input_continuous:
208
+ # TODO: should use out_channels for continuous projections
209
+ if use_linear_projection:
210
+ self.proj_out = linear_cls(inner_dim, in_channels)
211
+ else:
212
+ self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
213
+ elif self.is_input_vectorized:
214
+ self.norm_out = nn.LayerNorm(inner_dim)
215
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
216
+ elif self.is_input_patches and norm_type != "ada_norm_single":
217
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
218
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
219
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
220
+ elif self.is_input_patches and norm_type == "ada_norm_single":
221
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
222
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
223
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
224
+
225
+ # 5. PixArt-Alpha blocks.
226
+ self.adaln_single = None
227
+ self.use_additional_conditions = False
228
+ if norm_type == "ada_norm_single":
229
+ self.use_additional_conditions = self.config.sample_size == 128
230
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
231
+ # additional conditions until we find better name
232
+ self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
233
+
234
+ self.caption_projection = None
235
+ if caption_channels is not None:
236
+ self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim)
237
+
238
+ self.gradient_checkpointing = False
239
+
240
+ def forward(
241
+ self,
242
+ hidden_states: torch.Tensor,
243
+ encoder_hidden_states: Optional[torch.Tensor] = None,
244
+ timestep: Optional[torch.LongTensor] = None,
245
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
246
+ class_labels: Optional[torch.LongTensor] = None,
247
+ cross_attention_kwargs: Dict[str, Any] = None,
248
+ attention_mask: Optional[torch.Tensor] = None,
249
+ encoder_attention_mask: Optional[torch.Tensor] = None,
250
+ return_dict: bool = True,
251
+ ):
252
+ """
253
+ The [`Transformer2DModel`] forward method.
254
+
255
+ Args:
256
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
257
+ Input `hidden_states`.
258
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
259
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
260
+ self-attention.
261
+ timestep ( `torch.LongTensor`, *optional*):
262
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
263
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
264
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
265
+ `AdaLayerZeroNorm`.
266
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
267
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
268
+ `self.processor` in
269
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
270
+ attention_mask ( `torch.Tensor`, *optional*):
271
+ An attention face_hair_mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the face_hair_mask
272
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
273
+ negative values to the attention scores corresponding to "discard" tokens.
274
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
275
+ Cross-attention face_hair_mask applied to `encoder_hidden_states`. Two formats supported:
276
+
277
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
278
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
279
+
280
+ If `ndim == 2`: will be interpreted as a face_hair_mask, then converted into a bias consistent with the format
281
+ above. This bias will be added to the cross-attention scores.
282
+ return_dict (`bool`, *optional*, defaults to `True`):
283
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
284
+ tuple.
285
+
286
+ Returns:
287
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
288
+ `tuple` where the first element is the sample tensor.
289
+ """
290
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
291
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
292
+ # we can tell by counting dims; if ndim == 2: it's a face_hair_mask rather than a bias.
293
+ # expects face_hair_mask of shape:
294
+ # [batch, key_tokens]
295
+ # adds singleton query_tokens dimension:
296
+ # [batch, 1, key_tokens]
297
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
298
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
299
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
300
+ if attention_mask is not None and attention_mask.ndim == 2:
301
+ # assume that face_hair_mask is expressed as:
302
+ # (1 = keep, 0 = discard)
303
+ # convert face_hair_mask into a bias that can be added to attention scores:
304
+ # (keep = +0, discard = -10000.0)
305
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
306
+ attention_mask = attention_mask.unsqueeze(1)
307
+
308
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
309
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
310
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
311
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
312
+
313
+ # Retrieve lora scale.
314
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
315
+
316
+ # 1. Input
317
+ if self.is_input_continuous:
318
+ batch, _, height, width = hidden_states.shape
319
+ residual = hidden_states
320
+
321
+ hidden_states = self.norm(hidden_states)
322
+ if not self.use_linear_projection:
323
+ hidden_states = (
324
+ self.proj_in(hidden_states, scale=lora_scale)
325
+ if not USE_PEFT_BACKEND
326
+ else self.proj_in(hidden_states)
327
+ )
328
+ inner_dim = hidden_states.shape[1]
329
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
330
+ else:
331
+ inner_dim = hidden_states.shape[1]
332
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
333
+ hidden_states = (
334
+ self.proj_in(hidden_states, scale=lora_scale)
335
+ if not USE_PEFT_BACKEND
336
+ else self.proj_in(hidden_states)
337
+ )
338
+
339
+ elif self.is_input_vectorized:
340
+ hidden_states = self.latent_image_embedding(hidden_states)
341
+ elif self.is_input_patches:
342
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
343
+ hidden_states = self.pos_embed(hidden_states)
344
+
345
+ if self.adaln_single is not None:
346
+ if self.use_additional_conditions and added_cond_kwargs is None:
347
+ raise ValueError(
348
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
349
+ )
350
+ batch_size = hidden_states.shape[0]
351
+ timestep, embedded_timestep = self.adaln_single(
352
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
353
+ )
354
+
355
+ # 2. Blocks
356
+ if self.caption_projection is not None:
357
+ batch_size = hidden_states.shape[0]
358
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
359
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
360
+
361
+ for block in self.transformer_blocks:
362
+ if self.training and self.gradient_checkpointing:
363
+ hidden_states = torch.utils.checkpoint.checkpoint(
364
+ block,
365
+ hidden_states,
366
+ attention_mask,
367
+ encoder_hidden_states,
368
+ encoder_attention_mask,
369
+ timestep,
370
+ cross_attention_kwargs,
371
+ class_labels,
372
+ use_reentrant=False,
373
+ )
374
+ else:
375
+ hidden_states = block(
376
+ hidden_states,
377
+ attention_mask=attention_mask,
378
+ encoder_hidden_states=encoder_hidden_states,
379
+ encoder_attention_mask=encoder_attention_mask,
380
+ timestep=timestep,
381
+ cross_attention_kwargs=cross_attention_kwargs,
382
+ class_labels=class_labels,
383
+ )
384
+
385
+ # 3. Output
386
+ if self.is_input_continuous:
387
+ if not self.use_linear_projection:
388
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
389
+ hidden_states = (
390
+ self.proj_out(hidden_states, scale=lora_scale)
391
+ if not USE_PEFT_BACKEND
392
+ else self.proj_out(hidden_states)
393
+ )
394
+ else:
395
+ hidden_states = (
396
+ self.proj_out(hidden_states, scale=lora_scale)
397
+ if not USE_PEFT_BACKEND
398
+ else self.proj_out(hidden_states)
399
+ )
400
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
401
+
402
+ output = hidden_states + residual
403
+ elif self.is_input_vectorized:
404
+ hidden_states = self.norm_out(hidden_states)
405
+ logits = self.out(hidden_states)
406
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
407
+ logits = logits.permute(0, 2, 1)
408
+
409
+ # log(p(x_0))
410
+ output = F.log_softmax(logits.double(), dim=1).float()
411
+
412
+ if self.is_input_patches:
413
+ if self.config.norm_type != "ada_norm_single":
414
+ conditioning = self.transformer_blocks[0].norm1.emb(
415
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
416
+ )
417
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
418
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
419
+ hidden_states = self.proj_out_2(hidden_states)
420
+ elif self.config.norm_type == "ada_norm_single":
421
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
422
+ hidden_states = self.norm_out(hidden_states)
423
+ # Modulation
424
+ hidden_states = hidden_states * (1 + scale) + shift
425
+ hidden_states = self.proj_out(hidden_states)
426
+ hidden_states = hidden_states.squeeze(1)
427
+
428
+ # unpatchify
429
+ if self.adaln_single is None:
430
+ height = width = int(hidden_states.shape[1] ** 0.5)
431
+ hidden_states = hidden_states.reshape(
432
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
433
+ )
434
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
435
+ output = hidden_states.reshape(
436
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
437
+ )
438
+
439
+ if not return_dict:
440
+ return (output,)
441
+
442
+ return Transformer2DModelOutput(sample=output)
diffusers/models/transformer_temporal.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from ..configuration_utils import ConfigMixin, register_to_config
21
+ from ..utils import BaseOutput
22
+ from .attention import BasicTransformerBlock
23
+ from .modeling_utils import ModelMixin
24
+
25
+
26
+ @dataclass
27
+ class TransformerTemporalModelOutput(BaseOutput):
28
+ """
29
+ The output of [`TransformerTemporalModel`].
30
+
31
+ Args:
32
+ sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
33
+ The hidden states output conditioned on `encoder_hidden_states` input.
34
+ """
35
+
36
+ sample: torch.FloatTensor
37
+
38
+
39
+ class TransformerTemporalModel(ModelMixin, ConfigMixin):
40
+ """
41
+ A Transformer model for video-like data.
42
+
43
+ Parameters:
44
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
45
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
46
+ in_channels (`int`, *optional*):
47
+ The number of channels in the input and output (specify if the input is **continuous**).
48
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
49
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
50
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
51
+ attention_bias (`bool`, *optional*):
52
+ Configure if the `TransformerBlock` attention should contain a bias parameter.
53
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
54
+ This is fixed during training since it is used to learn a number of position embeddings.
55
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
56
+ Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
57
+ activation functions.
58
+ norm_elementwise_affine (`bool`, *optional*):
59
+ Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
60
+ double_self_attention (`bool`, *optional*):
61
+ Configure if each `TransformerBlock` should contain two self-attention layers.
62
+ positional_embeddings: (`str`, *optional*):
63
+ The type of positional embeddings to apply to the sequence input before passing use.
64
+ num_positional_embeddings: (`int`, *optional*):
65
+ The maximum length of the sequence over which to apply positional embeddings.
66
+ """
67
+
68
+ @register_to_config
69
+ def __init__(
70
+ self,
71
+ num_attention_heads: int = 16,
72
+ attention_head_dim: int = 88,
73
+ in_channels: Optional[int] = None,
74
+ out_channels: Optional[int] = None,
75
+ num_layers: int = 1,
76
+ dropout: float = 0.0,
77
+ norm_num_groups: int = 32,
78
+ cross_attention_dim: Optional[int] = None,
79
+ attention_bias: bool = False,
80
+ sample_size: Optional[int] = None,
81
+ activation_fn: str = "geglu",
82
+ norm_elementwise_affine: bool = True,
83
+ double_self_attention: bool = True,
84
+ positional_embeddings: Optional[str] = None,
85
+ num_positional_embeddings: Optional[int] = None,
86
+ ):
87
+ super().__init__()
88
+ self.num_attention_heads = num_attention_heads
89
+ self.attention_head_dim = attention_head_dim
90
+ inner_dim = num_attention_heads * attention_head_dim
91
+
92
+ self.in_channels = in_channels
93
+
94
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
95
+ self.proj_in = nn.Linear(in_channels, inner_dim)
96
+
97
+ # 3. Define transformers blocks
98
+ self.transformer_blocks = nn.ModuleList(
99
+ [
100
+ BasicTransformerBlock(
101
+ inner_dim,
102
+ num_attention_heads,
103
+ attention_head_dim,
104
+ dropout=dropout,
105
+ cross_attention_dim=cross_attention_dim,
106
+ activation_fn=activation_fn,
107
+ attention_bias=attention_bias,
108
+ double_self_attention=double_self_attention,
109
+ norm_elementwise_affine=norm_elementwise_affine,
110
+ positional_embeddings=positional_embeddings,
111
+ num_positional_embeddings=num_positional_embeddings,
112
+ )
113
+ for d in range(num_layers)
114
+ ]
115
+ )
116
+
117
+ self.proj_out = nn.Linear(inner_dim, in_channels)
118
+
119
+ def forward(
120
+ self,
121
+ hidden_states: torch.FloatTensor,
122
+ encoder_hidden_states: Optional[torch.LongTensor] = None,
123
+ timestep: Optional[torch.LongTensor] = None,
124
+ class_labels: torch.LongTensor = None,
125
+ num_frames: int = 1,
126
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
127
+ return_dict: bool = True,
128
+ ) -> TransformerTemporalModelOutput:
129
+ """
130
+ The [`TransformerTemporal`] forward method.
131
+
132
+ Args:
133
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
134
+ Input hidden_states.
135
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
136
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
137
+ self-attention.
138
+ timestep ( `torch.LongTensor`, *optional*):
139
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
140
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
141
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
142
+ `AdaLayerZeroNorm`.
143
+ num_frames (`int`, *optional*, defaults to 1):
144
+ The number of frames to be processed per batch. This is used to reshape the hidden states.
145
+ cross_attention_kwargs (`dict`, *optional*):
146
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
147
+ `self.processor` in
148
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
149
+ return_dict (`bool`, *optional*, defaults to `True`):
150
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
151
+ tuple.
152
+
153
+ Returns:
154
+ [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
155
+ If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
156
+ returned, otherwise a `tuple` where the first element is the sample tensor.
157
+ """
158
+ # 1. Input
159
+ batch_frames, channel, height, width = hidden_states.shape
160
+ batch_size = batch_frames // num_frames
161
+
162
+ residual = hidden_states
163
+
164
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
165
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
166
+
167
+ hidden_states = self.norm(hidden_states)
168
+ hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
169
+
170
+ hidden_states = self.proj_in(hidden_states)
171
+
172
+ # 2. Blocks
173
+ for block in self.transformer_blocks:
174
+ hidden_states = block(
175
+ hidden_states,
176
+ encoder_hidden_states=encoder_hidden_states,
177
+ timestep=timestep,
178
+ cross_attention_kwargs=cross_attention_kwargs,
179
+ class_labels=class_labels,
180
+ )
181
+
182
+ # 3. Output
183
+ hidden_states = self.proj_out(hidden_states)
184
+ hidden_states = (
185
+ hidden_states[None, None, :]
186
+ .reshape(batch_size, height, width, num_frames, channel)
187
+ .permute(0, 3, 4, 1, 2)
188
+ .contiguous()
189
+ )
190
+ hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
191
+
192
+ output = hidden_states + residual
193
+
194
+ if not return_dict:
195
+ return (output,)
196
+
197
+ return TransformerTemporalModelOutput(sample=output)
diffusers/models/unet_1d.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..utils import BaseOutput
23
+ from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
24
+ from .modeling_utils import ModelMixin
25
+ from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
26
+
27
+
28
+ @dataclass
29
+ class UNet1DOutput(BaseOutput):
30
+ """
31
+ The output of [`UNet1DModel`].
32
+
33
+ Args:
34
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
35
+ The hidden states output from the last layer of the model.
36
+ """
37
+
38
+ sample: torch.FloatTensor
39
+
40
+
41
+ class UNet1DModel(ModelMixin, ConfigMixin):
42
+ r"""
43
+ A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
44
+
45
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
46
+ for all models (such as downloading or saving).
47
+
48
+ Parameters:
49
+ sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
50
+ in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
51
+ out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
52
+ extra_in_channels (`int`, *optional*, defaults to 0):
53
+ Number of additional channels to be added to the input of the first down block. Useful for cases where the
54
+ input data has more channels than what the model was initially designed for.
55
+ time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
56
+ freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding.
57
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
58
+ Whether to flip sin to cos for Fourier time embedding.
59
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D")`):
60
+ Tuple of downsample block types.
61
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip")`):
62
+ Tuple of upsample block types.
63
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`):
64
+ Tuple of block output channels.
65
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet.
66
+ out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet.
67
+ act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks.
68
+ norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization.
69
+ layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block.
70
+ downsample_each_block (`int`, *optional*, defaults to `False`):
71
+ Experimental feature for using a UNet without upsampling.
72
+ """
73
+
74
+ @register_to_config
75
+ def __init__(
76
+ self,
77
+ sample_size: int = 65536,
78
+ sample_rate: Optional[int] = None,
79
+ in_channels: int = 2,
80
+ out_channels: int = 2,
81
+ extra_in_channels: int = 0,
82
+ time_embedding_type: str = "fourier",
83
+ flip_sin_to_cos: bool = True,
84
+ use_timestep_embedding: bool = False,
85
+ freq_shift: float = 0.0,
86
+ down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
87
+ up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
88
+ mid_block_type: Tuple[str] = "UNetMidBlock1D",
89
+ out_block_type: str = None,
90
+ block_out_channels: Tuple[int] = (32, 32, 64),
91
+ act_fn: str = None,
92
+ norm_num_groups: int = 8,
93
+ layers_per_block: int = 1,
94
+ downsample_each_block: bool = False,
95
+ ):
96
+ super().__init__()
97
+ self.sample_size = sample_size
98
+
99
+ # time
100
+ if time_embedding_type == "fourier":
101
+ self.time_proj = GaussianFourierProjection(
102
+ embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
103
+ )
104
+ timestep_input_dim = 2 * block_out_channels[0]
105
+ elif time_embedding_type == "positional":
106
+ self.time_proj = Timesteps(
107
+ block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
108
+ )
109
+ timestep_input_dim = block_out_channels[0]
110
+
111
+ if use_timestep_embedding:
112
+ time_embed_dim = block_out_channels[0] * 4
113
+ self.time_mlp = TimestepEmbedding(
114
+ in_channels=timestep_input_dim,
115
+ time_embed_dim=time_embed_dim,
116
+ act_fn=act_fn,
117
+ out_dim=block_out_channels[0],
118
+ )
119
+
120
+ self.down_blocks = nn.ModuleList([])
121
+ self.mid_block = None
122
+ self.up_blocks = nn.ModuleList([])
123
+ self.out_block = None
124
+
125
+ # down
126
+ output_channel = in_channels
127
+ for i, down_block_type in enumerate(down_block_types):
128
+ input_channel = output_channel
129
+ output_channel = block_out_channels[i]
130
+
131
+ if i == 0:
132
+ input_channel += extra_in_channels
133
+
134
+ is_final_block = i == len(block_out_channels) - 1
135
+
136
+ down_block = get_down_block(
137
+ down_block_type,
138
+ num_layers=layers_per_block,
139
+ in_channels=input_channel,
140
+ out_channels=output_channel,
141
+ temb_channels=block_out_channels[0],
142
+ add_downsample=not is_final_block or downsample_each_block,
143
+ )
144
+ self.down_blocks.append(down_block)
145
+
146
+ # mid
147
+ self.mid_block = get_mid_block(
148
+ mid_block_type,
149
+ in_channels=block_out_channels[-1],
150
+ mid_channels=block_out_channels[-1],
151
+ out_channels=block_out_channels[-1],
152
+ embed_dim=block_out_channels[0],
153
+ num_layers=layers_per_block,
154
+ add_downsample=downsample_each_block,
155
+ )
156
+
157
+ # up
158
+ reversed_block_out_channels = list(reversed(block_out_channels))
159
+ output_channel = reversed_block_out_channels[0]
160
+ if out_block_type is None:
161
+ final_upsample_channels = out_channels
162
+ else:
163
+ final_upsample_channels = block_out_channels[0]
164
+
165
+ for i, up_block_type in enumerate(up_block_types):
166
+ prev_output_channel = output_channel
167
+ output_channel = (
168
+ reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
169
+ )
170
+
171
+ is_final_block = i == len(block_out_channels) - 1
172
+
173
+ up_block = get_up_block(
174
+ up_block_type,
175
+ num_layers=layers_per_block,
176
+ in_channels=prev_output_channel,
177
+ out_channels=output_channel,
178
+ temb_channels=block_out_channels[0],
179
+ add_upsample=not is_final_block,
180
+ )
181
+ self.up_blocks.append(up_block)
182
+ prev_output_channel = output_channel
183
+
184
+ # out
185
+ num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
186
+ self.out_block = get_out_block(
187
+ out_block_type=out_block_type,
188
+ num_groups_out=num_groups_out,
189
+ embed_dim=block_out_channels[0],
190
+ out_channels=out_channels,
191
+ act_fn=act_fn,
192
+ fc_dim=block_out_channels[-1] // 4,
193
+ )
194
+
195
+ def forward(
196
+ self,
197
+ sample: torch.FloatTensor,
198
+ timestep: Union[torch.Tensor, float, int],
199
+ return_dict: bool = True,
200
+ ) -> Union[UNet1DOutput, Tuple]:
201
+ r"""
202
+ The [`UNet1DModel`] forward method.
203
+
204
+ Args:
205
+ sample (`torch.FloatTensor`):
206
+ The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
207
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
208
+ return_dict (`bool`, *optional*, defaults to `True`):
209
+ Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
210
+
211
+ Returns:
212
+ [`~models.unet_1d.UNet1DOutput`] or `tuple`:
213
+ If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is
214
+ returned where the first element is the sample tensor.
215
+ """
216
+
217
+ # 1. time
218
+ timesteps = timestep
219
+ if not torch.is_tensor(timesteps):
220
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
221
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
222
+ timesteps = timesteps[None].to(sample.device)
223
+
224
+ timestep_embed = self.time_proj(timesteps)
225
+ if self.config.use_timestep_embedding:
226
+ timestep_embed = self.time_mlp(timestep_embed)
227
+ else:
228
+ timestep_embed = timestep_embed[..., None]
229
+ timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
230
+ timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))
231
+
232
+ # 2. down
233
+ down_block_res_samples = ()
234
+ for downsample_block in self.down_blocks:
235
+ sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed)
236
+ down_block_res_samples += res_samples
237
+
238
+ # 3. mid
239
+ if self.mid_block:
240
+ sample = self.mid_block(sample, timestep_embed)
241
+
242
+ # 4. up
243
+ for i, upsample_block in enumerate(self.up_blocks):
244
+ res_samples = down_block_res_samples[-1:]
245
+ down_block_res_samples = down_block_res_samples[:-1]
246
+ sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
247
+
248
+ # 5. post-process
249
+ if self.out_block:
250
+ sample = self.out_block(sample, timestep_embed)
251
+
252
+ if not return_dict:
253
+ return (sample,)
254
+
255
+ return UNet1DOutput(sample=sample)
diffusers/models/unet_1d_blocks.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from .activations import get_activation
22
+ from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
23
+
24
+
25
+ class DownResnetBlock1D(nn.Module):
26
+ def __init__(
27
+ self,
28
+ in_channels: int,
29
+ out_channels: Optional[int] = None,
30
+ num_layers: int = 1,
31
+ conv_shortcut: bool = False,
32
+ temb_channels: int = 32,
33
+ groups: int = 32,
34
+ groups_out: Optional[int] = None,
35
+ non_linearity: Optional[str] = None,
36
+ time_embedding_norm: str = "default",
37
+ output_scale_factor: float = 1.0,
38
+ add_downsample: bool = True,
39
+ ):
40
+ super().__init__()
41
+ self.in_channels = in_channels
42
+ out_channels = in_channels if out_channels is None else out_channels
43
+ self.out_channels = out_channels
44
+ self.use_conv_shortcut = conv_shortcut
45
+ self.time_embedding_norm = time_embedding_norm
46
+ self.add_downsample = add_downsample
47
+ self.output_scale_factor = output_scale_factor
48
+
49
+ if groups_out is None:
50
+ groups_out = groups
51
+
52
+ # there will always be at least one resnet
53
+ resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
54
+
55
+ for _ in range(num_layers):
56
+ resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
57
+
58
+ self.resnets = nn.ModuleList(resnets)
59
+
60
+ if non_linearity is None:
61
+ self.nonlinearity = None
62
+ else:
63
+ self.nonlinearity = get_activation(non_linearity)
64
+
65
+ self.downsample = None
66
+ if add_downsample:
67
+ self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
68
+
69
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
70
+ output_states = ()
71
+
72
+ hidden_states = self.resnets[0](hidden_states, temb)
73
+ for resnet in self.resnets[1:]:
74
+ hidden_states = resnet(hidden_states, temb)
75
+
76
+ output_states += (hidden_states,)
77
+
78
+ if self.nonlinearity is not None:
79
+ hidden_states = self.nonlinearity(hidden_states)
80
+
81
+ if self.downsample is not None:
82
+ hidden_states = self.downsample(hidden_states)
83
+
84
+ return hidden_states, output_states
85
+
86
+
87
+ class UpResnetBlock1D(nn.Module):
88
+ def __init__(
89
+ self,
90
+ in_channels: int,
91
+ out_channels: Optional[int] = None,
92
+ num_layers: int = 1,
93
+ temb_channels: int = 32,
94
+ groups: int = 32,
95
+ groups_out: Optional[int] = None,
96
+ non_linearity: Optional[str] = None,
97
+ time_embedding_norm: str = "default",
98
+ output_scale_factor: float = 1.0,
99
+ add_upsample: bool = True,
100
+ ):
101
+ super().__init__()
102
+ self.in_channels = in_channels
103
+ out_channels = in_channels if out_channels is None else out_channels
104
+ self.out_channels = out_channels
105
+ self.time_embedding_norm = time_embedding_norm
106
+ self.add_upsample = add_upsample
107
+ self.output_scale_factor = output_scale_factor
108
+
109
+ if groups_out is None:
110
+ groups_out = groups
111
+
112
+ # there will always be at least one resnet
113
+ resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
114
+
115
+ for _ in range(num_layers):
116
+ resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
117
+
118
+ self.resnets = nn.ModuleList(resnets)
119
+
120
+ if non_linearity is None:
121
+ self.nonlinearity = None
122
+ else:
123
+ self.nonlinearity = get_activation(non_linearity)
124
+
125
+ self.upsample = None
126
+ if add_upsample:
127
+ self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
128
+
129
+ def forward(
130
+ self,
131
+ hidden_states: torch.FloatTensor,
132
+ res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None,
133
+ temb: Optional[torch.FloatTensor] = None,
134
+ ) -> torch.FloatTensor:
135
+ if res_hidden_states_tuple is not None:
136
+ res_hidden_states = res_hidden_states_tuple[-1]
137
+ hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
138
+
139
+ hidden_states = self.resnets[0](hidden_states, temb)
140
+ for resnet in self.resnets[1:]:
141
+ hidden_states = resnet(hidden_states, temb)
142
+
143
+ if self.nonlinearity is not None:
144
+ hidden_states = self.nonlinearity(hidden_states)
145
+
146
+ if self.upsample is not None:
147
+ hidden_states = self.upsample(hidden_states)
148
+
149
+ return hidden_states
150
+
151
+
152
+ class ValueFunctionMidBlock1D(nn.Module):
153
+ def __init__(self, in_channels: int, out_channels: int, embed_dim: int):
154
+ super().__init__()
155
+ self.in_channels = in_channels
156
+ self.out_channels = out_channels
157
+ self.embed_dim = embed_dim
158
+
159
+ self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
160
+ self.down1 = Downsample1D(out_channels // 2, use_conv=True)
161
+ self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
162
+ self.down2 = Downsample1D(out_channels // 4, use_conv=True)
163
+
164
+ def forward(self, x: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
165
+ x = self.res1(x, temb)
166
+ x = self.down1(x)
167
+ x = self.res2(x, temb)
168
+ x = self.down2(x)
169
+ return x
170
+
171
+
172
+ class MidResTemporalBlock1D(nn.Module):
173
+ def __init__(
174
+ self,
175
+ in_channels: int,
176
+ out_channels: int,
177
+ embed_dim: int,
178
+ num_layers: int = 1,
179
+ add_downsample: bool = False,
180
+ add_upsample: bool = False,
181
+ non_linearity: Optional[str] = None,
182
+ ):
183
+ super().__init__()
184
+ self.in_channels = in_channels
185
+ self.out_channels = out_channels
186
+ self.add_downsample = add_downsample
187
+
188
+ # there will always be at least one resnet
189
+ resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
190
+
191
+ for _ in range(num_layers):
192
+ resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
193
+
194
+ self.resnets = nn.ModuleList(resnets)
195
+
196
+ if non_linearity is None:
197
+ self.nonlinearity = None
198
+ else:
199
+ self.nonlinearity = get_activation(non_linearity)
200
+
201
+ self.upsample = None
202
+ if add_upsample:
203
+ self.upsample = Downsample1D(out_channels, use_conv=True)
204
+
205
+ self.downsample = None
206
+ if add_downsample:
207
+ self.downsample = Downsample1D(out_channels, use_conv=True)
208
+
209
+ if self.upsample and self.downsample:
210
+ raise ValueError("Block cannot downsample and upsample")
211
+
212
+ def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
213
+ hidden_states = self.resnets[0](hidden_states, temb)
214
+ for resnet in self.resnets[1:]:
215
+ hidden_states = resnet(hidden_states, temb)
216
+
217
+ if self.upsample:
218
+ hidden_states = self.upsample(hidden_states)
219
+ if self.downsample:
220
+ self.downsample = self.downsample(hidden_states)
221
+
222
+ return hidden_states
223
+
224
+
225
+ class OutConv1DBlock(nn.Module):
226
+ def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str):
227
+ super().__init__()
228
+ self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
229
+ self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
230
+ self.final_conv1d_act = get_activation(act_fn)
231
+ self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
232
+
233
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
234
+ hidden_states = self.final_conv1d_1(hidden_states)
235
+ hidden_states = rearrange_dims(hidden_states)
236
+ hidden_states = self.final_conv1d_gn(hidden_states)
237
+ hidden_states = rearrange_dims(hidden_states)
238
+ hidden_states = self.final_conv1d_act(hidden_states)
239
+ hidden_states = self.final_conv1d_2(hidden_states)
240
+ return hidden_states
241
+
242
+
243
+ class OutValueFunctionBlock(nn.Module):
244
+ def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"):
245
+ super().__init__()
246
+ self.final_block = nn.ModuleList(
247
+ [
248
+ nn.Linear(fc_dim + embed_dim, fc_dim // 2),
249
+ get_activation(act_fn),
250
+ nn.Linear(fc_dim // 2, 1),
251
+ ]
252
+ )
253
+
254
+ def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
255
+ hidden_states = hidden_states.view(hidden_states.shape[0], -1)
256
+ hidden_states = torch.cat((hidden_states, temb), dim=-1)
257
+ for layer in self.final_block:
258
+ hidden_states = layer(hidden_states)
259
+
260
+ return hidden_states
261
+
262
+
263
+ _kernels = {
264
+ "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
265
+ "cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875],
266
+ "lanczos3": [
267
+ 0.003689131001010537,
268
+ 0.015056144446134567,
269
+ -0.03399861603975296,
270
+ -0.066637322306633,
271
+ 0.13550527393817902,
272
+ 0.44638532400131226,
273
+ 0.44638532400131226,
274
+ 0.13550527393817902,
275
+ -0.066637322306633,
276
+ -0.03399861603975296,
277
+ 0.015056144446134567,
278
+ 0.003689131001010537,
279
+ ],
280
+ }
281
+
282
+
283
+ class Downsample1d(nn.Module):
284
+ def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
285
+ super().__init__()
286
+ self.pad_mode = pad_mode
287
+ kernel_1d = torch.tensor(_kernels[kernel])
288
+ self.pad = kernel_1d.shape[0] // 2 - 1
289
+ self.register_buffer("kernel", kernel_1d)
290
+
291
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
292
+ hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
293
+ weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
294
+ indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
295
+ kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
296
+ weight[indices, indices] = kernel
297
+ return F.conv1d(hidden_states, weight, stride=2)
298
+
299
+
300
+ class Upsample1d(nn.Module):
301
+ def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
302
+ super().__init__()
303
+ self.pad_mode = pad_mode
304
+ kernel_1d = torch.tensor(_kernels[kernel]) * 2
305
+ self.pad = kernel_1d.shape[0] // 2 - 1
306
+ self.register_buffer("kernel", kernel_1d)
307
+
308
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
309
+ hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
310
+ weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
311
+ indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
312
+ kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
313
+ weight[indices, indices] = kernel
314
+ return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)
315
+
316
+
317
+ class SelfAttention1d(nn.Module):
318
+ def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0):
319
+ super().__init__()
320
+ self.channels = in_channels
321
+ self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
322
+ self.num_heads = n_head
323
+
324
+ self.query = nn.Linear(self.channels, self.channels)
325
+ self.key = nn.Linear(self.channels, self.channels)
326
+ self.value = nn.Linear(self.channels, self.channels)
327
+
328
+ self.proj_attn = nn.Linear(self.channels, self.channels, bias=True)
329
+
330
+ self.dropout = nn.Dropout(dropout_rate, inplace=True)
331
+
332
+ def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
333
+ new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
334
+ # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
335
+ new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
336
+ return new_projection
337
+
338
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
339
+ residual = hidden_states
340
+ batch, channel_dim, seq = hidden_states.shape
341
+
342
+ hidden_states = self.group_norm(hidden_states)
343
+ hidden_states = hidden_states.transpose(1, 2)
344
+
345
+ query_proj = self.query(hidden_states)
346
+ key_proj = self.key(hidden_states)
347
+ value_proj = self.value(hidden_states)
348
+
349
+ query_states = self.transpose_for_scores(query_proj)
350
+ key_states = self.transpose_for_scores(key_proj)
351
+ value_states = self.transpose_for_scores(value_proj)
352
+
353
+ scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1]))
354
+
355
+ attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
356
+ attention_probs = torch.softmax(attention_scores, dim=-1)
357
+
358
+ # compute attention output
359
+ hidden_states = torch.matmul(attention_probs, value_states)
360
+
361
+ hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
362
+ new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
363
+ hidden_states = hidden_states.view(new_hidden_states_shape)
364
+
365
+ # compute next hidden_states
366
+ hidden_states = self.proj_attn(hidden_states)
367
+ hidden_states = hidden_states.transpose(1, 2)
368
+ hidden_states = self.dropout(hidden_states)
369
+
370
+ output = hidden_states + residual
371
+
372
+ return output
373
+
374
+
375
+ class ResConvBlock(nn.Module):
376
+ def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False):
377
+ super().__init__()
378
+ self.is_last = is_last
379
+ self.has_conv_skip = in_channels != out_channels
380
+
381
+ if self.has_conv_skip:
382
+ self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False)
383
+
384
+ self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2)
385
+ self.group_norm_1 = nn.GroupNorm(1, mid_channels)
386
+ self.gelu_1 = nn.GELU()
387
+ self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2)
388
+
389
+ if not self.is_last:
390
+ self.group_norm_2 = nn.GroupNorm(1, out_channels)
391
+ self.gelu_2 = nn.GELU()
392
+
393
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
394
+ residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
395
+
396
+ hidden_states = self.conv_1(hidden_states)
397
+ hidden_states = self.group_norm_1(hidden_states)
398
+ hidden_states = self.gelu_1(hidden_states)
399
+ hidden_states = self.conv_2(hidden_states)
400
+
401
+ if not self.is_last:
402
+ hidden_states = self.group_norm_2(hidden_states)
403
+ hidden_states = self.gelu_2(hidden_states)
404
+
405
+ output = hidden_states + residual
406
+ return output
407
+
408
+
409
+ class UNetMidBlock1D(nn.Module):
410
+ def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None):
411
+ super().__init__()
412
+
413
+ out_channels = in_channels if out_channels is None else out_channels
414
+
415
+ # there is always at least one resnet
416
+ self.down = Downsample1d("cubic")
417
+ resnets = [
418
+ ResConvBlock(in_channels, mid_channels, mid_channels),
419
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
420
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
421
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
422
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
423
+ ResConvBlock(mid_channels, mid_channels, out_channels),
424
+ ]
425
+ attentions = [
426
+ SelfAttention1d(mid_channels, mid_channels // 32),
427
+ SelfAttention1d(mid_channels, mid_channels // 32),
428
+ SelfAttention1d(mid_channels, mid_channels // 32),
429
+ SelfAttention1d(mid_channels, mid_channels // 32),
430
+ SelfAttention1d(mid_channels, mid_channels // 32),
431
+ SelfAttention1d(out_channels, out_channels // 32),
432
+ ]
433
+ self.up = Upsample1d(kernel="cubic")
434
+
435
+ self.attentions = nn.ModuleList(attentions)
436
+ self.resnets = nn.ModuleList(resnets)
437
+
438
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
439
+ hidden_states = self.down(hidden_states)
440
+ for attn, resnet in zip(self.attentions, self.resnets):
441
+ hidden_states = resnet(hidden_states)
442
+ hidden_states = attn(hidden_states)
443
+
444
+ hidden_states = self.up(hidden_states)
445
+
446
+ return hidden_states
447
+
448
+
449
+ class AttnDownBlock1D(nn.Module):
450
+ def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
451
+ super().__init__()
452
+ mid_channels = out_channels if mid_channels is None else mid_channels
453
+
454
+ self.down = Downsample1d("cubic")
455
+ resnets = [
456
+ ResConvBlock(in_channels, mid_channels, mid_channels),
457
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
458
+ ResConvBlock(mid_channels, mid_channels, out_channels),
459
+ ]
460
+ attentions = [
461
+ SelfAttention1d(mid_channels, mid_channels // 32),
462
+ SelfAttention1d(mid_channels, mid_channels // 32),
463
+ SelfAttention1d(out_channels, out_channels // 32),
464
+ ]
465
+
466
+ self.attentions = nn.ModuleList(attentions)
467
+ self.resnets = nn.ModuleList(resnets)
468
+
469
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
470
+ hidden_states = self.down(hidden_states)
471
+
472
+ for resnet, attn in zip(self.resnets, self.attentions):
473
+ hidden_states = resnet(hidden_states)
474
+ hidden_states = attn(hidden_states)
475
+
476
+ return hidden_states, (hidden_states,)
477
+
478
+
479
+ class DownBlock1D(nn.Module):
480
+ def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
481
+ super().__init__()
482
+ mid_channels = out_channels if mid_channels is None else mid_channels
483
+
484
+ self.down = Downsample1d("cubic")
485
+ resnets = [
486
+ ResConvBlock(in_channels, mid_channels, mid_channels),
487
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
488
+ ResConvBlock(mid_channels, mid_channels, out_channels),
489
+ ]
490
+
491
+ self.resnets = nn.ModuleList(resnets)
492
+
493
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
494
+ hidden_states = self.down(hidden_states)
495
+
496
+ for resnet in self.resnets:
497
+ hidden_states = resnet(hidden_states)
498
+
499
+ return hidden_states, (hidden_states,)
500
+
501
+
502
+ class DownBlock1DNoSkip(nn.Module):
503
+ def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
504
+ super().__init__()
505
+ mid_channels = out_channels if mid_channels is None else mid_channels
506
+
507
+ resnets = [
508
+ ResConvBlock(in_channels, mid_channels, mid_channels),
509
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
510
+ ResConvBlock(mid_channels, mid_channels, out_channels),
511
+ ]
512
+
513
+ self.resnets = nn.ModuleList(resnets)
514
+
515
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
516
+ hidden_states = torch.cat([hidden_states, temb], dim=1)
517
+ for resnet in self.resnets:
518
+ hidden_states = resnet(hidden_states)
519
+
520
+ return hidden_states, (hidden_states,)
521
+
522
+
523
+ class AttnUpBlock1D(nn.Module):
524
+ def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
525
+ super().__init__()
526
+ mid_channels = out_channels if mid_channels is None else mid_channels
527
+
528
+ resnets = [
529
+ ResConvBlock(2 * in_channels, mid_channels, mid_channels),
530
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
531
+ ResConvBlock(mid_channels, mid_channels, out_channels),
532
+ ]
533
+ attentions = [
534
+ SelfAttention1d(mid_channels, mid_channels // 32),
535
+ SelfAttention1d(mid_channels, mid_channels // 32),
536
+ SelfAttention1d(out_channels, out_channels // 32),
537
+ ]
538
+
539
+ self.attentions = nn.ModuleList(attentions)
540
+ self.resnets = nn.ModuleList(resnets)
541
+ self.up = Upsample1d(kernel="cubic")
542
+
543
+ def forward(
544
+ self,
545
+ hidden_states: torch.FloatTensor,
546
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
547
+ temb: Optional[torch.FloatTensor] = None,
548
+ ) -> torch.FloatTensor:
549
+ res_hidden_states = res_hidden_states_tuple[-1]
550
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
551
+
552
+ for resnet, attn in zip(self.resnets, self.attentions):
553
+ hidden_states = resnet(hidden_states)
554
+ hidden_states = attn(hidden_states)
555
+
556
+ hidden_states = self.up(hidden_states)
557
+
558
+ return hidden_states
559
+
560
+
561
+ class UpBlock1D(nn.Module):
562
+ def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
563
+ super().__init__()
564
+ mid_channels = in_channels if mid_channels is None else mid_channels
565
+
566
+ resnets = [
567
+ ResConvBlock(2 * in_channels, mid_channels, mid_channels),
568
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
569
+ ResConvBlock(mid_channels, mid_channels, out_channels),
570
+ ]
571
+
572
+ self.resnets = nn.ModuleList(resnets)
573
+ self.up = Upsample1d(kernel="cubic")
574
+
575
+ def forward(
576
+ self,
577
+ hidden_states: torch.FloatTensor,
578
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
579
+ temb: Optional[torch.FloatTensor] = None,
580
+ ) -> torch.FloatTensor:
581
+ res_hidden_states = res_hidden_states_tuple[-1]
582
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
583
+
584
+ for resnet in self.resnets:
585
+ hidden_states = resnet(hidden_states)
586
+
587
+ hidden_states = self.up(hidden_states)
588
+
589
+ return hidden_states
590
+
591
+
592
+ class UpBlock1DNoSkip(nn.Module):
593
+ def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
594
+ super().__init__()
595
+ mid_channels = in_channels if mid_channels is None else mid_channels
596
+
597
+ resnets = [
598
+ ResConvBlock(2 * in_channels, mid_channels, mid_channels),
599
+ ResConvBlock(mid_channels, mid_channels, mid_channels),
600
+ ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
601
+ ]
602
+
603
+ self.resnets = nn.ModuleList(resnets)
604
+
605
+ def forward(
606
+ self,
607
+ hidden_states: torch.FloatTensor,
608
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
609
+ temb: Optional[torch.FloatTensor] = None,
610
+ ) -> torch.FloatTensor:
611
+ res_hidden_states = res_hidden_states_tuple[-1]
612
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
613
+
614
+ for resnet in self.resnets:
615
+ hidden_states = resnet(hidden_states)
616
+
617
+ return hidden_states
618
+
619
+
620
+ DownBlockType = Union[DownResnetBlock1D, DownBlock1D, AttnDownBlock1D, DownBlock1DNoSkip]
621
+ MidBlockType = Union[MidResTemporalBlock1D, ValueFunctionMidBlock1D, UNetMidBlock1D]
622
+ OutBlockType = Union[OutConv1DBlock, OutValueFunctionBlock]
623
+ UpBlockType = Union[UpResnetBlock1D, UpBlock1D, AttnUpBlock1D, UpBlock1DNoSkip]
624
+
625
+
626
+ def get_down_block(
627
+ down_block_type: str,
628
+ num_layers: int,
629
+ in_channels: int,
630
+ out_channels: int,
631
+ temb_channels: int,
632
+ add_downsample: bool,
633
+ ) -> DownBlockType:
634
+ if down_block_type == "DownResnetBlock1D":
635
+ return DownResnetBlock1D(
636
+ in_channels=in_channels,
637
+ num_layers=num_layers,
638
+ out_channels=out_channels,
639
+ temb_channels=temb_channels,
640
+ add_downsample=add_downsample,
641
+ )
642
+ elif down_block_type == "DownBlock1D":
643
+ return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
644
+ elif down_block_type == "AttnDownBlock1D":
645
+ return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
646
+ elif down_block_type == "DownBlock1DNoSkip":
647
+ return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
648
+ raise ValueError(f"{down_block_type} does not exist.")
649
+
650
+
651
+ def get_up_block(
652
+ up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool
653
+ ) -> UpBlockType:
654
+ if up_block_type == "UpResnetBlock1D":
655
+ return UpResnetBlock1D(
656
+ in_channels=in_channels,
657
+ num_layers=num_layers,
658
+ out_channels=out_channels,
659
+ temb_channels=temb_channels,
660
+ add_upsample=add_upsample,
661
+ )
662
+ elif up_block_type == "UpBlock1D":
663
+ return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
664
+ elif up_block_type == "AttnUpBlock1D":
665
+ return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
666
+ elif up_block_type == "UpBlock1DNoSkip":
667
+ return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
668
+ raise ValueError(f"{up_block_type} does not exist.")
669
+
670
+
671
+ def get_mid_block(
672
+ mid_block_type: str,
673
+ num_layers: int,
674
+ in_channels: int,
675
+ mid_channels: int,
676
+ out_channels: int,
677
+ embed_dim: int,
678
+ add_downsample: bool,
679
+ ) -> MidBlockType:
680
+ if mid_block_type == "MidResTemporalBlock1D":
681
+ return MidResTemporalBlock1D(
682
+ num_layers=num_layers,
683
+ in_channels=in_channels,
684
+ out_channels=out_channels,
685
+ embed_dim=embed_dim,
686
+ add_downsample=add_downsample,
687
+ )
688
+ elif mid_block_type == "ValueFunctionMidBlock1D":
689
+ return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
690
+ elif mid_block_type == "UNetMidBlock1D":
691
+ return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
692
+ raise ValueError(f"{mid_block_type} does not exist.")
693
+
694
+
695
+ def get_out_block(
696
+ *, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int
697
+ ) -> Optional[OutBlockType]:
698
+ if out_block_type == "OutConv1DBlock":
699
+ return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
700
+ elif out_block_type == "ValueFunction":
701
+ return OutValueFunctionBlock(fc_dim, embed_dim, act_fn)
702
+ return None
diffusers/models/unet_2d.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from ..configuration_utils import ConfigMixin, register_to_config
21
+ from ..utils import BaseOutput
22
+ from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
23
+ from .modeling_utils import ModelMixin
24
+ from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
25
+
26
+
27
+ @dataclass
28
+ class UNet2DOutput(BaseOutput):
29
+ """
30
+ The output of [`UNet2DModel`].
31
+
32
+ Args:
33
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
34
+ The hidden states output from the last layer of the model.
35
+ """
36
+
37
+ sample: torch.FloatTensor
38
+
39
+
40
+ class UNet2DModel(ModelMixin, ConfigMixin):
41
+ r"""
42
+ A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
43
+
44
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
45
+ for all models (such as downloading or saving).
46
+
47
+ Parameters:
48
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
49
+ Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
50
+ 1)`.
51
+ in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
52
+ out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
53
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
54
+ time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
55
+ freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
56
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
57
+ Whether to flip sin to cos for Fourier time embedding.
58
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
59
+ Tuple of downsample block types.
60
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
61
+ Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
62
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
63
+ Tuple of upsample block types.
64
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
65
+ Tuple of block output channels.
66
+ layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
67
+ mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
68
+ downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
69
+ downsample_type (`str`, *optional*, defaults to `conv`):
70
+ The downsample type for downsampling layers. Choose between "conv" and "resnet"
71
+ upsample_type (`str`, *optional*, defaults to `conv`):
72
+ The upsample type for upsampling layers. Choose between "conv" and "resnet"
73
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
74
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
75
+ attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
76
+ norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
77
+ attn_norm_num_groups (`int`, *optional*, defaults to `None`):
78
+ If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the
79
+ given number of groups. If left as `None`, the group norm layer will only be created if
80
+ `resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups.
81
+ norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
82
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
83
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
84
+ class_embed_type (`str`, *optional*, defaults to `None`):
85
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
86
+ `"timestep"`, or `"identity"`.
87
+ num_class_embeds (`int`, *optional*, defaults to `None`):
88
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
89
+ conditioning with `class_embed_type` equal to `None`.
90
+ """
91
+
92
+ @register_to_config
93
+ def __init__(
94
+ self,
95
+ sample_size: Optional[Union[int, Tuple[int, int]]] = None,
96
+ in_channels: int = 3,
97
+ out_channels: int = 3,
98
+ center_input_sample: bool = False,
99
+ time_embedding_type: str = "positional",
100
+ freq_shift: int = 0,
101
+ flip_sin_to_cos: bool = True,
102
+ down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
103
+ up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
104
+ block_out_channels: Tuple[int] = (224, 448, 672, 896),
105
+ layers_per_block: int = 2,
106
+ mid_block_scale_factor: float = 1,
107
+ downsample_padding: int = 1,
108
+ downsample_type: str = "conv",
109
+ upsample_type: str = "conv",
110
+ dropout: float = 0.0,
111
+ act_fn: str = "silu",
112
+ attention_head_dim: Optional[int] = 8,
113
+ norm_num_groups: int = 32,
114
+ attn_norm_num_groups: Optional[int] = None,
115
+ norm_eps: float = 1e-5,
116
+ resnet_time_scale_shift: str = "default",
117
+ add_attention: bool = True,
118
+ class_embed_type: Optional[str] = None,
119
+ num_class_embeds: Optional[int] = None,
120
+ num_train_timesteps: Optional[int] = None,
121
+ ):
122
+ super().__init__()
123
+
124
+ self.sample_size = sample_size
125
+ time_embed_dim = block_out_channels[0] * 4
126
+
127
+ # Check inputs
128
+ if len(down_block_types) != len(up_block_types):
129
+ raise ValueError(
130
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
131
+ )
132
+
133
+ if len(block_out_channels) != len(down_block_types):
134
+ raise ValueError(
135
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
136
+ )
137
+
138
+ # input
139
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
140
+
141
+ # time
142
+ if time_embedding_type == "fourier":
143
+ self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
144
+ timestep_input_dim = 2 * block_out_channels[0]
145
+ elif time_embedding_type == "positional":
146
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
147
+ timestep_input_dim = block_out_channels[0]
148
+ elif time_embedding_type == "learned":
149
+ self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
150
+ timestep_input_dim = block_out_channels[0]
151
+
152
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
153
+
154
+ # class embedding
155
+ if class_embed_type is None and num_class_embeds is not None:
156
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
157
+ elif class_embed_type == "timestep":
158
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
159
+ elif class_embed_type == "identity":
160
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
161
+ else:
162
+ self.class_embedding = None
163
+
164
+ self.down_blocks = nn.ModuleList([])
165
+ self.mid_block = None
166
+ self.up_blocks = nn.ModuleList([])
167
+
168
+ # down
169
+ output_channel = block_out_channels[0]
170
+ for i, down_block_type in enumerate(down_block_types):
171
+ input_channel = output_channel
172
+ output_channel = block_out_channels[i]
173
+ is_final_block = i == len(block_out_channels) - 1
174
+
175
+ down_block = get_down_block(
176
+ down_block_type,
177
+ num_layers=layers_per_block,
178
+ in_channels=input_channel,
179
+ out_channels=output_channel,
180
+ temb_channels=time_embed_dim,
181
+ add_downsample=not is_final_block,
182
+ resnet_eps=norm_eps,
183
+ resnet_act_fn=act_fn,
184
+ resnet_groups=norm_num_groups,
185
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
186
+ downsample_padding=downsample_padding,
187
+ resnet_time_scale_shift=resnet_time_scale_shift,
188
+ downsample_type=downsample_type,
189
+ dropout=dropout,
190
+ )
191
+ self.down_blocks.append(down_block)
192
+
193
+ # mid
194
+ self.mid_block = UNetMidBlock2D(
195
+ in_channels=block_out_channels[-1],
196
+ temb_channels=time_embed_dim,
197
+ dropout=dropout,
198
+ resnet_eps=norm_eps,
199
+ resnet_act_fn=act_fn,
200
+ output_scale_factor=mid_block_scale_factor,
201
+ resnet_time_scale_shift=resnet_time_scale_shift,
202
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
203
+ resnet_groups=norm_num_groups,
204
+ attn_groups=attn_norm_num_groups,
205
+ add_attention=add_attention,
206
+ )
207
+
208
+ # up
209
+ reversed_block_out_channels = list(reversed(block_out_channels))
210
+ output_channel = reversed_block_out_channels[0]
211
+ for i, up_block_type in enumerate(up_block_types):
212
+ prev_output_channel = output_channel
213
+ output_channel = reversed_block_out_channels[i]
214
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
215
+
216
+ is_final_block = i == len(block_out_channels) - 1
217
+
218
+ up_block = get_up_block(
219
+ up_block_type,
220
+ num_layers=layers_per_block + 1,
221
+ in_channels=input_channel,
222
+ out_channels=output_channel,
223
+ prev_output_channel=prev_output_channel,
224
+ temb_channels=time_embed_dim,
225
+ add_upsample=not is_final_block,
226
+ resnet_eps=norm_eps,
227
+ resnet_act_fn=act_fn,
228
+ resnet_groups=norm_num_groups,
229
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
230
+ resnet_time_scale_shift=resnet_time_scale_shift,
231
+ upsample_type=upsample_type,
232
+ dropout=dropout,
233
+ )
234
+ self.up_blocks.append(up_block)
235
+ prev_output_channel = output_channel
236
+
237
+ # out
238
+ num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
239
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
240
+ self.conv_act = nn.SiLU()
241
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
242
+
243
+ def forward(
244
+ self,
245
+ sample: torch.FloatTensor,
246
+ timestep: Union[torch.Tensor, float, int],
247
+ class_labels: Optional[torch.Tensor] = None,
248
+ return_dict: bool = True,
249
+ ) -> Union[UNet2DOutput, Tuple]:
250
+ r"""
251
+ The [`UNet2DModel`] forward method.
252
+
253
+ Args:
254
+ sample (`torch.FloatTensor`):
255
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
256
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
257
+ class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
258
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
259
+ return_dict (`bool`, *optional*, defaults to `True`):
260
+ Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
261
+
262
+ Returns:
263
+ [`~models.unet_2d.UNet2DOutput`] or `tuple`:
264
+ If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
265
+ returned where the first element is the sample tensor.
266
+ """
267
+ # 0. center input if necessary
268
+ if self.config.center_input_sample:
269
+ sample = 2 * sample - 1.0
270
+
271
+ # 1. time
272
+ timesteps = timestep
273
+ if not torch.is_tensor(timesteps):
274
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
275
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
276
+ timesteps = timesteps[None].to(sample.device)
277
+
278
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
279
+ timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
280
+
281
+ t_emb = self.time_proj(timesteps)
282
+
283
+ # timesteps does not contain any weights and will always return f32 tensors
284
+ # but time_embedding might actually be running in fp16. so we need to cast here.
285
+ # there might be better ways to encapsulate this.
286
+ t_emb = t_emb.to(dtype=self.dtype)
287
+ emb = self.time_embedding(t_emb)
288
+
289
+ if self.class_embedding is not None:
290
+ if class_labels is None:
291
+ raise ValueError("class_labels should be provided when doing class conditioning")
292
+
293
+ if self.config.class_embed_type == "timestep":
294
+ class_labels = self.time_proj(class_labels)
295
+
296
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
297
+ emb = emb + class_emb
298
+ elif self.class_embedding is None and class_labels is not None:
299
+ raise ValueError("class_embedding needs to be initialized in order to use class conditioning")
300
+
301
+ # 2. pre-process
302
+ skip_sample = sample
303
+ sample = self.conv_in(sample)
304
+
305
+ # 3. down
306
+ down_block_res_samples = (sample,)
307
+ for downsample_block in self.down_blocks:
308
+ if hasattr(downsample_block, "skip_conv"):
309
+ sample, res_samples, skip_sample = downsample_block(
310
+ hidden_states=sample, temb=emb, skip_sample=skip_sample
311
+ )
312
+ else:
313
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
314
+
315
+ down_block_res_samples += res_samples
316
+
317
+ # 4. mid
318
+ sample = self.mid_block(sample, emb)
319
+
320
+ # 5. up
321
+ skip_sample = None
322
+ for upsample_block in self.up_blocks:
323
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
324
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
325
+
326
+ if hasattr(upsample_block, "skip_conv"):
327
+ sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
328
+ else:
329
+ sample = upsample_block(sample, res_samples, emb)
330
+
331
+ # 6. post-process
332
+ sample = self.conv_norm_out(sample)
333
+ sample = self.conv_act(sample)
334
+ sample = self.conv_out(sample)
335
+
336
+ if skip_sample is not None:
337
+ sample += skip_sample
338
+
339
+ if self.config.time_embedding_type == "fourier":
340
+ timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
341
+ sample = sample / timesteps
342
+
343
+ if not return_dict:
344
+ return (sample,)
345
+
346
+ return UNet2DOutput(sample=sample)
diffusers/models/unet_2d_blocks.py ADDED
The diff for this file is too large to render. See raw diff
 
diffusers/models/unet_2d_blocks_flax.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import flax.linen as nn
16
+ import jax.numpy as jnp
17
+
18
+ from .attention_flax import FlaxTransformer2DModel
19
+ from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
20
+
21
+
22
+ class FlaxCrossAttnDownBlock2D(nn.Module):
23
+ r"""
24
+ Cross Attention 2D Downsizing block - original architecture from Unet transformers:
25
+ https://arxiv.org/abs/2103.06104
26
+
27
+ Parameters:
28
+ in_channels (:obj:`int`):
29
+ Input channels
30
+ out_channels (:obj:`int`):
31
+ Output channels
32
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
33
+ Dropout rate
34
+ num_layers (:obj:`int`, *optional*, defaults to 1):
35
+ Number of attention blocks layers
36
+ num_attention_heads (:obj:`int`, *optional*, defaults to 1):
37
+ Number of attention heads of each spatial transformer block
38
+ add_downsample (:obj:`bool`, *optional*, defaults to `True`):
39
+ Whether to add downsampling layer before each final output
40
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
41
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
42
+ split_head_dim (`bool`, *optional*, defaults to `False`):
43
+ Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
44
+ enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
45
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
46
+ Parameters `dtype`
47
+ """
48
+ in_channels: int
49
+ out_channels: int
50
+ dropout: float = 0.0
51
+ num_layers: int = 1
52
+ num_attention_heads: int = 1
53
+ add_downsample: bool = True
54
+ use_linear_projection: bool = False
55
+ only_cross_attention: bool = False
56
+ use_memory_efficient_attention: bool = False
57
+ split_head_dim: bool = False
58
+ dtype: jnp.dtype = jnp.float32
59
+ transformer_layers_per_block: int = 1
60
+
61
+ def setup(self):
62
+ resnets = []
63
+ attentions = []
64
+
65
+ for i in range(self.num_layers):
66
+ in_channels = self.in_channels if i == 0 else self.out_channels
67
+
68
+ res_block = FlaxResnetBlock2D(
69
+ in_channels=in_channels,
70
+ out_channels=self.out_channels,
71
+ dropout_prob=self.dropout,
72
+ dtype=self.dtype,
73
+ )
74
+ resnets.append(res_block)
75
+
76
+ attn_block = FlaxTransformer2DModel(
77
+ in_channels=self.out_channels,
78
+ n_heads=self.num_attention_heads,
79
+ d_head=self.out_channels // self.num_attention_heads,
80
+ depth=self.transformer_layers_per_block,
81
+ use_linear_projection=self.use_linear_projection,
82
+ only_cross_attention=self.only_cross_attention,
83
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
84
+ split_head_dim=self.split_head_dim,
85
+ dtype=self.dtype,
86
+ )
87
+ attentions.append(attn_block)
88
+
89
+ self.resnets = resnets
90
+ self.attentions = attentions
91
+
92
+ if self.add_downsample:
93
+ self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
94
+
95
+ def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
96
+ output_states = ()
97
+
98
+ for resnet, attn in zip(self.resnets, self.attentions):
99
+ hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
100
+ hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
101
+ output_states += (hidden_states,)
102
+
103
+ if self.add_downsample:
104
+ hidden_states = self.downsamplers_0(hidden_states)
105
+ output_states += (hidden_states,)
106
+
107
+ return hidden_states, output_states
108
+
109
+
110
+ class FlaxDownBlock2D(nn.Module):
111
+ r"""
112
+ Flax 2D downsizing block
113
+
114
+ Parameters:
115
+ in_channels (:obj:`int`):
116
+ Input channels
117
+ out_channels (:obj:`int`):
118
+ Output channels
119
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
120
+ Dropout rate
121
+ num_layers (:obj:`int`, *optional*, defaults to 1):
122
+ Number of attention blocks layers
123
+ add_downsample (:obj:`bool`, *optional*, defaults to `True`):
124
+ Whether to add downsampling layer before each final output
125
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
126
+ Parameters `dtype`
127
+ """
128
+ in_channels: int
129
+ out_channels: int
130
+ dropout: float = 0.0
131
+ num_layers: int = 1
132
+ add_downsample: bool = True
133
+ dtype: jnp.dtype = jnp.float32
134
+
135
+ def setup(self):
136
+ resnets = []
137
+
138
+ for i in range(self.num_layers):
139
+ in_channels = self.in_channels if i == 0 else self.out_channels
140
+
141
+ res_block = FlaxResnetBlock2D(
142
+ in_channels=in_channels,
143
+ out_channels=self.out_channels,
144
+ dropout_prob=self.dropout,
145
+ dtype=self.dtype,
146
+ )
147
+ resnets.append(res_block)
148
+ self.resnets = resnets
149
+
150
+ if self.add_downsample:
151
+ self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
152
+
153
+ def __call__(self, hidden_states, temb, deterministic=True):
154
+ output_states = ()
155
+
156
+ for resnet in self.resnets:
157
+ hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
158
+ output_states += (hidden_states,)
159
+
160
+ if self.add_downsample:
161
+ hidden_states = self.downsamplers_0(hidden_states)
162
+ output_states += (hidden_states,)
163
+
164
+ return hidden_states, output_states
165
+
166
+
167
+ class FlaxCrossAttnUpBlock2D(nn.Module):
168
+ r"""
169
+ Cross Attention 2D Upsampling block - original architecture from Unet transformers:
170
+ https://arxiv.org/abs/2103.06104
171
+
172
+ Parameters:
173
+ in_channels (:obj:`int`):
174
+ Input channels
175
+ out_channels (:obj:`int`):
176
+ Output channels
177
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
178
+ Dropout rate
179
+ num_layers (:obj:`int`, *optional*, defaults to 1):
180
+ Number of attention blocks layers
181
+ num_attention_heads (:obj:`int`, *optional*, defaults to 1):
182
+ Number of attention heads of each spatial transformer block
183
+ add_upsample (:obj:`bool`, *optional*, defaults to `True`):
184
+ Whether to add upsampling layer before each final output
185
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
186
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
187
+ split_head_dim (`bool`, *optional*, defaults to `False`):
188
+ Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
189
+ enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
190
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
191
+ Parameters `dtype`
192
+ """
193
+ in_channels: int
194
+ out_channels: int
195
+ prev_output_channel: int
196
+ dropout: float = 0.0
197
+ num_layers: int = 1
198
+ num_attention_heads: int = 1
199
+ add_upsample: bool = True
200
+ use_linear_projection: bool = False
201
+ only_cross_attention: bool = False
202
+ use_memory_efficient_attention: bool = False
203
+ split_head_dim: bool = False
204
+ dtype: jnp.dtype = jnp.float32
205
+ transformer_layers_per_block: int = 1
206
+
207
+ def setup(self):
208
+ resnets = []
209
+ attentions = []
210
+
211
+ for i in range(self.num_layers):
212
+ res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
213
+ resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
214
+
215
+ res_block = FlaxResnetBlock2D(
216
+ in_channels=resnet_in_channels + res_skip_channels,
217
+ out_channels=self.out_channels,
218
+ dropout_prob=self.dropout,
219
+ dtype=self.dtype,
220
+ )
221
+ resnets.append(res_block)
222
+
223
+ attn_block = FlaxTransformer2DModel(
224
+ in_channels=self.out_channels,
225
+ n_heads=self.num_attention_heads,
226
+ d_head=self.out_channels // self.num_attention_heads,
227
+ depth=self.transformer_layers_per_block,
228
+ use_linear_projection=self.use_linear_projection,
229
+ only_cross_attention=self.only_cross_attention,
230
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
231
+ split_head_dim=self.split_head_dim,
232
+ dtype=self.dtype,
233
+ )
234
+ attentions.append(attn_block)
235
+
236
+ self.resnets = resnets
237
+ self.attentions = attentions
238
+
239
+ if self.add_upsample:
240
+ self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
241
+
242
+ def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True):
243
+ for resnet, attn in zip(self.resnets, self.attentions):
244
+ # pop res hidden states
245
+ res_hidden_states = res_hidden_states_tuple[-1]
246
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
247
+ hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
248
+
249
+ hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
250
+ hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
251
+
252
+ if self.add_upsample:
253
+ hidden_states = self.upsamplers_0(hidden_states)
254
+
255
+ return hidden_states
256
+
257
+
258
+ class FlaxUpBlock2D(nn.Module):
259
+ r"""
260
+ Flax 2D upsampling block
261
+
262
+ Parameters:
263
+ in_channels (:obj:`int`):
264
+ Input channels
265
+ out_channels (:obj:`int`):
266
+ Output channels
267
+ prev_output_channel (:obj:`int`):
268
+ Output channels from the previous block
269
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
270
+ Dropout rate
271
+ num_layers (:obj:`int`, *optional*, defaults to 1):
272
+ Number of attention blocks layers
273
+ add_downsample (:obj:`bool`, *optional*, defaults to `True`):
274
+ Whether to add downsampling layer before each final output
275
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
276
+ Parameters `dtype`
277
+ """
278
+ in_channels: int
279
+ out_channels: int
280
+ prev_output_channel: int
281
+ dropout: float = 0.0
282
+ num_layers: int = 1
283
+ add_upsample: bool = True
284
+ dtype: jnp.dtype = jnp.float32
285
+
286
+ def setup(self):
287
+ resnets = []
288
+
289
+ for i in range(self.num_layers):
290
+ res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
291
+ resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
292
+
293
+ res_block = FlaxResnetBlock2D(
294
+ in_channels=resnet_in_channels + res_skip_channels,
295
+ out_channels=self.out_channels,
296
+ dropout_prob=self.dropout,
297
+ dtype=self.dtype,
298
+ )
299
+ resnets.append(res_block)
300
+
301
+ self.resnets = resnets
302
+
303
+ if self.add_upsample:
304
+ self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
305
+
306
+ def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True):
307
+ for resnet in self.resnets:
308
+ # pop res hidden states
309
+ res_hidden_states = res_hidden_states_tuple[-1]
310
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
311
+ hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
312
+
313
+ hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
314
+
315
+ if self.add_upsample:
316
+ hidden_states = self.upsamplers_0(hidden_states)
317
+
318
+ return hidden_states
319
+
320
+
321
+ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
322
+ r"""
323
+ Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104
324
+
325
+ Parameters:
326
+ in_channels (:obj:`int`):
327
+ Input channels
328
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
329
+ Dropout rate
330
+ num_layers (:obj:`int`, *optional*, defaults to 1):
331
+ Number of attention blocks layers
332
+ num_attention_heads (:obj:`int`, *optional*, defaults to 1):
333
+ Number of attention heads of each spatial transformer block
334
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
335
+ enable memory efficient attention https://arxiv.org/abs/2112.05682
336
+ split_head_dim (`bool`, *optional*, defaults to `False`):
337
+ Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
338
+ enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
339
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
340
+ Parameters `dtype`
341
+ """
342
+ in_channels: int
343
+ dropout: float = 0.0
344
+ num_layers: int = 1
345
+ num_attention_heads: int = 1
346
+ use_linear_projection: bool = False
347
+ use_memory_efficient_attention: bool = False
348
+ split_head_dim: bool = False
349
+ dtype: jnp.dtype = jnp.float32
350
+ transformer_layers_per_block: int = 1
351
+
352
+ def setup(self):
353
+ # there is always at least one resnet
354
+ resnets = [
355
+ FlaxResnetBlock2D(
356
+ in_channels=self.in_channels,
357
+ out_channels=self.in_channels,
358
+ dropout_prob=self.dropout,
359
+ dtype=self.dtype,
360
+ )
361
+ ]
362
+
363
+ attentions = []
364
+
365
+ for _ in range(self.num_layers):
366
+ attn_block = FlaxTransformer2DModel(
367
+ in_channels=self.in_channels,
368
+ n_heads=self.num_attention_heads,
369
+ d_head=self.in_channels // self.num_attention_heads,
370
+ depth=self.transformer_layers_per_block,
371
+ use_linear_projection=self.use_linear_projection,
372
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
373
+ split_head_dim=self.split_head_dim,
374
+ dtype=self.dtype,
375
+ )
376
+ attentions.append(attn_block)
377
+
378
+ res_block = FlaxResnetBlock2D(
379
+ in_channels=self.in_channels,
380
+ out_channels=self.in_channels,
381
+ dropout_prob=self.dropout,
382
+ dtype=self.dtype,
383
+ )
384
+ resnets.append(res_block)
385
+
386
+ self.resnets = resnets
387
+ self.attentions = attentions
388
+
389
+ def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
390
+ hidden_states = self.resnets[0](hidden_states, temb)
391
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
392
+ hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
393
+ hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
394
+
395
+ return hidden_states
diffusers/models/unet_2d_condition.py ADDED
@@ -0,0 +1,1163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..loaders import UNet2DConditionLoadersMixin
23
+ from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
24
+ from .activations import get_activation
25
+ from .attention_processor import (
26
+ ADDED_KV_ATTENTION_PROCESSORS,
27
+ CROSS_ATTENTION_PROCESSORS,
28
+ AttentionProcessor,
29
+ AttnAddedKVProcessor,
30
+ AttnProcessor,
31
+ )
32
+ from .embeddings import (
33
+ GaussianFourierProjection,
34
+ ImageHintTimeEmbedding,
35
+ ImageProjection,
36
+ ImageTimeEmbedding,
37
+ PositionNet,
38
+ TextImageProjection,
39
+ TextImageTimeEmbedding,
40
+ TextTimeEmbedding,
41
+ TimestepEmbedding,
42
+ Timesteps,
43
+ )
44
+ from .modeling_utils import ModelMixin
45
+ from .unet_2d_blocks import (
46
+ UNetMidBlock2D,
47
+ UNetMidBlock2DCrossAttn,
48
+ UNetMidBlock2DSimpleCrossAttn,
49
+ get_down_block,
50
+ get_up_block,
51
+ )
52
+
53
+
54
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
55
+
56
+
57
+ @dataclass
58
+ class UNet2DConditionOutput(BaseOutput):
59
+ """
60
+ The output of [`UNet2DConditionModel`].
61
+
62
+ Args:
63
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
64
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
65
+ """
66
+
67
+ sample: torch.FloatTensor = None
68
+
69
+
70
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
71
+ r"""
72
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
73
+ shaped output.
74
+
75
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
76
+ for all models (such as downloading or saving).
77
+
78
+ Parameters:
79
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
80
+ Height and width of input/output sample.
81
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
82
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
83
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
84
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
85
+ Whether to flip the sin to cos in the time embedding.
86
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
87
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
88
+ The tuple of downsample blocks to use.
89
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
90
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
91
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
92
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
93
+ The tuple of upsample blocks to use.
94
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
95
+ Whether to include self-attention in the basic transformer blocks, see
96
+ [`~models.attention.BasicTransformerBlock`].
97
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
98
+ The tuple of output channels for each block.
99
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
100
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
101
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
102
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
103
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
104
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
105
+ If `None`, normalization and activation layers is skipped in post-processing.
106
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
107
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
108
+ The dimension of the cross attention features.
109
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
110
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
111
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
112
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
113
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
114
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
115
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
116
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
117
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
118
+ encoder_hid_dim (`int`, *optional*, defaults to None):
119
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
120
+ dimension to `cross_attention_dim`.
121
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
122
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
123
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
124
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
125
+ num_attention_heads (`int`, *optional*):
126
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
127
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
128
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
129
+ class_embed_type (`str`, *optional*, defaults to `None`):
130
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
131
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
132
+ addition_embed_type (`str`, *optional*, defaults to `None`):
133
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
134
+ "text". "text" will use the `TextTimeEmbedding` layer.
135
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
136
+ Dimension for the timestep embeddings.
137
+ num_class_embeds (`int`, *optional*, defaults to `None`):
138
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
139
+ class conditioning with `class_embed_type` equal to `None`.
140
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
141
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
142
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
143
+ An optional override for the dimension of the projected time embedding.
144
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
145
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
146
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
147
+ timestep_post_act (`str`, *optional*, defaults to `None`):
148
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
149
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
150
+ The dimension of `cond_proj` layer in the timestep embedding.
151
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
152
+ *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
153
+ *optional*): The dimension of the `class_labels` input when
154
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
155
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
156
+ embeddings with the class embeddings.
157
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
158
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
159
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
160
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
161
+ otherwise.
162
+ """
163
+
164
+ _supports_gradient_checkpointing = True
165
+
166
+ @register_to_config
167
+ def __init__(
168
+ self,
169
+ sample_size: Optional[int] = None,
170
+ in_channels: int = 4,
171
+ out_channels: int = 4,
172
+ center_input_sample: bool = False,
173
+ flip_sin_to_cos: bool = True,
174
+ freq_shift: int = 0,
175
+ down_block_types: Tuple[str] = (
176
+ "CrossAttnDownBlock2D",
177
+ "CrossAttnDownBlock2D",
178
+ "CrossAttnDownBlock2D",
179
+ "DownBlock2D",
180
+ ),
181
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
182
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
183
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
184
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
185
+ layers_per_block: Union[int, Tuple[int]] = 2,
186
+ downsample_padding: int = 1,
187
+ mid_block_scale_factor: float = 1,
188
+ dropout: float = 0.0,
189
+ act_fn: str = "silu",
190
+ norm_num_groups: Optional[int] = 32,
191
+ norm_eps: float = 1e-5,
192
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
193
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
194
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
195
+ encoder_hid_dim: Optional[int] = None,
196
+ encoder_hid_dim_type: Optional[str] = None,
197
+ attention_head_dim: Union[int, Tuple[int]] = 8,
198
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
199
+ dual_cross_attention: bool = False,
200
+ use_linear_projection: bool = False,
201
+ class_embed_type: Optional[str] = None,
202
+ addition_embed_type: Optional[str] = None,
203
+ addition_time_embed_dim: Optional[int] = None,
204
+ num_class_embeds: Optional[int] = None,
205
+ upcast_attention: bool = False,
206
+ resnet_time_scale_shift: str = "default",
207
+ resnet_skip_time_act: bool = False,
208
+ resnet_out_scale_factor: int = 1.0,
209
+ time_embedding_type: str = "positional",
210
+ time_embedding_dim: Optional[int] = None,
211
+ time_embedding_act_fn: Optional[str] = None,
212
+ timestep_post_act: Optional[str] = None,
213
+ time_cond_proj_dim: Optional[int] = None,
214
+ conv_in_kernel: int = 3,
215
+ conv_out_kernel: int = 3,
216
+ projection_class_embeddings_input_dim: Optional[int] = None,
217
+ attention_type: str = "default",
218
+ class_embeddings_concat: bool = False,
219
+ mid_block_only_cross_attention: Optional[bool] = None,
220
+ cross_attention_norm: Optional[str] = None,
221
+ addition_embed_type_num_heads=64,
222
+ ):
223
+ super().__init__()
224
+
225
+ self.sample_size = sample_size
226
+
227
+ if num_attention_heads is not None:
228
+ raise ValueError(
229
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
230
+ )
231
+
232
+ # If `num_attention_heads` is not defined (which is the case for most models)
233
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
234
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
235
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
236
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
237
+ # which is why we correct for the naming here.
238
+ num_attention_heads = num_attention_heads or attention_head_dim
239
+
240
+ # Check inputs
241
+ if len(down_block_types) != len(up_block_types):
242
+ raise ValueError(
243
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
244
+ )
245
+
246
+ if len(block_out_channels) != len(down_block_types):
247
+ raise ValueError(
248
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
249
+ )
250
+
251
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
252
+ raise ValueError(
253
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
254
+ )
255
+
256
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
257
+ raise ValueError(
258
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
259
+ )
260
+
261
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
262
+ raise ValueError(
263
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
264
+ )
265
+
266
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
267
+ raise ValueError(
268
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
269
+ )
270
+
271
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
272
+ raise ValueError(
273
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
274
+ )
275
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
276
+ for layer_number_per_block in transformer_layers_per_block:
277
+ if isinstance(layer_number_per_block, list):
278
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
279
+
280
+ # input
281
+ conv_in_padding = (conv_in_kernel - 1) // 2
282
+ self.conv_in = nn.Conv2d(
283
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
284
+ )
285
+
286
+ # time
287
+ if time_embedding_type == "fourier":
288
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
289
+ if time_embed_dim % 2 != 0:
290
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
291
+ self.time_proj = GaussianFourierProjection(
292
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
293
+ )
294
+ timestep_input_dim = time_embed_dim
295
+ elif time_embedding_type == "positional":
296
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
297
+
298
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
299
+ timestep_input_dim = block_out_channels[0]
300
+ else:
301
+ raise ValueError(
302
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
303
+ )
304
+
305
+ self.time_embedding = TimestepEmbedding(
306
+ timestep_input_dim,
307
+ time_embed_dim,
308
+ act_fn=act_fn,
309
+ post_act_fn=timestep_post_act,
310
+ cond_proj_dim=time_cond_proj_dim,
311
+ )
312
+
313
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
314
+ encoder_hid_dim_type = "text_proj"
315
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
316
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
317
+
318
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
319
+ raise ValueError(
320
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
321
+ )
322
+
323
+ if encoder_hid_dim_type == "text_proj":
324
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
325
+ elif encoder_hid_dim_type == "text_image_proj":
326
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
327
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
328
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
329
+ self.encoder_hid_proj = TextImageProjection(
330
+ text_embed_dim=encoder_hid_dim,
331
+ image_embed_dim=cross_attention_dim,
332
+ cross_attention_dim=cross_attention_dim,
333
+ )
334
+ elif encoder_hid_dim_type == "image_proj":
335
+ # Kandinsky 2.2
336
+ self.encoder_hid_proj = ImageProjection(
337
+ image_embed_dim=encoder_hid_dim,
338
+ cross_attention_dim=cross_attention_dim,
339
+ )
340
+ elif encoder_hid_dim_type is not None:
341
+ raise ValueError(
342
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
343
+ )
344
+ else:
345
+ self.encoder_hid_proj = None
346
+
347
+ # class embedding
348
+ if class_embed_type is None and num_class_embeds is not None:
349
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
350
+ elif class_embed_type == "timestep":
351
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
352
+ elif class_embed_type == "identity":
353
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
354
+ elif class_embed_type == "projection":
355
+ if projection_class_embeddings_input_dim is None:
356
+ raise ValueError(
357
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
358
+ )
359
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
360
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
361
+ # 2. it projects from an arbitrary input dimension.
362
+ #
363
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
364
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
365
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
366
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
367
+ elif class_embed_type == "simple_projection":
368
+ if projection_class_embeddings_input_dim is None:
369
+ raise ValueError(
370
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
371
+ )
372
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
373
+ else:
374
+ self.class_embedding = None
375
+
376
+ if addition_embed_type == "text":
377
+ if encoder_hid_dim is not None:
378
+ text_time_embedding_from_dim = encoder_hid_dim
379
+ else:
380
+ text_time_embedding_from_dim = cross_attention_dim
381
+
382
+ self.add_embedding = TextTimeEmbedding(
383
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
384
+ )
385
+ elif addition_embed_type == "text_image":
386
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
387
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
388
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
389
+ self.add_embedding = TextImageTimeEmbedding(
390
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
391
+ )
392
+ elif addition_embed_type == "text_time":
393
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
394
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
395
+ elif addition_embed_type == "image":
396
+ # Kandinsky 2.2
397
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
398
+ elif addition_embed_type == "image_hint":
399
+ # Kandinsky 2.2 ControlNet
400
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
401
+ elif addition_embed_type is not None:
402
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
403
+
404
+ if time_embedding_act_fn is None:
405
+ self.time_embed_act = None
406
+ else:
407
+ self.time_embed_act = get_activation(time_embedding_act_fn)
408
+
409
+ self.down_blocks = nn.ModuleList([])
410
+ self.up_blocks = nn.ModuleList([])
411
+
412
+ if isinstance(only_cross_attention, bool):
413
+ if mid_block_only_cross_attention is None:
414
+ mid_block_only_cross_attention = only_cross_attention
415
+
416
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
417
+
418
+ if mid_block_only_cross_attention is None:
419
+ mid_block_only_cross_attention = False
420
+
421
+ if isinstance(num_attention_heads, int):
422
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
423
+
424
+ if isinstance(attention_head_dim, int):
425
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
426
+
427
+ if isinstance(cross_attention_dim, int):
428
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
429
+
430
+ if isinstance(layers_per_block, int):
431
+ layers_per_block = [layers_per_block] * len(down_block_types)
432
+
433
+ if isinstance(transformer_layers_per_block, int):
434
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
435
+
436
+ if class_embeddings_concat:
437
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
438
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
439
+ # regular time embeddings
440
+ blocks_time_embed_dim = time_embed_dim * 2
441
+ else:
442
+ blocks_time_embed_dim = time_embed_dim
443
+
444
+ # down
445
+ output_channel = block_out_channels[0]
446
+ for i, down_block_type in enumerate(down_block_types):
447
+ input_channel = output_channel
448
+ output_channel = block_out_channels[i]
449
+ is_final_block = i == len(block_out_channels) - 1
450
+
451
+ down_block = get_down_block(
452
+ down_block_type,
453
+ num_layers=layers_per_block[i],
454
+ transformer_layers_per_block=transformer_layers_per_block[i],
455
+ in_channels=input_channel,
456
+ out_channels=output_channel,
457
+ temb_channels=blocks_time_embed_dim,
458
+ add_downsample=not is_final_block,
459
+ resnet_eps=norm_eps,
460
+ resnet_act_fn=act_fn,
461
+ resnet_groups=norm_num_groups,
462
+ cross_attention_dim=cross_attention_dim[i],
463
+ num_attention_heads=num_attention_heads[i],
464
+ downsample_padding=downsample_padding,
465
+ dual_cross_attention=dual_cross_attention,
466
+ use_linear_projection=use_linear_projection,
467
+ only_cross_attention=only_cross_attention[i],
468
+ upcast_attention=upcast_attention,
469
+ resnet_time_scale_shift=resnet_time_scale_shift,
470
+ attention_type=attention_type,
471
+ resnet_skip_time_act=resnet_skip_time_act,
472
+ resnet_out_scale_factor=resnet_out_scale_factor,
473
+ cross_attention_norm=cross_attention_norm,
474
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
475
+ dropout=dropout,
476
+ )
477
+ self.down_blocks.append(down_block)
478
+
479
+ # mid
480
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
481
+ self.mid_block = UNetMidBlock2DCrossAttn(
482
+ transformer_layers_per_block=transformer_layers_per_block[-1],
483
+ in_channels=block_out_channels[-1],
484
+ temb_channels=blocks_time_embed_dim,
485
+ dropout=dropout,
486
+ resnet_eps=norm_eps,
487
+ resnet_act_fn=act_fn,
488
+ output_scale_factor=mid_block_scale_factor,
489
+ resnet_time_scale_shift=resnet_time_scale_shift,
490
+ cross_attention_dim=cross_attention_dim[-1],
491
+ num_attention_heads=num_attention_heads[-1],
492
+ resnet_groups=norm_num_groups,
493
+ dual_cross_attention=dual_cross_attention,
494
+ use_linear_projection=use_linear_projection,
495
+ upcast_attention=upcast_attention,
496
+ attention_type=attention_type,
497
+ )
498
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
499
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
500
+ in_channels=block_out_channels[-1],
501
+ temb_channels=blocks_time_embed_dim,
502
+ dropout=dropout,
503
+ resnet_eps=norm_eps,
504
+ resnet_act_fn=act_fn,
505
+ output_scale_factor=mid_block_scale_factor,
506
+ cross_attention_dim=cross_attention_dim[-1],
507
+ attention_head_dim=attention_head_dim[-1],
508
+ resnet_groups=norm_num_groups,
509
+ resnet_time_scale_shift=resnet_time_scale_shift,
510
+ skip_time_act=resnet_skip_time_act,
511
+ only_cross_attention=mid_block_only_cross_attention,
512
+ cross_attention_norm=cross_attention_norm,
513
+ )
514
+ elif mid_block_type == "UNetMidBlock2D":
515
+ self.mid_block = UNetMidBlock2D(
516
+ in_channels=block_out_channels[-1],
517
+ temb_channels=blocks_time_embed_dim,
518
+ dropout=dropout,
519
+ num_layers=0,
520
+ resnet_eps=norm_eps,
521
+ resnet_act_fn=act_fn,
522
+ output_scale_factor=mid_block_scale_factor,
523
+ resnet_groups=norm_num_groups,
524
+ resnet_time_scale_shift=resnet_time_scale_shift,
525
+ add_attention=False,
526
+ )
527
+ elif mid_block_type is None:
528
+ self.mid_block = None
529
+ else:
530
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
531
+
532
+ # count how many layers upsample the images
533
+ self.num_upsamplers = 0
534
+
535
+ # up
536
+ reversed_block_out_channels = list(reversed(block_out_channels))
537
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
538
+ reversed_layers_per_block = list(reversed(layers_per_block))
539
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
540
+ reversed_transformer_layers_per_block = (
541
+ list(reversed(transformer_layers_per_block))
542
+ if reverse_transformer_layers_per_block is None
543
+ else reverse_transformer_layers_per_block
544
+ )
545
+ only_cross_attention = list(reversed(only_cross_attention))
546
+
547
+ output_channel = reversed_block_out_channels[0]
548
+ for i, up_block_type in enumerate(up_block_types):
549
+ is_final_block = i == len(block_out_channels) - 1
550
+
551
+ prev_output_channel = output_channel
552
+ output_channel = reversed_block_out_channels[i]
553
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
554
+
555
+ # add upsample block for all BUT final layer
556
+ if not is_final_block:
557
+ add_upsample = True
558
+ self.num_upsamplers += 1
559
+ else:
560
+ add_upsample = False
561
+
562
+ up_block = get_up_block(
563
+ up_block_type,
564
+ num_layers=reversed_layers_per_block[i] + 1,
565
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
566
+ in_channels=input_channel,
567
+ out_channels=output_channel,
568
+ prev_output_channel=prev_output_channel,
569
+ temb_channels=blocks_time_embed_dim,
570
+ add_upsample=add_upsample,
571
+ resnet_eps=norm_eps,
572
+ resnet_act_fn=act_fn,
573
+ resolution_idx=i,
574
+ resnet_groups=norm_num_groups,
575
+ cross_attention_dim=reversed_cross_attention_dim[i],
576
+ num_attention_heads=reversed_num_attention_heads[i],
577
+ dual_cross_attention=dual_cross_attention,
578
+ use_linear_projection=use_linear_projection,
579
+ only_cross_attention=only_cross_attention[i],
580
+ upcast_attention=upcast_attention,
581
+ resnet_time_scale_shift=resnet_time_scale_shift,
582
+ attention_type=attention_type,
583
+ resnet_skip_time_act=resnet_skip_time_act,
584
+ resnet_out_scale_factor=resnet_out_scale_factor,
585
+ cross_attention_norm=cross_attention_norm,
586
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
587
+ dropout=dropout,
588
+ )
589
+ self.up_blocks.append(up_block)
590
+ prev_output_channel = output_channel
591
+
592
+ # out
593
+ if norm_num_groups is not None:
594
+ self.conv_norm_out = nn.GroupNorm(
595
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
596
+ )
597
+
598
+ self.conv_act = get_activation(act_fn)
599
+
600
+ else:
601
+ self.conv_norm_out = None
602
+ self.conv_act = None
603
+
604
+ conv_out_padding = (conv_out_kernel - 1) // 2
605
+ self.conv_out = nn.Conv2d(
606
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
607
+ )
608
+
609
+ if attention_type in ["gated", "gated-text-image"]:
610
+ positive_len = 768
611
+ if isinstance(cross_attention_dim, int):
612
+ positive_len = cross_attention_dim
613
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
614
+ positive_len = cross_attention_dim[0]
615
+
616
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
617
+ self.position_net = PositionNet(
618
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
619
+ )
620
+
621
+ @property
622
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
623
+ r"""
624
+ Returns:
625
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
626
+ indexed by its weight name.
627
+ """
628
+ # set recursively
629
+ processors = {}
630
+
631
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
632
+ if hasattr(module, "get_processor"):
633
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
634
+
635
+ for sub_name, child in module.named_children():
636
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
637
+
638
+ return processors
639
+
640
+ for name, module in self.named_children():
641
+ fn_recursive_add_processors(name, module, processors)
642
+
643
+ return processors
644
+
645
+ def set_attn_processor(
646
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
647
+ ):
648
+ r"""
649
+ Sets the attention processor to use to compute attention.
650
+
651
+ Parameters:
652
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
653
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
654
+ for **all** `Attention` layers.
655
+
656
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
657
+ processor. This is strongly recommended when setting trainable attention processors.
658
+
659
+ """
660
+ count = len(self.attn_processors.keys())
661
+
662
+ if isinstance(processor, dict) and len(processor) != count:
663
+ raise ValueError(
664
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
665
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
666
+ )
667
+
668
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
669
+ if hasattr(module, "set_processor"):
670
+ if not isinstance(processor, dict):
671
+ module.set_processor(processor, _remove_lora=_remove_lora)
672
+ else:
673
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
674
+
675
+ for sub_name, child in module.named_children():
676
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
677
+
678
+ for name, module in self.named_children():
679
+ fn_recursive_attn_processor(name, module, processor)
680
+
681
+ def set_default_attn_processor(self):
682
+ """
683
+ Disables custom attention processors and sets the default attention implementation.
684
+ """
685
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
686
+ processor = AttnAddedKVProcessor()
687
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
688
+ processor = AttnProcessor()
689
+ else:
690
+ raise ValueError(
691
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
692
+ )
693
+
694
+ self.set_attn_processor(processor, _remove_lora=True)
695
+
696
+ def set_attention_slice(self, slice_size):
697
+ r"""
698
+ Enable sliced attention computation.
699
+
700
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
701
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
702
+
703
+ Args:
704
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
705
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
706
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
707
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
708
+ must be a multiple of `slice_size`.
709
+ """
710
+ sliceable_head_dims = []
711
+
712
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
713
+ if hasattr(module, "set_attention_slice"):
714
+ sliceable_head_dims.append(module.sliceable_head_dim)
715
+
716
+ for child in module.children():
717
+ fn_recursive_retrieve_sliceable_dims(child)
718
+
719
+ # retrieve number of attention layers
720
+ for module in self.children():
721
+ fn_recursive_retrieve_sliceable_dims(module)
722
+
723
+ num_sliceable_layers = len(sliceable_head_dims)
724
+
725
+ if slice_size == "auto":
726
+ # half the attention head size is usually a good trade-off between
727
+ # speed and memory
728
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
729
+ elif slice_size == "max":
730
+ # make smallest slice possible
731
+ slice_size = num_sliceable_layers * [1]
732
+
733
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
734
+
735
+ if len(slice_size) != len(sliceable_head_dims):
736
+ raise ValueError(
737
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
738
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
739
+ )
740
+
741
+ for i in range(len(slice_size)):
742
+ size = slice_size[i]
743
+ dim = sliceable_head_dims[i]
744
+ if size is not None and size > dim:
745
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
746
+
747
+ # Recursively walk through all the children.
748
+ # Any children which exposes the set_attention_slice method
749
+ # gets the message
750
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
751
+ if hasattr(module, "set_attention_slice"):
752
+ module.set_attention_slice(slice_size.pop())
753
+
754
+ for child in module.children():
755
+ fn_recursive_set_attention_slice(child, slice_size)
756
+
757
+ reversed_slice_size = list(reversed(slice_size))
758
+ for module in self.children():
759
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
760
+
761
+ def _set_gradient_checkpointing(self, module, value=False):
762
+ if hasattr(module, "gradient_checkpointing"):
763
+ module.gradient_checkpointing = value
764
+
765
+ def enable_freeu(self, s1, s2, b1, b2):
766
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
767
+
768
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
769
+
770
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
771
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
772
+
773
+ Args:
774
+ s1 (`float`):
775
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
776
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
777
+ s2 (`float`):
778
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
779
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
780
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
781
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
782
+ """
783
+ for i, upsample_block in enumerate(self.up_blocks):
784
+ setattr(upsample_block, "s1", s1)
785
+ setattr(upsample_block, "s2", s2)
786
+ setattr(upsample_block, "b1", b1)
787
+ setattr(upsample_block, "b2", b2)
788
+
789
+ def disable_freeu(self):
790
+ """Disables the FreeU mechanism."""
791
+ freeu_keys = {"s1", "s2", "b1", "b2"}
792
+ for i, upsample_block in enumerate(self.up_blocks):
793
+ for k in freeu_keys:
794
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
795
+ setattr(upsample_block, k, None)
796
+
797
+ def forward(
798
+ self,
799
+ sample: torch.FloatTensor,
800
+ timestep: Union[torch.Tensor, float, int],
801
+ encoder_hidden_states: torch.Tensor,
802
+ class_labels: Optional[torch.Tensor] = None,
803
+ timestep_cond: Optional[torch.Tensor] = None,
804
+ attention_mask: Optional[torch.Tensor] = None,
805
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
806
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
807
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
808
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
809
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
810
+ encoder_attention_mask: Optional[torch.Tensor] = None,
811
+ return_dict: bool = True,
812
+ ) -> Union[UNet2DConditionOutput, Tuple]:
813
+ r"""
814
+ The [`UNet2DConditionModel`] forward method.
815
+
816
+ Args:
817
+ sample (`torch.FloatTensor`):
818
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
819
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
820
+ encoder_hidden_states (`torch.FloatTensor`):
821
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
822
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
823
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
824
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
825
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
826
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
827
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
828
+ An attention face_hair_mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the face_hair_mask
829
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
830
+ negative values to the attention scores corresponding to "discard" tokens.
831
+ cross_attention_kwargs (`dict`, *optional*):
832
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
833
+ `self.processor` in
834
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
835
+ added_cond_kwargs: (`dict`, *optional*):
836
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
837
+ are passed along to the UNet blocks.
838
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
839
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
840
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
841
+ A tensor that if specified is added to the residual of the middle unet block.
842
+ encoder_attention_mask (`torch.Tensor`):
843
+ A cross-attention face_hair_mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
844
+ `True` the face_hair_mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
845
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
846
+ return_dict (`bool`, *optional*, defaults to `True`):
847
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
848
+ tuple.
849
+ cross_attention_kwargs (`dict`, *optional*):
850
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
851
+ added_cond_kwargs: (`dict`, *optional*):
852
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
853
+ are passed along to the UNet blocks.
854
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
855
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
856
+ example from ControlNet side model(s)
857
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
858
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
859
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
860
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
861
+
862
+ Returns:
863
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
864
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
865
+ a `tuple` is returned where the first element is the sample tensor.
866
+ """
867
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
868
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
869
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
870
+ # on the fly if necessary.
871
+ default_overall_up_factor = 2**self.num_upsamplers
872
+
873
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
874
+ forward_upsample_size = False
875
+ upsample_size = None
876
+
877
+ for dim in sample.shape[-2:]:
878
+ if dim % default_overall_up_factor != 0:
879
+ # Forward upsample size to force interpolation output size.
880
+ forward_upsample_size = True
881
+ break
882
+
883
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
884
+ # expects face_hair_mask of shape:
885
+ # [batch, key_tokens]
886
+ # adds singleton query_tokens dimension:
887
+ # [batch, 1, key_tokens]
888
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
889
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
890
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
891
+ if attention_mask is not None:
892
+ # assume that face_hair_mask is expressed as:
893
+ # (1 = keep, 0 = discard)
894
+ # convert face_hair_mask into a bias that can be added to attention scores:
895
+ # (keep = +0, discard = -10000.0)
896
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
897
+ attention_mask = attention_mask.unsqueeze(1)
898
+
899
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
900
+ if encoder_attention_mask is not None:
901
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
902
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
903
+
904
+ # 0. center input if necessary
905
+ if self.config.center_input_sample:
906
+ sample = 2 * sample - 1.0
907
+
908
+ # 1. time
909
+ timesteps = timestep
910
+ if not torch.is_tensor(timesteps):
911
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
912
+ # This would be a good case for the `match` statement (Python 3.10+)
913
+ is_mps = sample.device.type == "mps"
914
+ if isinstance(timestep, float):
915
+ dtype = torch.float32 if is_mps else torch.float64
916
+ else:
917
+ dtype = torch.int32 if is_mps else torch.int64
918
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
919
+ elif len(timesteps.shape) == 0:
920
+ timesteps = timesteps[None].to(sample.device)
921
+
922
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
923
+ timesteps = timesteps.expand(sample.shape[0])
924
+
925
+ t_emb = self.time_proj(timesteps)
926
+
927
+ # `Timesteps` does not contain any weights and will always return f32 tensors
928
+ # but time_embedding might actually be running in fp16. so we need to cast here.
929
+ # there might be better ways to encapsulate this.
930
+ t_emb = t_emb.to(dtype=sample.dtype)
931
+
932
+ emb = self.time_embedding(t_emb, timestep_cond)
933
+ aug_emb = None
934
+
935
+ if self.class_embedding is not None:
936
+ if class_labels is None:
937
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
938
+
939
+ if self.config.class_embed_type == "timestep":
940
+ class_labels = self.time_proj(class_labels)
941
+
942
+ # `Timesteps` does not contain any weights and will always return f32 tensors
943
+ # there might be better ways to encapsulate this.
944
+ class_labels = class_labels.to(dtype=sample.dtype)
945
+
946
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
947
+
948
+ if self.config.class_embeddings_concat:
949
+ emb = torch.cat([emb, class_emb], dim=-1)
950
+ else:
951
+ emb = emb + class_emb
952
+
953
+ if self.config.addition_embed_type == "text":
954
+ aug_emb = self.add_embedding(encoder_hidden_states)
955
+ elif self.config.addition_embed_type == "text_image":
956
+ # Kandinsky 2.1 - style
957
+ if "image_embeds" not in added_cond_kwargs:
958
+ raise ValueError(
959
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
960
+ )
961
+
962
+ image_embs = added_cond_kwargs.get("image_embeds")
963
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
964
+ aug_emb = self.add_embedding(text_embs, image_embs)
965
+ elif self.config.addition_embed_type == "text_time":
966
+ # SDXL - style
967
+ if "text_embeds" not in added_cond_kwargs:
968
+ raise ValueError(
969
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
970
+ )
971
+ text_embeds = added_cond_kwargs.get("text_embeds")
972
+ if "time_ids" not in added_cond_kwargs:
973
+ raise ValueError(
974
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
975
+ )
976
+ time_ids = added_cond_kwargs.get("time_ids")
977
+ time_embeds = self.add_time_proj(time_ids.flatten())
978
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
979
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
980
+ add_embeds = add_embeds.to(emb.dtype)
981
+ aug_emb = self.add_embedding(add_embeds)
982
+ elif self.config.addition_embed_type == "image":
983
+ # Kandinsky 2.2 - style
984
+ if "image_embeds" not in added_cond_kwargs:
985
+ raise ValueError(
986
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
987
+ )
988
+ image_embs = added_cond_kwargs.get("image_embeds")
989
+ aug_emb = self.add_embedding(image_embs)
990
+ elif self.config.addition_embed_type == "image_hint":
991
+ # Kandinsky 2.2 - style
992
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
993
+ raise ValueError(
994
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
995
+ )
996
+ image_embs = added_cond_kwargs.get("image_embeds")
997
+ hint = added_cond_kwargs.get("hint")
998
+ aug_emb, hint = self.add_embedding(image_embs, hint)
999
+ sample = torch.cat([sample, hint], dim=1)
1000
+
1001
+ emb = emb + aug_emb if aug_emb is not None else emb
1002
+
1003
+ if self.time_embed_act is not None:
1004
+ emb = self.time_embed_act(emb)
1005
+
1006
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1007
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1008
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1009
+ # Kadinsky 2.1 - style
1010
+ if "image_embeds" not in added_cond_kwargs:
1011
+ raise ValueError(
1012
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1013
+ )
1014
+
1015
+ image_embeds = added_cond_kwargs.get("image_embeds")
1016
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1017
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1018
+ # Kandinsky 2.2 - style
1019
+ if "image_embeds" not in added_cond_kwargs:
1020
+ raise ValueError(
1021
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1022
+ )
1023
+ image_embeds = added_cond_kwargs.get("image_embeds")
1024
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1025
+ # 2. pre-process
1026
+ sample = self.conv_in(sample)
1027
+
1028
+ # 2.5 GLIGEN position net
1029
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1030
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1031
+ gligen_args = cross_attention_kwargs.pop("gligen")
1032
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1033
+
1034
+ # 3. down
1035
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1036
+ if USE_PEFT_BACKEND:
1037
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1038
+ scale_lora_layers(self, lora_scale)
1039
+
1040
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1041
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1042
+ is_adapter = down_intrablock_additional_residuals is not None
1043
+ # maintain backward compatibility for legacy usage, where
1044
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1045
+ # but can only use one or the other
1046
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
1047
+ deprecate(
1048
+ "T2I should not use down_block_additional_residuals",
1049
+ "1.3.0",
1050
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1051
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1052
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1053
+ standard_warn=False,
1054
+ )
1055
+ down_intrablock_additional_residuals = down_block_additional_residuals
1056
+ is_adapter = True
1057
+
1058
+ down_block_res_samples = (sample,)
1059
+ for downsample_block in self.down_blocks:
1060
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1061
+ # For t2i-adapter CrossAttnDownBlock2D
1062
+ additional_residuals = {}
1063
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1064
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
1065
+
1066
+ sample, res_samples = downsample_block(
1067
+ hidden_states=sample,
1068
+ temb=emb,
1069
+ encoder_hidden_states=encoder_hidden_states,
1070
+ attention_mask=attention_mask,
1071
+ cross_attention_kwargs=cross_attention_kwargs,
1072
+ encoder_attention_mask=encoder_attention_mask,
1073
+ **additional_residuals,
1074
+ )
1075
+ else:
1076
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
1077
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1078
+ sample += down_intrablock_additional_residuals.pop(0)
1079
+
1080
+ down_block_res_samples += res_samples
1081
+
1082
+ if is_controlnet:
1083
+ new_down_block_res_samples = ()
1084
+
1085
+ for down_block_res_sample, down_block_additional_residual in zip(
1086
+ down_block_res_samples, down_block_additional_residuals
1087
+ ):
1088
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1089
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1090
+
1091
+ down_block_res_samples = new_down_block_res_samples
1092
+
1093
+ # 4. mid
1094
+ if self.mid_block is not None:
1095
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1096
+ sample = self.mid_block(
1097
+ sample,
1098
+ emb,
1099
+ encoder_hidden_states=encoder_hidden_states,
1100
+ attention_mask=attention_mask,
1101
+ cross_attention_kwargs=cross_attention_kwargs,
1102
+ encoder_attention_mask=encoder_attention_mask,
1103
+ )
1104
+ else:
1105
+ sample = self.mid_block(sample, emb)
1106
+
1107
+ # To support T2I-Adapter-XL
1108
+ if (
1109
+ is_adapter
1110
+ and len(down_intrablock_additional_residuals) > 0
1111
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1112
+ ):
1113
+ sample += down_intrablock_additional_residuals.pop(0)
1114
+
1115
+ if is_controlnet:
1116
+ sample = sample + mid_block_additional_residual
1117
+
1118
+ # 5. up
1119
+ for i, upsample_block in enumerate(self.up_blocks):
1120
+ is_final_block = i == len(self.up_blocks) - 1
1121
+
1122
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1123
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1124
+
1125
+ # if we have not reached the final block and need to forward the
1126
+ # upsample size, we do it here
1127
+ if not is_final_block and forward_upsample_size:
1128
+ upsample_size = down_block_res_samples[-1].shape[2:]
1129
+
1130
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1131
+ sample = upsample_block(
1132
+ hidden_states=sample,
1133
+ temb=emb,
1134
+ res_hidden_states_tuple=res_samples,
1135
+ encoder_hidden_states=encoder_hidden_states,
1136
+ cross_attention_kwargs=cross_attention_kwargs,
1137
+ upsample_size=upsample_size,
1138
+ attention_mask=attention_mask,
1139
+ encoder_attention_mask=encoder_attention_mask,
1140
+ )
1141
+ else:
1142
+ sample = upsample_block(
1143
+ hidden_states=sample,
1144
+ temb=emb,
1145
+ res_hidden_states_tuple=res_samples,
1146
+ upsample_size=upsample_size,
1147
+ scale=lora_scale,
1148
+ )
1149
+
1150
+ # 6. post-process
1151
+ if self.conv_norm_out:
1152
+ sample = self.conv_norm_out(sample)
1153
+ sample = self.conv_act(sample)
1154
+ sample = self.conv_out(sample)
1155
+
1156
+ if USE_PEFT_BACKEND:
1157
+ # remove `lora_scale` from each PEFT layer
1158
+ unscale_lora_layers(self, lora_scale)
1159
+
1160
+ if not return_dict:
1161
+ return (sample,)
1162
+
1163
+ return UNet2DConditionOutput(sample=sample)
diffusers/models/unet_2d_condition_flax.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict, Optional, Tuple, Union
15
+
16
+ import flax
17
+ import flax.linen as nn
18
+ import jax
19
+ import jax.numpy as jnp
20
+ from flax.core.frozen_dict import FrozenDict
21
+
22
+ from ..configuration_utils import ConfigMixin, flax_register_to_config
23
+ from ..utils import BaseOutput
24
+ from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
25
+ from .modeling_flax_utils import FlaxModelMixin
26
+ from .unet_2d_blocks_flax import (
27
+ FlaxCrossAttnDownBlock2D,
28
+ FlaxCrossAttnUpBlock2D,
29
+ FlaxDownBlock2D,
30
+ FlaxUNetMidBlock2DCrossAttn,
31
+ FlaxUpBlock2D,
32
+ )
33
+
34
+
35
+ @flax.struct.dataclass
36
+ class FlaxUNet2DConditionOutput(BaseOutput):
37
+ """
38
+ The output of [`FlaxUNet2DConditionModel`].
39
+
40
+ Args:
41
+ sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
42
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
43
+ """
44
+
45
+ sample: jnp.ndarray
46
+
47
+
48
+ @flax_register_to_config
49
+ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
50
+ r"""
51
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
52
+ shaped output.
53
+
54
+ This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods
55
+ implemented for all models (such as downloading or saving).
56
+
57
+ This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
58
+ subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
59
+ general usage and behavior.
60
+
61
+ Inherent JAX features such as the following are supported:
62
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
63
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
64
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
65
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
66
+
67
+ Parameters:
68
+ sample_size (`int`, *optional*):
69
+ The size of the input sample.
70
+ in_channels (`int`, *optional*, defaults to 4):
71
+ The number of channels in the input sample.
72
+ out_channels (`int`, *optional*, defaults to 4):
73
+ The number of channels in the output.
74
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
75
+ The tuple of downsample blocks to use.
76
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
77
+ The tuple of upsample blocks to use.
78
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
79
+ The tuple of output channels for each block.
80
+ layers_per_block (`int`, *optional*, defaults to 2):
81
+ The number of layers per block.
82
+ attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
83
+ The dimension of the attention heads.
84
+ num_attention_heads (`int` or `Tuple[int]`, *optional*):
85
+ The number of attention heads.
86
+ cross_attention_dim (`int`, *optional*, defaults to 768):
87
+ The dimension of the cross attention features.
88
+ dropout (`float`, *optional*, defaults to 0):
89
+ Dropout probability for down, up and bottleneck blocks.
90
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
91
+ Whether to flip the sin to cos in the time embedding.
92
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
93
+ use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
94
+ Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682).
95
+ split_head_dim (`bool`, *optional*, defaults to `False`):
96
+ Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
97
+ enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
98
+ """
99
+
100
+ sample_size: int = 32
101
+ in_channels: int = 4
102
+ out_channels: int = 4
103
+ down_block_types: Tuple[str] = (
104
+ "CrossAttnDownBlock2D",
105
+ "CrossAttnDownBlock2D",
106
+ "CrossAttnDownBlock2D",
107
+ "DownBlock2D",
108
+ )
109
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
110
+ only_cross_attention: Union[bool, Tuple[bool]] = False
111
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
112
+ layers_per_block: int = 2
113
+ attention_head_dim: Union[int, Tuple[int]] = 8
114
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None
115
+ cross_attention_dim: int = 1280
116
+ dropout: float = 0.0
117
+ use_linear_projection: bool = False
118
+ dtype: jnp.dtype = jnp.float32
119
+ flip_sin_to_cos: bool = True
120
+ freq_shift: int = 0
121
+ use_memory_efficient_attention: bool = False
122
+ split_head_dim: bool = False
123
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1
124
+ addition_embed_type: Optional[str] = None
125
+ addition_time_embed_dim: Optional[int] = None
126
+ addition_embed_type_num_heads: int = 64
127
+ projection_class_embeddings_input_dim: Optional[int] = None
128
+
129
+ def init_weights(self, rng: jax.Array) -> FrozenDict:
130
+ # init input tensors
131
+ sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
132
+ sample = jnp.zeros(sample_shape, dtype=jnp.float32)
133
+ timesteps = jnp.ones((1,), dtype=jnp.int32)
134
+ encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
135
+
136
+ params_rng, dropout_rng = jax.random.split(rng)
137
+ rngs = {"params": params_rng, "dropout": dropout_rng}
138
+
139
+ added_cond_kwargs = None
140
+ if self.addition_embed_type == "text_time":
141
+ # we retrieve the expected `text_embeds_dim` by first checking if the architecture is a refiner
142
+ # or non-refiner architecture and then by "reverse-computing" from `projection_class_embeddings_input_dim`
143
+ is_refiner = (
144
+ 5 * self.config.addition_time_embed_dim + self.config.cross_attention_dim
145
+ == self.config.projection_class_embeddings_input_dim
146
+ )
147
+ num_micro_conditions = 5 if is_refiner else 6
148
+
149
+ text_embeds_dim = self.config.projection_class_embeddings_input_dim - (
150
+ num_micro_conditions * self.config.addition_time_embed_dim
151
+ )
152
+
153
+ time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim
154
+ time_ids_dims = time_ids_channels // self.addition_time_embed_dim
155
+ added_cond_kwargs = {
156
+ "text_embeds": jnp.zeros((1, text_embeds_dim), dtype=jnp.float32),
157
+ "time_ids": jnp.zeros((1, time_ids_dims), dtype=jnp.float32),
158
+ }
159
+ return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]
160
+
161
+ def setup(self):
162
+ block_out_channels = self.block_out_channels
163
+ time_embed_dim = block_out_channels[0] * 4
164
+
165
+ if self.num_attention_heads is not None:
166
+ raise ValueError(
167
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
168
+ )
169
+
170
+ # If `num_attention_heads` is not defined (which is the case for most models)
171
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
172
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
173
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
174
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
175
+ # which is why we correct for the naming here.
176
+ num_attention_heads = self.num_attention_heads or self.attention_head_dim
177
+
178
+ # input
179
+ self.conv_in = nn.Conv(
180
+ block_out_channels[0],
181
+ kernel_size=(3, 3),
182
+ strides=(1, 1),
183
+ padding=((1, 1), (1, 1)),
184
+ dtype=self.dtype,
185
+ )
186
+
187
+ # time
188
+ self.time_proj = FlaxTimesteps(
189
+ block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift
190
+ )
191
+ self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
192
+
193
+ only_cross_attention = self.only_cross_attention
194
+ if isinstance(only_cross_attention, bool):
195
+ only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
196
+
197
+ if isinstance(num_attention_heads, int):
198
+ num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
199
+
200
+ # transformer layers per block
201
+ transformer_layers_per_block = self.transformer_layers_per_block
202
+ if isinstance(transformer_layers_per_block, int):
203
+ transformer_layers_per_block = [transformer_layers_per_block] * len(self.down_block_types)
204
+
205
+ # addition embed types
206
+ if self.addition_embed_type is None:
207
+ self.add_embedding = None
208
+ elif self.addition_embed_type == "text_time":
209
+ if self.addition_time_embed_dim is None:
210
+ raise ValueError(
211
+ f"addition_embed_type {self.addition_embed_type} requires `addition_time_embed_dim` to not be None"
212
+ )
213
+ self.add_time_proj = FlaxTimesteps(self.addition_time_embed_dim, self.flip_sin_to_cos, self.freq_shift)
214
+ self.add_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
215
+ else:
216
+ raise ValueError(f"addition_embed_type: {self.addition_embed_type} must be None or `text_time`.")
217
+
218
+ # down
219
+ down_blocks = []
220
+ output_channel = block_out_channels[0]
221
+ for i, down_block_type in enumerate(self.down_block_types):
222
+ input_channel = output_channel
223
+ output_channel = block_out_channels[i]
224
+ is_final_block = i == len(block_out_channels) - 1
225
+
226
+ if down_block_type == "CrossAttnDownBlock2D":
227
+ down_block = FlaxCrossAttnDownBlock2D(
228
+ in_channels=input_channel,
229
+ out_channels=output_channel,
230
+ dropout=self.dropout,
231
+ num_layers=self.layers_per_block,
232
+ transformer_layers_per_block=transformer_layers_per_block[i],
233
+ num_attention_heads=num_attention_heads[i],
234
+ add_downsample=not is_final_block,
235
+ use_linear_projection=self.use_linear_projection,
236
+ only_cross_attention=only_cross_attention[i],
237
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
238
+ split_head_dim=self.split_head_dim,
239
+ dtype=self.dtype,
240
+ )
241
+ else:
242
+ down_block = FlaxDownBlock2D(
243
+ in_channels=input_channel,
244
+ out_channels=output_channel,
245
+ dropout=self.dropout,
246
+ num_layers=self.layers_per_block,
247
+ add_downsample=not is_final_block,
248
+ dtype=self.dtype,
249
+ )
250
+
251
+ down_blocks.append(down_block)
252
+ self.down_blocks = down_blocks
253
+
254
+ # mid
255
+ self.mid_block = FlaxUNetMidBlock2DCrossAttn(
256
+ in_channels=block_out_channels[-1],
257
+ dropout=self.dropout,
258
+ num_attention_heads=num_attention_heads[-1],
259
+ transformer_layers_per_block=transformer_layers_per_block[-1],
260
+ use_linear_projection=self.use_linear_projection,
261
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
262
+ split_head_dim=self.split_head_dim,
263
+ dtype=self.dtype,
264
+ )
265
+
266
+ # up
267
+ up_blocks = []
268
+ reversed_block_out_channels = list(reversed(block_out_channels))
269
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
270
+ only_cross_attention = list(reversed(only_cross_attention))
271
+ output_channel = reversed_block_out_channels[0]
272
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
273
+ for i, up_block_type in enumerate(self.up_block_types):
274
+ prev_output_channel = output_channel
275
+ output_channel = reversed_block_out_channels[i]
276
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
277
+
278
+ is_final_block = i == len(block_out_channels) - 1
279
+
280
+ if up_block_type == "CrossAttnUpBlock2D":
281
+ up_block = FlaxCrossAttnUpBlock2D(
282
+ in_channels=input_channel,
283
+ out_channels=output_channel,
284
+ prev_output_channel=prev_output_channel,
285
+ num_layers=self.layers_per_block + 1,
286
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
287
+ num_attention_heads=reversed_num_attention_heads[i],
288
+ add_upsample=not is_final_block,
289
+ dropout=self.dropout,
290
+ use_linear_projection=self.use_linear_projection,
291
+ only_cross_attention=only_cross_attention[i],
292
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
293
+ split_head_dim=self.split_head_dim,
294
+ dtype=self.dtype,
295
+ )
296
+ else:
297
+ up_block = FlaxUpBlock2D(
298
+ in_channels=input_channel,
299
+ out_channels=output_channel,
300
+ prev_output_channel=prev_output_channel,
301
+ num_layers=self.layers_per_block + 1,
302
+ add_upsample=not is_final_block,
303
+ dropout=self.dropout,
304
+ dtype=self.dtype,
305
+ )
306
+
307
+ up_blocks.append(up_block)
308
+ prev_output_channel = output_channel
309
+ self.up_blocks = up_blocks
310
+
311
+ # out
312
+ self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5)
313
+ self.conv_out = nn.Conv(
314
+ self.out_channels,
315
+ kernel_size=(3, 3),
316
+ strides=(1, 1),
317
+ padding=((1, 1), (1, 1)),
318
+ dtype=self.dtype,
319
+ )
320
+
321
+ def __call__(
322
+ self,
323
+ sample,
324
+ timesteps,
325
+ encoder_hidden_states,
326
+ added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None,
327
+ down_block_additional_residuals=None,
328
+ mid_block_additional_residual=None,
329
+ return_dict: bool = True,
330
+ train: bool = False,
331
+ ) -> Union[FlaxUNet2DConditionOutput, Tuple]:
332
+ r"""
333
+ Args:
334
+ sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
335
+ timestep (`jnp.ndarray` or `float` or `int`): timesteps
336
+ encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
337
+ added_cond_kwargs: (`dict`, *optional*):
338
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
339
+ are passed along to the UNet blocks.
340
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
341
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
342
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
343
+ A tensor that if specified is added to the residual of the middle unet block.
344
+ return_dict (`bool`, *optional*, defaults to `True`):
345
+ Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
346
+ plain tuple.
347
+ train (`bool`, *optional*, defaults to `False`):
348
+ Use deterministic functions and disable dropout when not training.
349
+
350
+ Returns:
351
+ [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
352
+ [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
353
+ When returning a tuple, the first element is the sample tensor.
354
+ """
355
+ # 1. time
356
+ if not isinstance(timesteps, jnp.ndarray):
357
+ timesteps = jnp.array([timesteps], dtype=jnp.int32)
358
+ elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
359
+ timesteps = timesteps.astype(dtype=jnp.float32)
360
+ timesteps = jnp.expand_dims(timesteps, 0)
361
+
362
+ t_emb = self.time_proj(timesteps)
363
+ t_emb = self.time_embedding(t_emb)
364
+
365
+ # additional embeddings
366
+ aug_emb = None
367
+ if self.addition_embed_type == "text_time":
368
+ if added_cond_kwargs is None:
369
+ raise ValueError(
370
+ f"Need to provide argument `added_cond_kwargs` for {self.__class__} when using `addition_embed_type={self.addition_embed_type}`"
371
+ )
372
+ text_embeds = added_cond_kwargs.get("text_embeds")
373
+ if text_embeds is None:
374
+ raise ValueError(
375
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
376
+ )
377
+ time_ids = added_cond_kwargs.get("time_ids")
378
+ if time_ids is None:
379
+ raise ValueError(
380
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
381
+ )
382
+ # compute time embeds
383
+ time_embeds = self.add_time_proj(jnp.ravel(time_ids)) # (1, 6) => (6,) => (6, 256)
384
+ time_embeds = jnp.reshape(time_embeds, (text_embeds.shape[0], -1))
385
+ add_embeds = jnp.concatenate([text_embeds, time_embeds], axis=-1)
386
+ aug_emb = self.add_embedding(add_embeds)
387
+
388
+ t_emb = t_emb + aug_emb if aug_emb is not None else t_emb
389
+
390
+ # 2. pre-process
391
+ sample = jnp.transpose(sample, (0, 2, 3, 1))
392
+ sample = self.conv_in(sample)
393
+
394
+ # 3. down
395
+ down_block_res_samples = (sample,)
396
+ for down_block in self.down_blocks:
397
+ if isinstance(down_block, FlaxCrossAttnDownBlock2D):
398
+ sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
399
+ else:
400
+ sample, res_samples = down_block(sample, t_emb, deterministic=not train)
401
+ down_block_res_samples += res_samples
402
+
403
+ if down_block_additional_residuals is not None:
404
+ new_down_block_res_samples = ()
405
+
406
+ for down_block_res_sample, down_block_additional_residual in zip(
407
+ down_block_res_samples, down_block_additional_residuals
408
+ ):
409
+ down_block_res_sample += down_block_additional_residual
410
+ new_down_block_res_samples += (down_block_res_sample,)
411
+
412
+ down_block_res_samples = new_down_block_res_samples
413
+
414
+ # 4. mid
415
+ sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
416
+
417
+ if mid_block_additional_residual is not None:
418
+ sample += mid_block_additional_residual
419
+
420
+ # 5. up
421
+ for up_block in self.up_blocks:
422
+ res_samples = down_block_res_samples[-(self.layers_per_block + 1) :]
423
+ down_block_res_samples = down_block_res_samples[: -(self.layers_per_block + 1)]
424
+ if isinstance(up_block, FlaxCrossAttnUpBlock2D):
425
+ sample = up_block(
426
+ sample,
427
+ temb=t_emb,
428
+ encoder_hidden_states=encoder_hidden_states,
429
+ res_hidden_states_tuple=res_samples,
430
+ deterministic=not train,
431
+ )
432
+ else:
433
+ sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples, deterministic=not train)
434
+
435
+ # 6. post-process
436
+ sample = self.conv_norm_out(sample)
437
+ sample = nn.silu(sample)
438
+ sample = self.conv_out(sample)
439
+ sample = jnp.transpose(sample, (0, 3, 1, 2))
440
+
441
+ if not return_dict:
442
+ return (sample,)
443
+
444
+ return FlaxUNet2DConditionOutput(sample=sample)
diffusers/models/unet_3d_blocks.py ADDED
@@ -0,0 +1,1611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional, Tuple
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from ..utils import is_torch_version
21
+ from ..utils.torch_utils import apply_freeu
22
+ from .dual_transformer_2d import DualTransformer2DModel
23
+ from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
24
+ from .transformer_2d import Transformer2DModel
25
+ from .transformer_temporal import TransformerTemporalModel
26
+
27
+
28
+ def get_down_block(
29
+ down_block_type,
30
+ num_layers,
31
+ in_channels,
32
+ out_channels,
33
+ temb_channels,
34
+ add_downsample,
35
+ resnet_eps,
36
+ resnet_act_fn,
37
+ num_attention_heads,
38
+ resnet_groups=None,
39
+ cross_attention_dim=None,
40
+ downsample_padding=None,
41
+ dual_cross_attention=False,
42
+ use_linear_projection=True,
43
+ only_cross_attention=False,
44
+ upcast_attention=False,
45
+ resnet_time_scale_shift="default",
46
+ temporal_num_attention_heads=8,
47
+ temporal_max_seq_length=32,
48
+ ):
49
+ if down_block_type == "DownBlock3D":
50
+ return DownBlock3D(
51
+ num_layers=num_layers,
52
+ in_channels=in_channels,
53
+ out_channels=out_channels,
54
+ temb_channels=temb_channels,
55
+ add_downsample=add_downsample,
56
+ resnet_eps=resnet_eps,
57
+ resnet_act_fn=resnet_act_fn,
58
+ resnet_groups=resnet_groups,
59
+ downsample_padding=downsample_padding,
60
+ resnet_time_scale_shift=resnet_time_scale_shift,
61
+ )
62
+ elif down_block_type == "CrossAttnDownBlock3D":
63
+ if cross_attention_dim is None:
64
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
65
+ return CrossAttnDownBlock3D(
66
+ num_layers=num_layers,
67
+ in_channels=in_channels,
68
+ out_channels=out_channels,
69
+ temb_channels=temb_channels,
70
+ add_downsample=add_downsample,
71
+ resnet_eps=resnet_eps,
72
+ resnet_act_fn=resnet_act_fn,
73
+ resnet_groups=resnet_groups,
74
+ downsample_padding=downsample_padding,
75
+ cross_attention_dim=cross_attention_dim,
76
+ num_attention_heads=num_attention_heads,
77
+ dual_cross_attention=dual_cross_attention,
78
+ use_linear_projection=use_linear_projection,
79
+ only_cross_attention=only_cross_attention,
80
+ upcast_attention=upcast_attention,
81
+ resnet_time_scale_shift=resnet_time_scale_shift,
82
+ )
83
+ if down_block_type == "DownBlockMotion":
84
+ return DownBlockMotion(
85
+ num_layers=num_layers,
86
+ in_channels=in_channels,
87
+ out_channels=out_channels,
88
+ temb_channels=temb_channels,
89
+ add_downsample=add_downsample,
90
+ resnet_eps=resnet_eps,
91
+ resnet_act_fn=resnet_act_fn,
92
+ resnet_groups=resnet_groups,
93
+ downsample_padding=downsample_padding,
94
+ resnet_time_scale_shift=resnet_time_scale_shift,
95
+ temporal_num_attention_heads=temporal_num_attention_heads,
96
+ temporal_max_seq_length=temporal_max_seq_length,
97
+ )
98
+ elif down_block_type == "CrossAttnDownBlockMotion":
99
+ if cross_attention_dim is None:
100
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion")
101
+ return CrossAttnDownBlockMotion(
102
+ num_layers=num_layers,
103
+ in_channels=in_channels,
104
+ out_channels=out_channels,
105
+ temb_channels=temb_channels,
106
+ add_downsample=add_downsample,
107
+ resnet_eps=resnet_eps,
108
+ resnet_act_fn=resnet_act_fn,
109
+ resnet_groups=resnet_groups,
110
+ downsample_padding=downsample_padding,
111
+ cross_attention_dim=cross_attention_dim,
112
+ num_attention_heads=num_attention_heads,
113
+ dual_cross_attention=dual_cross_attention,
114
+ use_linear_projection=use_linear_projection,
115
+ only_cross_attention=only_cross_attention,
116
+ upcast_attention=upcast_attention,
117
+ resnet_time_scale_shift=resnet_time_scale_shift,
118
+ temporal_num_attention_heads=temporal_num_attention_heads,
119
+ temporal_max_seq_length=temporal_max_seq_length,
120
+ )
121
+
122
+ raise ValueError(f"{down_block_type} does not exist.")
123
+
124
+
125
+ def get_up_block(
126
+ up_block_type,
127
+ num_layers,
128
+ in_channels,
129
+ out_channels,
130
+ prev_output_channel,
131
+ temb_channels,
132
+ add_upsample,
133
+ resnet_eps,
134
+ resnet_act_fn,
135
+ num_attention_heads,
136
+ resolution_idx=None,
137
+ resnet_groups=None,
138
+ cross_attention_dim=None,
139
+ dual_cross_attention=False,
140
+ use_linear_projection=True,
141
+ only_cross_attention=False,
142
+ upcast_attention=False,
143
+ resnet_time_scale_shift="default",
144
+ temporal_num_attention_heads=8,
145
+ temporal_cross_attention_dim=None,
146
+ temporal_max_seq_length=32,
147
+ ):
148
+ if up_block_type == "UpBlock3D":
149
+ return UpBlock3D(
150
+ num_layers=num_layers,
151
+ in_channels=in_channels,
152
+ out_channels=out_channels,
153
+ prev_output_channel=prev_output_channel,
154
+ temb_channels=temb_channels,
155
+ add_upsample=add_upsample,
156
+ resnet_eps=resnet_eps,
157
+ resnet_act_fn=resnet_act_fn,
158
+ resnet_groups=resnet_groups,
159
+ resnet_time_scale_shift=resnet_time_scale_shift,
160
+ resolution_idx=resolution_idx,
161
+ )
162
+ elif up_block_type == "CrossAttnUpBlock3D":
163
+ if cross_attention_dim is None:
164
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
165
+ return CrossAttnUpBlock3D(
166
+ num_layers=num_layers,
167
+ in_channels=in_channels,
168
+ out_channels=out_channels,
169
+ prev_output_channel=prev_output_channel,
170
+ temb_channels=temb_channels,
171
+ add_upsample=add_upsample,
172
+ resnet_eps=resnet_eps,
173
+ resnet_act_fn=resnet_act_fn,
174
+ resnet_groups=resnet_groups,
175
+ cross_attention_dim=cross_attention_dim,
176
+ num_attention_heads=num_attention_heads,
177
+ dual_cross_attention=dual_cross_attention,
178
+ use_linear_projection=use_linear_projection,
179
+ only_cross_attention=only_cross_attention,
180
+ upcast_attention=upcast_attention,
181
+ resnet_time_scale_shift=resnet_time_scale_shift,
182
+ resolution_idx=resolution_idx,
183
+ )
184
+ if up_block_type == "UpBlockMotion":
185
+ return UpBlockMotion(
186
+ num_layers=num_layers,
187
+ in_channels=in_channels,
188
+ out_channels=out_channels,
189
+ prev_output_channel=prev_output_channel,
190
+ temb_channels=temb_channels,
191
+ add_upsample=add_upsample,
192
+ resnet_eps=resnet_eps,
193
+ resnet_act_fn=resnet_act_fn,
194
+ resnet_groups=resnet_groups,
195
+ resnet_time_scale_shift=resnet_time_scale_shift,
196
+ resolution_idx=resolution_idx,
197
+ temporal_num_attention_heads=temporal_num_attention_heads,
198
+ temporal_max_seq_length=temporal_max_seq_length,
199
+ )
200
+ elif up_block_type == "CrossAttnUpBlockMotion":
201
+ if cross_attention_dim is None:
202
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion")
203
+ return CrossAttnUpBlockMotion(
204
+ num_layers=num_layers,
205
+ in_channels=in_channels,
206
+ out_channels=out_channels,
207
+ prev_output_channel=prev_output_channel,
208
+ temb_channels=temb_channels,
209
+ add_upsample=add_upsample,
210
+ resnet_eps=resnet_eps,
211
+ resnet_act_fn=resnet_act_fn,
212
+ resnet_groups=resnet_groups,
213
+ cross_attention_dim=cross_attention_dim,
214
+ num_attention_heads=num_attention_heads,
215
+ dual_cross_attention=dual_cross_attention,
216
+ use_linear_projection=use_linear_projection,
217
+ only_cross_attention=only_cross_attention,
218
+ upcast_attention=upcast_attention,
219
+ resnet_time_scale_shift=resnet_time_scale_shift,
220
+ resolution_idx=resolution_idx,
221
+ temporal_num_attention_heads=temporal_num_attention_heads,
222
+ temporal_max_seq_length=temporal_max_seq_length,
223
+ )
224
+ raise ValueError(f"{up_block_type} does not exist.")
225
+
226
+
227
+ class UNetMidBlock3DCrossAttn(nn.Module):
228
+ def __init__(
229
+ self,
230
+ in_channels: int,
231
+ temb_channels: int,
232
+ dropout: float = 0.0,
233
+ num_layers: int = 1,
234
+ resnet_eps: float = 1e-6,
235
+ resnet_time_scale_shift: str = "default",
236
+ resnet_act_fn: str = "swish",
237
+ resnet_groups: int = 32,
238
+ resnet_pre_norm: bool = True,
239
+ num_attention_heads=1,
240
+ output_scale_factor=1.0,
241
+ cross_attention_dim=1280,
242
+ dual_cross_attention=False,
243
+ use_linear_projection=True,
244
+ upcast_attention=False,
245
+ ):
246
+ super().__init__()
247
+
248
+ self.has_cross_attention = True
249
+ self.num_attention_heads = num_attention_heads
250
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
251
+
252
+ # there is always at least one resnet
253
+ resnets = [
254
+ ResnetBlock2D(
255
+ in_channels=in_channels,
256
+ out_channels=in_channels,
257
+ temb_channels=temb_channels,
258
+ eps=resnet_eps,
259
+ groups=resnet_groups,
260
+ dropout=dropout,
261
+ time_embedding_norm=resnet_time_scale_shift,
262
+ non_linearity=resnet_act_fn,
263
+ output_scale_factor=output_scale_factor,
264
+ pre_norm=resnet_pre_norm,
265
+ )
266
+ ]
267
+ temp_convs = [
268
+ TemporalConvLayer(
269
+ in_channels,
270
+ in_channels,
271
+ dropout=0.1,
272
+ )
273
+ ]
274
+ attentions = []
275
+ temp_attentions = []
276
+
277
+ for _ in range(num_layers):
278
+ attentions.append(
279
+ Transformer2DModel(
280
+ in_channels // num_attention_heads,
281
+ num_attention_heads,
282
+ in_channels=in_channels,
283
+ num_layers=1,
284
+ cross_attention_dim=cross_attention_dim,
285
+ norm_num_groups=resnet_groups,
286
+ use_linear_projection=use_linear_projection,
287
+ upcast_attention=upcast_attention,
288
+ )
289
+ )
290
+ temp_attentions.append(
291
+ TransformerTemporalModel(
292
+ in_channels // num_attention_heads,
293
+ num_attention_heads,
294
+ in_channels=in_channels,
295
+ num_layers=1,
296
+ cross_attention_dim=cross_attention_dim,
297
+ norm_num_groups=resnet_groups,
298
+ )
299
+ )
300
+ resnets.append(
301
+ ResnetBlock2D(
302
+ in_channels=in_channels,
303
+ out_channels=in_channels,
304
+ temb_channels=temb_channels,
305
+ eps=resnet_eps,
306
+ groups=resnet_groups,
307
+ dropout=dropout,
308
+ time_embedding_norm=resnet_time_scale_shift,
309
+ non_linearity=resnet_act_fn,
310
+ output_scale_factor=output_scale_factor,
311
+ pre_norm=resnet_pre_norm,
312
+ )
313
+ )
314
+ temp_convs.append(
315
+ TemporalConvLayer(
316
+ in_channels,
317
+ in_channels,
318
+ dropout=0.1,
319
+ )
320
+ )
321
+
322
+ self.resnets = nn.ModuleList(resnets)
323
+ self.temp_convs = nn.ModuleList(temp_convs)
324
+ self.attentions = nn.ModuleList(attentions)
325
+ self.temp_attentions = nn.ModuleList(temp_attentions)
326
+
327
+ def forward(
328
+ self,
329
+ hidden_states,
330
+ temb=None,
331
+ encoder_hidden_states=None,
332
+ attention_mask=None,
333
+ num_frames=1,
334
+ cross_attention_kwargs=None,
335
+ ):
336
+ hidden_states = self.resnets[0](hidden_states, temb)
337
+ hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
338
+ for attn, temp_attn, resnet, temp_conv in zip(
339
+ self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
340
+ ):
341
+ hidden_states = attn(
342
+ hidden_states,
343
+ encoder_hidden_states=encoder_hidden_states,
344
+ cross_attention_kwargs=cross_attention_kwargs,
345
+ return_dict=False,
346
+ )[0]
347
+ hidden_states = temp_attn(
348
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
349
+ )[0]
350
+ hidden_states = resnet(hidden_states, temb)
351
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
352
+
353
+ return hidden_states
354
+
355
+
356
+ class CrossAttnDownBlock3D(nn.Module):
357
+ def __init__(
358
+ self,
359
+ in_channels: int,
360
+ out_channels: int,
361
+ temb_channels: int,
362
+ dropout: float = 0.0,
363
+ num_layers: int = 1,
364
+ resnet_eps: float = 1e-6,
365
+ resnet_time_scale_shift: str = "default",
366
+ resnet_act_fn: str = "swish",
367
+ resnet_groups: int = 32,
368
+ resnet_pre_norm: bool = True,
369
+ num_attention_heads=1,
370
+ cross_attention_dim=1280,
371
+ output_scale_factor=1.0,
372
+ downsample_padding=1,
373
+ add_downsample=True,
374
+ dual_cross_attention=False,
375
+ use_linear_projection=False,
376
+ only_cross_attention=False,
377
+ upcast_attention=False,
378
+ ):
379
+ super().__init__()
380
+ resnets = []
381
+ attentions = []
382
+ temp_attentions = []
383
+ temp_convs = []
384
+
385
+ self.has_cross_attention = True
386
+ self.num_attention_heads = num_attention_heads
387
+
388
+ for i in range(num_layers):
389
+ in_channels = in_channels if i == 0 else out_channels
390
+ resnets.append(
391
+ ResnetBlock2D(
392
+ in_channels=in_channels,
393
+ out_channels=out_channels,
394
+ temb_channels=temb_channels,
395
+ eps=resnet_eps,
396
+ groups=resnet_groups,
397
+ dropout=dropout,
398
+ time_embedding_norm=resnet_time_scale_shift,
399
+ non_linearity=resnet_act_fn,
400
+ output_scale_factor=output_scale_factor,
401
+ pre_norm=resnet_pre_norm,
402
+ )
403
+ )
404
+ temp_convs.append(
405
+ TemporalConvLayer(
406
+ out_channels,
407
+ out_channels,
408
+ dropout=0.1,
409
+ )
410
+ )
411
+ attentions.append(
412
+ Transformer2DModel(
413
+ out_channels // num_attention_heads,
414
+ num_attention_heads,
415
+ in_channels=out_channels,
416
+ num_layers=1,
417
+ cross_attention_dim=cross_attention_dim,
418
+ norm_num_groups=resnet_groups,
419
+ use_linear_projection=use_linear_projection,
420
+ only_cross_attention=only_cross_attention,
421
+ upcast_attention=upcast_attention,
422
+ )
423
+ )
424
+ temp_attentions.append(
425
+ TransformerTemporalModel(
426
+ out_channels // num_attention_heads,
427
+ num_attention_heads,
428
+ in_channels=out_channels,
429
+ num_layers=1,
430
+ cross_attention_dim=cross_attention_dim,
431
+ norm_num_groups=resnet_groups,
432
+ )
433
+ )
434
+ self.resnets = nn.ModuleList(resnets)
435
+ self.temp_convs = nn.ModuleList(temp_convs)
436
+ self.attentions = nn.ModuleList(attentions)
437
+ self.temp_attentions = nn.ModuleList(temp_attentions)
438
+
439
+ if add_downsample:
440
+ self.downsamplers = nn.ModuleList(
441
+ [
442
+ Downsample2D(
443
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
444
+ )
445
+ ]
446
+ )
447
+ else:
448
+ self.downsamplers = None
449
+
450
+ self.gradient_checkpointing = False
451
+
452
+ def forward(
453
+ self,
454
+ hidden_states,
455
+ temb=None,
456
+ encoder_hidden_states=None,
457
+ attention_mask=None,
458
+ num_frames=1,
459
+ cross_attention_kwargs=None,
460
+ ):
461
+ # TODO(Patrick, William) - attention face_hair_mask is not used
462
+ output_states = ()
463
+
464
+ for resnet, temp_conv, attn, temp_attn in zip(
465
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
466
+ ):
467
+ hidden_states = resnet(hidden_states, temb)
468
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
469
+ hidden_states = attn(
470
+ hidden_states,
471
+ encoder_hidden_states=encoder_hidden_states,
472
+ cross_attention_kwargs=cross_attention_kwargs,
473
+ return_dict=False,
474
+ )[0]
475
+ hidden_states = temp_attn(
476
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
477
+ )[0]
478
+
479
+ output_states += (hidden_states,)
480
+
481
+ if self.downsamplers is not None:
482
+ for downsampler in self.downsamplers:
483
+ hidden_states = downsampler(hidden_states)
484
+
485
+ output_states += (hidden_states,)
486
+
487
+ return hidden_states, output_states
488
+
489
+
490
+ class DownBlock3D(nn.Module):
491
+ def __init__(
492
+ self,
493
+ in_channels: int,
494
+ out_channels: int,
495
+ temb_channels: int,
496
+ dropout: float = 0.0,
497
+ num_layers: int = 1,
498
+ resnet_eps: float = 1e-6,
499
+ resnet_time_scale_shift: str = "default",
500
+ resnet_act_fn: str = "swish",
501
+ resnet_groups: int = 32,
502
+ resnet_pre_norm: bool = True,
503
+ output_scale_factor=1.0,
504
+ add_downsample=True,
505
+ downsample_padding=1,
506
+ ):
507
+ super().__init__()
508
+ resnets = []
509
+ temp_convs = []
510
+
511
+ for i in range(num_layers):
512
+ in_channels = in_channels if i == 0 else out_channels
513
+ resnets.append(
514
+ ResnetBlock2D(
515
+ in_channels=in_channels,
516
+ out_channels=out_channels,
517
+ temb_channels=temb_channels,
518
+ eps=resnet_eps,
519
+ groups=resnet_groups,
520
+ dropout=dropout,
521
+ time_embedding_norm=resnet_time_scale_shift,
522
+ non_linearity=resnet_act_fn,
523
+ output_scale_factor=output_scale_factor,
524
+ pre_norm=resnet_pre_norm,
525
+ )
526
+ )
527
+ temp_convs.append(
528
+ TemporalConvLayer(
529
+ out_channels,
530
+ out_channels,
531
+ dropout=0.1,
532
+ )
533
+ )
534
+
535
+ self.resnets = nn.ModuleList(resnets)
536
+ self.temp_convs = nn.ModuleList(temp_convs)
537
+
538
+ if add_downsample:
539
+ self.downsamplers = nn.ModuleList(
540
+ [
541
+ Downsample2D(
542
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
543
+ )
544
+ ]
545
+ )
546
+ else:
547
+ self.downsamplers = None
548
+
549
+ self.gradient_checkpointing = False
550
+
551
+ def forward(self, hidden_states, temb=None, num_frames=1):
552
+ output_states = ()
553
+
554
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
555
+ hidden_states = resnet(hidden_states, temb)
556
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
557
+
558
+ output_states += (hidden_states,)
559
+
560
+ if self.downsamplers is not None:
561
+ for downsampler in self.downsamplers:
562
+ hidden_states = downsampler(hidden_states)
563
+
564
+ output_states += (hidden_states,)
565
+
566
+ return hidden_states, output_states
567
+
568
+
569
+ class CrossAttnUpBlock3D(nn.Module):
570
+ def __init__(
571
+ self,
572
+ in_channels: int,
573
+ out_channels: int,
574
+ prev_output_channel: int,
575
+ temb_channels: int,
576
+ dropout: float = 0.0,
577
+ num_layers: int = 1,
578
+ resnet_eps: float = 1e-6,
579
+ resnet_time_scale_shift: str = "default",
580
+ resnet_act_fn: str = "swish",
581
+ resnet_groups: int = 32,
582
+ resnet_pre_norm: bool = True,
583
+ num_attention_heads=1,
584
+ cross_attention_dim=1280,
585
+ output_scale_factor=1.0,
586
+ add_upsample=True,
587
+ dual_cross_attention=False,
588
+ use_linear_projection=False,
589
+ only_cross_attention=False,
590
+ upcast_attention=False,
591
+ resolution_idx=None,
592
+ ):
593
+ super().__init__()
594
+ resnets = []
595
+ temp_convs = []
596
+ attentions = []
597
+ temp_attentions = []
598
+
599
+ self.has_cross_attention = True
600
+ self.num_attention_heads = num_attention_heads
601
+
602
+ for i in range(num_layers):
603
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
604
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
605
+
606
+ resnets.append(
607
+ ResnetBlock2D(
608
+ in_channels=resnet_in_channels + res_skip_channels,
609
+ out_channels=out_channels,
610
+ temb_channels=temb_channels,
611
+ eps=resnet_eps,
612
+ groups=resnet_groups,
613
+ dropout=dropout,
614
+ time_embedding_norm=resnet_time_scale_shift,
615
+ non_linearity=resnet_act_fn,
616
+ output_scale_factor=output_scale_factor,
617
+ pre_norm=resnet_pre_norm,
618
+ )
619
+ )
620
+ temp_convs.append(
621
+ TemporalConvLayer(
622
+ out_channels,
623
+ out_channels,
624
+ dropout=0.1,
625
+ )
626
+ )
627
+ attentions.append(
628
+ Transformer2DModel(
629
+ out_channels // num_attention_heads,
630
+ num_attention_heads,
631
+ in_channels=out_channels,
632
+ num_layers=1,
633
+ cross_attention_dim=cross_attention_dim,
634
+ norm_num_groups=resnet_groups,
635
+ use_linear_projection=use_linear_projection,
636
+ only_cross_attention=only_cross_attention,
637
+ upcast_attention=upcast_attention,
638
+ )
639
+ )
640
+ temp_attentions.append(
641
+ TransformerTemporalModel(
642
+ out_channels // num_attention_heads,
643
+ num_attention_heads,
644
+ in_channels=out_channels,
645
+ num_layers=1,
646
+ cross_attention_dim=cross_attention_dim,
647
+ norm_num_groups=resnet_groups,
648
+ )
649
+ )
650
+ self.resnets = nn.ModuleList(resnets)
651
+ self.temp_convs = nn.ModuleList(temp_convs)
652
+ self.attentions = nn.ModuleList(attentions)
653
+ self.temp_attentions = nn.ModuleList(temp_attentions)
654
+
655
+ if add_upsample:
656
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
657
+ else:
658
+ self.upsamplers = None
659
+
660
+ self.gradient_checkpointing = False
661
+ self.resolution_idx = resolution_idx
662
+
663
+ def forward(
664
+ self,
665
+ hidden_states,
666
+ res_hidden_states_tuple,
667
+ temb=None,
668
+ encoder_hidden_states=None,
669
+ upsample_size=None,
670
+ attention_mask=None,
671
+ num_frames=1,
672
+ cross_attention_kwargs=None,
673
+ ):
674
+ is_freeu_enabled = (
675
+ getattr(self, "s1", None)
676
+ and getattr(self, "s2", None)
677
+ and getattr(self, "b1", None)
678
+ and getattr(self, "b2", None)
679
+ )
680
+
681
+ # TODO(Patrick, William) - attention face_hair_mask is not used
682
+ for resnet, temp_conv, attn, temp_attn in zip(
683
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
684
+ ):
685
+ # pop res hidden states
686
+ res_hidden_states = res_hidden_states_tuple[-1]
687
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
688
+
689
+ # FreeU: Only operate on the first two stages
690
+ if is_freeu_enabled:
691
+ hidden_states, res_hidden_states = apply_freeu(
692
+ self.resolution_idx,
693
+ hidden_states,
694
+ res_hidden_states,
695
+ s1=self.s1,
696
+ s2=self.s2,
697
+ b1=self.b1,
698
+ b2=self.b2,
699
+ )
700
+
701
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
702
+
703
+ hidden_states = resnet(hidden_states, temb)
704
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
705
+ hidden_states = attn(
706
+ hidden_states,
707
+ encoder_hidden_states=encoder_hidden_states,
708
+ cross_attention_kwargs=cross_attention_kwargs,
709
+ return_dict=False,
710
+ )[0]
711
+ hidden_states = temp_attn(
712
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
713
+ )[0]
714
+
715
+ if self.upsamplers is not None:
716
+ for upsampler in self.upsamplers:
717
+ hidden_states = upsampler(hidden_states, upsample_size)
718
+
719
+ return hidden_states
720
+
721
+
722
+ class UpBlock3D(nn.Module):
723
+ def __init__(
724
+ self,
725
+ in_channels: int,
726
+ prev_output_channel: int,
727
+ out_channels: int,
728
+ temb_channels: int,
729
+ dropout: float = 0.0,
730
+ num_layers: int = 1,
731
+ resnet_eps: float = 1e-6,
732
+ resnet_time_scale_shift: str = "default",
733
+ resnet_act_fn: str = "swish",
734
+ resnet_groups: int = 32,
735
+ resnet_pre_norm: bool = True,
736
+ output_scale_factor=1.0,
737
+ add_upsample=True,
738
+ resolution_idx=None,
739
+ ):
740
+ super().__init__()
741
+ resnets = []
742
+ temp_convs = []
743
+
744
+ for i in range(num_layers):
745
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
746
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
747
+
748
+ resnets.append(
749
+ ResnetBlock2D(
750
+ in_channels=resnet_in_channels + res_skip_channels,
751
+ out_channels=out_channels,
752
+ temb_channels=temb_channels,
753
+ eps=resnet_eps,
754
+ groups=resnet_groups,
755
+ dropout=dropout,
756
+ time_embedding_norm=resnet_time_scale_shift,
757
+ non_linearity=resnet_act_fn,
758
+ output_scale_factor=output_scale_factor,
759
+ pre_norm=resnet_pre_norm,
760
+ )
761
+ )
762
+ temp_convs.append(
763
+ TemporalConvLayer(
764
+ out_channels,
765
+ out_channels,
766
+ dropout=0.1,
767
+ )
768
+ )
769
+
770
+ self.resnets = nn.ModuleList(resnets)
771
+ self.temp_convs = nn.ModuleList(temp_convs)
772
+
773
+ if add_upsample:
774
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
775
+ else:
776
+ self.upsamplers = None
777
+
778
+ self.gradient_checkpointing = False
779
+ self.resolution_idx = resolution_idx
780
+
781
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
782
+ is_freeu_enabled = (
783
+ getattr(self, "s1", None)
784
+ and getattr(self, "s2", None)
785
+ and getattr(self, "b1", None)
786
+ and getattr(self, "b2", None)
787
+ )
788
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
789
+ # pop res hidden states
790
+ res_hidden_states = res_hidden_states_tuple[-1]
791
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
792
+
793
+ # FreeU: Only operate on the first two stages
794
+ if is_freeu_enabled:
795
+ hidden_states, res_hidden_states = apply_freeu(
796
+ self.resolution_idx,
797
+ hidden_states,
798
+ res_hidden_states,
799
+ s1=self.s1,
800
+ s2=self.s2,
801
+ b1=self.b1,
802
+ b2=self.b2,
803
+ )
804
+
805
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
806
+
807
+ hidden_states = resnet(hidden_states, temb)
808
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
809
+
810
+ if self.upsamplers is not None:
811
+ for upsampler in self.upsamplers:
812
+ hidden_states = upsampler(hidden_states, upsample_size)
813
+
814
+ return hidden_states
815
+
816
+
817
+ class DownBlockMotion(nn.Module):
818
+ def __init__(
819
+ self,
820
+ in_channels: int,
821
+ out_channels: int,
822
+ temb_channels: int,
823
+ dropout: float = 0.0,
824
+ num_layers: int = 1,
825
+ resnet_eps: float = 1e-6,
826
+ resnet_time_scale_shift: str = "default",
827
+ resnet_act_fn: str = "swish",
828
+ resnet_groups: int = 32,
829
+ resnet_pre_norm: bool = True,
830
+ output_scale_factor=1.0,
831
+ add_downsample=True,
832
+ downsample_padding=1,
833
+ temporal_num_attention_heads=1,
834
+ temporal_cross_attention_dim=None,
835
+ temporal_max_seq_length=32,
836
+ ):
837
+ super().__init__()
838
+ resnets = []
839
+ motion_modules = []
840
+
841
+ for i in range(num_layers):
842
+ in_channels = in_channels if i == 0 else out_channels
843
+ resnets.append(
844
+ ResnetBlock2D(
845
+ in_channels=in_channels,
846
+ out_channels=out_channels,
847
+ temb_channels=temb_channels,
848
+ eps=resnet_eps,
849
+ groups=resnet_groups,
850
+ dropout=dropout,
851
+ time_embedding_norm=resnet_time_scale_shift,
852
+ non_linearity=resnet_act_fn,
853
+ output_scale_factor=output_scale_factor,
854
+ pre_norm=resnet_pre_norm,
855
+ )
856
+ )
857
+ motion_modules.append(
858
+ TransformerTemporalModel(
859
+ num_attention_heads=temporal_num_attention_heads,
860
+ in_channels=out_channels,
861
+ norm_num_groups=resnet_groups,
862
+ cross_attention_dim=temporal_cross_attention_dim,
863
+ attention_bias=False,
864
+ activation_fn="geglu",
865
+ positional_embeddings="sinusoidal",
866
+ num_positional_embeddings=temporal_max_seq_length,
867
+ attention_head_dim=out_channels // temporal_num_attention_heads,
868
+ )
869
+ )
870
+
871
+ self.resnets = nn.ModuleList(resnets)
872
+ self.motion_modules = nn.ModuleList(motion_modules)
873
+
874
+ if add_downsample:
875
+ self.downsamplers = nn.ModuleList(
876
+ [
877
+ Downsample2D(
878
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
879
+ )
880
+ ]
881
+ )
882
+ else:
883
+ self.downsamplers = None
884
+
885
+ self.gradient_checkpointing = False
886
+
887
+ def forward(self, hidden_states, temb=None, scale: float = 1.0, num_frames=1):
888
+ output_states = ()
889
+
890
+ blocks = zip(self.resnets, self.motion_modules)
891
+ for resnet, motion_module in blocks:
892
+ if self.training and self.gradient_checkpointing:
893
+
894
+ def create_custom_forward(module):
895
+ def custom_forward(*inputs):
896
+ return module(*inputs)
897
+
898
+ return custom_forward
899
+
900
+ if is_torch_version(">=", "1.11.0"):
901
+ hidden_states = torch.utils.checkpoint.checkpoint(
902
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
903
+ )
904
+ else:
905
+ hidden_states = torch.utils.checkpoint.checkpoint(
906
+ create_custom_forward(resnet), hidden_states, temb, scale
907
+ )
908
+ hidden_states = torch.utils.checkpoint.checkpoint(
909
+ create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, num_frames
910
+ )
911
+
912
+ else:
913
+ hidden_states = resnet(hidden_states, temb, scale=scale)
914
+ hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
915
+
916
+ output_states = output_states + (hidden_states,)
917
+
918
+ if self.downsamplers is not None:
919
+ for downsampler in self.downsamplers:
920
+ hidden_states = downsampler(hidden_states, scale=scale)
921
+
922
+ output_states = output_states + (hidden_states,)
923
+
924
+ return hidden_states, output_states
925
+
926
+
927
+ class CrossAttnDownBlockMotion(nn.Module):
928
+ def __init__(
929
+ self,
930
+ in_channels: int,
931
+ out_channels: int,
932
+ temb_channels: int,
933
+ dropout: float = 0.0,
934
+ num_layers: int = 1,
935
+ transformer_layers_per_block: int = 1,
936
+ resnet_eps: float = 1e-6,
937
+ resnet_time_scale_shift: str = "default",
938
+ resnet_act_fn: str = "swish",
939
+ resnet_groups: int = 32,
940
+ resnet_pre_norm: bool = True,
941
+ num_attention_heads=1,
942
+ cross_attention_dim=1280,
943
+ output_scale_factor=1.0,
944
+ downsample_padding=1,
945
+ add_downsample=True,
946
+ dual_cross_attention=False,
947
+ use_linear_projection=False,
948
+ only_cross_attention=False,
949
+ upcast_attention=False,
950
+ attention_type="default",
951
+ temporal_cross_attention_dim=None,
952
+ temporal_num_attention_heads=8,
953
+ temporal_max_seq_length=32,
954
+ ):
955
+ super().__init__()
956
+ resnets = []
957
+ attentions = []
958
+ motion_modules = []
959
+
960
+ self.has_cross_attention = True
961
+ self.num_attention_heads = num_attention_heads
962
+
963
+ for i in range(num_layers):
964
+ in_channels = in_channels if i == 0 else out_channels
965
+ resnets.append(
966
+ ResnetBlock2D(
967
+ in_channels=in_channels,
968
+ out_channels=out_channels,
969
+ temb_channels=temb_channels,
970
+ eps=resnet_eps,
971
+ groups=resnet_groups,
972
+ dropout=dropout,
973
+ time_embedding_norm=resnet_time_scale_shift,
974
+ non_linearity=resnet_act_fn,
975
+ output_scale_factor=output_scale_factor,
976
+ pre_norm=resnet_pre_norm,
977
+ )
978
+ )
979
+
980
+ if not dual_cross_attention:
981
+ attentions.append(
982
+ Transformer2DModel(
983
+ num_attention_heads,
984
+ out_channels // num_attention_heads,
985
+ in_channels=out_channels,
986
+ num_layers=transformer_layers_per_block,
987
+ cross_attention_dim=cross_attention_dim,
988
+ norm_num_groups=resnet_groups,
989
+ use_linear_projection=use_linear_projection,
990
+ only_cross_attention=only_cross_attention,
991
+ upcast_attention=upcast_attention,
992
+ attention_type=attention_type,
993
+ )
994
+ )
995
+ else:
996
+ attentions.append(
997
+ DualTransformer2DModel(
998
+ num_attention_heads,
999
+ out_channels // num_attention_heads,
1000
+ in_channels=out_channels,
1001
+ num_layers=1,
1002
+ cross_attention_dim=cross_attention_dim,
1003
+ norm_num_groups=resnet_groups,
1004
+ )
1005
+ )
1006
+
1007
+ motion_modules.append(
1008
+ TransformerTemporalModel(
1009
+ num_attention_heads=temporal_num_attention_heads,
1010
+ in_channels=out_channels,
1011
+ norm_num_groups=resnet_groups,
1012
+ cross_attention_dim=temporal_cross_attention_dim,
1013
+ attention_bias=False,
1014
+ activation_fn="geglu",
1015
+ positional_embeddings="sinusoidal",
1016
+ num_positional_embeddings=temporal_max_seq_length,
1017
+ attention_head_dim=out_channels // temporal_num_attention_heads,
1018
+ )
1019
+ )
1020
+
1021
+ self.attentions = nn.ModuleList(attentions)
1022
+ self.resnets = nn.ModuleList(resnets)
1023
+ self.motion_modules = nn.ModuleList(motion_modules)
1024
+
1025
+ if add_downsample:
1026
+ self.downsamplers = nn.ModuleList(
1027
+ [
1028
+ Downsample2D(
1029
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
1030
+ )
1031
+ ]
1032
+ )
1033
+ else:
1034
+ self.downsamplers = None
1035
+
1036
+ self.gradient_checkpointing = False
1037
+
1038
+ def forward(
1039
+ self,
1040
+ hidden_states,
1041
+ temb=None,
1042
+ encoder_hidden_states=None,
1043
+ attention_mask=None,
1044
+ num_frames=1,
1045
+ encoder_attention_mask=None,
1046
+ cross_attention_kwargs=None,
1047
+ additional_residuals=None,
1048
+ ):
1049
+ output_states = ()
1050
+
1051
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1052
+
1053
+ blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
1054
+ for i, (resnet, attn, motion_module) in enumerate(blocks):
1055
+ if self.training and self.gradient_checkpointing:
1056
+
1057
+ def create_custom_forward(module, return_dict=None):
1058
+ def custom_forward(*inputs):
1059
+ if return_dict is not None:
1060
+ return module(*inputs, return_dict=return_dict)
1061
+ else:
1062
+ return module(*inputs)
1063
+
1064
+ return custom_forward
1065
+
1066
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1067
+ hidden_states = torch.utils.checkpoint.checkpoint(
1068
+ create_custom_forward(resnet),
1069
+ hidden_states,
1070
+ temb,
1071
+ **ckpt_kwargs,
1072
+ )
1073
+ hidden_states = attn(
1074
+ hidden_states,
1075
+ encoder_hidden_states=encoder_hidden_states,
1076
+ cross_attention_kwargs=cross_attention_kwargs,
1077
+ attention_mask=attention_mask,
1078
+ encoder_attention_mask=encoder_attention_mask,
1079
+ return_dict=False,
1080
+ )[0]
1081
+ else:
1082
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
1083
+ hidden_states = attn(
1084
+ hidden_states,
1085
+ encoder_hidden_states=encoder_hidden_states,
1086
+ cross_attention_kwargs=cross_attention_kwargs,
1087
+ attention_mask=attention_mask,
1088
+ encoder_attention_mask=encoder_attention_mask,
1089
+ return_dict=False,
1090
+ )[0]
1091
+ hidden_states = motion_module(
1092
+ hidden_states,
1093
+ num_frames=num_frames,
1094
+ )[0]
1095
+
1096
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
1097
+ if i == len(blocks) - 1 and additional_residuals is not None:
1098
+ hidden_states = hidden_states + additional_residuals
1099
+
1100
+ output_states = output_states + (hidden_states,)
1101
+
1102
+ if self.downsamplers is not None:
1103
+ for downsampler in self.downsamplers:
1104
+ hidden_states = downsampler(hidden_states, scale=lora_scale)
1105
+
1106
+ output_states = output_states + (hidden_states,)
1107
+
1108
+ return hidden_states, output_states
1109
+
1110
+
1111
+ class CrossAttnUpBlockMotion(nn.Module):
1112
+ def __init__(
1113
+ self,
1114
+ in_channels: int,
1115
+ out_channels: int,
1116
+ prev_output_channel: int,
1117
+ temb_channels: int,
1118
+ resolution_idx: int = None,
1119
+ dropout: float = 0.0,
1120
+ num_layers: int = 1,
1121
+ transformer_layers_per_block: int = 1,
1122
+ resnet_eps: float = 1e-6,
1123
+ resnet_time_scale_shift: str = "default",
1124
+ resnet_act_fn: str = "swish",
1125
+ resnet_groups: int = 32,
1126
+ resnet_pre_norm: bool = True,
1127
+ num_attention_heads=1,
1128
+ cross_attention_dim=1280,
1129
+ output_scale_factor=1.0,
1130
+ add_upsample=True,
1131
+ dual_cross_attention=False,
1132
+ use_linear_projection=False,
1133
+ only_cross_attention=False,
1134
+ upcast_attention=False,
1135
+ attention_type="default",
1136
+ temporal_cross_attention_dim=None,
1137
+ temporal_num_attention_heads=8,
1138
+ temporal_max_seq_length=32,
1139
+ ):
1140
+ super().__init__()
1141
+ resnets = []
1142
+ attentions = []
1143
+ motion_modules = []
1144
+
1145
+ self.has_cross_attention = True
1146
+ self.num_attention_heads = num_attention_heads
1147
+
1148
+ for i in range(num_layers):
1149
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1150
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1151
+
1152
+ resnets.append(
1153
+ ResnetBlock2D(
1154
+ in_channels=resnet_in_channels + res_skip_channels,
1155
+ out_channels=out_channels,
1156
+ temb_channels=temb_channels,
1157
+ eps=resnet_eps,
1158
+ groups=resnet_groups,
1159
+ dropout=dropout,
1160
+ time_embedding_norm=resnet_time_scale_shift,
1161
+ non_linearity=resnet_act_fn,
1162
+ output_scale_factor=output_scale_factor,
1163
+ pre_norm=resnet_pre_norm,
1164
+ )
1165
+ )
1166
+
1167
+ if not dual_cross_attention:
1168
+ attentions.append(
1169
+ Transformer2DModel(
1170
+ num_attention_heads,
1171
+ out_channels // num_attention_heads,
1172
+ in_channels=out_channels,
1173
+ num_layers=transformer_layers_per_block,
1174
+ cross_attention_dim=cross_attention_dim,
1175
+ norm_num_groups=resnet_groups,
1176
+ use_linear_projection=use_linear_projection,
1177
+ only_cross_attention=only_cross_attention,
1178
+ upcast_attention=upcast_attention,
1179
+ attention_type=attention_type,
1180
+ )
1181
+ )
1182
+ else:
1183
+ attentions.append(
1184
+ DualTransformer2DModel(
1185
+ num_attention_heads,
1186
+ out_channels // num_attention_heads,
1187
+ in_channels=out_channels,
1188
+ num_layers=1,
1189
+ cross_attention_dim=cross_attention_dim,
1190
+ norm_num_groups=resnet_groups,
1191
+ )
1192
+ )
1193
+ motion_modules.append(
1194
+ TransformerTemporalModel(
1195
+ num_attention_heads=temporal_num_attention_heads,
1196
+ in_channels=out_channels,
1197
+ norm_num_groups=resnet_groups,
1198
+ cross_attention_dim=temporal_cross_attention_dim,
1199
+ attention_bias=False,
1200
+ activation_fn="geglu",
1201
+ positional_embeddings="sinusoidal",
1202
+ num_positional_embeddings=temporal_max_seq_length,
1203
+ attention_head_dim=out_channels // temporal_num_attention_heads,
1204
+ )
1205
+ )
1206
+
1207
+ self.attentions = nn.ModuleList(attentions)
1208
+ self.resnets = nn.ModuleList(resnets)
1209
+ self.motion_modules = nn.ModuleList(motion_modules)
1210
+
1211
+ if add_upsample:
1212
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1213
+ else:
1214
+ self.upsamplers = None
1215
+
1216
+ self.gradient_checkpointing = False
1217
+ self.resolution_idx = resolution_idx
1218
+
1219
+ def forward(
1220
+ self,
1221
+ hidden_states: torch.FloatTensor,
1222
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1223
+ temb: Optional[torch.FloatTensor] = None,
1224
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1225
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1226
+ upsample_size: Optional[int] = None,
1227
+ attention_mask: Optional[torch.FloatTensor] = None,
1228
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1229
+ num_frames=1,
1230
+ ):
1231
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1232
+ is_freeu_enabled = (
1233
+ getattr(self, "s1", None)
1234
+ and getattr(self, "s2", None)
1235
+ and getattr(self, "b1", None)
1236
+ and getattr(self, "b2", None)
1237
+ )
1238
+
1239
+ blocks = zip(self.resnets, self.attentions, self.motion_modules)
1240
+ for resnet, attn, motion_module in blocks:
1241
+ # pop res hidden states
1242
+ res_hidden_states = res_hidden_states_tuple[-1]
1243
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1244
+
1245
+ # FreeU: Only operate on the first two stages
1246
+ if is_freeu_enabled:
1247
+ hidden_states, res_hidden_states = apply_freeu(
1248
+ self.resolution_idx,
1249
+ hidden_states,
1250
+ res_hidden_states,
1251
+ s1=self.s1,
1252
+ s2=self.s2,
1253
+ b1=self.b1,
1254
+ b2=self.b2,
1255
+ )
1256
+
1257
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1258
+
1259
+ if self.training and self.gradient_checkpointing:
1260
+
1261
+ def create_custom_forward(module, return_dict=None):
1262
+ def custom_forward(*inputs):
1263
+ if return_dict is not None:
1264
+ return module(*inputs, return_dict=return_dict)
1265
+ else:
1266
+ return module(*inputs)
1267
+
1268
+ return custom_forward
1269
+
1270
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1271
+ hidden_states = torch.utils.checkpoint.checkpoint(
1272
+ create_custom_forward(resnet),
1273
+ hidden_states,
1274
+ temb,
1275
+ **ckpt_kwargs,
1276
+ )
1277
+ hidden_states = attn(
1278
+ hidden_states,
1279
+ encoder_hidden_states=encoder_hidden_states,
1280
+ cross_attention_kwargs=cross_attention_kwargs,
1281
+ attention_mask=attention_mask,
1282
+ encoder_attention_mask=encoder_attention_mask,
1283
+ return_dict=False,
1284
+ )[0]
1285
+ else:
1286
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
1287
+ hidden_states = attn(
1288
+ hidden_states,
1289
+ encoder_hidden_states=encoder_hidden_states,
1290
+ cross_attention_kwargs=cross_attention_kwargs,
1291
+ attention_mask=attention_mask,
1292
+ encoder_attention_mask=encoder_attention_mask,
1293
+ return_dict=False,
1294
+ )[0]
1295
+ hidden_states = motion_module(
1296
+ hidden_states,
1297
+ num_frames=num_frames,
1298
+ )[0]
1299
+
1300
+ if self.upsamplers is not None:
1301
+ for upsampler in self.upsamplers:
1302
+ hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
1303
+
1304
+ return hidden_states
1305
+
1306
+
1307
+ class UpBlockMotion(nn.Module):
1308
+ def __init__(
1309
+ self,
1310
+ in_channels: int,
1311
+ prev_output_channel: int,
1312
+ out_channels: int,
1313
+ temb_channels: int,
1314
+ resolution_idx: int = None,
1315
+ dropout: float = 0.0,
1316
+ num_layers: int = 1,
1317
+ resnet_eps: float = 1e-6,
1318
+ resnet_time_scale_shift: str = "default",
1319
+ resnet_act_fn: str = "swish",
1320
+ resnet_groups: int = 32,
1321
+ resnet_pre_norm: bool = True,
1322
+ output_scale_factor=1.0,
1323
+ add_upsample=True,
1324
+ temporal_norm_num_groups=32,
1325
+ temporal_cross_attention_dim=None,
1326
+ temporal_num_attention_heads=8,
1327
+ temporal_max_seq_length=32,
1328
+ ):
1329
+ super().__init__()
1330
+ resnets = []
1331
+ motion_modules = []
1332
+
1333
+ for i in range(num_layers):
1334
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1335
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1336
+
1337
+ resnets.append(
1338
+ ResnetBlock2D(
1339
+ in_channels=resnet_in_channels + res_skip_channels,
1340
+ out_channels=out_channels,
1341
+ temb_channels=temb_channels,
1342
+ eps=resnet_eps,
1343
+ groups=resnet_groups,
1344
+ dropout=dropout,
1345
+ time_embedding_norm=resnet_time_scale_shift,
1346
+ non_linearity=resnet_act_fn,
1347
+ output_scale_factor=output_scale_factor,
1348
+ pre_norm=resnet_pre_norm,
1349
+ )
1350
+ )
1351
+
1352
+ motion_modules.append(
1353
+ TransformerTemporalModel(
1354
+ num_attention_heads=temporal_num_attention_heads,
1355
+ in_channels=out_channels,
1356
+ norm_num_groups=temporal_norm_num_groups,
1357
+ cross_attention_dim=temporal_cross_attention_dim,
1358
+ attention_bias=False,
1359
+ activation_fn="geglu",
1360
+ positional_embeddings="sinusoidal",
1361
+ num_positional_embeddings=temporal_max_seq_length,
1362
+ attention_head_dim=out_channels // temporal_num_attention_heads,
1363
+ )
1364
+ )
1365
+
1366
+ self.resnets = nn.ModuleList(resnets)
1367
+ self.motion_modules = nn.ModuleList(motion_modules)
1368
+
1369
+ if add_upsample:
1370
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1371
+ else:
1372
+ self.upsamplers = None
1373
+
1374
+ self.gradient_checkpointing = False
1375
+ self.resolution_idx = resolution_idx
1376
+
1377
+ def forward(
1378
+ self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0, num_frames=1
1379
+ ):
1380
+ is_freeu_enabled = (
1381
+ getattr(self, "s1", None)
1382
+ and getattr(self, "s2", None)
1383
+ and getattr(self, "b1", None)
1384
+ and getattr(self, "b2", None)
1385
+ )
1386
+
1387
+ blocks = zip(self.resnets, self.motion_modules)
1388
+
1389
+ for resnet, motion_module in blocks:
1390
+ # pop res hidden states
1391
+ res_hidden_states = res_hidden_states_tuple[-1]
1392
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1393
+
1394
+ # FreeU: Only operate on the first two stages
1395
+ if is_freeu_enabled:
1396
+ hidden_states, res_hidden_states = apply_freeu(
1397
+ self.resolution_idx,
1398
+ hidden_states,
1399
+ res_hidden_states,
1400
+ s1=self.s1,
1401
+ s2=self.s2,
1402
+ b1=self.b1,
1403
+ b2=self.b2,
1404
+ )
1405
+
1406
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1407
+
1408
+ if self.training and self.gradient_checkpointing:
1409
+
1410
+ def create_custom_forward(module):
1411
+ def custom_forward(*inputs):
1412
+ return module(*inputs)
1413
+
1414
+ return custom_forward
1415
+
1416
+ if is_torch_version(">=", "1.11.0"):
1417
+ hidden_states = torch.utils.checkpoint.checkpoint(
1418
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
1419
+ )
1420
+ else:
1421
+ hidden_states = torch.utils.checkpoint.checkpoint(
1422
+ create_custom_forward(resnet), hidden_states, temb
1423
+ )
1424
+ hidden_states = torch.utils.checkpoint.checkpoint(
1425
+ create_custom_forward(resnet),
1426
+ hidden_states,
1427
+ temb,
1428
+ )
1429
+
1430
+ else:
1431
+ hidden_states = resnet(hidden_states, temb, scale=scale)
1432
+ hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
1433
+
1434
+ if self.upsamplers is not None:
1435
+ for upsampler in self.upsamplers:
1436
+ hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
1437
+
1438
+ return hidden_states
1439
+
1440
+
1441
+ class UNetMidBlockCrossAttnMotion(nn.Module):
1442
+ def __init__(
1443
+ self,
1444
+ in_channels: int,
1445
+ temb_channels: int,
1446
+ dropout: float = 0.0,
1447
+ num_layers: int = 1,
1448
+ transformer_layers_per_block: int = 1,
1449
+ resnet_eps: float = 1e-6,
1450
+ resnet_time_scale_shift: str = "default",
1451
+ resnet_act_fn: str = "swish",
1452
+ resnet_groups: int = 32,
1453
+ resnet_pre_norm: bool = True,
1454
+ num_attention_heads=1,
1455
+ output_scale_factor=1.0,
1456
+ cross_attention_dim=1280,
1457
+ dual_cross_attention=False,
1458
+ use_linear_projection=False,
1459
+ upcast_attention=False,
1460
+ attention_type="default",
1461
+ temporal_num_attention_heads=1,
1462
+ temporal_cross_attention_dim=None,
1463
+ temporal_max_seq_length=32,
1464
+ ):
1465
+ super().__init__()
1466
+
1467
+ self.has_cross_attention = True
1468
+ self.num_attention_heads = num_attention_heads
1469
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
1470
+
1471
+ # there is always at least one resnet
1472
+ resnets = [
1473
+ ResnetBlock2D(
1474
+ in_channels=in_channels,
1475
+ out_channels=in_channels,
1476
+ temb_channels=temb_channels,
1477
+ eps=resnet_eps,
1478
+ groups=resnet_groups,
1479
+ dropout=dropout,
1480
+ time_embedding_norm=resnet_time_scale_shift,
1481
+ non_linearity=resnet_act_fn,
1482
+ output_scale_factor=output_scale_factor,
1483
+ pre_norm=resnet_pre_norm,
1484
+ )
1485
+ ]
1486
+ attentions = []
1487
+ motion_modules = []
1488
+
1489
+ for _ in range(num_layers):
1490
+ if not dual_cross_attention:
1491
+ attentions.append(
1492
+ Transformer2DModel(
1493
+ num_attention_heads,
1494
+ in_channels // num_attention_heads,
1495
+ in_channels=in_channels,
1496
+ num_layers=transformer_layers_per_block,
1497
+ cross_attention_dim=cross_attention_dim,
1498
+ norm_num_groups=resnet_groups,
1499
+ use_linear_projection=use_linear_projection,
1500
+ upcast_attention=upcast_attention,
1501
+ attention_type=attention_type,
1502
+ )
1503
+ )
1504
+ else:
1505
+ attentions.append(
1506
+ DualTransformer2DModel(
1507
+ num_attention_heads,
1508
+ in_channels // num_attention_heads,
1509
+ in_channels=in_channels,
1510
+ num_layers=1,
1511
+ cross_attention_dim=cross_attention_dim,
1512
+ norm_num_groups=resnet_groups,
1513
+ )
1514
+ )
1515
+ resnets.append(
1516
+ ResnetBlock2D(
1517
+ in_channels=in_channels,
1518
+ out_channels=in_channels,
1519
+ temb_channels=temb_channels,
1520
+ eps=resnet_eps,
1521
+ groups=resnet_groups,
1522
+ dropout=dropout,
1523
+ time_embedding_norm=resnet_time_scale_shift,
1524
+ non_linearity=resnet_act_fn,
1525
+ output_scale_factor=output_scale_factor,
1526
+ pre_norm=resnet_pre_norm,
1527
+ )
1528
+ )
1529
+ motion_modules.append(
1530
+ TransformerTemporalModel(
1531
+ num_attention_heads=temporal_num_attention_heads,
1532
+ attention_head_dim=in_channels // temporal_num_attention_heads,
1533
+ in_channels=in_channels,
1534
+ norm_num_groups=resnet_groups,
1535
+ cross_attention_dim=temporal_cross_attention_dim,
1536
+ attention_bias=False,
1537
+ positional_embeddings="sinusoidal",
1538
+ num_positional_embeddings=temporal_max_seq_length,
1539
+ activation_fn="geglu",
1540
+ )
1541
+ )
1542
+
1543
+ self.attentions = nn.ModuleList(attentions)
1544
+ self.resnets = nn.ModuleList(resnets)
1545
+ self.motion_modules = nn.ModuleList(motion_modules)
1546
+
1547
+ self.gradient_checkpointing = False
1548
+
1549
+ def forward(
1550
+ self,
1551
+ hidden_states: torch.FloatTensor,
1552
+ temb: Optional[torch.FloatTensor] = None,
1553
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1554
+ attention_mask: Optional[torch.FloatTensor] = None,
1555
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1556
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1557
+ num_frames=1,
1558
+ ) -> torch.FloatTensor:
1559
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1560
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
1561
+
1562
+ blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
1563
+ for attn, resnet, motion_module in blocks:
1564
+ if self.training and self.gradient_checkpointing:
1565
+
1566
+ def create_custom_forward(module, return_dict=None):
1567
+ def custom_forward(*inputs):
1568
+ if return_dict is not None:
1569
+ return module(*inputs, return_dict=return_dict)
1570
+ else:
1571
+ return module(*inputs)
1572
+
1573
+ return custom_forward
1574
+
1575
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1576
+ hidden_states = attn(
1577
+ hidden_states,
1578
+ encoder_hidden_states=encoder_hidden_states,
1579
+ cross_attention_kwargs=cross_attention_kwargs,
1580
+ attention_mask=attention_mask,
1581
+ encoder_attention_mask=encoder_attention_mask,
1582
+ return_dict=False,
1583
+ )[0]
1584
+ hidden_states = torch.utils.checkpoint.checkpoint(
1585
+ create_custom_forward(motion_module),
1586
+ hidden_states,
1587
+ temb,
1588
+ **ckpt_kwargs,
1589
+ )
1590
+ hidden_states = torch.utils.checkpoint.checkpoint(
1591
+ create_custom_forward(resnet),
1592
+ hidden_states,
1593
+ temb,
1594
+ **ckpt_kwargs,
1595
+ )
1596
+ else:
1597
+ hidden_states = attn(
1598
+ hidden_states,
1599
+ encoder_hidden_states=encoder_hidden_states,
1600
+ cross_attention_kwargs=cross_attention_kwargs,
1601
+ attention_mask=attention_mask,
1602
+ encoder_attention_mask=encoder_attention_mask,
1603
+ return_dict=False,
1604
+ )[0]
1605
+ hidden_states = motion_module(
1606
+ hidden_states,
1607
+ num_frames=num_frames,
1608
+ )[0]
1609
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
1610
+
1611
+ return hidden_states