TTP / mmdet /models /losses /triplet_loss.py
KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmengine.model import BaseModule
from mmdet.registry import MODELS
@MODELS.register_module()
class TripletLoss(BaseModule):
"""Triplet loss with hard positive/negative mining.
Reference:
Hermans et al. In Defense of the Triplet Loss for
Person Re-Identification. arXiv:1703.07737.
Imported from `<https://github.com/KaiyangZhou/deep-person-reid/blob/
master/torchreid/losses/hard_mine_triplet_loss.py>`_.
Args:
margin (float, optional): Margin for triplet loss. Defaults to 0.3.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
hard_mining (bool, optional): Whether to perform hard mining.
Defaults to True.
"""
def __init__(self,
margin: float = 0.3,
loss_weight: float = 1.0,
hard_mining=True):
super(TripletLoss, self).__init__()
self.margin = margin
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
self.loss_weight = loss_weight
self.hard_mining = hard_mining
def hard_mining_triplet_loss_forward(
self, inputs: torch.Tensor,
targets: torch.LongTensor) -> torch.Tensor:
"""
Args:
inputs (torch.Tensor): feature matrix with shape
(batch_size, feat_dim).
targets (torch.LongTensor): ground truth labels with shape
(num_classes).
Returns:
torch.Tensor: triplet loss with hard mining.
"""
batch_size = inputs.size(0)
# Compute Euclidean distance
dist = torch.pow(inputs, 2).sum(
dim=1, keepdim=True).expand(batch_size, batch_size)
dist = dist + dist.t()
dist.addmm_(inputs, inputs.t(), beta=1, alpha=-2)
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
# For each anchor, find the furthest positive sample
# and nearest negative sample in the embedding space
mask = targets.expand(batch_size, batch_size).eq(
targets.expand(batch_size, batch_size).t())
dist_ap, dist_an = [], []
for i in range(batch_size):
dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
dist_ap = torch.cat(dist_ap)
dist_an = torch.cat(dist_an)
# Compute ranking hinge loss
y = torch.ones_like(dist_an)
return self.loss_weight * self.ranking_loss(dist_an, dist_ap, y)
def forward(self, inputs: torch.Tensor,
targets: torch.LongTensor) -> torch.Tensor:
"""
Args:
inputs (torch.Tensor): feature matrix with shape
(batch_size, feat_dim).
targets (torch.LongTensor): ground truth labels with shape
(num_classes).
Returns:
torch.Tensor: triplet loss.
"""
if self.hard_mining:
return self.hard_mining_triplet_loss_forward(inputs, targets)
else:
raise NotImplementedError()