|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import torchvision |
|
from torch import nn |
|
from modules.real3d.facev2v_warp.func_utils import apply_imagenet_normalization, apply_vggface_normalization |
|
|
|
|
|
@torch.jit.script |
|
def fuse_math_min_mean_pos(x): |
|
r"""Fuse operation min mean for hinge loss computation of positive |
|
samples""" |
|
minval = torch.min(x - 1, x * 0) |
|
loss = -torch.mean(minval) |
|
return loss |
|
|
|
|
|
@torch.jit.script |
|
def fuse_math_min_mean_neg(x): |
|
r"""Fuse operation min mean for hinge loss computation of negative |
|
samples""" |
|
minval = torch.min(-x - 1, x * 0) |
|
loss = -torch.mean(minval) |
|
return loss |
|
|
|
|
|
class _PerceptualNetwork(nn.Module): |
|
def __init__(self, network, layer_name_mapping, layers): |
|
super().__init__() |
|
self.network = network.cuda() |
|
self.layer_name_mapping = layer_name_mapping |
|
self.layers = layers |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, x): |
|
output = {} |
|
for i, layer in enumerate(self.network): |
|
x = layer(x) |
|
layer_name = self.layer_name_mapping.get(i, None) |
|
if layer_name in self.layers: |
|
output[layer_name] = x |
|
return output |
|
|
|
|
|
def _vgg19(layers): |
|
network = torchvision.models.vgg19() |
|
state_dict = torch.utils.model_zoo.load_url( |
|
"https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", map_location=torch.device("cpu"), progress=True |
|
) |
|
network.load_state_dict(state_dict) |
|
network = network.features |
|
layer_name_mapping = { |
|
1: "relu_1_1", |
|
3: "relu_1_2", |
|
6: "relu_2_1", |
|
8: "relu_2_2", |
|
11: "relu_3_1", |
|
13: "relu_3_2", |
|
15: "relu_3_3", |
|
17: "relu_3_4", |
|
20: "relu_4_1", |
|
22: "relu_4_2", |
|
24: "relu_4_3", |
|
26: "relu_4_4", |
|
29: "relu_5_1", |
|
} |
|
return _PerceptualNetwork(network, layer_name_mapping, layers) |
|
|
|
|
|
def _vgg_face(layers): |
|
network = torchvision.models.vgg16(num_classes=2622) |
|
state_dict = torch.utils.model_zoo.load_url( |
|
"http://www.robots.ox.ac.uk/~albanie/models/pytorch-mcn/" "vgg_face_dag.pth", map_location=torch.device("cpu"), progress=True |
|
) |
|
feature_layer_name_mapping = { |
|
0: "conv1_1", |
|
2: "conv1_2", |
|
5: "conv2_1", |
|
7: "conv2_2", |
|
10: "conv3_1", |
|
12: "conv3_2", |
|
14: "conv3_3", |
|
17: "conv4_1", |
|
19: "conv4_2", |
|
21: "conv4_3", |
|
24: "conv5_1", |
|
26: "conv5_2", |
|
28: "conv5_3", |
|
} |
|
new_state_dict = {} |
|
for k, v in feature_layer_name_mapping.items(): |
|
new_state_dict["features." + str(k) + ".weight"] = state_dict[v + ".weight"] |
|
new_state_dict["features." + str(k) + ".bias"] = state_dict[v + ".bias"] |
|
classifier_layer_name_mapping = {0: "fc6", 3: "fc7", 6: "fc8"} |
|
for k, v in classifier_layer_name_mapping.items(): |
|
new_state_dict["classifier." + str(k) + ".weight"] = state_dict[v + ".weight"] |
|
new_state_dict["classifier." + str(k) + ".bias"] = state_dict[v + ".bias"] |
|
network.load_state_dict(new_state_dict) |
|
layer_name_mapping = { |
|
1: "relu_1_1", |
|
3: "relu_1_2", |
|
6: "relu_2_1", |
|
8: "relu_2_2", |
|
11: "relu_3_1", |
|
13: "relu_3_2", |
|
15: "relu_3_3", |
|
18: "relu_4_1", |
|
20: "relu_4_2", |
|
22: "relu_4_3", |
|
25: "relu_5_1", |
|
} |
|
return _PerceptualNetwork(network.features, layer_name_mapping, layers) |
|
|
|
|
|
class PerceptualLoss(nn.Module): |
|
def __init__( |
|
self, |
|
layers_weight={"relu_1_1": 0.03125, "relu_2_1": 0.0625, "relu_3_1": 0.125, "relu_4_1": 0.25, "relu_5_1": 1.0}, |
|
n_scale=3, |
|
vgg19_loss_weight=1.0, |
|
vggface_loss_weight=1.0, |
|
): |
|
super().__init__() |
|
self.vgg19 = _vgg19(layers_weight.keys()) |
|
self.vggface = _vgg_face(layers_weight.keys()) |
|
self.mse_criterion = nn.MSELoss() |
|
self.criterion = nn.L1Loss() |
|
self.layers_weight, self.n_scale = layers_weight, n_scale |
|
self.vgg19_loss_weight = vgg19_loss_weight |
|
self.vggface_loss_weight = vggface_loss_weight |
|
self.vgg19.eval() |
|
self.vggface.eval() |
|
|
|
def forward(self, input, target): |
|
""" |
|
input: [B, 3, H, W] in 0.~1. scale |
|
""" |
|
if input.shape[-1] != 512: |
|
assert input.ndim == 4 |
|
input = F.interpolate(input, mode="bilinear", size=(512,512), antialias=True, align_corners=False) |
|
target = F.interpolate(target, mode="bilinear", size=(512,512), antialias=True, align_corners=False) |
|
|
|
self.vgg19.eval() |
|
self.vggface.eval() |
|
loss = 0 |
|
features_vggface_input = self.vggface(apply_vggface_normalization(input)) |
|
features_vggface_target = self.vggface(apply_vggface_normalization(target)) |
|
input = apply_imagenet_normalization(input) |
|
target = apply_imagenet_normalization(target) |
|
features_vgg19_input = self.vgg19(input) |
|
features_vgg19_target = self.vgg19(target) |
|
for layer, weight in self.layers_weight.items(): |
|
tmp = self.vggface_loss_weight * weight * self.criterion(features_vggface_input[layer], features_vggface_target[layer].detach()) / 255 |
|
if not torch.any(torch.isnan(tmp)): |
|
loss += tmp |
|
else: |
|
loss += torch.zeros_like(tmp) |
|
tmp = self.vgg19_loss_weight * weight * self.criterion(features_vgg19_input[layer], features_vgg19_target[layer].detach()) |
|
if not torch.any(torch.isnan(tmp)): |
|
loss += tmp |
|
else: |
|
loss += torch.zeros_like(tmp) |
|
for i in range(self.n_scale): |
|
input = F.interpolate(input, mode="bilinear", scale_factor=0.5, align_corners=False, recompute_scale_factor=True) |
|
target = F.interpolate(target, mode="bilinear", scale_factor=0.5, align_corners=False, recompute_scale_factor=True) |
|
features_vgg19_input = self.vgg19(input) |
|
features_vgg19_target = self.vgg19(target) |
|
tmp = weight * self.criterion(features_vgg19_input[layer], features_vgg19_target[layer].detach()) |
|
if not torch.any(torch.isnan(tmp)): |
|
loss += tmp |
|
else: |
|
loss += torch.zeros_like(tmp) |
|
return loss |
|
|
|
|
|
class GANLoss(nn.Module): |
|
|
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, dis_output, t_real, dis_update=True): |
|
r"""GAN loss computation. |
|
Args: |
|
dis_output (tensor or list of tensors): Discriminator outputs. |
|
t_real (bool): If ``True``, uses the real label as target, otherwise |
|
uses the fake label as target. |
|
dis_update (bool): If ``True``, the loss will be used to update the |
|
discriminator, otherwise the generator. |
|
Returns: |
|
loss (tensor): Loss value. |
|
""" |
|
|
|
if dis_update: |
|
if t_real: |
|
loss = fuse_math_min_mean_pos(dis_output) |
|
else: |
|
loss = fuse_math_min_mean_neg(dis_output) |
|
else: |
|
loss = -torch.mean(dis_output) |
|
return loss |
|
|
|
|
|
class FeatureMatchingLoss(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.criterion = nn.L1Loss() |
|
|
|
def forward(self, fake_features, real_features): |
|
num_d = len(fake_features) |
|
dis_weight = 1.0 / num_d |
|
loss = fake_features[0][0].new_tensor(0) |
|
for i in range(num_d): |
|
for j in range(len(fake_features[i])): |
|
tmp_loss = self.criterion(fake_features[i][j], real_features[i][j].detach()) |
|
loss += dis_weight * tmp_loss |
|
return loss |
|
|
|
|
|
class EquivarianceLoss(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.criterion = nn.L1Loss() |
|
|
|
def forward(self, kp_d, reverse_kp): |
|
loss = self.criterion(kp_d[:, :, :2], reverse_kp) |
|
return loss |
|
|
|
|
|
class KeypointPriorLoss(nn.Module): |
|
def __init__(self, Dt=0.1, zt=0.33): |
|
super().__init__() |
|
self.Dt, self.zt = Dt, zt |
|
|
|
def forward(self, kp_d): |
|
|
|
dist_mat = torch.cdist(kp_d, kp_d).square() |
|
loss = ( |
|
torch.max(0 * dist_mat, self.Dt - dist_mat).sum((1, 2)).mean() |
|
+ torch.abs(kp_d[:, :, 2].mean(1) - self.zt).mean() |
|
- kp_d.shape[1] * self.Dt |
|
) |
|
return loss |
|
|
|
|
|
class HeadPoseLoss(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.criterion = nn.L1Loss() |
|
|
|
def forward(self, yaw, pitch, roll, real_yaw, real_pitch, real_roll): |
|
loss = (self.criterion(yaw, real_yaw.detach()) + self.criterion(pitch, real_pitch.detach()) + self.criterion(roll, real_roll.detach())) / 3 |
|
return loss / np.pi * 180 |
|
|
|
|
|
class DeformationPriorLoss(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, delta_d): |
|
loss = delta_d.abs().mean() |
|
return loss |
|
|
|
|
|
if __name__ == '__main__': |
|
loss_fn = PerceptualLoss() |
|
x1 = torch.randn([4, 3, 512, 512]).cuda() |
|
x2 = torch.randn([4, 3, 512, 512]).cuda() |
|
loss = loss_fn(x1, x2) |