Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import logging | |
from functools import partial | |
from typing import Optional | |
import torch | |
import torch.distributed as dist | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmengine.logging import print_log | |
from torch import Tensor | |
from mmdet.registry import MODELS | |
class EQLV2Loss(nn.Module): | |
def __init__(self, | |
use_sigmoid: bool = True, | |
reduction: str = 'mean', | |
class_weight: Optional[Tensor] = None, | |
loss_weight: float = 1.0, | |
num_classes: int = 1203, | |
use_distributed: bool = False, | |
mu: float = 0.8, | |
alpha: float = 4.0, | |
gamma: int = 12, | |
vis_grad: bool = False, | |
test_with_obj: bool = True) -> None: | |
"""`Equalization Loss v2 <https://arxiv.org/abs/2012.08548>`_ | |
Args: | |
use_sigmoid (bool): EQLv2 uses the sigmoid function to transform | |
the predicted logits to an estimated probability distribution. | |
reduction (str, optional): The method used to reduce the loss into | |
a scalar. Defaults to 'mean'. | |
class_weight (Tensor, optional): The weight of loss for each | |
prediction. Defaults to None. | |
loss_weight (float, optional): The weight of the total EQLv2 loss. | |
Defaults to 1.0. | |
num_classes (int): 1203 for lvis v1.0, 1230 for lvis v0.5. | |
use_distributed (bool, float): EQLv2 will calculate the gradients | |
on all GPUs if there is any. Change to True if you are using | |
distributed training. Default to False. | |
mu (float, optional): Defaults to 0.8 | |
alpha (float, optional): A balance factor for the negative part of | |
EQLV2 Loss. Defaults to 4.0. | |
gamma (int, optional): The gamma for calculating the modulating | |
factor. Defaults to 12. | |
vis_grad (bool, optional): Default to False. | |
test_with_obj (bool, optional): Default to True. | |
Returns: | |
None. | |
""" | |
super().__init__() | |
self.use_sigmoid = True | |
self.reduction = reduction | |
self.loss_weight = loss_weight | |
self.class_weight = class_weight | |
self.num_classes = num_classes | |
self.group = True | |
# cfg for eqlv2 | |
self.vis_grad = vis_grad | |
self.mu = mu | |
self.alpha = alpha | |
self.gamma = gamma | |
self.use_distributed = use_distributed | |
# initial variables | |
self.register_buffer('pos_grad', torch.zeros(self.num_classes)) | |
self.register_buffer('neg_grad', torch.zeros(self.num_classes)) | |
# At the beginning of training, we set a high value (eg. 100) | |
# for the initial gradient ratio so that the weight for pos | |
# gradients and neg gradients are 1. | |
self.register_buffer('pos_neg', torch.ones(self.num_classes) * 100) | |
self.test_with_obj = test_with_obj | |
def _func(x, gamma, mu): | |
return 1 / (1 + torch.exp(-gamma * (x - mu))) | |
self.map_func = partial(_func, gamma=self.gamma, mu=self.mu) | |
print_log( | |
f'build EQL v2, gamma: {gamma}, mu: {mu}, alpha: {alpha}', | |
logger='current', | |
level=logging.DEBUG) | |
def forward(self, | |
cls_score: Tensor, | |
label: Tensor, | |
weight: Optional[Tensor] = None, | |
avg_factor: Optional[int] = None, | |
reduction_override: Optional[Tensor] = None) -> Tensor: | |
"""`Equalization Loss v2 <https://arxiv.org/abs/2012.08548>`_ | |
Args: | |
cls_score (Tensor): The prediction with shape (N, C), C is the | |
number of classes. | |
label (Tensor): The ground truth label of the predicted target with | |
shape (N, C), C is the number of classes. | |
weight (Tensor, optional): The weight of loss for each prediction. | |
Defaults to None. | |
avg_factor (int, optional): Average factor that is used to average | |
the loss. Defaults to None. | |
reduction_override (str, optional): The reduction method used to | |
override the original reduction method of the loss. | |
Options are "none", "mean" and "sum". | |
Returns: | |
Tensor: The calculated loss | |
""" | |
self.n_i, self.n_c = cls_score.size() | |
self.gt_classes = label | |
self.pred_class_logits = cls_score | |
def expand_label(pred, gt_classes): | |
target = pred.new_zeros(self.n_i, self.n_c) | |
target[torch.arange(self.n_i), gt_classes] = 1 | |
return target | |
target = expand_label(cls_score, label) | |
pos_w, neg_w = self.get_weight(cls_score) | |
weight = pos_w * target + neg_w * (1 - target) | |
cls_loss = F.binary_cross_entropy_with_logits( | |
cls_score, target, reduction='none') | |
cls_loss = torch.sum(cls_loss * weight) / self.n_i | |
self.collect_grad(cls_score.detach(), target.detach(), weight.detach()) | |
return self.loss_weight * cls_loss | |
def get_channel_num(self, num_classes): | |
num_channel = num_classes + 1 | |
return num_channel | |
def get_activation(self, pred): | |
pred = torch.sigmoid(pred) | |
n_i, n_c = pred.size() | |
bg_score = pred[:, -1].view(n_i, 1) | |
if self.test_with_obj: | |
pred[:, :-1] *= (1 - bg_score) | |
return pred | |
def collect_grad(self, pred, target, weight): | |
prob = torch.sigmoid(pred) | |
grad = target * (prob - 1) + (1 - target) * prob | |
grad = torch.abs(grad) | |
# do not collect grad for objectiveness branch [:-1] | |
pos_grad = torch.sum(grad * target * weight, dim=0)[:-1] | |
neg_grad = torch.sum(grad * (1 - target) * weight, dim=0)[:-1] | |
if self.use_distributed: | |
dist.all_reduce(pos_grad) | |
dist.all_reduce(neg_grad) | |
self.pos_grad += pos_grad | |
self.neg_grad += neg_grad | |
self.pos_neg = self.pos_grad / (self.neg_grad + 1e-10) | |
def get_weight(self, pred): | |
neg_w = torch.cat([self.map_func(self.pos_neg), pred.new_ones(1)]) | |
pos_w = 1 + self.alpha * (1 - neg_w) | |
neg_w = neg_w.view(1, -1).expand(self.n_i, self.n_c) | |
pos_w = pos_w.view(1, -1).expand(self.n_i, self.n_c) | |
return pos_w, neg_w | |