|
import random |
|
import math |
|
|
|
class NoDuplicatesDataLoader: |
|
|
|
def __init__(self, train_examples, batch_size): |
|
""" |
|
A special data loader to be used with MultipleNegativesRankingLoss. |
|
The data loader ensures that there are no duplicate sentences within the same batch |
|
""" |
|
self.batch_size = batch_size |
|
self.data_pointer = 0 |
|
self.collate_fn = None |
|
self.train_examples = train_examples |
|
random.shuffle(self.train_examples) |
|
|
|
def __iter__(self): |
|
for _ in range(self.__len__()): |
|
batch = [] |
|
texts_in_batch = set() |
|
|
|
while len(batch) < self.batch_size: |
|
example = self.train_examples[self.data_pointer] |
|
|
|
valid_example = True |
|
for text in example.texts: |
|
if text.strip().lower() in texts_in_batch: |
|
valid_example = False |
|
break |
|
|
|
if valid_example: |
|
batch.append(example) |
|
for text in example.texts: |
|
texts_in_batch.add(text.strip().lower()) |
|
|
|
self.data_pointer += 1 |
|
if self.data_pointer >= len(self.train_examples): |
|
self.data_pointer = 0 |
|
random.shuffle(self.train_examples) |
|
|
|
yield self.collate_fn(batch) if self.collate_fn is not None else batch |
|
|
|
def __len__(self): |
|
return math.floor(len(self.train_examples) / self.batch_size) |