import numpy as np import torch import torch.nn as nn from torch.nn.utils.rnn import pad_sequence #from pytorch3d.loss import chamfer_distance class AdabinsLoss(nn.Module): """ Losses employed in Adabins. """ def __init__(self, depth_normalize, variance_focus=0.85, loss_weight=1, out_channel=100, data_type=['stereo', 'lidar'], w_ce=False, w_chamber=False, **kwargs): super(AdabinsLoss, self).__init__() self.variance_focus = variance_focus self.loss_weight = loss_weight self.data_type = data_type #self.bins_num = out_channel #self.cel = nn.CrossEntropyLoss(ignore_index=self.bins_num + 1) self.depth_min = depth_normalize[0] self.depth_max = depth_normalize[1] self.w_ce = w_ce self.eps = 1e-6 def silog_loss(self, prediction, target, mask): d = torch.log(prediction[mask]) - torch.log(target[mask]) d_square_mean = torch.sum(d ** 2) / (d.numel() + self.eps) d_mean = torch.sum(d) / (d.numel() + self.eps) loss = torch.sqrt(d_square_mean - self.variance_focus * (d_mean ** 2)) return loss def chamfer_distance_loss(self, bins, target_depth_maps, mask): bin_centers = 0.5 * (bins[:, 1:] + bins[:, :-1]) n, p = bin_centers.shape input_points = bin_centers.view(n, p, 1) # .shape = n, p, 1 # n, c, h, w = target_depth_maps.shape target_points = target_depth_maps.flatten(1) # n, hwc #mask = target_points.ge(1e-3) # only valid ground truth points target_points = [p[m] for p, m in zip(target_depth_maps, mask)] target_lengths = torch.Tensor([len(t) for t in target_points], dtype=torch.long, device="cuda") target_points = pad_sequence(target_points, batch_first=True).unsqueeze(2) # .shape = n, T, 1 loss, _ = chamfer_distance(x=input_points, y=target_points, y_lengths=target_lengths) return loss # def depth_to_bins(self, depth, mask, depth_edges, size_limite=(512, 960)): # """ # Discretize depth into depth bins. Predefined bins edges are provided. # Mark invalid padding area as bins_num + 1 # Args: # @depth: 1-channel depth, [B, 1, h, w] # return: depth bins [B, C, h, w] # """ # def _depth_to_bins_block_(depth, mask, depth_edges): # bins_id = torch.sum(depth_edges[:, None, None, None, :] < torch.abs(depth)[:, :, :, :, None], dim=-1) # bins_id = bins_id - 1 # invalid_mask = ~mask # mask_lower = (depth <= self.depth_min) # mask_higher = (depth >= self.depth_max) # bins_id[mask_lower] = 0 # bins_id[mask_higher] = self.bins_num - 1 # bins_id[bins_id == self.bins_num] = self.bins_num - 1 # bins_id[invalid_mask] = self.bins_num + 1 # return bins_id # # _, _, H, W = depth.shape # # bins = mask.clone().long() # # h_blocks = np.ceil(H / size_limite[0]).astype(np.int) # # w_blocks = np.ceil(W/ size_limite[1]).astype(np.int) # # for i in range(h_blocks): # # for j in range(w_blocks): # # h_start = i*size_limite[0] # # h_end_proposal = (i + 1) * size_limite[0] # # h_end = h_end_proposal if h_end_proposal < H else H # # w_start = j*size_limite[1] # # w_end_proposal = (j + 1) * size_limite[1] # # w_end = w_end_proposal if w_end_proposal < W else W # # bins_ij = _depth_to_bins_block_( # # depth[:, :, h_start:h_end, w_start:w_end], # # mask[:, :, h_start:h_end, w_start:w_end], # # depth_edges # # ) # # bins[:, :, h_start:h_end, w_start:w_end] = bins_ij # bins = _depth_to_bins_block_(depth, mask, depth_edges) # return bins # def ce_loss(self, pred_logit, target, mask, bins_edges): # target_depth_bins = self.depth_to_bins(target, mask, bins_edges) # loss = self.cel(pred_logit, target_depth_bins.squeeze().long()) # return loss def forward(self, prediction, target, bins_edges, mask=None, **kwargs): silog_loss = self.silog_loss(prediction=prediction, target=target, mask=mask) #cf_loss = self.chamfer_distance_loss(bins=bins_edges, target_depth_maps=target, mask=mask) loss = silog_loss * 10 #+ 0.1 * cf_loss # if self.w_ce: # loss = loss + self.ce_loss(kwargs['pred_logit'], target, mask, bins_edges) if torch.isnan(loss).item() | torch.isinf(loss).item(): raise RuntimeError(f'Adabins loss error, {loss}') return loss * self.loss_weight