Spaces:
Building
Building
from typing import Optional, Dict | |
import torch.nn as nn | |
import torch | |
from .schema import LossConfiguration | |
def dice_loss(input: torch.Tensor, | |
target: torch.Tensor, | |
loss_mask: torch.Tensor, | |
class_weights: Optional[torch.Tensor | bool], | |
smooth=1e-5): | |
''' | |
:param input: (B, H, W, C) Logits for each class | |
:param target: (B, H, W, C) Ground truth class labels in one_hot | |
:param loss_mask: (B, H, W) Mask indicating valid regions of the image | |
:param class_weights: (C) Weights for each class | |
:param smooth: Smoothing factor to avoid division by zero, default 1.0 | |
''' | |
if isinstance(class_weights, torch.Tensor): | |
class_weights = class_weights.unsqueeze(0) | |
elif class_weights is None or class_weights == False: | |
class_weights = torch.ones( | |
1, target.size(-1), dtype=target.dtype, device=target.device) | |
elif class_weights == True: | |
class_weights = target.sum(1) | |
class_weights = torch.reciprocal(target.mean(1) + 1e-3) | |
class_weights = class_weights.clamp(min=1e-5) | |
# Only consider classes that are present | |
class_weights *= (target.sum(1) != 0).float() | |
class_weights.requires_grad = False | |
intersect = (2 * input * target) | |
intersect = (intersect) + smooth | |
union = (input + target) | |
union = (union) + smooth | |
loss = 1 - (intersect / union) # B, H, W, C | |
loss *= class_weights.unsqueeze(0).unsqueeze(0) | |
loss = loss.sum(-1) / class_weights.sum() | |
loss *= loss_mask | |
loss = loss.sum() / loss_mask.sum() # 1 | |
return loss | |
class EnhancedLoss(nn.Module): | |
def __init__( | |
self, | |
cfg: LossConfiguration, | |
): # following params in the paper | |
super(EnhancedLoss, self).__init__() | |
self.num_classes = cfg.num_classes | |
self.xent_weight = cfg.xent_weight | |
self.focal = cfg.focal_loss | |
self.focal_gamma = cfg.focal_loss_gamma | |
self.dice_weight = cfg.dice_weight | |
# self.class_mapping = | |
if self.xent_weight == 0. and self.dice_weight == 0.: | |
raise ValueError( | |
"At least one of xent_weight and dice_weight must be greater than 0.") | |
if self.xent_weight > 0.: | |
self.xent_loss = nn.BCEWithLogitsLoss( | |
reduction="none" | |
) | |
if self.dice_weight > 0.: | |
self.dice_loss = dice_loss | |
if cfg.class_weights is not None and cfg.class_weights != True: | |
self.register_buffer("class_weights", torch.tensor( | |
cfg.class_weights), persistent=False) | |
else: | |
self.class_weights = cfg.class_weights | |
self.class_weights: Optional[torch.Tensor | bool] | |
self.requires_frustrum = cfg.requires_frustrum | |
self.requires_flood_mask = cfg.requires_flood_mask | |
self.label_smoothing = cfg.label_smoothing | |
def forward(self, pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor]): | |
''' | |
Args: | |
pred: Dict containing the | |
- output: (B, C, H, W) Probabilities for each class | |
- valid_bev: (B, H, W) Mask indicating valid regions of the image | |
- conf: (B, H, W) Confidence map | |
data: Dict containing the | |
- seg_masks: (B, H, W, C) Ground truth class labels, one-hot encoded | |
- confidence_map: (B, H, W) Confidence map | |
''' | |
loss = {} | |
probs = pred['output'].permute(0, 2, 3, 1) # (B, H, W, C) | |
logits = pred['logits'].permute(0, 2, 3, 1) # (B, H, W, C) | |
labels: torch.Tensor = data['seg_masks'] # (B, H, W, C) | |
loss_mask = torch.ones( | |
labels.shape[:3], device=labels.device, dtype=labels.dtype) | |
if self.requires_frustrum: | |
frustrum_mask = pred["valid_bev"][..., :-1] != 0 | |
loss_mask = loss_mask * frustrum_mask.float() | |
if self.requires_flood_mask: | |
flood_mask = data["flood_masks"] == 0 | |
loss_mask = loss_mask * flood_mask.float() | |
if self.xent_weight > 0.: | |
if self.label_smoothing > 0.: | |
labels_ls = labels.float().clone() | |
labels_ls = labels_ls * \ | |
(1 - self.label_smoothing) + \ | |
self.label_smoothing / self.num_classes | |
xent_loss = self.xent_loss(logits, labels_ls) | |
else: | |
xent_loss = self.xent_loss(logits, labels) | |
if self.focal: | |
pt = torch.exp(-xent_loss) | |
xent_loss = (1 - pt) ** self.focal_gamma * xent_loss | |
xent_loss *= loss_mask.unsqueeze(-1) | |
xent_loss = xent_loss.sum() / (loss_mask.sum() + 1e-5) | |
loss['cross_entropy'] = xent_loss | |
loss['total'] = xent_loss * self.xent_weight | |
if self.dice_weight > 0.: | |
dloss = self.dice_loss( | |
probs, labels, loss_mask, self.class_weights) | |
loss['dice'] = dloss | |
if 'total' in loss: | |
loss['total'] += dloss * self.dice_weight | |
else: | |
loss['total'] = dloss * self.dice_weight | |
return loss | |