SentenceTransformer / examples /evaluation /evaluation_translation_matching.py
lengocduc195's picture
pushNe
2359bda
raw
history blame
2.17 kB
"""
Given a tab seperated file (.tsv) with parallel sentences, where the second column is the translation of the sentence in the first column, for example, in the format:
src1 trg1
src2 trg2
...
where trg_i is the translation of src_i.
Given src_i, the TranslationEvaluator checks which trg_j has the highest similarity using cosine similarity. If i == j, we assume
a match, i.e., the correct translation has been found for src_i out of all possible target sentences.
It then computes an accuracy over all possible source sentences src_i. Equivalently, it computes also the accuracy for the other direction.
A high accuracy score indicates that the model is able to find the correct translation out of a large pool with sentences.
Usage:
python [model_name_or_path] [parallel-file1] [parallel-file2] ...
For example:
python distiluse-base-multilingual-cased TED2020-en-de.tsv.gz
See the training_multilingual/get_parallel_data_...py scripts for getting parallel sentence data from different sources
"""
from sentence_transformers import SentenceTransformer, evaluation, LoggingHandler
import sys
import gzip
import os
import logging
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)
model_name = sys.argv[1]
filepaths = sys.argv[2:]
inference_batch_size = 32
model = SentenceTransformer(model_name)
for filepath in filepaths:
src_sentences = []
trg_sentences = []
with gzip.open(filepath, 'rt', encoding='utf8') if filepath.endswith('.gz') else open(filepath, 'r', encoding='utf8') as fIn:
for line in fIn:
splits = line.strip().split('\t')
if len(splits) >= 2:
src_sentences.append(splits[0])
trg_sentences.append(splits[1])
logger.info(os.path.basename(filepath)+": "+str(len(src_sentences))+" sentence pairs")
dev_trans_acc = evaluation.TranslationEvaluator(src_sentences, trg_sentences, name=os.path.basename(filepath), batch_size=inference_batch_size)
dev_trans_acc(model)