TTP / mmdet /models /losses /l2_loss.py
KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple, Union
import numpy as np
import torch
from mmengine.model import BaseModule
from torch import Tensor
from mmdet.registry import MODELS
from .utils import weighted_loss
@weighted_loss
def l2_loss(pred: Tensor, target: Tensor) -> Tensor:
"""L2 loss.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
Returns:
torch.Tensor: Calculated loss
"""
assert pred.size() == target.size()
loss = torch.abs(pred - target)**2
return loss
@MODELS.register_module()
class L2Loss(BaseModule):
"""L2 loss.
Args:
reduction (str, optional): The method to reduce the loss.
Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of loss.
"""
def __init__(self,
neg_pos_ub: int = -1,
pos_margin: float = -1,
neg_margin: float = -1,
hard_mining: bool = False,
reduction: str = 'mean',
loss_weight: float = 1.0):
super(L2Loss, self).__init__()
self.neg_pos_ub = neg_pos_ub
self.pos_margin = pos_margin
self.neg_margin = neg_margin
self.hard_mining = hard_mining
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred: Tensor,
target: Tensor,
weight: Optional[Tensor] = None,
avg_factor: Optional[float] = None,
reduction_override: Optional[str] = None) -> Tensor:
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (float, optional): Average factor that is used to
average the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
pred, weight, avg_factor = self.update_weight(pred, target, weight,
avg_factor)
loss_bbox = self.loss_weight * l2_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss_bbox
def update_weight(self, pred: Tensor, target: Tensor, weight: Tensor,
avg_factor: float) -> Tuple[Tensor, Tensor, float]:
"""Update the weight according to targets."""
if weight is None:
weight = target.new_ones(target.size())
invalid_inds = weight <= 0
target[invalid_inds] = -1
pos_inds = target == 1
neg_inds = target == 0
if self.pos_margin > 0:
pred[pos_inds] -= self.pos_margin
if self.neg_margin > 0:
pred[neg_inds] -= self.neg_margin
pred = torch.clamp(pred, min=0, max=1)
num_pos = int((target == 1).sum())
num_neg = int((target == 0).sum())
if self.neg_pos_ub > 0 and num_neg / (num_pos +
1e-6) > self.neg_pos_ub:
num_neg = num_pos * self.neg_pos_ub
neg_idx = torch.nonzero(target == 0, as_tuple=False)
if self.hard_mining:
costs = l2_loss(
pred, target, reduction='none')[neg_idx[:, 0],
neg_idx[:, 1]].detach()
neg_idx = neg_idx[costs.topk(num_neg)[1], :]
else:
neg_idx = self.random_choice(neg_idx, num_neg)
new_neg_inds = neg_inds.new_zeros(neg_inds.size()).bool()
new_neg_inds[neg_idx[:, 0], neg_idx[:, 1]] = True
invalid_neg_inds = torch.logical_xor(neg_inds, new_neg_inds)
weight[invalid_neg_inds] = 0
avg_factor = (weight > 0).sum()
return pred, weight, avg_factor
@staticmethod
def random_choice(gallery: Union[list, np.ndarray, Tensor],
num: int) -> np.ndarray:
"""Random select some elements from the gallery.
It seems that Pytorch's implementation is slower than numpy so we use
numpy to randperm the indices.
"""
assert len(gallery) >= num
if isinstance(gallery, list):
gallery = np.array(gallery)
cands = np.arange(len(gallery))
np.random.shuffle(cands)
rand_inds = cands[:num]
if not isinstance(gallery, np.ndarray):
rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device)
return gallery[rand_inds]