File size: 5,224 Bytes
4187c6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from typing import Optional, Dict
import torch.nn as nn
import torch
from .schema import LossConfiguration


def dice_loss(input: torch.Tensor,
              target: torch.Tensor,
              loss_mask: torch.Tensor,
              class_weights: Optional[torch.Tensor | bool],
              smooth=1e-5):
    '''
    :param input: (B, H, W, C) Logits for each class
    :param target: (B, H, W, C) Ground truth class labels in one_hot
    :param loss_mask: (B, H, W) Mask indicating valid regions of the image
    :param class_weights: (C) Weights for each class
    :param smooth: Smoothing factor to avoid division by zero, default 1.0
    '''
    
    if isinstance(class_weights, torch.Tensor):
        class_weights = class_weights.unsqueeze(0)
    elif class_weights is None or class_weights == False:
        class_weights = torch.ones(
            1, target.size(-1), dtype=target.dtype, device=target.device)
    elif class_weights == True:
        class_weights = target.sum(1)
        class_weights = torch.reciprocal(target.mean(1) + 1e-3)
        class_weights = class_weights.clamp(min=1e-5)
        # Only consider classes that are present
        class_weights *= (target.sum(1) != 0).float()
        class_weights.requires_grad = False

    intersect = (2 * input * target)
    intersect = (intersect) + smooth

    union = (input + target)
    union = (union) + smooth

    loss = 1 - (intersect / union)  # B, H, W, C
    loss *= class_weights.unsqueeze(0).unsqueeze(0)
    loss = loss.sum(-1) / class_weights.sum()
    loss *= loss_mask
    loss = loss.sum() / loss_mask.sum()  # 1

    return loss


class EnhancedLoss(nn.Module):
    def __init__(
        self,
        cfg: LossConfiguration,
    ):  # following params in the paper
        super(EnhancedLoss, self).__init__()
        self.num_classes = cfg.num_classes
        self.xent_weight = cfg.xent_weight
        self.focal = cfg.focal_loss
        self.focal_gamma = cfg.focal_loss_gamma
        self.dice_weight = cfg.dice_weight
        # self.class_mapping = 

        if self.xent_weight == 0. and self.dice_weight == 0.:
            raise ValueError(
                "At least one of xent_weight and dice_weight must be greater than 0.")
        
        if self.xent_weight > 0.:
            self.xent_loss = nn.BCEWithLogitsLoss(
                reduction="none"
            )

        if self.dice_weight > 0.:
            self.dice_loss = dice_loss

        if cfg.class_weights is not None and cfg.class_weights != True:
            self.register_buffer("class_weights", torch.tensor(
                cfg.class_weights), persistent=False)
        else:
            self.class_weights = cfg.class_weights

        self.class_weights: Optional[torch.Tensor | bool]

        self.requires_frustrum = cfg.requires_frustrum
        self.requires_flood_mask = cfg.requires_flood_mask
        self.label_smoothing = cfg.label_smoothing

    def forward(self, pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor]):
        '''
        Args:
            pred: Dict containing the
                - output: (B, C, H, W) Probabilities for each class
                - valid_bev: (B, H, W) Mask indicating valid regions of the image
                - conf: (B, H, W) Confidence map
            data: Dict containing the
                - seg_masks: (B, H, W, C) Ground truth class labels, one-hot encoded
                - confidence_map: (B, H, W) Confidence map
        '''
        loss = {}

        probs = pred['output'].permute(0, 2, 3, 1)  # (B, H, W, C)
        logits = pred['logits'].permute(0, 2, 3, 1)  # (B, H, W, C)
        labels: torch.Tensor = data['seg_masks']  # (B, H, W, C)

        loss_mask = torch.ones(
            labels.shape[:3], device=labels.device, dtype=labels.dtype)

        if self.requires_frustrum:
            frustrum_mask = pred["valid_bev"][..., :-1] != 0
            loss_mask = loss_mask * frustrum_mask.float()

        if self.requires_flood_mask:
            flood_mask = data["flood_masks"] == 0
            loss_mask = loss_mask * flood_mask.float()

        if self.xent_weight > 0.:

            if self.label_smoothing > 0.:
                labels_ls = labels.float().clone()
                labels_ls = labels_ls * \
                    (1 - self.label_smoothing) + \
                    self.label_smoothing / self.num_classes

                xent_loss = self.xent_loss(logits, labels_ls)
            else:
                xent_loss = self.xent_loss(logits, labels)

            if self.focal:
                pt = torch.exp(-xent_loss)
                xent_loss = (1 - pt) ** self.focal_gamma * xent_loss

            xent_loss *= loss_mask.unsqueeze(-1)
            xent_loss = xent_loss.sum() / (loss_mask.sum() + 1e-5)
            loss['cross_entropy'] = xent_loss
            loss['total'] = xent_loss * self.xent_weight

        if self.dice_weight > 0.:
            dloss = self.dice_loss(
                probs, labels, loss_mask, self.class_weights)
            loss['dice'] = dloss

            if 'total' in loss:
                loss['total'] += dloss * self.dice_weight
            else:
                loss['total'] = dloss * self.dice_weight

        return loss