Spaces:
Runtime error
Runtime error
| # Copyright (c) 2024 Jaerin Lee | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE. | |
| from transformers import Blip2Processor, Blip2ForConditionalGeneration | |
| from diffusers import ( | |
| AutoencoderTiny, | |
| StableDiffusionXLPipeline, | |
| UNet2DConditionModel, | |
| EulerDiscreteScheduler, | |
| ) | |
| from diffusers.models.attention_processor import ( | |
| AttnProcessor2_0, | |
| FusedAttnProcessor2_0, | |
| LoRAAttnProcessor2_0, | |
| LoRAXFormersAttnProcessor, | |
| XFormersAttnProcessor, | |
| ) | |
| from diffusers.loaders import ( | |
| StableDiffusionXLLoraLoaderMixin, | |
| TextualInversionLoaderMixin, | |
| ) | |
| from diffusers.utils import ( | |
| USE_PEFT_BACKEND, | |
| logging, | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.transforms as T | |
| from einops import rearrange | |
| from typing import Tuple, List, Literal, Optional, Union | |
| from tqdm import tqdm | |
| from PIL import Image | |
| from util import gaussian_lowpass, blend, get_panorama_views, shift_to_mask_bbox_center | |
| # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg | |
| def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): | |
| """ | |
| Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and | |
| Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 | |
| """ | |
| std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) | |
| std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) | |
| # rescale the results from guidance (fixes overexposure) | |
| noise_pred_rescaled = noise_cfg * (std_text / std_cfg) | |
| # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images | |
| noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg | |
| return noise_cfg | |
| class StableMultiDiffusionSDXLPipeline(nn.Module): | |
| def __init__( | |
| self, | |
| device: torch.device, | |
| dtype: torch.dtype = torch.float16, | |
| hf_key: Optional[str] = None, | |
| lora_key: Optional[str] = None, | |
| load_from_local: bool = False, # Turn on if you have already downloaed LoRA & Hugging Face hub is down. | |
| default_mask_std: float = 1.0, # 8.0 | |
| default_mask_strength: float = 1.0, | |
| default_prompt_strength: float = 1.0, # 8.0 | |
| default_bootstrap_steps: int = 1, | |
| default_boostrap_mix_steps: float = 1.0, | |
| default_bootstrap_leak_sensitivity: float = 0.2, | |
| default_preprocess_mask_cover_alpha: float = 0.3, | |
| t_index_list: List[int] = [0, 4, 12, 25, 37], # [0, 5, 16, 18, 20, 37], # # [0, 12, 25, 37], # Magic number. | |
| mask_type: Literal['discrete', 'semi-continuous', 'continuous'] = 'discrete', | |
| has_i2t: bool = True, | |
| lora_weight: float = 1.0, | |
| ) -> None: | |
| r"""Stabilized MultiDiffusion for fast sampling. | |
| Accelrated region-based text-to-image synthesis with Latent Consistency | |
| Model while preserving mask fidelity and quality. | |
| Args: | |
| device (torch.device): Specify CUDA device. | |
| hf_key (Optional[str]): Custom StableDiffusion checkpoint for | |
| stylized generation. | |
| lora_key (Optional[str]): Custom Lightning LoRA for acceleration. | |
| load_from_local (bool): Turn on if you have already downloaed LoRA | |
| & Hugging Face hub is down. | |
| default_mask_std (float): Preprocess mask with Gaussian blur with | |
| specified standard deviation. | |
| default_mask_strength (float): Preprocess mask by multiplying it | |
| globally with the specified variable. Caution: extremely | |
| sensitive. Recommended range: 0.98-1. | |
| default_prompt_strength (float): Preprocess foreground prompts | |
| globally by linearly interpolating its embedding with the | |
| background prompt embeddint with specified mix ratio. Useful | |
| control handle for foreground blending. Recommended range: | |
| 0.5-1. | |
| default_bootstrap_steps (int): Bootstrapping stage steps to | |
| encourage region separation. Recommended range: 1-3. | |
| default_boostrap_mix_steps (float): Bootstrapping background is a | |
| linear interpolation between background latent and the white | |
| image latent. This handle controls the mix ratio. Available | |
| range: 0-(number of bootstrapping inference steps). For | |
| example, 2.3 means that for the first two steps, white image | |
| is used as a bootstrapping background and in the third step, | |
| mixture of white (0.3) and registered background (0.7) is used | |
| as a bootstrapping background. | |
| default_bootstrap_leak_sensitivity (float): Postprocessing at each | |
| inference step by masking away the remaining bootstrap | |
| backgrounds t Recommended range: 0-1. | |
| default_preprocess_mask_cover_alpha (float): Optional preprocessing | |
| where each mask covered by other masks is reduced in its alpha | |
| value by this specified factor. | |
| t_index_list (List[int]): The default scheduling for LCM scheduler. | |
| mask_type (Literal['discrete', 'semi-continuous', 'continuous']): | |
| defines the mask quantization modes. Details in the codes of | |
| `self.process_mask`. Basically, this (subtly) controls the | |
| smoothness of foreground-background blending. More continuous | |
| means more blending, but smaller generated patch depending on | |
| the mask standard deviation. | |
| has_i2t (bool): Automatic background image to text prompt con- | |
| version with BLIP-2 model. May not be necessary for the non- | |
| streaming application. | |
| lora_weight (float): Adjusts weight of the LCM/Lightning LoRA. | |
| Heavily affects the overall quality! | |
| """ | |
| super().__init__() | |
| self.device = device | |
| self.dtype = dtype | |
| self.default_mask_std = default_mask_std | |
| self.default_mask_strength = default_mask_strength | |
| self.default_prompt_strength = default_prompt_strength | |
| self.default_t_list = t_index_list | |
| self.default_bootstrap_steps = default_bootstrap_steps | |
| self.default_boostrap_mix_steps = default_boostrap_mix_steps | |
| self.default_bootstrap_leak_sensitivity = default_bootstrap_leak_sensitivity | |
| self.default_preprocess_mask_cover_alpha = default_preprocess_mask_cover_alpha | |
| self.mask_type = mask_type | |
| # Create model. | |
| print(f'[INFO] Loading Stable Diffusion...') | |
| variant = None | |
| model_ckpt = None | |
| lora_ckpt = None | |
| lightning_repo = 'ByteDance/SDXL-Lightning' | |
| if hf_key is not None: | |
| print(f'[INFO] Using Hugging Face custom model key: {hf_key}') | |
| model_key = hf_key | |
| lora_ckpt = 'sdxl_lightning_4step_lora.safetensors' | |
| self.pipe = StableDiffusionXLPipeline.from_pretrained(model_key, variant=variant, torch_dtype=self.dtype).to(self.device) | |
| self.pipe.load_lora_weights(hf_hub_download(lightning_repo, lora_ckpt), adapter_name='lightning') | |
| self.pipe.set_adapters(["lightning"], adapter_weights=[lora_weight]) | |
| self.pipe.fuse_lora() | |
| else: | |
| model_key = 'stabilityai/stable-diffusion-xl-base-1.0' | |
| variant = 'fp16' | |
| model_ckpt = "sdxl_lightning_4step_unet.safetensors" # Use the correct ckpt for your step setting! | |
| unet = UNet2DConditionModel.from_config(model_key, subfolder='unet').to(self.device, self.dtype) | |
| unet.load_state_dict(load_file(hf_hub_download(lightning_repo, model_ckpt), device=self.device)) | |
| self.pipe = StableDiffusionXLPipeline.from_pretrained(model_key, unet=unet, torch_dtype=self.dtype, variant=variant).to(self.device) | |
| # Create model | |
| if has_i2t: | |
| self.i2t_processor = Blip2Processor.from_pretrained('Salesforce/blip2-opt-2.7b') | |
| self.i2t_model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-opt-2.7b') | |
| # Use SDXL-Lightning LoRA by default. | |
| self.pipe.scheduler = EulerDiscreteScheduler.from_config( | |
| self.pipe.scheduler.config, timestep_spacing="trailing") | |
| self.scheduler = self.pipe.scheduler | |
| self.default_num_inference_steps = 4 | |
| self.default_guidance_scale = 0.0 | |
| if t_index_list is None: | |
| self.prepare_lightning_schedule( | |
| list(range(self.default_num_inference_steps)), | |
| self.default_num_inference_steps, | |
| ) | |
| else: | |
| self.prepare_lightning_schedule(t_index_list, 50) | |
| self.vae = self.pipe.vae | |
| self.tokenizer = self.pipe.tokenizer | |
| self.tokenizer_2 = self.pipe.tokenizer_2 | |
| self.text_encoder = self.pipe.text_encoder | |
| self.text_encoder_2 = self.pipe.text_encoder_2 | |
| self.unet = self.pipe.unet | |
| self.vae_scale_factor = self.pipe.vae_scale_factor | |
| # Prepare white background for bootstrapping. | |
| self.get_white_background(1024, 1024) | |
| print(f'[INFO] Model is loaded!') | |
| def prepare_lightning_schedule( | |
| self, | |
| t_index_list: Optional[List[int]] = None, | |
| num_inference_steps: Optional[int] = None, | |
| s_churn: float = 0.0, | |
| s_tmin: float = 0.0, | |
| s_tmax: float = float("inf"), | |
| ) -> None: | |
| r"""Set up different inference schedule for the diffusion model. | |
| You do not have to run this explicitly if you want to use the default | |
| setting, but if you want other time schedules, run this function | |
| between the module initialization and the main call. | |
| Note: | |
| - Recommended t_index_lists for LCMs: | |
| - [0, 12, 25, 37]: Default schedule for 4 steps. Best for | |
| panorama. Not recommended if you want to use bootstrapping. | |
| Because bootstrapping stage affects the initial structuring | |
| of the generated image & in this four step LCM, this is done | |
| with only at the first step, the structure may be distorted. | |
| - [0, 4, 12, 25, 37]: Recommended if you would use 1-step boot- | |
| strapping. Default initialization in this implementation. | |
| - [0, 5, 16, 18, 20, 37]: Recommended if you would use 2-step | |
| bootstrapping. | |
| - Due to the characteristic of SD1.5 LCM LoRA, setting | |
| `num_inference_steps` larger than 20 may results in overly blurry | |
| and unrealistic images. Beware! | |
| Args: | |
| t_index_list (Optional[List[int]]): The specified scheduling step | |
| regarding the maximum timestep as `num_inference_steps`, which | |
| is by default, 50. That means that | |
| `t_index_list=[0, 12, 25, 37]` is a relative time indices basd | |
| on the full scale of 50. If None, reinitialize the module with | |
| the default value. | |
| num_inference_steps (Optional[int]): The maximum timestep of the | |
| sampler. Defines relative scale of the `t_index_list`. Rarely | |
| used in practice. If None, reinitialize the module with the | |
| default value. | |
| """ | |
| if t_index_list is None: | |
| t_index_list = self.default_t_list | |
| if num_inference_steps is None: | |
| num_inference_steps = self.default_num_inference_steps | |
| self.scheduler.set_timesteps(num_inference_steps) | |
| self.timesteps = self.scheduler.timesteps[torch.tensor(t_index_list)] | |
| # EulerDiscreteScheduler | |
| self.sigmas = self.scheduler.sigmas[torch.tensor(t_index_list)] | |
| self.sigmas_next = torch.cat([self.sigmas, self.sigmas.new_zeros(1)])[1:] | |
| sigma_mask = torch.logical_and(s_tmin <= self.sigmas, self.sigmas <= s_tmax) | |
| # self.gammas = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) * sigma_mask | |
| self.gammas = min(s_churn / (num_inference_steps - 1), 2**0.5 - 1) * sigma_mask | |
| self.sigma_hats = self.sigmas * (self.gammas + 1) | |
| self.dt = self.sigmas_next - self.sigma_hats | |
| noise_lvs = self.sigmas * (self.sigmas**2 + 1)**(-0.5) | |
| self.noise_lvs = noise_lvs[None, :, None, None, None] | |
| self.next_noise_lvs = torch.cat([noise_lvs[1:], noise_lvs.new_zeros(1)])[None, :, None, None, None] | |
| def upcast_vae(self): | |
| dtype = self.vae.dtype | |
| self.vae.to(dtype=torch.float32) | |
| use_torch_2_0_or_xformers = isinstance( | |
| self.vae.decoder.mid_block.attentions[0].processor, | |
| ( | |
| AttnProcessor2_0, | |
| XFormersAttnProcessor, | |
| LoRAXFormersAttnProcessor, | |
| LoRAAttnProcessor2_0, | |
| FusedAttnProcessor2_0, | |
| ), | |
| ) | |
| # if xformers or torch_2_0 is used attention block does not need | |
| # to be in float32 which can save lots of memory | |
| if use_torch_2_0_or_xformers: | |
| self.vae.post_quant_conv.to(dtype) | |
| self.vae.decoder.conv_in.to(dtype) | |
| self.vae.decoder.mid_block.to(dtype) | |
| def _get_add_time_ids( | |
| self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None | |
| ): | |
| add_time_ids = list(original_size + crops_coords_top_left + target_size) | |
| passed_add_embed_dim = ( | |
| self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim | |
| ) | |
| expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features | |
| if expected_add_embed_dim != passed_add_embed_dim: | |
| raise ValueError( | |
| f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." | |
| ) | |
| add_time_ids = torch.tensor([add_time_ids], dtype=dtype) | |
| return add_time_ids | |
| def encode_prompt( | |
| self, | |
| prompt: str, | |
| prompt_2: Optional[str] = None, | |
| device: Optional[torch.device] = None, | |
| num_images_per_prompt: int = 1, | |
| do_classifier_free_guidance: bool = True, | |
| negative_prompt: Optional[str] = None, | |
| negative_prompt_2: Optional[str] = None, | |
| prompt_embeds: Optional[torch.FloatTensor] = None, | |
| negative_prompt_embeds: Optional[torch.FloatTensor] = None, | |
| pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | |
| negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | |
| lora_scale: Optional[float] = None, | |
| clip_skip: Optional[int] = None, | |
| ): | |
| r""" | |
| Encodes the prompt into text encoder hidden states. | |
| Args: | |
| prompt (`str` or `List[str]`, *optional*): | |
| prompt to be encoded | |
| prompt_2 (`str` or `List[str]`, *optional*): | |
| The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is | |
| used in both text-encoders | |
| device: (`torch.device`): | |
| torch device | |
| num_images_per_prompt (`int`): | |
| number of images that should be generated per prompt | |
| do_classifier_free_guidance (`bool`): | |
| whether to use classifier free guidance or not | |
| negative_prompt (`str` or `List[str]`, *optional*): | |
| The prompt or prompts not to guide the image generation. If not defined, one has to pass | |
| `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is | |
| less than `1`). | |
| negative_prompt_2 (`str` or `List[str]`, *optional*): | |
| The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and | |
| `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders | |
| prompt_embeds (`torch.FloatTensor`, *optional*): | |
| Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | |
| provided, text embeddings will be generated from `prompt` input argument. | |
| negative_prompt_embeds (`torch.FloatTensor`, *optional*): | |
| Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt | |
| weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input | |
| argument. | |
| pooled_prompt_embeds (`torch.FloatTensor`, *optional*): | |
| Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. | |
| If not provided, pooled text embeddings will be generated from `prompt` input argument. | |
| negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): | |
| Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt | |
| weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` | |
| input argument. | |
| lora_scale (`float`, *optional*): | |
| A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. | |
| clip_skip (`int`, *optional*): | |
| Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that | |
| the output of the pre-final layer will be used for computing the prompt embeddings. | |
| """ | |
| device = device or self._execution_device | |
| # set lora scale so that monkey patched LoRA | |
| # function of text encoder can correctly access it | |
| if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): | |
| self._lora_scale = lora_scale | |
| # dynamically adjust the LoRA scale | |
| if self.text_encoder is not None: | |
| if not USE_PEFT_BACKEND: | |
| adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) | |
| else: | |
| scale_lora_layers(self.text_encoder, lora_scale) | |
| if self.text_encoder_2 is not None: | |
| if not USE_PEFT_BACKEND: | |
| adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) | |
| else: | |
| scale_lora_layers(self.text_encoder_2, lora_scale) | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| if prompt is not None: | |
| batch_size = len(prompt) | |
| else: | |
| batch_size = prompt_embeds.shape[0] | |
| # Define tokenizers and text encoders | |
| tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] | |
| text_encoders = ( | |
| [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] | |
| ) | |
| if prompt_embeds is None: | |
| prompt_2 = prompt_2 or prompt | |
| prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 | |
| # textual inversion: process multi-vector tokens if necessary | |
| prompt_embeds_list = [] | |
| prompts = [prompt, prompt_2] | |
| for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): | |
| if isinstance(self, TextualInversionLoaderMixin): | |
| prompt = self.maybe_convert_prompt(prompt, tokenizer) | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids | |
| if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( | |
| text_input_ids, untruncated_ids | |
| ): | |
| removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) | |
| logger.warning( | |
| "The following part of your input was truncated because CLIP can only handle sequences up to" | |
| f" {tokenizer.model_max_length} tokens: {removed_text}" | |
| ) | |
| prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) | |
| # We are only ALWAYS interested in the pooled output of the final text encoder | |
| pooled_prompt_embeds = prompt_embeds[0] | |
| if clip_skip is None: | |
| prompt_embeds = prompt_embeds.hidden_states[-2] | |
| else: | |
| # "2" because SDXL always indexes from the penultimate layer. | |
| prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] | |
| prompt_embeds_list.append(prompt_embeds) | |
| prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) | |
| # get unconditional embeddings for classifier free guidance | |
| zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt | |
| if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: | |
| negative_prompt_embeds = torch.zeros_like(prompt_embeds) | |
| negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) | |
| elif do_classifier_free_guidance and negative_prompt_embeds is None: | |
| negative_prompt = negative_prompt or "" | |
| negative_prompt_2 = negative_prompt_2 or negative_prompt | |
| # normalize str to list | |
| negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt | |
| negative_prompt_2 = ( | |
| batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 | |
| ) | |
| uncond_tokens: List[str] | |
| if prompt is not None and type(prompt) is not type(negative_prompt): | |
| raise TypeError( | |
| f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" | |
| f" {type(prompt)}." | |
| ) | |
| elif batch_size != len(negative_prompt): | |
| raise ValueError( | |
| f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" | |
| f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" | |
| " the batch size of `prompt`." | |
| ) | |
| else: | |
| uncond_tokens = [negative_prompt, negative_prompt_2] | |
| negative_prompt_embeds_list = [] | |
| for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): | |
| if isinstance(self, TextualInversionLoaderMixin): | |
| negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) | |
| max_length = prompt_embeds.shape[1] | |
| uncond_input = tokenizer( | |
| negative_prompt, | |
| padding="max_length", | |
| max_length=max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| negative_prompt_embeds = text_encoder( | |
| uncond_input.input_ids.to(device), | |
| output_hidden_states=True, | |
| ) | |
| # We are only ALWAYS interested in the pooled output of the final text encoder | |
| negative_pooled_prompt_embeds = negative_prompt_embeds[0] | |
| negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] | |
| negative_prompt_embeds_list.append(negative_prompt_embeds) | |
| negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) | |
| if self.text_encoder_2 is not None: | |
| prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) | |
| else: | |
| prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) | |
| bs_embed, seq_len, _ = prompt_embeds.shape | |
| # duplicate text embeddings for each generation per prompt, using mps friendly method | |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
| prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) | |
| if do_classifier_free_guidance: | |
| # duplicate unconditional embeddings for each generation per prompt, using mps friendly method | |
| seq_len = negative_prompt_embeds.shape[1] | |
| if self.text_encoder_2 is not None: | |
| negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) | |
| else: | |
| negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) | |
| negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
| negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) | |
| pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( | |
| bs_embed * num_images_per_prompt, -1 | |
| ) | |
| if do_classifier_free_guidance: | |
| negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( | |
| bs_embed * num_images_per_prompt, -1 | |
| ) | |
| if self.text_encoder is not None: | |
| if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: | |
| # Retrieve the original scale by scaling back the LoRA layers | |
| unscale_lora_layers(self.text_encoder, lora_scale) | |
| if self.text_encoder_2 is not None: | |
| if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: | |
| # Retrieve the original scale by scaling back the LoRA layers | |
| unscale_lora_layers(self.text_encoder_2, lora_scale) | |
| return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds | |
| def get_text_prompts(self, image: Image.Image) -> str: | |
| r"""A convenient method to extract text prompt from an image. | |
| This is called if the user does not provide background prompt but only | |
| the background image. We use BLIP-2 to automatically generate prompts. | |
| Args: | |
| image (Image.Image): A PIL image. | |
| Returns: | |
| A single string of text prompt. | |
| """ | |
| if hasattr(self, 'i2t_model'): | |
| question = 'Question: What are in the image? Answer:' | |
| inputs = self.i2t_processor(image, question, return_tensors='pt') | |
| out = self.i2t_model.generate(**inputs, max_new_tokens=77) | |
| prompt = self.i2t_processor.decode(out[0], skip_special_tokens=True).strip() | |
| return prompt | |
| else: | |
| return '' | |
| def encode_imgs( | |
| self, | |
| imgs: torch.Tensor, | |
| generator: Optional[torch.Generator] = None, | |
| vae: Optional[nn.Module] = None, | |
| ) -> torch.Tensor: | |
| r"""A wrapper function for VAE encoder of the latent diffusion model. | |
| Args: | |
| imgs (torch.Tensor): An image to get StableDiffusion latents. | |
| Expected shape: (B, 3, H, W). Expected pixel scale: [0, 1]. | |
| generator (Optional[torch.Generator]): Seed for KL-Autoencoder. | |
| vae (Optional[nn.Module]): Explicitly specify VAE (used for | |
| the demo application with TinyVAE). | |
| Returns: | |
| An image latent embedding with 1/8 size (depending on the auto- | |
| encoder. Shape: (B, 4, H//8, W//8). | |
| """ | |
| def _retrieve_latents( | |
| encoder_output: torch.Tensor, | |
| generator: Optional[torch.Generator] = None, | |
| sample_mode: str = 'sample', | |
| ): | |
| if hasattr(encoder_output, 'latent_dist') and sample_mode == 'sample': | |
| return encoder_output.latent_dist.sample(generator) | |
| elif hasattr(encoder_output, 'latent_dist') and sample_mode == 'argmax': | |
| return encoder_output.latent_dist.mode() | |
| elif hasattr(encoder_output, 'latents'): | |
| return encoder_output.latents | |
| else: | |
| raise AttributeError('Could not access latents of provided encoder_output') | |
| vae = self.vae if vae is None else vae | |
| imgs = 2 * imgs - 1 | |
| latents = vae.config.scaling_factor * _retrieve_latents(vae.encode(imgs), generator=generator) | |
| return latents | |
| def decode_latents(self, latents: torch.Tensor, vae: Optional[nn.Module] = None) -> torch.Tensor: | |
| r"""A wrapper function for VAE decoder of the latent diffusion model. | |
| Args: | |
| latents (torch.Tensor): An image latent to get associated images. | |
| Expected shape: (B, 4, H//8, W//8). | |
| vae (Optional[nn.Module]): Explicitly specify VAE (used for | |
| the demo application with TinyVAE). | |
| Returns: | |
| An image latent embedding with 1/8 size (depending on the auto- | |
| encoder. Shape: (B, 3, H, W). | |
| """ | |
| vae = self.vae if vae is None else vae | |
| latents = 1 / vae.config.scaling_factor * latents | |
| imgs = vae.decode(latents).sample | |
| imgs = (imgs / 2 + 0.5).clip_(0, 1) | |
| return imgs | |
| def get_white_background(self, height: int, width: int) -> torch.Tensor: | |
| r"""White background image latent for bootstrapping or in case of | |
| absent background. | |
| Additionally stores the maximally-sized white latent for fast retrieval | |
| in the future. By default, we initially call this with 1024x1024 sized | |
| white image, so the function is rarely visited twice. | |
| Args: | |
| height (int): The height of the white *image*, not its latent. | |
| width (int): The width of the white *image*, not its latent. | |
| Returns: | |
| A white image latent of size (1, 4, height//8, width//8). A cropped | |
| version of the stored white latent is returned if the requested | |
| size is smaller than what we already have created. | |
| """ | |
| if not hasattr(self, 'white') or self.white.shape[-2] < height or self.white.shape[-1] < width: | |
| white = torch.ones(1, 3, height, width, dtype=self.dtype, device=self.device) | |
| self.white = self.encode_imgs(white) | |
| return self.white | |
| return self.white[..., :(height // self.vae_scale_factor), :(width // self.vae_scale_factor)] | |
| def process_mask( | |
| self, | |
| masks: Union[torch.Tensor, Image.Image, List[Image.Image]], | |
| strength: Optional[Union[torch.Tensor, float]] = None, | |
| std: Optional[Union[torch.Tensor, float]] = None, | |
| height: int = 1024, | |
| width: int = 1024, | |
| use_boolean_mask: bool = True, | |
| timesteps: Optional[torch.Tensor] = None, | |
| preprocess_mask_cover_alpha: Optional[float] = None, | |
| ) -> Tuple[torch.Tensor]: | |
| r"""Fast preprocess of masks for region-based generation with fine- | |
| grained controls. | |
| Mask preprocessing is done in four steps: | |
| 1. Resizing: Resize the masks into the specified width and height by | |
| nearest neighbor interpolation. | |
| 2. (Optional) Ordering: Masks with higher indices are considered to | |
| cover the masks with smaller indices. Covered masks are decayed | |
| in its alpha value by the specified factor of | |
| `preprocess_mask_cover_alpha`. | |
| 3. Blurring: Gaussian blur is applied to the mask with the specified | |
| standard deviation (isotropic). This results in gradual increase of | |
| masked region as the timesteps evolve, naturally blending fore- | |
| ground and the predesignated background. Not strictly required if | |
| you want to produce images from scratch withoout background. | |
| 4. Quantization: Split the real-numbered masks of value between [0, 1] | |
| into predefined noise levels for each quantized scheduling step of | |
| the diffusion sampler. For example, if the diffusion model sampler | |
| has noise level of [0.9977, 0.9912, 0.9735, 0.8499, 0.5840], which | |
| is the default noise level of this module with schedule [0, 4, 12, | |
| 25, 37], the masks are split into binary masks whose values are | |
| greater than these levels. This results in tradual increase of mask | |
| region as the timesteps increase. Details are described in our | |
| paper at https://arxiv.org/pdf/2403.09055.pdf. | |
| On the Three Modes of `mask_type`: | |
| `self.mask_type` is predefined at the initialization stage of this | |
| pipeline. Three possible modes are available: 'discrete', 'semi- | |
| continuous', and 'continuous'. These define the mask quantization | |
| modes we use. Basically, this (subtly) controls the smoothness of | |
| foreground-background blending. Continuous modes produces nonbinary | |
| masks to further blend foreground and background latents by linear- | |
| ly interpolating between them. Semi-continuous masks only applies | |
| continuous mask at the last step of the LCM sampler. Due to the | |
| large step size of the LCM scheduler, we find that our continuous | |
| blending helps generating seamless inpainting and editing results. | |
| Args: | |
| masks (Union[torch.Tensor, Image.Image, List[Image.Image]]): Masks. | |
| strength (Optional[Union[torch.Tensor, float]]): Mask strength that | |
| overrides the default value. A globally multiplied factor to | |
| the mask at the initial stage of processing. Can be applied | |
| seperately for each mask. | |
| std (Optional[Union[torch.Tensor, float]]): Mask blurring Gaussian | |
| kernel's standard deviation. Overrides the default value. Can | |
| be applied seperately for each mask. | |
| height (int): The height of the expected generation. Mask is | |
| resized to (height//8, width//8) with nearest neighbor inter- | |
| polation. | |
| width (int): The width of the expected generation. Mask is resized | |
| to (height//8, width//8) with nearest neighbor interpolation. | |
| use_boolean_mask (bool): Specify this to treat the mask image as | |
| a boolean tensor. The retion with dark part darker than 0.5 of | |
| the maximal pixel value (that is, 127.5) is considered as the | |
| designated mask. | |
| timesteps (Optional[torch.Tensor]): Defines the scheduler noise | |
| levels that acts as bins of mask quantization. | |
| preprocess_mask_cover_alpha (Optional[float]): Optional pre- | |
| processing where each mask covered by other masks is reduced in | |
| its alpha value by this specified factor. Overrides the default | |
| value. | |
| Returns: A tuple of tensors. | |
| - masks: Preprocessed (ordered, blurred, and quantized) binary/non- | |
| binary masks (see the explanation on `mask_type` above) for | |
| region-based image synthesis. | |
| - masks_blurred: Gaussian blurred masks. Used for optionally | |
| specified foreground-background blending after image | |
| generation. | |
| - std: Mask blur standard deviation. Used for optionally specified | |
| foreground-background blending after image generation. | |
| """ | |
| if isinstance(masks, Image.Image): | |
| masks = [masks] | |
| if isinstance(masks, (tuple, list)): | |
| # Assumes white background for Image.Image; | |
| # inverted boolean masks with shape (1, 1, H, W) for torch.Tensor. | |
| if use_boolean_mask: | |
| proc = lambda m: T.ToTensor()(m)[None, -1:] < 0.5 | |
| else: | |
| proc = lambda m: 1.0 - T.ToTensor()(m)[None, -1:] | |
| masks = torch.cat([proc(mask) for mask in masks], dim=0).float().clip_(0, 1) | |
| masks = F.interpolate(masks.float(), size=(height, width), mode='bilinear', align_corners=False) | |
| masks = masks.to(self.device) | |
| # Background mask alpha is decayed by the specified factor where foreground masks covers it. | |
| if preprocess_mask_cover_alpha is None: | |
| preprocess_mask_cover_alpha = self.default_preprocess_mask_cover_alpha | |
| if preprocess_mask_cover_alpha > 0: | |
| masks = torch.stack([ | |
| torch.where( | |
| masks[i + 1:].sum(dim=0) > 0, | |
| mask * preprocess_mask_cover_alpha, | |
| mask, | |
| ) if i < len(masks) - 1 else mask | |
| for i, mask in enumerate(masks) | |
| ], dim=0) | |
| # Scheduler noise levels for mask quantization. | |
| if timesteps is None: | |
| noise_lvs = self.noise_lvs | |
| next_noise_lvs = self.next_noise_lvs | |
| else: | |
| noise_lvs_ = self.sigmas * (self.sigmas**2 + 1)**(-0.5) | |
| # noise_lvs_ = (1 - self.scheduler.alphas_cumprod[timesteps].to(self.device)) ** 0.5 | |
| noise_lvs = noise_lvs_[None, :, None, None, None].to(masks.device) | |
| next_noise_lvs = torch.cat([noise_lvs_[1:], noise_lvs_.new_zeros(1)])[None, :, None, None, None] | |
| # Mask preprocessing parameters are fetched from the default settings. | |
| if std is None: | |
| std = self.default_mask_std | |
| if isinstance(std, (int, float)): | |
| std = [std] * len(masks) | |
| if isinstance(std, (list, tuple)): | |
| std = torch.as_tensor(std, dtype=torch.float, device=self.device) | |
| if strength is None: | |
| strength = self.default_mask_strength | |
| if isinstance(strength, (int, float)): | |
| strength = [strength] * len(masks) | |
| if isinstance(strength, (list, tuple)): | |
| strength = torch.as_tensor(strength, dtype=torch.float, device=self.device) | |
| if (std > 0).any(): | |
| std = torch.where(std > 0, std, 1e-5) | |
| masks = gaussian_lowpass(masks, std) | |
| masks_blurred = masks | |
| # NOTE: This `strength` aligns with `denoising strength`. However, with LCM, using strength < 0.96 | |
| # gives unpleasant results. | |
| masks = masks * strength[:, None, None, None] | |
| masks = masks.unsqueeze(1).repeat(1, noise_lvs.shape[1], 1, 1, 1) | |
| # Mask is quantized according to the current noise levels specified by the scheduler. | |
| if self.mask_type == 'discrete': | |
| # Discrete mode. | |
| masks = masks > noise_lvs | |
| elif self.mask_type == 'semi-continuous': | |
| # Semi-continuous mode (continuous at the last step only). | |
| masks = torch.cat(( | |
| masks[:, :-1] > noise_lvs[:, :-1], | |
| ( | |
| (masks[:, -1:] - next_noise_lvs[:, -1:]) / (noise_lvs[:, -1:] - next_noise_lvs[:, -1:]) | |
| ).clip_(0, 1), | |
| ), dim=1) | |
| elif self.mask_type == 'continuous': | |
| # Continuous mode: Have the exact same `1` coverage with discrete mode, but the mask gradually | |
| # decreases continuously after the discrete mode boundary to become `0` at the | |
| # next lower threshold. | |
| masks = ((masks - next_noise_lvs) / (noise_lvs - next_noise_lvs)).clip_(0, 1) | |
| # NOTE: Post processing mask strength does not align with conventional 'denoising_strength'. However, | |
| # fine-grained mask alpha channel tuning is available with this form. | |
| # masks = masks * strength[None, :, None, None, None] | |
| h = height // self.vae_scale_factor | |
| w = width // self.vae_scale_factor | |
| masks = rearrange(masks.float(), 'p t () h w -> (p t) () h w') | |
| masks = F.interpolate(masks, size=(h, w), mode='nearest') | |
| masks = rearrange(masks.to(self.dtype), '(p t) () h w -> p t () h w', p=len(std)) | |
| return masks, masks_blurred, std | |
| def scheduler_scale_model_input( | |
| self, | |
| latent: torch.FloatTensor, | |
| idx: int, | |
| ) -> torch.FloatTensor: | |
| """ | |
| Ensures interchangeability with schedulers that need to scale the denoising model input depending on the | |
| current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. | |
| Args: | |
| sample (`torch.FloatTensor`): | |
| The input sample. | |
| timestep (`int`, *optional*): | |
| The current timestep in the diffusion chain. | |
| Returns: | |
| `torch.FloatTensor`: | |
| A scaled input sample. | |
| """ | |
| latent = latent / ((self.sigmas[idx]**2 + 1) ** 0.5) | |
| return latent | |
| def scheduler_step( | |
| self, | |
| noise_pred: torch.Tensor, | |
| idx: int, | |
| latent: torch.Tensor, | |
| ) -> torch.Tensor: | |
| r"""Denoise-only step for reverse diffusion scheduler. | |
| Designed to match the interface of the original `pipe.scheduler.step`, | |
| which is a combination of this method and the following | |
| `scheduler_add_noise`. | |
| Args: | |
| noise_pred (torch.Tensor): Noise prediction results from the U-Net. | |
| idx (int): Instead of timesteps (in [0, 1000]-scale) use indices | |
| for the timesteps tensor (ranged in [0, len(timesteps)-1]). | |
| latent (torch.Tensor): Noisy latent. | |
| Returns: | |
| A denoised tensor with the same size as latent. | |
| """ | |
| # Upcast to avoid precision issues when computing prev_sample. | |
| latent = latent.to(torch.float32) | |
| # 1. Compute predicted original sample (x_0) from sigma-scaled predicted noise. | |
| assert self.scheduler.config.prediction_type == 'epsilon', 'Only supports `prediction_type` of `epsilon` for now.' | |
| # pred_original_sample = latent - self.sigma_hats[idx] * noise_pred | |
| # prev_sample = pred_original_sample + noise_pred * (self.dt[i] + self.sigma_hats[i]) | |
| # return pred_original_sample.to(self.dtype) | |
| # 2. Convert to an ODE derivative. | |
| prev_sample = latent + noise_pred * self.dt[idx] | |
| return prev_sample.to(self.dtype) | |
| def scheduler_add_noise( | |
| self, | |
| latent: torch.Tensor, | |
| noise: Optional[torch.Tensor], | |
| idx: int, | |
| s_noise: float = 1.0, | |
| initial: bool = False, | |
| ) -> torch.Tensor: | |
| r"""Separated noise-add step for the reverse diffusion scheduler. | |
| Designed to match the interface of the original | |
| `pipe.scheduler.add_noise`. | |
| Args: | |
| latent (torch.Tensor): Denoised latent. | |
| noise (torch.Tensor): Added noise. Can be None. If None, a random | |
| noise is newly sampled for addition. | |
| idx (int): Instead of timesteps (in [0, 1000]-scale) use indices | |
| for the timesteps tensor (ranged in [0, len(timesteps)-1]). | |
| Returns: | |
| A noisy tensor with the same size as latent. | |
| """ | |
| if initial: | |
| if idx < len(self.sigmas) and idx >= 0: | |
| noise = torch.randn_like(latent) if noise is None else noise | |
| return latent + self.sigmas[idx] * noise | |
| else: | |
| return latent | |
| else: | |
| # 3. Post-add noise. | |
| noise_lv = (self.sigma_hats[idx]**2 - self.sigmas[idx]**2) ** 0.5 | |
| if self.gammas[idx] > 0 and noise_lv > 0 and s_noise > 0 and idx < len(self.sigmas) and idx >= 0: | |
| noise = torch.randn_like(latent) if noise is None else noise | |
| eps = noise * s_noise * noise_lv | |
| latent = latent + eps | |
| # pred_original_sample = pred_original_sample + eps | |
| return latent | |
| def __call__( | |
| self, | |
| prompts: Optional[Union[str, List[str]]] = None, | |
| negative_prompts: Union[str, List[str]] = '', | |
| suffix: Optional[str] = None, #', background is ', | |
| background: Optional[Union[torch.Tensor, Image.Image]] = None, | |
| background_prompt: Optional[str] = None, | |
| background_negative_prompt: str = '', | |
| height: int = 1024, | |
| width: int = 1024, | |
| num_inference_steps: Optional[int] = None, | |
| guidance_scale: Optional[float] = None, | |
| prompt_strengths: Optional[Union[torch.Tensor, float, List[float]]] = None, | |
| masks: Optional[Union[Image.Image, List[Image.Image]]] = None, | |
| mask_strengths: Optional[Union[torch.Tensor, float, List[float]]] = None, | |
| mask_stds: Optional[Union[torch.Tensor, float, List[float]]] = None, | |
| use_boolean_mask: bool = True, | |
| do_blend: bool = True, | |
| tile_size: int = 1024, | |
| bootstrap_steps: Optional[int] = None, | |
| boostrap_mix_steps: Optional[float] = None, | |
| bootstrap_leak_sensitivity: Optional[float] = None, | |
| preprocess_mask_cover_alpha: Optional[float] = None, | |
| ) -> Image.Image: | |
| r"""Arbitrary-size image generation from multiple pairs of (regional) | |
| text prompt-mask pairs. | |
| This is a main routine for this pipeline. | |
| Example: | |
| >>> device = torch.device('cuda:0') | |
| >>> smd = StableMultiDiffusionPipeline(device) | |
| >>> prompts = {... specify prompts} | |
| >>> masks = {... specify mask tensors} | |
| >>> height, width = masks.shape[-2:] | |
| >>> image = smd( | |
| >>> prompts, masks=masks.float(), height=height, width=width) | |
| >>> image.save('my_beautiful_creation.png') | |
| Args: | |
| prompts (Union[str, List[str]]): A text prompt. | |
| negative_prompts (Union[str, List[str]]): A negative text prompt. | |
| suffix (Optional[str]): One option for blending foreground prompts | |
| with background prompts by simply appending background prompt | |
| to the end of each foreground prompt with this `middle word` in | |
| between. For example, if you set this as `, background is`, | |
| then the foreground prompt will be changed into | |
| `(fg), background is (bg)` before conditional generation. | |
| background (Optional[Union[torch.Tensor, Image.Image]]): a | |
| background image, if the user wants to draw in front of the | |
| specified image. Background prompt will automatically generated | |
| with a BLIP-2 model. | |
| background_prompt (Optional[str]): The background prompt is used | |
| for preprocessing foreground prompt embeddings to blend | |
| foreground and background. | |
| background_negative_prompt (Optional[str]): The negative background | |
| prompt. | |
| height (int): Height of a generated image. It is tiled if larger | |
| than `tile_size`. | |
| width (int): Width of a generated image. It is tiled if larger | |
| than `tile_size`. | |
| num_inference_steps (Optional[int]): Number of inference steps. | |
| Default inference scheduling is used if none is specified. | |
| guidance_scale (Optional[float]): Classifier guidance scale. | |
| Default value is used if none is specified. | |
| prompt_strength (float): Overrides default value. Preprocess | |
| foreground prompts globally by linearly interpolating its | |
| embedding with the background prompt embeddint with specified | |
| mix ratio. Useful control handle for foreground blending. | |
| Recommended range: 0.5-1. | |
| masks (Optional[Union[Image.Image, List[Image.Image]]]): a list of | |
| mask images. Each mask associates with each of the text prompts | |
| and each of the negative prompts. If specified as an image, it | |
| regards the image as a boolean mask. Also accepts torch.Tensor | |
| masks, which can have nonbinary values for fine-grained | |
| controls in mixing regional generations. | |
| mask_strengths (Optional[Union[torch.Tensor, float, List[float]]]): | |
| Overrides the default value. an be assigned for each mask | |
| separately. Preprocess mask by multiplying it globally with the | |
| specified variable. Caution: extremely sensitive. Recommended | |
| range: 0.98-1. | |
| mask_stds (Optional[Union[torch.Tensor, float, List[float]]]): | |
| Overrides the default value. Can be assigned for each mask | |
| separately. Preprocess mask with Gaussian blur with specified | |
| standard deviation. Recommended range: 0-64. | |
| use_boolean_mask (bool): Turn this off if you want to treat the | |
| mask image as nonbinary one. The module will use the last | |
| channel of the given image in `masks` as the mask value. | |
| do_blend (bool): Blend the generated foreground and the optionally | |
| predefined background by smooth boundary obtained from Gaussian | |
| blurs of the foreground `masks` with the given `mask_stds`. | |
| tile_size (Optional[int]): Tile size of the panorama generation. | |
| Works best with the default training size of the Stable- | |
| Diffusion model, i.e., 1024 or 1024 for SD1.5 and 1024 for SDXL. | |
| bootstrap_steps (int): Overrides the default value. Bootstrapping | |
| stage steps to encourage region separation. Recommended range: | |
| 1-3. | |
| boostrap_mix_steps (float): Overrides the default value. | |
| Bootstrapping background is a linear interpolation between | |
| background latent and the white image latent. This handle | |
| controls the mix ratio. Available range: 0-(number of | |
| bootstrapping inference steps). For example, 2.3 means that for | |
| the first two steps, white image is used as a bootstrapping | |
| background and in the third step, mixture of white (0.3) and | |
| registered background (0.7) is used as a bootstrapping | |
| background. | |
| bootstrap_leak_sensitivity (float): Overrides the default value. | |
| Postprocessing at each inference step by masking away the | |
| remaining bootstrap backgrounds t Recommended range: 0-1. | |
| preprocess_mask_cover_alpha (float): Overrides the default value. | |
| Optional preprocessing where each mask covered by other masks | |
| is reduced in its alpha value by this specified factor. | |
| Returns: A PIL.Image image of a panorama (large-size) image. | |
| """ | |
| ### Simplest cases | |
| # prompts is None: return background. | |
| # masks is None but prompts is not None: return prompts | |
| # masks is not None and prompts is not None: Do StableMultiDiffusion. | |
| if prompts is None or (isinstance(prompts, (list, tuple, str)) and len(prompts) == 0): | |
| if background is None and background_prompt is not None: | |
| return sample(background_prompt, background_negative_prompt, height, width, num_inference_steps, guidance_scale) | |
| return background | |
| elif masks is None or (isinstance(masks, (list, tuple)) and len(masks) == 0): | |
| return sample(prompts, negative_prompts, height, width, num_inference_steps, guidance_scale) | |
| ### Prepare generation | |
| if num_inference_steps is not None: | |
| # self.prepare_lcm_schedule(list(range(num_inference_steps)), num_inference_steps) | |
| self.prepare_lightning_schedule(list(range(num_inference_steps)), num_inference_steps) | |
| if guidance_scale is None: | |
| guidance_scale = self.default_guidance_scale | |
| do_classifier_free_guidance = guidance_scale > 1.0 | |
| ### Prompts & Masks | |
| # asserts #m > 0 and #p > 0. | |
| # #m == #p == #n > 0: We happily generate according to the prompts & masks. | |
| # #m != #p: #p should be 1 and we will broadcast text embeds of p through m masks. | |
| # #p != #n: #n should be 1 and we will broadcast negative embeds n through p prompts. | |
| if isinstance(masks, Image.Image): | |
| masks = [masks] | |
| if isinstance(prompts, str): | |
| prompts = [prompts] | |
| if isinstance(negative_prompts, str): | |
| negative_prompts = [negative_prompts] | |
| num_masks = len(masks) | |
| num_prompts = len(prompts) | |
| num_nprompts = len(negative_prompts) | |
| assert num_prompts in (num_masks, 1), \ | |
| f'The number of prompts {num_prompts} should match the number of masks {num_masks}!' | |
| assert num_nprompts in (num_prompts, 1), \ | |
| f'The number of negative prompts {num_nprompts} should match the number of prompts {num_prompts}!' | |
| fg_masks, masks_g, std = self.process_mask( | |
| masks, | |
| mask_strengths, | |
| mask_stds, | |
| height=height, | |
| width=width, | |
| use_boolean_mask=use_boolean_mask, | |
| timesteps=self.timesteps, | |
| preprocess_mask_cover_alpha=preprocess_mask_cover_alpha, | |
| ) # (p, t, 1, H, W) | |
| bg_masks = (1 - fg_masks.sum(dim=0)).clip_(0, 1) # (T, 1, h, w) | |
| has_background = bg_masks.sum() > 0 | |
| h = (height + self.vae_scale_factor - 1) // self.vae_scale_factor | |
| w = (width + self.vae_scale_factor - 1) // self.vae_scale_factor | |
| ### Background | |
| # background == None && background_prompt == None: Initialize with white background. | |
| # background == None && background_prompt != None: Generate background *along with other prompts*. | |
| # background != None && background_prompt == None: Retrieve text prompt using BLIP. | |
| # background != None && background_prompt != None: Use the given arguments. | |
| # not has_background: no effect of prompt_strength (the mix ratio between fg prompt & bg prompt) | |
| # has_background && prompt_strength != 1: mix only for this case. | |
| bg_latent = None | |
| if has_background: | |
| if background is None and background_prompt is not None: | |
| fg_masks = torch.cat((bg_masks[None], fg_masks), dim=0) | |
| if suffix is not None: | |
| prompts = [p + suffix + background_prompt for p in prompts] | |
| prompts = [background_prompt] + prompts | |
| negative_prompts = [background_negative_prompt] + negative_prompts | |
| has_background = False # Regard that background does not exist. | |
| else: | |
| if background is None and background_prompt is None: | |
| background = torch.ones(1, 3, height, width, dtype=self.dtype, device=self.device) | |
| background_prompt = 'simple white background image' | |
| elif background is not None and background_prompt is None: | |
| background_prompt = self.get_text_prompts(background) | |
| if suffix is not None: | |
| prompts = [p + suffix + background_prompt for p in prompts] | |
| prompts = [background_prompt] + prompts | |
| negative_prompts = [background_negative_prompt] + negative_prompts | |
| if isinstance(background, Image.Image): | |
| background = T.ToTensor()(background).to(dtype=self.dtype, device=self.device)[None] | |
| background = F.interpolate(background, size=(height, width), mode='bicubic', align_corners=False) | |
| bg_latent = self.encode_imgs(background) | |
| # Bootstrapping stage preparation. | |
| if bootstrap_steps is None: | |
| bootstrap_steps = self.default_bootstrap_steps | |
| if boostrap_mix_steps is None: | |
| boostrap_mix_steps = self.default_boostrap_mix_steps | |
| if bootstrap_leak_sensitivity is None: | |
| bootstrap_leak_sensitivity = self.default_bootstrap_leak_sensitivity | |
| if bootstrap_steps > 0: | |
| height_ = min(height, tile_size) | |
| width_ = min(width, tile_size) | |
| white = self.get_white_background(height, width) # (1, 4, h, w) | |
| ### Prepare text embeddings (optimized for the minimal encoder batch size) | |
| # SDXL pipeline settings. | |
| batch_size = 1 | |
| output_type = 'pil' | |
| guidance_rescale = 0.7 | |
| prompt_2 = None | |
| device = self.device | |
| num_images_per_prompt = 1 | |
| negative_prompt_2 = None | |
| original_size = (height, width) | |
| target_size = (height, width) | |
| crops_coords_top_left = (0, 0) | |
| negative_crops_coords_top_left = (0, 0) | |
| negative_original_size = None | |
| negative_target_size = None | |
| pooled_prompt_embeds = None | |
| negative_pooled_prompt_embeds = None | |
| text_encoder_lora_scale = None | |
| prompt_embeds = None | |
| negative_prompt_embeds = None | |
| ( | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| pooled_prompt_embeds, | |
| negative_pooled_prompt_embeds, | |
| ) = self.encode_prompt( | |
| prompt=prompts, | |
| prompt_2=prompt_2, | |
| device=device, | |
| num_images_per_prompt=num_images_per_prompt, | |
| do_classifier_free_guidance=do_classifier_free_guidance, | |
| negative_prompt=negative_prompts, | |
| negative_prompt_2=negative_prompt_2, | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| pooled_prompt_embeds=pooled_prompt_embeds, | |
| negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, | |
| lora_scale=text_encoder_lora_scale, | |
| ) | |
| add_text_embeds = pooled_prompt_embeds | |
| if self.text_encoder_2 is None: | |
| text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) | |
| else: | |
| text_encoder_projection_dim = self.text_encoder_2.config.projection_dim | |
| add_time_ids = self._get_add_time_ids( | |
| original_size, | |
| crops_coords_top_left, | |
| target_size, | |
| dtype=prompt_embeds.dtype, | |
| text_encoder_projection_dim=text_encoder_projection_dim, | |
| ) | |
| if negative_original_size is not None and negative_target_size is not None: | |
| negative_add_time_ids = self._get_add_time_ids( | |
| negative_original_size, | |
| negative_crops_coords_top_left, | |
| negative_target_size, | |
| dtype=prompt_embeds.dtype, | |
| text_encoder_projection_dim=text_encoder_projection_dim, | |
| ) | |
| else: | |
| negative_add_time_ids = add_time_ids | |
| if has_background: | |
| # First channel is background prompt text embeds. Background prompt itself is not used for generation. | |
| s = prompt_strengths | |
| if prompt_strengths is None: | |
| s = self.default_prompt_strength | |
| if isinstance(s, (int, float)): | |
| s = [s] * num_prompts | |
| if isinstance(s, (list, tuple)): | |
| assert len(s) == num_prompts, \ | |
| f'The number of prompt strengths {len(s)} should match the number of prompts {num_prompts}!' | |
| s = torch.as_tensor(s, dtype=self.dtype, device=self.device) | |
| s = s[:, None, None] | |
| be = prompt_embeds[:1] | |
| fe = prompt_embeds[1:] | |
| prompt_embeds = torch.lerp(be, fe, s) # (p, 77, 1024) | |
| if negative_prompt_embeds is not None: | |
| bu = negative_prompt_embeds[:1] | |
| fu = negative_prompt_embeds[1:] | |
| if num_prompts > num_nprompts: | |
| # # negative prompts = 1; # prompts > 1. | |
| assert fu.shape[0] == 1 and fe.shape == num_prompts | |
| fu = fu.repeat(num_prompts, 1, 1) | |
| negative_prompt_embeds = torch.lerp(bu, fu, s) # (n, 77, 1024) | |
| elif negative_prompt_embeds is not None and num_prompts > num_nprompts: | |
| # # negative prompts = 1; # prompts > 1. | |
| assert negative_prompt_embeds.shape[0] == 1 and prompt_embeds.shape[0] == num_prompts | |
| negative_prompt_embeds = negative_prompt_embeds.repeat(num_prompts, 1, 1) | |
| # assert negative_prompt_embeds.shape[0] == prompt_embeds.shape[0] == num_prompts | |
| if num_masks > num_prompts: | |
| assert masks.shape[0] == num_masks and num_prompts == 1 | |
| prompt_embeds = prompt_embeds.repeat(num_masks, 1, 1) | |
| if negative_prompt_embeds is not None: | |
| negative_prompt_embeds = negative_prompt_embeds.repeat(num_masks, 1, 1) | |
| # SDXL pipeline settings. | |
| if do_classifier_free_guidance: | |
| prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) | |
| add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) | |
| add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) | |
| del negative_prompt_embeds, negative_pooled_prompt_embeds, negative_add_time_ids | |
| prompt_embeds = prompt_embeds.to(device) | |
| add_text_embeds = add_text_embeds.to(device) | |
| add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) | |
| ### Run | |
| # Latent initialization. | |
| if self.timesteps[0] < 999 and has_background: | |
| latents = self.scheduler_add_noise(bg_latents, None, 0, initial=True) | |
| else: | |
| latents = torch.randn((1, self.unet.config.in_channels, h, w), dtype=self.dtype, device=self.device) | |
| latents = latents * self.scheduler.init_noise_sigma | |
| # Tiling (if needed). | |
| if height > tile_size or width > tile_size: | |
| t = (tile_size + self.vae_scale_factor - 1) // self.vae_scale_factor | |
| views, tile_masks = get_panorama_views(h, w, t) | |
| tile_masks = tile_masks.to(self.device) | |
| else: | |
| views = [(0, h, 0, w)] | |
| tile_masks = latents.new_ones((1, 1, h, w)) | |
| value = torch.zeros_like(latents) | |
| count_all = torch.zeros_like(latents) | |
| with torch.autocast('cuda'): | |
| for i, t in enumerate(tqdm(self.timesteps)): | |
| fg_mask = fg_masks[:, i] | |
| bg_mask = bg_masks[i:i + 1] | |
| value.zero_() | |
| count_all.zero_() | |
| for j, (h_start, h_end, w_start, w_end) in enumerate(views): | |
| fg_mask_ = fg_mask[..., h_start:h_end, w_start:w_end] | |
| latents_ = latents[..., h_start:h_end, w_start:w_end].repeat(num_masks, 1, 1, 1) | |
| # Additional arguments for the SDXL pipeline. | |
| add_time_ids_input = add_time_ids.clone() | |
| add_time_ids_input[:, 2] = h_start * self.vae_scale_factor | |
| add_time_ids_input[:, 3] = w_start * self.vae_scale_factor | |
| add_time_ids_input = add_time_ids_input.repeat_interleave(num_prompts, dim=0) | |
| # Bootstrap for tight background. | |
| if i < bootstrap_steps: | |
| mix_ratio = min(1, max(0, boostrap_mix_steps - i)) | |
| # Treat the first foreground latent as the background latent if one does not exist. | |
| bg_latents_ = bg_latents[..., h_start:h_end, w_start:w_end] if has_background else latents_[:1] | |
| white_ = white[..., h_start:h_end, w_start:w_end] | |
| white_ = self.scheduler_add_noise(white_, None, i, initial=True) | |
| bg_latents_ = mix_ratio * white_ + (1.0 - mix_ratio) * bg_latents_ | |
| latents_ = (1.0 - fg_mask_) * bg_latents_ + fg_mask_ * latents_ | |
| # Centering. | |
| latents_ = shift_to_mask_bbox_center(latents_, fg_mask_, reverse=True) | |
| latent_model_input = torch.cat([latents_] * 2) if do_classifier_free_guidance else latents_ | |
| latent_model_input = self.scheduler_scale_model_input(latent_model_input, i) | |
| # Perform one step of the reverse diffusion. | |
| added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids_input} | |
| noise_pred = self.unet( | |
| latent_model_input, | |
| t, | |
| encoder_hidden_states=prompt_embeds, | |
| timestep_cond=None, | |
| cross_attention_kwargs=None, | |
| added_cond_kwargs=added_cond_kwargs, | |
| return_dict=False, | |
| )[0] | |
| if do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
| if do_classifier_free_guidance and guidance_rescale > 0.0: | |
| # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf | |
| noise_pred = rescale_noise_cfg(noise_pred, noise_pred_cond, guidance_rescale=guidance_rescale) | |
| latents_ = self.scheduler_step(noise_pred, i, latents_) | |
| if i < bootstrap_steps: | |
| # Uncentering. | |
| latents_ = shift_to_mask_bbox_center(latents_, fg_mask_) | |
| # Remove leakage (optional). | |
| leak = (latents_ - bg_latents_).pow(2).mean(dim=1, keepdim=True) | |
| leak_sigmoid = torch.sigmoid(leak / bootstrap_leak_sensitivity) * 2 - 1 | |
| fg_mask_ = fg_mask_ * leak_sigmoid | |
| # Mix the latents. | |
| fg_mask_ = fg_mask_ * tile_masks[:, j:j + 1, h_start:h_end, w_start:w_end] | |
| value[..., h_start:h_end, w_start:w_end] += (fg_mask_ * latents_).sum(dim=0, keepdim=True) | |
| count_all[..., h_start:h_end, w_start:w_end] += fg_mask_.sum(dim=0, keepdim=True) | |
| latents = torch.where(count_all > 0, value / count_all, value) | |
| bg_mask = (1 - count_all).clip_(0, 1) # (T, 1, h, w) | |
| if has_background: | |
| latents = (1 - bg_mask) * latents + bg_mask * bg_latents | |
| # Noise is added after mixing. | |
| if i < len(self.timesteps) - 1: | |
| latents = self.scheduler_add_noise(latents, None, i + 1) | |
| if not output_type == "latent": | |
| # make sure the VAE is in float32 mode, as it overflows in float16 | |
| needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast | |
| if needs_upcasting: | |
| self.upcast_vae() | |
| latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) | |
| # unscale/denormalize the latents | |
| # denormalize with the mean and std if available and not None | |
| has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None | |
| has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None | |
| if has_latents_mean and has_latents_std: | |
| latents_mean = ( | |
| torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) | |
| ) | |
| latents_std = ( | |
| torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) | |
| ) | |
| latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean | |
| else: | |
| latents = latents / self.vae.config.scaling_factor | |
| image = self.vae.decode(latents, return_dict=False)[0] | |
| # cast back to fp16 if needed | |
| if needs_upcasting: | |
| self.vae.to(dtype=torch.float16) | |
| else: | |
| image = latents | |
| # Return PIL Image. | |
| image = image[0].clip_(-1, 1) * 0.5 + 0.5 | |
| if has_background and do_blend: | |
| fg_mask = torch.sum(masks_g, dim=0).clip_(0, 1) | |
| image = blend(image, background[0], fg_mask) | |
| else: | |
| image = T.ToPILImage()(image) | |
| return image | |