Spaces:
Sleeping
Sleeping
"""A simple training interface using ttools.""" | |
import argparse | |
import os | |
import logging | |
import random | |
import numpy as np | |
import torch | |
from torchvision.datasets import MNIST | |
import torchvision.transforms as xforms | |
from torch.utils.data import DataLoader | |
import ttools | |
import ttools.interfaces | |
import pydiffvg | |
LOG = ttools.get_logger(__name__) | |
pydiffvg.render_pytorch.print_timing = False | |
torch.manual_seed(123) | |
np.random.seed(123) | |
torch.backends.cudnn.deterministic = True | |
latent_dim = 100 | |
img_size = 32 | |
num_paths = 8 | |
num_segments = 8 | |
def weights_init_normal(m): | |
classname = m.__class__.__name__ | |
if classname.find("Conv") != -1: | |
torch.nn.init.normal_(m.weight.data, 0.0, 0.02) | |
elif classname.find("BatchNorm2d") != -1: | |
torch.nn.init.normal_(m.weight.data, 1.0, 0.02) | |
torch.nn.init.constant_(m.bias.data, 0.0) | |
class VisdomImageCallback(ttools.callbacks.ImageDisplayCallback): | |
def visualized_image(self, batch, fwd_result): | |
return torch.cat([batch[0], fwd_result.cpu()], dim = 2) | |
# From https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/dcgan/dcgan.py | |
class Generator(torch.nn.Module): | |
def __init__(self): | |
super(Generator, self).__init__() | |
self.fc = torch.nn.Sequential( | |
torch.nn.Linear(latent_dim, 128), | |
torch.nn.LeakyReLU(0.2, inplace=True), | |
torch.nn.Linear(128, 256), | |
torch.nn.LeakyReLU(0.2, inplace=True), | |
torch.nn.Linear(256, 512), | |
torch.nn.LeakyReLU(0.2, inplace=True), | |
torch.nn.Linear(512, 1024), | |
torch.nn.LeakyReLU(0.2, inplace=True), | |
torch.nn.Linear(1024, 2 * num_paths * (num_segments + 1) + num_paths + num_paths), | |
torch.nn.Sigmoid() | |
) | |
def forward(self, z): | |
out = self.fc(z) | |
# construct paths | |
imgs = [] | |
for b in range(out.shape[0]): | |
index = 0 | |
shapes = [] | |
shape_groups = [] | |
for i in range(num_paths): | |
points = img_size * out[b, index: index + 2 * (num_segments + 1)].view(-1, 2).cpu() | |
index += 2 * (num_segments + 1) | |
stroke_width = img_size * out[b, index].view(1).cpu() | |
index += 1 | |
num_control_points = torch.zeros(num_segments, dtype = torch.int32) + 2 | |
path = pydiffvg.Path(num_control_points = num_control_points, | |
points = points, | |
stroke_width = stroke_width, | |
is_closed = False) | |
shapes.append(path) | |
stroke_color = out[b, index].view(1).cpu() | |
index += 1 | |
stroke_color = torch.cat([stroke_color, torch.tensor([0.0, 0.0, 1.0])]) | |
path_group = pydiffvg.ShapeGroup(shape_ids = torch.tensor([len(shapes) - 1]), | |
fill_color = None, | |
stroke_color = stroke_color) | |
shape_groups.append(path_group) | |
scene_args = pydiffvg.RenderFunction.serialize_scene(img_size, img_size, shapes, shape_groups) | |
render = pydiffvg.RenderFunction.apply | |
img = render(img_size, # width | |
img_size, # height | |
2, # num_samples_x | |
2, # num_samples_y | |
random.randint(0, 1048576), # seed | |
None, | |
*scene_args) | |
img = img[:, :, :1] | |
# HWC -> NCHW | |
img = img.unsqueeze(0) | |
img = img.permute(0, 3, 1, 2) # NHWC -> NCHW | |
imgs.append(img) | |
img = torch.cat(imgs, dim = 0) | |
return img | |
class Discriminator(torch.nn.Module): | |
def __init__(self): | |
super(Discriminator, self).__init__() | |
def discriminator_block(in_filters, out_filters, bn=True): | |
block = [torch.nn.Conv2d(in_filters, out_filters, 3, 2, 1), | |
torch.nn.LeakyReLU(0.2, inplace=True), | |
torch.nn.Dropout2d(0.25)] | |
if bn: | |
block.append(torch.nn.BatchNorm2d(out_filters, 0.8)) | |
return block | |
self.model = torch.nn.Sequential( | |
*discriminator_block(1, 16, bn=False), | |
*discriminator_block(16, 32), | |
*discriminator_block(32, 64), | |
*discriminator_block(64, 128), | |
) | |
# The height and width of downsampled image | |
ds_size = img_size // 2 ** 4 | |
self.adv_layer = torch.nn.Sequential( | |
torch.nn.Linear(128 * ds_size ** 2, 1), | |
torch.nn.Sigmoid()) | |
def forward(self, img): | |
out = self.model(img) | |
out = out.view(out.shape[0], -1) | |
validity = self.adv_layer(out) | |
return validity | |
class MNISTInterface(ttools.interfaces.SGANInterface): | |
"""An adapter to run or train a model.""" | |
def __init__(self, gen, discrim, lr=2e-4): | |
super(MNISTInterface, self).__init__(gen, discrim, lr, opt = 'adam') | |
def forward(self, batch): | |
return self.gen(torch.zeros([batch[0].shape[0], latent_dim], device = self.device).normal_()) | |
def _discriminator_input(self, batch, fwd_data, fake=False): | |
if fake: | |
return fwd_data | |
else: | |
return batch[0].to(self.device) | |
def train(args): | |
"""Train a MNIST classifier.""" | |
# Setup train and val data | |
_xform = xforms.Compose([xforms.Resize([32, 32]), xforms.ToTensor()]) | |
data = MNIST("data/mnist", train=True, download=True, transform=_xform) | |
# Initialize asynchronous dataloaders | |
loader = DataLoader(data, batch_size=args.bs, num_workers=2) | |
# Instantiate the models | |
gen = Generator() | |
discrim = Discriminator() | |
gen.apply(weights_init_normal) | |
discrim.apply(weights_init_normal) | |
# Checkpointer to save/recall model parameters | |
checkpointer_gen = ttools.Checkpointer(os.path.join(args.out, "checkpoints"), model=gen, prefix="gen_") | |
checkpointer_discrim = ttools.Checkpointer(os.path.join(args.out, "checkpoints"), model=discrim, prefix="discrim_") | |
# resume from a previous checkpoint, if any | |
checkpointer_gen.load_latest() | |
checkpointer_discrim.load_latest() | |
# Setup a training interface for the model | |
interface = MNISTInterface(gen, discrim, lr=args.lr) | |
# Create a training looper with the interface we defined | |
trainer = ttools.Trainer(interface) | |
# Adds several callbacks, that will be called by the trainer -------------- | |
# A periodic checkpointing operation | |
trainer.add_callback(ttools.callbacks.CheckpointingCallback(checkpointer_gen)) | |
trainer.add_callback(ttools.callbacks.CheckpointingCallback(checkpointer_discrim)) | |
# A simple progress bar | |
trainer.add_callback(ttools.callbacks.ProgressBarCallback( | |
keys=["loss_g", "loss_d", "loss"])) | |
# A volatile logging using visdom | |
trainer.add_callback(ttools.callbacks.VisdomLoggingCallback( | |
keys=["loss_g", "loss_d", "loss"], | |
port=8080, env="mnist_demo")) | |
# Image | |
trainer.add_callback(VisdomImageCallback(port=8080, env="mnist_demo")) | |
# ------------------------------------------------------------------------- | |
# Start the training | |
LOG.info("Training started, press Ctrl-C to interrupt.") | |
trainer.train(loader, num_epochs=args.epochs) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
# TODO: subparsers | |
parser.add_argument("data", help="directory where we download and store the MNIST dataset.") | |
parser.add_argument("out", help="directory where we write the checkpoints and visualizations.") | |
parser.add_argument("--lr", type=float, default=1e-4, help="learning rate for the optimizer.") | |
parser.add_argument("--epochs", type=int, default=500, help="number of epochs to train for.") | |
parser.add_argument("--bs", type=int, default=64, help="number of elements per batch.") | |
args = parser.parse_args() | |
ttools.set_logger(True) # activate debug prints | |
train(args) | |