Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from src.utils import feature_normalize | |
### START### CONTEXTUAL LOSS #### | |
class ContextualLoss(nn.Module): | |
""" | |
input is Al, Bl, channel = 1, range ~ [0, 255] | |
""" | |
def __init__(self): | |
super(ContextualLoss, self).__init__() | |
return None | |
def forward(self, X_features, Y_features, h=0.1, feature_centering=True): | |
""" | |
X_features&Y_features are are feature vectors or feature 2d array | |
h: bandwidth | |
return the per-sample loss | |
""" | |
batch_size = X_features.shape[0] | |
feature_depth = X_features.shape[1] | |
# to normalized feature vectors | |
if feature_centering: | |
X_features = X_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze( | |
dim=-1 | |
) | |
Y_features = Y_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze( | |
dim=-1 | |
) | |
X_features = feature_normalize(X_features).view( | |
batch_size, feature_depth, -1 | |
) # batch_size * feature_depth * feature_size^2 | |
Y_features = feature_normalize(Y_features).view( | |
batch_size, feature_depth, -1 | |
) # batch_size * feature_depth * feature_size^2 | |
# conine distance = 1 - similarity | |
X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth | |
d = 1 - torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2 | |
# normalized distance: dij_bar | |
d_norm = d / (torch.min(d, dim=-1, keepdim=True)[0] + 1e-5) # batch_size * feature_size^2 * feature_size^2 | |
# pairwise affinity | |
w = torch.exp((1 - d_norm) / h) | |
A_ij = w / torch.sum(w, dim=-1, keepdim=True) | |
# contextual loss per sample | |
CX = torch.mean(torch.max(A_ij, dim=1)[0], dim=-1) | |
return -torch.log(CX) | |
class ContextualLoss_forward(nn.Module): | |
""" | |
input is Al, Bl, channel = 1, range ~ [0, 255] | |
""" | |
def __init__(self): | |
super(ContextualLoss_forward, self).__init__() | |
return None | |
def forward(self, X_features, Y_features, h=0.1, feature_centering=True): | |
""" | |
X_features&Y_features are are feature vectors or feature 2d array | |
h: bandwidth | |
return the per-sample loss | |
""" | |
batch_size = X_features.shape[0] | |
feature_depth = X_features.shape[1] | |
# to normalized feature vectors | |
if feature_centering: | |
X_features = X_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze( | |
dim=-1 | |
) | |
Y_features = Y_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze( | |
dim=-1 | |
) | |
X_features = feature_normalize(X_features).view( | |
batch_size, feature_depth, -1 | |
) # batch_size * feature_depth * feature_size^2 | |
Y_features = feature_normalize(Y_features).view( | |
batch_size, feature_depth, -1 | |
) # batch_size * feature_depth * feature_size^2 | |
# conine distance = 1 - similarity | |
X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth | |
d = 1 - torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2 | |
# normalized distance: dij_bar | |
d_norm = d / (torch.min(d, dim=-1, keepdim=True)[0] + 1e-5) # batch_size * feature_size^2 * feature_size^2 | |
# pairwise affinity | |
w = torch.exp((1 - d_norm) / h) | |
A_ij = w / torch.sum(w, dim=-1, keepdim=True) | |
# contextual loss per sample | |
CX = torch.mean(torch.max(A_ij, dim=-1)[0], dim=1) | |
return -torch.log(CX) | |
### END### CONTEXTUAL LOSS #### | |
########################## | |
def mse_loss_fn(input, target=0): | |
return torch.mean((input - target) ** 2) | |
### START### PERCEPTUAL LOSS ### | |
def Perceptual_loss(domain_invariant, weight_perceptual): | |
instancenorm = nn.InstanceNorm2d(512, affine=False) | |
def __call__(A_relu5_1, predict_relu5_1): | |
if domain_invariant: | |
feat_loss = ( | |
mse_loss_fn(instancenorm(predict_relu5_1), instancenorm(A_relu5_1.detach())) * weight_perceptual * 1e5 * 0.2 | |
) | |
else: | |
feat_loss = mse_loss_fn(predict_relu5_1, A_relu5_1.detach()) * weight_perceptual | |
return feat_loss | |
return __call__ | |
### END### PERCEPTUAL LOSS ### | |
def l1_loss_fn(input, target=0): | |
return torch.mean(torch.abs(input - target)) | |
### END################# | |
### START### ADVERSIAL LOSS ### | |
def generator_loss_fn(real_data_lab, fake_data_lab, discriminator, weight_gan, device): | |
if weight_gan > 0: | |
y_pred_fake, _ = discriminator(fake_data_lab) | |
y_pred_real, _ = discriminator(real_data_lab) | |
y = torch.ones_like(y_pred_real) | |
generator_loss = ( | |
( | |
torch.mean((y_pred_real - torch.mean(y_pred_fake) + y) ** 2) | |
+ torch.mean((y_pred_fake - torch.mean(y_pred_real) - y) ** 2) | |
) | |
/ 2 | |
* weight_gan | |
) | |
return generator_loss | |
return torch.Tensor([0]).to(device) | |
def discriminator_loss_fn(real_data_lab, fake_data_lab, discriminator): | |
y_pred_fake, _ = discriminator(fake_data_lab.detach()) | |
y_pred_real, _ = discriminator(real_data_lab.detach()) | |
y = torch.ones_like(y_pred_real) | |
discriminator_loss = ( | |
torch.mean((y_pred_real - torch.mean(y_pred_fake) - y) ** 2) | |
+ torch.mean((y_pred_fake - torch.mean(y_pred_real) + y) ** 2) | |
) / 2 | |
return discriminator_loss | |
### END### ADVERSIAL LOSS ##### | |
def consistent_loss_fn( | |
I_current_lab_predict, | |
I_last_ab_predict, | |
I_current_nonlocal_lab_predict, | |
I_last_nonlocal_lab_predict, | |
flow_forward, | |
mask, | |
warping_layer, | |
weight_consistent=0.02, | |
weight_nonlocal_consistent=0.0, | |
device="cuda", | |
): | |
def weighted_mse_loss(input, target, weights): | |
out = (input - target) ** 2 | |
out = out * weights.expand_as(out) | |
return out.mean() | |
def consistent(): | |
I_current_lab_predict_warp = warping_layer(I_current_lab_predict, flow_forward) | |
I_current_ab_predict_warp = I_current_lab_predict_warp[:, 1:3, :, :] | |
consistent_loss = weighted_mse_loss(I_current_ab_predict_warp, I_last_ab_predict, mask) * weight_consistent | |
return consistent_loss | |
def nonlocal_consistent(): | |
I_current_nonlocal_lab_predict_warp = warping_layer(I_current_nonlocal_lab_predict, flow_forward) | |
nonlocal_consistent_loss = ( | |
weighted_mse_loss( | |
I_current_nonlocal_lab_predict_warp[:, 1:3, :, :], | |
I_last_nonlocal_lab_predict[:, 1:3, :, :], | |
mask, | |
) | |
* weight_nonlocal_consistent | |
) | |
return nonlocal_consistent_loss | |
consistent_loss = consistent() if weight_consistent else torch.Tensor([0]).to(device) | |
nonlocal_consistent_loss = nonlocal_consistent() if weight_nonlocal_consistent else torch.Tensor([0]).to(device) | |
return consistent_loss + nonlocal_consistent_loss | |
### END### CONSISTENCY LOSS ##### | |
### START### SMOOTHNESS LOSS ### | |
def smoothness_loss_fn( | |
I_current_l, | |
I_current_lab, | |
I_current_ab_predict, | |
A_relu2_1, | |
weighted_layer_color, | |
nonlocal_weighted_layer, | |
weight_smoothness=5.0, | |
weight_nonlocal_smoothness=0.0, | |
device="cuda", | |
): | |
def smoothness(scale_factor=1.0): | |
I_current_lab_predict = torch.cat((I_current_l, I_current_ab_predict), dim=1) | |
IA_ab_weighed = weighted_layer_color( | |
I_current_lab, | |
I_current_lab_predict, | |
patch_size=3, | |
alpha=10, | |
scale_factor=scale_factor, | |
) | |
smoothness_loss = ( | |
mse_loss_fn( | |
nn.functional.interpolate(I_current_ab_predict, scale_factor=scale_factor), | |
IA_ab_weighed, | |
) | |
* weight_smoothness | |
) | |
return smoothness_loss | |
def nonlocal_smoothness(scale_factor=0.25, alpha_nonlocal_smoothness=0.5): | |
nonlocal_smooth_feature = feature_normalize(A_relu2_1) | |
I_current_lab_predict = torch.cat((I_current_l, I_current_ab_predict), dim=1) | |
I_current_ab_weighted_nonlocal = nonlocal_weighted_layer( | |
I_current_lab_predict, | |
nonlocal_smooth_feature.detach(), | |
patch_size=3, | |
alpha=alpha_nonlocal_smoothness, | |
scale_factor=scale_factor, | |
) | |
nonlocal_smoothness_loss = ( | |
mse_loss_fn( | |
nn.functional.interpolate(I_current_ab_predict, scale_factor=scale_factor), | |
I_current_ab_weighted_nonlocal, | |
) | |
* weight_nonlocal_smoothness | |
) | |
return nonlocal_smoothness_loss | |
smoothness_loss = smoothness() if weight_smoothness else torch.Tensor([0]).to(device) | |
nonlocal_smoothness_loss = nonlocal_smoothness() if weight_nonlocal_smoothness else torch.Tensor([0]).to(device) | |
return smoothness_loss + nonlocal_smoothness_loss | |
### END### SMOOTHNESS LOSS ##### | |