|
import os |
|
import sys |
|
|
|
import subprocess |
|
import io |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
import torch |
|
import torchvision.transforms.functional as TF |
|
import torchvision.transforms as T |
|
import math |
|
import requests |
|
import cv2 |
|
from resize_right import resize |
|
from guided_diffusion.guided_diffusion.script_util import model_and_diffusion_defaults |
|
from types import SimpleNamespace |
|
from PIL import Image |
|
import argparse |
|
from guided_diffusion.guided_diffusion.unet import HFUNetModel |
|
from tqdm.notebook import tqdm |
|
from datetime import datetime |
|
from guided_diffusion.guided_diffusion.script_util import create_model_and_diffusion |
|
import clip |
|
from transformers import BertForSequenceClassification, BertTokenizer |
|
import gc |
|
import random |
|
|
|
|
|
|
|
PROJECT_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
useCPU = False |
|
skip_augs = False |
|
perlin_init = False |
|
|
|
use_secondary_model = False |
|
diffusion_model = "custom" |
|
|
|
|
|
side_x = 512 |
|
side_y = 512 |
|
|
|
diffusion_sampling_mode = 'ddim' |
|
use_checkpoint = True |
|
ViTB32 = False |
|
ViTB16 = False |
|
ViTL14 = True |
|
ViTL14_336px = False |
|
RN101 = False |
|
RN50 = False |
|
RN50x4 = False |
|
RN50x16 = False |
|
RN50x64 = False |
|
|
|
|
|
|
|
ViTB32_laion2b_e16 = False |
|
ViTB32_laion400m_e31 = False |
|
ViTB32_laion400m_32 = False |
|
ViTB32quickgelu_laion400m_e31 = False |
|
ViTB32quickgelu_laion400m_e32 = False |
|
ViTB16_laion400m_e31 = False |
|
ViTB16_laion400m_e32 = False |
|
RN50_yffcc15m = False |
|
RN50_cc12m = False |
|
RN50_quickgelu_yfcc15m = False |
|
RN50_quickgelu_cc12m = False |
|
RN101_yfcc15m = False |
|
RN101_quickgelu_yfcc15m = False |
|
|
|
|
|
|
|
|
|
steps = 100 |
|
tv_scale = 0 |
|
range_scale = 150 |
|
sat_scale = 0 |
|
cutn_batches = 1 |
|
skip_augs = False |
|
|
|
|
|
intermediate_saves = 0 |
|
intermediates_in_subfolder = True |
|
|
|
|
|
perlin_mode = 'mixed' |
|
set_seed = 'random_seed' |
|
eta = 0.8 |
|
clamp_grad = True |
|
clamp_max = 0.05 |
|
|
|
|
|
randomize_class = True |
|
clip_denoised = False |
|
fuzzy_prompt = False |
|
rand_mag = 0.05 |
|
|
|
|
|
cut_overview = "[12]*400+[4]*600" |
|
cut_innercut = "[4]*400+[12]*600" |
|
cut_ic_pow = "[1]*1000" |
|
cut_icgray_p = "[0.2]*400+[0]*600" |
|
|
|
|
|
|
|
use_vertical_symmetry = False |
|
use_horizontal_symmetry = False |
|
transformation_percent = [0.09] |
|
|
|
display_rate = 3 |
|
n_batches = 1 |
|
|
|
|
|
check_model_SHA = False |
|
interp_spline = 'Linear' |
|
resume_run = False |
|
batch_size = 1 |
|
|
|
|
|
def createPath(filepath): |
|
os.makedirs(filepath, exist_ok=True) |
|
|
|
|
|
def wget(url, outputdir): |
|
res = subprocess.run(['wget', url, '-P', f'{outputdir}'], stdout=subprocess.PIPE).stdout.decode('utf-8') |
|
print(res) |
|
|
|
|
|
def alpha_sigma_to_t(alpha, sigma): |
|
return torch.atan2(sigma, alpha) * 2 / math.pi |
|
|
|
|
|
def interp(t): |
|
return 3 * t**2 - 2 * t ** 3 |
|
|
|
|
|
def perlin(width, height, scale=10, device=None): |
|
gx, gy = torch.randn(2, width + 1, height + 1, 1, 1, device=device) |
|
xs = torch.linspace(0, 1, scale + 1)[:-1, None].to(device) |
|
ys = torch.linspace(0, 1, scale + 1)[None, :-1].to(device) |
|
wx = 1 - interp(xs) |
|
wy = 1 - interp(ys) |
|
dots = 0 |
|
dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys) |
|
dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys) |
|
dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys)) |
|
dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys)) |
|
return dots.permute(0, 2, 1, 3).contiguous().view(width * scale, height * scale) |
|
|
|
|
|
def perlin_ms(octaves, width, height, grayscale, device=None): |
|
out_array = [0.5] if grayscale else [0.5, 0.5, 0.5] |
|
|
|
for i in range(1 if grayscale else 3): |
|
scale = 2 ** len(octaves) |
|
oct_width = width |
|
oct_height = height |
|
for oct in octaves: |
|
p = perlin(oct_width, oct_height, scale, device) |
|
out_array[i] += p * oct |
|
scale //= 2 |
|
oct_width *= 2 |
|
oct_height *= 2 |
|
return torch.cat(out_array) |
|
|
|
|
|
def fetch(url_or_path): |
|
if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'): |
|
r = requests.get(url_or_path) |
|
r.raise_for_status() |
|
fd = io.BytesIO() |
|
fd.write(r.content) |
|
fd.seek(0) |
|
return fd |
|
return open(url_or_path, 'rb') |
|
|
|
|
|
def read_image_workaround(path): |
|
"""OpenCV reads images as BGR, Pillow saves them as RGB. Work around |
|
this incompatibility to avoid colour inversions.""" |
|
im_tmp = cv2.imread(path) |
|
return cv2.cvtColor(im_tmp, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
def parse_prompt(prompt): |
|
if prompt.startswith('http://') or prompt.startswith('https://'): |
|
vals = prompt.rsplit(':', 2) |
|
vals = [vals[0] + ':' + vals[1], *vals[2:]] |
|
else: |
|
vals = prompt.rsplit(':', 1) |
|
vals = vals + ['', '1'][len(vals):] |
|
return vals[0], float(vals[1]) |
|
|
|
|
|
def sinc(x): |
|
return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([])) |
|
|
|
|
|
def lanczos(x, a): |
|
cond = torch.logical_and(-a < x, x < a) |
|
out = torch.where(cond, sinc(x) * sinc(x / a), x.new_zeros([])) |
|
return out / out.sum() |
|
|
|
|
|
def ramp(ratio, width): |
|
n = math.ceil(width / ratio + 1) |
|
out = torch.empty([n]) |
|
cur = 0 |
|
for i in range(out.shape[0]): |
|
out[i] = cur |
|
cur += ratio |
|
return torch.cat([-out[1:].flip([0]), out])[1:-1] |
|
|
|
|
|
def resample(input, size, align_corners=True): |
|
n, c, h, w = input.shape |
|
dh, dw = size |
|
|
|
input = input.reshape([n * c, 1, h, w]) |
|
|
|
if dh < h: |
|
kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype) |
|
pad_h = (kernel_h.shape[0] - 1) // 2 |
|
input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect') |
|
input = F.conv2d(input, kernel_h[None, None, :, None]) |
|
|
|
if dw < w: |
|
kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype) |
|
pad_w = (kernel_w.shape[0] - 1) // 2 |
|
input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect') |
|
input = F.conv2d(input, kernel_w[None, None, None, :]) |
|
|
|
input = input.reshape([n, c, h, w]) |
|
return F.interpolate(input, size, mode='bicubic', align_corners=align_corners) |
|
|
|
|
|
class MakeCutouts(nn.Module): |
|
def __init__(self, cut_size, cutn, skip_augs=False): |
|
super().__init__() |
|
self.cut_size = cut_size |
|
self.cutn = cutn |
|
self.skip_augs = skip_augs |
|
self.augs = T.Compose([ |
|
T.RandomHorizontalFlip(p=0.5), |
|
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
|
T.RandomAffine(degrees=15, translate=(0.1, 0.1)), |
|
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
|
T.RandomPerspective(distortion_scale=0.4, p=0.7), |
|
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
|
T.RandomGrayscale(p=0.15), |
|
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
|
|
|
]) |
|
|
|
def forward(self, input): |
|
input = T.Pad(input.shape[2] // 4, fill=0)(input) |
|
sideY, sideX = input.shape[2:4] |
|
max_size = min(sideX, sideY) |
|
|
|
cutouts = [] |
|
for ch in range(self.cutn): |
|
if ch > self.cutn - self.cutn // 4: |
|
cutout = input.clone() |
|
else: |
|
size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size / max_size), 1.)) |
|
offsetx = torch.randint(0, abs(sideX - size + 1), ()) |
|
offsety = torch.randint(0, abs(sideY - size + 1), ()) |
|
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] |
|
|
|
if not self.skip_augs: |
|
cutout = self.augs(cutout) |
|
cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) |
|
del cutout |
|
|
|
cutouts = torch.cat(cutouts, dim=0) |
|
return cutouts |
|
|
|
|
|
class MakeCutoutsDango(nn.Module): |
|
def __init__(self, cut_size, args, |
|
Overview=4, |
|
InnerCrop=0, IC_Size_Pow=0.5, IC_Grey_P=0.2, |
|
): |
|
super().__init__() |
|
self.padargs = {} |
|
self.cutout_debug = False |
|
self.cut_size = cut_size |
|
self.Overview = Overview |
|
self.InnerCrop = InnerCrop |
|
self.IC_Size_Pow = IC_Size_Pow |
|
self.IC_Grey_P = IC_Grey_P |
|
self.augs = T.Compose([ |
|
T.RandomHorizontalFlip(p=0.5), |
|
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
|
T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation=T.InterpolationMode.BILINEAR), |
|
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
|
T.RandomGrayscale(p=0.1), |
|
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
|
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), |
|
]) |
|
|
|
def forward(self, input): |
|
cutouts = [] |
|
gray = T.Grayscale(3) |
|
sideY, sideX = input.shape[2:4] |
|
max_size = min(sideX, sideY) |
|
min_size = min(sideX, sideY, self.cut_size) |
|
output_shape = [1, 3, self.cut_size, self.cut_size] |
|
pad_input = F.pad(input, ((sideY - max_size) // 2, (sideY - max_size) // 2, (sideX - max_size) // 2, (sideX - max_size) // 2), **self.padargs) |
|
cutout = resize(pad_input, out_shape=output_shape) |
|
|
|
if self.Overview > 0: |
|
if self.Overview <= 4: |
|
if self.Overview >= 1: |
|
cutouts.append(cutout) |
|
if self.Overview >= 2: |
|
cutouts.append(gray(cutout)) |
|
if self.Overview >= 3: |
|
cutouts.append(TF.hflip(cutout)) |
|
if self.Overview == 4: |
|
cutouts.append(gray(TF.hflip(cutout))) |
|
else: |
|
cutout = resize(pad_input, out_shape=output_shape) |
|
for _ in range(self.Overview): |
|
cutouts.append(cutout) |
|
|
|
if self.cutout_debug: |
|
|
|
|
|
|
|
TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save("cutout_overview0.jpg", quality=99) |
|
|
|
if self.InnerCrop > 0: |
|
for i in range(self.InnerCrop): |
|
size = int(torch.rand([])**self.IC_Size_Pow * (max_size - min_size) + min_size) |
|
offsetx = torch.randint(0, sideX - size + 1, ()) |
|
offsety = torch.randint(0, sideY - size + 1, ()) |
|
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] |
|
if i <= int(self.IC_Grey_P * self.InnerCrop): |
|
cutout = gray(cutout) |
|
cutout = resize(cutout, out_shape=output_shape) |
|
cutouts.append(cutout) |
|
if self.cutout_debug: |
|
|
|
|
|
|
|
TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save("cutout_InnerCrop.jpg", quality=99) |
|
cutouts = torch.cat(cutouts) |
|
if skip_augs is not True: |
|
cutouts = self.augs(cutouts) |
|
return cutouts |
|
|
|
|
|
def spherical_dist_loss(x, y): |
|
x = F.normalize(x, dim=-1) |
|
y = F.normalize(y, dim=-1) |
|
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) |
|
|
|
|
|
def tv_loss(input): |
|
"""L2 total variation loss, as in Mahendran et al.""" |
|
input = F.pad(input, (0, 1, 0, 1), 'replicate') |
|
x_diff = input[..., :-1, 1:] - input[..., :-1, :-1] |
|
y_diff = input[..., 1:, :-1] - input[..., :-1, :-1] |
|
return (x_diff**2 + y_diff**2).mean([1, 2, 3]) |
|
|
|
|
|
def range_loss(input): |
|
return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3]) |
|
|
|
|
|
def symmetry_transformation_fn(x): |
|
|
|
use_horizontal_symmetry = False |
|
if use_horizontal_symmetry: |
|
[n, c, h, w] = x.size() |
|
x = torch.concat((x[:, :, :, :w // 2], torch.flip(x[:, :, :, :w // 2], [-1])), -1) |
|
print("horizontal symmetry applied") |
|
if use_vertical_symmetry: |
|
[n, c, h, w] = x.size() |
|
x = torch.concat((x[:, :, :h // 2, :], torch.flip(x[:, :, :h // 2, :], [-2])), -2) |
|
print("vertical symmetry applied") |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
other chaos settings |
|
""" |
|
|
|
|
|
outDirPath = f'{PROJECT_DIR}/images_out' |
|
createPath(outDirPath) |
|
model_path = f'{PROJECT_DIR}/models' |
|
createPath(model_path) |
|
|
|
|
|
|
|
DEVICE = torch.device('cuda:0' if (torch.cuda.is_available() and not useCPU) else 'cpu') |
|
print('Using device:', DEVICE) |
|
device = DEVICE |
|
if not useCPU: |
|
if torch.cuda.get_device_capability(DEVICE) == (8, 0): |
|
print('Disabling CUDNN for A100 gpu', file=sys.stderr) |
|
torch.backends.cudnn.enabled = False |
|
|
|
model_config = model_and_diffusion_defaults() |
|
model_config.update({ |
|
'attention_resolutions': '32, 16, 8', |
|
'class_cond': False, |
|
'diffusion_steps': 1000, |
|
'rescale_timesteps': True, |
|
'timestep_respacing': 250, |
|
'image_size': 512, |
|
'learn_sigma': True, |
|
'noise_schedule': 'linear', |
|
'num_channels': 256, |
|
'num_head_channels': 64, |
|
'num_res_blocks': 2, |
|
'resblock_updown': True, |
|
'use_checkpoint': use_checkpoint, |
|
'use_fp16': not useCPU, |
|
'use_scale_shift_norm': True, |
|
}) |
|
|
|
model_default = model_config['image_size'] |
|
normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) |
|
|
|
|
|
steps_per_checkpoint = steps + 10 |
|
|
|
timestep_respacing = f'ddim{steps}' |
|
diffusion_steps = (1000 // steps) * steps if steps < 1000 else steps |
|
model_config.update({ |
|
'timestep_respacing': timestep_respacing, |
|
'diffusion_steps': diffusion_steps, |
|
}) |
|
|
|
|
|
start_frame = 0 |
|
print('Starting Run:') |
|
if set_seed == 'random_seed': |
|
random.seed() |
|
seed = random.randint(0, 2**32) |
|
|
|
else: |
|
seed = int(set_seed) |
|
|
|
args = { |
|
|
|
'display_rate': display_rate, |
|
'n_batches': n_batches, |
|
'batch_size': batch_size, |
|
'steps': steps, |
|
'diffusion_sampling_mode': diffusion_sampling_mode, |
|
|
|
'tv_scale': tv_scale, |
|
'range_scale': range_scale, |
|
'sat_scale': sat_scale, |
|
'cutn_batches': cutn_batches, |
|
|
|
|
|
'timestep_respacing': timestep_respacing, |
|
'diffusion_steps': diffusion_steps, |
|
'cut_overview': eval(cut_overview), |
|
'cut_innercut': eval(cut_innercut), |
|
'cut_ic_pow': eval(cut_ic_pow), |
|
'cut_icgray_p': eval(cut_icgray_p), |
|
'intermediate_saves': intermediate_saves, |
|
'intermediates_in_subfolder': intermediates_in_subfolder, |
|
'steps_per_checkpoint': steps_per_checkpoint, |
|
'set_seed': set_seed, |
|
'eta': eta, |
|
'clamp_grad': clamp_grad, |
|
'clamp_max': clamp_max, |
|
'skip_augs': skip_augs, |
|
'randomize_class': randomize_class, |
|
'clip_denoised': clip_denoised, |
|
'fuzzy_prompt': fuzzy_prompt, |
|
'rand_mag': rand_mag, |
|
'use_vertical_symmetry': use_vertical_symmetry, |
|
'use_horizontal_symmetry': use_horizontal_symmetry, |
|
'transformation_percent': transformation_percent, |
|
} |
|
args = SimpleNamespace(**args) |
|
|
|
|
|
|
|
|
|
class Diffuser: |
|
def __init__(self, cutom_path='IDEA-CCNL/Taiyi-Diffusion-532M-Nature'): |
|
self.model_setup(cutom_path) |
|
|
|
def model_setup(self, custom_path): |
|
|
|
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' |
|
print(f'Prepping model...model name: {custom_path}') |
|
__, self.diffusion = create_model_and_diffusion(**model_config) |
|
self.model = HFUNetModel.from_pretrained(custom_path) |
|
|
|
|
|
|
|
|
|
self.model.requires_grad_(False).eval().to(device) |
|
for name, param in self.model.named_parameters(): |
|
if 'qkv' in name or 'norm' in name or 'proj' in name: |
|
param.requires_grad_() |
|
if model_config['use_fp16']: |
|
self.model.convert_to_fp16() |
|
print(f'Diffusion_model Loaded {diffusion_model}') |
|
|
|
|
|
print('Prepping model...model name: CLIP') |
|
self.taiyi_tokenizer = BertTokenizer.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese") |
|
self.taiyi_transformer = BertForSequenceClassification.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese").eval().to(device) |
|
self.clip_models = [] |
|
if ViTB32: |
|
self.clip_models.append(clip.load('ViT-B/32', jit=False)[0].eval().requires_grad_(False).to(device)) |
|
if ViTB16: |
|
self.clip_models.append(clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device)) |
|
if ViTL14: |
|
self.clip_models.append(clip.load('ViT-L/14', jit=False)[0].eval().requires_grad_(False).to(device)) |
|
if ViTL14_336px: |
|
self.clip_models.append(clip.load('ViT-L/14@336px', jit=False)[0].eval().requires_grad_(False).to(device)) |
|
print('CLIP Loaded') |
|
|
|
|
|
def generate(self, |
|
input_text_prompts=['夕阳西下'], |
|
init_image=None, |
|
skip_steps=10, |
|
clip_guidance_scale=7500, |
|
init_scale=2000, |
|
st_dynamic_image=None, |
|
seed=None, |
|
side_x=512, |
|
side_y=512, |
|
): |
|
|
|
seed = seed |
|
frame_num = 0 |
|
init_image = init_image |
|
init_scale = init_scale |
|
skip_steps = skip_steps |
|
loss_values = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frame_prompt = input_text_prompts |
|
|
|
print(f'Frame {frame_num} Prompt: {frame_prompt}') |
|
|
|
model_stats = [] |
|
for clip_model in self.clip_models: |
|
|
|
model_stat = {"clip_model": None, "target_embeds": [], "make_cutouts": None, "weights": []} |
|
model_stat["clip_model"] = clip_model |
|
|
|
for prompt in frame_prompt: |
|
txt, weight = parse_prompt(prompt) |
|
|
|
|
|
txt = self.taiyi_transformer(self.taiyi_tokenizer(txt, return_tensors='pt')['input_ids'].to(device)).logits |
|
if args.fuzzy_prompt: |
|
for i in range(25): |
|
model_stat["target_embeds"].append((txt + torch.randn(txt.shape).cuda() * args.rand_mag).clamp(0, 1)) |
|
model_stat["weights"].append(weight) |
|
else: |
|
model_stat["target_embeds"].append(txt) |
|
model_stat["weights"].append(weight) |
|
|
|
model_stat["target_embeds"] = torch.cat(model_stat["target_embeds"]) |
|
model_stat["weights"] = torch.tensor(model_stat["weights"], device=device) |
|
if model_stat["weights"].sum().abs() < 1e-3: |
|
raise RuntimeError('The weights must not sum to 0.') |
|
model_stat["weights"] /= model_stat["weights"].sum().abs() |
|
model_stats.append(model_stat) |
|
|
|
init = None |
|
if init_image is not None: |
|
|
|
init = init_image |
|
init = init.resize((side_x, side_y), Image.LANCZOS) |
|
init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1) |
|
|
|
cur_t = None |
|
|
|
def cond_fn(x, t, y=None): |
|
with torch.enable_grad(): |
|
x_is_NaN = False |
|
x = x.detach().requires_grad_() |
|
n = x.shape[0] |
|
|
|
my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t |
|
out = self.diffusion.p_mean_variance(self.model, x, my_t, clip_denoised=False, model_kwargs={'y': y}) |
|
fac = self.diffusion.sqrt_one_minus_alphas_cumprod[cur_t] |
|
x_in = out['pred_xstart'] * fac + x * (1 - fac) |
|
x_in_grad = torch.zeros_like(x_in) |
|
|
|
for model_stat in model_stats: |
|
for i in range(args.cutn_batches): |
|
t_int = int(t.item()) + 1 |
|
|
|
input_resolution = model_stat["clip_model"].visual.input_resolution |
|
|
|
|
|
|
|
cuts = MakeCutoutsDango(input_resolution, |
|
Overview=args.cut_overview[1000 - t_int], |
|
InnerCrop=args.cut_innercut[1000 - t_int], |
|
IC_Size_Pow=args.cut_ic_pow[1000 - t_int], |
|
IC_Grey_P=args.cut_icgray_p[1000 - t_int], |
|
args=args, |
|
) |
|
clip_in = normalize(cuts(x_in.add(1).div(2))) |
|
image_embeds = model_stat["clip_model"].encode_image(clip_in).float() |
|
dists = spherical_dist_loss(image_embeds.unsqueeze(1), model_stat["target_embeds"].unsqueeze(0)) |
|
dists = dists.view([args.cut_overview[1000 - t_int] + args.cut_innercut[1000 - t_int], n, -1]) |
|
losses = dists.mul(model_stat["weights"]).sum(2).mean(0) |
|
loss_values.append(losses.sum().item()) |
|
x_in_grad += torch.autograd.grad(losses.sum() * clip_guidance_scale, x_in)[0] / cutn_batches |
|
tv_losses = tv_loss(x_in) |
|
range_losses = range_loss(out['pred_xstart']) |
|
sat_losses = torch.abs(x_in - x_in.clamp(min=-1, max=1)).mean() |
|
loss = tv_losses.sum() * tv_scale + range_losses.sum() * range_scale + sat_losses.sum() * sat_scale |
|
if init is not None and init_scale: |
|
init_losses = self.lpips_model(x_in, init) |
|
loss = loss + init_losses.sum() * init_scale |
|
x_in_grad += torch.autograd.grad(loss, x_in)[0] |
|
if not torch.isnan(x_in_grad).any(): |
|
grad = -torch.autograd.grad(x_in, x, x_in_grad)[0] |
|
else: |
|
x_is_NaN = True |
|
grad = torch.zeros_like(x) |
|
if args.clamp_grad and not x_is_NaN: |
|
magnitude = grad.square().mean().sqrt() |
|
return grad * magnitude.clamp(max=args.clamp_max) / magnitude |
|
return grad |
|
|
|
if args.diffusion_sampling_mode == 'ddim': |
|
sample_fn = self.diffusion.ddim_sample_loop_progressive |
|
else: |
|
sample_fn = self.diffusion.plms_sample_loop_progressive |
|
|
|
for i in range(args.n_batches): |
|
current_time = datetime.now().strftime('%y%m%d-%H%M%S_%f') |
|
|
|
batchBar = tqdm(range(args.n_batches), desc="Batches") |
|
batchBar.n = i |
|
batchBar.refresh() |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
cur_t = self.diffusion.num_timesteps - skip_steps - 1 |
|
|
|
|
|
if args.diffusion_sampling_mode == 'ddim': |
|
samples = sample_fn( |
|
self.model, |
|
(batch_size, 3, side_y, side_x), |
|
clip_denoised=clip_denoised, |
|
model_kwargs={}, |
|
cond_fn=cond_fn, |
|
progress=True, |
|
skip_timesteps=skip_steps, |
|
init_image=init, |
|
randomize_class=randomize_class, |
|
eta=eta, |
|
transformation_fn=symmetry_transformation_fn, |
|
transformation_percent=args.transformation_percent |
|
) |
|
else: |
|
samples = sample_fn( |
|
self.model, |
|
(batch_size, 3, side_y, side_x), |
|
clip_denoised=clip_denoised, |
|
model_kwargs={}, |
|
cond_fn=cond_fn, |
|
progress=True, |
|
skip_timesteps=skip_steps, |
|
init_image=init, |
|
randomize_class=randomize_class, |
|
order=2, |
|
) |
|
|
|
for j, sample in enumerate(samples): |
|
cur_t -= 1 |
|
intermediateStep = False |
|
if args.steps_per_checkpoint is not None: |
|
if j % steps_per_checkpoint == 0 and j > 0: |
|
intermediateStep = True |
|
elif j in args.intermediate_saves: |
|
intermediateStep = True |
|
if j % args.display_rate == 0 or cur_t == -1 or intermediateStep: |
|
for k, image in enumerate(sample['pred_xstart']): |
|
|
|
|
|
if args.n_batches > 0: |
|
filename = f'{current_time}-{parse_prompt(prompt)[0]}.png' |
|
image = TF.to_pil_image(image.add(1).div(2).clamp(0, 1)) |
|
if j % args.display_rate == 0 or cur_t == -1: |
|
image.save(f'{outDirPath}/{filename}') |
|
if st_dynamic_image: |
|
st_dynamic_image.image(image, use_column_width=True) |
|
|
|
return image |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser(description="setting") |
|
parser.add_argument('--prompt', type=str, required=True) |
|
parser.add_argument('--text_scale', type=int, default=5000) |
|
parser.add_argument('--model_path', type=str, default="IDEA-CCNL/Taiyi-Diffusion-532M-Nature") |
|
parser.add_argument('--width', type=int, default=512) |
|
parser.add_argument('--height', type=int, default=512) |
|
|
|
user_args = parser.parse_args() |
|
|
|
dd = Diffuser(user_args.model_path) |
|
dd.generate([user_args.prompt], |
|
clip_guidance_scale=user_args.text_scale, |
|
side_x=user_args.width, |
|
side_y=user_args.height, |
|
) |
|
|