ysmikey commited on
Commit
d895063
·
1 Parent(s): 2692412
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,97 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ license_name: flux-1-dev-non-commercial-license
4
+ license_link: https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md
5
+ language:
6
+ - en
7
+ library_name: diffusers
8
+ pipeline_tag: text-to-image
9
+ base_model: black-forest-labs/FLUX.1-dev
10
+ tags:
11
+ - diffusers
12
+ - lora
13
+ - flux
14
+ inference:
15
+ parameters:
16
+ width: 1440
17
+ height: 720
18
+
19
+ ---
20
+ <Gallery />
21
+
22
+
23
+
24
+ # Layerpano3D-FLUX-Panorama-LoRA
25
+
26
+ <table>
27
+ <tr>
28
+ <td><img src="assets/Magical academy courtyard with floating orbs of light, ancient stone buildings, and a large tree in the center, mystical and enchanting.png" alt="Image 9" width="100%"></td>
29
+ <td><img src="assets/Autumn park scene with people sitting on benches surrounded by colorful trees, storybook illustration style.png" alt="Image 4" width="100%"></td>
30
+ <td><img src="assets/An ancient stone archway standing alone in a peaceful meadow, surrounded by wildflowers, with sunlight streaming through, casting long shadows.png" alt="Image 3" width="100%"></td>
31
+ </tr>
32
+ <tr>
33
+ <td><img src="assets/A charming village market square filled with outdoor vendors, baskets of fresh produce, and villagers interacting in the morning sun.png" alt="Image 2" width="100%"></td>
34
+ <td><img src="assets/A vibrant city avenue, bustling traffic, towering skyscrapers.png" alt="Image 5" width="100%"></td>
35
+ <td><img src="assets/Bustling city street at sunset, skyscrapers, streets, cars.png" alt="Image 6" width="100%"></td>
36
+ </tr>
37
+ <tr>
38
+ <td><img src="assets/Cozy livingroom in christmas.png" alt="Image 7" width="100%"></td>
39
+ <td><img src="assets/lego city with lego shops, lego road with street lamp, cars and lego mans on the street, lego trees and lake at a park.png" alt="Image 8" width="100%"></td>
40
+ <td><img src="assets/A bustling open-air market with colorful stalls overflowing with fresh produce, flowers, and goods, bathed in soft, warm sunlight, capturing the vibrancy of daily life.png" alt="Image 1" width="100%"></td>
41
+ </tr>
42
+ </table>
43
+
44
+ A LoRA model to generate panoramas using Flux (Text2Panorama and Panorama inpaint).
45
+
46
+ ## Which image ratio and resolution to use?
47
+
48
+ This model has been trained on images with a 2:1 ratio.
49
+
50
+ [NOTES]: Since we use lora for training, we find that using the same image resolution as training during inference will have better results. We provide lora with multiple resolutions in lora_hubs(now 720×1440 / 512×1024 version1), and will continue to update safetensors with better results and more flexible resolution in the future.
51
+
52
+ ## Inference
53
+
54
+ ```shell
55
+ pip install diffusers==0.32.0
56
+ ```
57
+ **Text-to-Panorama Generation** (run with our pipeline_flux.py in repo to ensure the close-loop)
58
+ ```python
59
+ import torch
60
+ import random
61
+ from pipeline_flux import FluxPipeline # use our modifed flux pipeline to ensure close-loop.
62
+
63
+ lora_path="lora_hubs/pano_lora_720*1440_v1.safetensors" # download panorama lora in our huggingface repo and replace it to your path.
64
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")
65
+ pipe.load_lora_weights(lora_path) # change this.
66
+ pipe.enable_model_cpu_offload() # save some VRAM by offloading the model to CPU
67
+
68
+ prompt = 'A vibrant city avenue, bustling traffic, towering skyscrapers'
69
+
70
+ pipe.enable_vae_tiling()
71
+ seed = 119223
72
+
73
+ #Select the same resolution as LoRA for inference
74
+ image = pipe(prompt,
75
+ height=720,
76
+ width=1440,
77
+ generator=torch.Generator("cpu").manual_seed(seed),
78
+ num_inference_steps=50,
79
+ blend_extend=6,
80
+ guidance_scale=7).images[0]
81
+
82
+ image.save("result.png")
83
+
84
+ ```
85
+
86
+ ## Related 360-Panoramic Project
87
+ - [**LayerPano3D**: Layered 3D Panorama for Hyper-Immersive Scene Generation](https://github.com/3DTopia/LayerPano3D). LayerPano3D generates full-view, explorable panoramic 3D scene from a single text prompt.
88
+
89
+
90
+ - [**Imagine360**: Immersive 360 Video Generation from Perspective Anchor](https://github.com/3DTopia/Imagine360). Imagine360 lifts standard perspective video into 360-degree video with rich and structured motion, unlocking dynamic scene experience from full 360 degrees.
91
+
92
+
93
+ ## Non-commercial use
94
+
95
+ As the base model is FLUX.1-[dev] and since the data comes from Google Street View, it should be used for [non-commercial, personal or demonstration purposes only](https://github.com/black-forest-labs/flux/blob/main/model_licenses/LICENSE-FLUX1-dev).
96
+
97
+ Please use it responsibly, thank you!
assets/A bustling open-air market with colorful stalls overflowing with fresh produce, flowers, and goods, bathed in soft, warm sunlight, capturing the vibrancy of daily life.png ADDED

Git LFS Details

  • SHA256: 208e8f395861f06107ac1b3133932a0203d5b121b6a90c7fade70357e0d46056
  • Pointer size: 132 Bytes
  • Size of remote file: 1.76 MB
assets/A charming village market square filled with outdoor vendors, baskets of fresh produce, and villagers interacting in the morning sun.png ADDED

Git LFS Details

  • SHA256: cc54786e396ab0afb3e22935f55b0b4e3504e56152cdc0f53377e7dddc31aa12
  • Pointer size: 132 Bytes
  • Size of remote file: 1.5 MB
assets/A vibrant city avenue, bustling traffic, towering skyscrapers.png ADDED

Git LFS Details

  • SHA256: c23878ce6d783638c01f6cb885ea243072b2e42410f70b689b4c4ed03343af8d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.57 MB
assets/An ancient stone archway standing alone in a peaceful meadow, surrounded by wildflowers, with sunlight streaming through, casting long shadows.png ADDED

Git LFS Details

  • SHA256: b0d8ae2484af0ce19e3af7cbaef6a21216636d479d0fa04754ee088c4a34b99b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.68 MB
assets/Autumn park scene with people sitting on benches surrounded by colorful trees, storybook illustration style.png ADDED

Git LFS Details

  • SHA256: e4a75a8fc38a304c68b7a72490bbf394c0883b8669b658c527493a6ff2237eb5
  • Pointer size: 132 Bytes
  • Size of remote file: 1.73 MB
assets/Bustling city street at sunset, skyscrapers, streets, cars.png ADDED

Git LFS Details

  • SHA256: 59b5d68e3d58726dd71e503839a6e044f524d38bc75bfb69645240e2bb8b8d52
  • Pointer size: 132 Bytes
  • Size of remote file: 1.36 MB
assets/Cozy livingroom in christmas.png ADDED

Git LFS Details

  • SHA256: 3e8d5bac47b877ef35c173cc6cc1831f731754724934280de1ce1dd572d15977
  • Pointer size: 132 Bytes
  • Size of remote file: 1.12 MB
assets/Magical academy courtyard with floating orbs of light, ancient stone buildings, and a large tree in the center, mystical and enchanting.png ADDED

Git LFS Details

  • SHA256: 23d466e632df056b2135425dc6a144b77b7a818b4fbc09ae6ae91a58134d741c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.86 MB
assets/lego city with lego shops, lego road with street lamp, cars and lego mans on the street, lego trees and lake at a park.png ADDED

Git LFS Details

  • SHA256: 62db9d6e2570f10c25b319f6027a71ca501b4818764068587b506a024ab70c99
  • Pointer size: 132 Bytes
  • Size of remote file: 1.27 MB
lora_hubs/pano_lora_512*1024_v1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:61f3f98f06ba0df49f9137c57d50222ab8f5adf9f83b47f669bd60da19f24df8
3
+ size 26365456
lora_hubs/pano_lora_720*1440_v1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36a64d71a914a30bf8946f6715a51a551855b4647f6cb9ed8779042c9c65cb5a
3
+ size 26365456
main.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ from pipeline_flux import FluxPipeline # use our modifed flux pipeline to ensure close-loop.
4
+
5
+ lora_path="lora_hubs/pano_lora_720*1440_v1.safetensors" # download panorama lora in our huggingface repo and replace it to your path.
6
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")
7
+ pipe.load_lora_weights(lora_path) # change this.
8
+ pipe.enable_model_cpu_offload() # save some VRAM by offloading the model to CPU
9
+
10
+ prompt = 'A vibrant city avenue, bustling traffic, towering skyscrapers'
11
+
12
+ pipe.enable_vae_tiling()
13
+ seed = 119223
14
+
15
+ #Select the same resolution as LoRA for inference
16
+ image = pipe(prompt,
17
+ height=720,
18
+ width=1440,
19
+ generator=torch.Generator("cpu").manual_seed(seed),
20
+ num_inference_steps=50,
21
+ blend_extend=6,
22
+ guidance_scale=7).images[0]
23
+
24
+ image.save("result.png")
pipeline_flux.py ADDED
@@ -0,0 +1,1114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import (
21
+ CLIPImageProcessor,
22
+ CLIPTextModel,
23
+ CLIPTokenizer,
24
+ CLIPVisionModelWithProjection,
25
+ T5EncoderModel,
26
+ T5TokenizerFast,
27
+ )
28
+
29
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
30
+ from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
31
+ from diffusers.models.autoencoders import AutoencoderKL
32
+
33
+ # from autoencoder_kl import AutoencoderKL
34
+
35
+
36
+ from diffusers.models.transformers import FluxTransformer2DModel
37
+ # from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
38
+
39
+ from scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
40
+
41
+ from diffusers.utils import (
42
+ USE_PEFT_BACKEND,
43
+ is_torch_xla_available,
44
+ logging,
45
+ replace_example_docstring,
46
+ scale_lora_layers,
47
+ unscale_lora_layers,
48
+ )
49
+ from diffusers.utils.torch_utils import randn_tensor
50
+ from diffusers import DiffusionPipeline
51
+ from diffusers.pipelines.flux import FluxPipelineOutput
52
+ import torch.nn.functional as F
53
+ from einops import rearrange
54
+
55
+ try:
56
+ from diffusers.models.autoencoders.vae import DecoderOutput
57
+ except:
58
+ from diffusers.models.vae import DecoderOutput
59
+
60
+ if is_torch_xla_available():
61
+ import torch_xla.core.xla_model as xm
62
+
63
+ XLA_AVAILABLE = True
64
+ else:
65
+ XLA_AVAILABLE = False
66
+
67
+
68
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
69
+
70
+ EXAMPLE_DOC_STRING = """
71
+ Examples:
72
+ ```py
73
+ >>> import torch
74
+ >>> from diffusers import FluxPipeline
75
+
76
+ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
77
+ >>> pipe.to("cuda")
78
+ >>> prompt = "A cat holding a sign that says hello world"
79
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
80
+ >>> # Refer to the pipeline documentation for more details.
81
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
82
+ >>> image.save("flux.png")
83
+ ```
84
+ """
85
+
86
+
87
+ def calculate_shift(
88
+ image_seq_len,
89
+ base_seq_len: int = 256,
90
+ max_seq_len: int = 4096,
91
+ base_shift: float = 0.5,
92
+ max_shift: float = 1.16,
93
+ ):
94
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
95
+ b = base_shift - m * base_seq_len
96
+ mu = image_seq_len * m + b
97
+ return mu
98
+
99
+
100
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
101
+ def retrieve_timesteps(
102
+ scheduler,
103
+ num_inference_steps: Optional[int] = None,
104
+ device: Optional[Union[str, torch.device]] = None,
105
+ timesteps: Optional[List[int]] = None,
106
+ sigmas: Optional[List[float]] = None,
107
+ **kwargs,
108
+ ):
109
+ r"""
110
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
111
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
112
+
113
+ Args:
114
+ scheduler (`SchedulerMixin`):
115
+ The scheduler to get timesteps from.
116
+ num_inference_steps (`int`):
117
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
118
+ must be `None`.
119
+ device (`str` or `torch.device`, *optional*):
120
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
121
+ timesteps (`List[int]`, *optional*):
122
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
123
+ `num_inference_steps` and `sigmas` must be `None`.
124
+ sigmas (`List[float]`, *optional*):
125
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
126
+ `num_inference_steps` and `timesteps` must be `None`.
127
+
128
+ Returns:
129
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
130
+ second element is the number of inference steps.
131
+ """
132
+ if timesteps is not None and sigmas is not None:
133
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
134
+ if timesteps is not None:
135
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
136
+ if not accepts_timesteps:
137
+ raise ValueError(
138
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
139
+ f" timestep schedules. Please check whether you are using the correct scheduler."
140
+ )
141
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
142
+ timesteps = scheduler.timesteps
143
+ num_inference_steps = len(timesteps)
144
+ elif sigmas is not None:
145
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
146
+ if not accept_sigmas:
147
+ raise ValueError(
148
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
149
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
150
+ )
151
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
152
+ timesteps = scheduler.timesteps
153
+ num_inference_steps = len(timesteps)
154
+ else:
155
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
156
+ timesteps = scheduler.timesteps
157
+ return timesteps, num_inference_steps
158
+
159
+
160
+ class FluxPipeline(
161
+ DiffusionPipeline,
162
+ FluxLoraLoaderMixin,
163
+ FromSingleFileMixin,
164
+ TextualInversionLoaderMixin,
165
+ FluxIPAdapterMixin,
166
+ ):
167
+ r"""
168
+ The Flux pipeline for text-to-image generation.
169
+
170
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
171
+
172
+ Args:
173
+ transformer ([`FluxTransformer2DModel`]):
174
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
175
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
176
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
177
+ vae ([`AutoencoderKL`]):
178
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
179
+ text_encoder ([`CLIPTextModel`]):
180
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
181
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
182
+ text_encoder_2 ([`T5EncoderModel`]):
183
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
184
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
185
+ tokenizer (`CLIPTokenizer`):
186
+ Tokenizer of class
187
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
188
+ tokenizer_2 (`T5TokenizerFast`):
189
+ Second Tokenizer of class
190
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
191
+ """
192
+
193
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
194
+ _optional_components = ["image_encoder", "feature_extractor"]
195
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
196
+
197
+ def __init__(
198
+ self,
199
+ scheduler: FlowMatchEulerDiscreteScheduler,
200
+ vae: AutoencoderKL,
201
+ text_encoder: CLIPTextModel,
202
+ tokenizer: CLIPTokenizer,
203
+ text_encoder_2: T5EncoderModel,
204
+ tokenizer_2: T5TokenizerFast,
205
+ transformer: FluxTransformer2DModel,
206
+ image_encoder: CLIPVisionModelWithProjection = None,
207
+ feature_extractor: CLIPImageProcessor = None,
208
+ ):
209
+ super().__init__()
210
+
211
+ self.register_modules(
212
+ vae=vae,
213
+ text_encoder=text_encoder,
214
+ text_encoder_2=text_encoder_2,
215
+ tokenizer=tokenizer,
216
+ tokenizer_2=tokenizer_2,
217
+ transformer=transformer,
218
+ scheduler=scheduler,
219
+ image_encoder=image_encoder,
220
+ feature_extractor=feature_extractor,
221
+ )
222
+ self.vae_scale_factor = (
223
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
224
+ )
225
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
226
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
227
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
228
+ self.tokenizer_max_length = (
229
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
230
+ )
231
+ self.default_sample_size = 128
232
+
233
+ def _get_t5_prompt_embeds(
234
+ self,
235
+ prompt: Union[str, List[str]] = None,
236
+ num_images_per_prompt: int = 1,
237
+ max_sequence_length: int = 512,
238
+ device: Optional[torch.device] = None,
239
+ dtype: Optional[torch.dtype] = None,
240
+ ):
241
+ device = device or self._execution_device
242
+ dtype = dtype or self.text_encoder.dtype
243
+
244
+ prompt = [prompt] if isinstance(prompt, str) else prompt
245
+ batch_size = len(prompt)
246
+
247
+ if isinstance(self, TextualInversionLoaderMixin):
248
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
249
+
250
+ text_inputs = self.tokenizer_2(
251
+ prompt,
252
+ padding="max_length",
253
+ max_length=max_sequence_length,
254
+ truncation=True,
255
+ return_length=False,
256
+ return_overflowing_tokens=False,
257
+ return_tensors="pt",
258
+ )
259
+ text_input_ids = text_inputs.input_ids
260
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
261
+
262
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
263
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
264
+ logger.warning(
265
+ "The following part of your input was truncated because `max_sequence_length` is set to "
266
+ f" {max_sequence_length} tokens: {removed_text}"
267
+ )
268
+
269
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
270
+
271
+ dtype = self.text_encoder_2.dtype
272
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
273
+
274
+ _, seq_len, _ = prompt_embeds.shape
275
+
276
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
277
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
278
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
279
+
280
+ return prompt_embeds
281
+
282
+ def _get_clip_prompt_embeds(
283
+ self,
284
+ prompt: Union[str, List[str]],
285
+ num_images_per_prompt: int = 1,
286
+ device: Optional[torch.device] = None,
287
+ ):
288
+ device = device or self._execution_device
289
+
290
+ prompt = [prompt] if isinstance(prompt, str) else prompt
291
+ batch_size = len(prompt)
292
+
293
+ if isinstance(self, TextualInversionLoaderMixin):
294
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
295
+
296
+ text_inputs = self.tokenizer(
297
+ prompt,
298
+ padding="max_length",
299
+ max_length=self.tokenizer_max_length,
300
+ truncation=True,
301
+ return_overflowing_tokens=False,
302
+ return_length=False,
303
+ return_tensors="pt",
304
+ )
305
+
306
+ text_input_ids = text_inputs.input_ids
307
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
308
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
309
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
310
+ logger.warning(
311
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
312
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
313
+ )
314
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
315
+
316
+ # Use pooled output of CLIPTextModel
317
+ prompt_embeds = prompt_embeds.pooler_output
318
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
319
+
320
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
321
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
322
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
323
+
324
+ return prompt_embeds
325
+
326
+ def encode_prompt(
327
+ self,
328
+ prompt: Union[str, List[str]],
329
+ prompt_2: Union[str, List[str]],
330
+ device: Optional[torch.device] = None,
331
+ num_images_per_prompt: int = 1,
332
+ prompt_embeds: Optional[torch.FloatTensor] = None,
333
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
334
+ max_sequence_length: int = 512,
335
+ lora_scale: Optional[float] = None,
336
+ ):
337
+ r"""
338
+
339
+ Args:
340
+ prompt (`str` or `List[str]`, *optional*):
341
+ prompt to be encoded
342
+ prompt_2 (`str` or `List[str]`, *optional*):
343
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
344
+ used in all text-encoders
345
+ device: (`torch.device`):
346
+ torch device
347
+ num_images_per_prompt (`int`):
348
+ number of images that should be generated per prompt
349
+ prompt_embeds (`torch.FloatTensor`, *optional*):
350
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
351
+ provided, text embeddings will be generated from `prompt` input argument.
352
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
353
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
354
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
355
+ lora_scale (`float`, *optional*):
356
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
357
+ """
358
+ device = device or self._execution_device
359
+
360
+ # set lora scale so that monkey patched LoRA
361
+ # function of text encoder can correctly access it
362
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
363
+ self._lora_scale = lora_scale
364
+
365
+ # dynamically adjust the LoRA scale
366
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
367
+ scale_lora_layers(self.text_encoder, lora_scale)
368
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
369
+ scale_lora_layers(self.text_encoder_2, lora_scale)
370
+
371
+ prompt = [prompt] if isinstance(prompt, str) else prompt
372
+
373
+ if prompt_embeds is None:
374
+ prompt_2 = prompt_2 or prompt
375
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
376
+
377
+ # We only use the pooled prompt output from the CLIPTextModel
378
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
379
+ prompt=prompt,
380
+ device=device,
381
+ num_images_per_prompt=num_images_per_prompt,
382
+ )
383
+ prompt_embeds = self._get_t5_prompt_embeds(
384
+ prompt=prompt_2,
385
+ num_images_per_prompt=num_images_per_prompt,
386
+ max_sequence_length=max_sequence_length,
387
+ device=device,
388
+ )
389
+
390
+ if self.text_encoder is not None:
391
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
392
+ # Retrieve the original scale by scaling back the LoRA layers
393
+ unscale_lora_layers(self.text_encoder, lora_scale)
394
+
395
+ if self.text_encoder_2 is not None:
396
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
397
+ # Retrieve the original scale by scaling back the LoRA layers
398
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
399
+
400
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
401
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
402
+
403
+ return prompt_embeds, pooled_prompt_embeds, text_ids
404
+
405
+ def encode_image(self, image, device, num_images_per_prompt):
406
+ dtype = next(self.image_encoder.parameters()).dtype
407
+
408
+ if not isinstance(image, torch.Tensor):
409
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
410
+
411
+ image = image.to(device=device, dtype=dtype)
412
+ image_embeds = self.image_encoder(image).image_embeds
413
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
414
+ return image_embeds
415
+
416
+ def prepare_ip_adapter_image_embeds(
417
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
418
+ ):
419
+ image_embeds = []
420
+ if ip_adapter_image_embeds is None:
421
+ if not isinstance(ip_adapter_image, list):
422
+ ip_adapter_image = [ip_adapter_image]
423
+
424
+ if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
425
+ raise ValueError(
426
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
427
+ )
428
+
429
+ for single_ip_adapter_image, image_proj_layer in zip(
430
+ ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
431
+ ):
432
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
433
+
434
+ image_embeds.append(single_image_embeds[None, :])
435
+ else:
436
+ for single_image_embeds in ip_adapter_image_embeds:
437
+ image_embeds.append(single_image_embeds)
438
+
439
+ ip_adapter_image_embeds = []
440
+ for i, single_image_embeds in enumerate(image_embeds):
441
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
442
+ single_image_embeds = single_image_embeds.to(device=device)
443
+ ip_adapter_image_embeds.append(single_image_embeds)
444
+
445
+ return ip_adapter_image_embeds
446
+
447
+ def check_inputs(
448
+ self,
449
+ prompt,
450
+ prompt_2,
451
+ height,
452
+ width,
453
+ negative_prompt=None,
454
+ negative_prompt_2=None,
455
+ prompt_embeds=None,
456
+ negative_prompt_embeds=None,
457
+ pooled_prompt_embeds=None,
458
+ negative_pooled_prompt_embeds=None,
459
+ callback_on_step_end_tensor_inputs=None,
460
+ max_sequence_length=None,
461
+ ):
462
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
463
+ logger.warning(
464
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
465
+ )
466
+
467
+ if callback_on_step_end_tensor_inputs is not None and not all(
468
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
469
+ ):
470
+ raise ValueError(
471
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
472
+ )
473
+
474
+ if prompt is not None and prompt_embeds is not None:
475
+ raise ValueError(
476
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
477
+ " only forward one of the two."
478
+ )
479
+ elif prompt_2 is not None and prompt_embeds is not None:
480
+ raise ValueError(
481
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
482
+ " only forward one of the two."
483
+ )
484
+ elif prompt is None and prompt_embeds is None:
485
+ raise ValueError(
486
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
487
+ )
488
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
489
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
490
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
491
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
492
+
493
+ if negative_prompt is not None and negative_prompt_embeds is not None:
494
+ raise ValueError(
495
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
496
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
497
+ )
498
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
499
+ raise ValueError(
500
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
501
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
502
+ )
503
+
504
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
505
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
506
+ raise ValueError(
507
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
508
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
509
+ f" {negative_prompt_embeds.shape}."
510
+ )
511
+
512
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
513
+ raise ValueError(
514
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
515
+ )
516
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
517
+ raise ValueError(
518
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
519
+ )
520
+
521
+ if max_sequence_length is not None and max_sequence_length > 512:
522
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
523
+
524
+ @staticmethod
525
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
526
+ latent_image_ids = torch.zeros(height, width, 3)
527
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
528
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
529
+
530
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
531
+
532
+ latent_image_ids = latent_image_ids.reshape(
533
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
534
+ )
535
+
536
+ return latent_image_ids.to(device=device, dtype=dtype)
537
+
538
+ @staticmethod
539
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
540
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
541
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
542
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
543
+
544
+ return latents
545
+
546
+ @staticmethod
547
+ def _unpack_latents(latents, height, width, vae_scale_factor):
548
+ batch_size, num_patches, channels = latents.shape
549
+
550
+ # VAE applies 8x compression on images but we must also account for packing which requires
551
+ # latent height and width to be divisible by 2.
552
+ height = 2 * (int(height) // (vae_scale_factor * 2))
553
+ width = 2 * (int(width) // (vae_scale_factor * 2))
554
+
555
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
556
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
557
+
558
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
559
+
560
+ return latents
561
+
562
+
563
+ def blend_v(self, a, b, blend_extent):
564
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
565
+ for y in range(blend_extent):
566
+ b[:, :,
567
+ y, :] = a[:, :, -blend_extent
568
+ + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (
569
+ y / blend_extent)
570
+ return b
571
+
572
+ def blend_h(self, a, b, blend_extent):
573
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
574
+ for x in range(blend_extent):
575
+ b[:, :, :, x] = a[:, :, :, -blend_extent
576
+ + x] * (1 - x / blend_extent) + b[:, :, :, x] * (
577
+ x / blend_extent)
578
+ return b
579
+
580
+
581
+
582
+ def enable_vae_slicing(self):
583
+ r"""
584
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
585
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
586
+ """
587
+ self.vae.enable_slicing()
588
+
589
+ def disable_vae_slicing(self):
590
+ r"""
591
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
592
+ computing decoding in one step.
593
+ """
594
+ self.vae.disable_slicing()
595
+
596
+ def enable_vae_tiling(self):
597
+ r"""
598
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
599
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
600
+ processing larger images.
601
+ """
602
+ self.vae.enable_tiling()
603
+
604
+ def disable_vae_tiling(self):
605
+ r"""
606
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
607
+ computing decoding in one step.
608
+ """
609
+ self.vae.disable_tiling()
610
+
611
+ def prepare_latents(
612
+ self,
613
+ batch_size,
614
+ num_channels_latents,
615
+ height,
616
+ width,
617
+ dtype,
618
+ device,
619
+ generator,
620
+ latents=None,
621
+ ):
622
+ # VAE applies 8x compression on images but we must also account for packing which requires
623
+ # latent height and width to be divisible by 2.
624
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
625
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
626
+
627
+ shape = (batch_size, num_channels_latents, height, width)
628
+
629
+ if latents is not None:
630
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
631
+ return latents.to(device=device, dtype=dtype), latent_image_ids
632
+
633
+ if isinstance(generator, list) and len(generator) != batch_size:
634
+ raise ValueError(
635
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
636
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
637
+ )
638
+
639
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
640
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
641
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
642
+
643
+ return latents, latent_image_ids
644
+
645
+ @property
646
+ def guidance_scale(self):
647
+ return self._guidance_scale
648
+
649
+ @property
650
+ def joint_attention_kwargs(self):
651
+ return self._joint_attention_kwargs
652
+
653
+ @property
654
+ def num_timesteps(self):
655
+ return self._num_timesteps
656
+
657
+ @property
658
+ def interrupt(self):
659
+ return self._interrupt
660
+
661
+ @torch.no_grad()
662
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
663
+ def __call__(
664
+ self,
665
+ prompt: Union[str, List[str]] = None,
666
+ prompt_2: Optional[Union[str, List[str]]] = None,
667
+ negative_prompt: Union[str, List[str]] = None,
668
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
669
+ true_cfg_scale: float = 1.0,
670
+ height: Optional[int] = None,
671
+ width: Optional[int] = None,
672
+ num_inference_steps: int = 28,
673
+ sigmas: Optional[List[float]] = None,
674
+ guidance_scale: float = 3.5,
675
+ num_images_per_prompt: Optional[int] = 1,
676
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
677
+ latents: Optional[torch.FloatTensor] = None,
678
+ prompt_embeds: Optional[torch.FloatTensor] = None,
679
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
680
+ ip_adapter_image: Optional[PipelineImageInput] = None,
681
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
682
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
683
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
684
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
685
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
686
+ output_type: Optional[str] = "pil",
687
+ return_dict: bool = True,
688
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
689
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
690
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
691
+ max_sequence_length: int = 512,
692
+ blend_extend: int = 6
693
+ ):
694
+ r"""
695
+ Function invoked when calling the pipeline for generation.
696
+
697
+ Args:
698
+ prompt (`str` or `List[str]`, *optional*):
699
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
700
+ instead.
701
+ prompt_2 (`str` or `List[str]`, *optional*):
702
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
703
+ will be used instead
704
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
705
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
706
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
707
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
708
+ num_inference_steps (`int`, *optional*, defaults to 50):
709
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
710
+ expense of slower inference.
711
+ sigmas (`List[float]`, *optional*):
712
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
713
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
714
+ will be used.
715
+ guidance_scale (`float`, *optional*, defaults to 7.0):
716
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
717
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
718
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
719
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
720
+ usually at the expense of lower image quality.
721
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
722
+ The number of images to generate per prompt.
723
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
724
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
725
+ to make generation deterministic.
726
+ latents (`torch.FloatTensor`, *optional*):
727
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
728
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
729
+ tensor will ge generated by sampling using the supplied random `generator`.
730
+ prompt_embeds (`torch.FloatTensor`, *optional*):
731
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
732
+ provided, text embeddings will be generated from `prompt` input argument.
733
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
734
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
735
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
736
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
737
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
738
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
739
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
740
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
741
+ negative_ip_adapter_image:
742
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
743
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
744
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
745
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
746
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
747
+ output_type (`str`, *optional*, defaults to `"pil"`):
748
+ The output format of the generate image. Choose between
749
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
750
+ return_dict (`bool`, *optional*, defaults to `True`):
751
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
752
+ joint_attention_kwargs (`dict`, *optional*):
753
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
754
+ `self.processor` in
755
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
756
+ callback_on_step_end (`Callable`, *optional*):
757
+ A function that calls at the end of each denoising steps during the inference. The function is called
758
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
759
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
760
+ `callback_on_step_end_tensor_inputs`.
761
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
762
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
763
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
764
+ `._callback_tensor_inputs` attribute of your pipeline class.
765
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
766
+
767
+ Examples:
768
+
769
+ Returns:
770
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
771
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
772
+ images.
773
+ """
774
+
775
+ self.vae.enable_tiling()
776
+
777
+ ###########################################
778
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
779
+
780
+ # if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
781
+ # return self.tiled_decode(z, return_dict=return_dict)
782
+
783
+ if self.use_tiling:
784
+ return self.tiled_decode(z, return_dict=return_dict)
785
+
786
+ if self.post_quant_conv is not None:
787
+ z = self.post_quant_conv(z)
788
+
789
+ dec = self.decoder(z)
790
+
791
+ if not return_dict:
792
+ return (dec,)
793
+
794
+ return DecoderOutput(sample=dec)
795
+
796
+
797
+
798
+
799
+ ###########################################
800
+ def tiled_decode(
801
+ self,
802
+ z: torch.FloatTensor,
803
+ return_dict: bool = True
804
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
805
+
806
+
807
+ r"""Decode a batch of images using a tiled decoder.
808
+
809
+ Args:
810
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
811
+ steps. This is useful to keep memory use constant regardless of image size.
812
+ The end result of tiled decoding is: different from non-tiled decoding due to each tile using a different
813
+ decoder. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output.
814
+ You may still see tile-sized changes in the look of the output, but they should be much less noticeable.
815
+ z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to
816
+ `True`):
817
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
818
+ """
819
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
820
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
821
+ row_limit = self.tile_sample_min_size - blend_extent
822
+
823
+ w = z.shape[3]
824
+
825
+ z = torch.cat([z, z[:, :, :, :2]], dim=-1) #[1, 16, 64, 160]
826
+
827
+ # z = torch.cat([z, z[:, :, :, :w // 32]], dim=-1) #[1, 16, 64, 160]
828
+ # Split z into overlapping 64x64 tiles and decode them separately.
829
+ # The tiles have an overlap to avoid seams between tiles.
830
+
831
+ rows = []
832
+ for i in range(0, z.shape[2], overlap_size):
833
+ row = []
834
+ tile = z[:, :, i:i + self.tile_latent_min_size, :]
835
+ if self.config.use_post_quant_conv:
836
+ tile = self.post_quant_conv(tile)
837
+
838
+ decoded = self.decoder(tile)
839
+ vae_scale_factor = decoded.shape[-1] // tile.shape[-1]
840
+ row.append(decoded)
841
+ rows.append(row)
842
+ result_rows = []
843
+ for i, row in enumerate(rows):
844
+ result_row = []
845
+ for j, tile in enumerate(row):
846
+ # blend the above tile and the left tile
847
+ # to the current tile and add the current tile to the result row
848
+ if i > 0:
849
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
850
+ if j > 0:
851
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
852
+ result_row.append(
853
+ self.blend_h(
854
+ tile[:, :, :row_limit, w * vae_scale_factor:],
855
+ tile[:, :, :row_limit, :w * vae_scale_factor],
856
+ tile.shape[-1] - w * vae_scale_factor))
857
+ result_rows.append(torch.cat(result_row, dim=3))
858
+
859
+ dec = torch.cat(result_rows, dim=2)
860
+ if not return_dict:
861
+ return (dec, )
862
+ return DecoderOutput(sample=dec)
863
+
864
+ self.vae.tiled_decode = tiled_decode.__get__(self.vae, AutoencoderKL)
865
+ self.vae._decode = _decode.__get__(self.vae, AutoencoderKL)
866
+
867
+ self.blend_extend = blend_extend
868
+
869
+ # self.blend_extend = width // self.vae_scale_factor // 32
870
+ ###########################################
871
+
872
+
873
+
874
+ height = height or self.default_sample_size * self.vae_scale_factor
875
+ width = width or self.default_sample_size * self.vae_scale_factor
876
+
877
+ # 1. Check inputs. Raise error if not correct
878
+ self.check_inputs(
879
+ prompt,
880
+ prompt_2,
881
+ height,
882
+ width,
883
+ negative_prompt=negative_prompt,
884
+ negative_prompt_2=negative_prompt_2,
885
+ prompt_embeds=prompt_embeds,
886
+ negative_prompt_embeds=negative_prompt_embeds,
887
+ pooled_prompt_embeds=pooled_prompt_embeds,
888
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
889
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
890
+ max_sequence_length=max_sequence_length,
891
+ )
892
+
893
+ self._guidance_scale = guidance_scale
894
+ self._joint_attention_kwargs = joint_attention_kwargs
895
+ self._interrupt = False
896
+
897
+ # 2. Define call parameters
898
+ if prompt is not None and isinstance(prompt, str):
899
+ batch_size = 1
900
+ elif prompt is not None and isinstance(prompt, list):
901
+ batch_size = len(prompt)
902
+ else:
903
+ batch_size = prompt_embeds.shape[0]
904
+
905
+ device = self._execution_device
906
+
907
+ lora_scale = (
908
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
909
+ )
910
+ do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
911
+ (
912
+ prompt_embeds,
913
+ pooled_prompt_embeds,
914
+ text_ids,
915
+ ) = self.encode_prompt(
916
+ prompt=prompt,
917
+ prompt_2=prompt_2,
918
+ prompt_embeds=prompt_embeds,
919
+ pooled_prompt_embeds=pooled_prompt_embeds,
920
+ device=device,
921
+ num_images_per_prompt=num_images_per_prompt,
922
+ max_sequence_length=max_sequence_length,
923
+ lora_scale=lora_scale,
924
+ )
925
+ if do_true_cfg:
926
+ (
927
+ negative_prompt_embeds,
928
+ negative_pooled_prompt_embeds,
929
+ _,
930
+ ) = self.encode_prompt(
931
+ prompt=negative_prompt,
932
+ prompt_2=negative_prompt_2,
933
+ prompt_embeds=negative_prompt_embeds,
934
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
935
+ device=device,
936
+ num_images_per_prompt=num_images_per_prompt,
937
+ max_sequence_length=max_sequence_length,
938
+ lora_scale=lora_scale,
939
+ )
940
+
941
+ # 4. Prepare latent variables
942
+ num_channels_latents = self.transformer.config.in_channels // 4
943
+ latents, latent_image_ids = self.prepare_latents(
944
+ batch_size * num_images_per_prompt,
945
+ num_channels_latents,
946
+ height,
947
+ width,
948
+ prompt_embeds.dtype,
949
+ device,
950
+ generator,
951
+ latents,
952
+ )
953
+
954
+ latents_unpack = self._unpack_latents(latents, height, width, self.vae_scale_factor)
955
+ latents_unpack = torch.cat([latents_unpack, latents_unpack[:, :, :, :self.blend_extend]], dim=-1)
956
+ width_new_blended = latents_unpack.shape[-1] * 8
957
+ latent_image_ids = self._prepare_latent_image_ids(batch_size * num_images_per_prompt,
958
+ height // 16, width_new_blended // 16,
959
+ latents.device,
960
+ latents.dtype)
961
+ latents = self._pack_latents(latents_unpack, batch_size, num_channels_latents, height // 8, width_new_blended // 8)
962
+
963
+
964
+
965
+
966
+ # 5. Prepare timesteps
967
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
968
+ image_seq_len = latents.shape[1]
969
+ mu = calculate_shift(
970
+ image_seq_len,
971
+ self.scheduler.config.base_image_seq_len,
972
+ self.scheduler.config.max_image_seq_len,
973
+ self.scheduler.config.base_shift,
974
+ self.scheduler.config.max_shift,
975
+ )
976
+ timesteps, num_inference_steps = retrieve_timesteps(
977
+ self.scheduler,
978
+ num_inference_steps,
979
+ device,
980
+ sigmas=sigmas,
981
+ mu=mu,
982
+ )
983
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
984
+ self._num_timesteps = len(timesteps)
985
+
986
+ # handle guidance
987
+ if self.transformer.config.guidance_embeds:
988
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
989
+ guidance = guidance.expand(latents.shape[0])
990
+ else:
991
+ guidance = None
992
+
993
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
994
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
995
+ ):
996
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
997
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
998
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
999
+ ):
1000
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
1001
+
1002
+ if self.joint_attention_kwargs is None:
1003
+ self._joint_attention_kwargs = {}
1004
+
1005
+ image_embeds = None
1006
+ negative_image_embeds = None
1007
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1008
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1009
+ ip_adapter_image,
1010
+ ip_adapter_image_embeds,
1011
+ device,
1012
+ batch_size * num_images_per_prompt,
1013
+ )
1014
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
1015
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
1016
+ negative_ip_adapter_image,
1017
+ negative_ip_adapter_image_embeds,
1018
+ device,
1019
+ batch_size * num_images_per_prompt,
1020
+ )
1021
+
1022
+ # 6. Denoising loop
1023
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1024
+ for i, t in enumerate(timesteps):
1025
+ if self.interrupt:
1026
+ continue
1027
+
1028
+ if image_embeds is not None:
1029
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
1030
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1031
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
1032
+
1033
+ noise_pred = self.transformer(
1034
+ hidden_states=latents,
1035
+ timestep=timestep / 1000,
1036
+ guidance=guidance,
1037
+ pooled_projections=pooled_prompt_embeds,
1038
+ encoder_hidden_states=prompt_embeds,
1039
+ txt_ids=text_ids,
1040
+ img_ids=latent_image_ids,
1041
+ joint_attention_kwargs=self.joint_attention_kwargs,
1042
+ return_dict=False,
1043
+ )[0]
1044
+
1045
+ if do_true_cfg:
1046
+ if negative_image_embeds is not None:
1047
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
1048
+ neg_noise_pred = self.transformer(
1049
+ hidden_states=latents,
1050
+ timestep=timestep / 1000,
1051
+ guidance=guidance,
1052
+ pooled_projections=negative_pooled_prompt_embeds,
1053
+ encoder_hidden_states=negative_prompt_embeds,
1054
+ txt_ids=text_ids,
1055
+ img_ids=latent_image_ids,
1056
+ joint_attention_kwargs=self.joint_attention_kwargs,
1057
+ return_dict=False,
1058
+ )[0]
1059
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
1060
+
1061
+ # compute the previous noisy sample x_t -> x_t-1
1062
+ latents_dtype = latents.dtype
1063
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1064
+
1065
+
1066
+ ### ================================== ###
1067
+ latents_unpack = self._unpack_latents(latents, height, width_new_blended, self.vae_scale_factor)
1068
+ latents_unpack = self.blend_h(latents_unpack, latents_unpack, self.blend_extend)
1069
+ latents = self._pack_latents(latents_unpack, batch_size, num_channels_latents, height // 8, width_new_blended // 8)
1070
+ ##########################################
1071
+
1072
+ if latents.dtype != latents_dtype:
1073
+ if torch.backends.mps.is_available():
1074
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1075
+ latents = latents.to(latents_dtype)
1076
+
1077
+ if callback_on_step_end is not None:
1078
+ callback_kwargs = {}
1079
+ for k in callback_on_step_end_tensor_inputs:
1080
+ callback_kwargs[k] = locals()[k]
1081
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1082
+
1083
+ latents = callback_outputs.pop("latents", latents)
1084
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1085
+
1086
+ # call the callback, if provided
1087
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1088
+ progress_bar.update()
1089
+
1090
+ if XLA_AVAILABLE:
1091
+ xm.mark_step()
1092
+
1093
+ latents_unpack = self._unpack_latents(latents, height, width_new_blended, self.vae_scale_factor)
1094
+ latents_unpack = self.blend_h(latents_unpack, latents_unpack, self.blend_extend)
1095
+ latents_unpack = latents_unpack[:, :, :, :width // self.vae_scale_factor]
1096
+ latents = self._pack_latents(latents_unpack, batch_size, num_channels_latents, height // 8, width // 8)
1097
+
1098
+ if output_type == "latent":
1099
+ image = latents
1100
+
1101
+ else:
1102
+
1103
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
1104
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1105
+ image = self.vae.decode(latents, return_dict=False)[0]
1106
+ image = self.image_processor.postprocess(image, output_type=output_type)
1107
+
1108
+ # Offload all models
1109
+ self.maybe_free_model_hooks()
1110
+
1111
+ if not return_dict:
1112
+ return (image,)
1113
+
1114
+ return FluxPipelineOutput(images=image)