M3000j's picture
Upload folder using huggingface_hub
31726e5 verified
#!/bin/env python
"""Train a GAN.
Usage:
* Train a MNIST model:
`python train_gan.py`
* Train a Quickdraw model:
`python train_gan.py --task quickdraw`
"""
import argparse
import os
import numpy as np
import torch as th
from torch.utils.data import DataLoader
import ttools
import ttools.interfaces
import losses
import data
import models
import pydiffvg
LOG = ttools.get_logger(__name__)
BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir)
OUTPUT = os.path.join(BASE_DIR, "results")
class Callback(ttools.callbacks.ImageDisplayCallback):
"""Simple callback that visualize images."""
def visualized_image(self, batch, step_data, is_val=False):
if is_val:
return
gen = step_data["gen_image"][:16].detach()
ref = step_data["gt_image"][:16].detach()
# tensor to visualize, concatenate images
vizdata = th.cat([ref, gen], 2)
vector = step_data["vector_image"]
if vector is not None:
vector = vector[:16].detach()
vizdata = th.cat([vizdata, vector], 2)
vizdata = (vizdata + 1.0 ) * 0.5
viz = th.clamp(vizdata, 0, 1)
return viz
def caption(self, batch, step_data, is_val=False):
if step_data["vector_image"] is not None:
s = "top: real, middle: raster, bottom: vector"
else:
s = "top: real, bottom: fake"
return s
class Interface(ttools.ModelInterface):
def __init__(self, generator, vect_generator,
discriminator, vect_discriminator,
lr=1e-4, lr_decay=0.9999,
gradient_penalty=10,
wgan_gp=False,
raster_resolution=32, device="cpu", grad_clip=1.0):
super(Interface, self).__init__()
self.wgan_gp = wgan_gp
self.w_gradient_penalty = gradient_penalty
self.n_critic = 1
if self.wgan_gp:
self.n_critic = 5
self.grad_clip = grad_clip
self.raster_resolution = raster_resolution
self.gen = generator
self.vect_gen = vect_generator
self.discrim = discriminator
self.vect_discrim = vect_discriminator
self.device = device
self.gen.to(self.device)
self.discrim.to(self.device)
beta1 = 0.5
beta2 = 0.9
self.gen_opt = th.optim.Adam(
self.gen.parameters(), lr=lr, betas=(beta1, beta2))
self.discrim_opt = th.optim.Adam(
self.discrim.parameters(), lr=lr, betas=(beta1, beta2))
self.schedulers = [
th.optim.lr_scheduler.ExponentialLR(self.gen_opt, lr_decay),
th.optim.lr_scheduler.ExponentialLR(self.discrim_opt, lr_decay),
]
self.optimizers = [self.gen_opt, self.discrim_opt]
if self.vect_gen is not None:
assert self.vect_discrim is not None
self.vect_gen.to(self.device)
self.vect_discrim.to(self.device)
self.vect_gen_opt = th.optim.Adam(
self.vect_gen.parameters(), lr=lr, betas=(beta1, beta2))
self.vect_discrim_opt = th.optim.Adam(
self.vect_discrim.parameters(), lr=lr, betas=(beta1, beta2))
self.schedulers += [
th.optim.lr_scheduler.ExponentialLR(self.vect_gen_opt,
lr_decay),
th.optim.lr_scheduler.ExponentialLR(self.vect_discrim_opt,
lr_decay),
]
self.optimizers += [self.vect_gen_opt, self.vect_discrim_opt]
# include loss on alpha
self.im_loss = losses.MultiscaleMSELoss(channels=4).to(self.device)
self.iter = 0
self.cross_entropy = th.nn.BCEWithLogitsLoss()
self.mse = th.nn.MSELoss()
def _gradient_penalty(self, discrim, fake, real):
bs = real.size(0)
epsilon = th.rand(bs, 1, 1, 1, device=real.device)
epsilon = epsilon.expand_as(real)
interpolation = epsilon * real.data + (1 - epsilon) * fake.data
interpolation = th.autograd.Variable(interpolation, requires_grad=True)
interpolation_logits = discrim(interpolation)
grad_outputs = th.ones(interpolation_logits.size(), device=real.device)
gradients = th.autograd.grad(outputs=interpolation_logits,
inputs=interpolation,
grad_outputs=grad_outputs,
create_graph=True, retain_graph=True)[0]
gradients = gradients.view(bs, -1)
gradients_norm = th.sqrt(th.sum(gradients ** 2, dim=1) + 1e-12)
# [Tanh-Tung 2019] https://openreview.net/pdf?id=ByxPYjC5KQ
return self.w_gradient_penalty * ((gradients_norm - 0) ** 2).mean()
# return self.w_gradient_penalty * ((gradients_norm - 1) ** 2).mean()
def _discriminator_step(self, discrim, opt, fake, real):
"""Try to classify fake as 0 and real as 1."""
opt.zero_grad()
# no backprop to gen
fake = fake.detach()
fake_pred = discrim(fake)
real_pred = discrim(real)
if self.wgan_gp:
gradient_penalty = self._gradient_penalty(discrim, fake, real)
loss_d = fake_pred.mean() - real_pred.mean() + gradient_penalty
gradient_penalty = gradient_penalty.item()
else:
fake_loss = self.cross_entropy(fake_pred, th.zeros_like(fake_pred))
real_loss = self.cross_entropy(real_pred, th.ones_like(real_pred))
# fake_loss = self.mse(fake_pred, th.zeros_like(fake_pred))
# real_loss = self.mse(real_pred, th.ones_like(real_pred))
loss_d = 0.5*(fake_loss + real_loss)
gradient_penalty = None
loss_d.backward()
nrm = th.nn.utils.clip_grad_norm_(
discrim.parameters(), self.grad_clip)
if nrm > self.grad_clip:
LOG.debug("Clipped discriminator gradient (%.5f) to %.2f",
nrm, self.grad_clip)
opt.step()
return loss_d.item(), gradient_penalty
def _generator_step(self, gen, discrim, opt, fake):
"""Try to classify fake as 1."""
opt.zero_grad()
fake_pred = discrim(fake)
if self.wgan_gp:
loss_g = -fake_pred.mean()
else:
loss_g = self.cross_entropy(fake_pred, th.ones_like(fake_pred))
# loss_g = self.mse(fake_pred, th.ones_like(fake_pred))
loss_g.backward()
# clip gradients
nrm = th.nn.utils.clip_grad_norm_(
gen.parameters(), self.grad_clip)
if nrm > self.grad_clip:
LOG.debug("Clipped generator gradient (%.5f) to %.2f",
nrm, self.grad_clip)
opt.step()
return loss_g.item()
def training_step(self, batch):
im = batch
im = im.to(self.device)
z = self.gen.sample_z(im.shape[0], device=self.device)
generated = self.gen(z)
vect_generated = None
if self.vect_gen is not None:
vect_generated = self.vect_gen(z)
loss_g = None
loss_d = None
loss_g_vect = None
loss_d_vect = None
gp = None
gp_vect = None
if self.iter < self.n_critic: # Discriminator update
self.iter += 1
loss_d, gp = self._discriminator_step(
self.discrim, self.discrim_opt, generated, im)
if vect_generated is not None:
loss_d_vect, gp_vect = self._discriminator_step(
self.vect_discrim, self.vect_discrim_opt, vect_generated, im)
else: # Generator update
self.iter = 0
loss_g = self._generator_step(
self.gen, self.discrim, self.gen_opt, generated)
if vect_generated is not None:
loss_g_vect = self._generator_step(
self.vect_gen, self.vect_discrim, self.vect_gen_opt, vect_generated)
return {
"loss_g": loss_g,
"loss_d": loss_d,
"loss_g_vect": loss_g_vect,
"loss_d_vect": loss_d_vect,
"gp": gp,
"gp_vect": gp_vect,
"gt_image": im,
"gen_image": generated,
"vector_image": vect_generated,
"lr": self.gen_opt.param_groups[0]["lr"],
}
def init_validation(self):
return dict(sample=None)
def validation_step(self, batch, running_data):
# Switch to eval mode for dropout, batchnorm, etc
self.model.eval()
return running_data
def train(args):
th.manual_seed(0)
np.random.seed(0)
color_output = False
if args.task == "mnist":
dataset = data.MNISTDataset(args.raster_resolution, train=True)
elif args.task == "quickdraw":
dataset = data.QuickDrawImageDataset(
args.raster_resolution, train=True)
else:
raise NotImplementedError()
dataloader = DataLoader(
dataset, batch_size=args.bs, num_workers=args.workers, shuffle=True)
val_dataloader = None
model_params = {
"zdim": args.zdim,
"num_strokes": args.num_strokes,
"imsize": args.raster_resolution,
"stroke_width": args.stroke_width,
"color_output": color_output,
}
gen = models.Generator(**model_params)
gen.train()
discrim = models.Discriminator(color_output=color_output)
discrim.train()
if args.raster_only:
vect_gen = None
vect_discrim = None
else:
if args.generator == "fc":
vect_gen = models.VectorGenerator(**model_params)
elif args.generator == "bezier_fc":
vect_gen = models.BezierVectorGenerator(**model_params)
elif args.generator in ["rnn"]:
vect_gen = models.RNNVectorGenerator(**model_params)
elif args.generator in ["chain_rnn"]:
vect_gen = models.ChainRNNVectorGenerator(**model_params)
else:
raise NotImplementedError()
vect_gen.train()
vect_discrim = models.Discriminator(color_output=color_output)
vect_discrim.train()
LOG.info("Model parameters:\n%s", model_params)
device = "cpu"
if th.cuda.is_available():
device = "cuda"
LOG.info("Using CUDA")
interface = Interface(gen, vect_gen, discrim, vect_discrim,
raster_resolution=args.raster_resolution, lr=args.lr,
wgan_gp=args.wgan_gp,
lr_decay=args.lr_decay, device=device)
env_name = args.task + "_gan"
if args.raster_only:
env_name += "_raster"
else:
env_name += "_vector"
env_name += "_" + args.generator
if args.wgan_gp:
env_name += "_wgan"
chkpt = os.path.join(OUTPUT, env_name)
meta = {
"model_params": model_params,
"task": args.task,
"generator": args.generator,
}
checkpointer = ttools.Checkpointer(
chkpt, gen, meta=meta,
optimizers=interface.optimizers,
schedulers=interface.schedulers,
prefix="g_")
checkpointer_d = ttools.Checkpointer(
chkpt, discrim,
prefix="d_")
# Resume from checkpoint, if any
extras, _ = checkpointer.load_latest()
checkpointer_d.load_latest()
if not args.raster_only:
checkpointer_vect = ttools.Checkpointer(
chkpt, vect_gen, meta=meta,
optimizers=interface.optimizers,
schedulers=interface.schedulers,
prefix="vect_g_")
checkpointer_d_vect = ttools.Checkpointer(
chkpt, vect_discrim,
prefix="vect_d_")
extras, _ = checkpointer_vect.load_latest()
checkpointer_d_vect.load_latest()
epoch = extras["epoch"] if extras and "epoch" in extras.keys() else 0
# if meta is not None and meta["model_parameters"] != model_params:
# LOG.info("Checkpoint's metaparams differ "
# "from CLI, aborting: %s and %s", meta, model_params)
trainer = ttools.Trainer(interface)
# Add callbacks
losses = ["loss_g", "loss_d", "loss_g_vect", "loss_d_vect", "gp",
"gp_vect"]
training_debug = ["lr"]
trainer.add_callback(Callback(
env=env_name, win="samples", port=args.port, frequency=args.freq))
trainer.add_callback(ttools.callbacks.ProgressBarCallback(
keys=losses, val_keys=None))
trainer.add_callback(ttools.callbacks.MultiPlotCallback(
keys=losses, val_keys=None, env=env_name, port=args.port,
server=args.server, base_url=args.base_url,
win="losses", frequency=args.freq))
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
keys=training_debug, smoothing=0, val_keys=None, env=env_name,
server=args.server, base_url=args.base_url,
port=args.port))
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
checkpointer, max_files=2, interval=600, max_epochs=10))
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
checkpointer_d, max_files=2, interval=600, max_epochs=10))
if not args.raster_only:
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
checkpointer_vect, max_files=2, interval=600, max_epochs=10))
trainer.add_callback(ttools.callbacks.CheckpointingCallback(
checkpointer_d_vect, max_files=2, interval=600, max_epochs=10))
trainer.add_callback(
ttools.callbacks.LRSchedulerCallback(interface.schedulers))
# Start training
trainer.train(dataloader, starting_epoch=epoch,
val_dataloader=val_dataloader,
num_epochs=args.num_epochs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--task",
default="mnist",
choices=["mnist", "quickdraw"])
parser.add_argument("--generator",
default="bezier_fc",
choices=["bezier_fc", "fc", "rnn", "chain_rnn"],
help="model to use as generator")
parser.add_argument("--raster_only", action="store_true", default=False,
help="if true only train the raster baseline")
parser.add_argument("--standard_gan", dest="wgan_gp", action="store_false",
default=True,
help="if true, use regular GAN instead of WGAN")
# Training params
parser.add_argument("--bs", type=int, default=4, help="batch size")
parser.add_argument("--workers", type=int, default=4,
help="number of dataloader threads")
parser.add_argument("--num_epochs", type=int, default=200,
help="number of epochs to train for")
parser.add_argument("--lr", type=float, default=1e-4,
help="learning rate")
parser.add_argument("--lr_decay", type=float, default=0.9999,
help="exponential learning rate decay rate")
# Model configuration
parser.add_argument("--zdim", type=int, default=32,
help="latent space dimension")
parser.add_argument("--stroke_width", type=float, nargs=2,
default=(0.5, 1.5),
help="min and max stroke width")
parser.add_argument("--num_strokes", type=int, default=16,
help="number of strokes to generate")
parser.add_argument("--raster_resolution", type=int, default=32,
help="raster canvas resolution on each side")
# Viz params
parser.add_argument("--freq", type=int, default=10,
help="visualization frequency")
parser.add_argument("--port", type=int, default=8097,
help="visdom port")
parser.add_argument("--server", default=None,
help="visdom server if not local.")
parser.add_argument("--base_url", default="", help="visdom entrypoint URL")
args = parser.parse_args()
pydiffvg.set_use_gpu(False)
ttools.set_logger(False)
train(args)