import torch |
import torch.nn.functional as F |
import numpy as np |
import torchvision |
from torch import nn |
from modules.real3d.facev2v_warp.func_utils import apply_imagenet_normalization, apply_vggface_normalization |
@torch.jit.script |
def fuse_math_min_mean_pos(x): |
r"""Fuse operation min mean for hinge loss computation of positive |
samples""" |
minval = torch.min(x - 1, x * 0) |
loss = -torch.mean(minval) |
return loss |
@torch.jit.script |
def fuse_math_min_mean_neg(x): |
r"""Fuse operation min mean for hinge loss computation of negative |
samples""" |
minval = torch.min(-x - 1, x * 0) |
loss = -torch.mean(minval) |
return loss |
class _PerceptualNetwork(nn.Module): |
def __init__(self, network, layer_name_mapping, layers): |
super().__init__() |
self.network = network.cuda() |
self.layer_name_mapping = layer_name_mapping |
self.layers = layers |
for param in self.parameters(): |
param.requires_grad = False |
def forward(self, x): |
output = {} |
for i, layer in enumerate(self.network): |
x = layer(x) |
layer_name = self.layer_name_mapping.get(i, None) |
if layer_name in self.layers: |
output[layer_name] = x |
return output |
def _vgg19(layers): |
network = torchvision.models.vgg19() |
state_dict = torch.utils.model_zoo.load_url( |
"https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", map_location=torch.device("cpu"), progress=True |
) |
network.load_state_dict(state_dict) |
network = network.features |
layer_name_mapping = { |
1: "relu_1_1", |
3: "relu_1_2", |
6: "relu_2_1", |
8: "relu_2_2", |
11: "relu_3_1", |
13: "relu_3_2", |
15: "relu_3_3", |
17: "relu_3_4", |
20: "relu_4_1", |
22: "relu_4_2", |
24: "relu_4_3", |
26: "relu_4_4", |
29: "relu_5_1", |
} |
return _PerceptualNetwork(network, layer_name_mapping, layers) |
def _vgg_face(layers): |
network = torchvision.models.vgg16(num_classes=2622) |
state_dict = torch.utils.model_zoo.load_url( |
"http://www.robots.ox.ac.uk/~albanie/models/pytorch-mcn/" "vgg_face_dag.pth", map_location=torch.device("cpu"), progress=True |
) |
feature_layer_name_mapping = { |
0: "conv1_1", |
2: "conv1_2", |
5: "conv2_1", |
7: "conv2_2", |
10: "conv3_1", |
12: "conv3_2", |
14: "conv3_3", |
17: "conv4_1", |
19: "conv4_2", |
21: "conv4_3", |
24: "conv5_1", |
26: "conv5_2", |
28: "conv5_3", |
} |
new_state_dict = {} |
for k, v in feature_layer_name_mapping.items(): |
new_state_dict["features." + str(k) + ".weight"] = state_dict[v + ".weight"] |
new_state_dict["features." + str(k) + ".bias"] = state_dict[v + ".bias"] |
classifier_layer_name_mapping = {0: "fc6", 3: "fc7", 6: "fc8"} |
for k, v in classifier_layer_name_mapping.items(): |
new_state_dict["classifier." + str(k) + ".weight"] = state_dict[v + ".weight"] |
new_state_dict["classifier." + str(k) + ".bias"] = state_dict[v + ".bias"] |
network.load_state_dict(new_state_dict) |
layer_name_mapping = { |
1: "relu_1_1", |
3: "relu_1_2", |
6: "relu_2_1", |
8: "relu_2_2", |
11: "relu_3_1", |
13: "relu_3_2", |
15: "relu_3_3", |
18: "relu_4_1", |
20: "relu_4_2", |
22: "relu_4_3", |
25: "relu_5_1", |
} |
return _PerceptualNetwork(network.features, layer_name_mapping, layers) |
class PerceptualLoss(nn.Module): |
def __init__( |
self, |
layers_weight={"relu_1_1": 0.03125, "relu_2_1": 0.0625, "relu_3_1": 0.125, "relu_4_1": 0.25, "relu_5_1": 1.0}, |
n_scale=3, |
vgg19_loss_weight=1.0, |
vggface_loss_weight=1.0, |
): |
super().__init__() |
self.vgg19 = _vgg19(layers_weight.keys()) |
self.vggface = _vgg_face(layers_weight.keys()) |
self.mse_criterion = nn.MSELoss() |
self.criterion = nn.L1Loss() |
self.layers_weight, self.n_scale = layers_weight, n_scale |
self.vgg19_loss_weight = vgg19_loss_weight |
self.vggface_loss_weight = vggface_loss_weight |
self.vgg19.eval() |
self.vggface.eval() |
def forward(self, input, target): |
""" |
input: [B, 3, H, W] in 0.~1. scale |
""" |
if input.shape[-1] != 512: |
assert input.ndim == 4 |
input = F.interpolate(input, mode="bilinear", size=(512,512), antialias=True, align_corners=False) |
target = F.interpolate(target, mode="bilinear", size=(512,512), antialias=True, align_corners=False) |
self.vgg19.eval() |
self.vggface.eval() |
loss = 0 |
features_vggface_input = self.vggface(apply_vggface_normalization(input)) |
features_vggface_target = self.vggface(apply_vggface_normalization(target)) |
input = apply_imagenet_normalization(input) |
target = apply_imagenet_normalization(target) |
features_vgg19_input = self.vgg19(input) |
features_vgg19_target = self.vgg19(target) |
for layer, weight in self.layers_weight.items(): |
tmp = self.vggface_loss_weight * weight * self.criterion(features_vggface_input[layer], features_vggface_target[layer].detach()) / 255 |
if not torch.any(torch.isnan(tmp)): |
loss += tmp |
else: |
loss += torch.zeros_like(tmp) |
tmp = self.vgg19_loss_weight * weight * self.criterion(features_vgg19_input[layer], features_vgg19_target[layer].detach()) |
if not torch.any(torch.isnan(tmp)): |
loss += tmp |
else: |
loss += torch.zeros_like(tmp) |
for i in range(self.n_scale): |
input = F.interpolate(input, mode="bilinear", scale_factor=0.5, align_corners=False, recompute_scale_factor=True) |
target = F.interpolate(target, mode="bilinear", scale_factor=0.5, align_corners=False, recompute_scale_factor=True) |
features_vgg19_input = self.vgg19(input) |
features_vgg19_target = self.vgg19(target) |
tmp = weight * self.criterion(features_vgg19_input[layer], features_vgg19_target[layer].detach()) |
if not torch.any(torch.isnan(tmp)): |
loss += tmp |
else: |
loss += torch.zeros_like(tmp) |
return loss |
class GANLoss(nn.Module): |
def __init__(self): |
super().__init__() |
def forward(self, dis_output, t_real, dis_update=True): |
r"""GAN loss computation. |
Args: |
dis_output (tensor or list of tensors): Discriminator outputs. |
t_real (bool): If ``True``, uses the real label as target, otherwise |
uses the fake label as target. |
dis_update (bool): If ``True``, the loss will be used to update the |
discriminator, otherwise the generator. |
Returns: |
loss (tensor): Loss value. |
""" |
if dis_update: |
if t_real: |
loss = fuse_math_min_mean_pos(dis_output) |
else: |
loss = fuse_math_min_mean_neg(dis_output) |
else: |
loss = -torch.mean(dis_output) |
return loss |
class FeatureMatchingLoss(nn.Module): |
def __init__(self): |
super().__init__() |
self.criterion = nn.L1Loss() |
def forward(self, fake_features, real_features): |
num_d = len(fake_features) |
dis_weight = 1.0 / num_d |
loss = fake_features[0][0].new_tensor(0) |
for i in range(num_d): |
for j in range(len(fake_features[i])): |
tmp_loss = self.criterion(fake_features[i][j], real_features[i][j].detach()) |
loss += dis_weight * tmp_loss |
return loss |
class EquivarianceLoss(nn.Module): |
def __init__(self): |
super().__init__() |
self.criterion = nn.L1Loss() |
def forward(self, kp_d, reverse_kp): |
loss = self.criterion(kp_d[:, :, :2], reverse_kp) |
return loss |
class KeypointPriorLoss(nn.Module): |
def __init__(self, Dt=0.1, zt=0.33): |
super().__init__() |
self.Dt, self.zt = Dt, zt |
def forward(self, kp_d): |
dist_mat = torch.cdist(kp_d, kp_d).square() |
loss = ( |
torch.max(0 * dist_mat, self.Dt - dist_mat).sum((1, 2)).mean() |
+ torch.abs(kp_d[:, :, 2].mean(1) - self.zt).mean() |
- kp_d.shape[1] * self.Dt |
) |
return loss |
class HeadPoseLoss(nn.Module): |
def __init__(self): |
super().__init__() |
self.criterion = nn.L1Loss() |
def forward(self, yaw, pitch, roll, real_yaw, real_pitch, real_roll): |
loss = (self.criterion(yaw, real_yaw.detach()) + self.criterion(pitch, real_pitch.detach()) + self.criterion(roll, real_roll.detach())) / 3 |
return loss / np.pi * 180 |
class DeformationPriorLoss(nn.Module): |
def __init__(self): |
super().__init__() |
def forward(self, delta_d): |
loss = delta_d.abs().mean() |
return loss |
if __name__ == '__main__': |
loss_fn = PerceptualLoss() |
x1 = torch.randn([4, 3, 512, 512]).cuda() |
x2 = torch.randn([4, 3, 512, 512]).cuda() |
loss = loss_fn(x1, x2) |