|
import torch |
|
import torch.nn as nn |
|
|
|
class ScaleAlignLoss(nn.Module): |
|
""" |
|
Loss function defined over sequence of depth predictions |
|
""" |
|
def __init__(self, data_type=['lidar', 'denselidar', 'stereo', 'denselidar_syn'], loss_weight=1.0, disable_dataset=['MapillaryPSD'], **kwargs): |
|
super(ScaleAlignLoss, self).__init__() |
|
self.loss_weight = loss_weight |
|
self.data_type = data_type |
|
self.disable_dataset = disable_dataset |
|
|
|
def forward(self, prediction, target, mask, scale, **kwargs): |
|
device = target.device |
|
|
|
B, C, H, W = prediction.shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batches_dataset = kwargs['dataset'] |
|
self.batch_valid = torch.tensor([1 if batch_dataset not in self.disable_dataset else 0 \ |
|
for batch_dataset in batches_dataset], device=device) |
|
|
|
scale_tgt = torch.zeros_like(scale).squeeze(3).squeeze(2).squeeze(1) |
|
for i in range(B): |
|
mask_i = mask[i, ...] |
|
if torch.sum(mask_i) > 10: |
|
scale_tgt[i] = torch.median(target[i, ...][mask_i]) |
|
else: |
|
scale_tgt[i] = 0 |
|
|
|
batch_valid = self.batch_valid * (scale_tgt > 1e-8) |
|
scale_diff = torch.abs(scale.squeeze(3).squeeze(2).squeeze(1) - scale_tgt) |
|
loss = torch.sum(scale_diff * batch_valid) / (torch.sum(batch_valid) + 1e-8) |
|
|
|
return loss * self.loss_weight |