Spaces:
Sleeping
Sleeping
#!/bin/env python | |
"""Train a Sketch-RNN.""" | |
import argparse | |
from enum import Enum | |
import os | |
import wget | |
import numpy as np | |
import torch as th | |
from torch.utils.data import DataLoader | |
import torchvision.datasets as dset | |
import torchvision.transforms as transforms | |
import ttools | |
import ttools.interfaces | |
from ttools.modules import networks | |
import pydiffvg | |
import rendering | |
import losses | |
import data | |
LOG = ttools.get_logger(__name__) | |
BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir) | |
OUTPUT = os.path.join(BASE_DIR, "results", "sketch_rnn_diffvg") | |
OUTPUT_BASELINE = os.path.join(BASE_DIR, "results", "sketch_rnn") | |
class SketchRNN(th.nn.Module): | |
class Encoder(th.nn.Module): | |
def __init__(self, hidden_size=512, dropout=0.9, zdim=128, | |
num_layers=1): | |
super(SketchRNN.Encoder, self).__init__() | |
self.hidden_size = hidden_size | |
self.num_layers = num_layers | |
self.zdim = zdim | |
self.lstm = th.nn.LSTM(5, hidden_size, num_layers=self.num_layers, | |
dropout=dropout, bidirectional=True, | |
batch_first=True) | |
# bidirectional model -> *2 | |
self.mu_predictor = th.nn.Linear(2*hidden_size, zdim) | |
self.sigma_predictor = th.nn.Linear(2*hidden_size, zdim) | |
def forward(self, sequences, hidden_and_cell=None): | |
bs = sequences.shape[0] | |
if hidden_and_cell is None: | |
hidden = th.zeros(self.num_layers*2, bs, self.hidden_size).to( | |
sequences.device) | |
cell = th.zeros(self.num_layers*2, bs, self.hidden_size).to( | |
sequences.device) | |
hidden_and_cell = (hidden, cell) | |
out, hidden_and_cell = self.lstm(sequences, hidden_and_cell) | |
hidden = hidden_and_cell[0] | |
# Concat the forward/backward states | |
fc_input = th.cat([hidden[0], hidden[1]], 1) | |
# VAE params | |
mu = self.mu_predictor(fc_input) | |
log_sigma = self.sigma_predictor(fc_input) | |
# Sample a latent vector | |
sigma = th.exp(log_sigma/2.0) | |
z0 = th.randn(self.zdim, device=mu.device) | |
z = mu + sigma*z0 | |
# KL divergence needs mu/sigma | |
return z, mu, log_sigma | |
class Decoder(th.nn.Module): | |
""" | |
The decoder outputs a sequence where each time step models (dx, dy) as | |
a mixture of `num_gaussians` 2D Gaussians and the state triplet is a | |
categorical distribution. | |
The model outputs at each time step: | |
- 5 parameters for each Gaussian: mu_x, mu_y, sigma_x, sigma_y, | |
rho_xy | |
- 1 logit for each Gaussian (the mixture weight) | |
- 3 logits for the state triplet probabilities | |
""" | |
def __init__(self, hidden_size=512, dropout=0.9, zdim=128, | |
num_layers=1, num_gaussians=20): | |
super(SketchRNN.Decoder, self).__init__() | |
self.hidden_size = hidden_size | |
self.num_layers = num_layers | |
self.zdim = zdim | |
self.num_gaussians = num_gaussians | |
# Maps the latent vector to an initial cell/hidden vector | |
self.hidden_cell_predictor = th.nn.Linear(zdim, 2*hidden_size) | |
self.lstm = th.nn.LSTM( | |
5 + zdim, hidden_size, | |
num_layers=self.num_layers, dropout=dropout, | |
batch_first=True) | |
self.parameters_predictor = th.nn.Linear( | |
hidden_size, num_gaussians + 5*num_gaussians + 3) | |
def forward(self, inputs, z, hidden_and_cell=None): | |
# Every step in the sequence takes the latent vector as input so we | |
# replicate it here | |
expanded_z = z.unsqueeze(1).repeat(1, inputs.shape[1], 1) | |
inputs = th.cat([inputs, expanded_z], 2) | |
bs, steps = inputs.shape[:2] | |
if hidden_and_cell is None: | |
# Initialize from latent vector | |
hidden_and_cell = self.hidden_cell_predictor(th.tanh(z)) | |
hidden = hidden_and_cell[:, :self.hidden_size] | |
hidden = hidden.unsqueeze(0).contiguous() | |
cell = hidden_and_cell[:, self.hidden_size:] | |
cell = cell.unsqueeze(0).contiguous() | |
hidden_and_cell = (hidden, cell) | |
outputs, hidden_and_cell = self.lstm(inputs, hidden_and_cell) | |
hidden, cell = hidden_and_cell | |
# if self.training: | |
# At train time we want parameters for each time step | |
outputs = outputs.reshape(bs*steps, self.hidden_size) | |
params = self.parameters_predictor(outputs).view(bs, steps, -1) | |
pen_logits = params[..., -3:] | |
gaussian_params = params[..., :-3] | |
mixture_logits = gaussian_params[..., :self.num_gaussians] | |
gaussian_params = gaussian_params[..., self.num_gaussians:].view( | |
bs, steps, self.num_gaussians, -1) | |
return pen_logits, mixture_logits, gaussian_params, hidden_and_cell | |
def __init__(self, zdim=128, num_gaussians=20, encoder_dim=256, | |
decoder_dim=512): | |
super(SketchRNN, self).__init__() | |
self.encoder = SketchRNN.Encoder(zdim=zdim, hidden_size=encoder_dim) | |
self.decoder = SketchRNN.Decoder(zdim=zdim, hidden_size=decoder_dim, | |
num_gaussians=num_gaussians) | |
def forward(self, sequences): | |
# Encode the sequences as latent vectors | |
# We skip the first time step since it is the same for all sequences: | |
# (0, 0, 1, 0, 0) | |
z, mu, log_sigma = self.encoder(sequences[:, 1:]) | |
# Decode the latent vector into a model sequence | |
# Do not process the last time step (it is an end-of-sequence token) | |
pen_logits, mixture_logits, gaussian_params, hidden_and_cell = \ | |
self.decoder(sequences[:, :-1], z) | |
return { | |
"pen_logits": pen_logits, | |
"mixture_logits": mixture_logits, | |
"gaussian_params": gaussian_params, | |
"z": z, | |
"mu": mu, | |
"log_sigma": log_sigma, | |
"hidden_and_cell": hidden_and_cell, | |
} | |
def sample(self, sequences, temperature=1.0): | |
# Compute a latent vector conditionned based on a real sequence | |
z, _, _ = self.encoder(sequences[:, 1:]) | |
start_of_seq = sequences[:, :1] | |
max_steps = sequences.shape[1] - 1 # last step is an end-of-seq token | |
output_sequences = th.zeros_like(sequences) | |
output_sequences[:, 0] = start_of_seq.squeeze(1) | |
current_input = start_of_seq | |
hidden_and_cell = None | |
for step in range(max_steps): | |
pen_logits, mixture_logits, gaussian_params, hidden_and_cell = \ | |
self.decoder(current_input, z, hidden_and_cell=hidden_and_cell) | |
# Pen and displacement state for the next step | |
next_state = th.zeros_like(current_input) | |
# Adjust temperature to control randomness | |
mixture_logits = mixture_logits*temperature | |
pen_logits = pen_logits*temperature | |
# Select one of 3 pen states | |
pen_distrib = \ | |
th.distributions.categorical.Categorical(logits=pen_logits) | |
pen_state = pen_distrib.sample() | |
# One-hot encoding of the state | |
next_state[:, :, 2:].scatter_(2, pen_state.unsqueeze(-1), | |
th.ones_like(next_state[:, :, 2:])) | |
# Select one of the Gaussians from the mixture | |
mixture_distrib = \ | |
th.distributions.categorical.Categorical(logits=mixture_logits) | |
mixture_idx = mixture_distrib.sample() | |
# select the Gaussian parameter | |
mixture_idx = mixture_idx.unsqueeze(-1).unsqueeze(-1) | |
mixture_idx = mixture_idx.repeat(1, 1, 1, 5) | |
params = th.gather(gaussian_params, 2, mixture_idx).squeeze(2) | |
# Sample a Gaussian from the corresponding Gaussian | |
mu = params[..., :2] | |
sigma_x = params[..., 2].exp() | |
sigma_y = params[..., 3].exp() | |
rho_xy = th.tanh(params[..., 4]) | |
cov = th.zeros(params.shape[0], params.shape[1], 2, 2, | |
device=params.device) | |
cov[..., 0, 0] = sigma_x.pow(2)*temperature | |
cov[..., 1, 1] = sigma_x.pow(2)*temperature | |
cov[..., 1, 0] = sigma_x*sigma_y*rho_xy*temperature | |
point_distrib = \ | |
th.distributions.multivariate_normal.MultivariateNormal( | |
mu, scale_tril=cov) | |
point = point_distrib.sample() | |
next_state[:, :, :2] = point | |
# Commit step to output | |
output_sequences[:, step + 1] = next_state.squeeze(1) | |
# Prepare next recurrent step | |
current_input = next_state | |
return output_sequences | |
class SketchRNNCallback(ttools.callbacks.ImageDisplayCallback): | |
"""Simple callback that visualize images.""" | |
def visualized_image(self, batch, step_data, is_val=False): | |
if not is_val: | |
# No need to render training data | |
return None | |
with th.no_grad(): | |
# only display the first n drawings | |
n = 8 | |
batch = batch[:n] | |
out_im = rendering.stroke2diffvg(step_data["sample"][:n]) | |
im = rendering.stroke2diffvg(batch) | |
im = th.cat([im, out_im], 2) | |
return im | |
def caption(self, batch, step_data, is_val=False): | |
if is_val: | |
return "top: truth, bottom: sample" | |
else: | |
return "top: truth, bottom: sample" | |
class Interface(ttools.ModelInterface): | |
def __init__(self, model, lr=1e-3, lr_decay=0.9999, | |
kl_weight=0.5, kl_min_weight=0.01, kl_decay=0.99995, | |
device="cpu", grad_clip=1.0, sampling_temperature=0.4): | |
super(Interface, self).__init__() | |
self.grad_clip = grad_clip | |
self.sampling_temperature = sampling_temperature | |
self.model = model | |
self.device = device | |
self.model.to(self.device) | |
self.enc_opt = th.optim.Adam(self.model.encoder.parameters(), lr=lr) | |
self.dec_opt = th.optim.Adam(self.model.decoder.parameters(), lr=lr) | |
self.kl_weight = kl_weight | |
self.kl_min_weight = kl_min_weight | |
self.kl_decay = kl_decay | |
self.kl_loss = losses.KLDivergence() | |
self.schedulers = [ | |
th.optim.lr_scheduler.ExponentialLR(self.enc_opt, lr_decay), | |
th.optim.lr_scheduler.ExponentialLR(self.dec_opt, lr_decay), | |
] | |
self.reconstruction_loss = losses.GaussianMixtureReconstructionLoss() | |
def optimizers(self): | |
return [self.enc_opt, self.dec_opt] | |
def training_step(self, batch): | |
batch = batch.to(self.device) | |
out = self.model(batch) | |
kl_loss = self.kl_loss( | |
out["mu"], out["log_sigma"]) | |
# The target to predict is the next sequence step | |
targets = batch[:, 1:].to(self.device) | |
# Scale the KL divergence weight | |
try: | |
state = self.enc_opt.state_dict()["param_groups"][0]["params"][0] | |
optim_step = self.enc_opt.state_dict()["state"][state]["step"] | |
except KeyError: | |
optim_step = 0 # no step taken yet | |
kl_scaling = 1.0 - (1.0 - | |
self.kl_min_weight)*(self.kl_decay**optim_step) | |
kl_weight = self.kl_weight * kl_scaling | |
reconstruction_loss = self.reconstruction_loss( | |
out["pen_logits"], out["mixture_logits"], | |
out["gaussian_params"], targets) | |
loss = kl_loss*self.kl_weight + reconstruction_loss | |
self.enc_opt.zero_grad() | |
self.dec_opt.zero_grad() | |
loss.backward() | |
# clip gradients | |
enc_nrm = th.nn.utils.clip_grad_norm_( | |
self.model.encoder.parameters(), self.grad_clip) | |
dec_nrm = th.nn.utils.clip_grad_norm_( | |
self.model.decoder.parameters(), self.grad_clip) | |
if enc_nrm > self.grad_clip: | |
LOG.debug("Clipped encoder gradient (%.5f) to %.2f", | |
enc_nrm, self.grad_clip) | |
if dec_nrm > self.grad_clip: | |
LOG.debug("Clipped decoder gradient (%.5f) to %.2f", | |
dec_nrm, self.grad_clip) | |
self.enc_opt.step() | |
self.dec_opt.step() | |
return { | |
"loss": loss.item(), | |
"kl_loss": kl_loss.item(), | |
"kl_weight": kl_weight, | |
"recons_loss": reconstruction_loss.item(), | |
"lr": self.enc_opt.param_groups[0]["lr"], | |
} | |
def init_validation(self): | |
return dict(sample=None) | |
def validation_step(self, batch, running_data): | |
# Switch to eval mode for dropout, batchnorm, etc | |
self.model.eval() | |
with th.no_grad(): | |
sample = self.model.sample( | |
batch.to(self.device), temperature=self.sampling_temperature) | |
running_data["sample"] = sample | |
self.model.train() | |
return running_data | |
def train(args): | |
th.manual_seed(0) | |
np.random.seed(0) | |
dataset = data.QuickDrawDataset(args.dataset) | |
dataloader = DataLoader( | |
dataset, batch_size=args.bs, num_workers=4, shuffle=True, | |
pin_memory=False) | |
val_dataset = [s for idx, s in enumerate(dataset) if idx < 8] | |
val_dataloader = DataLoader( | |
val_dataset, batch_size=8, num_workers=4, shuffle=False, | |
pin_memory=False) | |
model_params = { | |
"zdim": args.zdim, | |
"num_gaussians": args.num_gaussians, | |
"encoder_dim": args.encoder_dim, | |
"decoder_dim": args.decoder_dim, | |
} | |
model = SketchRNN(**model_params) | |
model.train() | |
device = "cpu" | |
if th.cuda.is_available(): | |
device = "cuda" | |
LOG.info("Using CUDA") | |
interface = Interface(model, lr=args.lr, lr_decay=args.lr_decay, | |
kl_decay=args.kl_decay, kl_weight=args.kl_weight, | |
sampling_temperature=args.sampling_temperature, | |
device=device) | |
chkpt = OUTPUT_BASELINE | |
env_name = "sketch_rnn" | |
# Resume from checkpoint, if any | |
checkpointer = ttools.Checkpointer( | |
chkpt, model, meta=model_params, | |
optimizers=interface.optimizers(), | |
schedulers=interface.schedulers) | |
extras, meta = checkpointer.load_latest() | |
epoch = extras["epoch"] if extras and "epoch" in extras.keys() else 0 | |
if meta is not None and meta != model_params: | |
LOG.info("Checkpoint's metaparams differ " | |
"from CLI, aborting: %s and %s", meta, model_params) | |
trainer = ttools.Trainer(interface) | |
# Add callbacks | |
losses = ["loss", "kl_loss", "recons_loss"] | |
training_debug = ["lr", "kl_weight"] | |
trainer.add_callback(ttools.callbacks.ProgressBarCallback( | |
keys=losses, val_keys=None)) | |
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback( | |
keys=losses, val_keys=None, env=env_name, port=args.port)) | |
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback( | |
keys=training_debug, smoothing=0, val_keys=None, env=env_name, | |
port=args.port)) | |
trainer.add_callback(ttools.callbacks.CheckpointingCallback( | |
checkpointer, max_files=2, interval=600, max_epochs=10)) | |
trainer.add_callback( | |
ttools.callbacks.LRSchedulerCallback(interface.schedulers)) | |
trainer.add_callback(SketchRNNCallback( | |
env=env_name, win="samples", port=args.port, frequency=args.freq)) | |
# Start training | |
trainer.train(dataloader, starting_epoch=epoch, | |
val_dataloader=val_dataloader, | |
num_epochs=args.num_epochs) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--dataset", default="cat.npz") | |
# Training params | |
parser.add_argument("--bs", type=int, default=100) | |
parser.add_argument("--num_epochs", type=int, default=10000) | |
parser.add_argument("--lr", type=float, default=1e-4) | |
parser.add_argument("--lr_decay", type=float, default=0.9999) | |
parser.add_argument("--kl_weight", type=float, default=0.5) | |
parser.add_argument("--kl_decay", type=float, default=0.99995) | |
# Model configuration | |
parser.add_argument("--zdim", type=int, default=128) | |
parser.add_argument("--num_gaussians", type=int, default=20) | |
parser.add_argument("--encoder_dim", type=int, default=256) | |
parser.add_argument("--decoder_dim", type=int, default=512) | |
parser.add_argument("--sampling_temperature", type=float, default=0.4, | |
help="controls sampling randomness. " | |
"0.0: deterministic, 1.0: unchanged") | |
# Viz params | |
parser.add_argument("--freq", type=int, default=100) | |
parser.add_argument("--port", type=int, default=5000) | |
args = parser.parse_args() | |
pydiffvg.set_use_gpu(th.cuda.is_available()) | |
train(args) | |