testapi / manga_translator /inpainting /guided_ldm_inpainting.py
Sunday01's picture
up
9dce458
import einops
import torch
import torch as th
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import cv2
from .ldm.modules.diffusionmodules.util import (
conv_nd,
linear,
zero_module,
timestep_embedding,
)
from .ldm.modules.diffusionmodules.util import noise_like
from einops import rearrange, repeat
from torchvision.utils import make_grid
from .ldm.modules.attention import SpatialTransformer
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
from .ldm.models.diffusion.ddpm import LatentDiffusion
from .ldm.models.diffusion.ddim import DDIMSampler
from .ldm.util import log_txt_as_img, exists, instantiate_from_config
class GuidedDDIMSample(DDIMSampler) :
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [torch.cat([
unconditional_conditioning[k][i],
c[k][i]]) for i in range(len(c[k]))]
else:
c_in[k] = torch.cat([
unconditional_conditioning[k],
c[k]])
elif isinstance(c, list):
c_in = list()
assert isinstance(unconditional_conditioning, list)
for i in range(len(c)):
c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
else:
c_in = torch.cat([unconditional_conditioning, c])
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
e_t = model_output
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def decode(self, x_latent, cond, t_start, init_latent=None, nmask=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False, callback=None):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
total_steps = len(timesteps)
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running Guided DDIM Sampling with {len(timesteps)} timesteps, t_start={t_start}")
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
p = (i + (total_steps - t_start) + 1) / (total_steps)
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
if nmask is not None :
noised_input = self.model.q_sample(init_latent.to(x_latent.device), ts.to(x_latent.device))
x_dec = (1 - nmask) * noised_input + nmask * x_dec
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
if callback: callback(i)
return x_dec
def get_inpainting_image_condition(model, image, mask) :
conditioning_mask = np.array(mask.convert("L"))
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
conditioning_mask = torch.round(conditioning_mask)
conditioning_mask = conditioning_mask.to(device=image.device, dtype=image.dtype)
conditioning_image = torch.lerp(
image,
image * (1.0 - conditioning_mask),
1
)
conditioning_image = model.get_first_stage_encoding(model.encode_first_stage(conditioning_image))
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=conditioning_image.shape[-2:])
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
return image_conditioning
def get_empty_image_condition(latent) :
return latent.new_zeros(latent.shape[0], 5, latent.shape[2], latent.shape[3])
from PIL import Image, ImageFilter, ImageOps
def fill_mask_input(image, mask):
"""fills masked regions with colors from image using blur. Not extremely effective."""
image_mod = Image.new('RGBA', (image.width, image.height))
image_masked = Image.new('RGBa', (image.width, image.height))
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
image_masked = image_masked.convert('RGBa')
for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
for _ in range(repeats):
image_mod.alpha_composite(blurred)
return image_mod.convert("RGB")
class GuidedLDM(LatentDiffusion):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@torch.no_grad()
def img2img_inpaint(
self,
image: Image.Image,
c_text: str,
uc_text: str,
mask: Image.Image,
ddim_steps = 50,
mask_blur: int = 16,
device: str = 'cpu',
**kwargs) -> Image.Image :
ddim_sampler = GuidedDDIMSample(self)
# move to device mps, cuda or cpu
if device.startswith('cuda') or device == 'mps':
self.cond_stage_model.to(device)
self.first_stage_model.to(device)
c_text = self.get_learned_conditioning([c_text])
uc_text = self.get_learned_conditioning([uc_text])
cond = {"c_crossattn": [c_text]}
uc_cond = {"c_crossattn": [uc_text]}
image_mask = mask
image_mask = image_mask.convert('L')
image_mask = image_mask.filter(ImageFilter.GaussianBlur(mask_blur))
latent_mask = image_mask
image = fill_mask_input(image, latent_mask)
image = np.array(image).astype(np.float32) / 127.5 - 1.0
image = np.moveaxis(image, 2, 0)
image = torch.from_numpy(image).to(device)[None]
init_latent = self.get_first_stage_encoding(self.encode_first_stage(image))
init_mask = latent_mask
latmask = init_mask.convert('RGB').resize((init_latent.shape[3], init_latent.shape[2]))
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
latmask = latmask[0]
latmask = np.around(latmask)
latmask = np.tile(latmask[None], (4, 1, 1))
nmask = torch.asarray(latmask).to(init_latent.device).float()
init_latent = (1 - nmask) * init_latent + nmask * torch.randn_like(init_latent)
denoising_strength = 1
if self.model.conditioning_key == 'hybrid' :
image_cdt = get_inpainting_image_condition(self, image, image_mask)
cond["c_concat"] = [image_cdt]
uc_cond["c_concat"] = [image_cdt]
steps = ddim_steps
t_enc = int(min(denoising_strength, 0.999) * steps)
eta = 0
noise = torch.randn_like(init_latent)
ddim_sampler.make_schedule(ddim_num_steps=steps, ddim_eta=eta, ddim_discretize="uniform", verbose=False)
x1 = ddim_sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * int(init_latent.shape[0])).to(device), noise=noise)
if device.startswith('cuda') or device == 'mps':
self.cond_stage_model.cpu()
self.first_stage_model.cpu()
if device.startswith('cuda') or device == 'mps':
self.model.to(device)
decoded = ddim_sampler.decode(x1, cond,t_enc,init_latent=init_latent,nmask=nmask,unconditional_guidance_scale=7,unconditional_conditioning=uc_cond)
if device.startswith('cuda') or device == 'mps':
self.model.cpu()
if mask is not None :
decoded = init_latent * (1 - nmask) + decoded * nmask
if device.startswith('cuda') or device == 'mps':
self.first_stage_model.to(device)
x_samples = self.decode_first_stage(decoded)
if device.startswith('cuda') or device == 'mps':
self.first_stage_model.cpu()
return torch.clip(x_samples, -1, 1)
import os
import torch
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
def get_state_dict(d):
return d.get('state_dict', d)
def load_state_dict(ckpt_path, location='cpu'):
_, extension = os.path.splitext(ckpt_path)
if extension.lower() == ".safetensors":
import safetensors.torch
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
else:
state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
state_dict = get_state_dict(state_dict)
return state_dict
def create_model(config_path):
config = OmegaConf.load(config_path)
model = instantiate_from_config(config.model).cpu()
return model