Text-to-Image
Diffusers
Safetensors
jiuntian commited on
Commit
122e973
·
verified ·
1 Parent(s): 18973cc

Upload folder using huggingface_hub

Browse files
model_index.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": ["pipeline_gligen_sdxl", "StableDiffusionXLGLIGENPipeline"],
3
+ "_diffusers_version": "0.30.1",
4
+ "_name_or_path": "jiuntian/gligen-xl-1024",
5
+ "feature_extractor": [
6
+ null,
7
+ null
8
+ ],
9
+ "force_zeros_for_empty_prompt": true,
10
+ "image_encoder": [
11
+ null,
12
+ null
13
+ ],
14
+ "scheduler": [
15
+ "diffusers",
16
+ "EulerDiscreteScheduler"
17
+ ],
18
+ "text_encoder": [
19
+ "transformers",
20
+ "CLIPTextModel"
21
+ ],
22
+ "text_encoder_2": [
23
+ "transformers",
24
+ "CLIPTextModelWithProjection"
25
+ ],
26
+ "tokenizer": [
27
+ "transformers",
28
+ "CLIPTokenizer"
29
+ ],
30
+ "tokenizer_2": [
31
+ "transformers",
32
+ "CLIPTokenizer"
33
+ ],
34
+ "unet": [
35
+ "diffusers",
36
+ "UNet2DConditionModel"
37
+ ],
38
+ "vae": [
39
+ "diffusers",
40
+ "AutoencoderKL"
41
+ ]
42
+ }
pipeline_gligen_sdxl.py ADDED
@@ -0,0 +1,1406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3
+ import warnings
4
+
5
+ import PIL.Image
6
+ import torch
7
+
8
+ from transformers import (
9
+ CLIPImageProcessor,
10
+ CLIPTextModel,
11
+ CLIPTextModelWithProjection,
12
+ CLIPTokenizer,
13
+ CLIPVisionModelWithProjection,
14
+ )
15
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
16
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
17
+ from diffusers.loaders import (
18
+ FromSingleFileMixin,
19
+ IPAdapterMixin,
20
+ StableDiffusionXLLoraLoaderMixin,
21
+ TextualInversionLoaderMixin,
22
+ )
23
+ from diffusers.models import (
24
+ AutoencoderKL,
25
+ ImageProjection,
26
+ UNet2DConditionModel
27
+ )
28
+ from diffusers.models.attention import GatedSelfAttentionDense
29
+ from diffusers.models.attention_processor import (
30
+ AttnProcessor2_0,
31
+ FusedAttnProcessor2_0,
32
+ XFormersAttnProcessor,
33
+ )
34
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
35
+ from diffusers.pipelines.pipeline_utils import (
36
+ DiffusionPipeline,
37
+ StableDiffusionMixin
38
+ )
39
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg, retrieve_timesteps
40
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
41
+ from diffusers.schedulers import (
42
+ KarrasDiffusionSchedulers
43
+ )
44
+ from diffusers.utils import (
45
+ USE_PEFT_BACKEND,
46
+ deprecate,
47
+ is_invisible_watermark_available,
48
+ is_torch_xla_available,
49
+ logging,
50
+ replace_example_docstring,
51
+ scale_lora_layers,
52
+ unscale_lora_layers,
53
+ )
54
+ from diffusers.utils.torch_utils import randn_tensor
55
+
56
+
57
+ if is_invisible_watermark_available():
58
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
59
+
60
+ if is_torch_xla_available():
61
+ import torch_xla.core.xla_model as xm # type: ignore
62
+
63
+ XLA_AVAILABLE = True
64
+ else:
65
+ XLA_AVAILABLE = False
66
+
67
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
68
+
69
+ EXAMPLE_DOC_STRING = """
70
+ Examples:
71
+ ```py
72
+ >>> import torch
73
+ >>> from pipeline_gligen_sdxl import StableDiffusionXLGLIGENPipeline
74
+
75
+ >>> pipe = StableDiffusionXLGLIGENPipeline.from_pretrained(
76
+ ... "xxx", torch_dtype=torch.float16
77
+ ... )
78
+ >>> pipe = pipe.to("cuda")
79
+
80
+ >>> prompt = "a waterfall and a modern high speed train running through the tunnel in a beautiful forest with fall foliage"
81
+ >>> boxes = [[0.1387, 0.2051, 0.4277, 0.7090], [0.4980, 0.4355, 0.8516, 0.7266]]
82
+ >>> phrases = ["a waterfall", "a modern high speed train running through the tunnel"]
83
+
84
+ >>> images = pipe(
85
+ ... prompt=prompt,
86
+ ... gligen_phrases=phrases,
87
+ ... gligen_boxes=boxes,
88
+ ... gligen_scheduled_sampling_beta=1,
89
+ ... output_type="pil",
90
+ ... num_inference_steps=50,
91
+ ... ).images
92
+
93
+ >>> images[0].save("./gligen-xl-generation-text-box.jpg")
94
+ ```
95
+ """
96
+
97
+ class StableDiffusionXLGLIGENPipeline(
98
+ DiffusionPipeline,
99
+ StableDiffusionMixin,
100
+ FromSingleFileMixin,
101
+ StableDiffusionXLLoraLoaderMixin,
102
+ TextualInversionLoaderMixin,
103
+ IPAdapterMixin,
104
+ ):
105
+ r"""
106
+ Pipeline for GLIGEN layout text-to-image generation using Stable Diffusion XL.
107
+ """
108
+
109
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
110
+ _optional_components = [
111
+ "tokenizer",
112
+ "tokenizer_2",
113
+ "text_encoder",
114
+ "text_encoder_2",
115
+ "image_encoder",
116
+ "feature_extractor",
117
+ ]
118
+ _callback_tensor_inputs = [
119
+ "latents",
120
+ "prompt_embeds",
121
+ "negative_prompt_embeds",
122
+ "add_text_embeds",
123
+ "add_time_ids",
124
+ "negative_pooled_prompt_embeds",
125
+ "negative_add_time_ids",
126
+ ]
127
+
128
+
129
+ def __init__(
130
+ self,
131
+ vae: AutoencoderKL,
132
+ text_encoder: CLIPTextModel,
133
+ text_encoder_2: CLIPTextModelWithProjection,
134
+ tokenizer: CLIPTokenizer,
135
+ tokenizer_2: CLIPTokenizer,
136
+ unet: UNet2DConditionModel,
137
+ scheduler: KarrasDiffusionSchedulers,
138
+ image_encoder: CLIPVisionModelWithProjection = None,
139
+ feature_extractor: CLIPImageProcessor = None,
140
+ force_zeros_for_empty_prompt: bool = True,
141
+ add_watermarker: Optional[bool] = None,
142
+ ):
143
+ super().__init__()
144
+
145
+ self.register_modules(
146
+ vae=vae,
147
+ text_encoder=text_encoder,
148
+ text_encoder_2=text_encoder_2,
149
+ tokenizer=tokenizer,
150
+ tokenizer_2=tokenizer_2,
151
+ unet=unet,
152
+ scheduler=scheduler,
153
+ image_encoder=image_encoder,
154
+ feature_extractor=feature_extractor,
155
+ )
156
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
157
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
158
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
159
+
160
+ self.default_sample_size = self.unet.config.sample_size
161
+
162
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
163
+
164
+ if add_watermarker:
165
+ self.watermark = StableDiffusionXLWatermarker()
166
+ else:
167
+ self.watermark = None
168
+ # copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionPipelineXL.encode_prompt
169
+ def encode_prompt(
170
+ self,
171
+ prompt: str,
172
+ prompt_2: Optional[str] = None,
173
+ device: Optional[torch.device] = None,
174
+ num_images_per_prompt: int = 1,
175
+ do_classifier_free_guidance: bool = True,
176
+ negative_prompt: Optional[str] = None,
177
+ negative_prompt_2: Optional[str] = None,
178
+ prompt_embeds: Optional[torch.Tensor] = None,
179
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
180
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
181
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
182
+ lora_scale: Optional[float] = None,
183
+ clip_skip: Optional[int] = None,
184
+ ):
185
+ device = device or self._execution_device
186
+
187
+ # set lora scale so that monkey patched LoRA
188
+ # function of text encoder can correctly access it
189
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
190
+ self._lora_scale = lora_scale
191
+
192
+ # dynamically adjust the LoRA scale
193
+ if self.text_encoder is not None:
194
+ if not USE_PEFT_BACKEND:
195
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
196
+ else:
197
+ scale_lora_layers(self.text_encoder, lora_scale)
198
+
199
+ if self.text_encoder_2 is not None:
200
+ if not USE_PEFT_BACKEND:
201
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
202
+ else:
203
+ scale_lora_layers(self.text_encoder_2, lora_scale)
204
+
205
+ prompt = [prompt] if isinstance(prompt, str) else prompt
206
+
207
+ if prompt is not None:
208
+ batch_size = len(prompt)
209
+ else:
210
+ batch_size = prompt_embeds.shape[0]
211
+
212
+ # Define tokenizers and text encoders
213
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
214
+ text_encoders = (
215
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
216
+ )
217
+
218
+ if prompt_embeds is None:
219
+ prompt_2 = prompt_2 or prompt
220
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
221
+
222
+ # textual inversion: process multi-vector tokens if necessary
223
+ prompt_embeds_list = []
224
+ prompts = [prompt, prompt_2]
225
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
226
+ if isinstance(self, TextualInversionLoaderMixin):
227
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
228
+
229
+ text_inputs = tokenizer(
230
+ prompt,
231
+ padding="max_length",
232
+ max_length=tokenizer.model_max_length,
233
+ truncation=True,
234
+ return_tensors="pt",
235
+ )
236
+
237
+ text_input_ids = text_inputs.input_ids
238
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
239
+
240
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
241
+ text_input_ids, untruncated_ids
242
+ ):
243
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
244
+ logger.warning(
245
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
246
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
247
+ )
248
+
249
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
250
+
251
+ # We are only ALWAYS interested in the pooled output of the final text encoder
252
+ pooled_prompt_embeds = prompt_embeds[0]
253
+ if clip_skip is None:
254
+ prompt_embeds = prompt_embeds.hidden_states[-2]
255
+ else:
256
+ # "2" because SDXL always indexes from the penultimate layer.
257
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
258
+
259
+ prompt_embeds_list.append(prompt_embeds)
260
+
261
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
262
+
263
+ # get unconditional embeddings for classifier free guidance
264
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
265
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
266
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
267
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
268
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
269
+ negative_prompt = negative_prompt or ""
270
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
271
+
272
+ # normalize str to list
273
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
274
+ negative_prompt_2 = (
275
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
276
+ )
277
+
278
+ uncond_tokens: List[str]
279
+ if prompt is not None and type(prompt) is not type(negative_prompt):
280
+ raise TypeError(
281
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
282
+ f" {type(prompt)}."
283
+ )
284
+ elif batch_size != len(negative_prompt):
285
+ raise ValueError(
286
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
287
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
288
+ " the batch size of `prompt`."
289
+ )
290
+ else:
291
+ uncond_tokens = [negative_prompt, negative_prompt_2]
292
+
293
+ negative_prompt_embeds_list = []
294
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
295
+ if isinstance(self, TextualInversionLoaderMixin):
296
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
297
+
298
+ max_length = prompt_embeds.shape[1]
299
+ uncond_input = tokenizer(
300
+ negative_prompt,
301
+ padding="max_length",
302
+ max_length=max_length,
303
+ truncation=True,
304
+ return_tensors="pt",
305
+ )
306
+
307
+ negative_prompt_embeds = text_encoder(
308
+ uncond_input.input_ids.to(device),
309
+ output_hidden_states=True,
310
+ )
311
+ # We are only ALWAYS interested in the pooled output of the final text encoder
312
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
313
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
314
+
315
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
316
+
317
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
318
+
319
+ if self.text_encoder_2 is not None:
320
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
321
+ else:
322
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
323
+
324
+ bs_embed, seq_len, _ = prompt_embeds.shape
325
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
326
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
327
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
328
+
329
+ if do_classifier_free_guidance:
330
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
331
+ seq_len = negative_prompt_embeds.shape[1]
332
+
333
+ if self.text_encoder_2 is not None:
334
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
335
+ else:
336
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
337
+
338
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
339
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
340
+
341
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
342
+ bs_embed * num_images_per_prompt, -1
343
+ )
344
+ if do_classifier_free_guidance:
345
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
346
+ bs_embed * num_images_per_prompt, -1
347
+ )
348
+
349
+ if self.text_encoder is not None:
350
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
351
+ # Retrieve the original scale by scaling back the LoRA layers
352
+ unscale_lora_layers(self.text_encoder, lora_scale)
353
+
354
+ if self.text_encoder_2 is not None:
355
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
356
+ # Retrieve the original scale by scaling back the LoRA layers
357
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
358
+
359
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
360
+
361
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
362
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
363
+ dtype = next(self.image_encoder.parameters()).dtype
364
+
365
+ if not isinstance(image, torch.Tensor):
366
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
367
+
368
+ image = image.to(device=device, dtype=dtype)
369
+ if output_hidden_states:
370
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
371
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
372
+ uncond_image_enc_hidden_states = self.image_encoder(
373
+ torch.zeros_like(image), output_hidden_states=True
374
+ ).hidden_states[-2]
375
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
376
+ num_images_per_prompt, dim=0
377
+ )
378
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
379
+ else:
380
+ image_embeds = self.image_encoder(image).image_embeds
381
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
382
+ uncond_image_embeds = torch.zeros_like(image_embeds)
383
+
384
+ return image_embeds, uncond_image_embeds
385
+
386
+ def encode_prompt_gligen(
387
+ self,
388
+ prompt: str,
389
+ prompt_2: Optional[str] = None,
390
+ device: Optional[torch.device] = None,
391
+ num_images_per_prompt: int = 1,
392
+ gligen_embeds: Optional[torch.Tensor] = None,
393
+ lora_scale: Optional[float] = None,
394
+ clip_skip: Optional[int] = None,
395
+ ):
396
+ device = device or self._execution_device
397
+
398
+ # set lora scale so that monkey patched LoRA
399
+ # function of text encoder can correctly access it
400
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
401
+ self._lora_scale = lora_scale
402
+
403
+ # dynamically adjust the LoRA scale
404
+ if self.text_encoder is not None:
405
+ if not USE_PEFT_BACKEND:
406
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
407
+ else:
408
+ scale_lora_layers(self.text_encoder, lora_scale)
409
+
410
+ if self.text_encoder_2 is not None:
411
+ if not USE_PEFT_BACKEND:
412
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
413
+ else:
414
+ scale_lora_layers(self.text_encoder_2, lora_scale)
415
+
416
+ prompt = [prompt] if isinstance(prompt, str) else prompt
417
+
418
+ if prompt is not None:
419
+ batch_size = len(prompt)
420
+ else:
421
+ batch_size = prompt_embeds.shape[0]
422
+
423
+ # Define tokenizers and text encoders
424
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
425
+ text_encoders = (
426
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
427
+ )
428
+
429
+ if gligen_embeds is None:
430
+ prompt_2 = prompt_2 or prompt
431
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
432
+
433
+ # textual inversion: process multi-vector tokens if necessary
434
+ gligen_embeds_list = []
435
+ prompts = [prompt, prompt_2]
436
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
437
+ if isinstance(self, TextualInversionLoaderMixin):
438
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
439
+
440
+ text_inputs = tokenizer(
441
+ prompt,
442
+ padding="max_length",
443
+ max_length=tokenizer.model_max_length,
444
+ truncation=True,
445
+ return_tensors="pt",
446
+ )
447
+
448
+ text_input_ids = text_inputs.input_ids
449
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
450
+
451
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
452
+ text_input_ids, untruncated_ids
453
+ ):
454
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
455
+ logger.warning(
456
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
457
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
458
+ )
459
+
460
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
461
+
462
+ if isinstance(text_encoder, CLIPTextModel):
463
+ gligen_embeds_list.append(prompt_embeds.pooler_output)
464
+ elif isinstance(text_encoder, CLIPTextModelWithProjection):
465
+ gligen_embeds_list.append(prompt_embeds.text_embeds)
466
+
467
+ gligen_embeds = torch.concat(gligen_embeds_list, dim=-1)
468
+
469
+ if self.text_encoder is not None:
470
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
471
+ # Retrieve the original scale by scaling back the LoRA layers
472
+ unscale_lora_layers(self.text_encoder, lora_scale)
473
+
474
+ if self.text_encoder_2 is not None:
475
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
476
+ # Retrieve the original scale by scaling back the LoRA layers
477
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
478
+
479
+ return gligen_embeds
480
+
481
+ # Copied from SDXL
482
+ def prepare_ip_adapter_image_embeds(
483
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
484
+ ):
485
+ image_embeds = []
486
+ if do_classifier_free_guidance:
487
+ negative_image_embeds = []
488
+ if ip_adapter_image_embeds is None:
489
+ if not isinstance(ip_adapter_image, list):
490
+ ip_adapter_image = [ip_adapter_image]
491
+
492
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
493
+ raise ValueError(
494
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
495
+ )
496
+
497
+ for single_ip_adapter_image, image_proj_layer in zip(
498
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
499
+ ):
500
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
501
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
502
+ single_ip_adapter_image, device, 1, output_hidden_state
503
+ )
504
+
505
+ image_embeds.append(single_image_embeds[None, :])
506
+ if do_classifier_free_guidance:
507
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
508
+ else:
509
+ for single_image_embeds in ip_adapter_image_embeds:
510
+ if do_classifier_free_guidance:
511
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
512
+ negative_image_embeds.append(single_negative_image_embeds)
513
+ image_embeds.append(single_image_embeds)
514
+
515
+ ip_adapter_image_embeds = []
516
+ for i, single_image_embeds in enumerate(image_embeds):
517
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
518
+ if do_classifier_free_guidance:
519
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
520
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
521
+
522
+ single_image_embeds = single_image_embeds.to(device=device)
523
+ ip_adapter_image_embeds.append(single_image_embeds)
524
+
525
+ return ip_adapter_image_embeds
526
+
527
+ # Copied form SDXL
528
+ def prepare_extra_step_kwargs(self, generator, eta):
529
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
530
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
531
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
532
+ # and should be between [0, 1]
533
+
534
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
535
+ extra_step_kwargs = {}
536
+ if accepts_eta:
537
+ extra_step_kwargs["eta"] = eta
538
+
539
+ # check if the scheduler accepts generator
540
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
541
+ if accepts_generator:
542
+ extra_step_kwargs["generator"] = generator
543
+ return extra_step_kwargs
544
+
545
+ # Copied from SDXL and StableDiffusionGLIGENPipeline
546
+ def check_inputs(
547
+ self,
548
+ prompt,
549
+ prompt_2,
550
+ height,
551
+ width,
552
+ callback_steps,
553
+ gligen_phrases,
554
+ gligen_boxes,
555
+ negative_prompt=None,
556
+ negative_prompt_2=None,
557
+ prompt_embeds=None,
558
+ negative_prompt_embeds=None,
559
+ pooled_prompt_embeds=None,
560
+ negative_pooled_prompt_embeds=None,
561
+ ip_adapter_image=None,
562
+ ip_adapter_image_embeds=None,
563
+ callback_on_step_end_tensor_inputs=None,
564
+ ):
565
+ if height % 8 != 0 or width % 8 != 0:
566
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
567
+
568
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
569
+ raise ValueError(
570
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
571
+ f" {type(callback_steps)}."
572
+ )
573
+
574
+ if callback_on_step_end_tensor_inputs is not None and not all(
575
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
576
+ ):
577
+ raise ValueError(
578
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
579
+ )
580
+
581
+ if prompt is not None and prompt_embeds is not None:
582
+ raise ValueError(
583
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
584
+ " only forward one of the two."
585
+ )
586
+ elif prompt_2 is not None and prompt_embeds is not None:
587
+ raise ValueError(
588
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
589
+ " only forward one of the two."
590
+ )
591
+ elif prompt is None and prompt_embeds is None:
592
+ raise ValueError(
593
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
594
+ )
595
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
596
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
597
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
598
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
599
+
600
+ if negative_prompt is not None and negative_prompt_embeds is not None:
601
+ raise ValueError(
602
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
603
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
604
+ )
605
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
606
+ raise ValueError(
607
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
608
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
609
+ )
610
+
611
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
612
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
613
+ raise ValueError(
614
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
615
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
616
+ f" {negative_prompt_embeds.shape}."
617
+ )
618
+
619
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
620
+ raise ValueError(
621
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
622
+ )
623
+
624
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
625
+ raise ValueError(
626
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
627
+ )
628
+
629
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
630
+ raise ValueError(
631
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
632
+ )
633
+
634
+ if ip_adapter_image_embeds is not None:
635
+ if not isinstance(ip_adapter_image_embeds, list):
636
+ raise ValueError(
637
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
638
+ )
639
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
640
+ raise ValueError(
641
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
642
+ )
643
+
644
+ if len(gligen_phrases) != len(gligen_boxes):
645
+ raise ValueError(
646
+ "length of `gligen_phrases` and `gligen_boxes` has to be same, but"
647
+ f" got: `gligen_phrases` {len(gligen_phrases)} != `gligen_boxes` {len(gligen_boxes)}"
648
+ )
649
+
650
+ # Copied from SDXL
651
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
652
+ shape = (
653
+ batch_size,
654
+ num_channels_latents,
655
+ int(height) // self.vae_scale_factor,
656
+ int(width) // self.vae_scale_factor,
657
+ )
658
+ if isinstance(generator, list) and len(generator) != batch_size:
659
+ raise ValueError(
660
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
661
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
662
+ )
663
+
664
+ if latents is None:
665
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
666
+ else:
667
+ latents = latents.to(device)
668
+
669
+ # scale the initial noise by the standard deviation required by the scheduler
670
+ latents = latents * self.scheduler.init_noise_sigma
671
+ return latents
672
+
673
+ # Copied from SDGligenPipeline
674
+ def enable_fuser(self, enabled=True):
675
+ for module in self.unet.modules():
676
+ if type(module) is GatedSelfAttentionDense:
677
+ module.enabled = enabled
678
+
679
+ # Copied from SDGligenPipeline
680
+ def draw_inpaint_mask_from_boxes(self, boxes, size):
681
+ inpaint_mask = torch.ones(size[0], size[1])
682
+ for box in boxes:
683
+ x0, x1 = box[0] * size[0], box[2] * size[0]
684
+ y0, y1 = box[1] * size[1], box[3] * size[1]
685
+ inpaint_mask[int(y0) : int(y1), int(x0) : int(x1)] = 0
686
+ return inpaint_mask
687
+
688
+ # Copied from SDGligenPipeline
689
+ def crop(self, im, new_width, new_height):
690
+ width, height = im.size
691
+ left = (width - new_width) / 2
692
+ top = (height - new_height) / 2
693
+ right = (width + new_width) / 2
694
+ bottom = (height + new_height) / 2
695
+ return im.crop((left, top, right, bottom))
696
+
697
+ # Copied from SDGligenPipeline
698
+ def target_size_center_crop(self, im, new_hw):
699
+ width, height = im.size
700
+ if width != height:
701
+ im = self.crop(im, min(height, width), min(height, width))
702
+ return im.resize((new_hw, new_hw), PIL.Image.LANCZOS)
703
+
704
+ # Copied from SDXL
705
+ def _get_add_time_ids(
706
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
707
+ ):
708
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
709
+
710
+ passed_add_embed_dim = (
711
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
712
+ )
713
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
714
+
715
+ if expected_add_embed_dim != passed_add_embed_dim:
716
+ raise ValueError(
717
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
718
+ )
719
+
720
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
721
+ return add_time_ids
722
+
723
+ # Copied from SDXL
724
+ def upcast_vae(self):
725
+ dtype = self.vae.dtype
726
+ self.vae.to(dtype=torch.float32)
727
+ use_torch_2_0_or_xformers = isinstance(
728
+ self.vae.decoder.mid_block.attentions[0].processor,
729
+ (
730
+ AttnProcessor2_0,
731
+ XFormersAttnProcessor,
732
+ FusedAttnProcessor2_0,
733
+ ),
734
+ )
735
+ # if xformers or torch_2_0 is used attention block does not need
736
+ # to be in float32 which can save lots of memory
737
+ if use_torch_2_0_or_xformers:
738
+ self.vae.post_quant_conv.to(dtype)
739
+ self.vae.decoder.conv_in.to(dtype)
740
+ self.vae.decoder.mid_block.to(dtype)
741
+
742
+ # Copied from SDXL
743
+ def get_guidance_scale_embedding(
744
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
745
+ ) -> torch.Tensor:
746
+ """
747
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
748
+
749
+ Args:
750
+ w (`torch.Tensor`):
751
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
752
+ embedding_dim (`int`, *optional*, defaults to 512):
753
+ Dimension of the embeddings to generate.
754
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
755
+ Data type of the generated embeddings.
756
+
757
+ Returns:
758
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
759
+ """
760
+ assert len(w.shape) == 1
761
+ w = w * 1000.0
762
+
763
+ half_dim = embedding_dim // 2
764
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
765
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
766
+ emb = w.to(dtype)[:, None] * emb[None, :]
767
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
768
+ if embedding_dim % 2 == 1: # zero pad
769
+ emb = torch.nn.functional.pad(emb, (0, 1))
770
+ assert emb.shape == (w.shape[0], embedding_dim)
771
+ return emb
772
+
773
+ @property
774
+ def guidance_scale(self):
775
+ return self._guidance_scale
776
+
777
+ @property
778
+ def guidance_rescale(self):
779
+ return self._guidance_rescale
780
+
781
+ @property
782
+ def clip_skip(self):
783
+ return self._clip_skip
784
+
785
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
786
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
787
+ # corresponds to doing no classifier free guidance.
788
+ @property
789
+ def do_classifier_free_guidance(self):
790
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
791
+
792
+ @property
793
+ def cross_attention_kwargs(self):
794
+ return self._cross_attention_kwargs
795
+
796
+ @property
797
+ def denoising_end(self):
798
+ return self._denoising_end
799
+
800
+ @property
801
+ def num_timesteps(self):
802
+ return self._num_timesteps
803
+
804
+ @property
805
+ def interrupt(self):
806
+ return self._interrupt
807
+
808
+ @torch.no_grad()
809
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
810
+ def __call__(
811
+ self,
812
+ prompt: Union[str, List[str]] = None,
813
+ prompt_2: Optional[Union[str, List[str]]] = None,
814
+ height: Optional[int] = None,
815
+ width: Optional[int] = None,
816
+ num_inference_steps: int = 50,
817
+ timesteps: List[int] = None,
818
+ sigmas: List[float] = None,
819
+ denoising_end: Optional[float] = None,
820
+ guidance_scale: float = 5.0,
821
+ gligen_scheduled_sampling_beta: float = 0.3,
822
+ gligen_phrases: List[str] = None,
823
+ gligen_boxes: List[List[float]] = None,
824
+ gligen_inpaint_image: Optional[PIL.Image.Image] = None,
825
+ negative_prompt: Optional[Union[str, List[str]]] = None,
826
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
827
+ num_images_per_prompt: Optional[int] = 1,
828
+ eta: float = 0.0,
829
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
830
+ latents: Optional[torch.Tensor] = None,
831
+ prompt_embeds: Optional[torch.Tensor] = None,
832
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
833
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
834
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
835
+ ip_adapter_image: Optional[PipelineImageInput] = None,
836
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
837
+ output_type: Optional[str] = "pil",
838
+ return_dict: bool = True,
839
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
840
+ guidance_rescale: float = 0.0,
841
+ original_size: Optional[Tuple[int, int]] = None,
842
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
843
+ target_size: Optional[Tuple[int, int]] = None,
844
+ negative_original_size: Optional[Tuple[int, int]] = None,
845
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
846
+ negative_target_size: Optional[Tuple[int, int]] = None,
847
+ clip_skip: Optional[int] = None,
848
+ callback_on_step_end: Optional[
849
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
850
+ ] = None,
851
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
852
+ **kwargs,
853
+ ):
854
+ r"""
855
+ Function invoked when calling the pipeline for generation.
856
+
857
+ Args:
858
+ prompt (`str` or `List[str]`, *optional*):
859
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
860
+ instead.
861
+ prompt_2 (`str` or `List[str]`, *optional*):
862
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
863
+ used in both text-encoders
864
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
865
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
866
+ Anything below 512 pixels won't work well for
867
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
868
+ and checkpoints that are not specifically fine-tuned on low resolutions.
869
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
870
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
871
+ Anything below 512 pixels won't work well for
872
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
873
+ and checkpoints that are not specifically fine-tuned on low resolutions.
874
+ num_inference_steps (`int`, *optional*, defaults to 50):
875
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
876
+ expense of slower inference.
877
+ timesteps (`List[int]`, *optional*):
878
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
879
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
880
+ passed will be used. Must be in descending order.
881
+ sigmas (`List[float]`, *optional*):
882
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
883
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
884
+ will be used.
885
+ denoising_end (`float`, *optional*):
886
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
887
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
888
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
889
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
890
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
891
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
892
+ guidance_scale (`float`, *optional*, defaults to 5.0):
893
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
894
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
895
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
896
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
897
+ usually at the expense of lower image quality.
898
+ gligen_phrases (`List[str]`):
899
+ The phrases to guide what to include in each of the regions defined by the corresponding
900
+ `gligen_boxes`. There should only be one phrase per bounding box.
901
+ gligen_boxes (`List[List[float]]`):
902
+ The bounding boxes that identify rectangular regions of the image that are going to be filled with the
903
+ content described by the corresponding `gligen_phrases`. Each rectangular box is defined as a
904
+ `List[float]` of 4 elements `[xmin, ymin, xmax, ymax]` where each value is between [0,1].
905
+ gligen_inpaint_image (`PIL.Image.Image`, *optional*):
906
+ The input image, if provided, is inpainted with objects described by the `gligen_boxes` and
907
+ `gligen_phrases`. Otherwise, it is treated as a generation task on a blank input image.
908
+ gligen_scheduled_sampling_beta (`float`, defaults to 0.3):
909
+ Scheduled Sampling factor from [GLIGEN: Open-Set Grounded Text-to-Image
910
+ Generation](https://arxiv.org/pdf/2301.07093.pdf). Scheduled Sampling factor is only varied for
911
+ scheduled sampling during inference for improved quality and controllability.
912
+ negative_prompt (`str` or `List[str]`, *optional*):
913
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
914
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
915
+ less than `1`).
916
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
917
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
918
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
919
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
920
+ The number of images to generate per prompt.
921
+ eta (`float`, *optional*, defaults to 0.0):
922
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
923
+ [`schedulers.DDIMScheduler`], will be ignored for others.
924
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
925
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
926
+ to make generation deterministic.
927
+ latents (`torch.Tensor`, *optional*):
928
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
929
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
930
+ tensor will ge generated by sampling using the supplied random `generator`.
931
+ prompt_embeds (`torch.Tensor`, *optional*):
932
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
933
+ provided, text embeddings will be generated from `prompt` input argument.
934
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
935
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
936
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
937
+ argument.
938
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
939
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
940
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
941
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
942
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
943
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
944
+ input argument.
945
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
946
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
947
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
948
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
949
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
950
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
951
+ output_type (`str`, *optional*, defaults to `"pil"`):
952
+ The output format of the generate image. Choose between
953
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
954
+ return_dict (`bool`, *optional*, defaults to `True`):
955
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
956
+ of a plain tuple.
957
+ cross_attention_kwargs (`dict`, *optional*):
958
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
959
+ `self.processor` in
960
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
961
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
962
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
963
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
964
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
965
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
966
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
967
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
968
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
969
+ explained in section 2.2 of
970
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
971
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
972
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
973
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
974
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
975
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
976
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
977
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
978
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
979
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
980
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
981
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
982
+ micro-conditioning as explained in section 2.2 of
983
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
984
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
985
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
986
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
987
+ micro-conditioning as explained in section 2.2 of
988
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
989
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
990
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
991
+ To negatively condition the generation process based on a target image resolution. It should be as same
992
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
993
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
994
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
995
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
996
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
997
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
998
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
999
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1000
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1001
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1002
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1003
+ `._callback_tensor_inputs` attribute of your pipeline class.
1004
+
1005
+ Examples:
1006
+
1007
+ Returns:
1008
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
1009
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
1010
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
1011
+ """
1012
+
1013
+ callback = kwargs.pop("callback", None)
1014
+ callback_steps = kwargs.pop("callback_steps", None)
1015
+
1016
+ if callback is not None:
1017
+ deprecate(
1018
+ "callback",
1019
+ "1.0.0",
1020
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1021
+ )
1022
+ if callback_steps is not None:
1023
+ deprecate(
1024
+ "callback_steps",
1025
+ "1.0.0",
1026
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1027
+ )
1028
+
1029
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1030
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1031
+
1032
+ # 0. Default height and width to unet
1033
+ height = height or self.default_sample_size * self.vae_scale_factor
1034
+ width = width or self.default_sample_size * self.vae_scale_factor
1035
+
1036
+ original_size = original_size or (height, width)
1037
+ target_size = target_size or (height, width)
1038
+
1039
+ # 1. Check inputs. Raise error if not correct
1040
+ self.check_inputs(
1041
+ prompt,
1042
+ prompt_2,
1043
+ height,
1044
+ width,
1045
+ callback_steps,
1046
+ gligen_phrases,
1047
+ gligen_boxes,
1048
+ negative_prompt,
1049
+ negative_prompt_2,
1050
+ prompt_embeds,
1051
+ negative_prompt_embeds,
1052
+ pooled_prompt_embeds,
1053
+ negative_pooled_prompt_embeds,
1054
+ ip_adapter_image,
1055
+ ip_adapter_image_embeds,
1056
+ callback_on_step_end_tensor_inputs,
1057
+ )
1058
+
1059
+ self._guidance_scale = guidance_scale
1060
+ self._guidance_rescale = guidance_rescale
1061
+ self._clip_skip = clip_skip
1062
+ self._cross_attention_kwargs = cross_attention_kwargs
1063
+ self._denoising_end = denoising_end
1064
+ self._interrupt = False
1065
+
1066
+ # 2. Define call parameters
1067
+ if prompt is not None and isinstance(prompt, str):
1068
+ batch_size = 1
1069
+ elif prompt is not None and isinstance(prompt, list):
1070
+ batch_size = len(prompt)
1071
+ else:
1072
+ batch_size = prompt_embeds.shape[0]
1073
+
1074
+ device = self._execution_device
1075
+
1076
+ # 3. Encode input prompt
1077
+ lora_scale = (
1078
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1079
+ )
1080
+
1081
+ (
1082
+ prompt_embeds,
1083
+ negative_prompt_embeds,
1084
+ pooled_prompt_embeds,
1085
+ negative_pooled_prompt_embeds,
1086
+ ) = self.encode_prompt(
1087
+ prompt=prompt,
1088
+ prompt_2=prompt_2,
1089
+ device=device,
1090
+ num_images_per_prompt=num_images_per_prompt,
1091
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1092
+ negative_prompt=negative_prompt,
1093
+ negative_prompt_2=negative_prompt_2,
1094
+ prompt_embeds=prompt_embeds,
1095
+ negative_prompt_embeds=negative_prompt_embeds,
1096
+ pooled_prompt_embeds=pooled_prompt_embeds,
1097
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1098
+ lora_scale=lora_scale,
1099
+ clip_skip=self.clip_skip,
1100
+ )
1101
+
1102
+ # 4. Prepare timesteps
1103
+ timesteps, num_inference_steps = retrieve_timesteps(
1104
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
1105
+ )
1106
+
1107
+ # 5. Prepare latent variables
1108
+ num_channels_latents = self.unet.config.in_channels
1109
+ latents = self.prepare_latents(
1110
+ batch_size * num_images_per_prompt,
1111
+ num_channels_latents,
1112
+ height,
1113
+ width,
1114
+ prompt_embeds.dtype,
1115
+ device,
1116
+ generator,
1117
+ latents,
1118
+ )
1119
+ # 5.1 Prepare GLIGEN variables
1120
+ max_objs = 30
1121
+ if len(gligen_boxes) > max_objs:
1122
+ warnings.warn(
1123
+ f"More that {max_objs} objects found. Only first {max_objs} objects will be processed.",
1124
+ FutureWarning,
1125
+ )
1126
+ gligen_phrases = gligen_phrases[:max_objs]
1127
+ gligen_boxes = gligen_boxes[:max_objs]
1128
+ # prepare batched input to the GLIGENTextBoundingboxProjection (boxes, phrases, mask)
1129
+ # obtain its text features for phrases
1130
+ (
1131
+ gligen_embeds
1132
+ ) = self.encode_prompt_gligen(
1133
+ prompt=gligen_phrases,
1134
+ device=device,
1135
+ num_images_per_prompt=1,
1136
+ # TODO: whether we had to follow prompt encoding configuration on LoRA and CLIP skip
1137
+ lora_scale=lora_scale,
1138
+ clip_skip=self.clip_skip,
1139
+ )
1140
+
1141
+ n_objs = len(gligen_boxes)
1142
+ # For each entity, described in phrases, is denoted with a bounding box,
1143
+ # we represent the location information as (xmin,ymin,xmax,ymax)
1144
+ boxes = torch.zeros(max_objs, 4, device=device, dtype=self.text_encoder.dtype)
1145
+ boxes[:n_objs] = torch.tensor(gligen_boxes)
1146
+ text_embeddings = torch.zeros(
1147
+ max_objs, self.unet.config.cross_attention_dim, device=device, dtype=self.text_encoder.dtype
1148
+ )
1149
+ text_embeddings[:n_objs] = gligen_embeds
1150
+ # Generate a mask for each object that is entity described by phrases
1151
+ masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype)
1152
+ masks[:n_objs] = 1
1153
+ repeat_batch = batch_size * num_images_per_prompt
1154
+ boxes = boxes.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
1155
+ text_embeddings = text_embeddings.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
1156
+ masks = masks.unsqueeze(0).expand(repeat_batch, -1).clone()
1157
+ if self.do_classifier_free_guidance:
1158
+ repeat_batch = repeat_batch * 2
1159
+ boxes = torch.cat([boxes] * 2)
1160
+ text_embeddings = torch.cat([text_embeddings] * 2)
1161
+ masks = torch.cat([masks] * 2)
1162
+ masks[: repeat_batch // 2] = 0
1163
+ if self.cross_attention_kwargs is None:
1164
+ self._cross_attention_kwargs = {}
1165
+ self.cross_attention_kwargs["gligen"] = {"boxes": boxes, "positive_embeddings": text_embeddings, "masks": masks}
1166
+
1167
+ # Prepare latent variables for GLIGEN inpainting
1168
+ if gligen_inpaint_image is not None:
1169
+ # if the given input image is not of the same size as expected by VAE
1170
+ # center crop and resize the input image to expected shape
1171
+ if gligen_inpaint_image.size != (self.vae.sample_size, self.vae.sample_size):
1172
+ gligen_inpaint_image = self.target_size_center_crop(gligen_inpaint_image, self.vae.sample_size)
1173
+ # Convert a single image into a batch of images with a batch size of 1
1174
+ # The resulting shape becomes (1, C, H, W), where C is the number of channels,
1175
+ # and H and W are the height and width of the image.
1176
+ # scales the pixel values to a range [-1, 1]
1177
+ gligen_inpaint_image = self.image_processor.preprocess(gligen_inpaint_image)
1178
+ gligen_inpaint_image = gligen_inpaint_image.to(dtype=self.vae.dtype, device=self.vae.device)
1179
+ # Run AutoEncoder to get corresponding latents
1180
+ gligen_inpaint_latent = self.vae.encode(gligen_inpaint_image).latent_dist.sample()
1181
+ gligen_inpaint_latent = self.vae.config.scaling_factor * gligen_inpaint_latent
1182
+ # Generate an inpainting mask
1183
+ # pixel value = 0, where the object is present (defined by bounding boxes above)
1184
+ # 1, everywhere else
1185
+ gligen_inpaint_mask = self.draw_inpaint_mask_from_boxes(gligen_boxes, gligen_inpaint_latent.shape[2:])
1186
+ gligen_inpaint_mask = gligen_inpaint_mask.to(
1187
+ dtype=gligen_inpaint_latent.dtype, device=gligen_inpaint_latent.device
1188
+ )
1189
+ gligen_inpaint_mask = gligen_inpaint_mask[None, None]
1190
+ gligen_inpaint_mask_addition = torch.cat(
1191
+ (gligen_inpaint_latent * gligen_inpaint_mask, gligen_inpaint_mask), dim=1
1192
+ )
1193
+ # Convert a single mask into a batch of masks with a batch size of 1
1194
+ gligen_inpaint_mask_addition = gligen_inpaint_mask_addition.expand(repeat_batch, -1, -1, -1).clone()
1195
+
1196
+ num_grounding_steps = int(gligen_scheduled_sampling_beta * len(timesteps))
1197
+ self.enable_fuser(True)
1198
+
1199
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1200
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1201
+
1202
+ # 7. Prepare added time ids & embeddings
1203
+ add_text_embeds = pooled_prompt_embeds
1204
+ if self.text_encoder_2 is None:
1205
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1206
+ else:
1207
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1208
+
1209
+ add_time_ids = self._get_add_time_ids(
1210
+ original_size,
1211
+ crops_coords_top_left,
1212
+ target_size,
1213
+ dtype=prompt_embeds.dtype,
1214
+ text_encoder_projection_dim=text_encoder_projection_dim,
1215
+ )
1216
+ if negative_original_size is not None and negative_target_size is not None:
1217
+ negative_add_time_ids = self._get_add_time_ids(
1218
+ negative_original_size,
1219
+ negative_crops_coords_top_left,
1220
+ negative_target_size,
1221
+ dtype=prompt_embeds.dtype,
1222
+ text_encoder_projection_dim=text_encoder_projection_dim,
1223
+ )
1224
+ else:
1225
+ negative_add_time_ids = add_time_ids
1226
+
1227
+ if self.do_classifier_free_guidance:
1228
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1229
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1230
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1231
+
1232
+ prompt_embeds = prompt_embeds.to(device)
1233
+ add_text_embeds = add_text_embeds.to(device)
1234
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1235
+
1236
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1237
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1238
+ ip_adapter_image,
1239
+ ip_adapter_image_embeds,
1240
+ device,
1241
+ batch_size * num_images_per_prompt,
1242
+ self.do_classifier_free_guidance,
1243
+ )
1244
+
1245
+ # 8. Denoising loop
1246
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1247
+
1248
+ # 8.1 Apply denoising_end
1249
+ if (
1250
+ self.denoising_end is not None
1251
+ and isinstance(self.denoising_end, float)
1252
+ and self.denoising_end > 0
1253
+ and self.denoising_end < 1
1254
+ ):
1255
+ discrete_timestep_cutoff = int(
1256
+ round(
1257
+ self.scheduler.config.num_train_timesteps
1258
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
1259
+ )
1260
+ )
1261
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1262
+ timesteps = timesteps[:num_inference_steps]
1263
+
1264
+ # 9. Optionally get Guidance Scale Embedding
1265
+ timestep_cond = None
1266
+ if self.unet.config.time_cond_proj_dim is not None:
1267
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1268
+ timestep_cond = self.get_guidance_scale_embedding(
1269
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1270
+ ).to(device=device, dtype=latents.dtype)
1271
+
1272
+ self._num_timesteps = len(timesteps)
1273
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1274
+ for i, t in enumerate(timesteps):
1275
+ if self.interrupt:
1276
+ continue
1277
+
1278
+ if i == num_grounding_steps:
1279
+ self.enable_fuser(False)
1280
+
1281
+ if gligen_inpaint_image is not None:
1282
+ gligen_inpaint_latent_with_noise = (
1283
+ self.scheduler.add_noise(
1284
+ gligen_inpaint_latent, torch.randn_like(gligen_inpaint_latent), torch.tensor([t])
1285
+ )
1286
+ .expand(latents.shape[0], -1, -1, -1)
1287
+ .clone()
1288
+ )
1289
+ latents = gligen_inpaint_latent_with_noise * gligen_inpaint_mask + latents * (
1290
+ 1 - gligen_inpaint_mask
1291
+ )
1292
+
1293
+ # expand the latents if we are doing classifier free guidance
1294
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1295
+
1296
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1297
+
1298
+ if gligen_inpaint_image is not None:
1299
+ latent_model_input = torch.cat((latent_model_input, gligen_inpaint_mask_addition), dim=1)
1300
+
1301
+ # predict the noise residual
1302
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1303
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1304
+ added_cond_kwargs["image_embeds"] = image_embeds
1305
+ noise_pred = self.unet(
1306
+ latent_model_input,
1307
+ t,
1308
+ encoder_hidden_states=prompt_embeds,
1309
+ timestep_cond=timestep_cond,
1310
+ cross_attention_kwargs=self.cross_attention_kwargs,
1311
+ added_cond_kwargs=added_cond_kwargs,
1312
+ return_dict=False,
1313
+ )[0]
1314
+
1315
+ # perform guidance
1316
+ if self.do_classifier_free_guidance:
1317
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1318
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1319
+
1320
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1321
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1322
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1323
+
1324
+ # compute the previous noisy sample x_t -> x_t-1
1325
+ latents_dtype = latents.dtype
1326
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1327
+ if latents.dtype != latents_dtype:
1328
+ if torch.backends.mps.is_available():
1329
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1330
+ latents = latents.to(latents_dtype)
1331
+
1332
+ if callback_on_step_end is not None:
1333
+ callback_kwargs = {}
1334
+ for k in callback_on_step_end_tensor_inputs:
1335
+ callback_kwargs[k] = locals()[k]
1336
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1337
+
1338
+ latents = callback_outputs.pop("latents", latents)
1339
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1340
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1341
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1342
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1343
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1344
+ )
1345
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1346
+ negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
1347
+
1348
+ # call the callback, if provided
1349
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1350
+ progress_bar.update()
1351
+ if callback is not None and i % callback_steps == 0:
1352
+ step_idx = i // getattr(self.scheduler, "order", 1)
1353
+ callback(step_idx, t, latents)
1354
+
1355
+ if XLA_AVAILABLE:
1356
+ xm.mark_step()
1357
+
1358
+ if not output_type == "latent":
1359
+ # make sure the VAE is in float32 mode, as it overflows in float16
1360
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1361
+
1362
+ if needs_upcasting:
1363
+ self.upcast_vae()
1364
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1365
+ elif latents.dtype != self.vae.dtype:
1366
+ if torch.backends.mps.is_available():
1367
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1368
+ self.vae = self.vae.to(latents.dtype)
1369
+
1370
+ # unscale/denormalize the latents
1371
+ # denormalize with the mean and std if available and not None
1372
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
1373
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
1374
+ if has_latents_mean and has_latents_std:
1375
+ latents_mean = (
1376
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1377
+ )
1378
+ latents_std = (
1379
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1380
+ )
1381
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
1382
+ else:
1383
+ latents = latents / self.vae.config.scaling_factor
1384
+
1385
+ image = self.vae.decode(latents, return_dict=False)[0]
1386
+
1387
+ # cast back to fp16 if needed
1388
+ if needs_upcasting:
1389
+ self.vae.to(dtype=torch.float16)
1390
+ else:
1391
+ image = latents
1392
+
1393
+ if not output_type == "latent":
1394
+ # apply watermark if available
1395
+ if self.watermark is not None:
1396
+ image = self.watermark.apply_watermark(image)
1397
+
1398
+ image = self.image_processor.postprocess(image, output_type=output_type)
1399
+
1400
+ # Offload all models
1401
+ self.maybe_free_model_hooks()
1402
+
1403
+ if not return_dict:
1404
+ return (image,)
1405
+
1406
+ return StableDiffusionXLPipelineOutput(images=image)
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "EulerDiscreteScheduler",
3
+ "_diffusers_version": "0.30.1",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "final_sigmas_type": "zero",
9
+ "interpolation_type": "linear",
10
+ "num_train_timesteps": 1000,
11
+ "prediction_type": "epsilon",
12
+ "rescale_betas_zero_snr": false,
13
+ "sample_max_value": 1.0,
14
+ "set_alpha_to_one": false,
15
+ "sigma_max": null,
16
+ "sigma_min": null,
17
+ "skip_prk_steps": true,
18
+ "steps_offset": 1,
19
+ "timestep_spacing": "leading",
20
+ "timestep_type": "discrete",
21
+ "trained_betas": null,
22
+ "use_karras_sigmas": false
23
+ }
text_encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/home/user/jiuntian/.cache/huggingface/hub/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/462165984030d82259a11f4367a4eed129e94a7b/text_encoder",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "quick_gelu",
11
+ "hidden_size": 768,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 768,
22
+ "torch_dtype": "float16",
23
+ "transformers_version": "4.44.2",
24
+ "vocab_size": 49408
25
+ }
text_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:660c6f5b1abae9dc498ac2d21e1347d2abdb0cf6c0c0c8576cd796491d9a6cdd
3
+ size 246144152
text_encoder_2/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/home/user/jiuntian/.cache/huggingface/hub/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/462165984030d82259a11f4367a4eed129e94a7b/text_encoder_2",
3
+ "architectures": [
4
+ "CLIPTextModelWithProjection"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_size": 1280,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 5120,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 20,
19
+ "num_hidden_layers": 32,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 1280,
22
+ "torch_dtype": "float16",
23
+ "transformers_version": "4.44.2",
24
+ "vocab_size": 49408
25
+ }
text_encoder_2/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec310df2af79c318e24d20511b601a591ca8cd4f1fce1d8dff822a356bcdb1f4
3
+ size 1389382176
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<|endoftext|>",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "49406": {
5
+ "content": "<|startoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "49407": {
13
+ "content": "<|endoftext|>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ }
20
+ },
21
+ "bos_token": "<|startoftext|>",
22
+ "clean_up_tokenization_spaces": true,
23
+ "do_lower_case": true,
24
+ "eos_token": "<|endoftext|>",
25
+ "errors": "replace",
26
+ "model_max_length": 77,
27
+ "pad_token": "<|endoftext|>",
28
+ "tokenizer_class": "CLIPTokenizer",
29
+ "unk_token": "<|endoftext|>"
30
+ }
tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_2/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_2/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "!",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer_2/tokenizer_config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "!",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "49406": {
13
+ "content": "<|startoftext|>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "49407": {
21
+ "content": "<|endoftext|>",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ }
28
+ },
29
+ "bos_token": "<|startoftext|>",
30
+ "clean_up_tokenization_spaces": true,
31
+ "do_lower_case": true,
32
+ "eos_token": "<|endoftext|>",
33
+ "errors": "replace",
34
+ "model_max_length": 77,
35
+ "pad_token": "!",
36
+ "tokenizer_class": "CLIPTokenizer",
37
+ "unk_token": "<|endoftext|>"
38
+ }
tokenizer_2/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
unet/config.json ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.30.1",
4
+ "_name_or_path": "logs/gligen_sdxl_bs32/checkpoint-599000",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": "text_time",
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": 256,
9
+ "attention_head_dim": [
10
+ 5,
11
+ 10,
12
+ 20
13
+ ],
14
+ "attention_type": "gated",
15
+ "block_out_channels": [
16
+ 320,
17
+ 640,
18
+ 1280
19
+ ],
20
+ "center_input_sample": false,
21
+ "class_embed_type": null,
22
+ "class_embeddings_concat": false,
23
+ "conv_in_kernel": 3,
24
+ "conv_out_kernel": 3,
25
+ "cross_attention_dim": 2048,
26
+ "cross_attention_norm": null,
27
+ "down_block_types": [
28
+ "DownBlock2D",
29
+ "CrossAttnDownBlock2D",
30
+ "CrossAttnDownBlock2D"
31
+ ],
32
+ "downsample_padding": 1,
33
+ "dropout": 0.0,
34
+ "dual_cross_attention": false,
35
+ "encoder_hid_dim": null,
36
+ "encoder_hid_dim_type": null,
37
+ "flip_sin_to_cos": true,
38
+ "freq_shift": 0,
39
+ "in_channels": 4,
40
+ "layers_per_block": 2,
41
+ "mid_block_only_cross_attention": null,
42
+ "mid_block_scale_factor": 1,
43
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
44
+ "norm_eps": 1e-05,
45
+ "norm_num_groups": 32,
46
+ "num_attention_heads": null,
47
+ "num_class_embeds": null,
48
+ "only_cross_attention": false,
49
+ "out_channels": 4,
50
+ "projection_class_embeddings_input_dim": 2816,
51
+ "resnet_out_scale_factor": 1.0,
52
+ "resnet_skip_time_act": false,
53
+ "resnet_time_scale_shift": "default",
54
+ "reverse_transformer_layers_per_block": null,
55
+ "sample_size": 128,
56
+ "time_cond_proj_dim": null,
57
+ "time_embedding_act_fn": null,
58
+ "time_embedding_dim": null,
59
+ "time_embedding_type": "positional",
60
+ "timestep_post_act": null,
61
+ "transformer_layers_per_block": [
62
+ 1,
63
+ 2,
64
+ 10
65
+ ],
66
+ "up_block_types": [
67
+ "CrossAttnUpBlock2D",
68
+ "CrossAttnUpBlock2D",
69
+ "UpBlock2D"
70
+ ],
71
+ "upcast_attention": null,
72
+ "use_linear_projection": true
73
+ }
unet/diffusion_pytorch_model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f62bc47d05b06cafaa022c3622a4239bb9fb92b4eeda4ec8c95fcc14f6510dfe
3
+ size 9995914112
unet/diffusion_pytorch_model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f200f888eead0ee769381c89d089bf499b2272a852b8bcabdc5321343f7d132
3
+ size 7524087696
unet/diffusion_pytorch_model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
vae/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.30.1",
4
+ "_name_or_path": "madebyollin/sdxl-vae-fp16-fix",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": false,
19
+ "in_channels": 3,
20
+ "latent_channels": 4,
21
+ "latents_mean": null,
22
+ "latents_std": null,
23
+ "layers_per_block": 2,
24
+ "mid_block_add_attention": true,
25
+ "norm_num_groups": 32,
26
+ "out_channels": 3,
27
+ "sample_size": 512,
28
+ "scaling_factor": 0.13025,
29
+ "shift_factor": null,
30
+ "up_block_types": [
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D",
34
+ "UpDecoderBlock2D"
35
+ ],
36
+ "use_post_quant_conv": true,
37
+ "use_quant_conv": true
38
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6353737672c94b96174cb590f711eac6edf2fcce5b6e91aa9d73c5adc589ee48
3
+ size 167335342