|  | import inspect | 
					
						
						|  | from typing import Callable, List, Optional, Tuple, Union | 
					
						
						|  |  | 
					
						
						|  | import numpy as np | 
					
						
						|  | import PIL.Image | 
					
						
						|  | import torch | 
					
						
						|  | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer | 
					
						
						|  |  | 
					
						
						|  | from diffusers import DiffusionPipeline | 
					
						
						|  | from diffusers.configuration_utils import FrozenDict | 
					
						
						|  | from diffusers.models import AutoencoderKL, UNet2DConditionModel | 
					
						
						|  | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput | 
					
						
						|  | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker | 
					
						
						|  | from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler | 
					
						
						|  | from diffusers.utils import deprecate, logging | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | logger = logging.get_logger(__name__) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def prepare_mask_and_masked_image(image, mask): | 
					
						
						|  | image = np.array(image.convert("RGB")) | 
					
						
						|  | image = image[None].transpose(0, 3, 1, 2) | 
					
						
						|  | image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 | 
					
						
						|  |  | 
					
						
						|  | mask = np.array(mask.convert("L")) | 
					
						
						|  | mask = mask.astype(np.float32) / 255.0 | 
					
						
						|  | mask = mask[None, None] | 
					
						
						|  | mask[mask < 0.5] = 0 | 
					
						
						|  | mask[mask >= 0.5] = 1 | 
					
						
						|  | mask = torch.from_numpy(mask) | 
					
						
						|  |  | 
					
						
						|  | masked_image = image * (mask < 0.5) | 
					
						
						|  |  | 
					
						
						|  | return mask, masked_image | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def check_size(image, height, width): | 
					
						
						|  | if isinstance(image, PIL.Image.Image): | 
					
						
						|  | w, h = image.size | 
					
						
						|  | elif isinstance(image, torch.Tensor): | 
					
						
						|  | *_, h, w = image.shape | 
					
						
						|  |  | 
					
						
						|  | if h != height or w != width: | 
					
						
						|  | raise ValueError(f"Image size should be {height}x{width}, but got {h}x{w}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def overlay_inner_image(image, inner_image, paste_offset: Tuple[int] = (0, 0)): | 
					
						
						|  | inner_image = inner_image.convert("RGBA") | 
					
						
						|  | image = image.convert("RGB") | 
					
						
						|  |  | 
					
						
						|  | image.paste(inner_image, paste_offset, inner_image) | 
					
						
						|  | image = image.convert("RGB") | 
					
						
						|  |  | 
					
						
						|  | return image | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ImageToImageInpaintingPipeline(DiffusionPipeline): | 
					
						
						|  | r""" | 
					
						
						|  | Pipeline for text-guided image-to-image inpainting using Stable Diffusion. *This is an experimental feature*. | 
					
						
						|  |  | 
					
						
						|  | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the | 
					
						
						|  | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | vae ([`AutoencoderKL`]): | 
					
						
						|  | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. | 
					
						
						|  | text_encoder ([`CLIPTextModel`]): | 
					
						
						|  | Frozen text-encoder. Stable Diffusion uses the text portion of | 
					
						
						|  | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically | 
					
						
						|  | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. | 
					
						
						|  | tokenizer (`CLIPTokenizer`): | 
					
						
						|  | Tokenizer of class | 
					
						
						|  | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). | 
					
						
						|  | unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. | 
					
						
						|  | scheduler ([`SchedulerMixin`]): | 
					
						
						|  | A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of | 
					
						
						|  | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. | 
					
						
						|  | safety_checker ([`StableDiffusionSafetyChecker`]): | 
					
						
						|  | Classification module that estimates whether generated images could be considered offensive or harmful. | 
					
						
						|  | Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. | 
					
						
						|  | feature_extractor ([`CLIPImageProcessor`]): | 
					
						
						|  | Model that extracts features from generated images to be used as inputs for the `safety_checker`. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | vae: AutoencoderKL, | 
					
						
						|  | text_encoder: CLIPTextModel, | 
					
						
						|  | tokenizer: CLIPTokenizer, | 
					
						
						|  | unet: UNet2DConditionModel, | 
					
						
						|  | scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], | 
					
						
						|  | safety_checker: StableDiffusionSafetyChecker, | 
					
						
						|  | feature_extractor: CLIPImageProcessor, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: | 
					
						
						|  | deprecation_message = ( | 
					
						
						|  | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" | 
					
						
						|  | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " | 
					
						
						|  | "to update the config accordingly as leaving `steps_offset` might led to incorrect results" | 
					
						
						|  | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," | 
					
						
						|  | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" | 
					
						
						|  | " file" | 
					
						
						|  | ) | 
					
						
						|  | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) | 
					
						
						|  | new_config = dict(scheduler.config) | 
					
						
						|  | new_config["steps_offset"] = 1 | 
					
						
						|  | scheduler._internal_dict = FrozenDict(new_config) | 
					
						
						|  |  | 
					
						
						|  | if safety_checker is None: | 
					
						
						|  | logger.warning( | 
					
						
						|  | f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" | 
					
						
						|  | " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" | 
					
						
						|  | " results in services or applications open to the public. Both the diffusers team and Hugging Face" | 
					
						
						|  | " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" | 
					
						
						|  | " it only for use-cases that involve analyzing network behavior or auditing its results. For more" | 
					
						
						|  | " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.register_modules( | 
					
						
						|  | vae=vae, | 
					
						
						|  | text_encoder=text_encoder, | 
					
						
						|  | tokenizer=tokenizer, | 
					
						
						|  | unet=unet, | 
					
						
						|  | scheduler=scheduler, | 
					
						
						|  | safety_checker=safety_checker, | 
					
						
						|  | feature_extractor=feature_extractor, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | 
					
						
						|  | r""" | 
					
						
						|  | Enable sliced attention computation. | 
					
						
						|  |  | 
					
						
						|  | When this option is enabled, the attention module will split the input tensor in slices, to compute attention | 
					
						
						|  | in several steps. This is useful to save some memory in exchange for a small speed decrease. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | slice_size (`str` or `int`, *optional*, defaults to `"auto"`): | 
					
						
						|  | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If | 
					
						
						|  | a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, | 
					
						
						|  | `attention_head_dim` must be a multiple of `slice_size`. | 
					
						
						|  | """ | 
					
						
						|  | if slice_size == "auto": | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | slice_size = self.unet.config.attention_head_dim // 2 | 
					
						
						|  | self.unet.set_attention_slice(slice_size) | 
					
						
						|  |  | 
					
						
						|  | def disable_attention_slicing(self): | 
					
						
						|  | r""" | 
					
						
						|  | Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go | 
					
						
						|  | back to computing attention in one step. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | self.enable_attention_slicing(None) | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def __call__( | 
					
						
						|  | self, | 
					
						
						|  | prompt: Union[str, List[str]], | 
					
						
						|  | image: Union[torch.FloatTensor, PIL.Image.Image], | 
					
						
						|  | inner_image: Union[torch.FloatTensor, PIL.Image.Image], | 
					
						
						|  | mask_image: Union[torch.FloatTensor, PIL.Image.Image], | 
					
						
						|  | height: int = 512, | 
					
						
						|  | width: int = 512, | 
					
						
						|  | num_inference_steps: int = 50, | 
					
						
						|  | guidance_scale: float = 7.5, | 
					
						
						|  | negative_prompt: Optional[Union[str, List[str]]] = None, | 
					
						
						|  | num_images_per_prompt: Optional[int] = 1, | 
					
						
						|  | eta: float = 0.0, | 
					
						
						|  | generator: Optional[torch.Generator] = None, | 
					
						
						|  | latents: Optional[torch.FloatTensor] = None, | 
					
						
						|  | output_type: Optional[str] = "pil", | 
					
						
						|  | return_dict: bool = True, | 
					
						
						|  | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 
					
						
						|  | callback_steps: int = 1, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ): | 
					
						
						|  | r""" | 
					
						
						|  | Function invoked when calling the pipeline for generation. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | prompt (`str` or `List[str]`): | 
					
						
						|  | The prompt or prompts to guide the image generation. | 
					
						
						|  | image (`torch.Tensor` or `PIL.Image.Image`): | 
					
						
						|  | `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will | 
					
						
						|  | be masked out with `mask_image` and repainted according to `prompt`. | 
					
						
						|  | inner_image (`torch.Tensor` or `PIL.Image.Image`): | 
					
						
						|  | `Image`, or tensor representing an image batch which will be overlayed onto `image`. Non-transparent | 
					
						
						|  | regions of `inner_image` must fit inside white pixels in `mask_image`. Expects four channels, with | 
					
						
						|  | the last channel representing the alpha channel, which will be used to blend `inner_image` with | 
					
						
						|  | `image`. If not provided, it will be forcibly cast to RGBA. | 
					
						
						|  | mask_image (`PIL.Image.Image`): | 
					
						
						|  | `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be | 
					
						
						|  | repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted | 
					
						
						|  | to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) | 
					
						
						|  | instead of 3, so the expected shape would be `(B, H, W, 1)`. | 
					
						
						|  | height (`int`, *optional*, defaults to 512): | 
					
						
						|  | The height in pixels of the generated image. | 
					
						
						|  | width (`int`, *optional*, defaults to 512): | 
					
						
						|  | The width in pixels of the generated image. | 
					
						
						|  | num_inference_steps (`int`, *optional*, defaults to 50): | 
					
						
						|  | The number of denoising steps. More denoising steps usually lead to a higher quality image at the | 
					
						
						|  | expense of slower inference. | 
					
						
						|  | guidance_scale (`float`, *optional*, defaults to 7.5): | 
					
						
						|  | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). | 
					
						
						|  | `guidance_scale` is defined as `w` of equation 2. of [Imagen | 
					
						
						|  | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > | 
					
						
						|  | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, | 
					
						
						|  | usually at the expense of lower image quality. | 
					
						
						|  | negative_prompt (`str` or `List[str]`, *optional*): | 
					
						
						|  | The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored | 
					
						
						|  | if `guidance_scale` is less than `1`). | 
					
						
						|  | num_images_per_prompt (`int`, *optional*, defaults to 1): | 
					
						
						|  | The number of images to generate per prompt. | 
					
						
						|  | eta (`float`, *optional*, defaults to 0.0): | 
					
						
						|  | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to | 
					
						
						|  | [`schedulers.DDIMScheduler`], will be ignored for others. | 
					
						
						|  | generator (`torch.Generator`, *optional*): | 
					
						
						|  | A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation | 
					
						
						|  | deterministic. | 
					
						
						|  | latents (`torch.FloatTensor`, *optional*): | 
					
						
						|  | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image | 
					
						
						|  | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents | 
					
						
						|  | tensor will ge generated by sampling using the supplied random `generator`. | 
					
						
						|  | output_type (`str`, *optional*, defaults to `"pil"`): | 
					
						
						|  | The output format of the generate image. Choose between | 
					
						
						|  | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. | 
					
						
						|  | return_dict (`bool`, *optional*, defaults to `True`): | 
					
						
						|  | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | 
					
						
						|  | plain tuple. | 
					
						
						|  | callback (`Callable`, *optional*): | 
					
						
						|  | A function that will be called every `callback_steps` steps during inference. The function will be | 
					
						
						|  | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. | 
					
						
						|  | callback_steps (`int`, *optional*, defaults to 1): | 
					
						
						|  | The frequency at which the `callback` function will be called. If not specified, the callback will be | 
					
						
						|  | called at every step. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: | 
					
						
						|  | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. | 
					
						
						|  | When returning a tuple, the first element is a list with the generated images, and the second element is a | 
					
						
						|  | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" | 
					
						
						|  | (nsfw) content, according to the `safety_checker`. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | if isinstance(prompt, str): | 
					
						
						|  | batch_size = 1 | 
					
						
						|  | elif isinstance(prompt, list): | 
					
						
						|  | batch_size = len(prompt) | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | 
					
						
						|  |  | 
					
						
						|  | if height % 8 != 0 or width % 8 != 0: | 
					
						
						|  | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | 
					
						
						|  |  | 
					
						
						|  | if (callback_steps is None) or ( | 
					
						
						|  | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) | 
					
						
						|  | ): | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" | 
					
						
						|  | f" {type(callback_steps)}." | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | check_size(image, height, width) | 
					
						
						|  | check_size(inner_image, height, width) | 
					
						
						|  | check_size(mask_image, height, width) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | text_inputs = self.tokenizer( | 
					
						
						|  | prompt, | 
					
						
						|  | padding="max_length", | 
					
						
						|  | max_length=self.tokenizer.model_max_length, | 
					
						
						|  | return_tensors="pt", | 
					
						
						|  | ) | 
					
						
						|  | text_input_ids = text_inputs.input_ids | 
					
						
						|  |  | 
					
						
						|  | if text_input_ids.shape[-1] > self.tokenizer.model_max_length: | 
					
						
						|  | removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) | 
					
						
						|  | logger.warning( | 
					
						
						|  | "The following part of your input was truncated because CLIP can only handle sequences up to" | 
					
						
						|  | f" {self.tokenizer.model_max_length} tokens: {removed_text}" | 
					
						
						|  | ) | 
					
						
						|  | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] | 
					
						
						|  | text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | bs_embed, seq_len, _ = text_embeddings.shape | 
					
						
						|  | text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) | 
					
						
						|  | text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | do_classifier_free_guidance = guidance_scale > 1.0 | 
					
						
						|  |  | 
					
						
						|  | if do_classifier_free_guidance: | 
					
						
						|  | uncond_tokens: List[str] | 
					
						
						|  | if negative_prompt is None: | 
					
						
						|  | uncond_tokens = [""] | 
					
						
						|  | elif type(prompt) is not type(negative_prompt): | 
					
						
						|  | raise TypeError( | 
					
						
						|  | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" | 
					
						
						|  | f" {type(prompt)}." | 
					
						
						|  | ) | 
					
						
						|  | elif isinstance(negative_prompt, str): | 
					
						
						|  | uncond_tokens = [negative_prompt] | 
					
						
						|  | elif batch_size != len(negative_prompt): | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" | 
					
						
						|  | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" | 
					
						
						|  | " the batch size of `prompt`." | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | uncond_tokens = negative_prompt | 
					
						
						|  |  | 
					
						
						|  | max_length = text_input_ids.shape[-1] | 
					
						
						|  | uncond_input = self.tokenizer( | 
					
						
						|  | uncond_tokens, | 
					
						
						|  | padding="max_length", | 
					
						
						|  | max_length=max_length, | 
					
						
						|  | truncation=True, | 
					
						
						|  | return_tensors="pt", | 
					
						
						|  | ) | 
					
						
						|  | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | seq_len = uncond_embeddings.shape[1] | 
					
						
						|  | uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1) | 
					
						
						|  | uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | num_channels_latents = self.vae.config.latent_channels | 
					
						
						|  | latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) | 
					
						
						|  | latents_dtype = text_embeddings.dtype | 
					
						
						|  | if latents is None: | 
					
						
						|  | if self.device.type == "mps": | 
					
						
						|  |  | 
					
						
						|  | latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( | 
					
						
						|  | self.device | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) | 
					
						
						|  | else: | 
					
						
						|  | if latents.shape != latents_shape: | 
					
						
						|  | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") | 
					
						
						|  | latents = latents.to(self.device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | image = overlay_inner_image(image, inner_image) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | mask, masked_image = prepare_mask_and_masked_image(image, mask_image) | 
					
						
						|  | mask = mask.to(device=self.device, dtype=text_embeddings.dtype) | 
					
						
						|  | masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) | 
					
						
						|  | masked_image_latents = 0.18215 * masked_image_latents | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | mask = mask.repeat(batch_size * num_images_per_prompt, 1, 1, 1) | 
					
						
						|  | masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 1, 1, 1) | 
					
						
						|  |  | 
					
						
						|  | mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask | 
					
						
						|  | masked_image_latents = ( | 
					
						
						|  | torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | num_channels_mask = mask.shape[1] | 
					
						
						|  | num_channels_masked_image = masked_image_latents.shape[1] | 
					
						
						|  |  | 
					
						
						|  | if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" | 
					
						
						|  | f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" | 
					
						
						|  | f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" | 
					
						
						|  | f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" | 
					
						
						|  | " `pipeline.unet` or your `mask_image` or `image` input." | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.scheduler.set_timesteps(num_inference_steps) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | timesteps_tensor = self.scheduler.timesteps.to(self.device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | latents = latents * self.scheduler.init_noise_sigma | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) | 
					
						
						|  | extra_step_kwargs = {} | 
					
						
						|  | if accepts_eta: | 
					
						
						|  | extra_step_kwargs["eta"] = eta | 
					
						
						|  |  | 
					
						
						|  | for i, t in enumerate(self.progress_bar(timesteps_tensor)): | 
					
						
						|  |  | 
					
						
						|  | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) | 
					
						
						|  |  | 
					
						
						|  | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if do_classifier_free_guidance: | 
					
						
						|  | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | 
					
						
						|  | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if callback is not None and i % callback_steps == 0: | 
					
						
						|  | step_idx = i // getattr(self.scheduler, "order", 1) | 
					
						
						|  | callback(step_idx, t, latents) | 
					
						
						|  |  | 
					
						
						|  | latents = 1 / 0.18215 * latents | 
					
						
						|  | image = self.vae.decode(latents).sample | 
					
						
						|  |  | 
					
						
						|  | image = (image / 2 + 0.5).clamp(0, 1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | image = image.cpu().permute(0, 2, 3, 1).float().numpy() | 
					
						
						|  |  | 
					
						
						|  | if self.safety_checker is not None: | 
					
						
						|  | safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( | 
					
						
						|  | self.device | 
					
						
						|  | ) | 
					
						
						|  | image, has_nsfw_concept = self.safety_checker( | 
					
						
						|  | images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | has_nsfw_concept = None | 
					
						
						|  |  | 
					
						
						|  | if output_type == "pil": | 
					
						
						|  | image = self.numpy_to_pil(image) | 
					
						
						|  |  | 
					
						
						|  | if not return_dict: | 
					
						
						|  | return (image, has_nsfw_concept) | 
					
						
						|  |  | 
					
						
						|  | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) | 
					
						
						|  |  |