Spaces:
Runtime error
Runtime error
Update pipeline.py
Browse files- pipeline.py +7 -7
pipeline.py
CHANGED
|
@@ -240,7 +240,7 @@ class Flex2Pipeline(FluxControlPipeline):
|
|
| 240 |
if control_image is None:
|
| 241 |
control_latents = torch.zeros(
|
| 242 |
batch_size * num_images_per_prompt,
|
| 243 |
-
|
| 244 |
latent_height,
|
| 245 |
latent_width,
|
| 246 |
device=device,
|
|
@@ -261,12 +261,11 @@ class Flex2Pipeline(FluxControlPipeline):
|
|
| 261 |
|
| 262 |
# apply control strength
|
| 263 |
control_latents = control_latents * control_strength
|
| 264 |
-
print("control_latents", control_latents.shape)
|
| 265 |
|
| 266 |
if inpaint_image is None and inpaint_mask is None:
|
| 267 |
inpaint_latents = torch.zeros(
|
| 268 |
batch_size * num_images_per_prompt,
|
| 269 |
-
|
| 270 |
latent_height,
|
| 271 |
latent_width,
|
| 272 |
device=device,
|
|
@@ -282,7 +281,7 @@ class Flex2Pipeline(FluxControlPipeline):
|
|
| 282 |
)
|
| 283 |
else:
|
| 284 |
print("inpaint_image.shape",inpaint_image.size)
|
| 285 |
-
print("inpaint_mask.shape",inpaint_mask.
|
| 286 |
inpaint_image = self.prepare_image(
|
| 287 |
image=inpaint_image,
|
| 288 |
width=width,
|
|
@@ -294,7 +293,6 @@ class Flex2Pipeline(FluxControlPipeline):
|
|
| 294 |
)
|
| 295 |
inpaint_image = self.vae.encode(inpaint_image).latent_dist.sample(generator=generator)
|
| 296 |
inpaint_latents = (inpaint_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
| 297 |
-
print("inpaint_latents", inpaint_latents.shape)
|
| 298 |
height_inpaint_image, width_inpaint_image = inpaint_image.shape[2:]
|
| 299 |
|
| 300 |
inpaint_mask = self.prepare_image(
|
|
@@ -310,7 +308,7 @@ class Flex2Pipeline(FluxControlPipeline):
|
|
| 310 |
inpaint_mask = inpaint_mask[:, 0:1, :, :] * 0.5 + 0.5
|
| 311 |
# resize to match height_inpaint_image and width_inpaint_image
|
| 312 |
inpaint_latents_mask = F.interpolate(inpaint_mask, size=(height_inpaint_image, width_inpaint_image), mode="bilinear", align_corners=False)
|
| 313 |
-
|
| 314 |
# apply inverted mask to inpaint latents
|
| 315 |
inpaint_latents = inpaint_latents * (1 - inpaint_latents_mask)
|
| 316 |
|
|
@@ -443,4 +441,6 @@ class Flex2Pipeline(FluxControlPipeline):
|
|
| 443 |
if not return_dict:
|
| 444 |
return (image,)
|
| 445 |
|
| 446 |
-
return FluxPipelineOutput(images=image)
|
|
|
|
|
|
|
|
|
| 240 |
if control_image is None:
|
| 241 |
control_latents = torch.zeros(
|
| 242 |
batch_size * num_images_per_prompt,
|
| 243 |
+
16,
|
| 244 |
latent_height,
|
| 245 |
latent_width,
|
| 246 |
device=device,
|
|
|
|
| 261 |
|
| 262 |
# apply control strength
|
| 263 |
control_latents = control_latents * control_strength
|
|
|
|
| 264 |
|
| 265 |
if inpaint_image is None and inpaint_mask is None:
|
| 266 |
inpaint_latents = torch.zeros(
|
| 267 |
batch_size * num_images_per_prompt,
|
| 268 |
+
16,
|
| 269 |
latent_height,
|
| 270 |
latent_width,
|
| 271 |
device=device,
|
|
|
|
| 281 |
)
|
| 282 |
else:
|
| 283 |
print("inpaint_image.shape",inpaint_image.size)
|
| 284 |
+
print("inpaint_mask.shape",inpaint_mask.shape)
|
| 285 |
inpaint_image = self.prepare_image(
|
| 286 |
image=inpaint_image,
|
| 287 |
width=width,
|
|
|
|
| 293 |
)
|
| 294 |
inpaint_image = self.vae.encode(inpaint_image).latent_dist.sample(generator=generator)
|
| 295 |
inpaint_latents = (inpaint_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
|
|
|
| 296 |
height_inpaint_image, width_inpaint_image = inpaint_image.shape[2:]
|
| 297 |
|
| 298 |
inpaint_mask = self.prepare_image(
|
|
|
|
| 308 |
inpaint_mask = inpaint_mask[:, 0:1, :, :] * 0.5 + 0.5
|
| 309 |
# resize to match height_inpaint_image and width_inpaint_image
|
| 310 |
inpaint_latents_mask = F.interpolate(inpaint_mask, size=(height_inpaint_image, width_inpaint_image), mode="bilinear", align_corners=False)
|
| 311 |
+
|
| 312 |
# apply inverted mask to inpaint latents
|
| 313 |
inpaint_latents = inpaint_latents * (1 - inpaint_latents_mask)
|
| 314 |
|
|
|
|
| 441 |
if not return_dict:
|
| 442 |
return (image,)
|
| 443 |
|
| 444 |
+
return FluxPipelineOutput(images=image)
|
| 445 |
+
|
| 446 |
+
|