Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
import torch | |
from .abstract_loss_func import AbstractLossClass | |
from metrics.registry import LOSSFUNC | |
class ConsistencyCos(nn.Module): | |
def __init__(self): | |
super(ConsistencyCos, self).__init__() | |
# # CrossEntropy Loss | |
# weight=torch.Tensor([4.0, 1.0]) | |
# if torch.cuda.is_available(): | |
# weight = weight.cuda() | |
# self.loss_fn = nn.CrossEntropyLoss(weight) | |
self.loss_fn = nn.CrossEntropyLoss() | |
self.mse_fn = nn.MSELoss() | |
def forward(self, feat, inputs, targets): | |
feat = nn.functional.normalize(feat, dim=1) | |
feat_0 = feat[:int(feat.size(0)/2),:] | |
feat_1 = feat[int(feat.size(0)/2): 2*int(feat.size(0)/2),:] | |
cos = torch.einsum('nc,nc->n', [feat_0, feat_1]).unsqueeze(-1) | |
labels = torch.ones((cos.shape[0],1), dtype=torch.float, requires_grad=False) | |
if torch.cuda.is_available(): | |
labels = labels.cuda() | |
self.consistency_rate = 1.0 | |
loss = self.consistency_rate * self.mse_fn(cos, labels) + self.loss_fn(inputs, targets) | |
return loss | |
# | |
##FIXME to be implemented | |
class ConsistencyL2(nn.Module): | |
def __init__(self): | |
super(ConsistencyL2, self).__init__() | |
self.mse_fn = nn.MSELoss() | |
def forward(self, feat): | |
feat_0 = feat[:int(feat.size(0)/2),:] | |
feat_1 = feat[int(feat.size(0)/2):,:] | |
loss = self.mse_fn(feat_0, feat_1) | |
return loss | |
class ConsistencyL1(nn.Module): | |
def __init__(self): | |
super(ConsistencyL1, self).__init__() | |
self.L1_fn = nn.L1Loss() | |
def forward(self, feat): | |
feat_0 = feat[:int(feat.size(0)/2),:] | |
feat_1 = feat[int(feat.size(0)/2):,:] | |
loss = self.L1_fn(feat_0, feat_1) | |
return loss |