|
from sentence_transformers.evaluation import SentenceEvaluator |
|
from sentence_transformers.util import batch_to_device |
|
from sentence_transformers import SentenceTransformer |
|
from typing import List, Tuple, Dict |
|
import torch |
|
import numpy as np |
|
import logging |
|
import os |
|
import csv |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class MSEEvaluatorFromDataFrame(SentenceEvaluator): |
|
""" |
|
Computes the mean squared error (x100) between the computed sentence embedding |
|
and some target sentence embedding. |
|
:param dataframe: |
|
It must have the following format. Rows contains different, parallel sentences. Columns are the respective language codes |
|
[{'en': 'My sentence', 'es': 'Sentence in Spanisch', 'fr': 'Sentence in French'...}, |
|
{'en': 'My second sentence', ....] |
|
:param combinations: |
|
Must be of the format [('en', 'es'), ('en', 'fr'), ...] |
|
First entry in a tuple is the source language. The sentence in the respective language will be fetched from the dataframe and passed to the teacher model. |
|
Second entry in a tuple the the target language. Sentence will be fetched from the dataframe and passed to the student model |
|
""" |
|
def __init__(self, dataframe: List[Dict[str, str]], teacher_model: SentenceTransformer, combinations: List[Tuple[str, str]], batch_size: int = 8, name='', write_csv: bool = True): |
|
|
|
self.combinations = combinations |
|
self.name = name |
|
self.batch_size = batch_size |
|
|
|
|
|
if name: |
|
name = "_"+name |
|
|
|
self.csv_file = "mse_evaluation" + name + "_results.csv" |
|
self.csv_headers = ["epoch", "steps"] |
|
self.write_csv = write_csv |
|
self.data = {} |
|
|
|
logger.info("Compute teacher embeddings") |
|
all_source_sentences = set() |
|
for src_lang, trg_lang in self.combinations: |
|
src_sentences = [] |
|
trg_sentences = [] |
|
|
|
for row in dataframe: |
|
if row[src_lang].strip() != "" and row[trg_lang].strip() != "": |
|
all_source_sentences.add(row[src_lang]) |
|
src_sentences.append(row[src_lang]) |
|
trg_sentences.append(row[trg_lang]) |
|
|
|
self.data[(src_lang, trg_lang)] = (src_sentences, trg_sentences) |
|
self.csv_headers.append("{}-{}".format(src_lang, trg_lang)) |
|
|
|
all_source_sentences = list(all_source_sentences) |
|
all_src_embeddings = teacher_model.encode(all_source_sentences, batch_size=self.batch_size) |
|
self.teacher_embeddings = {sent: emb for sent, emb in zip(all_source_sentences, all_src_embeddings)} |
|
|
|
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1): |
|
model.eval() |
|
|
|
mse_scores = [] |
|
for src_lang, trg_lang in self.combinations: |
|
src_sentences, trg_sentences = self.data[(src_lang, trg_lang)] |
|
|
|
src_embeddings = np.asarray([self.teacher_embeddings[sent] for sent in src_sentences]) |
|
trg_embeddings = np.asarray(model.encode(trg_sentences, batch_size=self.batch_size)) |
|
|
|
mse = ((src_embeddings - trg_embeddings) ** 2).mean() |
|
mse *= 100 |
|
mse_scores.append(mse) |
|
|
|
logger.info("MSE evaluation on {} dataset - {}-{}:".format(self.name, src_lang, trg_lang)) |
|
logger.info("MSE (*100):\t{:4f}".format(mse)) |
|
|
|
if output_path is not None and self.write_csv: |
|
csv_path = os.path.join(output_path, self.csv_file) |
|
output_file_exists = os.path.isfile(csv_path) |
|
with open(csv_path, newline='', mode="a" if output_file_exists else 'w', encoding="utf-8") as f: |
|
writer = csv.writer(f) |
|
if not output_file_exists: |
|
writer.writerow(self.csv_headers) |
|
|
|
writer.writerow([epoch, steps]+mse_scores) |
|
|
|
return -np.mean(mse_scores) |
|
|
|
|