cloth-vton / src /pipeline_tryon.py
abrarlohia's picture
Upload 37 files
c79133e verified
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
from PIL import Image
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models.autoencoders import AutoencoderKL
from diffusers.models.transformers import FluxTransformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from diffusers.pipelines import FluxInpaintPipeline
from diffusers.pipelines.flux.pipeline_flux_inpaint import calculate_shift, retrieve_latents, retrieve_timesteps
class FluxTryonPipeline(FluxInpaintPipeline):
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype, target_width=-1, tryon=False):
latent_image_ids = torch.zeros(height, width, 3)
if target_width==-1:
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
else:
latent_image_ids[:, target_width:, 0] = 1
# height keep as before
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
if tryon:
latent_image_ids[:, target_width*2:, 0] = 2
# left
latent_image_ids[:, :target_width, 2] = latent_image_ids[:, :target_width, 2] + torch.arange(target_width)[None, :]
# right
latent_image_ids[:, target_width:, 2] = latent_image_ids[:, target_width:, 2] + torch.arange(width-target_width)[None, :]
else:
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
def prepare_latents(
self,
image,
timestep,
batch_size,
num_channels_latents,
height,
width,
target_width,
tryon,
dtype,
device,
generator,
latents=None,
):
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
sp = 2 * (int(target_width) // (self.vae_scale_factor * 2))//2 # -1
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype, sp, tryon)
image = image.to(device=device, dtype=dtype)
# image_latents = self._encode_vae_image(image=image, generator=generator)
img_parts = [image[:,:,:,:target_width], image[:,:,:,target_width:]]
image_latents = [self._encode_vae_image(image=img, generator=generator) for img in img_parts]
image_latents = torch.cat(image_latents, dim=-1)
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
else:
image_latents = torch.cat([image_latents], dim=0)
if latents is None:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
else:
noise = latents.to(device)
latents = noise
noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
return latents, noise, image_latents, latent_image_ids
def prepare_mask_latents(
self,
mask,
masked_image,
batch_size,
num_channels_latents,
num_images_per_prompt,
height,
width,
dtype,
device,
generator,
):
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(mask, size=(height, width), mode="nearest")
mask = mask.to(device=device, dtype=dtype)
batch_size = batch_size * num_images_per_prompt
masked_image = masked_image.to(device=device, dtype=dtype)
if masked_image.shape[1] == 16:
masked_image_latents = masked_image
else:
masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
masked_image_latents = self._pack_latents(
masked_image_latents,
batch_size,
num_channels_latents,
height,
width,
)
mask = self._pack_latents(
mask.repeat(1, num_channels_latents, 1, 1),
batch_size,
num_channels_latents,
height,
width,
)
return mask, masked_image_latents
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
image: PipelineImageInput = None,
mask_image: PipelineImageInput = None,
masked_image_latents: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
target_width: Optional[int] = None,
tryon: bool = False,
padding_mask_crop: Optional[int] = None,
strength: float = 0.6,
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 7.0,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
image,
mask_image,
strength,
height,
width,
output_type=output_type,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
padding_mask_crop=padding_mask_crop,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Preprocess mask and image
if padding_mask_crop is not None:
crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
resize_mode = "fill"
else:
crops_coords = None
resize_mode = "default"
original_image = image
init_image = self.image_processor.preprocess(
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
)
init_image = init_image.to(dtype=torch.float32)
# 3. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 4.Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
mu=mu,
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
if num_inference_steps < 1:
raise ValueError(
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
num_channels_transformer = self.transformer.config.in_channels
latents, noise, image_latents, latent_image_ids= self.prepare_latents(
init_image,
latent_timestep,
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
target_width,
tryon,
prompt_embeds.dtype,
device,
generator,
latents,
)
mask_condition = self.mask_processor.preprocess(
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
)
if masked_image_latents is None:
masked_image = init_image * (mask_condition < 0.5)
else:
masked_image = masked_image_latents
mask, masked_image_latents = self.prepare_mask_latents(
mask_condition,
masked_image,
batch_size,
num_channels_latents,
num_images_per_prompt,
height,
width,
prompt_embeds.dtype,
device,
generator,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# for 64 channel transformer only.
init_latents_proper = image_latents
init_mask = mask
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
'''
# for 64 channel transformer only.
init_latents_proper = image_latents
init_mask = mask
# NOTE: we just use clean latents
# if i < len(timesteps) - 1:
# noise_timestep = timesteps[i + 1]
# init_latents_proper = self.scheduler.scale_noise(
# init_latents_proper, torch.tensor([noise_timestep]), noise
# )
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
'''
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
# if XLA_AVAILABLE:
# xm.mark_step()
# latents = (1 - mask) * image_latents + mask * latents
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = latents[:,:,:,:target_width//self.vae_scale_factor]
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents.to(device=self.vae.device, dtype=self.vae.dtype), return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
def flux_pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
def flux_unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents
# TODO: it is more reasonable to have target pe staring at 0
def prepare_latent_image_ids(height, width_tgt, height_spa, width_spa, height_sub, width_sub, device, dtype):
assert width_spa==0 or width_tgt==width_spa
latent_image_ids = torch.zeros(height, width_tgt, 3, device=device, dtype=dtype)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height, device=device)[:, None] # y坐标
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width_tgt, device=device)[None, :] # x坐标
cond_mark = 0
if width_spa>0:
cond_mark += 1
condspa_image_ids = torch.zeros(height_spa, width_spa, 3, device=device, dtype=dtype)
condspa_image_ids[..., 0] = cond_mark
condspa_image_ids[..., 1] = condspa_image_ids[..., 1] + torch.arange(height_spa, device=device)[:, None]
condspa_image_ids[..., 2] = condspa_image_ids[..., 2] + torch.arange(width_spa, device=device)[None, :]
condspa_image_ids = condspa_image_ids.reshape(-1, condspa_image_ids.shape[-1])
if width_sub>0:
cond_mark += 1
condsub_image_ids = torch.zeros(height_sub, width_sub, 3, device=device, dtype=dtype)
condsub_image_ids[..., 0] = cond_mark
condsub_image_ids[..., 1] = condsub_image_ids[..., 1] + torch.arange(height_sub, device=device)[:, None]
condsub_image_ids[..., 2] = condsub_image_ids[..., 2] + torch.arange(width_sub, device=device)[None, :] + width_tgt
condsub_image_ids = condsub_image_ids.reshape(-1, condsub_image_ids.shape[-1])
latent_image_ids = latent_image_ids.reshape(-1, latent_image_ids.shape[-1])
latent_image_ids = torch.cat([latent_image_ids, condspa_image_ids],dim=-2) if width_spa>0 else latent_image_ids
latent_image_ids = torch.cat([latent_image_ids, condsub_image_ids],dim=-2) if width_sub>0 else latent_image_ids
return latent_image_ids
def crop_to_multiple_of_16(img):
width, height = img.size
# Calculate new dimensions that are multiples of 8
new_width = width - (width % 16)
new_height = height - (height % 16)
# Calculate crop box coordinates
left = (width - new_width) // 2
top = (height - new_height) // 2
right = left + new_width
bottom = top + new_height
# Crop the image
cropped_img = img.crop((left, top, right, bottom))
return cropped_img
def resize_and_pad_to_size(image, target_width, target_height):
# Convert numpy array to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Get original dimensions
orig_width, orig_height = image.size
# Calculate aspect ratios
target_ratio = target_width / target_height
orig_ratio = orig_width / orig_height
# Calculate new dimensions while maintaining aspect ratio
if orig_ratio > target_ratio:
# Image is wider than target ratio - scale by width
new_width = target_width
new_height = int(new_width / orig_ratio)
else:
# Image is taller than target ratio - scale by height
new_height = target_height
new_width = int(new_height * orig_ratio)
# Resize image
resized_image = image.resize((new_width, new_height))
# Create white background image of target size
padded_image = Image.new('RGB', (target_width, target_height), 'white')
# Calculate padding to center the image
left_padding = (target_width - new_width) // 2
top_padding = (target_height - new_height) // 2
# Paste resized image onto padded background
padded_image.paste(resized_image, (left_padding, top_padding))
return padded_image, left_padding, top_padding, target_width - new_width - left_padding, target_height - new_height - top_padding
def resize_by_height(image, height):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# image is a PIL image
image = image.resize((int(image.width * height / image.height), height))
return crop_to_multiple_of_16(image)