Spaces:
Sleeping
Sleeping
| # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01a_losses.ipynb. | |
| # %% ../nbs/01a_losses.ipynb 2 | |
| from __future__ import annotations | |
| from .imports import * | |
| from .torch_imports import * | |
| from .torch_core import * | |
| from .layers import * | |
| # %% auto 0 | |
| __all__ = ['BaseLoss', 'CrossEntropyLossFlat', 'FocalLoss', 'FocalLossFlat', 'BCEWithLogitsLossFlat', 'BCELossFlat', | |
| 'MSELossFlat', 'L1LossFlat', 'LabelSmoothingCrossEntropy', 'LabelSmoothingCrossEntropyFlat', 'DiceLoss'] | |
| # %% ../nbs/01a_losses.ipynb 5 | |
| class BaseLoss(): | |
| "Same as `loss_cls`, but flattens input and target." | |
| activation=decodes=noops | |
| def __init__(self, | |
| loss_cls, # Uninitialized PyTorch-compatible loss | |
| *args, | |
| axis:int=-1, # Class axis | |
| flatten:bool=True, # Flatten `inp` and `targ` before calculating loss | |
| floatify:bool=False, # Convert `targ` to `float` | |
| is_2d:bool=True, # Whether `flatten` keeps one or two channels when applied | |
| **kwargs | |
| ): | |
| store_attr("axis,flatten,floatify,is_2d") | |
| self.func = loss_cls(*args,**kwargs) | |
| functools.update_wrapper(self, self.func) | |
| def __repr__(self) -> str: return f"FlattenedLoss of {self.func}" | |
| def reduction(self) -> str: return self.func.reduction | |
| def reduction(self, v:str): | |
| "Sets the reduction style (typically 'mean', 'sum', or 'none')" | |
| self.func.reduction = v | |
| def _contiguous(self, x:Tensor) -> TensorBase: | |
| "Move `self.axis` to the last dimension and ensure tensor is contigous for `Tensor` otherwise just return" | |
| return TensorBase(x.transpose(self.axis,-1).contiguous()) if isinstance(x,torch.Tensor) else x | |
| def __call__(self, | |
| inp:Tensor|MutableSequence, # Predictions from a `Learner` | |
| targ:Tensor|MutableSequence, # Actual y label | |
| **kwargs | |
| ) -> TensorBase: # `loss_cls` calculated on `inp` and `targ` | |
| inp,targ = map(self._contiguous, (inp,targ)) | |
| if self.floatify and targ.dtype!=torch.float16: targ = targ.float() | |
| if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long() | |
| if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1) | |
| return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs) | |
| def to(self, device:torch.device): | |
| "Move the loss function to a specified `device`" | |
| if isinstance(self.func, nn.Module): self.func.to(device) | |
| # %% ../nbs/01a_losses.ipynb 8 | |
| class CrossEntropyLossFlat(BaseLoss): | |
| "Same as `nn.CrossEntropyLoss`, but flattens input and target." | |
| y_int = True # y interpolation | |
| def __init__(self, | |
| *args, | |
| axis:int=-1, # Class axis | |
| **kwargs | |
| ): | |
| super().__init__(nn.CrossEntropyLoss, *args, axis=axis, **kwargs) | |
| def decodes(self, x:Tensor) -> Tensor: | |
| "Converts model output to target format" | |
| return x.argmax(dim=self.axis) | |
| def activation(self, x:Tensor) -> Tensor: | |
| "`nn.CrossEntropyLoss`'s fused activation function applied to model output" | |
| return F.softmax(x, dim=self.axis) | |
| # %% ../nbs/01a_losses.ipynb 13 | |
| class FocalLoss(Module): | |
| y_int=True # y interpolation | |
| def __init__(self, | |
| gamma:float=2.0, # Focusing parameter. Higher values down-weight easy examples' contribution to loss | |
| weight:Tensor=None, # Manual rescaling weight given to each class | |
| reduction:str='mean' # PyTorch reduction to apply to the output | |
| ): | |
| "Applies Focal Loss: https://arxiv.org/pdf/1708.02002.pdf" | |
| store_attr() | |
| def forward(self, inp:Tensor, targ:Tensor) -> Tensor: | |
| "Applies focal loss based on https://arxiv.org/pdf/1708.02002.pdf" | |
| ce_loss = F.cross_entropy(inp, targ, weight=self.weight, reduction="none") | |
| p_t = torch.exp(-ce_loss) | |
| loss = (1 - p_t)**self.gamma * ce_loss | |
| if self.reduction == "mean": | |
| loss = loss.mean() | |
| elif self.reduction == "sum": | |
| loss = loss.sum() | |
| return loss | |
| class FocalLossFlat(BaseLoss): | |
| """ | |
| Same as CrossEntropyLossFlat but with focal paramter, `gamma`. Focal loss is introduced by Lin et al. | |
| https://arxiv.org/pdf/1708.02002.pdf. Note the class weighting factor in the paper, alpha, can be | |
| implemented through pytorch `weight` argument passed through to F.cross_entropy. | |
| """ | |
| y_int = True # y interpolation | |
| def __init__(self, | |
| *args, | |
| gamma:float=2.0, # Focusing parameter. Higher values down-weight easy examples' contribution to loss | |
| axis:int=-1, # Class axis | |
| **kwargs | |
| ): | |
| super().__init__(FocalLoss, *args, gamma=gamma, axis=axis, **kwargs) | |
| def decodes(self, x:Tensor) -> Tensor: | |
| "Converts model output to target format" | |
| return x.argmax(dim=self.axis) | |
| def activation(self, x:Tensor) -> Tensor: | |
| "`F.cross_entropy`'s fused activation function applied to model output" | |
| return F.softmax(x, dim=self.axis) | |
| # %% ../nbs/01a_losses.ipynb 16 | |
| class BCEWithLogitsLossFlat(BaseLoss): | |
| "Same as `nn.BCEWithLogitsLoss`, but flattens input and target." | |
| def __init__(self, | |
| *args, | |
| axis:int=-1, # Class axis | |
| floatify:bool=True, # Convert `targ` to `float` | |
| thresh:float=0.5, # The threshold on which to predict | |
| **kwargs | |
| ): | |
| if kwargs.get('pos_weight', None) is not None and kwargs.get('flatten', None) is True: | |
| raise ValueError("`flatten` must be False when using `pos_weight` to avoid a RuntimeError due to shape mismatch") | |
| if kwargs.get('pos_weight', None) is not None: kwargs['flatten'] = False | |
| super().__init__(nn.BCEWithLogitsLoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs) | |
| self.thresh = thresh | |
| def decodes(self, x:Tensor) -> Tensor: | |
| "Converts model output to target format" | |
| return x>self.thresh | |
| def activation(self, x:Tensor) -> Tensor: | |
| "`nn.BCEWithLogitsLoss`'s fused activation function applied to model output" | |
| return torch.sigmoid(x) | |
| # %% ../nbs/01a_losses.ipynb 18 | |
| def BCELossFlat( | |
| *args, | |
| axis:int=-1, # Class axis | |
| floatify:bool=True, # Convert `targ` to `float` | |
| **kwargs | |
| ): | |
| "Same as `nn.BCELoss`, but flattens input and target." | |
| return BaseLoss(nn.BCELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs) | |
| # %% ../nbs/01a_losses.ipynb 20 | |
| def MSELossFlat( | |
| *args, | |
| axis:int=-1, # Class axis | |
| floatify:bool=True, # Convert `targ` to `float` | |
| **kwargs | |
| ): | |
| "Same as `nn.MSELoss`, but flattens input and target." | |
| return BaseLoss(nn.MSELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs) | |
| # %% ../nbs/01a_losses.ipynb 23 | |
| def L1LossFlat( | |
| *args, | |
| axis=-1, # Class axis | |
| floatify=True, # Convert `targ` to `float` | |
| **kwargs | |
| ): | |
| "Same as `nn.L1Loss`, but flattens input and target." | |
| return BaseLoss(nn.L1Loss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs) | |
| # %% ../nbs/01a_losses.ipynb 24 | |
| class LabelSmoothingCrossEntropy(Module): | |
| y_int = True # y interpolation | |
| def __init__(self, | |
| eps:float=0.1, # The weight for the interpolation formula | |
| weight:Tensor=None, # Manual rescaling weight given to each class passed to `F.nll_loss` | |
| reduction:str='mean' # PyTorch reduction to apply to the output | |
| ): | |
| store_attr() | |
| def forward(self, output:Tensor, target:Tensor) -> Tensor: | |
| "Apply `F.log_softmax` on output then blend the loss/num_classes(`c`) with the `F.nll_loss`" | |
| c = output.size()[1] | |
| log_preds = F.log_softmax(output, dim=1) | |
| if self.reduction=='sum': loss = -log_preds.sum() | |
| else: | |
| loss = -log_preds.sum(dim=1) #We divide by that size at the return line so sum and not mean | |
| if self.reduction=='mean': loss = loss.mean() | |
| return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target.long(), weight=self.weight, reduction=self.reduction) | |
| def activation(self, out:Tensor) -> Tensor: | |
| "`F.log_softmax`'s fused activation function applied to model output" | |
| return F.softmax(out, dim=-1) | |
| def decodes(self, out:Tensor) -> Tensor: | |
| "Converts model output to target format" | |
| return out.argmax(dim=-1) | |
| # %% ../nbs/01a_losses.ipynb 27 | |
| class LabelSmoothingCrossEntropyFlat(BaseLoss): | |
| "Same as `LabelSmoothingCrossEntropy`, but flattens input and target." | |
| y_int = True | |
| def __init__(self, | |
| *args, | |
| axis:int=-1, # Class axis | |
| **kwargs | |
| ): | |
| super().__init__(LabelSmoothingCrossEntropy, *args, axis=axis, **kwargs) | |
| def activation(self, out:Tensor) -> Tensor: | |
| "`LabelSmoothingCrossEntropy`'s fused activation function applied to model output" | |
| return F.softmax(out, dim=-1) | |
| def decodes(self, out:Tensor) -> Tensor: | |
| "Converts model output to target format" | |
| return out.argmax(dim=-1) | |
| # %% ../nbs/01a_losses.ipynb 30 | |
| class DiceLoss: | |
| "Dice loss for segmentation" | |
| def __init__(self, | |
| axis:int=1, # Class axis | |
| smooth:float=1e-6, # Helps with numerical stabilities in the IoU division | |
| reduction:str="sum", # PyTorch reduction to apply to the output | |
| square_in_union:bool=False # Squares predictions to increase slope of gradients | |
| ): | |
| store_attr() | |
| def __call__(self, pred:Tensor, targ:Tensor) -> Tensor: | |
| "One-hot encodes targ, then runs IoU calculation then takes 1-dice value" | |
| targ = self._one_hot(targ, pred.shape[self.axis]) | |
| pred, targ = TensorBase(pred), TensorBase(targ) | |
| assert pred.shape == targ.shape, 'input and target dimensions differ, DiceLoss expects non one-hot targs' | |
| pred = self.activation(pred) | |
| sum_dims = list(range(2, len(pred.shape))) | |
| inter = torch.sum(pred*targ, dim=sum_dims) | |
| union = (torch.sum(pred**2+targ, dim=sum_dims) if self.square_in_union | |
| else torch.sum(pred+targ, dim=sum_dims)) | |
| dice_score = (2. * inter + self.smooth)/(union + self.smooth) | |
| loss = 1- dice_score | |
| if self.reduction == 'mean': | |
| loss = loss.mean() | |
| elif self.reduction == 'sum': | |
| loss = loss.sum() | |
| return loss | |
| def _one_hot( | |
| x:Tensor, # Non one-hot encoded targs | |
| classes:int, # The number of classes | |
| axis:int=1 # The axis to stack for encoding (class dimension) | |
| ) -> Tensor: | |
| "Creates one binary mask per class" | |
| return torch.stack([torch.where(x==c, 1, 0) for c in range(classes)], axis=axis) | |
| def activation(self, x:Tensor) -> Tensor: | |
| "Activation function applied to model output" | |
| return F.softmax(x, dim=self.axis) | |
| def decodes(self, x:Tensor) -> Tensor: | |
| "Converts model output to target format" | |
| return x.argmax(dim=self.axis) | |