|
import torch |
|
from torch import nn, Tensor |
|
from typing import Union, Tuple, List, Iterable, Dict |
|
import torch.nn.functional as F |
|
from enum import Enum |
|
from ..SentenceTransformer import SentenceTransformer |
|
|
|
class TripletDistanceMetric(Enum): |
|
""" |
|
The metric for the triplet loss |
|
""" |
|
COSINE = lambda x, y: 1 - F.cosine_similarity(x, y) |
|
EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2) |
|
MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1) |
|
|
|
class TripletLoss(nn.Module): |
|
""" |
|
This class implements triplet loss. Given a triplet of (anchor, positive, negative), |
|
the loss minimizes the distance between anchor and positive while it maximizes the distance |
|
between anchor and negative. It compute the following loss function: |
|
|
|
loss = max(||anchor - positive|| - ||anchor - negative|| + margin, 0). |
|
|
|
Margin is an important hyperparameter and needs to be tuned respectively. |
|
|
|
For further details, see: https://en.wikipedia.org/wiki/Triplet_loss |
|
|
|
:param model: SentenceTransformerModel |
|
:param distance_metric: Function to compute distance between two embeddings. The class TripletDistanceMetric contains common distance metrices that can be used. |
|
:param triplet_margin: The negative should be at least this much further away from the anchor than the positive. |
|
|
|
Example:: |
|
|
|
from sentence_transformers import SentenceTransformer, SentencesDataset, LoggingHandler, losses |
|
from sentence_transformers.readers import InputExample |
|
|
|
model = SentenceTransformer('distilbert-base-nli-mean-tokens') |
|
train_examples = [InputExample(texts=['Anchor 1', 'Positive 1', 'Negative 1']), |
|
InputExample(texts=['Anchor 2', 'Positive 2', 'Negative 2'])] |
|
train_dataset = SentencesDataset(train_examples, model) |
|
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size) |
|
train_loss = losses.TripletLoss(model=model) |
|
""" |
|
def __init__(self, model: SentenceTransformer, distance_metric=TripletDistanceMetric.EUCLIDEAN, triplet_margin: float = 5): |
|
super(TripletLoss, self).__init__() |
|
self.model = model |
|
self.distance_metric = distance_metric |
|
self.triplet_margin = triplet_margin |
|
|
|
|
|
def get_config_dict(self): |
|
distance_metric_name = self.distance_metric.__name__ |
|
for name, value in vars(TripletDistanceMetric).items(): |
|
if value == self.distance_metric: |
|
distance_metric_name = "TripletDistanceMetric.{}".format(name) |
|
break |
|
|
|
return {'distance_metric': distance_metric_name, 'triplet_margin': self.triplet_margin} |
|
|
|
def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor): |
|
reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features] |
|
|
|
rep_anchor, rep_pos, rep_neg = reps |
|
distance_pos = self.distance_metric(rep_anchor, rep_pos) |
|
distance_neg = self.distance_metric(rep_anchor, rep_neg) |
|
|
|
losses = F.relu(distance_pos - distance_neg + self.triplet_margin) |
|
return losses.mean() |