lengocduc195's picture
pushNe
2359bda
import torch
from torch import nn, Tensor
from typing import Union, Tuple, List, Iterable, Dict
import torch.nn.functional as F
from enum import Enum
from ..SentenceTransformer import SentenceTransformer
class TripletDistanceMetric(Enum):
"""
The metric for the triplet loss
"""
COSINE = lambda x, y: 1 - F.cosine_similarity(x, y)
EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2)
MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1)
class TripletLoss(nn.Module):
"""
This class implements triplet loss. Given a triplet of (anchor, positive, negative),
the loss minimizes the distance between anchor and positive while it maximizes the distance
between anchor and negative. It compute the following loss function:
loss = max(||anchor - positive|| - ||anchor - negative|| + margin, 0).
Margin is an important hyperparameter and needs to be tuned respectively.
For further details, see: https://en.wikipedia.org/wiki/Triplet_loss
:param model: SentenceTransformerModel
:param distance_metric: Function to compute distance between two embeddings. The class TripletDistanceMetric contains common distance metrices that can be used.
:param triplet_margin: The negative should be at least this much further away from the anchor than the positive.
Example::
from sentence_transformers import SentenceTransformer, SentencesDataset, LoggingHandler, losses
from sentence_transformers.readers import InputExample
model = SentenceTransformer('distilbert-base-nli-mean-tokens')
train_examples = [InputExample(texts=['Anchor 1', 'Positive 1', 'Negative 1']),
InputExample(texts=['Anchor 2', 'Positive 2', 'Negative 2'])]
train_dataset = SentencesDataset(train_examples, model)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
train_loss = losses.TripletLoss(model=model)
"""
def __init__(self, model: SentenceTransformer, distance_metric=TripletDistanceMetric.EUCLIDEAN, triplet_margin: float = 5):
super(TripletLoss, self).__init__()
self.model = model
self.distance_metric = distance_metric
self.triplet_margin = triplet_margin
def get_config_dict(self):
distance_metric_name = self.distance_metric.__name__
for name, value in vars(TripletDistanceMetric).items():
if value == self.distance_metric:
distance_metric_name = "TripletDistanceMetric.{}".format(name)
break
return {'distance_metric': distance_metric_name, 'triplet_margin': self.triplet_margin}
def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
rep_anchor, rep_pos, rep_neg = reps
distance_pos = self.distance_metric(rep_anchor, rep_pos)
distance_neg = self.distance_metric(rep_anchor, rep_neg)
losses = F.relu(distance_pos - distance_neg + self.triplet_margin)
return losses.mean()