Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from opencd.registry import MODELS | |
def bcl_loss( | |
pred, | |
target, | |
margin=2.0, | |
eps=1e-4, | |
ignore_index=255, | |
**kwargs): | |
pred = pred.squeeze() | |
target = target.squeeze() | |
assert pred.size() == target.size() and target.numel() > 0 | |
mask = (target != ignore_index).float() | |
target = target * mask | |
utarget = 1 - target | |
n_u = utarget.sum() + eps | |
n_c = target.sum() + eps | |
loss = torch.sum(utarget * torch.pow(pred, 2) * mask) / n_u + \ | |
torch.sum(target * torch.pow(torch.clamp(margin - pred, min=0.), 2)) / n_c | |
return loss | |
class BCLLoss(nn.Module): | |
"""Batch-balanced Contrastive Loss""" | |
def __init__( | |
self, | |
margin=2.0, | |
loss_weight=1.0, | |
ignore_index=255, | |
loss_name='bcl_loss', | |
**kwargs): | |
super().__init__() | |
self.margin = margin | |
self.loss_weight = loss_weight | |
self.ignore_index = ignore_index | |
self._loss_name = loss_name | |
def forward(self, | |
pred, | |
target, | |
**kwargs): | |
loss = self.loss_weight * bcl_loss( | |
pred, target, self.margin, self.ignore_index) | |
return loss | |
def loss_name(self): | |
"""Loss Name. | |
This function must be implemented and will return the name of this | |
loss function. This name will be used to combine different loss items | |
by simple sum operation. In addition, if you want this loss item to be | |
included into the backward graph, `loss_` must be the prefix of the | |
name. | |
Returns: | |
str: The name of this loss item. | |
""" | |
return self._loss_name |