File size: 6,800 Bytes
2359bda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
from torch.utils.data import Dataset
import logging
import gzip
from .. import SentenceTransformer
from ..readers import InputExample
from typing import List
import random
logger = logging.getLogger(__name__)
class ParallelSentencesDataset(Dataset):
"""
This dataset reader can be used to read-in parallel sentences, i.e., it reads in a file with tab-seperated sentences with the same
sentence in different languages. For example, the file can look like this (EN\tDE\tES):
hello world hallo welt hola mundo
second sentence zweiter satz segunda oración
The sentence in the first column will be mapped to a sentence embedding using the given the embedder. For example,
embedder is a mono-lingual sentence embedding method for English. The sentences in the other languages will also be
mapped to this English sentence embedding.
When getting a sample from the dataset, we get one sentence with the according sentence embedding for this sentence.
teacher_model can be any class that implement an encode function. The encode function gets a list of sentences and
returns a list of sentence embeddings
"""
def __init__(self, student_model: SentenceTransformer, teacher_model: SentenceTransformer, batch_size: int = 8, use_embedding_cache: bool = True):
"""
Parallel sentences dataset reader to train student model given a teacher model
:param student_model: Student sentence embedding model that should be trained
:param teacher_model: Teacher model, that provides the sentence embeddings for the first column in the dataset file
"""
self.student_model = student_model
self.teacher_model = teacher_model
self.datasets = []
self.datasets_iterator = []
self.datasets_tokenized = []
self.dataset_indices = []
self.copy_dataset_indices = []
self.cache = []
self.batch_size = batch_size
self.use_embedding_cache = use_embedding_cache
self.embedding_cache = {}
self.num_sentences = 0
def load_data(self, filepath: str, weight: int = 100, max_sentences: int = None, max_sentence_length: int = 128):
"""
Reads in a tab-seperated .txt/.csv/.tsv or .gz file. The different columns contain the different translations of the sentence in the first column
:param filepath: Filepath to the file
:param weight: If more than one dataset is loaded with load_data: With which frequency should data be sampled from this dataset?
:param max_sentences: Max number of lines to be read from filepath
:param max_sentence_length: Skip the example if one of the sentences is has more characters than max_sentence_length
:param batch_size: Size for encoding parallel sentences
:return:
"""
logger.info("Load "+filepath)
parallel_sentences = []
with gzip.open(filepath, 'rt', encoding='utf8') if filepath.endswith('.gz') else open(filepath, encoding='utf8') as fIn:
count = 0
for line in fIn:
sentences = line.strip().split("\t")
if max_sentence_length is not None and max_sentence_length > 0 and max([len(sent) for sent in sentences]) > max_sentence_length:
continue
parallel_sentences.append(sentences)
count += 1
if max_sentences is not None and max_sentences > 0 and count >= max_sentences:
break
self.add_dataset(parallel_sentences, weight=weight, max_sentences=max_sentences, max_sentence_length=max_sentence_length)
def add_dataset(self, parallel_sentences: List[List[str]], weight: int = 100, max_sentences: int = None, max_sentence_length: int = 128):
sentences_map = {}
for sentences in parallel_sentences:
if max_sentence_length is not None and max_sentence_length > 0 and max([len(sent) for sent in sentences]) > max_sentence_length:
continue
source_sentence = sentences[0]
if source_sentence not in sentences_map:
sentences_map[source_sentence] = set()
for sent in sentences:
sentences_map[source_sentence].add(sent)
if max_sentences is not None and max_sentences > 0 and len(sentences_map) >= max_sentences:
break
if len(sentences_map) == 0:
return
self.num_sentences += sum([len(sentences_map[sent]) for sent in sentences_map])
dataset_id = len(self.datasets)
self.datasets.append(list(sentences_map.items()))
self.datasets_iterator.append(0)
self.dataset_indices.extend([dataset_id] * weight)
def generate_data(self):
source_sentences_list = []
target_sentences_list = []
for data_idx in self.dataset_indices:
src_sentence, trg_sentences = self.next_entry(data_idx)
source_sentences_list.append(src_sentence)
target_sentences_list.append(trg_sentences)
#Generate embeddings
src_embeddings = self.get_embeddings(source_sentences_list)
for src_embedding, trg_sentences in zip(src_embeddings, target_sentences_list):
for trg_sentence in trg_sentences:
self.cache.append(InputExample(texts=[trg_sentence], label=src_embedding))
random.shuffle(self.cache)
def next_entry(self, data_idx):
source, target_sentences = self.datasets[data_idx][self.datasets_iterator[data_idx]]
self.datasets_iterator[data_idx] += 1
if self.datasets_iterator[data_idx] >= len(self.datasets[data_idx]): #Restart iterator
self.datasets_iterator[data_idx] = 0
random.shuffle(self.datasets[data_idx])
return source, target_sentences
def get_embeddings(self, sentences):
if not self.use_embedding_cache:
return self.teacher_model.encode(sentences, batch_size=self.batch_size, show_progress_bar=False, convert_to_numpy=True)
#Use caching
new_sentences = []
for sent in sentences:
if sent not in self.embedding_cache:
new_sentences.append(sent)
if len(new_sentences) > 0:
new_embeddings = self.teacher_model.encode(new_sentences, batch_size=self.batch_size, show_progress_bar=False, convert_to_numpy=True)
for sent, embedding in zip(new_sentences, new_embeddings):
self.embedding_cache[sent] = embedding
return [self.embedding_cache[sent] for sent in sentences]
def __len__(self):
return self.num_sentences
def __getitem__(self, idx):
if len(self.cache) == 0:
self.generate_data()
return self.cache.pop()
|