Spaces:
Building
Building
File size: 5,224 Bytes
4187c6f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
from typing import Optional, Dict
import torch.nn as nn
import torch
from .schema import LossConfiguration
def dice_loss(input: torch.Tensor,
target: torch.Tensor,
loss_mask: torch.Tensor,
class_weights: Optional[torch.Tensor | bool],
smooth=1e-5):
'''
:param input: (B, H, W, C) Logits for each class
:param target: (B, H, W, C) Ground truth class labels in one_hot
:param loss_mask: (B, H, W) Mask indicating valid regions of the image
:param class_weights: (C) Weights for each class
:param smooth: Smoothing factor to avoid division by zero, default 1.0
'''
if isinstance(class_weights, torch.Tensor):
class_weights = class_weights.unsqueeze(0)
elif class_weights is None or class_weights == False:
class_weights = torch.ones(
1, target.size(-1), dtype=target.dtype, device=target.device)
elif class_weights == True:
class_weights = target.sum(1)
class_weights = torch.reciprocal(target.mean(1) + 1e-3)
class_weights = class_weights.clamp(min=1e-5)
# Only consider classes that are present
class_weights *= (target.sum(1) != 0).float()
class_weights.requires_grad = False
intersect = (2 * input * target)
intersect = (intersect) + smooth
union = (input + target)
union = (union) + smooth
loss = 1 - (intersect / union) # B, H, W, C
loss *= class_weights.unsqueeze(0).unsqueeze(0)
loss = loss.sum(-1) / class_weights.sum()
loss *= loss_mask
loss = loss.sum() / loss_mask.sum() # 1
return loss
class EnhancedLoss(nn.Module):
def __init__(
self,
cfg: LossConfiguration,
): # following params in the paper
super(EnhancedLoss, self).__init__()
self.num_classes = cfg.num_classes
self.xent_weight = cfg.xent_weight
self.focal = cfg.focal_loss
self.focal_gamma = cfg.focal_loss_gamma
self.dice_weight = cfg.dice_weight
# self.class_mapping =
if self.xent_weight == 0. and self.dice_weight == 0.:
raise ValueError(
"At least one of xent_weight and dice_weight must be greater than 0.")
if self.xent_weight > 0.:
self.xent_loss = nn.BCEWithLogitsLoss(
reduction="none"
)
if self.dice_weight > 0.:
self.dice_loss = dice_loss
if cfg.class_weights is not None and cfg.class_weights != True:
self.register_buffer("class_weights", torch.tensor(
cfg.class_weights), persistent=False)
else:
self.class_weights = cfg.class_weights
self.class_weights: Optional[torch.Tensor | bool]
self.requires_frustrum = cfg.requires_frustrum
self.requires_flood_mask = cfg.requires_flood_mask
self.label_smoothing = cfg.label_smoothing
def forward(self, pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor]):
'''
Args:
pred: Dict containing the
- output: (B, C, H, W) Probabilities for each class
- valid_bev: (B, H, W) Mask indicating valid regions of the image
- conf: (B, H, W) Confidence map
data: Dict containing the
- seg_masks: (B, H, W, C) Ground truth class labels, one-hot encoded
- confidence_map: (B, H, W) Confidence map
'''
loss = {}
probs = pred['output'].permute(0, 2, 3, 1) # (B, H, W, C)
logits = pred['logits'].permute(0, 2, 3, 1) # (B, H, W, C)
labels: torch.Tensor = data['seg_masks'] # (B, H, W, C)
loss_mask = torch.ones(
labels.shape[:3], device=labels.device, dtype=labels.dtype)
if self.requires_frustrum:
frustrum_mask = pred["valid_bev"][..., :-1] != 0
loss_mask = loss_mask * frustrum_mask.float()
if self.requires_flood_mask:
flood_mask = data["flood_masks"] == 0
loss_mask = loss_mask * flood_mask.float()
if self.xent_weight > 0.:
if self.label_smoothing > 0.:
labels_ls = labels.float().clone()
labels_ls = labels_ls * \
(1 - self.label_smoothing) + \
self.label_smoothing / self.num_classes
xent_loss = self.xent_loss(logits, labels_ls)
else:
xent_loss = self.xent_loss(logits, labels)
if self.focal:
pt = torch.exp(-xent_loss)
xent_loss = (1 - pt) ** self.focal_gamma * xent_loss
xent_loss *= loss_mask.unsqueeze(-1)
xent_loss = xent_loss.sum() / (loss_mask.sum() + 1e-5)
loss['cross_entropy'] = xent_loss
loss['total'] = xent_loss * self.xent_weight
if self.dice_weight > 0.:
dloss = self.dice_loss(
probs, labels, loss_mask, self.class_weights)
loss['dice'] = dloss
if 'total' in loss:
loss['total'] += dloss * self.dice_weight
else:
loss['total'] = dloss * self.dice_weight
return loss
|