#!/bin/env python """Train a Sketch-VAE.""" import argparse from enum import Enum import os import wget import time import numpy as np import torch as th from torch.utils.data import DataLoader import torchvision.datasets as dset import torchvision.transforms as transforms import ttools import ttools.interfaces from ttools.modules import networks import rendering import losses import modules import data import pydiffvg LOG = ttools.get_logger(__name__) BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir) OUTPUT = os.path.join(BASE_DIR, "results") class SketchVAE(th.nn.Module): class ImageEncoder(th.nn.Module): def __init__(self, image_size=64, width=64, zdim=128): super(SketchVAE.ImageEncoder, self).__init__() self.zdim = zdim self.net = th.nn.Sequential( th.nn.Conv2d(4, width, 5, padding=2), th.nn.InstanceNorm2d(width), th.nn.ReLU(inplace=True), # 64x64 th.nn.Conv2d(width, width, 5, padding=2), th.nn.InstanceNorm2d(width), th.nn.ReLU( inplace=True), # 64x64 th.nn.Conv2d(width, 2*width, 5, stride=1, padding=2), th.nn.InstanceNorm2d(2*width), th.nn.ReLU( inplace=True), # 32x32 th.nn.Conv2d(2*width, 2*width, 5, stride=2, padding=2), th.nn.InstanceNorm2d(2*width), th.nn.ReLU( inplace=True), # 16x16 th.nn.Conv2d(2*width, 2*width, 5, stride=2, padding=2), th.nn.InstanceNorm2d(2*width), th.nn.ReLU( inplace=True), # 16x16 th.nn.Conv2d(2*width, 2*width, 5, stride=2, padding=2), th.nn.InstanceNorm2d(2*width), th.nn.ReLU( inplace=True), # 8x8 th.nn.Conv2d(2*width, 2*width, 5, stride=2, padding=2), th.nn.InstanceNorm2d(2*width), th.nn.ReLU( inplace=True), # 4x4 modules.Flatten(), th.nn.Linear(4*4*2*width, 2*zdim) ) def forward(self, images): features = self.net(images) # VAE params mu = features[:, :self.zdim] log_sigma = features[:, self.zdim:] # Sample a latent vector sigma = th.exp(log_sigma/2.0) z0 = th.randn(self.zdim, device=mu.device) z = mu + sigma*z0 # KL divergence needs mu/sigma return z, mu, log_sigma class ImageDecoder(th.nn.Module): """""" def __init__(self, zdim=128, image_size=64, width=64): super(SketchVAE.ImageDecoder, self).__init__() self.zdim = zdim self.width = width self.embedding = th.nn.Linear(zdim, 4*4*2*width) self.net = th.nn.Sequential( th.nn.ConvTranspose2d(2*width, 2*width, 4, padding=1, stride=2), th.nn.InstanceNorm2d(2*width), th.nn.ReLU( inplace=True), # 8x8 th.nn.ConvTranspose2d(2*width, 2*width, 4, padding=1, stride=2), th.nn.InstanceNorm2d(2*width), th.nn.ReLU( inplace=True), # 16x16 th.nn.ConvTranspose2d(2*width, 2*width, 4, padding=1, stride=2), th.nn.InstanceNorm2d(2*width), th.nn.ReLU( inplace=True), # 16x16 th.nn.Conv2d(2*width, 2*width, 5, padding=2, stride=1), th.nn.InstanceNorm2d(2*width), th.nn.ReLU( inplace=True), # 16x16 th.nn.ConvTranspose2d(2*width, 2*width, 4, padding=1, stride=2), th.nn.InstanceNorm2d(2*width), th.nn.ReLU( inplace=True), # 32x32 th.nn.Conv2d(2*width, width, 5, padding=2, stride=1), th.nn.InstanceNorm2d(width), th.nn.ReLU( inplace=True), # 32x32 th.nn.ConvTranspose2d(width, width, 5, padding=2, stride=1), th.nn.InstanceNorm2d(width), th.nn.ReLU( inplace=True), # 64x64 th.nn.Conv2d(width, width, 5, padding=2, stride=1), th.nn.InstanceNorm2d(width), th.nn.ReLU( inplace=True), # 64x64 th.nn.Conv2d(width, 4, 5, padding=2, stride=1), ) def forward(self, z): bs = z.shape[0] im = self.embedding(z).view(bs, 2*self.width, 4, 4) out = self.net(im) return out class SketchDecoder(th.nn.Module): """ The decoder outputs a sequence where each time step models (dx, dy, opacity). """ def __init__(self, sequence_length, hidden_size=512, dropout=0.9, zdim=128, num_layers=3): super(SketchVAE.SketchDecoder, self).__init__() self.sequence_length = sequence_length self.hidden_size = hidden_size self.num_layers = num_layers self.zdim = zdim # Maps the latent vector to an initial cell/hidden vector self.hidden_cell_predictor = th.nn.Linear(zdim, 2*hidden_size*num_layers) self.lstm = th.nn.LSTM( zdim, hidden_size, num_layers=self.num_layers, dropout=dropout, batch_first=True) self.dxdy_predictor = th.nn.Sequential( th.nn.Linear(hidden_size, 2), th.nn.Tanh(), ) self.opacity_predictor = th.nn.Sequential( th.nn.Linear(hidden_size, 1), th.nn.Sigmoid(), ) def forward(self, z, hidden_and_cell=None): # Every step in the sequence takes the latent vector as input so we # replicate it here bs = z.shape[0] steps = self.sequence_length - 1 # no need to predict the start of sequence expanded_z = z.unsqueeze(1).repeat(1, steps, 1) if hidden_and_cell is None: # Initialize from latent vector hidden_and_cell = self.hidden_cell_predictor( th.tanh(z)) hidden = hidden_and_cell[:, :self.hidden_size*self.num_layers] hidden = hidden.view(-1, self.num_layers, self.hidden_size) hidden = hidden.permute(1, 0, 2).contiguous() # hidden = hidden.unsqueeze(1).contiguous() cell = hidden_and_cell[:, self.hidden_size*self.num_layers:] cell = cell.view(-1, self.num_layers, self.hidden_size) cell = cell.permute(1, 0, 2).contiguous() # cell = cell.unsqueeze(1).contiguous() hidden_and_cell = (hidden, cell) outputs, hidden_and_cell = self.lstm(expanded_z, hidden_and_cell) hidden, cell = hidden_and_cell dxdy = self.dxdy_predictor( outputs.reshape(bs*steps, self.hidden_size)).view(bs, steps, -1) opacity = self.opacity_predictor( outputs.reshape(bs*steps, self.hidden_size)).view(bs, steps, -1) strokes = th.cat([dxdy, opacity], -1) return strokes def __init__(self, sequence_length, zdim=128, image_size=64): super(SketchVAE, self).__init__() self.im_encoder = SketchVAE.ImageEncoder( zdim=zdim, image_size=image_size) self.im_decoder = SketchVAE.ImageDecoder( zdim=zdim, image_size=image_size) self.sketch_decoder = SketchVAE.SketchDecoder( sequence_length, zdim=zdim) def forward(self, images): # Encode the images as latent vectors z, mu, log_sigma = self.im_encoder(images) decoded_im = self.im_decoder(z) decoded_sketch = self.sketch_decoder(z) return { "decoded_im": decoded_im, "decoded_sketch": decoded_sketch, "z": z, "mu": mu, "log_sigma": log_sigma, } class SketchVAECallback(ttools.callbacks.ImageDisplayCallback): """Simple callback that visualize images.""" def visualized_image(self, batch, step_data, is_val=False): if is_val: return None # only display the first n drawings n = 8 gt = step_data["gt_image"][:n].detach() vae_im = step_data["vae_image"][:n].detach() sketch_im = step_data["sketch_image"][:n].detach() rendering = th.cat([gt, vae_im, sketch_im], 2) rendering = th.clamp(rendering, 0, 1) alpha = rendering[:, 3:4] rendering = rendering[:, :3] * alpha return rendering def caption(self, batch, step_data, is_val=False): if is_val: return "" else: return "top: truth, middle: vae sample, output: rnn-output" class Interface(ttools.ModelInterface): def __init__(self, model, lr=1e-4, lr_decay=0.9999, kl_weight=0.5, kl_min_weight=0.01, kl_decay=0.99995, raster_resolution=64, absolute_coords=False, device="cpu", grad_clip=1.0): super(Interface, self).__init__() self.grad_clip = grad_clip self.raster_resolution = raster_resolution self.absolute_coords = absolute_coords self.model = model self.device = device self.model.to(self.device) self.im_enc_opt = th.optim.Adam( self.model.im_encoder.parameters(), lr=lr) self.im_dec_opt = th.optim.Adam( self.model.im_decoder.parameters(), lr=lr) self.sketch_dec_opt = th.optim.Adam( self.model.sketch_decoder.parameters(), lr=lr) self.kl_weight = kl_weight self.kl_min_weight = kl_min_weight self.kl_decay = kl_decay self.kl_loss = losses.KLDivergence() self.schedulers = [ th.optim.lr_scheduler.ExponentialLR(self.im_enc_opt, lr_decay), th.optim.lr_scheduler.ExponentialLR(self.im_dec_opt, lr_decay), th.optim.lr_scheduler.ExponentialLR(self.sketch_dec_opt, lr_decay), ] # include loss on alpha self.im_loss = losses.MultiscaleMSELoss(channels=4).to(self.device) def optimizers(self): return [self.im_enc_opt, self.im_dec_opt, self.sketch_dec_opt] def kl_scaling(self): # Scale the KL divergence weight try: state = self.im_enc_opt.state_dict()["param_groups"][0]["params"][0] optim_step = self.im_enc_opt.state_dict()["state"][state]["step"] except KeyError: optim_step = 0 # no step taken yet kl_scaling = 1.0 - (1.0 - self.kl_min_weight)*(self.kl_decay**optim_step) return kl_scaling def training_step(self, batch): gt_strokes, gt_im = batch gt_strokes = gt_strokes.to(self.device) gt_im = gt_im.to(self.device) out = self.model(gt_im) kl_loss = self.kl_loss( out["mu"], out["log_sigma"]) kl_weight = self.kl_weight * self.kl_scaling() # add start of sequence sos = gt_strokes[:, :1] sketch = th.cat([sos, out["decoded_sketch"]], 1) vae_im = out["decoded_im"] # start = time.time() sketch_im = rendering.opacityStroke2diffvg( sketch, canvas_size=self.raster_resolution, debug=False, force_cpu=True, relative=not self.absolute_coords) # elapsed = (time.time() - start)*1000 # print("out rendering took %.2fms" % elapsed) vae_im_loss = self.im_loss(vae_im, gt_im) sketch_im_loss = self.im_loss(sketch_im, gt_im) # vae_im_loss = th.nn.functional.mse_loss(vae_im, gt_im) # sketch_im_loss = th.nn.functional.mse_loss(sketch_im, gt_im) loss = vae_im_loss + kl_loss*kl_weight + sketch_im_loss self.im_enc_opt.zero_grad() self.im_dec_opt.zero_grad() self.sketch_dec_opt.zero_grad() loss.backward() # clip gradients enc_nrm = th.nn.utils.clip_grad_norm_( self.model.im_encoder.parameters(), self.grad_clip) dec_nrm = th.nn.utils.clip_grad_norm_( self.model.im_decoder.parameters(), self.grad_clip) sketch_dec_nrm = th.nn.utils.clip_grad_norm_( self.model.sketch_decoder.parameters(), self.grad_clip) if enc_nrm > self.grad_clip: LOG.debug("Clipped encoder gradient (%.5f) to %.2f", enc_nrm, self.grad_clip) if dec_nrm > self.grad_clip: LOG.debug("Clipped decoder gradient (%.5f) to %.2f", dec_nrm, self.grad_clip) if sketch_dec_nrm > self.grad_clip: LOG.debug("Clipped sketch decoder gradient (%.5f) to %.2f", sketch_dec_nrm, self.grad_clip) self.im_enc_opt.step() self.im_dec_opt.step() self.sketch_dec_opt.step() return { "vae_image": vae_im, "sketch_image": sketch_im, "gt_image": gt_im, "loss": loss.item(), "vae_im_loss": vae_im_loss.item(), "sketch_im_loss": sketch_im_loss.item(), "kl_loss": kl_loss.item(), "kl_weight": kl_weight, "lr": self.im_enc_opt.param_groups[0]["lr"], } def init_validation(self): return dict(sample=None) def validation_step(self, batch, running_data): # Switch to eval mode for dropout, batchnorm, etc # self.model.eval() # with th.no_grad(): # # sample = self.model.sample( # # batch.to(self.device), temperature=self.sampling_temperature) # # running_data["sample"] = sample # self.model.train() return running_data def train(args): th.manual_seed(0) np.random.seed(0) dataset = data.FixedLengthQuickDrawDataset( args.dataset, max_seq_length=args.sequence_length, canvas_size=args.raster_resolution) dataloader = DataLoader( dataset, batch_size=args.bs, num_workers=args.workers, shuffle=True) # val_dataset = [s for idx, s in enumerate(dataset) if idx < 8] # val_dataloader = DataLoader( # val_dataset, batch_size=8, num_workers=4, shuffle=False) val_dataloader = None model_params = { "zdim": args.zdim, "sequence_length": args.sequence_length, "image_size": args.raster_resolution, # "encoder_dim": args.encoder_dim, # "decoder_dim": args.decoder_dim, } model = SketchVAE(**model_params) model.train() LOG.info("Model parameters:\n%s", model_params) device = "cpu" if th.cuda.is_available(): device = "cuda" LOG.info("Using CUDA") interface = Interface(model, raster_resolution=args.raster_resolution, lr=args.lr, lr_decay=args.lr_decay, kl_decay=args.kl_decay, kl_weight=args.kl_weight, absolute_coords=args.absolute_coordinates, device=device) env_name = "sketch_vae" if args.custom_name is not None: env_name += "_" + args.custom_name if args.absolute_coordinates: env_name += "_abs_coords" chkpt = os.path.join(OUTPUT, env_name) # Resume from checkpoint, if any checkpointer = ttools.Checkpointer( chkpt, model, meta=model_params, optimizers=interface.optimizers(), schedulers=interface.schedulers) extras, meta = checkpointer.load_latest() epoch = extras["epoch"] if extras and "epoch" in extras.keys() else 0 if meta is not None and meta != model_params: LOG.info("Checkpoint's metaparams differ " "from CLI, aborting: %s and %s", meta, model_params) trainer = ttools.Trainer(interface) # Add callbacks losses = ["loss", "kl_loss", "vae_im_loss", "sketch_im_loss"] training_debug = ["lr", "kl_weight"] trainer.add_callback(ttools.callbacks.ProgressBarCallback( keys=losses, val_keys=None)) trainer.add_callback(ttools.callbacks.VisdomLoggingCallback( keys=losses, val_keys=None, env=env_name, port=args.port)) trainer.add_callback(ttools.callbacks.VisdomLoggingCallback( keys=training_debug, smoothing=0, val_keys=None, env=env_name, port=args.port)) trainer.add_callback(ttools.callbacks.CheckpointingCallback( checkpointer, max_files=2, interval=600, max_epochs=10)) trainer.add_callback( ttools.callbacks.LRSchedulerCallback(interface.schedulers)) trainer.add_callback(SketchVAECallback( env=env_name, win="samples", port=args.port, frequency=args.freq)) # Start training trainer.train(dataloader, starting_epoch=epoch, val_dataloader=val_dataloader, num_epochs=args.num_epochs) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--dataset", default="cat.npz") parser.add_argument("--absolute_coordinates", action="store_true", default=False) parser.add_argument("--custom_name") # Training params parser.add_argument("--bs", type=int, default=1) parser.add_argument("--workers", type=int, default=0) parser.add_argument("--num_epochs", type=int, default=10000) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--lr_decay", type=float, default=0.9999) parser.add_argument("--kl_weight", type=float, default=0.5) parser.add_argument("--kl_decay", type=float, default=0.99995) # Model configuration parser.add_argument("--zdim", type=int, default=128) parser.add_argument("--sequence_length", type=int, default=50) parser.add_argument("--raster_resolution", type=int, default=64) # parser.add_argument("--encoder_dim", type=int, default=256) # parser.add_argument("--decoder_dim", type=int, default=512) # Viz params parser.add_argument("--freq", type=int, default=10) parser.add_argument("--port", type=int, default=5000) args = parser.parse_args() pydiffvg.set_use_gpu(False) train(args)