Update pipeline.py
Browse files- pipeline.py +15 -22
pipeline.py
CHANGED
@@ -33,31 +33,24 @@ class SuperDiffPipeline(DiffusionPipeline, ConfigMixin):
|
|
33 |
|
34 |
"""
|
35 |
super().__init__()
|
36 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
self.vae = vae
|
38 |
-
self.text_encoder = text_encoder
|
39 |
-
self.tokenizer = tokenizer
|
40 |
self.scheduler = scheduler
|
|
|
|
|
|
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
self.vae.to(device)
|
45 |
-
self.unet.to(device)
|
46 |
-
self.text_encoder.to(device)
|
47 |
-
|
48 |
-
self.register_to_config(
|
49 |
-
vae=vae.__class__.__name__,
|
50 |
-
scheduler=scheduler.__class__.__name__,
|
51 |
-
tokenizer=tokenizer.__class__.__name__,
|
52 |
-
unet=unet.__class__.__name__,
|
53 |
-
text_encoder=text_encoder.__class__.__name__,
|
54 |
-
device=device,
|
55 |
-
batch_size=None,
|
56 |
-
num_inference_steps=None,
|
57 |
-
guidance_scale=None,
|
58 |
-
lift=None,
|
59 |
-
seed=None,
|
60 |
-
)
|
61 |
|
62 |
@torch.no_grad
|
63 |
def get_batch(self, latents: Callable, nrow: int, ncol: int) -> Callable:
|
|
|
33 |
|
34 |
"""
|
35 |
super().__init__()
|
36 |
+
self.register_to_config(
|
37 |
+
batch_size=kwargs.get("batch_size", 1),
|
38 |
+
device=kwargs.get("device", "cuda"),
|
39 |
+
guidance_scale=kwargs.get("guidance_scale", 7.5),
|
40 |
+
lift=kwargs.get("lift", 0.0),
|
41 |
+
num_inference_steps=kwargs.get("num_inference_steps", 50),
|
42 |
+
seed=kwargs.get("seed", 42)
|
43 |
+
)
|
44 |
+
|
45 |
+
# Assign model components
|
46 |
self.vae = vae
|
|
|
|
|
47 |
self.scheduler = scheduler
|
48 |
+
self.tokenizer = tokenizer
|
49 |
+
self.unet = unet
|
50 |
+
self.text_encoder = text_encoder
|
51 |
|
52 |
+
# Move components to device
|
53 |
+
self.to(torch.device(self.config.device))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
@torch.no_grad
|
56 |
def get_batch(self, latents: Callable, nrow: int, ncol: int) -> Callable:
|