Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
from mmengine.model import BaseModule | |
from mmdet.registry import MODELS | |
class TripletLoss(BaseModule): | |
"""Triplet loss with hard positive/negative mining. | |
Reference: | |
Hermans et al. In Defense of the Triplet Loss for | |
Person Re-Identification. arXiv:1703.07737. | |
Imported from `<https://github.com/KaiyangZhou/deep-person-reid/blob/ | |
master/torchreid/losses/hard_mine_triplet_loss.py>`_. | |
Args: | |
margin (float, optional): Margin for triplet loss. Defaults to 0.3. | |
loss_weight (float, optional): Weight of the loss. Defaults to 1.0. | |
hard_mining (bool, optional): Whether to perform hard mining. | |
Defaults to True. | |
""" | |
def __init__(self, | |
margin: float = 0.3, | |
loss_weight: float = 1.0, | |
hard_mining=True): | |
super(TripletLoss, self).__init__() | |
self.margin = margin | |
self.ranking_loss = nn.MarginRankingLoss(margin=margin) | |
self.loss_weight = loss_weight | |
self.hard_mining = hard_mining | |
def hard_mining_triplet_loss_forward( | |
self, inputs: torch.Tensor, | |
targets: torch.LongTensor) -> torch.Tensor: | |
""" | |
Args: | |
inputs (torch.Tensor): feature matrix with shape | |
(batch_size, feat_dim). | |
targets (torch.LongTensor): ground truth labels with shape | |
(num_classes). | |
Returns: | |
torch.Tensor: triplet loss with hard mining. | |
""" | |
batch_size = inputs.size(0) | |
# Compute Euclidean distance | |
dist = torch.pow(inputs, 2).sum( | |
dim=1, keepdim=True).expand(batch_size, batch_size) | |
dist = dist + dist.t() | |
dist.addmm_(inputs, inputs.t(), beta=1, alpha=-2) | |
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability | |
# For each anchor, find the furthest positive sample | |
# and nearest negative sample in the embedding space | |
mask = targets.expand(batch_size, batch_size).eq( | |
targets.expand(batch_size, batch_size).t()) | |
dist_ap, dist_an = [], [] | |
for i in range(batch_size): | |
dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) | |
dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) | |
dist_ap = torch.cat(dist_ap) | |
dist_an = torch.cat(dist_an) | |
# Compute ranking hinge loss | |
y = torch.ones_like(dist_an) | |
return self.loss_weight * self.ranking_loss(dist_an, dist_ap, y) | |
def forward(self, inputs: torch.Tensor, | |
targets: torch.LongTensor) -> torch.Tensor: | |
""" | |
Args: | |
inputs (torch.Tensor): feature matrix with shape | |
(batch_size, feat_dim). | |
targets (torch.LongTensor): ground truth labels with shape | |
(num_classes). | |
Returns: | |
torch.Tensor: triplet loss. | |
""" | |
if self.hard_mining: | |
return self.hard_mining_triplet_loss_forward(inputs, targets) | |
else: | |
raise NotImplementedError() | |