import torch import torch.nn as nn class ConfidenceLoss(nn.Module): """ confidence loss. """ def __init__(self, loss_weight=1, data_type=['stereo', 'lidar', 'denselidar'], **kwargs): super(ConfidenceLoss, self).__init__() self.loss_weight = loss_weight self.data_type = data_type self.eps = 1e-6 def forward(self, prediction, target, confidence, mask=None, **kwargs): conf_mask = torch.abs(target - prediction) < target conf_mask = conf_mask & mask gt_confidence = (1 - torch.abs((prediction - target) / target)) * conf_mask loss = torch.sum(torch.abs(confidence - gt_confidence) * conf_mask) / (torch.sum(conf_mask) + self.eps) if torch.isnan(loss).item() | torch.isinf(loss).item(): loss = 0 * torch.sum(confidence) print(f'ConfidenceLoss NAN error, {loss}') return loss * self.loss_weight