Spaces:
Sleeping
Sleeping
"""Collection of generative models.""" | |
import torch as th | |
import ttools | |
import rendering | |
import modules | |
LOG = ttools.get_logger(__name__) | |
class BaseModel(th.nn.Module): | |
def sample_z(self, bs, device="cpu"): | |
return th.randn(bs, self.zdim).to(device) | |
class BaseVectorModel(BaseModel): | |
def get_vector(self, z): | |
_, scenes = self._forward(z) | |
return scenes | |
def _forward(self, x): | |
raise NotImplementedError() | |
def forward(self, z): | |
# Only return the raster | |
return self._forward(z)[0] | |
class BezierVectorGenerator(BaseVectorModel): | |
NUM_SEGMENTS = 2 | |
def __init__(self, num_strokes=4, | |
zdim=128, width=32, imsize=32, | |
color_output=False, | |
stroke_width=None): | |
super(BezierVectorGenerator, self).__init__() | |
if stroke_width is None: | |
self.stroke_width = (0.5, 3.0) | |
LOG.warning("Setting default stroke with %s", self.stroke_width) | |
else: | |
self.stroke_width = stroke_width | |
self.imsize = imsize | |
self.num_strokes = num_strokes | |
self.zdim = zdim | |
self.trunk = th.nn.Sequential( | |
th.nn.Linear(zdim, width), | |
th.nn.SELU(inplace=True), | |
th.nn.Linear(width, 2*width), | |
th.nn.SELU(inplace=True), | |
th.nn.Linear(2*width, 4*width), | |
th.nn.SELU(inplace=True), | |
th.nn.Linear(4*width, 8*width), | |
th.nn.SELU(inplace=True), | |
) | |
# 4 points bezier with n_segments -> 3*n_segments + 1 points | |
self.point_predictor = th.nn.Sequential( | |
th.nn.Linear(8*width, | |
2*self.num_strokes*( | |
BezierVectorGenerator.NUM_SEGMENTS*3 + 1)), | |
th.nn.Tanh() # bound spatial extent | |
) | |
self.width_predictor = th.nn.Sequential( | |
th.nn.Linear(8*width, self.num_strokes), | |
th.nn.Sigmoid() | |
) | |
self.alpha_predictor = th.nn.Sequential( | |
th.nn.Linear(8*width, self.num_strokes), | |
th.nn.Sigmoid() | |
) | |
self.color_predictor = None | |
if color_output: | |
self.color_predictor = th.nn.Sequential( | |
th.nn.Linear(8*width, 3*self.num_strokes), | |
th.nn.Sigmoid() | |
) | |
def _forward(self, z): | |
bs = z.shape[0] | |
feats = self.trunk(z) | |
all_points = self.point_predictor(feats) | |
all_alphas = self.alpha_predictor(feats) | |
if self.color_predictor: | |
all_colors = self.color_predictor(feats) | |
all_colors = all_colors.view(bs, self.num_strokes, 3) | |
else: | |
all_colors = None | |
all_widths = self.width_predictor(feats) | |
min_width = self.stroke_width[0] | |
max_width = self.stroke_width[1] | |
all_widths = (max_width - min_width) * all_widths + min_width | |
all_points = all_points.view( | |
bs, self.num_strokes, BezierVectorGenerator.NUM_SEGMENTS*3+1, 2) | |
output, scenes = rendering.bezier_render(all_points, all_widths, all_alphas, | |
colors=all_colors, | |
canvas_size=self.imsize) | |
# map to [-1, 1] | |
output = output*2.0 - 1.0 | |
return output, scenes | |
class VectorGenerator(BaseVectorModel): | |
def __init__(self, num_strokes=4, | |
zdim=128, width=32, imsize=32, | |
color_output=False, | |
stroke_width=None): | |
super(VectorGenerator, self).__init__() | |
if stroke_width is None: | |
self.stroke_width = (0.5, 3.0) | |
LOG.warning("Setting default stroke with %s", self.stroke_width) | |
else: | |
self.stroke_width = stroke_width | |
self.imsize = imsize | |
self.num_strokes = num_strokes | |
self.zdim = zdim | |
self.trunk = th.nn.Sequential( | |
th.nn.Linear(zdim, width), | |
th.nn.SELU(inplace=True), | |
th.nn.Linear(width, 2*width), | |
th.nn.SELU(inplace=True), | |
th.nn.Linear(2*width, 4*width), | |
th.nn.SELU(inplace=True), | |
th.nn.Linear(4*width, 8*width), | |
th.nn.SELU(inplace=True), | |
) | |
# straight lines so n_segments -> n_segments - 1 points | |
self.point_predictor = th.nn.Sequential( | |
th.nn.Linear(8*width, 2*(self.num_strokes*2)), | |
th.nn.Tanh() # bound spatial extent | |
) | |
self.width_predictor = th.nn.Sequential( | |
th.nn.Linear(8*width, self.num_strokes), | |
th.nn.Sigmoid() | |
) | |
self.alpha_predictor = th.nn.Sequential( | |
th.nn.Linear(8*width, self.num_strokes), | |
th.nn.Sigmoid() | |
) | |
self.color_predictor = None | |
if color_output: | |
self.color_predictor = th.nn.Sequential( | |
th.nn.Linear(8*width, 3*self.num_strokes), | |
th.nn.Sigmoid() | |
) | |
def _forward(self, z): | |
bs = z.shape[0] | |
feats = self.trunk(z) | |
all_points = self.point_predictor(feats) | |
all_alphas = self.alpha_predictor(feats) | |
if self.color_predictor: | |
all_colors = self.color_predictor(feats) | |
all_colors = all_colors.view(bs, self.num_strokes, 3) | |
else: | |
all_colors = None | |
all_widths = self.width_predictor(feats) | |
min_width = self.stroke_width[0] | |
max_width = self.stroke_width[1] | |
all_widths = (max_width - min_width) * all_widths + min_width | |
all_points = all_points.view(bs, self.num_strokes, 2, 2) | |
output, scenes = rendering.line_render(all_points, all_widths, all_alphas, | |
colors=all_colors, | |
canvas_size=self.imsize) | |
# map to [-1, 1] | |
output = output*2.0 - 1.0 | |
return output, scenes | |
class RNNVectorGenerator(BaseVectorModel): | |
def __init__(self, num_strokes=64, | |
zdim=128, width=32, imsize=32, | |
hidden_size=512, dropout=0.9, | |
color_output=False, | |
num_layers=3, stroke_width=None): | |
super(RNNVectorGenerator, self).__init__() | |
if stroke_width is None: | |
self.stroke_width = (0.5, 3.0) | |
LOG.warning("Setting default stroke with %s", self.stroke_width) | |
else: | |
self.stroke_width = stroke_width | |
self.num_layers = num_layers | |
self.imsize = imsize | |
self.num_strokes = num_strokes | |
self.hidden_size = hidden_size | |
self.zdim = zdim | |
self.hidden_cell_predictor = th.nn.Linear( | |
zdim, 2*hidden_size*num_layers) | |
self.lstm = th.nn.LSTM( | |
zdim, hidden_size, | |
num_layers=self.num_layers, dropout=dropout, | |
batch_first=True) | |
# straight lines so n_segments -> n_segments - 1 points | |
self.point_predictor = th.nn.Sequential( | |
th.nn.Linear(hidden_size, 2*2), # 2 points, (x,y) | |
th.nn.Tanh() # bound spatial extent | |
) | |
self.width_predictor = th.nn.Sequential( | |
th.nn.Linear(hidden_size, 1), | |
th.nn.Sigmoid() | |
) | |
self.alpha_predictor = th.nn.Sequential( | |
th.nn.Linear(hidden_size, 1), | |
th.nn.Sigmoid() | |
) | |
def _forward(self, z, hidden_and_cell=None): | |
steps = self.num_strokes | |
# z is passed at each step, duplicate it | |
bs = z.shape[0] | |
expanded_z = z.unsqueeze(1).repeat(1, steps, 1) | |
# First step in the RNN | |
if hidden_and_cell is None: | |
# Initialize from latent vector | |
hidden_and_cell = self.hidden_cell_predictor(th.tanh(z)) | |
hidden = hidden_and_cell[:, :self.hidden_size*self.num_layers] | |
hidden = hidden.view(-1, self.num_layers, self.hidden_size) | |
hidden = hidden.permute(1, 0, 2).contiguous() | |
cell = hidden_and_cell[:, self.hidden_size*self.num_layers:] | |
cell = cell.view(-1, self.num_layers, self.hidden_size) | |
cell = cell.permute(1, 0, 2).contiguous() | |
hidden_and_cell = (hidden, cell) | |
feats, hidden_and_cell = self.lstm(expanded_z, hidden_and_cell) | |
hidden, cell = hidden_and_cell | |
feats = feats.reshape(bs*steps, self.hidden_size) | |
all_points = self.point_predictor(feats).view(bs, steps, 2, 2) | |
all_alphas = self.alpha_predictor(feats).view(bs, steps) | |
all_widths = self.width_predictor(feats).view(bs, steps) | |
min_width = self.stroke_width[0] | |
max_width = self.stroke_width[1] | |
all_widths = (max_width - min_width) * all_widths + min_width | |
output, scenes = rendering.line_render(all_points, all_widths, all_alphas, | |
canvas_size=self.imsize) | |
# map to [-1, 1] | |
output = output*2.0 - 1.0 | |
return output, scenes | |
class ChainRNNVectorGenerator(BaseVectorModel): | |
"""Strokes form a single long chain.""" | |
def __init__(self, num_strokes=64, | |
zdim=128, width=32, imsize=32, | |
hidden_size=512, dropout=0.9, | |
color_output=False, | |
num_layers=3, stroke_width=None): | |
super(ChainRNNVectorGenerator, self).__init__() | |
if stroke_width is None: | |
self.stroke_width = (0.5, 3.0) | |
LOG.warning("Setting default stroke with %s", self.stroke_width) | |
else: | |
self.stroke_width = stroke_width | |
self.num_layers = num_layers | |
self.imsize = imsize | |
self.num_strokes = num_strokes | |
self.hidden_size = hidden_size | |
self.zdim = zdim | |
self.hidden_cell_predictor = th.nn.Linear( | |
zdim, 2*hidden_size*num_layers) | |
self.lstm = th.nn.LSTM( | |
zdim, hidden_size, | |
num_layers=self.num_layers, dropout=dropout, | |
batch_first=True) | |
# straight lines so n_segments -> n_segments - 1 points | |
self.point_predictor = th.nn.Sequential( | |
th.nn.Linear(hidden_size, 2), # 1 point, (x,y) | |
th.nn.Tanh() # bound spatial extent | |
) | |
self.width_predictor = th.nn.Sequential( | |
th.nn.Linear(hidden_size, 1), | |
th.nn.Sigmoid() | |
) | |
self.alpha_predictor = th.nn.Sequential( | |
th.nn.Linear(hidden_size, 1), | |
th.nn.Sigmoid() | |
) | |
def _forward(self, z, hidden_and_cell=None): | |
steps = self.num_strokes | |
# z is passed at each step, duplicate it | |
bs = z.shape[0] | |
expanded_z = z.unsqueeze(1).repeat(1, steps, 1) | |
# First step in the RNN | |
if hidden_and_cell is None: | |
# Initialize from latent vector | |
hidden_and_cell = self.hidden_cell_predictor(th.tanh(z)) | |
hidden = hidden_and_cell[:, :self.hidden_size*self.num_layers] | |
hidden = hidden.view(-1, self.num_layers, self.hidden_size) | |
hidden = hidden.permute(1, 0, 2).contiguous() | |
cell = hidden_and_cell[:, self.hidden_size*self.num_layers:] | |
cell = cell.view(-1, self.num_layers, self.hidden_size) | |
cell = cell.permute(1, 0, 2).contiguous() | |
hidden_and_cell = (hidden, cell) | |
feats, hidden_and_cell = self.lstm(expanded_z, hidden_and_cell) | |
hidden, cell = hidden_and_cell | |
feats = feats.reshape(bs*steps, self.hidden_size) | |
# Construct the chain | |
end_points = self.point_predictor(feats).view(bs, steps, 1, 2) | |
start_points = th.cat([ | |
# first point is canvas center | |
th.zeros(bs, 1, 1, 2, device=feats.device), | |
end_points[:, 1:, :, :]], 1) | |
all_points = th.cat([start_points, end_points], 2) | |
all_alphas = self.alpha_predictor(feats).view(bs, steps) | |
all_widths = self.width_predictor(feats).view(bs, steps) | |
min_width = self.stroke_width[0] | |
max_width = self.stroke_width[1] | |
all_widths = (max_width - min_width) * all_widths + min_width | |
output, scenes = rendering.line_render(all_points, all_widths, all_alphas, | |
canvas_size=self.imsize) | |
# map to [-1, 1] | |
output = output*2.0 - 1.0 | |
return output, scenes | |
class Generator(BaseModel): | |
def __init__(self, width=64, imsize=32, zdim=128, | |
stroke_width=None, | |
color_output=False, | |
num_strokes=4): | |
super(Generator, self).__init__() | |
assert imsize == 32 | |
self.imsize = imsize | |
self.zdim = zdim | |
num_in_chans = self.zdim // (2*2) | |
num_out_chans = 3 if color_output else 1 | |
self.net = th.nn.Sequential( | |
th.nn.ConvTranspose2d(num_in_chans, width*8, 4, padding=1, | |
stride=2), | |
th.nn.LeakyReLU(0.2, inplace=True), | |
th.nn.Conv2d(width*8, width*8, 3, padding=1), | |
th.nn.BatchNorm2d(width*8), | |
th.nn.LeakyReLU(0.2, inplace=True), | |
# 4x4 | |
th.nn.ConvTranspose2d(8*width, 4*width, 4, padding=1, stride=2), | |
th.nn.LeakyReLU(0.2, inplace=True), | |
th.nn.Conv2d(4*width, 4*width, 3, padding=1), | |
th.nn.BatchNorm2d(width*4), | |
th.nn.LeakyReLU(0.2, inplace=True), | |
# 8x8 | |
th.nn.ConvTranspose2d(4*width, 2*width, 4, padding=1, stride=2), | |
th.nn.LeakyReLU(0.2, inplace=True), | |
th.nn.Conv2d(2*width, 2*width, 3, padding=1), | |
th.nn.BatchNorm2d(width*2), | |
th.nn.LeakyReLU(0.2, inplace=True), | |
# 16x16 | |
th.nn.ConvTranspose2d(2*width, width, 4, padding=1, stride=2), | |
th.nn.LeakyReLU(0.2, inplace=True), | |
th.nn.Conv2d(width, width, 3, padding=1), | |
th.nn.BatchNorm2d(width), | |
th.nn.LeakyReLU(0.2, inplace=True), | |
# 32x32 | |
th.nn.Conv2d(width, width, 3, padding=1), | |
th.nn.BatchNorm2d(width), | |
th.nn.LeakyReLU(0.2, inplace=True), | |
th.nn.Conv2d(width, width, 3, padding=1), | |
th.nn.LeakyReLU(0.2, inplace=True), | |
th.nn.Conv2d(width, num_out_chans, 1), | |
th.nn.Tanh(), | |
) | |
def forward(self, z): | |
bs = z.shape[0] | |
num_in_chans = self.zdim // (2*2) | |
raster = self.net(z.view(bs, num_in_chans, 2, 2)) | |
return raster | |
class Discriminator(th.nn.Module): | |
def __init__(self, conditional=False, width=64, color_output=False): | |
super(Discriminator, self).__init__() | |
self.conditional = conditional | |
sn = th.nn.utils.spectral_norm | |
num_chan_in = 3 if color_output else 1 | |
self.net = th.nn.Sequential( | |
th.nn.Conv2d(num_chan_in, width, 3, padding=1), | |
th.nn.LeakyReLU(0.2, inplace=True), | |
th.nn.Conv2d(width, 2*width, 4, padding=1, stride=2), | |
th.nn.LeakyReLU(0.2, inplace=True), | |
# 16x16 | |
sn(th.nn.Conv2d(2*width, 2*width, 3, padding=1)), | |
th.nn.LeakyReLU(0.2, inplace=True), | |
sn(th.nn.Conv2d(2*width, 4*width, 4, padding=1, stride=2)), | |
th.nn.LeakyReLU(0.2, inplace=True), | |
# 8x8 | |
sn(th.nn.Conv2d(4*width, 4*width, 3, padding=1)), | |
th.nn.LeakyReLU(0.2, inplace=True), | |
sn(th.nn.Conv2d(4*width, width*4, 4, padding=1, stride=2)), | |
th.nn.LeakyReLU(0.2, inplace=True), | |
# 4x4 | |
sn(th.nn.Conv2d(4*width, 4*width, 3, padding=1)), | |
th.nn.LeakyReLU(0.2, inplace=True), | |
sn(th.nn.Conv2d(4*width, width*4, 4, padding=1, stride=2)), | |
th.nn.LeakyReLU(0.2, inplace=True), | |
# 2x2 | |
modules.Flatten(), | |
th.nn.Linear(width*4*2*2, 1), | |
) | |
def forward(self, x): | |
out = self.net(x) | |
return out | |