Spaces:
Runtime error
Runtime error
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01a_losses.ipynb. | |
# %% ../nbs/01a_losses.ipynb 2 | |
from __future__ import annotations | |
from .imports import * | |
from .torch_imports import * | |
from .torch_core import * | |
from .layers import * | |
# %% auto 0 | |
__all__ = ['BaseLoss', 'CrossEntropyLossFlat', 'FocalLoss', 'FocalLossFlat', 'BCEWithLogitsLossFlat', 'BCELossFlat', | |
'MSELossFlat', 'L1LossFlat', 'LabelSmoothingCrossEntropy', 'LabelSmoothingCrossEntropyFlat', 'DiceLoss'] | |
# %% ../nbs/01a_losses.ipynb 5 | |
class BaseLoss(): | |
"Same as `loss_cls`, but flattens input and target." | |
activation=decodes=noops | |
def __init__(self, | |
loss_cls, # Uninitialized PyTorch-compatible loss | |
*args, | |
axis:int=-1, # Class axis | |
flatten:bool=True, # Flatten `inp` and `targ` before calculating loss | |
floatify:bool=False, # Convert `targ` to `float` | |
is_2d:bool=True, # Whether `flatten` keeps one or two channels when applied | |
**kwargs | |
): | |
store_attr("axis,flatten,floatify,is_2d") | |
self.func = loss_cls(*args,**kwargs) | |
functools.update_wrapper(self, self.func) | |
def __repr__(self) -> str: return f"FlattenedLoss of {self.func}" | |
def reduction(self) -> str: return self.func.reduction | |
def reduction(self, v:str): | |
"Sets the reduction style (typically 'mean', 'sum', or 'none')" | |
self.func.reduction = v | |
def _contiguous(self, x:Tensor) -> TensorBase: | |
"Move `self.axis` to the last dimension and ensure tensor is contigous for `Tensor` otherwise just return" | |
return TensorBase(x.transpose(self.axis,-1).contiguous()) if isinstance(x,torch.Tensor) else x | |
def __call__(self, | |
inp:Tensor|MutableSequence, # Predictions from a `Learner` | |
targ:Tensor|MutableSequence, # Actual y label | |
**kwargs | |
) -> TensorBase: # `loss_cls` calculated on `inp` and `targ` | |
inp,targ = map(self._contiguous, (inp,targ)) | |
if self.floatify and targ.dtype!=torch.float16: targ = targ.float() | |
if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long() | |
if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1) | |
return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs) | |
def to(self, device:torch.device): | |
"Move the loss function to a specified `device`" | |
if isinstance(self.func, nn.Module): self.func.to(device) | |
# %% ../nbs/01a_losses.ipynb 8 | |
class CrossEntropyLossFlat(BaseLoss): | |
"Same as `nn.CrossEntropyLoss`, but flattens input and target." | |
y_int = True # y interpolation | |
def __init__(self, | |
*args, | |
axis:int=-1, # Class axis | |
**kwargs | |
): | |
super().__init__(nn.CrossEntropyLoss, *args, axis=axis, **kwargs) | |
def decodes(self, x:Tensor) -> Tensor: | |
"Converts model output to target format" | |
return x.argmax(dim=self.axis) | |
def activation(self, x:Tensor) -> Tensor: | |
"`nn.CrossEntropyLoss`'s fused activation function applied to model output" | |
return F.softmax(x, dim=self.axis) | |
# %% ../nbs/01a_losses.ipynb 13 | |
class FocalLoss(Module): | |
y_int=True # y interpolation | |
def __init__(self, | |
gamma:float=2.0, # Focusing parameter. Higher values down-weight easy examples' contribution to loss | |
weight:Tensor=None, # Manual rescaling weight given to each class | |
reduction:str='mean' # PyTorch reduction to apply to the output | |
): | |
"Applies Focal Loss: https://arxiv.org/pdf/1708.02002.pdf" | |
store_attr() | |
def forward(self, inp:Tensor, targ:Tensor) -> Tensor: | |
"Applies focal loss based on https://arxiv.org/pdf/1708.02002.pdf" | |
ce_loss = F.cross_entropy(inp, targ, weight=self.weight, reduction="none") | |
p_t = torch.exp(-ce_loss) | |
loss = (1 - p_t)**self.gamma * ce_loss | |
if self.reduction == "mean": | |
loss = loss.mean() | |
elif self.reduction == "sum": | |
loss = loss.sum() | |
return loss | |
class FocalLossFlat(BaseLoss): | |
""" | |
Same as CrossEntropyLossFlat but with focal paramter, `gamma`. Focal loss is introduced by Lin et al. | |
https://arxiv.org/pdf/1708.02002.pdf. Note the class weighting factor in the paper, alpha, can be | |
implemented through pytorch `weight` argument passed through to F.cross_entropy. | |
""" | |
y_int = True # y interpolation | |
def __init__(self, | |
*args, | |
gamma:float=2.0, # Focusing parameter. Higher values down-weight easy examples' contribution to loss | |
axis:int=-1, # Class axis | |
**kwargs | |
): | |
super().__init__(FocalLoss, *args, gamma=gamma, axis=axis, **kwargs) | |
def decodes(self, x:Tensor) -> Tensor: | |
"Converts model output to target format" | |
return x.argmax(dim=self.axis) | |
def activation(self, x:Tensor) -> Tensor: | |
"`F.cross_entropy`'s fused activation function applied to model output" | |
return F.softmax(x, dim=self.axis) | |
# %% ../nbs/01a_losses.ipynb 16 | |
class BCEWithLogitsLossFlat(BaseLoss): | |
"Same as `nn.BCEWithLogitsLoss`, but flattens input and target." | |
def __init__(self, | |
*args, | |
axis:int=-1, # Class axis | |
floatify:bool=True, # Convert `targ` to `float` | |
thresh:float=0.5, # The threshold on which to predict | |
**kwargs | |
): | |
if kwargs.get('pos_weight', None) is not None and kwargs.get('flatten', None) is True: | |
raise ValueError("`flatten` must be False when using `pos_weight` to avoid a RuntimeError due to shape mismatch") | |
if kwargs.get('pos_weight', None) is not None: kwargs['flatten'] = False | |
super().__init__(nn.BCEWithLogitsLoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs) | |
self.thresh = thresh | |
def decodes(self, x:Tensor) -> Tensor: | |
"Converts model output to target format" | |
return x>self.thresh | |
def activation(self, x:Tensor) -> Tensor: | |
"`nn.BCEWithLogitsLoss`'s fused activation function applied to model output" | |
return torch.sigmoid(x) | |
# %% ../nbs/01a_losses.ipynb 18 | |
def BCELossFlat( | |
*args, | |
axis:int=-1, # Class axis | |
floatify:bool=True, # Convert `targ` to `float` | |
**kwargs | |
): | |
"Same as `nn.BCELoss`, but flattens input and target." | |
return BaseLoss(nn.BCELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs) | |
# %% ../nbs/01a_losses.ipynb 20 | |
def MSELossFlat( | |
*args, | |
axis:int=-1, # Class axis | |
floatify:bool=True, # Convert `targ` to `float` | |
**kwargs | |
): | |
"Same as `nn.MSELoss`, but flattens input and target." | |
return BaseLoss(nn.MSELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs) | |
# %% ../nbs/01a_losses.ipynb 23 | |
def L1LossFlat( | |
*args, | |
axis=-1, # Class axis | |
floatify=True, # Convert `targ` to `float` | |
**kwargs | |
): | |
"Same as `nn.L1Loss`, but flattens input and target." | |
return BaseLoss(nn.L1Loss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs) | |
# %% ../nbs/01a_losses.ipynb 24 | |
class LabelSmoothingCrossEntropy(Module): | |
y_int = True # y interpolation | |
def __init__(self, | |
eps:float=0.1, # The weight for the interpolation formula | |
weight:Tensor=None, # Manual rescaling weight given to each class passed to `F.nll_loss` | |
reduction:str='mean' # PyTorch reduction to apply to the output | |
): | |
store_attr() | |
def forward(self, output:Tensor, target:Tensor) -> Tensor: | |
"Apply `F.log_softmax` on output then blend the loss/num_classes(`c`) with the `F.nll_loss`" | |
c = output.size()[1] | |
log_preds = F.log_softmax(output, dim=1) | |
if self.reduction=='sum': loss = -log_preds.sum() | |
else: | |
loss = -log_preds.sum(dim=1) #We divide by that size at the return line so sum and not mean | |
if self.reduction=='mean': loss = loss.mean() | |
return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target.long(), weight=self.weight, reduction=self.reduction) | |
def activation(self, out:Tensor) -> Tensor: | |
"`F.log_softmax`'s fused activation function applied to model output" | |
return F.softmax(out, dim=-1) | |
def decodes(self, out:Tensor) -> Tensor: | |
"Converts model output to target format" | |
return out.argmax(dim=-1) | |
# %% ../nbs/01a_losses.ipynb 27 | |
class LabelSmoothingCrossEntropyFlat(BaseLoss): | |
"Same as `LabelSmoothingCrossEntropy`, but flattens input and target." | |
y_int = True | |
def __init__(self, | |
*args, | |
axis:int=-1, # Class axis | |
**kwargs | |
): | |
super().__init__(LabelSmoothingCrossEntropy, *args, axis=axis, **kwargs) | |
def activation(self, out:Tensor) -> Tensor: | |
"`LabelSmoothingCrossEntropy`'s fused activation function applied to model output" | |
return F.softmax(out, dim=-1) | |
def decodes(self, out:Tensor) -> Tensor: | |
"Converts model output to target format" | |
return out.argmax(dim=-1) | |
# %% ../nbs/01a_losses.ipynb 30 | |
class DiceLoss: | |
"Dice loss for segmentation" | |
def __init__(self, | |
axis:int=1, # Class axis | |
smooth:float=1e-6, # Helps with numerical stabilities in the IoU division | |
reduction:str="sum", # PyTorch reduction to apply to the output | |
square_in_union:bool=False # Squares predictions to increase slope of gradients | |
): | |
store_attr() | |
def __call__(self, pred:Tensor, targ:Tensor) -> Tensor: | |
"One-hot encodes targ, then runs IoU calculation then takes 1-dice value" | |
targ = self._one_hot(targ, pred.shape[self.axis]) | |
pred, targ = TensorBase(pred), TensorBase(targ) | |
assert pred.shape == targ.shape, 'input and target dimensions differ, DiceLoss expects non one-hot targs' | |
pred = self.activation(pred) | |
sum_dims = list(range(2, len(pred.shape))) | |
inter = torch.sum(pred*targ, dim=sum_dims) | |
union = (torch.sum(pred**2+targ, dim=sum_dims) if self.square_in_union | |
else torch.sum(pred+targ, dim=sum_dims)) | |
dice_score = (2. * inter + self.smooth)/(union + self.smooth) | |
loss = 1- dice_score | |
if self.reduction == 'mean': | |
loss = loss.mean() | |
elif self.reduction == 'sum': | |
loss = loss.sum() | |
return loss | |
def _one_hot( | |
x:Tensor, # Non one-hot encoded targs | |
classes:int, # The number of classes | |
axis:int=1 # The axis to stack for encoding (class dimension) | |
) -> Tensor: | |
"Creates one binary mask per class" | |
return torch.stack([torch.where(x==c, 1, 0) for c in range(classes)], axis=axis) | |
def activation(self, x:Tensor) -> Tensor: | |
"Activation function applied to model output" | |
return F.softmax(x, dim=self.axis) | |
def decodes(self, x:Tensor) -> Tensor: | |
"Converts model output to target format" | |
return x.argmax(dim=self.axis) | |