|
import torch |
|
from torch import nn, Tensor |
|
from typing import Iterable, Dict |
|
from ..SentenceTransformer import SentenceTransformer |
|
from .. import util |
|
|
|
class MultipleNegativesSymmetricRankingLoss(nn.Module): |
|
""" |
|
This loss is an adaptation of MultipleNegativesRankingLoss. MultipleNegativesRankingLoss computes the following loss: |
|
For a given anchor and a list of candidates, find the positive candidate. |
|
|
|
In MultipleNegativesSymmetricRankingLoss, we add another loss term: Given the positive and a list of all anchors, |
|
find the correct (matching) anchor. |
|
|
|
For the example of question-answering: You have (question, answer)-pairs. MultipleNegativesRankingLoss just computes |
|
the loss to find the answer for a given question. MultipleNegativesSymmetricRankingLoss additionally computes the |
|
loss to find the question for a given answer. |
|
|
|
Note: If you pass triplets, the negative entry will be ignored. A anchor is just searched for the positive. |
|
|
|
Example:: |
|
|
|
from sentence_transformers import SentenceTransformer, losses, InputExample |
|
from torch.utils.data import DataLoader |
|
|
|
model = SentenceTransformer('distilbert-base-uncased') |
|
train_examples = [InputExample(texts=['Anchor 1', 'Positive 1']), |
|
InputExample(texts=['Anchor 2', 'Positive 2'])] |
|
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32) |
|
train_loss = losses.MultipleNegativesSymmetricRankingLoss(model=model) |
|
""" |
|
def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_fct = util.cos_sim): |
|
""" |
|
:param model: SentenceTransformer model |
|
:param scale: Output of similarity function is multiplied by scale value |
|
:param similarity_fct: similarity function between sentence embeddings. By default, cos_sim. Can also be set to dot product (and then set scale to 1) |
|
""" |
|
super(MultipleNegativesSymmetricRankingLoss, self).__init__() |
|
self.model = model |
|
self.scale = scale |
|
self.similarity_fct = similarity_fct |
|
self.cross_entropy_loss = nn.CrossEntropyLoss() |
|
|
|
|
|
def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor): |
|
reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features] |
|
anchor = reps[0] |
|
candidates = torch.cat(reps[1:]) |
|
|
|
scores = self.similarity_fct(anchor, candidates) * self.scale |
|
labels = torch.tensor(range(len(scores)), dtype=torch.long, device=scores.device) |
|
|
|
anchor_positive_scores = scores[:, 0:len(reps[1])] |
|
forward_loss = self.cross_entropy_loss(scores, labels) |
|
backward_loss = self.cross_entropy_loss(anchor_positive_scores.transpose(0, 1), labels) |
|
return (forward_loss + backward_loss) / 2 |
|
|
|
def get_config_dict(self): |
|
return {'scale': self.scale, 'similarity_fct': self.similarity_fct.__name__} |
|
|
|
|
|
|
|
|
|
|
|
|