Spaces:
Sleeping
Sleeping
#!/bin/env python | |
"""Train a Sketch-VAE.""" | |
import argparse | |
from enum import Enum | |
import os | |
import wget | |
import time | |
import numpy as np | |
import torch as th | |
from torch.utils.data import DataLoader | |
import torchvision.datasets as dset | |
import torchvision.transforms as transforms | |
import ttools | |
import ttools.interfaces | |
from ttools.modules import networks | |
import rendering | |
import losses | |
import modules | |
import data | |
import pydiffvg | |
LOG = ttools.get_logger(__name__) | |
BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir) | |
OUTPUT = os.path.join(BASE_DIR, "results") | |
class SketchVAE(th.nn.Module): | |
class ImageEncoder(th.nn.Module): | |
def __init__(self, image_size=64, width=64, zdim=128): | |
super(SketchVAE.ImageEncoder, self).__init__() | |
self.zdim = zdim | |
self.net = th.nn.Sequential( | |
th.nn.Conv2d(4, width, 5, padding=2), | |
th.nn.InstanceNorm2d(width), | |
th.nn.ReLU(inplace=True), | |
# 64x64 | |
th.nn.Conv2d(width, width, 5, padding=2), | |
th.nn.InstanceNorm2d(width), | |
th.nn.ReLU( inplace=True), | |
# 64x64 | |
th.nn.Conv2d(width, 2*width, 5, stride=1, padding=2), | |
th.nn.InstanceNorm2d(2*width), | |
th.nn.ReLU( inplace=True), | |
# 32x32 | |
th.nn.Conv2d(2*width, 2*width, 5, stride=2, padding=2), | |
th.nn.InstanceNorm2d(2*width), | |
th.nn.ReLU( inplace=True), | |
# 16x16 | |
th.nn.Conv2d(2*width, 2*width, 5, stride=2, padding=2), | |
th.nn.InstanceNorm2d(2*width), | |
th.nn.ReLU( inplace=True), | |
# 16x16 | |
th.nn.Conv2d(2*width, 2*width, 5, stride=2, padding=2), | |
th.nn.InstanceNorm2d(2*width), | |
th.nn.ReLU( inplace=True), | |
# 8x8 | |
th.nn.Conv2d(2*width, 2*width, 5, stride=2, padding=2), | |
th.nn.InstanceNorm2d(2*width), | |
th.nn.ReLU( inplace=True), | |
# 4x4 | |
modules.Flatten(), | |
th.nn.Linear(4*4*2*width, 2*zdim) | |
) | |
def forward(self, images): | |
features = self.net(images) | |
# VAE params | |
mu = features[:, :self.zdim] | |
log_sigma = features[:, self.zdim:] | |
# Sample a latent vector | |
sigma = th.exp(log_sigma/2.0) | |
z0 = th.randn(self.zdim, device=mu.device) | |
z = mu + sigma*z0 | |
# KL divergence needs mu/sigma | |
return z, mu, log_sigma | |
class ImageDecoder(th.nn.Module): | |
"""""" | |
def __init__(self, zdim=128, image_size=64, width=64): | |
super(SketchVAE.ImageDecoder, self).__init__() | |
self.zdim = zdim | |
self.width = width | |
self.embedding = th.nn.Linear(zdim, 4*4*2*width) | |
self.net = th.nn.Sequential( | |
th.nn.ConvTranspose2d(2*width, 2*width, 4, padding=1, stride=2), | |
th.nn.InstanceNorm2d(2*width), | |
th.nn.ReLU( inplace=True), | |
# 8x8 | |
th.nn.ConvTranspose2d(2*width, 2*width, 4, padding=1, stride=2), | |
th.nn.InstanceNorm2d(2*width), | |
th.nn.ReLU( inplace=True), | |
# 16x16 | |
th.nn.ConvTranspose2d(2*width, 2*width, 4, padding=1, stride=2), | |
th.nn.InstanceNorm2d(2*width), | |
th.nn.ReLU( inplace=True), | |
# 16x16 | |
th.nn.Conv2d(2*width, 2*width, 5, padding=2, stride=1), | |
th.nn.InstanceNorm2d(2*width), | |
th.nn.ReLU( inplace=True), | |
# 16x16 | |
th.nn.ConvTranspose2d(2*width, 2*width, 4, padding=1, stride=2), | |
th.nn.InstanceNorm2d(2*width), | |
th.nn.ReLU( inplace=True), | |
# 32x32 | |
th.nn.Conv2d(2*width, width, 5, padding=2, stride=1), | |
th.nn.InstanceNorm2d(width), | |
th.nn.ReLU( inplace=True), | |
# 32x32 | |
th.nn.ConvTranspose2d(width, width, 5, padding=2, stride=1), | |
th.nn.InstanceNorm2d(width), | |
th.nn.ReLU( inplace=True), | |
# 64x64 | |
th.nn.Conv2d(width, width, 5, padding=2, stride=1), | |
th.nn.InstanceNorm2d(width), | |
th.nn.ReLU( inplace=True), | |
# 64x64 | |
th.nn.Conv2d(width, 4, 5, padding=2, stride=1), | |
) | |
def forward(self, z): | |
bs = z.shape[0] | |
im = self.embedding(z).view(bs, 2*self.width, 4, 4) | |
out = self.net(im) | |
return out | |
class SketchDecoder(th.nn.Module): | |
""" | |
The decoder outputs a sequence where each time step models (dx, dy, | |
opacity). | |
""" | |
def __init__(self, sequence_length, hidden_size=512, dropout=0.9, | |
zdim=128, num_layers=3): | |
super(SketchVAE.SketchDecoder, self).__init__() | |
self.sequence_length = sequence_length | |
self.hidden_size = hidden_size | |
self.num_layers = num_layers | |
self.zdim = zdim | |
# Maps the latent vector to an initial cell/hidden vector | |
self.hidden_cell_predictor = th.nn.Linear(zdim, 2*hidden_size*num_layers) | |
self.lstm = th.nn.LSTM( | |
zdim, hidden_size, | |
num_layers=self.num_layers, dropout=dropout, | |
batch_first=True) | |
self.dxdy_predictor = th.nn.Sequential( | |
th.nn.Linear(hidden_size, 2), | |
th.nn.Tanh(), | |
) | |
self.opacity_predictor = th.nn.Sequential( | |
th.nn.Linear(hidden_size, 1), | |
th.nn.Sigmoid(), | |
) | |
def forward(self, z, hidden_and_cell=None): | |
# Every step in the sequence takes the latent vector as input so we | |
# replicate it here | |
bs = z.shape[0] | |
steps = self.sequence_length - 1 # no need to predict the start of sequence | |
expanded_z = z.unsqueeze(1).repeat(1, steps, 1) | |
if hidden_and_cell is None: | |
# Initialize from latent vector | |
hidden_and_cell = self.hidden_cell_predictor( | |
th.tanh(z)) | |
hidden = hidden_and_cell[:, :self.hidden_size*self.num_layers] | |
hidden = hidden.view(-1, self.num_layers, self.hidden_size) | |
hidden = hidden.permute(1, 0, 2).contiguous() | |
# hidden = hidden.unsqueeze(1).contiguous() | |
cell = hidden_and_cell[:, self.hidden_size*self.num_layers:] | |
cell = cell.view(-1, self.num_layers, self.hidden_size) | |
cell = cell.permute(1, 0, 2).contiguous() | |
# cell = cell.unsqueeze(1).contiguous() | |
hidden_and_cell = (hidden, cell) | |
outputs, hidden_and_cell = self.lstm(expanded_z, hidden_and_cell) | |
hidden, cell = hidden_and_cell | |
dxdy = self.dxdy_predictor( | |
outputs.reshape(bs*steps, self.hidden_size)).view(bs, steps, -1) | |
opacity = self.opacity_predictor( | |
outputs.reshape(bs*steps, self.hidden_size)).view(bs, steps, -1) | |
strokes = th.cat([dxdy, opacity], -1) | |
return strokes | |
def __init__(self, sequence_length, zdim=128, image_size=64): | |
super(SketchVAE, self).__init__() | |
self.im_encoder = SketchVAE.ImageEncoder( | |
zdim=zdim, image_size=image_size) | |
self.im_decoder = SketchVAE.ImageDecoder( | |
zdim=zdim, image_size=image_size) | |
self.sketch_decoder = SketchVAE.SketchDecoder( | |
sequence_length, zdim=zdim) | |
def forward(self, images): | |
# Encode the images as latent vectors | |
z, mu, log_sigma = self.im_encoder(images) | |
decoded_im = self.im_decoder(z) | |
decoded_sketch = self.sketch_decoder(z) | |
return { | |
"decoded_im": decoded_im, | |
"decoded_sketch": decoded_sketch, | |
"z": z, | |
"mu": mu, | |
"log_sigma": log_sigma, | |
} | |
class SketchVAECallback(ttools.callbacks.ImageDisplayCallback): | |
"""Simple callback that visualize images.""" | |
def visualized_image(self, batch, step_data, is_val=False): | |
if is_val: | |
return None | |
# only display the first n drawings | |
n = 8 | |
gt = step_data["gt_image"][:n].detach() | |
vae_im = step_data["vae_image"][:n].detach() | |
sketch_im = step_data["sketch_image"][:n].detach() | |
rendering = th.cat([gt, vae_im, sketch_im], 2) | |
rendering = th.clamp(rendering, 0, 1) | |
alpha = rendering[:, 3:4] | |
rendering = rendering[:, :3] * alpha | |
return rendering | |
def caption(self, batch, step_data, is_val=False): | |
if is_val: | |
return "" | |
else: | |
return "top: truth, middle: vae sample, output: rnn-output" | |
class Interface(ttools.ModelInterface): | |
def __init__(self, model, lr=1e-4, lr_decay=0.9999, | |
kl_weight=0.5, kl_min_weight=0.01, kl_decay=0.99995, | |
raster_resolution=64, absolute_coords=False, | |
device="cpu", grad_clip=1.0): | |
super(Interface, self).__init__() | |
self.grad_clip = grad_clip | |
self.raster_resolution = raster_resolution | |
self.absolute_coords = absolute_coords | |
self.model = model | |
self.device = device | |
self.model.to(self.device) | |
self.im_enc_opt = th.optim.Adam( | |
self.model.im_encoder.parameters(), lr=lr) | |
self.im_dec_opt = th.optim.Adam( | |
self.model.im_decoder.parameters(), lr=lr) | |
self.sketch_dec_opt = th.optim.Adam( | |
self.model.sketch_decoder.parameters(), lr=lr) | |
self.kl_weight = kl_weight | |
self.kl_min_weight = kl_min_weight | |
self.kl_decay = kl_decay | |
self.kl_loss = losses.KLDivergence() | |
self.schedulers = [ | |
th.optim.lr_scheduler.ExponentialLR(self.im_enc_opt, lr_decay), | |
th.optim.lr_scheduler.ExponentialLR(self.im_dec_opt, lr_decay), | |
th.optim.lr_scheduler.ExponentialLR(self.sketch_dec_opt, lr_decay), | |
] | |
# include loss on alpha | |
self.im_loss = losses.MultiscaleMSELoss(channels=4).to(self.device) | |
def optimizers(self): | |
return [self.im_enc_opt, self.im_dec_opt, self.sketch_dec_opt] | |
def kl_scaling(self): | |
# Scale the KL divergence weight | |
try: | |
state = self.im_enc_opt.state_dict()["param_groups"][0]["params"][0] | |
optim_step = self.im_enc_opt.state_dict()["state"][state]["step"] | |
except KeyError: | |
optim_step = 0 # no step taken yet | |
kl_scaling = 1.0 - (1.0 - | |
self.kl_min_weight)*(self.kl_decay**optim_step) | |
return kl_scaling | |
def training_step(self, batch): | |
gt_strokes, gt_im = batch | |
gt_strokes = gt_strokes.to(self.device) | |
gt_im = gt_im.to(self.device) | |
out = self.model(gt_im) | |
kl_loss = self.kl_loss( | |
out["mu"], out["log_sigma"]) | |
kl_weight = self.kl_weight * self.kl_scaling() | |
# add start of sequence | |
sos = gt_strokes[:, :1] | |
sketch = th.cat([sos, out["decoded_sketch"]], 1) | |
vae_im = out["decoded_im"] | |
# start = time.time() | |
sketch_im = rendering.opacityStroke2diffvg( | |
sketch, canvas_size=self.raster_resolution, debug=False, | |
force_cpu=True, relative=not self.absolute_coords) | |
# elapsed = (time.time() - start)*1000 | |
# print("out rendering took %.2fms" % elapsed) | |
vae_im_loss = self.im_loss(vae_im, gt_im) | |
sketch_im_loss = self.im_loss(sketch_im, gt_im) | |
# vae_im_loss = th.nn.functional.mse_loss(vae_im, gt_im) | |
# sketch_im_loss = th.nn.functional.mse_loss(sketch_im, gt_im) | |
loss = vae_im_loss + kl_loss*kl_weight + sketch_im_loss | |
self.im_enc_opt.zero_grad() | |
self.im_dec_opt.zero_grad() | |
self.sketch_dec_opt.zero_grad() | |
loss.backward() | |
# clip gradients | |
enc_nrm = th.nn.utils.clip_grad_norm_( | |
self.model.im_encoder.parameters(), self.grad_clip) | |
dec_nrm = th.nn.utils.clip_grad_norm_( | |
self.model.im_decoder.parameters(), self.grad_clip) | |
sketch_dec_nrm = th.nn.utils.clip_grad_norm_( | |
self.model.sketch_decoder.parameters(), self.grad_clip) | |
if enc_nrm > self.grad_clip: | |
LOG.debug("Clipped encoder gradient (%.5f) to %.2f", | |
enc_nrm, self.grad_clip) | |
if dec_nrm > self.grad_clip: | |
LOG.debug("Clipped decoder gradient (%.5f) to %.2f", | |
dec_nrm, self.grad_clip) | |
if sketch_dec_nrm > self.grad_clip: | |
LOG.debug("Clipped sketch decoder gradient (%.5f) to %.2f", | |
sketch_dec_nrm, self.grad_clip) | |
self.im_enc_opt.step() | |
self.im_dec_opt.step() | |
self.sketch_dec_opt.step() | |
return { | |
"vae_image": vae_im, | |
"sketch_image": sketch_im, | |
"gt_image": gt_im, | |
"loss": loss.item(), | |
"vae_im_loss": vae_im_loss.item(), | |
"sketch_im_loss": sketch_im_loss.item(), | |
"kl_loss": kl_loss.item(), | |
"kl_weight": kl_weight, | |
"lr": self.im_enc_opt.param_groups[0]["lr"], | |
} | |
def init_validation(self): | |
return dict(sample=None) | |
def validation_step(self, batch, running_data): | |
# Switch to eval mode for dropout, batchnorm, etc | |
# self.model.eval() | |
# with th.no_grad(): | |
# # sample = self.model.sample( | |
# # batch.to(self.device), temperature=self.sampling_temperature) | |
# # running_data["sample"] = sample | |
# self.model.train() | |
return running_data | |
def train(args): | |
th.manual_seed(0) | |
np.random.seed(0) | |
dataset = data.FixedLengthQuickDrawDataset( | |
args.dataset, max_seq_length=args.sequence_length, | |
canvas_size=args.raster_resolution) | |
dataloader = DataLoader( | |
dataset, batch_size=args.bs, num_workers=args.workers, shuffle=True) | |
# val_dataset = [s for idx, s in enumerate(dataset) if idx < 8] | |
# val_dataloader = DataLoader( | |
# val_dataset, batch_size=8, num_workers=4, shuffle=False) | |
val_dataloader = None | |
model_params = { | |
"zdim": args.zdim, | |
"sequence_length": args.sequence_length, | |
"image_size": args.raster_resolution, | |
# "encoder_dim": args.encoder_dim, | |
# "decoder_dim": args.decoder_dim, | |
} | |
model = SketchVAE(**model_params) | |
model.train() | |
LOG.info("Model parameters:\n%s", model_params) | |
device = "cpu" | |
if th.cuda.is_available(): | |
device = "cuda" | |
LOG.info("Using CUDA") | |
interface = Interface(model, raster_resolution=args.raster_resolution, | |
lr=args.lr, lr_decay=args.lr_decay, | |
kl_decay=args.kl_decay, kl_weight=args.kl_weight, | |
absolute_coords=args.absolute_coordinates, | |
device=device) | |
env_name = "sketch_vae" | |
if args.custom_name is not None: | |
env_name += "_" + args.custom_name | |
if args.absolute_coordinates: | |
env_name += "_abs_coords" | |
chkpt = os.path.join(OUTPUT, env_name) | |
# Resume from checkpoint, if any | |
checkpointer = ttools.Checkpointer( | |
chkpt, model, meta=model_params, | |
optimizers=interface.optimizers(), | |
schedulers=interface.schedulers) | |
extras, meta = checkpointer.load_latest() | |
epoch = extras["epoch"] if extras and "epoch" in extras.keys() else 0 | |
if meta is not None and meta != model_params: | |
LOG.info("Checkpoint's metaparams differ " | |
"from CLI, aborting: %s and %s", meta, model_params) | |
trainer = ttools.Trainer(interface) | |
# Add callbacks | |
losses = ["loss", "kl_loss", "vae_im_loss", "sketch_im_loss"] | |
training_debug = ["lr", "kl_weight"] | |
trainer.add_callback(ttools.callbacks.ProgressBarCallback( | |
keys=losses, val_keys=None)) | |
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback( | |
keys=losses, val_keys=None, env=env_name, port=args.port)) | |
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback( | |
keys=training_debug, smoothing=0, val_keys=None, env=env_name, | |
port=args.port)) | |
trainer.add_callback(ttools.callbacks.CheckpointingCallback( | |
checkpointer, max_files=2, interval=600, max_epochs=10)) | |
trainer.add_callback( | |
ttools.callbacks.LRSchedulerCallback(interface.schedulers)) | |
trainer.add_callback(SketchVAECallback( | |
env=env_name, win="samples", port=args.port, frequency=args.freq)) | |
# Start training | |
trainer.train(dataloader, starting_epoch=epoch, | |
val_dataloader=val_dataloader, | |
num_epochs=args.num_epochs) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--dataset", default="cat.npz") | |
parser.add_argument("--absolute_coordinates", action="store_true", | |
default=False) | |
parser.add_argument("--custom_name") | |
# Training params | |
parser.add_argument("--bs", type=int, default=1) | |
parser.add_argument("--workers", type=int, default=0) | |
parser.add_argument("--num_epochs", type=int, default=10000) | |
parser.add_argument("--lr", type=float, default=1e-4) | |
parser.add_argument("--lr_decay", type=float, default=0.9999) | |
parser.add_argument("--kl_weight", type=float, default=0.5) | |
parser.add_argument("--kl_decay", type=float, default=0.99995) | |
# Model configuration | |
parser.add_argument("--zdim", type=int, default=128) | |
parser.add_argument("--sequence_length", type=int, default=50) | |
parser.add_argument("--raster_resolution", type=int, default=64) | |
# parser.add_argument("--encoder_dim", type=int, default=256) | |
# parser.add_argument("--decoder_dim", type=int, default=512) | |
# Viz params | |
parser.add_argument("--freq", type=int, default=10) | |
parser.add_argument("--port", type=int, default=5000) | |
args = parser.parse_args() | |
pydiffvg.set_use_gpu(False) | |
train(args) | |