DreamMix / FooocusSDXLInpaintAllInOnePipeline.py
Buğrahan Dânmez
Initialize the repo
7638e9c
import copy
import torch
import numpy as np
from PIL import Image
from typing import Any, Dict, List, Optional, Tuple, Union
from diffusers import StableDiffusionXLInpaintPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
retrieve_timesteps,
rescale_noise_cfg,
)
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import (
StableDiffusionXLPipelineOutput,
)
from transformers import set_seed
import random
from utils import (
add_fooocus_inpaint_patch,
add_fooocus_inpaint_head_patch_with_work,
sks_decompose,
orthogonal_decomposition,
KSampler,
)
import modules.anisotropic as anisotropic
import modules.inpaint_worker as inpaint_worker
def blur_guidance(latents, positive_x0, timestep, sharpness):
# ! Fooocus trick
# We implemented a carefully tuned variation of Section 5.1 of "Improving Sample Quality of Diffusion Models Using Self-Attention Guidance". The weight is set to very low, but this is Fooocus's final guarantee to make sure that the XL will never yield an overly smooth or plastic appearance (examples here). This can almost eliminate all cases for which XL still occasionally produces overly smooth results, even with negative ADM guidance. (Update 2023 Aug 18, the Gaussian kernel of SAG is changed to an anisotropic kernel for better structure preservation and fewer artifacts.)
current_step = 1.0 - timestep.to(latents) / 999.0
global_diffusion_progress = current_step.detach().cpu().numpy().tolist()
positive_eps = latents - positive_x0
alpha = 0.001 * sharpness * global_diffusion_progress
positive_eps_degraded = anisotropic.adaptive_anisotropic_filter(
x=positive_eps, g=positive_x0
)
positive_eps_degraded_weighted = positive_eps_degraded * alpha + positive_eps * (
1.0 - alpha
)
return latents - positive_eps_degraded_weighted
def prepare_noise(latent_image, seed=None, noise_inds=None):
"""
creates random noise given a latent image and a seed.
optional arg skip can be used to skip and discard x number of noise generations for a given seed
"""
generator = None
# if seed is not None:
# generator = torch.manual_seed(seed)
if noise_inds is None:
return torch.randn(
latent_image.size(),
dtype=latent_image.dtype,
layout=latent_image.layout,
generator=generator,
device="cpu",
)
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
noises = []
for i in range(unique_inds[-1] + 1):
noise = torch.randn(
[1] + list(latent_image.size())[1:],
dtype=latent_image.dtype,
layout=latent_image.layout,
generator=generator,
device="cpu",
)
if i in unique_inds:
noises.append(noise)
noises = [noises[i] for i in inverse]
noises = torch.cat(noises, axis=0)
return noises
def seed_everything(seed=1234):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
class FooocusSDXLInpaintPipeline(StableDiffusionXLInpaintPipeline):
def only_load_fooocus_unet_and_cover_pipe_unet_for_train(
self, fooocus_model_path
):
print(f"Loading fooocus unet from {fooocus_model_path} ...")
# _device = self.device
# self.unet = self.unet.to("cpu")
add_fooocus_inpaint_patch(
self.unet,
model_path=fooocus_model_path,
)
print("Finish loading fooocus unet")
def preload_fooocus_unet(
self, fooocus_model_path, lora_configs=[], add_double_sa=False
):
"""
lora_config: {
path: scale, for_unet: bool, for_fooocus: bool
}
"""
if hasattr(self, "fooocus_unet"):
print("fooocus_unet already loaded. Reloading.")
print(f"Loading fooocus unet from {fooocus_model_path} ...")
self.unload_lora_weights()
_device = self.device
self.unet = self.unet.to("cpu")
self.fooocus_unet = copy.deepcopy(self.unet).to(_device)
add_fooocus_inpaint_patch(
self.fooocus_unet,
model_path=fooocus_model_path,
)
print("fooocus unet loaded")
if add_double_sa:
self._add_double_sa(self.fooocus_unet)
if lora_configs == []:
print("Finish loading fooocus unet without lora")
return
adapter_names_unet, adapter_names_fooocus = [], []
adapter_scales_unet, adapter_scales_fooocus = [], []
for lora_config in lora_configs:
# scale, for_unet, for_fooocus = lora_setting
# {"model_path": "./lora-dreambooth-model/pytorch_lora_weights.safetensors", "scale": 1, "for_unet": True, "for_fooocus_unet":True},
assert (
lora_config["for_fooocus_unet"] or lora_config["for_unet"]
), "lora_config should be for_fooocus_unet or for_unet or both"
print(f"Loading lora... config: {lora_config} ...")
adapter_name = lora_config["model_path"].replace(".", "_")
if lora_config["for_raw_unet"]:
self.load_lora_weights(
lora_config["model_path"], adapter_name=adapter_name
)
adapter_names_unet.append(adapter_name)
adapter_scales_unet.append(lora_config["scale"])
if lora_config["for_fooocus_unet"]:
self.unet, self.fooocus_unet = self.fooocus_unet, self.unet
self.load_lora_weights(
lora_config["model_path"], adapter_name=adapter_name
)
adapter_names_fooocus.append(adapter_name)
adapter_scales_fooocus.append(lora_config["scale"])
self.unet, self.fooocus_unet = self.fooocus_unet, self.unet
self.unet, self.fooocus_unet = self.fooocus_unet, self.unet
self.set_adapters(adapter_names_fooocus, adapter_weights=adapter_scales_fooocus)
self.unet, self.fooocus_unet = self.fooocus_unet, self.unet
print("lora loaded")
self.fooocus_unet.to("cpu")
self.unet = self.unet.to(_device)
self.set_adapters(adapter_names_unet, adapter_weights=adapter_scales_unet)
print("Finish loading fooocus unet")
@torch.no_grad()
def __call__(
self,
debug=False,
decompose_prefix_prompt="",
isf_global_time=-1,
isf_global_ia = 1,
soft_blending=False,
sks_decompose_words=[],
fooocus_model_head_path=None,
fooocus_model_head_upscale_path=None,
sharpness=2,
fooocus_time=0.7,
inpaint_respective_field=0.618,
adm_scaler_positive=1,
adm_scaler_negative=1,
adm_scaler_end=0.0,
seed=None,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
image=None,
mask_image=None,
masked_image_latents: torch.FloatTensor = None,
height: Optional[int] = None,
width: Optional[int] = None,
padding_mask_crop: Optional[int] = None,
strength: float = 0.9999,
num_inference_steps: int = 50,
timesteps: List[int] = None,
denoising_start: Optional[float] = None,
denoising_end: Optional[float] = None,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image=None,
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
original_size: Tuple[int, int] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Tuple[int, int] = None,
negative_original_size: Optional[Tuple[int, int]] = None,
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
negative_target_size: Optional[Tuple[int, int]] = None,
aesthetic_score: float = 6.0,
negative_aesthetic_score: float = 2.5,
clip_skip: Optional[int] = None,
callback_on_step_end=None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
):
assert hasattr(
self, "fooocus_unet"
), "fooocus_unet not loaded. Use pipe.preload_fooocus_unet() first."
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# ! load Fooocus model
if seed is not None:
SEED_LIMIT_NUMPY = 2**32
seed = int(seed) % SEED_LIMIT_NUMPY
set_seed(seed)
seed_everything(seed)
device = self.vae.device
self.fooocu_unet = self.fooocus_unet.to("cpu")
self.unet = self.unet.to(device)
target_size = (height, width)
image = image.resize(target_size)
mask_image = mask_image.resize(target_size)
image_for_inpaint_work = image.copy()
mask_image_for_inpaint_work = mask_image.copy()
inpaint_work = inpaint_worker.InpaintWorker(
image=np.asarray(image),
mask=np.asarray(mask_image)[:, :, 0],
use_fill=strength > 0.99,
k=inpaint_respective_field,
path_upscale_models=fooocus_model_head_upscale_path,
)
if debug:
raise NotImplementedError("debug mode not implemented yet")
add_fooocus_inpaint_head_patch_with_work(
self.fooocus_unet, self, fooocus_model_head_path, inpaint_work
)
self.fooocus_unet = self.fooocus_unet.to(device)
# image = Image.fromarray(inpaint_work.interested_fill)
image = Image.fromarray(inpaint_work.interested_image)
mask_image = Image.fromarray(inpaint_work.interested_mask)
# ! load Fooocus model end
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
# 1. Check inputs
self.check_inputs(
prompt,
prompt_2,
image,
mask_image,
height,
width,
strength,
callback_steps,
output_type,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
callback_on_step_end_tensor_inputs,
padding_mask_crop,
)
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._denoising_end = denoising_end
self._denoising_start = denoising_start
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# 3. Encode input prompt
text_encoder_lora_scale = (
self.cross_attention_kwargs.get("scale", None)
if self.cross_attention_kwargs is not None
else None
)
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
# ! HHH sks_decompose
prompt_embeds_decomposed = None
if len(sks_decompose_words) > 0:
decompose_words_num = len(sks_decompose_words)
decompose_str = " ".join(sks_decompose_words)
decompose_str = decompose_prefix_prompt + " " + decompose_str
(
sks_raw_prompt_embeds,
_,
pooled_sks_raw_prompt_embeds,
_,
) = self.encode_prompt(
prompt=decompose_str,
prompt_2=decompose_str,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=False,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
alpha = 0.0
prompt_embeds_decomposed = prompt_embeds.clone()
prompt_embeds_decomposed[0] = alpha * prompt_embeds[0] + (
1 - alpha
) * sks_decompose(
prompt,
prompt_embeds[0],
sks_raw_prompt_embeds[0],
decompose_words_num,
decompose_prefix_prompt,
)
prompt_embeds_decomposed_pooled = orthogonal_decomposition(
pooled_prompt_embeds[0], pooled_sks_raw_prompt_embeds[0]
).unsqueeze(0)
# 4. set timesteps
def denoising_value_valid(dnv):
return isinstance(dnv, float) and 0 < dnv < 1
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps
)
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps,
strength,
device,
denoising_start=(
self.denoising_start
if denoising_value_valid(self.denoising_start)
else None
),
)
# check that number of inference steps is not < 1 - as this doesn't make sense
if num_inference_steps < 1:
raise ValueError(
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
)
# 5. Preprocess mask and image
image_latents = inpaint_work.latent
mask_latent = inpaint_work.latent_mask
ksampler = KSampler(image_latents, num_inference_steps, device)
noise = prepare_noise(image_latents, seed=seed).to(device=device)
if strength > 0.9999:
noise = noise * torch.sqrt(1.0 + ksampler.sigmas[0] ** 2.0)
else:
noise = noise * ksampler.sigmas[0]
latents = image_latents + noise
# latents = noise
# 8. Check that sizes of mask, masked image and latents match
# 8.1 Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
height, width = latents.shape[-2:]
height = height * self.vae_scale_factor
width = width * self.vae_scale_factor
original_size = original_size or (height, width)
target_size = target_size or (height, width)
# 10. Prepare added time ids & embeddings
if negative_original_size is None:
negative_original_size = original_size
if negative_target_size is None:
negative_target_size = target_size
add_text_embeds = pooled_prompt_embeds
if self.text_encoder_2 is None:
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
else:
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left,
target_size,
aesthetic_score,
negative_aesthetic_score,
negative_original_size,
negative_crops_coords_top_left,
negative_target_size,
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
if prompt_embeds_decomposed is not None:
prompt_embeds_decomposed = torch.cat([negative_prompt_embeds, prompt_embeds_decomposed], dim=0)
add_text_embeds_pooled = torch.cat(
[negative_pooled_prompt_embeds, prompt_embeds_decomposed_pooled], dim=0
)
add_text_embeds = torch.cat(
[negative_pooled_prompt_embeds, add_text_embeds], dim=0
)
add_neg_time_ids = add_neg_time_ids.repeat(
batch_size * num_images_per_prompt, 1
)
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device)
if prompt_embeds_decomposed is not None:
prompt_embeds_decomposed = prompt_embeds_decomposed.to(device)
prompt_embeds, prompt_embeds_decomposed = prompt_embeds_decomposed, prompt_embeds
add_text_embeds_pooled = add_text_embeds_pooled.to(device)
add_text_embeds, add_text_embeds_pooled = add_text_embeds_pooled, add_text_embeds
# ! Negative ADM guidance
original_size_scaler = (
original_size[0] * adm_scaler_positive,
original_size[1] * adm_scaler_positive,
)
negative_original_size_scaler = (
negative_original_size[0] * adm_scaler_negative,
negative_original_size[1] * adm_scaler_negative,
)
add_time_ids_scaler, add_neg_time_ids_scaler = self._get_add_time_ids(
original_size_scaler,
crops_coords_top_left,
target_size,
aesthetic_score,
negative_aesthetic_score,
negative_original_size_scaler,
negative_crops_coords_top_left,
negative_target_size,
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
add_time_ids_scaler = add_time_ids_scaler.repeat(
batch_size * num_images_per_prompt, 1
)
if self.do_classifier_free_guidance:
add_neg_time_ids_scaler = add_neg_time_ids_scaler.repeat(
batch_size * num_images_per_prompt, 1
)
add_time_ids_scaler = torch.cat(
[add_neg_time_ids_scaler, add_time_ids_scaler], dim=0
)
add_time_ids_scaler = add_time_ids_scaler.to(device)
# ! Negative ADM guidance end
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
self.do_classifier_free_guidance,
)
if (
self.denoising_end is not None
and self.denoising_start is not None
and denoising_value_valid(self.denoising_end)
and denoising_value_valid(self.denoising_start)
and self.denoising_start >= self.denoising_end
):
raise ValueError(
f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
+ f" {self.denoising_end} when using type float."
)
elif self.denoising_end is not None and denoising_value_valid(
self.denoising_end
):
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
)
)
num_inference_steps = len(
list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))
)
timesteps = timesteps[:num_inference_steps]
# 11.1 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
batch_size * num_images_per_prompt
)
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)
energy_generator = None
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
# for i, t in enumerate(timesteps):
for i in range(num_inference_steps):
if self.interrupt:
continue
if i == isf_global_time:
def image_blending_toglobal(latents, inpaint_work, isf_global_ia=1):
latents = pred_x0
needs_upcasting = (self.vae.dtype == torch.float16 and self.vae.config.force_upcast)
if needs_upcasting:
self.upcast_vae()
latents = latents.to(
next(iter(self.vae.post_quant_conv.parameters())).dtype
)
latents = latents / self.vae.config.scaling_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# image[0].save("test_inpaint_toglobal_before.png")
image = [np.array(x) for x in image]
image = [inpaint_work.post_process(x, soft_blending) for x in image]
image = [Image.fromarray(x) for x in image]
image = image[0]
# image.save("test_inpaint_toglobal_after.png")
if isf_global_ia < 1:
image = inpaint_worker.InpaintWorker(
image=np.asarray(image),
mask=np.asarray(mask_image_for_inpaint_work)[:, :, 0],
use_fill=False,
k=isf_global_ia,
path_upscale_models=fooocus_model_head_upscale_path,
).interested_image
image = Image.fromarray(image)
# image.save("test_inpaint_toglobal_after_crop.png")
image = self.image_processor.preprocess(image).to(latents)
latents = self._encode_vae_image(image=image, generator=None)
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
return latents
latents = image_blending_toglobal(latents, inpaint_work, isf_global_ia)
inpaint_work = inpaint_worker.InpaintWorker(
image=np.asarray(image_for_inpaint_work),
mask=np.asarray(mask_image_for_inpaint_work)[:, :, 0],
use_fill=False,
k=isf_global_ia,
path_upscale_models=fooocus_model_head_upscale_path,
)
ksampler = KSampler(latents, num_inference_steps, device)
sigma = ksampler.sigmas[i]
energy_sigma = sigma.reshape([1] + [1] * (len(latents.shape) - 1))
current_energy = torch.randn(
latents.size(), dtype=latents.dtype, generator=energy_generator, device="cpu").to(latents) * energy_sigma
latents = latents + current_energy
add_fooocus_inpaint_head_patch_with_work(
self.fooocus_unet,
self,
fooocus_model_head_path,
inpaint_work,
)
image_latents = inpaint_work.latent
mask_latent = inpaint_work.latent_mask
if prompt_embeds_decomposed is not None:
prompt_embeds, prompt_embeds_decomposed = prompt_embeds_decomposed, prompt_embeds
add_text_embeds, add_text_embeds_pooled = add_text_embeds_pooled, add_text_embeds
t = ksampler.timestep(i)
# ! fooocus add noise
sigma = ksampler.sigmas[i]
energy_sigma = sigma.reshape([1] + [1] * (len(latents.shape) - 1))
current_energy = torch.randn(
latents.size(), dtype=latents.dtype, generator=energy_generator, device="cpu").to(latents) * energy_sigma
latents = latents * mask_latent + (image_latents + current_energy) * (1.0 - mask_latent)
# ! fooocus add noise end
latent_model_input = (
torch.cat([latents] * 2)
if self.do_classifier_free_guidance
else latents
)
latent_model_input = ksampler.calculate_input(i, latent_model_input).to(
dtype=self.fooocus_unet.dtype
)
#! Fooocus part
if i <= int(num_inference_steps * adm_scaler_end):
added_cond_kwargs = {
"text_embeds": add_text_embeds,
"time_ids": add_time_ids_scaler,
}
else:
added_cond_kwargs = {
"text_embeds": add_text_embeds,
"time_ids": add_time_ids,
}
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
added_cond_kwargs["image_embeds"] = image_embeds
# predict the noise residual
# if i <= int(len(timesteps)*fooocus_time*strength):
if i <= int(num_inference_steps * fooocus_time * strength):
# if fooocus_unet.device == torch.device("cpu"): # save cuda memory
self.unet = self.unet.to("cpu")
self.fooocus_unet = self.fooocus_unet.to(device)
noise_pred = self.fooocus_unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
else:
# if self.unet.device == torch.device("cpu"): # save cuda memory
self.fooocus_unet = self.fooocus_unet.to("cpu")
self.unet = self.unet.to(device)
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
positive_x0 = ksampler.calculate_denoised(
i, noise_pred_text, latents
)
negative_x0 = ksampler.calculate_denoised(
i, noise_pred_uncond, latents
)
if sharpness > 0:
positive_x0 = blur_guidance(latents, positive_x0, t, sharpness)
negative_eps = latents - negative_x0
positive_eps = latents - positive_x0
final_eps = negative_eps + self.guidance_scale * (
positive_eps - negative_eps
)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
final_eps = rescale_noise_cfg(
final_eps,
positive_eps,
guidance_rescale=self.guidance_rescale,
)
pred_x0 = latents - final_eps
else:
pred_x0 = ksampler.calculate_denoised(i, noise_pred, latents)
if sharpness > 0:
pred_x0 = blur_guidance(latents, pred_x0, t, sharpness)
# compute the previous noisy sample x_t -> x_t-1
latents = ksampler.step(i, pred_x0, latents)
#! Fooocus part end
if (i + 1) % self.scheduler.order == 0:
progress_bar.update()
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = (
self.vae.dtype == torch.float16 and self.vae.config.force_upcast
)
if needs_upcasting:
self.upcast_vae()
latents = latents.to(
next(iter(self.vae.post_quant_conv.parameters())).dtype
)
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
has_latents_mean = (
hasattr(self.vae.config, "latents_mean")
and self.vae.config.latents_mean is not None
)
has_latents_std = (
hasattr(self.vae.config, "latents_std")
and self.vae.config.latents_std is not None
)
if has_latents_mean and has_latents_std:
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, 4, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = (
torch.tensor(self.vae.config.latents_std)
.view(1, 4, 1, 1)
.to(latents.device, latents.dtype)
)
latents = (
latents * latents_std / self.vae.config.scaling_factor
+ latents_mean
)
else:
latents = latents / self.vae.config.scaling_factor
image = self.vae.decode(latents, return_dict=False)[0]
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
else:
return StableDiffusionXLPipelineOutput(images=latents)
# apply watermark if available
if self.watermark is not None:
image = self.watermark.apply_watermark(image)
image = self.image_processor.postprocess(image, output_type=output_type)
image = [np.array(x) for x in image]
image = [inpaint_work.post_process(x) for x in image]
image = [Image.fromarray(x) for x in image]
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return StableDiffusionXLPipelineOutput(images=image)