Spaces:
Runtime error
Runtime error
import inspect | |
from typing import Any, Callable, Dict, List, Optional, Union | |
import numpy as np | |
import PIL.Image | |
from PIL import Image | |
import torch | |
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast | |
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor | |
from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin | |
from diffusers.models.autoencoders import AutoencoderKL | |
from diffusers.models.transformers import FluxTransformer2DModel | |
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler | |
from diffusers.utils import ( | |
USE_PEFT_BACKEND, | |
is_torch_xla_available, | |
logging, | |
replace_example_docstring, | |
scale_lora_layers, | |
unscale_lora_layers, | |
) | |
from diffusers.utils.torch_utils import randn_tensor | |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline | |
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput | |
from diffusers.pipelines import FluxInpaintPipeline | |
from diffusers.pipelines.flux.pipeline_flux_inpaint import calculate_shift, retrieve_latents, retrieve_timesteps | |
class FluxTryonPipeline(FluxInpaintPipeline): | |
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids | |
def _prepare_latent_image_ids(batch_size, height, width, device, dtype, target_width=-1, tryon=False): | |
latent_image_ids = torch.zeros(height, width, 3) | |
if target_width==-1: | |
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] | |
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] | |
else: | |
latent_image_ids[:, target_width:, 0] = 1 | |
# height keep as before | |
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] | |
if tryon: | |
latent_image_ids[:, target_width*2:, 0] = 2 | |
# left | |
latent_image_ids[:, :target_width, 2] = latent_image_ids[:, :target_width, 2] + torch.arange(target_width)[None, :] | |
# right | |
latent_image_ids[:, target_width:, 2] = latent_image_ids[:, target_width:, 2] + torch.arange(width-target_width)[None, :] | |
else: | |
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] | |
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape | |
latent_image_ids = latent_image_ids.reshape( | |
latent_image_id_height * latent_image_id_width, latent_image_id_channels | |
) | |
return latent_image_ids.to(device=device, dtype=dtype) | |
def prepare_latents( | |
self, | |
image, | |
timestep, | |
batch_size, | |
num_channels_latents, | |
height, | |
width, | |
target_width, | |
tryon, | |
dtype, | |
device, | |
generator, | |
latents=None, | |
): | |
if isinstance(generator, list) and len(generator) != batch_size: | |
raise ValueError( | |
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | |
f" size of {batch_size}. Make sure the batch size matches the length of the generators." | |
) | |
# VAE applies 8x compression on images but we must also account for packing which requires | |
# latent height and width to be divisible by 2. | |
height = 2 * (int(height) // (self.vae_scale_factor * 2)) | |
width = 2 * (int(width) // (self.vae_scale_factor * 2)) | |
shape = (batch_size, num_channels_latents, height, width) | |
sp = 2 * (int(target_width) // (self.vae_scale_factor * 2))//2 # -1 | |
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype, sp, tryon) | |
image = image.to(device=device, dtype=dtype) | |
# image_latents = self._encode_vae_image(image=image, generator=generator) | |
img_parts = [image[:,:,:,:target_width], image[:,:,:,target_width:]] | |
image_latents = [self._encode_vae_image(image=img, generator=generator) for img in img_parts] | |
image_latents = torch.cat(image_latents, dim=-1) | |
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: | |
# expand init_latents for batch_size | |
additional_image_per_prompt = batch_size // image_latents.shape[0] | |
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) | |
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: | |
raise ValueError( | |
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." | |
) | |
else: | |
image_latents = torch.cat([image_latents], dim=0) | |
if latents is None: | |
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | |
latents = self.scheduler.scale_noise(image_latents, timestep, noise) | |
else: | |
noise = latents.to(device) | |
latents = noise | |
noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) | |
image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) | |
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) | |
return latents, noise, image_latents, latent_image_ids | |
def prepare_mask_latents( | |
self, | |
mask, | |
masked_image, | |
batch_size, | |
num_channels_latents, | |
num_images_per_prompt, | |
height, | |
width, | |
dtype, | |
device, | |
generator, | |
): | |
# VAE applies 8x compression on images but we must also account for packing which requires | |
# latent height and width to be divisible by 2. | |
height = 2 * (int(height) // (self.vae_scale_factor * 2)) | |
width = 2 * (int(width) // (self.vae_scale_factor * 2)) | |
# resize the mask to latents shape as we concatenate the mask to the latents | |
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload | |
# and half precision | |
mask = torch.nn.functional.interpolate(mask, size=(height, width), mode="nearest") | |
mask = mask.to(device=device, dtype=dtype) | |
batch_size = batch_size * num_images_per_prompt | |
masked_image = masked_image.to(device=device, dtype=dtype) | |
if masked_image.shape[1] == 16: | |
masked_image_latents = masked_image | |
else: | |
masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) | |
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor | |
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method | |
if mask.shape[0] < batch_size: | |
if not batch_size % mask.shape[0] == 0: | |
raise ValueError( | |
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" | |
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" | |
" of masks that you pass is divisible by the total requested batch size." | |
) | |
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) | |
if masked_image_latents.shape[0] < batch_size: | |
if not batch_size % masked_image_latents.shape[0] == 0: | |
raise ValueError( | |
"The passed images and the required batch size don't match. Images are supposed to be duplicated" | |
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." | |
" Make sure the number of images that you pass is divisible by the total requested batch size." | |
) | |
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) | |
# aligning device to prevent device errors when concating it with the latent model input | |
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) | |
masked_image_latents = self._pack_latents( | |
masked_image_latents, | |
batch_size, | |
num_channels_latents, | |
height, | |
width, | |
) | |
mask = self._pack_latents( | |
mask.repeat(1, num_channels_latents, 1, 1), | |
batch_size, | |
num_channels_latents, | |
height, | |
width, | |
) | |
return mask, masked_image_latents | |
def __call__( | |
self, | |
prompt: Union[str, List[str]] = None, | |
prompt_2: Optional[Union[str, List[str]]] = None, | |
image: PipelineImageInput = None, | |
mask_image: PipelineImageInput = None, | |
masked_image_latents: PipelineImageInput = None, | |
height: Optional[int] = None, | |
width: Optional[int] = None, | |
target_width: Optional[int] = None, | |
tryon: bool = False, | |
padding_mask_crop: Optional[int] = None, | |
strength: float = 0.6, | |
num_inference_steps: int = 28, | |
timesteps: List[int] = None, | |
guidance_scale: float = 7.0, | |
num_images_per_prompt: Optional[int] = 1, | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
latents: Optional[torch.FloatTensor] = None, | |
prompt_embeds: Optional[torch.FloatTensor] = None, | |
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | |
output_type: Optional[str] = "pil", | |
return_dict: bool = True, | |
joint_attention_kwargs: Optional[Dict[str, Any]] = None, | |
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, | |
callback_on_step_end_tensor_inputs: List[str] = ["latents"], | |
max_sequence_length: int = 512, | |
): | |
height = height or self.default_sample_size * self.vae_scale_factor | |
width = width or self.default_sample_size * self.vae_scale_factor | |
# 1. Check inputs. Raise error if not correct | |
self.check_inputs( | |
prompt, | |
prompt_2, | |
image, | |
mask_image, | |
strength, | |
height, | |
width, | |
output_type=output_type, | |
prompt_embeds=prompt_embeds, | |
pooled_prompt_embeds=pooled_prompt_embeds, | |
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, | |
padding_mask_crop=padding_mask_crop, | |
max_sequence_length=max_sequence_length, | |
) | |
self._guidance_scale = guidance_scale | |
self._joint_attention_kwargs = joint_attention_kwargs | |
self._interrupt = False | |
# 2. Preprocess mask and image | |
if padding_mask_crop is not None: | |
crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) | |
resize_mode = "fill" | |
else: | |
crops_coords = None | |
resize_mode = "default" | |
original_image = image | |
init_image = self.image_processor.preprocess( | |
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode | |
) | |
init_image = init_image.to(dtype=torch.float32) | |
# 3. Define call parameters | |
if prompt is not None and isinstance(prompt, str): | |
batch_size = 1 | |
elif prompt is not None and isinstance(prompt, list): | |
batch_size = len(prompt) | |
else: | |
batch_size = prompt_embeds.shape[0] | |
device = self._execution_device | |
lora_scale = ( | |
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None | |
) | |
( | |
prompt_embeds, | |
pooled_prompt_embeds, | |
text_ids, | |
) = self.encode_prompt( | |
prompt=prompt, | |
prompt_2=prompt_2, | |
prompt_embeds=prompt_embeds, | |
pooled_prompt_embeds=pooled_prompt_embeds, | |
device=device, | |
num_images_per_prompt=num_images_per_prompt, | |
max_sequence_length=max_sequence_length, | |
lora_scale=lora_scale, | |
) | |
# 4.Prepare timesteps | |
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) | |
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) | |
mu = calculate_shift( | |
image_seq_len, | |
self.scheduler.config.base_image_seq_len, | |
self.scheduler.config.max_image_seq_len, | |
self.scheduler.config.base_shift, | |
self.scheduler.config.max_shift, | |
) | |
timesteps, num_inference_steps = retrieve_timesteps( | |
self.scheduler, | |
num_inference_steps, | |
device, | |
timesteps, | |
sigmas, | |
mu=mu, | |
) | |
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) | |
if num_inference_steps < 1: | |
raise ValueError( | |
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" | |
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." | |
) | |
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) | |
# 5. Prepare latent variables | |
num_channels_latents = self.transformer.config.in_channels // 4 | |
num_channels_transformer = self.transformer.config.in_channels | |
latents, noise, image_latents, latent_image_ids= self.prepare_latents( | |
init_image, | |
latent_timestep, | |
batch_size * num_images_per_prompt, | |
num_channels_latents, | |
height, | |
width, | |
target_width, | |
tryon, | |
prompt_embeds.dtype, | |
device, | |
generator, | |
latents, | |
) | |
mask_condition = self.mask_processor.preprocess( | |
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords | |
) | |
if masked_image_latents is None: | |
masked_image = init_image * (mask_condition < 0.5) | |
else: | |
masked_image = masked_image_latents | |
mask, masked_image_latents = self.prepare_mask_latents( | |
mask_condition, | |
masked_image, | |
batch_size, | |
num_channels_latents, | |
num_images_per_prompt, | |
height, | |
width, | |
prompt_embeds.dtype, | |
device, | |
generator, | |
) | |
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) | |
self._num_timesteps = len(timesteps) | |
# handle guidance | |
if self.transformer.config.guidance_embeds: | |
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) | |
guidance = guidance.expand(latents.shape[0]) | |
else: | |
guidance = None | |
# 6. Denoising loop | |
with self.progress_bar(total=num_inference_steps) as progress_bar: | |
for i, t in enumerate(timesteps): | |
if self.interrupt: | |
continue | |
# for 64 channel transformer only. | |
init_latents_proper = image_latents | |
init_mask = mask | |
latents = (1 - init_mask) * init_latents_proper + init_mask * latents | |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
timestep = t.expand(latents.shape[0]).to(latents.dtype) | |
noise_pred = self.transformer( | |
hidden_states=latents, | |
timestep=timestep / 1000, | |
guidance=guidance, | |
pooled_projections=pooled_prompt_embeds, | |
encoder_hidden_states=prompt_embeds, | |
txt_ids=text_ids, | |
img_ids=latent_image_ids, | |
joint_attention_kwargs=self.joint_attention_kwargs, | |
return_dict=False, | |
)[0] | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents_dtype = latents.dtype | |
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
''' | |
# for 64 channel transformer only. | |
init_latents_proper = image_latents | |
init_mask = mask | |
# NOTE: we just use clean latents | |
# if i < len(timesteps) - 1: | |
# noise_timestep = timesteps[i + 1] | |
# init_latents_proper = self.scheduler.scale_noise( | |
# init_latents_proper, torch.tensor([noise_timestep]), noise | |
# ) | |
latents = (1 - init_mask) * init_latents_proper + init_mask * latents | |
''' | |
if latents.dtype != latents_dtype: | |
if torch.backends.mps.is_available(): | |
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 | |
latents = latents.to(latents_dtype) | |
if callback_on_step_end is not None: | |
callback_kwargs = {} | |
for k in callback_on_step_end_tensor_inputs: | |
callback_kwargs[k] = locals()[k] | |
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) | |
latents = callback_outputs.pop("latents", latents) | |
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) | |
# call the callback, if provided | |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | |
progress_bar.update() | |
# if XLA_AVAILABLE: | |
# xm.mark_step() | |
# latents = (1 - mask) * image_latents + mask * latents | |
if output_type == "latent": | |
image = latents | |
else: | |
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) | |
latents = latents[:,:,:,:target_width//self.vae_scale_factor] | |
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor | |
image = self.vae.decode(latents.to(device=self.vae.device, dtype=self.vae.dtype), return_dict=False)[0] | |
image = self.image_processor.postprocess(image, output_type=output_type) | |
# Offload all models | |
self.maybe_free_model_hooks() | |
if not return_dict: | |
return (image,) | |
return FluxPipelineOutput(images=image) | |
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents | |
def flux_pack_latents(latents, batch_size, num_channels_latents, height, width): | |
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) | |
latents = latents.permute(0, 2, 4, 1, 3, 5) | |
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) | |
return latents | |
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents | |
def flux_unpack_latents(latents, height, width, vae_scale_factor): | |
batch_size, num_patches, channels = latents.shape | |
# VAE applies 8x compression on images but we must also account for packing which requires | |
# latent height and width to be divisible by 2. | |
height = 2 * (int(height) // (vae_scale_factor * 2)) | |
width = 2 * (int(width) // (vae_scale_factor * 2)) | |
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) | |
latents = latents.permute(0, 3, 1, 4, 2, 5) | |
latents = latents.reshape(batch_size, channels // (2 * 2), height, width) | |
return latents | |
# TODO: it is more reasonable to have target pe staring at 0 | |
def prepare_latent_image_ids(height, width_tgt, height_spa, width_spa, height_sub, width_sub, device, dtype): | |
assert width_spa==0 or width_tgt==width_spa | |
latent_image_ids = torch.zeros(height, width_tgt, 3, device=device, dtype=dtype) | |
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height, device=device)[:, None] # y坐标 | |
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width_tgt, device=device)[None, :] # x坐标 | |
cond_mark = 0 | |
if width_spa>0: | |
cond_mark += 1 | |
condspa_image_ids = torch.zeros(height_spa, width_spa, 3, device=device, dtype=dtype) | |
condspa_image_ids[..., 0] = cond_mark | |
condspa_image_ids[..., 1] = condspa_image_ids[..., 1] + torch.arange(height_spa, device=device)[:, None] | |
condspa_image_ids[..., 2] = condspa_image_ids[..., 2] + torch.arange(width_spa, device=device)[None, :] | |
condspa_image_ids = condspa_image_ids.reshape(-1, condspa_image_ids.shape[-1]) | |
if width_sub>0: | |
cond_mark += 1 | |
condsub_image_ids = torch.zeros(height_sub, width_sub, 3, device=device, dtype=dtype) | |
condsub_image_ids[..., 0] = cond_mark | |
condsub_image_ids[..., 1] = condsub_image_ids[..., 1] + torch.arange(height_sub, device=device)[:, None] | |
condsub_image_ids[..., 2] = condsub_image_ids[..., 2] + torch.arange(width_sub, device=device)[None, :] + width_tgt | |
condsub_image_ids = condsub_image_ids.reshape(-1, condsub_image_ids.shape[-1]) | |
latent_image_ids = latent_image_ids.reshape(-1, latent_image_ids.shape[-1]) | |
latent_image_ids = torch.cat([latent_image_ids, condspa_image_ids],dim=-2) if width_spa>0 else latent_image_ids | |
latent_image_ids = torch.cat([latent_image_ids, condsub_image_ids],dim=-2) if width_sub>0 else latent_image_ids | |
return latent_image_ids | |
def crop_to_multiple_of_16(img): | |
width, height = img.size | |
# Calculate new dimensions that are multiples of 8 | |
new_width = width - (width % 16) | |
new_height = height - (height % 16) | |
# Calculate crop box coordinates | |
left = (width - new_width) // 2 | |
top = (height - new_height) // 2 | |
right = left + new_width | |
bottom = top + new_height | |
# Crop the image | |
cropped_img = img.crop((left, top, right, bottom)) | |
return cropped_img | |
def resize_and_pad_to_size(image, target_width, target_height): | |
# Convert numpy array to PIL Image if needed | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
# Get original dimensions | |
orig_width, orig_height = image.size | |
# Calculate aspect ratios | |
target_ratio = target_width / target_height | |
orig_ratio = orig_width / orig_height | |
# Calculate new dimensions while maintaining aspect ratio | |
if orig_ratio > target_ratio: | |
# Image is wider than target ratio - scale by width | |
new_width = target_width | |
new_height = int(new_width / orig_ratio) | |
else: | |
# Image is taller than target ratio - scale by height | |
new_height = target_height | |
new_width = int(new_height * orig_ratio) | |
# Resize image | |
resized_image = image.resize((new_width, new_height)) | |
# Create white background image of target size | |
padded_image = Image.new('RGB', (target_width, target_height), 'white') | |
# Calculate padding to center the image | |
left_padding = (target_width - new_width) // 2 | |
top_padding = (target_height - new_height) // 2 | |
# Paste resized image onto padded background | |
padded_image.paste(resized_image, (left_padding, top_padding)) | |
return padded_image, left_padding, top_padding, target_width - new_width - left_padding, target_height - new_height - top_padding | |
def resize_by_height(image, height): | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
# image is a PIL image | |
image = image.resize((int(image.width * height / image.height), height)) | |
return crop_to_multiple_of_16(image) |