vqgan_clip / main.py
Axolotlily's picture
Upload main.py
2a5b13c
import argparse
import math
import random
from vqgan_clip.grad import *
from vqgan_clip.helpers import *
from vqgan_clip.inits import *
from vqgan_clip.masking import *
from vqgan_clip.optimizers import *
from urllib.request import urlopen
from tqdm import tqdm
import sys
import os
from omegaconf import OmegaConf
from taming.models import cond_transformer, vqgan
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from torch.cuda import get_device_properties
torch.backends.cudnn.benchmark = False
from torch_optimizer import DiffGrad, AdamP, RAdam
import clip
import kornia.augmentation as K
import numpy as np
import imageio
from PIL import ImageFile, Image, PngImagePlugin, ImageChops
ImageFile.LOAD_TRUNCATED_IMAGES = True
from subprocess import Popen, PIPE
import re
from packaging import version
# Supress warnings
import warnings
warnings.filterwarnings('ignore')
# Check for GPU and reduce the default image size if low VRAM
default_image_size = 512 # >8GB VRAM
if not torch.cuda.is_available():
default_image_size = 256 # no GPU found
elif get_device_properties(0).total_memory <= 2 ** 33: # 2 ** 33 = 8,589,934,592 bytes = 8 GB
default_image_size = 318 # <8GB VRAM
def parse():
vq_parser = argparse.ArgumentParser(description='Image generation using VQGAN+CLIP')
vq_parser.add_argument("-aug", "--augments", nargs='+', action='append', type=str, choices=['Hf','Ji','Sh','Pe','Ro','Af','Et','Ts','Er'],
help="Enabled augments (latest vut method only)", default=[['Hf','Af', 'Pe', 'Ji', 'Er']], dest='augments')
vq_parser.add_argument("-cd", "--cuda_device", type=str, help="Cuda device to use", default="cuda:0", dest='cuda_device')
vq_parser.add_argument("-ckpt", "--vqgan_checkpoint", type=str, help="VQGAN checkpoint", default=f'checkpoints/vqgan_imagenet_f16_16384.ckpt',
dest='vqgan_checkpoint')
vq_parser.add_argument("-conf", "--vqgan_config", type=str, help="VQGAN config", default=f'checkpoints/vqgan_imagenet_f16_16384.yaml', dest='vqgan_config')
vq_parser.add_argument("-cpe", "--change_prompt_every", type=int, help="Prompt change frequency", default=0, dest='prompt_frequency')
vq_parser.add_argument("-cutm", "--cut_method", type=str, help="Cut method", choices=['original','latest'],
default='latest', dest='cut_method')
vq_parser.add_argument("-cutp", "--cut_power", type=float, help="Cut power", default=1., dest='cut_pow')
vq_parser.add_argument("-cuts", "--num_cuts", type=int, help="Number of cuts", default=32, dest='cutn')
vq_parser.add_argument("-d", "--deterministic", action='store_true', help="Enable cudnn.deterministic?", dest='cudnn_determinism')
vq_parser.add_argument("-i", "--iterations", type=int, help="Number of iterations", default=500, dest='max_iterations')
vq_parser.add_argument("-ifps", "--input_video_fps", type=float,
help="When creating an interpolated video, use this as the input fps to interpolate from (>0 & <ofps)", default=15,
dest='input_video_fps')
vq_parser.add_argument("-ii", "--init_image", type=str, help="Initial image", default=None, dest='init_image')
vq_parser.add_argument("-in", "--init_noise", type=str, help="Initial noise image (pixels or gradient)", default=None, dest='init_noise')
vq_parser.add_argument("-ip", "--image_prompts", type=str, help="Image prompts / target image", default=[], dest='image_prompts')
vq_parser.add_argument("-iw", "--init_weight", type=float, help="Initial weight", default=0., dest='init_weight')
vq_parser.add_argument("-lr", "--learning_rate", type=float, help="Learning rate", default=0.1, dest='step_size')
vq_parser.add_argument("-m", "--clip_model", type=str, help="CLIP model (e.g. ViT-B/32, ViT-B/16)", default='ViT-B/32', dest='clip_model')
vq_parser.add_argument("-nps", "--noise_prompt_seeds", nargs="*", type=int, help="Noise prompt seeds", default=[], dest='noise_prompt_seeds')
vq_parser.add_argument("-npw", "--noise_prompt_weights", nargs="*", type=float, help="Noise prompt weights", default=[], dest='noise_prompt_weights')
vq_parser.add_argument("-o", "--output", type=str, help="Output filename", default="output.png", dest='output')
vq_parser.add_argument("-ofps", "--output_video_fps", type=float,
help="Create an interpolated video (Nvidia GPU only) with this fps (min 10. best set to 30 or 60)", default=0, dest='output_video_fps')
vq_parser.add_argument("-opt", "--optimiser", type=str, help="Optimiser", choices=['Adam','AdamW','Adagrad','Adamax','DiffGrad','AdamP','RAdam','RMSprop'],
default='Adam', dest='optimiser')
vq_parser.add_argument("-p", "--prompts", type=str, help="Text prompts", default=None, dest='prompts')
vq_parser.add_argument("-s", "--size", nargs=2, type=int, help="Image size (width height) (default: %(default)s)",
default=[default_image_size, default_image_size], dest='size')
vq_parser.add_argument("-sd", "--seed", type=int, help="Seed", default=None, dest='seed')
vq_parser.add_argument("-se", "--save_every", type=int, help="Save image iterations", default=50, dest='display_freq')
vq_parser.add_argument("-vid", "--video", action='store_true', help="Create video frames?", dest='make_video')
vq_parser.add_argument("-vl", "--video_length", type=float, help="Video length in seconds (not interpolated)", default=10, dest='video_length')
vq_parser.add_argument("-vsd", "--video_style_dir", type=str, help="Directory with video frames to style", default=None, dest='video_style_dir')
vq_parser.add_argument("-zs", "--zoom_start", type=int, help="Zoom start iteration", default=0, dest='zoom_start')
vq_parser.add_argument("-zsc", "--zoom_scale", type=float, help="Zoom scale %", default=0.99, dest='zoom_scale')
vq_parser.add_argument("-zse", "--zoom_save_every", type=int, help="Save zoom image iterations", default=10, dest='zoom_frequency')
vq_parser.add_argument("-zsx", "--zoom_shift_x", type=int, help="Zoom shift x (left/right) amount in pixels", default=0, dest='zoom_shift_x')
vq_parser.add_argument("-zsy", "--zoom_shift_y", type=int, help="Zoom shift y (up/down) amount in pixels", default=0, dest='zoom_shift_y')
vq_parser.add_argument("-zvid", "--zoom_video", action='store_true', help="Create zoom video?", dest='make_zoom_video')
args = vq_parser.parse_args()
if not args.prompts and not args.image_prompts:
raise Exception("You must supply a text or image prompt")
torch.backends.cudnn.deterministic = args.cudnn_determinism
# Split text prompts using the pipe character (weights are split later)
if args.prompts:
# For stories, there will be many phrases
story_phrases = [phrase.strip() for phrase in args.prompts.split("^")]
# Make a list of all phrases
all_phrases = []
for phrase in story_phrases:
all_phrases.append(phrase.split("|"))
# First phrase
args.prompts = all_phrases[0]
# Split target images using the pipe character (weights are split later)
if args.image_prompts:
args.image_prompts = args.image_prompts.split("|")
args.image_prompts = [image.strip() for image in args.image_prompts]
if args.make_video and args.make_zoom_video:
print("Warning: Make video and make zoom video are mutually exclusive.")
args.make_video = False
# Make video steps directory
if args.make_video or args.make_zoom_video:
if not os.path.exists('steps'):
os.mkdir('steps')
return args
class Prompt(nn.Module):
def __init__(self, embed, weight=1., stop=float('-inf')):
super().__init__()
self.register_buffer('embed', embed)
self.register_buffer('weight', torch.as_tensor(weight))
self.register_buffer('stop', torch.as_tensor(stop))
def forward(self, input):
input_normed = F.normalize(input.unsqueeze(1), dim=2)
embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
dists = dists * self.weight.sign()
return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
#NR: Split prompts and weights
def split_prompt(prompt):
vals = prompt.rsplit(':', 2)
vals = vals + ['', '1', '-inf'][len(vals):]
return vals[0], float(vals[1]), float(vals[2])
def load_vqgan_model(config_path, checkpoint_path):
global gumbel
gumbel = False
config = OmegaConf.load(config_path)
if config.model.target == 'taming.models.vqgan.VQModel':
model = vqgan.VQModel(**config.model.params)
model.eval().requires_grad_(False)
model.init_from_ckpt(checkpoint_path)
elif config.model.target == 'taming.models.vqgan.GumbelVQ':
model = vqgan.GumbelVQ(**config.model.params)
model.eval().requires_grad_(False)
model.init_from_ckpt(checkpoint_path)
gumbel = True
elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer':
parent_model = cond_transformer.Net2NetTransformer(**config.model.params)
parent_model.eval().requires_grad_(False)
parent_model.init_from_ckpt(checkpoint_path)
model = parent_model.first_stage_model
else:
raise ValueError(f'unknown model type: {config.model.target}')
del model.loss
return model
# Vector quantize
def synth(z):
if gumbel:
z_q = vector_quantize(z.movedim(1, 3), model.quantize.embed.weight).movedim(3, 1)
else:
z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1)
return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)
@torch.inference_mode()
def checkin(i, losses):
losses_str = ', '.join(f'{loss.item():g}' for loss in losses)
tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}')
out = synth(z)
info = PngImagePlugin.PngInfo()
info.add_text('comment', f'{args.prompts}')
TF.to_pil_image(out[0].cpu()).save(args.output, pnginfo=info)
def ascend_txt():
global i
out = synth(z)
iii = perceptor.encode_image(normalize(make_cutouts(out))).float()
result = []
if args.init_weight:
# result.append(F.mse_loss(z, z_orig) * args.init_weight / 2)
result.append(F.mse_loss(z, torch.zeros_like(z_orig)) * ((1/torch.tensor(i*2 + 1))*args.init_weight) / 2)
for prompt in pMs:
result.append(prompt(iii))
if args.make_video:
img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:]
img = np.transpose(img, (1, 2, 0))
imageio.imwrite('./steps/' + str(i) + '.png', np.array(img))
return result # return loss
def train(i):
opt.zero_grad(set_to_none=True)
lossAll = ascend_txt()
if i % args.display_freq == 0:
checkin(i, lossAll)
loss = sum(lossAll)
loss.backward()
opt.step()
#with torch.no_grad():
with torch.inference_mode():
z.copy_(z.maximum(z_min).minimum(z_max))
if __name__ == '__main__':
args = parse()
# Do it
device = torch.device(args.cuda_device)
model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device)
jit = True if version.parse(torch.__version__) < version.parse('1.8.0') else False
perceptor = clip.load(args.clip_model, jit=jit)[0].eval().requires_grad_(False).to(device)
cut_size = perceptor.visual.input_resolution
f = 2**(model.decoder.num_resolutions - 1)
# Cutout class options:
# 'latest','original','updated' or 'updatedpooling'
if args.cut_method == 'latest':
make_cutouts = MakeCutouts(args, cut_size, args.cutn)
elif args.cut_method == 'original':
make_cutouts = MakeCutoutsOrig(args, cut_size, args.cutn)
toksX, toksY = args.size[0] // f, args.size[1] // f
sideX, sideY = toksX * f, toksY * f
# Gumbel or not?
if gumbel:
e_dim = 256
n_toks = model.quantize.n_embed
z_min = model.quantize.embed.weight.min(dim=0).values[None, :, None, None]
z_max = model.quantize.embed.weight.max(dim=0).values[None, :, None, None]
else:
e_dim = model.quantize.e_dim
n_toks = model.quantize.n_e
z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]
if args.init_image:
if 'http' in args.init_image:
img = Image.open(urlopen(args.init_image))
else:
img = Image.open(args.init_image)
pil_image = img.convert('RGB')
pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
pil_tensor = TF.to_tensor(pil_image)
z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1)
elif args.init_noise == 'pixels':
img = random_noise_image(args.size[0], args.size[1])
pil_image = img.convert('RGB')
pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
pil_tensor = TF.to_tensor(pil_image)
z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1)
elif args.init_noise == 'gradient':
img = random_gradient_image(args.size[0], args.size[1])
pil_image = img.convert('RGB')
pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
pil_tensor = TF.to_tensor(pil_image)
z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1)
else:
one_hot = F.one_hot(torch.randint(n_toks, [toksY * toksX], device=device), n_toks).float()
# z = one_hot @ model.quantize.embedding.weight
if gumbel:
z = one_hot @ model.quantize.embed.weight
else:
z = one_hot @ model.quantize.embedding.weight
z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
#z = torch.rand_like(z)*2 # NR: check
z_orig = z.clone()
z.requires_grad_(True)
pMs = []
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
# CLIP tokenize/encode
if args.prompts:
for prompt in args.prompts:
txt, weight, stop = split_prompt(prompt)
embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
pMs.append(Prompt(embed, weight, stop).to(device))
for prompt in args.image_prompts:
path, weight, stop = split_prompt(prompt)
img = Image.open(path)
pil_image = img.convert('RGB')
img = resize_image(pil_image, (sideX, sideY))
batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
embed = perceptor.encode_image(normalize(batch)).float()
pMs.append(Prompt(embed, weight, stop).to(device))
for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights):
gen = torch.Generator().manual_seed(seed)
embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen)
pMs.append(Prompt(embed, weight).to(device))
# Set the optimiser
opt, z = get_opt(args.optimiser, z, args.step_size)
# Output for the user
print('Using device:', device)
print('Optimising using:', args.optimiser)
if args.prompts:
print('Using text prompts:', args.prompts)
if args.image_prompts:
print('Using image prompts:', args.image_prompts)
if args.init_image:
print('Using initial image:', args.init_image)
if args.noise_prompt_weights:
print('Noise prompt weights:', args.noise_prompt_weights)
if args.seed is None:
seed = torch.seed()
else:
seed = args.seed
torch.manual_seed(seed)
print('Using seed:', seed)
i = 0 # Iteration counter
j = 0 # Zoom video frame counter
p = 1 # Phrase counter
smoother = 0 # Smoother counter
this_video_frame = 0 # for video styling
with tqdm() as pbar:
while i < args.max_iterations:
# Change text prompt
if args.prompt_frequency > 0:
if i % args.prompt_frequency == 0 and i > 0:
# In case there aren't enough phrases, just loop
if p >= len(all_phrases):
p = 0
pMs = []
args.prompts = all_phrases[p]
# Show user we're changing prompt
print(args.prompts)
for prompt in args.prompts:
txt, weight, stop = split_prompt(prompt)
embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
pMs.append(Prompt(embed, weight, stop).to(device))
p += 1
train(i)
i += 1
pbar.update()
print("done")