Spaces:
Sleeping
Sleeping
"""Losses for the generative models and baselines.""" | |
import torch as th | |
import numpy as np | |
import ttools.modules.image_operators as imops | |
class KLDivergence(th.nn.Module): | |
""" | |
Args: | |
min_value(float): the loss is clipped so that value below this | |
number don't affect the optimization. | |
""" | |
def __init__(self, min_value=0.2): | |
super(KLDivergence, self).__init__() | |
self.min_value = min_value | |
def forward(self, mu, log_sigma): | |
loss = -0.5 * (1.0 + log_sigma - mu.pow(2) - log_sigma.exp()) | |
loss = loss.mean() | |
loss = th.max(loss, self.min_value*th.ones_like(loss)) | |
return loss | |
class MultiscaleMSELoss(th.nn.Module): | |
def __init__(self, channels=3): | |
super(MultiscaleMSELoss, self).__init__() | |
self.blur = imops.GaussianBlur(1, channels=channels) | |
def forward(self, im, target): | |
bs, c, h, w = im.shape | |
num_levels = max(int(np.ceil(np.log2(h))) - 2, 1) | |
losses = [] | |
for lvl in range(num_levels): | |
loss = th.nn.functional.mse_loss(im, target) | |
losses.append(loss) | |
im = th.nn.functional.interpolate(self.blur(im), | |
scale_factor=0.5, | |
mode="nearest") | |
target = th.nn.functional.interpolate(self.blur(target), | |
scale_factor=0.5, | |
mode="nearest") | |
losses = th.stack(losses) | |
return losses.sum() | |
def gaussian_pdfs(dx, dy, params): | |
"""Returns the pdf at (dx, dy) for each Gaussian in the mixture. | |
""" | |
dx = dx.unsqueeze(-1) # replicate dx, dy to evaluate all pdfs at once | |
dy = dy.unsqueeze(-1) | |
mu_x = params[..., 0] | |
mu_y = params[..., 1] | |
sigma_x = params[..., 2].exp() | |
sigma_y = params[..., 3].exp() | |
rho_xy = th.tanh(params[..., 4]) | |
x = ((dx-mu_x) / sigma_x).pow(2) | |
y = ((dy-mu_y) / sigma_y).pow(2) | |
xy = (dx-mu_x)*(dy-mu_y) / (sigma_x * sigma_y) | |
arg = x + y - 2.0*rho_xy*xy | |
pdf = th.exp(-arg / (2*(1.0 - rho_xy.pow(2)))) | |
norm = 2.0 * np.pi * sigma_x * sigma_y * (1.0 - rho_xy.pow(2)).sqrt() | |
return pdf / norm | |
class GaussianMixtureReconstructionLoss(th.nn.Module): | |
""" | |
Args: | |
""" | |
def __init__(self, eps=1e-5): | |
super(GaussianMixtureReconstructionLoss, self).__init__() | |
self.eps = eps | |
def forward(self, pen_logits, mixture_logits, gaussian_params, targets): | |
dx = targets[..., 0] | |
dy = targets[..., 1] | |
pen_state = targets[..., 2:].argmax(-1) # target index | |
# Likelihood loss on the stroke position | |
# No need to predict accurate pen position for end-of-sequence tokens | |
valid_stroke = (targets[..., -1] != 1.0).float() | |
mixture_weights = th.nn.functional.softmax(mixture_logits, -1) | |
pdfs = gaussian_pdfs(dx, dy, gaussian_params) | |
position_loss = - th.log(self.eps + (pdfs * mixture_weights).sum(-1)) | |
# by actual non-empty count | |
position_loss = (position_loss*valid_stroke).sum() / valid_stroke.sum() | |
# Classification loss for the stroke mode | |
pen_loss = th.nn.functional.cross_entropy(pen_logits.view(-1, 3), | |
pen_state.view(-1)) | |
return position_loss + pen_loss | |