ameerazam08's picture
Upload folder using huggingface_hub
e34aada verified
import torch
import torch.nn.functional as F
import numpy as np
import torchvision
from torch import nn
from modules.real3d.facev2v_warp.func_utils import apply_imagenet_normalization, apply_vggface_normalization
@torch.jit.script
def fuse_math_min_mean_pos(x):
r"""Fuse operation min mean for hinge loss computation of positive
samples"""
minval = torch.min(x - 1, x * 0)
loss = -torch.mean(minval)
return loss
@torch.jit.script
def fuse_math_min_mean_neg(x):
r"""Fuse operation min mean for hinge loss computation of negative
samples"""
minval = torch.min(-x - 1, x * 0)
loss = -torch.mean(minval)
return loss
class _PerceptualNetwork(nn.Module):
def __init__(self, network, layer_name_mapping, layers):
super().__init__()
self.network = network.cuda()
self.layer_name_mapping = layer_name_mapping
self.layers = layers
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
output = {}
for i, layer in enumerate(self.network):
x = layer(x)
layer_name = self.layer_name_mapping.get(i, None)
if layer_name in self.layers:
output[layer_name] = x
return output
def _vgg19(layers):
network = torchvision.models.vgg19()
state_dict = torch.utils.model_zoo.load_url(
"https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", map_location=torch.device("cpu"), progress=True
)
network.load_state_dict(state_dict)
network = network.features
layer_name_mapping = {
1: "relu_1_1",
3: "relu_1_2",
6: "relu_2_1",
8: "relu_2_2",
11: "relu_3_1",
13: "relu_3_2",
15: "relu_3_3",
17: "relu_3_4",
20: "relu_4_1",
22: "relu_4_2",
24: "relu_4_3",
26: "relu_4_4",
29: "relu_5_1",
}
return _PerceptualNetwork(network, layer_name_mapping, layers)
def _vgg_face(layers):
network = torchvision.models.vgg16(num_classes=2622)
state_dict = torch.utils.model_zoo.load_url(
"http://www.robots.ox.ac.uk/~albanie/models/pytorch-mcn/" "vgg_face_dag.pth", map_location=torch.device("cpu"), progress=True
)
feature_layer_name_mapping = {
0: "conv1_1",
2: "conv1_2",
5: "conv2_1",
7: "conv2_2",
10: "conv3_1",
12: "conv3_2",
14: "conv3_3",
17: "conv4_1",
19: "conv4_2",
21: "conv4_3",
24: "conv5_1",
26: "conv5_2",
28: "conv5_3",
}
new_state_dict = {}
for k, v in feature_layer_name_mapping.items():
new_state_dict["features." + str(k) + ".weight"] = state_dict[v + ".weight"]
new_state_dict["features." + str(k) + ".bias"] = state_dict[v + ".bias"]
classifier_layer_name_mapping = {0: "fc6", 3: "fc7", 6: "fc8"}
for k, v in classifier_layer_name_mapping.items():
new_state_dict["classifier." + str(k) + ".weight"] = state_dict[v + ".weight"]
new_state_dict["classifier." + str(k) + ".bias"] = state_dict[v + ".bias"]
network.load_state_dict(new_state_dict)
layer_name_mapping = {
1: "relu_1_1",
3: "relu_1_2",
6: "relu_2_1",
8: "relu_2_2",
11: "relu_3_1",
13: "relu_3_2",
15: "relu_3_3",
18: "relu_4_1",
20: "relu_4_2",
22: "relu_4_3",
25: "relu_5_1",
}
return _PerceptualNetwork(network.features, layer_name_mapping, layers)
class PerceptualLoss(nn.Module):
def __init__(
self,
layers_weight={"relu_1_1": 0.03125, "relu_2_1": 0.0625, "relu_3_1": 0.125, "relu_4_1": 0.25, "relu_5_1": 1.0},
n_scale=3,
vgg19_loss_weight=1.0,
vggface_loss_weight=1.0,
):
super().__init__()
self.vgg19 = _vgg19(layers_weight.keys())
self.vggface = _vgg_face(layers_weight.keys())
self.mse_criterion = nn.MSELoss()
self.criterion = nn.L1Loss()
self.layers_weight, self.n_scale = layers_weight, n_scale
self.vgg19_loss_weight = vgg19_loss_weight
self.vggface_loss_weight = vggface_loss_weight
self.vgg19.eval()
self.vggface.eval()
def forward(self, input, target):
"""
input: [B, 3, H, W] in 0.~1. scale
"""
if input.shape[-1] != 512:
assert input.ndim == 4
input = F.interpolate(input, mode="bilinear", size=(512,512), antialias=True, align_corners=False)
target = F.interpolate(target, mode="bilinear", size=(512,512), antialias=True, align_corners=False)
self.vgg19.eval()
self.vggface.eval()
loss = 0
features_vggface_input = self.vggface(apply_vggface_normalization(input))
features_vggface_target = self.vggface(apply_vggface_normalization(target))
input = apply_imagenet_normalization(input)
target = apply_imagenet_normalization(target)
features_vgg19_input = self.vgg19(input)
features_vgg19_target = self.vgg19(target)
for layer, weight in self.layers_weight.items():
tmp = self.vggface_loss_weight * weight * self.criterion(features_vggface_input[layer], features_vggface_target[layer].detach()) / 255
if not torch.any(torch.isnan(tmp)):
loss += tmp
else:
loss += torch.zeros_like(tmp)
tmp = self.vgg19_loss_weight * weight * self.criterion(features_vgg19_input[layer], features_vgg19_target[layer].detach())
if not torch.any(torch.isnan(tmp)):
loss += tmp
else:
loss += torch.zeros_like(tmp)
for i in range(self.n_scale):
input = F.interpolate(input, mode="bilinear", scale_factor=0.5, align_corners=False, recompute_scale_factor=True)
target = F.interpolate(target, mode="bilinear", scale_factor=0.5, align_corners=False, recompute_scale_factor=True)
features_vgg19_input = self.vgg19(input)
features_vgg19_target = self.vgg19(target)
tmp = weight * self.criterion(features_vgg19_input[layer], features_vgg19_target[layer].detach())
if not torch.any(torch.isnan(tmp)):
loss += tmp
else:
loss += torch.zeros_like(tmp)
return loss
class GANLoss(nn.Module):
# Update generator: gan_loss(fake_output, True, False) + other losses
# Update discriminator: gan_loss(fake_output(detached), False, True) + gan_loss(real_output, True, True)
def __init__(self):
super().__init__()
def forward(self, dis_output, t_real, dis_update=True):
r"""GAN loss computation.
Args:
dis_output (tensor or list of tensors): Discriminator outputs.
t_real (bool): If ``True``, uses the real label as target, otherwise
uses the fake label as target.
dis_update (bool): If ``True``, the loss will be used to update the
discriminator, otherwise the generator.
Returns:
loss (tensor): Loss value.
"""
if dis_update:
if t_real:
loss = fuse_math_min_mean_pos(dis_output)
else:
loss = fuse_math_min_mean_neg(dis_output)
else:
loss = -torch.mean(dis_output)
return loss
class FeatureMatchingLoss(nn.Module):
def __init__(self):
super().__init__()
self.criterion = nn.L1Loss()
def forward(self, fake_features, real_features):
num_d = len(fake_features)
dis_weight = 1.0 / num_d
loss = fake_features[0][0].new_tensor(0)
for i in range(num_d):
for j in range(len(fake_features[i])):
tmp_loss = self.criterion(fake_features[i][j], real_features[i][j].detach())
loss += dis_weight * tmp_loss
return loss
class EquivarianceLoss(nn.Module):
def __init__(self):
super().__init__()
self.criterion = nn.L1Loss()
def forward(self, kp_d, reverse_kp):
loss = self.criterion(kp_d[:, :, :2], reverse_kp)
return loss
class KeypointPriorLoss(nn.Module):
def __init__(self, Dt=0.1, zt=0.33):
super().__init__()
self.Dt, self.zt = Dt, zt
def forward(self, kp_d):
# use distance matrix to avoid loop
dist_mat = torch.cdist(kp_d, kp_d).square()
loss = (
torch.max(0 * dist_mat, self.Dt - dist_mat).sum((1, 2)).mean()
+ torch.abs(kp_d[:, :, 2].mean(1) - self.zt).mean()
- kp_d.shape[1] * self.Dt
)
return loss
class HeadPoseLoss(nn.Module):
def __init__(self):
super().__init__()
self.criterion = nn.L1Loss()
def forward(self, yaw, pitch, roll, real_yaw, real_pitch, real_roll):
loss = (self.criterion(yaw, real_yaw.detach()) + self.criterion(pitch, real_pitch.detach()) + self.criterion(roll, real_roll.detach())) / 3
return loss / np.pi * 180
class DeformationPriorLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, delta_d):
loss = delta_d.abs().mean()
return loss
if __name__ == '__main__':
loss_fn = PerceptualLoss()
x1 = torch.randn([4, 3, 512, 512]).cuda()
x2 = torch.randn([4, 3, 512, 512]).cuda()
loss = loss_fn(x1, x2)