import numpy as np | |
import torch | |
import torch.nn as nn | |
from torch.nn.utils.rnn import pad_sequence | |
#from pytorch3d.loss import chamfer_distance | |
class AdabinsLoss(nn.Module): | |
""" | |
Losses employed in Adabins. | |
""" | |
def __init__(self, depth_normalize, variance_focus=0.85, loss_weight=1, out_channel=100, data_type=['stereo', 'lidar'], w_ce=False, w_chamber=False, **kwargs): | |
super(AdabinsLoss, self).__init__() | |
self.variance_focus = variance_focus | |
self.loss_weight = loss_weight | |
self.data_type = data_type | |
#self.bins_num = out_channel | |
#self.cel = nn.CrossEntropyLoss(ignore_index=self.bins_num + 1) | |
self.depth_min = depth_normalize[0] | |
self.depth_max = depth_normalize[1] | |
self.w_ce = w_ce | |
self.eps = 1e-6 | |
def silog_loss(self, prediction, target, mask): | |
d = torch.log(prediction[mask]) - torch.log(target[mask]) | |
d_square_mean = torch.sum(d ** 2) / (d.numel() + self.eps) | |
d_mean = torch.sum(d) / (d.numel() + self.eps) | |
loss = torch.sqrt(d_square_mean - self.variance_focus * (d_mean ** 2)) | |
return loss | |
def chamfer_distance_loss(self, bins, target_depth_maps, mask): | |
bin_centers = 0.5 * (bins[:, 1:] + bins[:, :-1]) | |
n, p = bin_centers.shape | |
input_points = bin_centers.view(n, p, 1) # .shape = n, p, 1 | |
# n, c, h, w = target_depth_maps.shape | |
target_points = target_depth_maps.flatten(1) # n, hwc | |
#mask = target_points.ge(1e-3) # only valid ground truth points | |
target_points = [p[m] for p, m in zip(target_depth_maps, mask)] | |
target_lengths = torch.Tensor([len(t) for t in target_points], dtype=torch.long, device="cuda") | |
target_points = pad_sequence(target_points, batch_first=True).unsqueeze(2) # .shape = n, T, 1 | |
loss, _ = chamfer_distance(x=input_points, y=target_points, y_lengths=target_lengths) | |
return loss | |
# def depth_to_bins(self, depth, mask, depth_edges, size_limite=(512, 960)): | |
# """ | |
# Discretize depth into depth bins. Predefined bins edges are provided. | |
# Mark invalid padding area as bins_num + 1 | |
# Args: | |
# @depth: 1-channel depth, [B, 1, h, w] | |
# return: depth bins [B, C, h, w] | |
# """ | |
# def _depth_to_bins_block_(depth, mask, depth_edges): | |
# bins_id = torch.sum(depth_edges[:, None, None, None, :] < torch.abs(depth)[:, :, :, :, None], dim=-1) | |
# bins_id = bins_id - 1 | |
# invalid_mask = ~mask | |
# mask_lower = (depth <= self.depth_min) | |
# mask_higher = (depth >= self.depth_max) | |
# bins_id[mask_lower] = 0 | |
# bins_id[mask_higher] = self.bins_num - 1 | |
# bins_id[bins_id == self.bins_num] = self.bins_num - 1 | |
# bins_id[invalid_mask] = self.bins_num + 1 | |
# return bins_id | |
# # _, _, H, W = depth.shape | |
# # bins = mask.clone().long() | |
# # h_blocks = np.ceil(H / size_limite[0]).astype(np.int) | |
# # w_blocks = np.ceil(W/ size_limite[1]).astype(np.int) | |
# # for i in range(h_blocks): | |
# # for j in range(w_blocks): | |
# # h_start = i*size_limite[0] | |
# # h_end_proposal = (i + 1) * size_limite[0] | |
# # h_end = h_end_proposal if h_end_proposal < H else H | |
# # w_start = j*size_limite[1] | |
# # w_end_proposal = (j + 1) * size_limite[1] | |
# # w_end = w_end_proposal if w_end_proposal < W else W | |
# # bins_ij = _depth_to_bins_block_( | |
# # depth[:, :, h_start:h_end, w_start:w_end], | |
# # mask[:, :, h_start:h_end, w_start:w_end], | |
# # depth_edges | |
# # ) | |
# # bins[:, :, h_start:h_end, w_start:w_end] = bins_ij | |
# bins = _depth_to_bins_block_(depth, mask, depth_edges) | |
# return bins | |
# def ce_loss(self, pred_logit, target, mask, bins_edges): | |
# target_depth_bins = self.depth_to_bins(target, mask, bins_edges) | |
# loss = self.cel(pred_logit, target_depth_bins.squeeze().long()) | |
# return loss | |
def forward(self, prediction, target, bins_edges, mask=None, **kwargs): | |
silog_loss = self.silog_loss(prediction=prediction, target=target, mask=mask) | |
#cf_loss = self.chamfer_distance_loss(bins=bins_edges, target_depth_maps=target, mask=mask) | |
loss = silog_loss * 10 #+ 0.1 * cf_loss | |
# if self.w_ce: | |
# loss = loss + self.ce_loss(kwargs['pred_logit'], target, mask, bins_edges) | |
if torch.isnan(loss).item() | torch.isinf(loss).item(): | |
raise RuntimeError(f'Adabins loss error, {loss}') | |
return loss * self.loss_weight |