|
import torch |
|
import torch.nn as nn |
|
|
|
class SSILoss(nn.Module): |
|
""" |
|
Scale shift invariant MAE loss. |
|
loss = MAE((d-median(d)/s - (d'-median(d'))/s'), s = mean(d- median(d)) |
|
""" |
|
def __init__(self, loss_weight=1, data_type=['sfm', 'stereo', 'lidar'], **kwargs): |
|
super(SSILoss, self).__init__() |
|
self.loss_weight = loss_weight |
|
self.data_type = data_type |
|
self.eps = 1e-6 |
|
|
|
def ssi_mae(self, target, prediction, mask): |
|
valid_pixes = torch.sum(mask) + self.eps |
|
|
|
gt_median = torch.median(target) if target.numel() else 0 |
|
gt_s = torch.abs(target - gt_median).sum() / valid_pixes |
|
gt_trans = (target - gt_median) / (gt_s + self.eps) |
|
|
|
pred_median = torch.median(prediction) if prediction.numel() else 0 |
|
pred_s = torch.abs(prediction - pred_median).sum() / valid_pixes |
|
pred_trans = (prediction - pred_median) / (pred_s + self.eps) |
|
|
|
ssi_mae_sum = torch.sum(torch.abs(gt_trans - pred_trans)) |
|
return ssi_mae_sum, valid_pixes |
|
|
|
def forward(self, prediction, target, mask=None, **kwargs): |
|
""" |
|
Calculate loss. |
|
""" |
|
B, C, H, W = prediction.shape |
|
loss = 0 |
|
valid_pix = 0 |
|
for i in range(B): |
|
mask_i = mask[i, ...] |
|
gt_depth_i = target[i, ...][mask_i] |
|
pred_depth_i = prediction[i, ...][mask_i] |
|
ssi_sum, valid_pix_i = self.ssi_mae(pred_depth_i, gt_depth_i, mask_i) |
|
loss += ssi_sum |
|
valid_pix += valid_pix_i |
|
loss /= (valid_pix + self.eps) |
|
return loss * self.loss_weight |
|
|
|
if __name__ == '__main__': |
|
torch.manual_seed(1) |
|
torch.cuda.manual_seed_all(1) |
|
|
|
ssil = SSILoss() |
|
pred = torch.rand((2, 1, 256, 256)).cuda() |
|
gt = torch.rand((2, 1, 256, 256)).cuda() |
|
gt[:, :, 100:256, 0:100] = -1 |
|
mask = gt > 0 |
|
out = ssil(pred, gt, mask) |
|
print(out) |
|
|