Spaces:
Runtime error
Runtime error
tiny fixes
Browse files- lib/pipline_ConsistentID.py +14 -8
lib/pipline_ConsistentID.py
CHANGED
@@ -362,7 +362,6 @@ class ConsistentIDPipeline(StableDiffusionPipeline):
|
|
362 |
unet = edict()
|
363 |
# Only keep the config and in_channels attributes that are used in the pipeline.
|
364 |
unet.config = self.unet.config
|
365 |
-
unet.in_channels = self.unet.in_channels
|
366 |
self.unet = unet
|
367 |
|
368 |
if "vae" in released_components:
|
@@ -484,8 +483,6 @@ class ConsistentIDPipeline(StableDiffusionPipeline):
|
|
484 |
prompt_embeds,
|
485 |
negative_prompt_embeds,
|
486 |
)
|
487 |
-
if not isinstance(input_subj_image_objs, list):
|
488 |
-
input_subj_image_objs = [input_subj_image_objs]
|
489 |
|
490 |
# 2. Define call parameters
|
491 |
if prompt is not None and isinstance(prompt, str):
|
@@ -499,22 +496,31 @@ class ConsistentIDPipeline(StableDiffusionPipeline):
|
|
499 |
do_classifier_free_guidance = guidance_scale >= 1.0
|
500 |
assert do_classifier_free_guidance
|
501 |
|
502 |
-
|
503 |
-
|
504 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
505 |
|
506 |
# 7. Prepare timesteps
|
507 |
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
508 |
timesteps = self.scheduler.timesteps
|
509 |
|
510 |
# 8. Prepare latent variables
|
511 |
-
num_channels_latents = self.unet.in_channels
|
512 |
latents = self.prepare_latents(
|
513 |
batch_size * num_images_per_prompt,
|
514 |
num_channels_latents,
|
515 |
height,
|
516 |
width,
|
517 |
-
|
518 |
device,
|
519 |
generator,
|
520 |
latents,
|
|
|
362 |
unet = edict()
|
363 |
# Only keep the config and in_channels attributes that are used in the pipeline.
|
364 |
unet.config = self.unet.config
|
|
|
365 |
self.unet = unet
|
366 |
|
367 |
if "vae" in released_components:
|
|
|
483 |
prompt_embeds,
|
484 |
negative_prompt_embeds,
|
485 |
)
|
|
|
|
|
486 |
|
487 |
# 2. Define call parameters
|
488 |
if prompt is not None and isinstance(prompt, str):
|
|
|
496 |
do_classifier_free_guidance = guidance_scale >= 1.0
|
497 |
assert do_classifier_free_guidance
|
498 |
|
499 |
+
if input_subj_image_objs is not None:
|
500 |
+
if not isinstance(input_subj_image_objs, list):
|
501 |
+
input_subj_image_objs = [input_subj_image_objs]
|
502 |
+
|
503 |
+
# 3. Encode input prompt
|
504 |
+
coarse_prompt_embeds, fine_prompt_embeds = \
|
505 |
+
self.extract_double_id_prompt_embeds(prompt, negative_prompt, input_subj_image_objs[0], device)
|
506 |
+
else:
|
507 |
+
# Hijack the coarse_prompt_embeds and fine_prompt_embeds to be the input prompt_embeds.
|
508 |
+
cfg_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
509 |
+
coarse_prompt_embeds = cfg_prompt_embeds
|
510 |
+
fine_prompt_embeds = cfg_prompt_embeds
|
511 |
|
512 |
# 7. Prepare timesteps
|
513 |
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
514 |
timesteps = self.scheduler.timesteps
|
515 |
|
516 |
# 8. Prepare latent variables
|
517 |
+
num_channels_latents = self.unet.config.in_channels
|
518 |
latents = self.prepare_latents(
|
519 |
batch_size * num_images_per_prompt,
|
520 |
num_channels_latents,
|
521 |
height,
|
522 |
width,
|
523 |
+
self.dtype,
|
524 |
device,
|
525 |
generator,
|
526 |
latents,
|