Spaces:
Configuration error
Configuration error
| import argparse | |
| import math | |
| import random | |
| from vqgan_clip.grad import * | |
| from vqgan_clip.helpers import * | |
| from vqgan_clip.inits import * | |
| from vqgan_clip.masking import * | |
| from vqgan_clip.optimizers import * | |
| from urllib.request import urlopen | |
| from tqdm import tqdm | |
| import sys | |
| import os | |
| from omegaconf import OmegaConf | |
| from taming.models import cond_transformer, vqgan | |
| import torch | |
| from torch import nn, optim | |
| from torch.nn import functional as F | |
| from torchvision import transforms | |
| from torchvision.transforms import functional as TF | |
| from torch.cuda import get_device_properties | |
| torch.backends.cudnn.benchmark = False | |
| from torch_optimizer import DiffGrad, AdamP, RAdam | |
| import clip | |
| import kornia.augmentation as K | |
| import numpy as np | |
| import imageio | |
| from PIL import ImageFile, Image, PngImagePlugin, ImageChops | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| from subprocess import Popen, PIPE | |
| import re | |
| from packaging import version | |
| # Supress warnings | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # Check for GPU and reduce the default image size if low VRAM | |
| default_image_size = 512 # >8GB VRAM | |
| if not torch.cuda.is_available(): | |
| default_image_size = 256 # no GPU found | |
| elif get_device_properties(0).total_memory <= 2 ** 33: # 2 ** 33 = 8,589,934,592 bytes = 8 GB | |
| default_image_size = 318 # <8GB VRAM | |
| def parse(): | |
| vq_parser = argparse.ArgumentParser(description='Image generation using VQGAN+CLIP') | |
| vq_parser.add_argument("-aug", "--augments", nargs='+', action='append', type=str, choices=['Hf','Ji','Sh','Pe','Ro','Af','Et','Ts','Er'], | |
| help="Enabled augments (latest vut method only)", default=[['Hf','Af', 'Pe', 'Ji', 'Er']], dest='augments') | |
| vq_parser.add_argument("-cd", "--cuda_device", type=str, help="Cuda device to use", default="cuda:0", dest='cuda_device') | |
| vq_parser.add_argument("-ckpt", "--vqgan_checkpoint", type=str, help="VQGAN checkpoint", default=f'checkpoints/vqgan_imagenet_f16_16384.ckpt', | |
| dest='vqgan_checkpoint') | |
| vq_parser.add_argument("-conf", "--vqgan_config", type=str, help="VQGAN config", default=f'checkpoints/vqgan_imagenet_f16_16384.yaml', dest='vqgan_config') | |
| vq_parser.add_argument("-cpe", "--change_prompt_every", type=int, help="Prompt change frequency", default=0, dest='prompt_frequency') | |
| vq_parser.add_argument("-cutm", "--cut_method", type=str, help="Cut method", choices=['original','latest'], | |
| default='latest', dest='cut_method') | |
| vq_parser.add_argument("-cutp", "--cut_power", type=float, help="Cut power", default=1., dest='cut_pow') | |
| vq_parser.add_argument("-cuts", "--num_cuts", type=int, help="Number of cuts", default=32, dest='cutn') | |
| vq_parser.add_argument("-d", "--deterministic", action='store_true', help="Enable cudnn.deterministic?", dest='cudnn_determinism') | |
| vq_parser.add_argument("-i", "--iterations", type=int, help="Number of iterations", default=500, dest='max_iterations') | |
| vq_parser.add_argument("-ifps", "--input_video_fps", type=float, | |
| help="When creating an interpolated video, use this as the input fps to interpolate from (>0 & <ofps)", default=15, | |
| dest='input_video_fps') | |
| vq_parser.add_argument("-ii", "--init_image", type=str, help="Initial image", default=None, dest='init_image') | |
| vq_parser.add_argument("-in", "--init_noise", type=str, help="Initial noise image (pixels or gradient)", default=None, dest='init_noise') | |
| vq_parser.add_argument("-ip", "--image_prompts", type=str, help="Image prompts / target image", default=[], dest='image_prompts') | |
| vq_parser.add_argument("-iw", "--init_weight", type=float, help="Initial weight", default=0., dest='init_weight') | |
| vq_parser.add_argument("-lr", "--learning_rate", type=float, help="Learning rate", default=0.1, dest='step_size') | |
| vq_parser.add_argument("-m", "--clip_model", type=str, help="CLIP model (e.g. ViT-B/32, ViT-B/16)", default='ViT-B/32', dest='clip_model') | |
| vq_parser.add_argument("-nps", "--noise_prompt_seeds", nargs="*", type=int, help="Noise prompt seeds", default=[], dest='noise_prompt_seeds') | |
| vq_parser.add_argument("-npw", "--noise_prompt_weights", nargs="*", type=float, help="Noise prompt weights", default=[], dest='noise_prompt_weights') | |
| vq_parser.add_argument("-o", "--output", type=str, help="Output filename", default="output.png", dest='output') | |
| vq_parser.add_argument("-ofps", "--output_video_fps", type=float, | |
| help="Create an interpolated video (Nvidia GPU only) with this fps (min 10. best set to 30 or 60)", default=0, dest='output_video_fps') | |
| vq_parser.add_argument("-opt", "--optimiser", type=str, help="Optimiser", choices=['Adam','AdamW','Adagrad','Adamax','DiffGrad','AdamP','RAdam','RMSprop'], | |
| default='Adam', dest='optimiser') | |
| vq_parser.add_argument("-p", "--prompts", type=str, help="Text prompts", default=None, dest='prompts') | |
| vq_parser.add_argument("-s", "--size", nargs=2, type=int, help="Image size (width height) (default: %(default)s)", | |
| default=[default_image_size, default_image_size], dest='size') | |
| vq_parser.add_argument("-sd", "--seed", type=int, help="Seed", default=None, dest='seed') | |
| vq_parser.add_argument("-se", "--save_every", type=int, help="Save image iterations", default=50, dest='display_freq') | |
| vq_parser.add_argument("-vid", "--video", action='store_true', help="Create video frames?", dest='make_video') | |
| vq_parser.add_argument("-vl", "--video_length", type=float, help="Video length in seconds (not interpolated)", default=10, dest='video_length') | |
| vq_parser.add_argument("-vsd", "--video_style_dir", type=str, help="Directory with video frames to style", default=None, dest='video_style_dir') | |
| vq_parser.add_argument("-zs", "--zoom_start", type=int, help="Zoom start iteration", default=0, dest='zoom_start') | |
| vq_parser.add_argument("-zsc", "--zoom_scale", type=float, help="Zoom scale %", default=0.99, dest='zoom_scale') | |
| vq_parser.add_argument("-zse", "--zoom_save_every", type=int, help="Save zoom image iterations", default=10, dest='zoom_frequency') | |
| vq_parser.add_argument("-zsx", "--zoom_shift_x", type=int, help="Zoom shift x (left/right) amount in pixels", default=0, dest='zoom_shift_x') | |
| vq_parser.add_argument("-zsy", "--zoom_shift_y", type=int, help="Zoom shift y (up/down) amount in pixels", default=0, dest='zoom_shift_y') | |
| vq_parser.add_argument("-zvid", "--zoom_video", action='store_true', help="Create zoom video?", dest='make_zoom_video') | |
| args = vq_parser.parse_args() | |
| if not args.prompts and not args.image_prompts: | |
| raise Exception("You must supply a text or image prompt") | |
| torch.backends.cudnn.deterministic = args.cudnn_determinism | |
| # Split text prompts using the pipe character (weights are split later) | |
| if args.prompts: | |
| # For stories, there will be many phrases | |
| story_phrases = [phrase.strip() for phrase in args.prompts.split("^")] | |
| # Make a list of all phrases | |
| all_phrases = [] | |
| for phrase in story_phrases: | |
| all_phrases.append(phrase.split("|")) | |
| # First phrase | |
| args.prompts = all_phrases[0] | |
| # Split target images using the pipe character (weights are split later) | |
| if args.image_prompts: | |
| args.image_prompts = args.image_prompts.split("|") | |
| args.image_prompts = [image.strip() for image in args.image_prompts] | |
| if args.make_video and args.make_zoom_video: | |
| print("Warning: Make video and make zoom video are mutually exclusive.") | |
| args.make_video = False | |
| # Make video steps directory | |
| if args.make_video or args.make_zoom_video: | |
| if not os.path.exists('steps'): | |
| os.mkdir('steps') | |
| return args | |
| class Prompt(nn.Module): | |
| def __init__(self, embed, weight=1., stop=float('-inf')): | |
| super().__init__() | |
| self.register_buffer('embed', embed) | |
| self.register_buffer('weight', torch.as_tensor(weight)) | |
| self.register_buffer('stop', torch.as_tensor(stop)) | |
| def forward(self, input): | |
| input_normed = F.normalize(input.unsqueeze(1), dim=2) | |
| embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2) | |
| dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) | |
| dists = dists * self.weight.sign() | |
| return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean() | |
| #NR: Split prompts and weights | |
| def split_prompt(prompt): | |
| vals = prompt.rsplit(':', 2) | |
| vals = vals + ['', '1', '-inf'][len(vals):] | |
| return vals[0], float(vals[1]), float(vals[2]) | |
| def load_vqgan_model(config_path, checkpoint_path): | |
| global gumbel | |
| gumbel = False | |
| config = OmegaConf.load(config_path) | |
| if config.model.target == 'taming.models.vqgan.VQModel': | |
| model = vqgan.VQModel(**config.model.params) | |
| model.eval().requires_grad_(False) | |
| model.init_from_ckpt(checkpoint_path) | |
| elif config.model.target == 'taming.models.vqgan.GumbelVQ': | |
| model = vqgan.GumbelVQ(**config.model.params) | |
| model.eval().requires_grad_(False) | |
| model.init_from_ckpt(checkpoint_path) | |
| gumbel = True | |
| elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer': | |
| parent_model = cond_transformer.Net2NetTransformer(**config.model.params) | |
| parent_model.eval().requires_grad_(False) | |
| parent_model.init_from_ckpt(checkpoint_path) | |
| model = parent_model.first_stage_model | |
| else: | |
| raise ValueError(f'unknown model type: {config.model.target}') | |
| del model.loss | |
| return model | |
| # Vector quantize | |
| def synth(z): | |
| if gumbel: | |
| z_q = vector_quantize(z.movedim(1, 3), model.quantize.embed.weight).movedim(3, 1) | |
| else: | |
| z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1) | |
| return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1) | |
| def checkin(i, losses): | |
| losses_str = ', '.join(f'{loss.item():g}' for loss in losses) | |
| tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}') | |
| out = synth(z) | |
| info = PngImagePlugin.PngInfo() | |
| info.add_text('comment', f'{args.prompts}') | |
| TF.to_pil_image(out[0].cpu()).save(args.output, pnginfo=info) | |
| def ascend_txt(): | |
| global i | |
| out = synth(z) | |
| iii = perceptor.encode_image(normalize(make_cutouts(out))).float() | |
| result = [] | |
| if args.init_weight: | |
| # result.append(F.mse_loss(z, z_orig) * args.init_weight / 2) | |
| result.append(F.mse_loss(z, torch.zeros_like(z_orig)) * ((1/torch.tensor(i*2 + 1))*args.init_weight) / 2) | |
| for prompt in pMs: | |
| result.append(prompt(iii)) | |
| if args.make_video: | |
| img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:] | |
| img = np.transpose(img, (1, 2, 0)) | |
| imageio.imwrite('./steps/' + str(i) + '.png', np.array(img)) | |
| return result # return loss | |
| def train(i): | |
| opt.zero_grad(set_to_none=True) | |
| lossAll = ascend_txt() | |
| if i % args.display_freq == 0: | |
| checkin(i, lossAll) | |
| loss = sum(lossAll) | |
| loss.backward() | |
| opt.step() | |
| #with torch.no_grad(): | |
| with torch.inference_mode(): | |
| z.copy_(z.maximum(z_min).minimum(z_max)) | |
| if __name__ == '__main__': | |
| args = parse() | |
| # Do it | |
| device = torch.device(args.cuda_device) | |
| model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device) | |
| jit = True if version.parse(torch.__version__) < version.parse('1.8.0') else False | |
| perceptor = clip.load(args.clip_model, jit=jit)[0].eval().requires_grad_(False).to(device) | |
| cut_size = perceptor.visual.input_resolution | |
| f = 2**(model.decoder.num_resolutions - 1) | |
| # Cutout class options: | |
| # 'latest','original','updated' or 'updatedpooling' | |
| if args.cut_method == 'latest': | |
| make_cutouts = MakeCutouts(args, cut_size, args.cutn) | |
| elif args.cut_method == 'original': | |
| make_cutouts = MakeCutoutsOrig(args, cut_size, args.cutn) | |
| toksX, toksY = args.size[0] // f, args.size[1] // f | |
| sideX, sideY = toksX * f, toksY * f | |
| # Gumbel or not? | |
| if gumbel: | |
| e_dim = 256 | |
| n_toks = model.quantize.n_embed | |
| z_min = model.quantize.embed.weight.min(dim=0).values[None, :, None, None] | |
| z_max = model.quantize.embed.weight.max(dim=0).values[None, :, None, None] | |
| else: | |
| e_dim = model.quantize.e_dim | |
| n_toks = model.quantize.n_e | |
| z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None] | |
| z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None] | |
| if args.init_image: | |
| if 'http' in args.init_image: | |
| img = Image.open(urlopen(args.init_image)) | |
| else: | |
| img = Image.open(args.init_image) | |
| pil_image = img.convert('RGB') | |
| pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS) | |
| pil_tensor = TF.to_tensor(pil_image) | |
| z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1) | |
| elif args.init_noise == 'pixels': | |
| img = random_noise_image(args.size[0], args.size[1]) | |
| pil_image = img.convert('RGB') | |
| pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS) | |
| pil_tensor = TF.to_tensor(pil_image) | |
| z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1) | |
| elif args.init_noise == 'gradient': | |
| img = random_gradient_image(args.size[0], args.size[1]) | |
| pil_image = img.convert('RGB') | |
| pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS) | |
| pil_tensor = TF.to_tensor(pil_image) | |
| z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1) | |
| else: | |
| one_hot = F.one_hot(torch.randint(n_toks, [toksY * toksX], device=device), n_toks).float() | |
| # z = one_hot @ model.quantize.embedding.weight | |
| if gumbel: | |
| z = one_hot @ model.quantize.embed.weight | |
| else: | |
| z = one_hot @ model.quantize.embedding.weight | |
| z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2) | |
| #z = torch.rand_like(z)*2 # NR: check | |
| z_orig = z.clone() | |
| z.requires_grad_(True) | |
| pMs = [] | |
| normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], | |
| std=[0.26862954, 0.26130258, 0.27577711]) | |
| # CLIP tokenize/encode | |
| if args.prompts: | |
| for prompt in args.prompts: | |
| txt, weight, stop = split_prompt(prompt) | |
| embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float() | |
| pMs.append(Prompt(embed, weight, stop).to(device)) | |
| for prompt in args.image_prompts: | |
| path, weight, stop = split_prompt(prompt) | |
| img = Image.open(path) | |
| pil_image = img.convert('RGB') | |
| img = resize_image(pil_image, (sideX, sideY)) | |
| batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device)) | |
| embed = perceptor.encode_image(normalize(batch)).float() | |
| pMs.append(Prompt(embed, weight, stop).to(device)) | |
| for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights): | |
| gen = torch.Generator().manual_seed(seed) | |
| embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen) | |
| pMs.append(Prompt(embed, weight).to(device)) | |
| # Set the optimiser | |
| opt, z = get_opt(args.optimiser, z, args.step_size) | |
| # Output for the user | |
| print('Using device:', device) | |
| print('Optimising using:', args.optimiser) | |
| if args.prompts: | |
| print('Using text prompts:', args.prompts) | |
| if args.image_prompts: | |
| print('Using image prompts:', args.image_prompts) | |
| if args.init_image: | |
| print('Using initial image:', args.init_image) | |
| if args.noise_prompt_weights: | |
| print('Noise prompt weights:', args.noise_prompt_weights) | |
| if args.seed is None: | |
| seed = torch.seed() | |
| else: | |
| seed = args.seed | |
| torch.manual_seed(seed) | |
| print('Using seed:', seed) | |
| i = 0 # Iteration counter | |
| j = 0 # Zoom video frame counter | |
| p = 1 # Phrase counter | |
| smoother = 0 # Smoother counter | |
| this_video_frame = 0 # for video styling | |
| with tqdm() as pbar: | |
| while i < args.max_iterations: | |
| # Change text prompt | |
| if args.prompt_frequency > 0: | |
| if i % args.prompt_frequency == 0 and i > 0: | |
| # In case there aren't enough phrases, just loop | |
| if p >= len(all_phrases): | |
| p = 0 | |
| pMs = [] | |
| args.prompts = all_phrases[p] | |
| # Show user we're changing prompt | |
| print(args.prompts) | |
| for prompt in args.prompts: | |
| txt, weight, stop = split_prompt(prompt) | |
| embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float() | |
| pMs.append(Prompt(embed, weight, stop).to(device)) | |
| p += 1 | |
| train(i) | |
| i += 1 | |
| pbar.update() | |
| print("done") | |