M3000j's picture
Upload folder using huggingface_hub
31726e5 verified
"""Collection of generative models."""
import torch as th
import ttools
import rendering
import modules
LOG = ttools.get_logger(__name__)
class BaseModel(th.nn.Module):
def sample_z(self, bs, device="cpu"):
return th.randn(bs, self.zdim).to(device)
class BaseVectorModel(BaseModel):
def get_vector(self, z):
_, scenes = self._forward(z)
return scenes
def _forward(self, x):
raise NotImplementedError()
def forward(self, z):
# Only return the raster
return self._forward(z)[0]
class BezierVectorGenerator(BaseVectorModel):
NUM_SEGMENTS = 2
def __init__(self, num_strokes=4,
zdim=128, width=32, imsize=32,
color_output=False,
stroke_width=None):
super(BezierVectorGenerator, self).__init__()
if stroke_width is None:
self.stroke_width = (0.5, 3.0)
LOG.warning("Setting default stroke with %s", self.stroke_width)
else:
self.stroke_width = stroke_width
self.imsize = imsize
self.num_strokes = num_strokes
self.zdim = zdim
self.trunk = th.nn.Sequential(
th.nn.Linear(zdim, width),
th.nn.SELU(inplace=True),
th.nn.Linear(width, 2*width),
th.nn.SELU(inplace=True),
th.nn.Linear(2*width, 4*width),
th.nn.SELU(inplace=True),
th.nn.Linear(4*width, 8*width),
th.nn.SELU(inplace=True),
)
# 4 points bezier with n_segments -> 3*n_segments + 1 points
self.point_predictor = th.nn.Sequential(
th.nn.Linear(8*width,
2*self.num_strokes*(
BezierVectorGenerator.NUM_SEGMENTS*3 + 1)),
th.nn.Tanh() # bound spatial extent
)
self.width_predictor = th.nn.Sequential(
th.nn.Linear(8*width, self.num_strokes),
th.nn.Sigmoid()
)
self.alpha_predictor = th.nn.Sequential(
th.nn.Linear(8*width, self.num_strokes),
th.nn.Sigmoid()
)
self.color_predictor = None
if color_output:
self.color_predictor = th.nn.Sequential(
th.nn.Linear(8*width, 3*self.num_strokes),
th.nn.Sigmoid()
)
def _forward(self, z):
bs = z.shape[0]
feats = self.trunk(z)
all_points = self.point_predictor(feats)
all_alphas = self.alpha_predictor(feats)
if self.color_predictor:
all_colors = self.color_predictor(feats)
all_colors = all_colors.view(bs, self.num_strokes, 3)
else:
all_colors = None
all_widths = self.width_predictor(feats)
min_width = self.stroke_width[0]
max_width = self.stroke_width[1]
all_widths = (max_width - min_width) * all_widths + min_width
all_points = all_points.view(
bs, self.num_strokes, BezierVectorGenerator.NUM_SEGMENTS*3+1, 2)
output, scenes = rendering.bezier_render(all_points, all_widths, all_alphas,
colors=all_colors,
canvas_size=self.imsize)
# map to [-1, 1]
output = output*2.0 - 1.0
return output, scenes
class VectorGenerator(BaseVectorModel):
def __init__(self, num_strokes=4,
zdim=128, width=32, imsize=32,
color_output=False,
stroke_width=None):
super(VectorGenerator, self).__init__()
if stroke_width is None:
self.stroke_width = (0.5, 3.0)
LOG.warning("Setting default stroke with %s", self.stroke_width)
else:
self.stroke_width = stroke_width
self.imsize = imsize
self.num_strokes = num_strokes
self.zdim = zdim
self.trunk = th.nn.Sequential(
th.nn.Linear(zdim, width),
th.nn.SELU(inplace=True),
th.nn.Linear(width, 2*width),
th.nn.SELU(inplace=True),
th.nn.Linear(2*width, 4*width),
th.nn.SELU(inplace=True),
th.nn.Linear(4*width, 8*width),
th.nn.SELU(inplace=True),
)
# straight lines so n_segments -> n_segments - 1 points
self.point_predictor = th.nn.Sequential(
th.nn.Linear(8*width, 2*(self.num_strokes*2)),
th.nn.Tanh() # bound spatial extent
)
self.width_predictor = th.nn.Sequential(
th.nn.Linear(8*width, self.num_strokes),
th.nn.Sigmoid()
)
self.alpha_predictor = th.nn.Sequential(
th.nn.Linear(8*width, self.num_strokes),
th.nn.Sigmoid()
)
self.color_predictor = None
if color_output:
self.color_predictor = th.nn.Sequential(
th.nn.Linear(8*width, 3*self.num_strokes),
th.nn.Sigmoid()
)
def _forward(self, z):
bs = z.shape[0]
feats = self.trunk(z)
all_points = self.point_predictor(feats)
all_alphas = self.alpha_predictor(feats)
if self.color_predictor:
all_colors = self.color_predictor(feats)
all_colors = all_colors.view(bs, self.num_strokes, 3)
else:
all_colors = None
all_widths = self.width_predictor(feats)
min_width = self.stroke_width[0]
max_width = self.stroke_width[1]
all_widths = (max_width - min_width) * all_widths + min_width
all_points = all_points.view(bs, self.num_strokes, 2, 2)
output, scenes = rendering.line_render(all_points, all_widths, all_alphas,
colors=all_colors,
canvas_size=self.imsize)
# map to [-1, 1]
output = output*2.0 - 1.0
return output, scenes
class RNNVectorGenerator(BaseVectorModel):
def __init__(self, num_strokes=64,
zdim=128, width=32, imsize=32,
hidden_size=512, dropout=0.9,
color_output=False,
num_layers=3, stroke_width=None):
super(RNNVectorGenerator, self).__init__()
if stroke_width is None:
self.stroke_width = (0.5, 3.0)
LOG.warning("Setting default stroke with %s", self.stroke_width)
else:
self.stroke_width = stroke_width
self.num_layers = num_layers
self.imsize = imsize
self.num_strokes = num_strokes
self.hidden_size = hidden_size
self.zdim = zdim
self.hidden_cell_predictor = th.nn.Linear(
zdim, 2*hidden_size*num_layers)
self.lstm = th.nn.LSTM(
zdim, hidden_size,
num_layers=self.num_layers, dropout=dropout,
batch_first=True)
# straight lines so n_segments -> n_segments - 1 points
self.point_predictor = th.nn.Sequential(
th.nn.Linear(hidden_size, 2*2), # 2 points, (x,y)
th.nn.Tanh() # bound spatial extent
)
self.width_predictor = th.nn.Sequential(
th.nn.Linear(hidden_size, 1),
th.nn.Sigmoid()
)
self.alpha_predictor = th.nn.Sequential(
th.nn.Linear(hidden_size, 1),
th.nn.Sigmoid()
)
def _forward(self, z, hidden_and_cell=None):
steps = self.num_strokes
# z is passed at each step, duplicate it
bs = z.shape[0]
expanded_z = z.unsqueeze(1).repeat(1, steps, 1)
# First step in the RNN
if hidden_and_cell is None:
# Initialize from latent vector
hidden_and_cell = self.hidden_cell_predictor(th.tanh(z))
hidden = hidden_and_cell[:, :self.hidden_size*self.num_layers]
hidden = hidden.view(-1, self.num_layers, self.hidden_size)
hidden = hidden.permute(1, 0, 2).contiguous()
cell = hidden_and_cell[:, self.hidden_size*self.num_layers:]
cell = cell.view(-1, self.num_layers, self.hidden_size)
cell = cell.permute(1, 0, 2).contiguous()
hidden_and_cell = (hidden, cell)
feats, hidden_and_cell = self.lstm(expanded_z, hidden_and_cell)
hidden, cell = hidden_and_cell
feats = feats.reshape(bs*steps, self.hidden_size)
all_points = self.point_predictor(feats).view(bs, steps, 2, 2)
all_alphas = self.alpha_predictor(feats).view(bs, steps)
all_widths = self.width_predictor(feats).view(bs, steps)
min_width = self.stroke_width[0]
max_width = self.stroke_width[1]
all_widths = (max_width - min_width) * all_widths + min_width
output, scenes = rendering.line_render(all_points, all_widths, all_alphas,
canvas_size=self.imsize)
# map to [-1, 1]
output = output*2.0 - 1.0
return output, scenes
class ChainRNNVectorGenerator(BaseVectorModel):
"""Strokes form a single long chain."""
def __init__(self, num_strokes=64,
zdim=128, width=32, imsize=32,
hidden_size=512, dropout=0.9,
color_output=False,
num_layers=3, stroke_width=None):
super(ChainRNNVectorGenerator, self).__init__()
if stroke_width is None:
self.stroke_width = (0.5, 3.0)
LOG.warning("Setting default stroke with %s", self.stroke_width)
else:
self.stroke_width = stroke_width
self.num_layers = num_layers
self.imsize = imsize
self.num_strokes = num_strokes
self.hidden_size = hidden_size
self.zdim = zdim
self.hidden_cell_predictor = th.nn.Linear(
zdim, 2*hidden_size*num_layers)
self.lstm = th.nn.LSTM(
zdim, hidden_size,
num_layers=self.num_layers, dropout=dropout,
batch_first=True)
# straight lines so n_segments -> n_segments - 1 points
self.point_predictor = th.nn.Sequential(
th.nn.Linear(hidden_size, 2), # 1 point, (x,y)
th.nn.Tanh() # bound spatial extent
)
self.width_predictor = th.nn.Sequential(
th.nn.Linear(hidden_size, 1),
th.nn.Sigmoid()
)
self.alpha_predictor = th.nn.Sequential(
th.nn.Linear(hidden_size, 1),
th.nn.Sigmoid()
)
def _forward(self, z, hidden_and_cell=None):
steps = self.num_strokes
# z is passed at each step, duplicate it
bs = z.shape[0]
expanded_z = z.unsqueeze(1).repeat(1, steps, 1)
# First step in the RNN
if hidden_and_cell is None:
# Initialize from latent vector
hidden_and_cell = self.hidden_cell_predictor(th.tanh(z))
hidden = hidden_and_cell[:, :self.hidden_size*self.num_layers]
hidden = hidden.view(-1, self.num_layers, self.hidden_size)
hidden = hidden.permute(1, 0, 2).contiguous()
cell = hidden_and_cell[:, self.hidden_size*self.num_layers:]
cell = cell.view(-1, self.num_layers, self.hidden_size)
cell = cell.permute(1, 0, 2).contiguous()
hidden_and_cell = (hidden, cell)
feats, hidden_and_cell = self.lstm(expanded_z, hidden_and_cell)
hidden, cell = hidden_and_cell
feats = feats.reshape(bs*steps, self.hidden_size)
# Construct the chain
end_points = self.point_predictor(feats).view(bs, steps, 1, 2)
start_points = th.cat([
# first point is canvas center
th.zeros(bs, 1, 1, 2, device=feats.device),
end_points[:, 1:, :, :]], 1)
all_points = th.cat([start_points, end_points], 2)
all_alphas = self.alpha_predictor(feats).view(bs, steps)
all_widths = self.width_predictor(feats).view(bs, steps)
min_width = self.stroke_width[0]
max_width = self.stroke_width[1]
all_widths = (max_width - min_width) * all_widths + min_width
output, scenes = rendering.line_render(all_points, all_widths, all_alphas,
canvas_size=self.imsize)
# map to [-1, 1]
output = output*2.0 - 1.0
return output, scenes
class Generator(BaseModel):
def __init__(self, width=64, imsize=32, zdim=128,
stroke_width=None,
color_output=False,
num_strokes=4):
super(Generator, self).__init__()
assert imsize == 32
self.imsize = imsize
self.zdim = zdim
num_in_chans = self.zdim // (2*2)
num_out_chans = 3 if color_output else 1
self.net = th.nn.Sequential(
th.nn.ConvTranspose2d(num_in_chans, width*8, 4, padding=1,
stride=2),
th.nn.LeakyReLU(0.2, inplace=True),
th.nn.Conv2d(width*8, width*8, 3, padding=1),
th.nn.BatchNorm2d(width*8),
th.nn.LeakyReLU(0.2, inplace=True),
# 4x4
th.nn.ConvTranspose2d(8*width, 4*width, 4, padding=1, stride=2),
th.nn.LeakyReLU(0.2, inplace=True),
th.nn.Conv2d(4*width, 4*width, 3, padding=1),
th.nn.BatchNorm2d(width*4),
th.nn.LeakyReLU(0.2, inplace=True),
# 8x8
th.nn.ConvTranspose2d(4*width, 2*width, 4, padding=1, stride=2),
th.nn.LeakyReLU(0.2, inplace=True),
th.nn.Conv2d(2*width, 2*width, 3, padding=1),
th.nn.BatchNorm2d(width*2),
th.nn.LeakyReLU(0.2, inplace=True),
# 16x16
th.nn.ConvTranspose2d(2*width, width, 4, padding=1, stride=2),
th.nn.LeakyReLU(0.2, inplace=True),
th.nn.Conv2d(width, width, 3, padding=1),
th.nn.BatchNorm2d(width),
th.nn.LeakyReLU(0.2, inplace=True),
# 32x32
th.nn.Conv2d(width, width, 3, padding=1),
th.nn.BatchNorm2d(width),
th.nn.LeakyReLU(0.2, inplace=True),
th.nn.Conv2d(width, width, 3, padding=1),
th.nn.LeakyReLU(0.2, inplace=True),
th.nn.Conv2d(width, num_out_chans, 1),
th.nn.Tanh(),
)
def forward(self, z):
bs = z.shape[0]
num_in_chans = self.zdim // (2*2)
raster = self.net(z.view(bs, num_in_chans, 2, 2))
return raster
class Discriminator(th.nn.Module):
def __init__(self, conditional=False, width=64, color_output=False):
super(Discriminator, self).__init__()
self.conditional = conditional
sn = th.nn.utils.spectral_norm
num_chan_in = 3 if color_output else 1
self.net = th.nn.Sequential(
th.nn.Conv2d(num_chan_in, width, 3, padding=1),
th.nn.LeakyReLU(0.2, inplace=True),
th.nn.Conv2d(width, 2*width, 4, padding=1, stride=2),
th.nn.LeakyReLU(0.2, inplace=True),
# 16x16
sn(th.nn.Conv2d(2*width, 2*width, 3, padding=1)),
th.nn.LeakyReLU(0.2, inplace=True),
sn(th.nn.Conv2d(2*width, 4*width, 4, padding=1, stride=2)),
th.nn.LeakyReLU(0.2, inplace=True),
# 8x8
sn(th.nn.Conv2d(4*width, 4*width, 3, padding=1)),
th.nn.LeakyReLU(0.2, inplace=True),
sn(th.nn.Conv2d(4*width, width*4, 4, padding=1, stride=2)),
th.nn.LeakyReLU(0.2, inplace=True),
# 4x4
sn(th.nn.Conv2d(4*width, 4*width, 3, padding=1)),
th.nn.LeakyReLU(0.2, inplace=True),
sn(th.nn.Conv2d(4*width, width*4, 4, padding=1, stride=2)),
th.nn.LeakyReLU(0.2, inplace=True),
# 2x2
modules.Flatten(),
th.nn.Linear(width*4*2*2, 1),
)
def forward(self, x):
out = self.net(x)
return out