import torch |
from torch import nn, Tensor |
from typing import Iterable, Dict |
from ..SentenceTransformer import SentenceTransformer |
from .. import util |
import copy |
import random |
import math |
from .. import InputExample |
import numpy as np |
class ContrastiveTensionLoss(nn.Module): |
""" |
This loss expects as input a batch consisting of multiple mini-batches of sentence pairs (a_1, p_1), (a_2, p_2)..., (a_{K+1}, p_{K+1}) |
where p_1 = a_1 = a_2 = ... a_{K+1} and p_2, p_3, ..., p_{K+1} are expected to be different from p_1 (this is done via random sampling). |
The corresponding labels y_1, y_2, ..., y_{K+1} for each mini-batch are assigned as: y_i = 1 if i == 1 and y_i = 0 otherwise. |
In other words, K represent the number of negative pairs and the positive pair is actually made of two identical sentences. The data generation |
process has already been implemented in readers/ContrastiveTensionReader.py |
For tractable optimization, two independent encoders ('model1' and 'model2') are created for encoding a_i and p_i, respectively. For inference, |
only model2 are used, which gives better performance. The training objective is binary cross entropy. |
For more information, see: https://openreview.net/pdf?id=Ov_sMNau-PF |
""" |
def __init__(self, model: SentenceTransformer): |
""" |
:param model: SentenceTransformer model |
""" |
super(ContrastiveTensionLoss, self).__init__() |
self.model2 = model |
self.model1 = copy.deepcopy(model) |
self.criterion = nn.BCEWithLogitsLoss(reduction='sum') |
def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor): |
sentence_features1, sentence_features2 = tuple(sentence_features) |
reps_1 = self.model1(sentence_features1)['sentence_embedding'] |
reps_2 = self.model2(sentence_features2)['sentence_embedding'] |
sim_scores = torch.matmul(reps_1[:,None], reps_2[:,:,None]).squeeze(-1).squeeze(-1) |
loss = self.criterion(sim_scores, labels.type_as(sim_scores)) |
return loss |
class ContrastiveTensionLossInBatchNegatives(nn.Module): |
def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_fct = util.cos_sim): |
""" |
:param model: SentenceTransformer model |
""" |
super(ContrastiveTensionLossInBatchNegatives, self).__init__() |
self.model2 = model |
self.model1 = copy.deepcopy(model) |
self.similarity_fct = similarity_fct |
self.cross_entropy_loss = nn.CrossEntropyLoss() |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(scale)) |
def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor): |
sentence_features1, sentence_features2 = tuple(sentence_features) |
embeddings_a = self.model1(sentence_features1)['sentence_embedding'] |
embeddings_b = self.model2(sentence_features2)['sentence_embedding'] |
scores = self.similarity_fct(embeddings_a, embeddings_b) * self.logit_scale.exp() |
labels = torch.tensor(range(len(scores)), dtype=torch.long, device=scores.device) |
return (self.cross_entropy_loss(scores, labels) + self.cross_entropy_loss(scores.t(), labels))/2 |
class ContrastiveTensionDataLoader: |
def __init__(self, sentences, batch_size, pos_neg_ratio=8): |
self.sentences = sentences |
self.batch_size = batch_size |
self.pos_neg_ratio = pos_neg_ratio |
self.collate_fn = None |
if self.batch_size % self.pos_neg_ratio != 0: |
raise ValueError(f"ContrastiveTensionDataLoader was loaded with a pos_neg_ratio of {pos_neg_ratio} and a batch size of {batch_size}. The batch size must be devisable by the pos_neg_ratio") |
def __iter__(self): |
random.shuffle(self.sentences) |
sentence_idx = 0 |
batch = [] |
while sentence_idx + 1 < len(self.sentences): |
s1 = self.sentences[sentence_idx] |
if len(batch) % self.pos_neg_ratio > 0: |
sentence_idx += 1 |
s2 = self.sentences[sentence_idx] |
label = 0 |
else: |
s2 = self.sentences[sentence_idx] |
label = 1 |
sentence_idx += 1 |
batch.append(InputExample(texts=[s1, s2], label=label)) |
if len(batch) >= self.batch_size: |
yield self.collate_fn(batch) if self.collate_fn is not None else batch |
batch = [] |
def __len__(self): |
return math.floor(len(self.sentences)/(2*self.batch_size)) |