jefflee commited on
Commit
9f9dc44
·
1 Parent(s): 5604534

tiny fixes

Browse files
Files changed (1) hide show
  1. lib/pipline_ConsistentID.py +14 -8
lib/pipline_ConsistentID.py CHANGED
@@ -362,7 +362,6 @@ class ConsistentIDPipeline(StableDiffusionPipeline):
362
  unet = edict()
363
  # Only keep the config and in_channels attributes that are used in the pipeline.
364
  unet.config = self.unet.config
365
- unet.in_channels = self.unet.in_channels
366
  self.unet = unet
367
 
368
  if "vae" in released_components:
@@ -484,8 +483,6 @@ class ConsistentIDPipeline(StableDiffusionPipeline):
484
  prompt_embeds,
485
  negative_prompt_embeds,
486
  )
487
- if not isinstance(input_subj_image_objs, list):
488
- input_subj_image_objs = [input_subj_image_objs]
489
 
490
  # 2. Define call parameters
491
  if prompt is not None and isinstance(prompt, str):
@@ -499,22 +496,31 @@ class ConsistentIDPipeline(StableDiffusionPipeline):
499
  do_classifier_free_guidance = guidance_scale >= 1.0
500
  assert do_classifier_free_guidance
501
 
502
- # 3. Encode input prompt
503
- coarse_prompt_embeds, fine_prompt_embeds = \
504
- self.extract_double_id_prompt_embeds(prompt, negative_prompt, input_subj_image_objs[0], device)
 
 
 
 
 
 
 
 
 
505
 
506
  # 7. Prepare timesteps
507
  self.scheduler.set_timesteps(num_inference_steps, device=device)
508
  timesteps = self.scheduler.timesteps
509
 
510
  # 8. Prepare latent variables
511
- num_channels_latents = self.unet.in_channels
512
  latents = self.prepare_latents(
513
  batch_size * num_images_per_prompt,
514
  num_channels_latents,
515
  height,
516
  width,
517
- coarse_prompt_embeds.dtype,
518
  device,
519
  generator,
520
  latents,
 
362
  unet = edict()
363
  # Only keep the config and in_channels attributes that are used in the pipeline.
364
  unet.config = self.unet.config
 
365
  self.unet = unet
366
 
367
  if "vae" in released_components:
 
483
  prompt_embeds,
484
  negative_prompt_embeds,
485
  )
 
 
486
 
487
  # 2. Define call parameters
488
  if prompt is not None and isinstance(prompt, str):
 
496
  do_classifier_free_guidance = guidance_scale >= 1.0
497
  assert do_classifier_free_guidance
498
 
499
+ if input_subj_image_objs is not None:
500
+ if not isinstance(input_subj_image_objs, list):
501
+ input_subj_image_objs = [input_subj_image_objs]
502
+
503
+ # 3. Encode input prompt
504
+ coarse_prompt_embeds, fine_prompt_embeds = \
505
+ self.extract_double_id_prompt_embeds(prompt, negative_prompt, input_subj_image_objs[0], device)
506
+ else:
507
+ # Hijack the coarse_prompt_embeds and fine_prompt_embeds to be the input prompt_embeds.
508
+ cfg_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
509
+ coarse_prompt_embeds = cfg_prompt_embeds
510
+ fine_prompt_embeds = cfg_prompt_embeds
511
 
512
  # 7. Prepare timesteps
513
  self.scheduler.set_timesteps(num_inference_steps, device=device)
514
  timesteps = self.scheduler.timesteps
515
 
516
  # 8. Prepare latent variables
517
+ num_channels_latents = self.unet.config.in_channels
518
  latents = self.prepare_latents(
519
  batch_size * num_images_per_prompt,
520
  num_channels_latents,
521
  height,
522
  width,
523
+ self.dtype,
524
  device,
525
  generator,
526
  latents,