AlekseyCalvin commited on
Commit
0018e73
·
verified ·
1 Parent(s): 0fa63d6

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +152 -1
pipeline.py CHANGED
@@ -56,7 +56,158 @@ def prepare_timesteps(
56
 
57
  # FLUX pipeline function
58
  class FluxWithCFGPipeline(FluxPipeline):
59
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  @torch.inference_mode()
61
  def generate_images(
62
  self,
 
56
 
57
  # FLUX pipeline function
58
  class FluxWithCFGPipeline(FluxPipeline):
59
+ def __call__(
60
+ self,
61
+ prompt: Union[str, List[str]] = None,
62
+ prompt_2: Optional[Union[str, List[str]]] = None,
63
+ height: Optional[int] = None,
64
+ width: Optional[int] = None,
65
+ negative_prompt: Optional[Union[str, List[str]]] = None,
66
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
67
+ num_inference_steps: int = 4,
68
+ timesteps: List[int] = None,
69
+ guidance_scale: float = 3.5,
70
+ num_images_per_prompt: Optional[int] = 1,
71
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
72
+ latents: Optional[torch.FloatTensor] = None,
73
+ prompt_embeds: Optional[torch.FloatTensor] = None,
74
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
75
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
76
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
77
+ output_type: Optional[str] = "pil",
78
+ return_dict: bool = True,
79
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
80
+ max_sequence_length: int = 300,
81
+ ):
82
+ height = height or self.default_sample_size * self.vae_scale_factor
83
+ width = width or self.default_sample_size * self.vae_scale_factor
84
+
85
+ # 1. Check inputs
86
+ self.check_inputs(
87
+ prompt,
88
+ prompt_2,
89
+ negative_prompt,
90
+ height,
91
+ width,
92
+ prompt_embeds=prompt_embeds,
93
+ pooled_prompt_embeds=pooled_prompt_embeds,
94
+ max_sequence_length=max_sequence_length,
95
+ )
96
+
97
+ self._guidance_scale = guidance_scale
98
+ self._joint_attention_kwargs = joint_attention_kwargs
99
+ self._interrupt = False
100
+
101
+ # 2. Define call parameters
102
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
103
+ device = "cuda" if torch.cuda.is_available() else "cpu"
104
+
105
+ # 3. Encode prompt
106
+ lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
107
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
108
+ prompt=prompt,
109
+ prompt_2=prompt_2,
110
+ prompt_embeds=prompt_embeds,
111
+ pooled_prompt_embeds=pooled_prompt_embeds,
112
+ device=device,
113
+ num_images_per_prompt=num_images_per_prompt,
114
+ max_sequence_length=max_sequence_length,
115
+ lora_scale=lora_scale,
116
+ )
117
+ negative_prompt_embeds, negative_pooled_prompt_embeds, negative_text_ids = self.encode_prompt(
118
+ prompt=negative_prompt,
119
+ prompt_2=negative_prompt_2,
120
+ prompt_embeds=negative_prompt_embeds,
121
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
122
+ device=device,
123
+ num_images_per_prompt=num_images_per_prompt,
124
+ max_sequence_length=max_sequence_length,
125
+ lora_scale=lora_scale,
126
+ )
127
+
128
+ # 4. Prepare latent variables
129
+ num_channels_latents = self.transformer.config.in_channels // 4
130
+ latents, latent_image_ids = self.prepare_latents(
131
+ batch_size * num_images_per_prompt,
132
+ num_channels_latents,
133
+ height,
134
+ width,
135
+ prompt_embeds.dtype,
136
+ negative_prompt_embeds.dtype,
137
+ device,
138
+ generator,
139
+ latents,
140
+ )
141
+
142
+ # 5. Prepare timesteps
143
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
144
+ image_seq_len = latents.shape[1]
145
+ mu = calculate_timestep_shift(image_seq_len)
146
+ timesteps, num_inference_steps = prepare_timesteps(
147
+ self.scheduler,
148
+ num_inference_steps,
149
+ device,
150
+ timesteps,
151
+ sigmas,
152
+ mu=mu,
153
+ )
154
+ self._num_timesteps = len(timesteps)
155
+
156
+ # Handle guidance
157
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
158
+
159
+ # 6. Denoising loop
160
+ for i, t in enumerate(timesteps):
161
+ if self.interrupt:
162
+ continue
163
+
164
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
165
+
166
+ noise_pred = self.transformer(
167
+ hidden_states=latents,
168
+ timestep=timestep / 1000,
169
+ guidance=guidance,
170
+ pooled_projections=pooled_prompt_embeds,
171
+ encoder_hidden_states=prompt_embeds,
172
+ txt_ids=text_ids,
173
+ img_ids=latent_image_ids,
174
+ joint_attention_kwargs=self.joint_attention_kwargs,
175
+ return_dict=False,
176
+ )[0]
177
+
178
+ noise_pred_uncond = self.transformer(
179
+ hidden_states=latents,
180
+ timestep=timestep / 1000,
181
+ guidance=guidance,
182
+ pooled_projections=negative_pooled_prompt_embeds,
183
+ encoder_hidden_states=negative_prompt_embeds,
184
+ txt_ids=negative_text_ids,
185
+ img_ids=latent_image_ids,
186
+ joint_attention_kwargs=self.joint_attention_kwargs,
187
+ return_dict=False,
188
+ )[0]
189
+
190
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
191
+
192
+ latents_dtype = latents.dtype
193
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
194
+ # Yield intermediate result
195
+ torch.cuda.empty_cache()
196
+
197
+ # Final image
198
+ return self._decode_latents_to_image(latents, height, width, output_type)
199
+ self.maybe_free_model_hooks()
200
+ torch.cuda.empty_cache()
201
+
202
+ def _decode_latents_to_image(self, latents, height, width, output_type, vae=None):
203
+ """Decodes the given latents into an image."""
204
+ vae = vae or self.vae
205
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
206
+ latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
207
+ image = vae.decode(latents, return_dict=False)[0]
208
+ return self.image_processor.postprocess(image, output_type=output_type)[0]
209
+
210
+ class FluxWithCFGPipeline(FluxPipeline):
211
  @torch.inference_mode()
212
  def generate_images(
213
  self,