SwinTExCo-cpu / src /losses.py
duongttr's picture
Upload app cpu version
62b9b3d
import torch
import torch.nn as nn
from src.utils import feature_normalize
### START### CONTEXTUAL LOSS ####
class ContextualLoss(nn.Module):
"""
input is Al, Bl, channel = 1, range ~ [0, 255]
"""
def __init__(self):
super(ContextualLoss, self).__init__()
return None
def forward(self, X_features, Y_features, h=0.1, feature_centering=True):
"""
X_features&Y_features are are feature vectors or feature 2d array
h: bandwidth
return the per-sample loss
"""
batch_size = X_features.shape[0]
feature_depth = X_features.shape[1]
# to normalized feature vectors
if feature_centering:
X_features = X_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze(
dim=-1
)
Y_features = Y_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze(
dim=-1
)
X_features = feature_normalize(X_features).view(
batch_size, feature_depth, -1
) # batch_size * feature_depth * feature_size^2
Y_features = feature_normalize(Y_features).view(
batch_size, feature_depth, -1
) # batch_size * feature_depth * feature_size^2
# conine distance = 1 - similarity
X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth
d = 1 - torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2
# normalized distance: dij_bar
d_norm = d / (torch.min(d, dim=-1, keepdim=True)[0] + 1e-5) # batch_size * feature_size^2 * feature_size^2
# pairwise affinity
w = torch.exp((1 - d_norm) / h)
A_ij = w / torch.sum(w, dim=-1, keepdim=True)
# contextual loss per sample
CX = torch.mean(torch.max(A_ij, dim=1)[0], dim=-1)
return -torch.log(CX)
class ContextualLoss_forward(nn.Module):
"""
input is Al, Bl, channel = 1, range ~ [0, 255]
"""
def __init__(self):
super(ContextualLoss_forward, self).__init__()
return None
def forward(self, X_features, Y_features, h=0.1, feature_centering=True):
"""
X_features&Y_features are are feature vectors or feature 2d array
h: bandwidth
return the per-sample loss
"""
batch_size = X_features.shape[0]
feature_depth = X_features.shape[1]
# to normalized feature vectors
if feature_centering:
X_features = X_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze(
dim=-1
)
Y_features = Y_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze(
dim=-1
)
X_features = feature_normalize(X_features).view(
batch_size, feature_depth, -1
) # batch_size * feature_depth * feature_size^2
Y_features = feature_normalize(Y_features).view(
batch_size, feature_depth, -1
) # batch_size * feature_depth * feature_size^2
# conine distance = 1 - similarity
X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth
d = 1 - torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2
# normalized distance: dij_bar
d_norm = d / (torch.min(d, dim=-1, keepdim=True)[0] + 1e-5) # batch_size * feature_size^2 * feature_size^2
# pairwise affinity
w = torch.exp((1 - d_norm) / h)
A_ij = w / torch.sum(w, dim=-1, keepdim=True)
# contextual loss per sample
CX = torch.mean(torch.max(A_ij, dim=-1)[0], dim=1)
return -torch.log(CX)
### END### CONTEXTUAL LOSS ####
##########################
def mse_loss_fn(input, target=0):
return torch.mean((input - target) ** 2)
### START### PERCEPTUAL LOSS ###
def Perceptual_loss(domain_invariant, weight_perceptual):
instancenorm = nn.InstanceNorm2d(512, affine=False)
def __call__(A_relu5_1, predict_relu5_1):
if domain_invariant:
feat_loss = (
mse_loss_fn(instancenorm(predict_relu5_1), instancenorm(A_relu5_1.detach())) * weight_perceptual * 1e5 * 0.2
)
else:
feat_loss = mse_loss_fn(predict_relu5_1, A_relu5_1.detach()) * weight_perceptual
return feat_loss
return __call__
### END### PERCEPTUAL LOSS ###
def l1_loss_fn(input, target=0):
return torch.mean(torch.abs(input - target))
### END#################
### START### ADVERSIAL LOSS ###
def generator_loss_fn(real_data_lab, fake_data_lab, discriminator, weight_gan, device):
if weight_gan > 0:
y_pred_fake, _ = discriminator(fake_data_lab)
y_pred_real, _ = discriminator(real_data_lab)
y = torch.ones_like(y_pred_real)
generator_loss = (
(
torch.mean((y_pred_real - torch.mean(y_pred_fake) + y) ** 2)
+ torch.mean((y_pred_fake - torch.mean(y_pred_real) - y) ** 2)
)
/ 2
* weight_gan
)
return generator_loss
return torch.Tensor([0]).to(device)
def discriminator_loss_fn(real_data_lab, fake_data_lab, discriminator):
y_pred_fake, _ = discriminator(fake_data_lab.detach())
y_pred_real, _ = discriminator(real_data_lab.detach())
y = torch.ones_like(y_pred_real)
discriminator_loss = (
torch.mean((y_pred_real - torch.mean(y_pred_fake) - y) ** 2)
+ torch.mean((y_pred_fake - torch.mean(y_pred_real) + y) ** 2)
) / 2
return discriminator_loss
### END### ADVERSIAL LOSS #####
def consistent_loss_fn(
I_current_lab_predict,
I_last_ab_predict,
I_current_nonlocal_lab_predict,
I_last_nonlocal_lab_predict,
flow_forward,
mask,
warping_layer,
weight_consistent=0.02,
weight_nonlocal_consistent=0.0,
device="cuda",
):
def weighted_mse_loss(input, target, weights):
out = (input - target) ** 2
out = out * weights.expand_as(out)
return out.mean()
def consistent():
I_current_lab_predict_warp = warping_layer(I_current_lab_predict, flow_forward)
I_current_ab_predict_warp = I_current_lab_predict_warp[:, 1:3, :, :]
consistent_loss = weighted_mse_loss(I_current_ab_predict_warp, I_last_ab_predict, mask) * weight_consistent
return consistent_loss
def nonlocal_consistent():
I_current_nonlocal_lab_predict_warp = warping_layer(I_current_nonlocal_lab_predict, flow_forward)
nonlocal_consistent_loss = (
weighted_mse_loss(
I_current_nonlocal_lab_predict_warp[:, 1:3, :, :],
I_last_nonlocal_lab_predict[:, 1:3, :, :],
mask,
)
* weight_nonlocal_consistent
)
return nonlocal_consistent_loss
consistent_loss = consistent() if weight_consistent else torch.Tensor([0]).to(device)
nonlocal_consistent_loss = nonlocal_consistent() if weight_nonlocal_consistent else torch.Tensor([0]).to(device)
return consistent_loss + nonlocal_consistent_loss
### END### CONSISTENCY LOSS #####
### START### SMOOTHNESS LOSS ###
def smoothness_loss_fn(
I_current_l,
I_current_lab,
I_current_ab_predict,
A_relu2_1,
weighted_layer_color,
nonlocal_weighted_layer,
weight_smoothness=5.0,
weight_nonlocal_smoothness=0.0,
device="cuda",
):
def smoothness(scale_factor=1.0):
I_current_lab_predict = torch.cat((I_current_l, I_current_ab_predict), dim=1)
IA_ab_weighed = weighted_layer_color(
I_current_lab,
I_current_lab_predict,
patch_size=3,
alpha=10,
scale_factor=scale_factor,
)
smoothness_loss = (
mse_loss_fn(
nn.functional.interpolate(I_current_ab_predict, scale_factor=scale_factor),
IA_ab_weighed,
)
* weight_smoothness
)
return smoothness_loss
def nonlocal_smoothness(scale_factor=0.25, alpha_nonlocal_smoothness=0.5):
nonlocal_smooth_feature = feature_normalize(A_relu2_1)
I_current_lab_predict = torch.cat((I_current_l, I_current_ab_predict), dim=1)
I_current_ab_weighted_nonlocal = nonlocal_weighted_layer(
I_current_lab_predict,
nonlocal_smooth_feature.detach(),
patch_size=3,
alpha=alpha_nonlocal_smoothness,
scale_factor=scale_factor,
)
nonlocal_smoothness_loss = (
mse_loss_fn(
nn.functional.interpolate(I_current_ab_predict, scale_factor=scale_factor),
I_current_ab_weighted_nonlocal,
)
* weight_nonlocal_smoothness
)
return nonlocal_smoothness_loss
smoothness_loss = smoothness() if weight_smoothness else torch.Tensor([0]).to(device)
nonlocal_smoothness_loss = nonlocal_smoothness() if weight_nonlocal_smoothness else torch.Tensor([0]).to(device)
return smoothness_loss + nonlocal_smoothness_loss
### END### SMOOTHNESS LOSS #####