mskrt commited on
Commit
4eb194a
verified
1 Parent(s): ffcb9f4

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +15 -22
pipeline.py CHANGED
@@ -33,31 +33,24 @@ class SuperDiffPipeline(DiffusionPipeline, ConfigMixin):
33
 
34
  """
35
  super().__init__()
36
- self.unet = unet
 
 
 
 
 
 
 
 
 
37
  self.vae = vae
38
- self.text_encoder = text_encoder
39
- self.tokenizer = tokenizer
40
  self.scheduler = scheduler
 
 
 
41
 
42
- device = "cuda" if torch.cuda.is_available() else "cpu"
43
-
44
- self.vae.to(device)
45
- self.unet.to(device)
46
- self.text_encoder.to(device)
47
-
48
- self.register_to_config(
49
- vae=vae.__class__.__name__,
50
- scheduler=scheduler.__class__.__name__,
51
- tokenizer=tokenizer.__class__.__name__,
52
- unet=unet.__class__.__name__,
53
- text_encoder=text_encoder.__class__.__name__,
54
- device=device,
55
- batch_size=None,
56
- num_inference_steps=None,
57
- guidance_scale=None,
58
- lift=None,
59
- seed=None,
60
- )
61
 
62
  @torch.no_grad
63
  def get_batch(self, latents: Callable, nrow: int, ncol: int) -> Callable:
 
33
 
34
  """
35
  super().__init__()
36
+ self.register_to_config(
37
+ batch_size=kwargs.get("batch_size", 1),
38
+ device=kwargs.get("device", "cuda"),
39
+ guidance_scale=kwargs.get("guidance_scale", 7.5),
40
+ lift=kwargs.get("lift", 0.0),
41
+ num_inference_steps=kwargs.get("num_inference_steps", 50),
42
+ seed=kwargs.get("seed", 42)
43
+ )
44
+
45
+ # Assign model components
46
  self.vae = vae
 
 
47
  self.scheduler = scheduler
48
+ self.tokenizer = tokenizer
49
+ self.unet = unet
50
+ self.text_encoder = text_encoder
51
 
52
+ # Move components to device
53
+ self.to(torch.device(self.config.device))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  @torch.no_grad
56
  def get_batch(self, latents: Callable, nrow: int, ncol: int) -> Callable: