import torch import torch.nn.functional as F import numpy as np import torchvision from torch import nn from modules.real3d.facev2v_warp.func_utils import apply_imagenet_normalization, apply_vggface_normalization @torch.jit.script def fuse_math_min_mean_pos(x): r"""Fuse operation min mean for hinge loss computation of positive samples""" minval = torch.min(x - 1, x * 0) loss = -torch.mean(minval) return loss @torch.jit.script def fuse_math_min_mean_neg(x): r"""Fuse operation min mean for hinge loss computation of negative samples""" minval = torch.min(-x - 1, x * 0) loss = -torch.mean(minval) return loss class _PerceptualNetwork(nn.Module): def __init__(self, network, layer_name_mapping, layers): super().__init__() self.network = network.cuda() self.layer_name_mapping = layer_name_mapping self.layers = layers for param in self.parameters(): param.requires_grad = False def forward(self, x): output = {} for i, layer in enumerate(self.network): x = layer(x) layer_name = self.layer_name_mapping.get(i, None) if layer_name in self.layers: output[layer_name] = x return output def _vgg19(layers): network = torchvision.models.vgg19() state_dict = torch.utils.model_zoo.load_url( "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", map_location=torch.device("cpu"), progress=True ) network.load_state_dict(state_dict) network = network.features layer_name_mapping = { 1: "relu_1_1", 3: "relu_1_2", 6: "relu_2_1", 8: "relu_2_2", 11: "relu_3_1", 13: "relu_3_2", 15: "relu_3_3", 17: "relu_3_4", 20: "relu_4_1", 22: "relu_4_2", 24: "relu_4_3", 26: "relu_4_4", 29: "relu_5_1", } return _PerceptualNetwork(network, layer_name_mapping, layers) def _vgg_face(layers): network = torchvision.models.vgg16(num_classes=2622) state_dict = torch.utils.model_zoo.load_url( "http://www.robots.ox.ac.uk/~albanie/models/pytorch-mcn/" "vgg_face_dag.pth", map_location=torch.device("cpu"), progress=True ) feature_layer_name_mapping = { 0: "conv1_1", 2: "conv1_2", 5: "conv2_1", 7: "conv2_2", 10: "conv3_1", 12: "conv3_2", 14: "conv3_3", 17: "conv4_1", 19: "conv4_2", 21: "conv4_3", 24: "conv5_1", 26: "conv5_2", 28: "conv5_3", } new_state_dict = {} for k, v in feature_layer_name_mapping.items(): new_state_dict["features." + str(k) + ".weight"] = state_dict[v + ".weight"] new_state_dict["features." + str(k) + ".bias"] = state_dict[v + ".bias"] classifier_layer_name_mapping = {0: "fc6", 3: "fc7", 6: "fc8"} for k, v in classifier_layer_name_mapping.items(): new_state_dict["classifier." + str(k) + ".weight"] = state_dict[v + ".weight"] new_state_dict["classifier." + str(k) + ".bias"] = state_dict[v + ".bias"] network.load_state_dict(new_state_dict) layer_name_mapping = { 1: "relu_1_1", 3: "relu_1_2", 6: "relu_2_1", 8: "relu_2_2", 11: "relu_3_1", 13: "relu_3_2", 15: "relu_3_3", 18: "relu_4_1", 20: "relu_4_2", 22: "relu_4_3", 25: "relu_5_1", } return _PerceptualNetwork(network.features, layer_name_mapping, layers) class PerceptualLoss(nn.Module): def __init__( self, layers_weight={"relu_1_1": 0.03125, "relu_2_1": 0.0625, "relu_3_1": 0.125, "relu_4_1": 0.25, "relu_5_1": 1.0}, n_scale=3, vgg19_loss_weight=1.0, vggface_loss_weight=1.0, ): super().__init__() self.vgg19 = _vgg19(layers_weight.keys()) self.vggface = _vgg_face(layers_weight.keys()) self.mse_criterion = nn.MSELoss() self.criterion = nn.L1Loss() self.layers_weight, self.n_scale = layers_weight, n_scale self.vgg19_loss_weight = vgg19_loss_weight self.vggface_loss_weight = vggface_loss_weight self.vgg19.eval() self.vggface.eval() def forward(self, input, target): """ input: [B, 3, H, W] in 0.~1. scale """ if input.shape[-1] != 512: assert input.ndim == 4 input = F.interpolate(input, mode="bilinear", size=(512,512), antialias=True, align_corners=False) target = F.interpolate(target, mode="bilinear", size=(512,512), antialias=True, align_corners=False) self.vgg19.eval() self.vggface.eval() loss = 0 features_vggface_input = self.vggface(apply_vggface_normalization(input)) features_vggface_target = self.vggface(apply_vggface_normalization(target)) input = apply_imagenet_normalization(input) target = apply_imagenet_normalization(target) features_vgg19_input = self.vgg19(input) features_vgg19_target = self.vgg19(target) for layer, weight in self.layers_weight.items(): tmp = self.vggface_loss_weight * weight * self.criterion(features_vggface_input[layer], features_vggface_target[layer].detach()) / 255 if not torch.any(torch.isnan(tmp)): loss += tmp else: loss += torch.zeros_like(tmp) tmp = self.vgg19_loss_weight * weight * self.criterion(features_vgg19_input[layer], features_vgg19_target[layer].detach()) if not torch.any(torch.isnan(tmp)): loss += tmp else: loss += torch.zeros_like(tmp) for i in range(self.n_scale): input = F.interpolate(input, mode="bilinear", scale_factor=0.5, align_corners=False, recompute_scale_factor=True) target = F.interpolate(target, mode="bilinear", scale_factor=0.5, align_corners=False, recompute_scale_factor=True) features_vgg19_input = self.vgg19(input) features_vgg19_target = self.vgg19(target) tmp = weight * self.criterion(features_vgg19_input[layer], features_vgg19_target[layer].detach()) if not torch.any(torch.isnan(tmp)): loss += tmp else: loss += torch.zeros_like(tmp) return loss class GANLoss(nn.Module): # Update generator: gan_loss(fake_output, True, False) + other losses # Update discriminator: gan_loss(fake_output(detached), False, True) + gan_loss(real_output, True, True) def __init__(self): super().__init__() def forward(self, dis_output, t_real, dis_update=True): r"""GAN loss computation. Args: dis_output (tensor or list of tensors): Discriminator outputs. t_real (bool): If ``True``, uses the real label as target, otherwise uses the fake label as target. dis_update (bool): If ``True``, the loss will be used to update the discriminator, otherwise the generator. Returns: loss (tensor): Loss value. """ if dis_update: if t_real: loss = fuse_math_min_mean_pos(dis_output) else: loss = fuse_math_min_mean_neg(dis_output) else: loss = -torch.mean(dis_output) return loss class FeatureMatchingLoss(nn.Module): def __init__(self): super().__init__() self.criterion = nn.L1Loss() def forward(self, fake_features, real_features): num_d = len(fake_features) dis_weight = 1.0 / num_d loss = fake_features[0][0].new_tensor(0) for i in range(num_d): for j in range(len(fake_features[i])): tmp_loss = self.criterion(fake_features[i][j], real_features[i][j].detach()) loss += dis_weight * tmp_loss return loss class EquivarianceLoss(nn.Module): def __init__(self): super().__init__() self.criterion = nn.L1Loss() def forward(self, kp_d, reverse_kp): loss = self.criterion(kp_d[:, :, :2], reverse_kp) return loss class KeypointPriorLoss(nn.Module): def __init__(self, Dt=0.1, zt=0.33): super().__init__() self.Dt, self.zt = Dt, zt def forward(self, kp_d): # use distance matrix to avoid loop dist_mat = torch.cdist(kp_d, kp_d).square() loss = ( torch.max(0 * dist_mat, self.Dt - dist_mat).sum((1, 2)).mean() + torch.abs(kp_d[:, :, 2].mean(1) - self.zt).mean() - kp_d.shape[1] * self.Dt ) return loss class HeadPoseLoss(nn.Module): def __init__(self): super().__init__() self.criterion = nn.L1Loss() def forward(self, yaw, pitch, roll, real_yaw, real_pitch, real_roll): loss = (self.criterion(yaw, real_yaw.detach()) + self.criterion(pitch, real_pitch.detach()) + self.criterion(roll, real_roll.detach())) / 3 return loss / np.pi * 180 class DeformationPriorLoss(nn.Module): def __init__(self): super().__init__() def forward(self, delta_d): loss = delta_d.abs().mean() return loss if __name__ == '__main__': loss_fn = PerceptualLoss() x1 = torch.randn([4, 3, 512, 512]).cuda() x2 = torch.randn([4, 3, 512, 512]).cuda() loss = loss_fn(x1, x2)