Metric3D / training /mono /model /losses /NormalBranchLoss.py
zach
initial commit based on github repo
3ef1661
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from .depth_to_normal import Depth2Normal
# compute loss
class NormalBranchLoss(nn.Module):
def __init__(self, loss_weight=1.0, data_type=['sfm', 'stereo', 'denselidar', 'denselidar_syn'], d2n_dataset=['ScanNetAll'], loss_fn='UG_NLL_ours', **kwargs):
"""loss_fn can be one of following:
- L1 - L1 loss (no uncertainty)
- L2 - L2 loss (no uncertainty)
- AL - Angular loss (no uncertainty)
- NLL_vMF - NLL of vonMF distribution
- NLL_ours - NLL of Angular vonMF distribution
- UG_NLL_vMF - NLL of vonMF distribution (+ pixel-wise MLP + uncertainty-guided sampling)
- UG_NLL_ours - NLL of Angular vonMF distribution (+ pixel-wise MLP + uncertainty-guided sampling)
- NLL_ours_GRU - NLL of Angular vonMF distribution for GRU sequence
"""
super(NormalBranchLoss, self).__init__()
self.loss_type = loss_fn
if self.loss_type in ['L1', 'L2', 'AL', 'NLL_vMF', 'NLL_ours']:
# self.loss_fn = self.forward_R
raise NotImplementedError
elif self.loss_type in ['UG_NLL_vMF']:
# self.loss_fn = self.forward_UG
raise NotImplementedError
elif self.loss_type in ['UG_NLL_ours']:
self.loss_fn = self.forward_UG
elif self.loss_type in ['NLL_ours_GRU', 'NLL_ours_GRU_auxi']:
self.loss_type = 'NLL_ours'
self.loss_fn = self.forward_GRU
self.loss_gamma = 0.9
try:
self.loss_weight_auxi = kwargs['loss_weight_auxi']
except:
self.loss_weight_auxi = 0.0
else:
raise Exception('invalid loss type')
self.loss_weight = loss_weight
self.data_type = data_type
#self.d2n_dataset = d2n_dataset
#self.depth2normal = Depth2Normal()
def forward(self, **kwargs):
# device = kwargs['mask'].device
# B, _, H, W = kwargs['mask'].shape
# pad_mask = torch.zeros_like(kwargs['mask'], device=device)
# for b in range(B):
# pad = kwargs['pad'][b].squeeze()
# pad_mask[b, :, pad[0]:H-pad[1], pad[2]:W-pad[3]] = True
# loss = self.loss_fn(pad_mask=pad_mask, **kwargs)
loss = self.loss_fn(**kwargs)
return loss * self.loss_weight
def forward_GRU(self, normal_out_list, normal, target, mask, intrinsic, pad_mask=None, auxi_normal=None, **kwargs):
n_predictions = len(normal_out_list)
assert n_predictions >= 1
loss = 0.0
# device = pad_mask.device
# batches_dataset = kwargs['dataset']
# self.batch_with_d2n = torch.tensor([0 if batch_dataset not in self.d2n_dataset else 1 \
# for batch_dataset in batches_dataset], device=device)[:,None,None,None]
# scale = kwargs['scale'][:, None, None].float()
# normal_d2n, new_mask_d2n = self.depth2normal(target, intrinsic, pad_mask, scale)
gt_normal_mask = ~torch.all(normal == 0, dim=1, keepdim=True) & mask
if auxi_normal != None:
auxi_normal_mask = ~gt_normal_mask
#normal = normal * (1 - self.batch_with_d2n) + normal_d2n * self.batch_with_d2n
# gt_normal_mask = gt_normal_mask * (1 - self.batch_with_d2n) + mask * new_mask_d2n * self.batch_with_d2n
if gt_normal_mask.sum() < 10:
if auxi_normal == None:
for norm_out in normal_out_list:
loss += norm_out.sum() * 0
return loss
for i, norm_out in enumerate(normal_out_list):
# We adjust the loss_gamma so it is consistent for any number of RAFT-Stereo iterations
adjusted_loss_gamma = self.loss_gamma**(15/(n_predictions - 1))
i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
curr_loss = self.forward_R(norm_out.clone(), normal, gt_normal_mask)
if auxi_normal != None:
auxi_loss = self.forward_R(norm_out.clone(), auxi_normal[:, :3], auxi_normal_mask)
curr_loss = curr_loss + self.loss_weight_auxi * auxi_loss
if torch.isnan(curr_loss).item() | torch.isinf(curr_loss).item():
curr_loss = 0 * torch.sum(norm_out)
print(f'NormalBranchLoss forward_GRU NAN error, {curr_loss}')
loss += curr_loss * i_weight
return loss
def forward_R(self, norm_out, gt_norm, gt_norm_mask):
pred_norm, pred_kappa = norm_out[:, 0:3, :, :], norm_out[:, 3:, :, :]
if self.loss_type == 'L1':
l1 = torch.sum(torch.abs(gt_norm - pred_norm), dim=1, keepdim=True)
loss = torch.mean(l1[gt_norm_mask])
elif self.loss_type == 'L2':
l2 = torch.sum(torch.square(gt_norm - pred_norm), dim=1, keepdim=True)
loss = torch.mean(l2[gt_norm_mask])
elif self.loss_type == 'AL':
dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
valid_mask = gt_norm_mask[:, 0, :, :].float() \
* (dot.detach() < 0.999).float() \
* (dot.detach() > -0.999).float()
valid_mask = valid_mask > 0.0
al = torch.acos(dot[valid_mask])
loss = torch.mean(al)
elif self.loss_type == 'NLL_vMF':
dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
valid_mask = gt_norm_mask[:, 0, :, :].float() \
* (dot.detach() < 0.999).float() \
* (dot.detach() > -0.999).float()
valid_mask = valid_mask > 0.0
dot = dot[valid_mask]
kappa = pred_kappa[:, 0, :, :][valid_mask]
loss_pixelwise = - torch.log(kappa) \
- (kappa * (dot - 1)) \
+ torch.log(1 - torch.exp(- 2 * kappa))
loss = torch.mean(loss_pixelwise)
elif self.loss_type == 'NLL_ours':
dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
valid_mask = gt_norm_mask[:, 0, :, :].float() \
* (dot.detach() < 0.999).float() \
* (dot.detach() > -0.999).float()
valid_mask = valid_mask > 0.5
dot = dot[valid_mask]
kappa = pred_kappa[:, 0, :, :][valid_mask]
loss_pixelwise = - torch.log(torch.square(kappa) + 1) \
+ kappa * torch.acos(dot) \
+ torch.log(1 + torch.exp(-kappa * np.pi))
loss = torch.mean(loss_pixelwise)
else:
raise Exception('invalid loss type')
return loss
def forward_UG(self, normal_pred_list, normal_coord_list, normal, mask, **kwargs):
gt_normal_mask = ~torch.all(normal == 0, dim=1, keepdim=True) & mask
# gt_norm = norms[0]
# gt_normal_mask = (gt_norm[:, 0:1, :, :] == 0) & (gt_norm[:, 1:2, :, :] == 0) & (gt_norm[:, 2:3, :, :] == 0)
# gt_normal_mask = ~gt_normal_mask
loss = 0.0
if gt_normal_mask[gt_normal_mask].numel() < 10:
for (pred, coord) in zip(normal_pred_list, normal_coord_list):
if pred is not None:
loss += pred.sum() * 0.
if coord is not None:
loss += coord.sum() * 0.
return loss
for (pred, coord) in zip(normal_pred_list, normal_coord_list):
if coord is None:
pred = F.interpolate(pred, size=[normal.size(2), normal.size(3)], mode='bilinear', align_corners=True)
pred_norm, pred_kappa = pred[:, 0:3, :, :], pred[:, 3:, :, :]
# if self.loss_type == 'UG_NLL_vMF':
# dot = torch.cosine_similarity(pred_norm, normal, dim=1)
# valid_mask = normal_mask[:, 0, :, :].float() \
# * (dot.detach() < 0.999).float() \
# * (dot.detach() > -0.999).float()
# valid_mask = valid_mask > 0.5
# # mask
# dot = dot[valid_mask]
# kappa = pred_kappa[:, 0, :, :][valid_mask]
# loss_pixelwise = - torch.log(kappa) \
# - (kappa * (dot - 1)) \
# + torch.log(1 - torch.exp(- 2 * kappa))
# loss = loss + torch.mean(loss_pixelwise)
if self.loss_type == 'UG_NLL_ours':
dot = torch.cosine_similarity(pred_norm, normal, dim=1)
valid_mask = gt_normal_mask[:, 0, :, :].float() \
* (dot.detach() < 0.999).float() \
* (dot.detach() > -0.999).float()
valid_mask = valid_mask > 0.5
dot = dot[valid_mask]
kappa = pred_kappa[:, 0, :, :][valid_mask]
loss_pixelwise = - torch.log(torch.square(kappa) + 1) \
+ kappa * torch.acos(dot) \
+ torch.log(1 + torch.exp(-kappa * np.pi))
loss = loss + torch.mean(loss_pixelwise)
else:
raise Exception
else:
# coord: B, 1, N, 2
# pred: B, 4, N
gt_norm_ = F.grid_sample(normal, coord, mode='nearest', align_corners=True) # (B, 3, 1, N)
gt_norm_mask_ = F.grid_sample(gt_normal_mask.float(), coord, mode='nearest', align_corners=True) # (B, 1, 1, N)
gt_norm_ = gt_norm_[:, :, 0, :] # (B, 3, N)
gt_norm_mask_ = gt_norm_mask_[:, :, 0, :] > 0.5 # (B, 1, N)
pred_norm, pred_kappa = pred[:, 0:3, :], pred[:, 3:, :]
# if self.loss_type == 'UG_NLL_vMF':
# dot = torch.cosine_similarity(pred_norm, gt_norm_, dim=1) # (B, N)
# valid_mask = gt_norm_mask_[:, 0, :].float() \
# * (dot.detach() < 0.999).float() \
# * (dot.detach() > -0.999).float()
# valid_mask = valid_mask > 0.5
# dot = dot[valid_mask]
# kappa = pred_kappa[:, 0, :][valid_mask]
# loss_pixelwise = - torch.log(kappa) \
# - (kappa * (dot - 1)) \
# + torch.log(1 - torch.exp(- 2 * kappa))
# loss = loss + torch.mean(loss_pixelwise)
if self.loss_type == 'UG_NLL_ours':
dot = torch.cosine_similarity(pred_norm, gt_norm_, dim=1) # (B, N)
valid_mask = gt_norm_mask_[:, 0, :].float() \
* (dot.detach() < 0.999).float() \
* (dot.detach() > -0.999).float()
valid_mask = valid_mask > 0.5
dot = dot[valid_mask]
kappa = pred_kappa[:, 0, :][valid_mask]
loss_pixelwise = - torch.log(torch.square(kappa) + 1) \
+ kappa * torch.acos(dot) \
+ torch.log(1 + torch.exp(-kappa * np.pi))
loss = loss + torch.mean(loss_pixelwise)
else:
raise Exception
return loss
# confidence-guided sampling
@torch.no_grad()
def sample_points(init_normal, confidence_map, gt_norm_mask, sampling_ratio, beta=1):
device = init_normal.device
B, _, H, W = init_normal.shape
N = int(sampling_ratio * H * W)
beta = beta
# confidence map
# confidence_map = init_normal[:, 3, :, :] # B, H, W
# gt_invalid_mask (B, H, W)
if gt_norm_mask is not None:
gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest')
gt_invalid_mask = gt_invalid_mask < 0.5
confidence_map[gt_invalid_mask] = -1e4
# (B, H*W)
_, idx = confidence_map.view(B, -1).sort(1, descending=True)
# confidence sampling
if int(beta * N) > 0:
importance = idx[:, :int(beta * N)] # B, beta*N
# remaining
remaining = idx[:, int(beta * N):] # B, H*W - beta*N
# coverage
num_coverage = N - int(beta * N)
if num_coverage <= 0:
samples = importance
else:
coverage_list = []
for i in range(B):
idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
samples = torch.cat((importance, coverage), dim=1) # B, N
else:
# remaining
remaining = idx[:, :] # B, H*W
# coverage
num_coverage = N
coverage_list = []
for i in range(B):
idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
samples = coverage
# point coordinates
rows_int = samples // W # 0 for first row, H-1 for last row
# rows_float = rows_int / float(H-1) # 0 to 1.0
# rows_float = (rows_float * 2.0) - 1.0 # -1.0 to 1.0
cols_int = samples % W # 0 for first column, W-1 for last column
# cols_float = cols_int / float(W-1) # 0 to 1.0
# cols_float = (cols_float * 2.0) - 1.0 # -1.0 to 1.0
# point_coords = torch.zeros(B, 1, N, 2)
# point_coords[:, 0, :, 0] = cols_float # x coord
# point_coords[:, 0, :, 1] = rows_float # y coord
# point_coords = point_coords.to(device)
# return point_coords, rows_int, cols_int
sample_mask = torch.zeros((B,1,H,W), dtype=torch.bool, device=device)
for i in range(B):
sample_mask[i, :, rows_int[i,:], cols_int[i,:]] = True
return sample_mask
# depth-normal consistency loss
class DeNoConsistencyLoss(nn.Module):
def __init__(self, loss_weight=1.0, data_type=['stereo', 'lidar', 'denselidar', 'denselidar_nometric', 'denselidar_syn'], loss_fn='NLL_ours', \
sky_id=142, scale=1, norm_dataset=['Taskonomy', 'Matterport3D', 'Replica', 'Hypersim', 'NYU'], no_sky_dataset=['BigData', 'DIODE', 'Completion', 'Matterport3D'], disable_dataset=[], depth_detach=False, **kwargs):
"""loss_fn can be one of following:
- L1 - L1 loss (no uncertainty)
- L2 - L2 loss (no uncertainty)
- AL - Angular loss (no uncertainty)
- NLL_vMF - NLL of vonMF distribution
- NLL_ours - NLL of Angular vonMF distribution
- UG_NLL_vMF - NLL of vonMF distribution (+ pixel-wise MLP + uncertainty-guided sampling)
- UG_NLL_ours - NLL of Angular vonMF distribution (+ pixel-wise MLP + uncertainty-guided sampling)
- NLL_ours_GRU - NLL of Angular vonMF distribution for GRU sequence
- CEL - cosine embedding loss
- CEL_GRU
"""
super(DeNoConsistencyLoss, self).__init__()
self.loss_type = loss_fn
if self.loss_type in ['L1', 'L2', 'NLL_vMF']:
# self.loss_fn = self.forward_R
raise NotImplementedError
elif self.loss_type in ['UG_NLL_vMF']:
# self.loss_fn = self.forward_UG
raise NotImplementedError
elif self.loss_type in ['UG_NLL_ours']:
# self.loss_fn = self.forward_UG
raise NotImplementedError
elif self.loss_type in ['NLL_ours']:
self.loss_fn = self.forward_J # confidence Joint optimization
self.loss_gamma = 0.9
elif self.loss_type in ['AL', 'CEL', 'CEL_L2']:
self.loss_fn = self.forward_S # confidence Sample
elif self.loss_type in ['CEL_GRU']:
self.loss_fn = self.forward_S_GRU # gru
self.loss_gamma = 0.9
elif 'Search' in self.loss_type:
self.loss_fn = self.forward_S_Search
else:
raise Exception('invalid loss type')
self.loss_weight = loss_weight
self.data_type = data_type
self.sky_id = sky_id
# For datasets without surface normal gt, enhance its weight (decrease the weight of the dataset with gt).
self.nonorm_data_scale = scale
self.norm_dataset = norm_dataset
self.no_sky_dataset = no_sky_dataset
self.disable_dataset = disable_dataset
self.depth_detach = depth_detach
self.depth2normal = Depth2Normal()
def forward(self, **kwargs):
device = kwargs['mask'].device
batches_dataset = kwargs['dataset']
self.batch_with_norm = torch.tensor([self.nonorm_data_scale if batch_dataset not in self.norm_dataset else 1 \
for batch_dataset in batches_dataset], device=device)[:,None,None,None]
self.batch_enabled= torch.tensor([1 if batch_dataset not in self.disable_dataset else 0 \
for batch_dataset in batches_dataset], device=device, dtype=torch.bool)[:,None,None,None]
self.batch_with_norm = self.batch_with_norm * self.batch_enabled
self.batch_with_norm_sky = torch.tensor([1 if batch_dataset not in self.no_sky_dataset else 0 \
for batch_dataset in batches_dataset], device=device, dtype=torch.bool)[:,None,None,None]
B, _, H, W = kwargs['mask'].shape
pad_mask = torch.zeros_like(kwargs['mask'], device=device)
for b in range(B):
pad = kwargs['pad'][b].squeeze()
pad_mask[b, :, pad[0]:H-pad[1], pad[2]:W-pad[3]] = True
loss = self.loss_fn(pad_mask=pad_mask, **kwargs)
return loss * self.loss_weight
def forward_J(self, prediction, confidence, normal_out_list, intrinsic, pad_mask, sem_mask=None, **kwargs):
prediction_normal = normal_out_list[-1].clone()
# get normal from depth-prediction
normal, new_mask = self.depth2normal(prediction.detach() if self.depth_detach else prediction, intrinsic, pad_mask)
# mask sky
sky_mask = sem_mask != self.sky_id
new_mask = new_mask & sky_mask
# normal = normal * (~sky_mask)
# normal[:,1:2,:,:][sky_mask] = 1
# confidence sampling (sample good depth -> good normal -> to )
sample_mask_d = sample_points(prediction, confidence, new_mask, sampling_ratio=0.7)
# all mask
normal_mask = ~torch.all(normal == 0, dim=1, keepdim=True) & new_mask & sample_mask_d
if normal_mask.sum() < 10:
return 0 * prediction_normal.sum()
loss = self.forward_R(prediction_normal, normal, normal_mask)
if torch.isnan(loss).item() | torch.isinf(loss).item():
loss = 0 * torch.sum(prediction_normal)
print(f'NormalBranchLoss forward_GRU NAN error, {loss}')
return loss
#def forward_S(self, prediction, confidence, normal_out_list, intrinsic, pad_mask, sem_mask=None, **kwargs):
def forward_S(self, prediction, confidence, intrinsic, pad_mask, normal_pred=None, sem_mask=None, target=None, is_initial_pair=False, **kwargs):
if normal_pred is None:
prediction_normal = kwargs['normal_out_list'][-1]
else:
prediction_normal = normal_pred
# get normal from depth-prediction
#try:
scale = kwargs['scale'][:, None, None].float()
#except:
#scale = 1.0
normal, new_mask = self.depth2normal(prediction.detach() if self.depth_detach else prediction, intrinsic, pad_mask, scale)
sky_mask = sem_mask != self.sky_id
if target != None:
sampling_ratio = 0.7
target_mask = (target > 0)
if is_initial_pair == False:
pass
# mask sky
else:
sky_mask = torch.nn.functional.interpolate(sky_mask.float(), scale_factor=0.25).bool()
target_mask = torch.nn.functional.interpolate(target_mask.float(), scale_factor=0.25).bool()
new_mask = new_mask & ((sky_mask & self.batch_with_norm_sky) | target_mask)
else:
new_mask = torch.ones_like(prediction).bool()
sampling_ratio = 0.5
# normal = normal * (~sky_mask)
# normal[:,1:2,:,:][sky_mask] = 1
# dual sampling
confidence_normal = prediction_normal[:, 3:, :, :]
sample_mask_n = sample_points(prediction_normal, confidence_normal, new_mask, sampling_ratio=sampling_ratio)
sample_mask_d = sample_points(prediction, confidence, new_mask, sampling_ratio=sampling_ratio)
conf_mask = confidence > 0.5
# all mask
normal_mask = ~torch.all(normal == 0, dim=1, keepdim=True) & new_mask & sample_mask_n & sample_mask_d & conf_mask
if normal_mask.sum() < 10:
return 0 * prediction_normal.sum()
loss = self.forward_R(prediction_normal, normal, normal_mask)
if torch.isnan(loss).item() | torch.isinf(loss).item():
loss = 0 * torch.sum(prediction_normal)
print(f'NormalBranchLoss forward_GRU NAN error, {loss}')
return loss
def forward_S_GRU(self, predictions_list, confidence_list, normal_out_list, intrinsic, pad_mask, sem_mask, target, low_resolution_init, **kwargs):
n_predictions = len(normal_out_list)
assert n_predictions >= 1
loss = 0.0
for i, (norm, conf, depth) in enumerate(zip(normal_out_list, confidence_list, predictions_list)):
# We adjust the loss_gamma so it is consistent for any number of RAFT-Stereo iterations
adjusted_loss_gamma = self.loss_gamma**(15/(n_predictions - 1))
i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
if i == 0:
is_initial_pair = True
new_intrinsic = torch.cat((intrinsic[:, :2, :]/4, intrinsic[:, 2:3, :]), dim=1)
curr_loss = self.forward_S(low_resolution_init[0], low_resolution_init[1], new_intrinsic, torch.nn.functional.interpolate(pad_mask.float(), scale_factor=0.25).bool(), low_resolution_init[2], sem_mask, target, is_initial_pair, scale=kwargs['scale'])
else:
is_initial_pair = False
curr_loss = self.forward_S(depth, conf, intrinsic, pad_mask, norm, sem_mask, target, is_initial_pair, scale=kwargs['scale'])
if torch.isnan(curr_loss).item() | torch.isinf(curr_loss).item():
curr_loss = 0 * torch.sum(norm)
print(f'NormalBranchLoss forward_GRU NAN error, {curr_loss}')
loss += curr_loss * i_weight
return loss
def forward_R(self, norm_out, gt_norm, gt_norm_mask, pred_kappa=None):
pred_norm = norm_out[:, 0:3, :, :]
if pred_kappa is None:
pred_kappa = norm_out[:, 3:, :, :]
if self.loss_type == 'L1':
l1 = torch.sum(torch.abs(gt_norm - pred_norm), dim=1, keepdim=True)
loss = torch.mean(l1[gt_norm_mask])
elif self.loss_type == 'L2' or self.loss_type == 'CEL_L2':
l2 = torch.sum(torch.square(gt_norm - pred_norm), dim=1, keepdim=True)
loss = torch.mean(l2[gt_norm_mask])
elif self.loss_type == 'AL':
dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
valid_mask = gt_norm_mask[:, 0, :, :].float() \
* (dot.detach() < 0.999).float() \
* (dot.detach() > -0.999).float()
valid_mask = valid_mask > 0.0
al = torch.acos(dot * valid_mask)
al = al * self.batch_with_norm[:, 0, :, :]
loss = torch.mean(al)
elif self.loss_type == 'CEL' or self.loss_type == 'CEL_GRU':
dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
valid_mask = gt_norm_mask[:, 0, :, :].float() \
* (dot.detach() < 0.999).float() \
* (dot.detach() > -0.999).float()
valid_mask = valid_mask > 0.0
al = 1 - dot * valid_mask
al = al * self.batch_with_norm[:, 0, :, :]
loss = torch.mean(al)
elif self.loss_type == 'NLL_vMF':
dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
valid_mask = gt_norm_mask[:, 0, :, :].float() \
* (dot.detach() < 0.999).float() \
* (dot.detach() > -0.999).float()
valid_mask = valid_mask > 0.0
dot = dot[valid_mask]
kappa = pred_kappa[:, 0, :, :][valid_mask]
loss_pixelwise = - torch.log(kappa) \
- (kappa * (dot - 1)) \
+ torch.log(1 - torch.exp(- 2 * kappa))
loss = torch.mean(loss_pixelwise)
elif self.loss_type == 'NLL_ours':
dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
valid_mask = gt_norm_mask[:, 0, :, :].float() \
* (dot.detach() < 0.999).float() \
* (dot.detach() > -0.999).float()
valid_mask = valid_mask > 0.5
dot = dot * valid_mask
kappa = pred_kappa[:, 0, :, :] * valid_mask
loss_pixelwise = - torch.log(torch.square(kappa) + 1) \
+ kappa * torch.acos(dot) \
+ torch.log(1 + torch.exp(-kappa * np.pi))
loss_pixelwise = loss_pixelwise * self.batch_with_norm[:, 0, :, :]
loss = torch.mean(loss_pixelwise)
else:
raise Exception('invalid loss type')
return loss
def forward_S_Search(self, prediction, confidence, intrinsic, pad_mask, normal_pred=None, sem_mask=None, target=None, is_initial_pair=False, **kwargs):
if normal_pred is None:
prediction_normal = kwargs['normal_out_list'][-1]
else:
prediction_normal = normal_pred
# get normal from depth-prediction
scale = kwargs['scale'][:, None, None].float()
candidate_scale = kwargs['candidate_scale'][:, None, None, None].float()
normal, new_mask = self.depth2normal(prediction.detach() if self.depth_detach else prediction, intrinsic, pad_mask, scale)
sky_mask = sem_mask != self.sky_id
if target != None:
sampling_ratio = 0.7
target_mask = (target > 0)
if is_initial_pair == False:
pass
# mask sky
else:
sky_mask = torch.nn.functional.interpolate(sky_mask.float(), scale_factor=0.25).bool()
target_mask = torch.nn.functional.interpolate(target_mask.float(), scale_factor=0.25).bool()
new_mask = new_mask & ((sky_mask & self.batch_with_norm_sky) | target_mask)
else:
new_mask = torch.ones_like(prediction).bool()
sampling_ratio = 0.5
# normal = normal * (~sky_mask)
# normal[:,1:2,:,:][sky_mask] = 1
# dual sampling
confidence_normal = prediction_normal[:, 3:, :, :]
sample_mask_n = sample_points(prediction_normal, confidence_normal, new_mask, sampling_ratio=sampling_ratio)
sample_mask_d = sample_points(prediction, confidence, new_mask, sampling_ratio=sampling_ratio)
conf_mask = confidence > 0.5
# all mask
normal_mask = ~torch.all(normal == 0, dim=1, keepdim=True) & new_mask & sample_mask_n & sample_mask_d & conf_mask
if normal_mask.sum() < 10:
return 0 * prediction_normal.sum()
prediction_normal = torch.cat((prediction_normal[:,:2]*torch.ones_like(candidate_scale), prediction_normal[:,2:3]*candidate_scale, prediction_normal[:,3:4]*torch.ones_like(candidate_scale)), dim=1)
norm_x = prediction_normal[:,0:1]
norm_y = prediction_normal[:,1:2]
norm_z = prediction_normal[:,2:3]
prediction_normal[:,:3] = prediction_normal[:,:3] / (torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10)
loss = self.forward_R_Search(prediction_normal, normal, normal_mask)
#if torch.isnan(loss).item() | torch.isinf(loss).item():
#loss = 0 * torch.sum(prediction_normal)
#print(f'NormalBranchLoss forward_GRU NAN error, {loss}')
return loss
def forward_R_Search(self, norm_out, gt_norm, gt_norm_mask, pred_kappa=None):
pred_norm = norm_out[:, 0:3, :, :]
if pred_kappa is None:
pred_kappa = norm_out[:, 3:, :, :]
if 'L1' in self.loss_type:
l1 = torch.sum(torch.abs(gt_norm - pred_norm), dim=1, keepdim=True)
loss = torch.mean(l1*gt_norm_mask, dim=[1, 2, 3])
elif 'L2' in self.loss_type:
l2 = torch.sum(torch.square(gt_norm - pred_norm), dim=1, keepdim=True)
loss = torch.mean(l2*gt_norm_mask, dim=[1, 2, 3])
elif 'AL' in self.loss_type:
dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
valid_mask = gt_norm_mask[:, 0, :, :].float() \
* (dot.detach() < 0.999).float() \
* (dot.detach() > -0.999).float()
valid_mask = valid_mask > 0.0
al = torch.acos(dot * valid_mask)
loss = torch.mean(al, dim=[1, 2])
elif 'CEL' in self.loss_type:
dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
valid_mask = gt_norm_mask[:, 0, :, :].float() \
* (dot.detach() < 0.999).float() \
* (dot.detach() > -0.999).float()
valid_mask = valid_mask > 0.0
al = 1 - dot * valid_mask
loss = torch.mean(al, dim=[1, 2])
elif 'NLL_vMF' in self.loss_type:
dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
valid_mask = gt_norm_mask[:, 0, :, :].float() \
* (dot.detach() < 0.999).float() \
* (dot.detach() > -0.999).float()
valid_mask = valid_mask > 0.0
dot = dot[valid_mask]
kappa = pred_kappa[:, 0, :, :][valid_mask]
loss_pixelwise = - torch.log(kappa) \
- (kappa * (dot - 1)) \
+ torch.log(1 - torch.exp(- 2 * kappa))
loss = torch.mean(loss_pixelwise, dim=[1, 2])
elif 'NLL_ours' in self.loss_type:
dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
valid_mask = gt_norm_mask[:, 0, :, :].float() \
* (dot.detach() < 0.999).float() \
* (dot.detach() > -0.999).float()
valid_mask = valid_mask > 0.5
dot = dot * valid_mask
kappa = pred_kappa[:, 0, :, :] * valid_mask
loss_pixelwise = - torch.log(torch.square(kappa) + 1) \
+ kappa * torch.acos(dot) \
+ torch.log(1 + torch.exp(-kappa * np.pi))
loss = torch.mean(loss_pixelwise, dim=[1, 2])
else:
raise Exception('invalid loss type')
return loss