|
|
|
|
|
import unittest |
|
import tempfile |
|
import os |
|
from translate import main |
|
import transformers |
|
|
|
|
|
class Inputs(unittest.TestCase): |
|
def test_m2m100_inputs(self): |
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
|
|
|
|
input_path = os.path.join(tmpdirname, "source.txt") |
|
output_path = os.path.join(tmpdirname, "target.txt") |
|
|
|
with open( |
|
os.path.join(tmpdirname, "source.txt"), "w", encoding="utf8" |
|
) as f: |
|
print("Hello, world, my name is Iker!", file=f) |
|
|
|
main( |
|
sentences_path=input_path, |
|
sentences_dir=None, |
|
files_extension="txt", |
|
output_path=output_path, |
|
source_lang="en", |
|
target_lang="es", |
|
starting_batch_size=32, |
|
model_name="facebook/m2m100_418M", |
|
lora_weights_name_or_path=None, |
|
force_auto_device_map=True, |
|
precision=None, |
|
max_length=64, |
|
num_beams=2, |
|
num_return_sequences=1, |
|
do_sample=False, |
|
temperature=1.0, |
|
top_k=50, |
|
top_p=1.0, |
|
keep_special_tokens=False, |
|
keep_tokenization_spaces=False, |
|
repetition_penalty=None, |
|
prompt=None, |
|
) |
|
|
|
main( |
|
sentences_path=None, |
|
sentences_dir=tmpdirname, |
|
files_extension="txt", |
|
output_path=os.path.join(tmpdirname, "target"), |
|
source_lang="en", |
|
target_lang="es", |
|
starting_batch_size=32, |
|
model_name="facebook/m2m100_418M", |
|
lora_weights_name_or_path=None, |
|
force_auto_device_map=True, |
|
precision=None, |
|
max_length=64, |
|
num_beams=2, |
|
num_return_sequences=1, |
|
do_sample=False, |
|
temperature=1.0, |
|
top_k=50, |
|
top_p=1.0, |
|
keep_special_tokens=False, |
|
keep_tokenization_spaces=False, |
|
repetition_penalty=None, |
|
prompt=None, |
|
) |
|
|
|
|
|
class Translations(unittest.TestCase): |
|
def test_m2m100(self): |
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
|
|
|
|
input_path = os.path.join(tmpdirname, "source.txt") |
|
output_path = os.path.join(tmpdirname, "target.txt") |
|
|
|
with open( |
|
os.path.join(tmpdirname, "source.txt"), "w", encoding="utf8" |
|
) as f: |
|
print("Hello, world, my name is Iker!", file=f) |
|
|
|
model_name = "facebook/m2m100_418M" |
|
src_lang = "en" |
|
tgt_lang = "es" |
|
|
|
main( |
|
sentences_path=input_path, |
|
sentences_dir=None, |
|
files_extension="txt", |
|
output_path=output_path, |
|
source_lang=src_lang, |
|
target_lang=tgt_lang, |
|
starting_batch_size=32, |
|
model_name=model_name, |
|
lora_weights_name_or_path=None, |
|
force_auto_device_map=True, |
|
precision="bf16", |
|
max_length=64, |
|
num_beams=2, |
|
num_return_sequences=1, |
|
do_sample=False, |
|
temperature=1.0, |
|
top_k=50, |
|
top_p=1.0, |
|
keep_special_tokens=False, |
|
keep_tokenization_spaces=False, |
|
repetition_penalty=None, |
|
prompt=None, |
|
) |
|
|
|
main( |
|
sentences_path=input_path, |
|
sentences_dir=None, |
|
files_extension="txt", |
|
output_path=output_path, |
|
source_lang=src_lang, |
|
target_lang=tgt_lang, |
|
starting_batch_size=32, |
|
model_name=model_name, |
|
lora_weights_name_or_path=None, |
|
force_auto_device_map=True, |
|
precision="4", |
|
max_length=64, |
|
num_beams=2, |
|
num_return_sequences=1, |
|
do_sample=False, |
|
temperature=1.0, |
|
top_k=50, |
|
top_p=1.0, |
|
keep_special_tokens=False, |
|
keep_tokenization_spaces=False, |
|
repetition_penalty=None, |
|
prompt=None, |
|
) |
|
|
|
def test_nllb200(self): |
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
|
|
|
|
input_path = os.path.join(tmpdirname, "source.txt") |
|
output_path = os.path.join(tmpdirname, "target.txt") |
|
|
|
with open( |
|
os.path.join(tmpdirname, "source.txt"), "w", encoding="utf8" |
|
) as f: |
|
print("Hello, world, my name is Iker!", file=f) |
|
|
|
model_name = "facebook/nllb-200-distilled-600M" |
|
src_lang = "eng_Latn" |
|
tgt_lang = "spa_Latn" |
|
|
|
main( |
|
sentences_path=input_path, |
|
sentences_dir=None, |
|
files_extension="txt", |
|
output_path=output_path, |
|
source_lang=src_lang, |
|
target_lang=tgt_lang, |
|
starting_batch_size=32, |
|
model_name=model_name, |
|
lora_weights_name_or_path=None, |
|
force_auto_device_map=True, |
|
precision="bf16", |
|
max_length=64, |
|
num_beams=2, |
|
num_return_sequences=1, |
|
do_sample=False, |
|
temperature=1.0, |
|
top_k=50, |
|
top_p=1.0, |
|
keep_special_tokens=False, |
|
keep_tokenization_spaces=False, |
|
repetition_penalty=None, |
|
prompt=None, |
|
) |
|
|
|
main( |
|
sentences_path=input_path, |
|
sentences_dir=None, |
|
files_extension="txt", |
|
output_path=output_path, |
|
source_lang=src_lang, |
|
target_lang=tgt_lang, |
|
starting_batch_size=32, |
|
model_name=model_name, |
|
lora_weights_name_or_path=None, |
|
force_auto_device_map=True, |
|
precision="4", |
|
max_length=64, |
|
num_beams=2, |
|
num_return_sequences=1, |
|
do_sample=False, |
|
temperature=1.0, |
|
top_k=50, |
|
top_p=1.0, |
|
keep_special_tokens=False, |
|
keep_tokenization_spaces=False, |
|
repetition_penalty=None, |
|
prompt=None, |
|
) |
|
|
|
def test_mbart(self): |
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
|
|
|
|
input_path = os.path.join(tmpdirname, "source.txt") |
|
output_path = os.path.join(tmpdirname, "target.txt") |
|
|
|
with open( |
|
os.path.join(tmpdirname, "source.txt"), "w", encoding="utf8" |
|
) as f: |
|
print("Hello, world, my name is Iker!", file=f) |
|
|
|
model_name = "facebook/mbart-large-50" |
|
src_lang = "en_XX" |
|
tgt_lang = "es_XX" |
|
|
|
main( |
|
sentences_path=input_path, |
|
sentences_dir=None, |
|
files_extension="txt", |
|
output_path=output_path, |
|
source_lang=src_lang, |
|
target_lang=tgt_lang, |
|
starting_batch_size=32, |
|
model_name=model_name, |
|
lora_weights_name_or_path=None, |
|
force_auto_device_map=True, |
|
precision="bf16", |
|
max_length=64, |
|
num_beams=2, |
|
num_return_sequences=1, |
|
do_sample=False, |
|
temperature=1.0, |
|
top_k=50, |
|
top_p=1.0, |
|
keep_special_tokens=False, |
|
keep_tokenization_spaces=False, |
|
repetition_penalty=None, |
|
prompt=None, |
|
) |
|
|
|
main( |
|
sentences_path=input_path, |
|
sentences_dir=None, |
|
files_extension="txt", |
|
output_path=output_path, |
|
source_lang=src_lang, |
|
target_lang=tgt_lang, |
|
starting_batch_size=32, |
|
model_name=model_name, |
|
lora_weights_name_or_path=None, |
|
force_auto_device_map=True, |
|
precision="4", |
|
max_length=64, |
|
num_beams=2, |
|
num_return_sequences=1, |
|
do_sample=False, |
|
temperature=1.0, |
|
top_k=50, |
|
top_p=1.0, |
|
keep_special_tokens=False, |
|
keep_tokenization_spaces=False, |
|
repetition_penalty=None, |
|
prompt=None, |
|
) |
|
|
|
def test_opus(self): |
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
|
|
|
|
input_path = os.path.join(tmpdirname, "source.txt") |
|
output_path = os.path.join(tmpdirname, "target.txt") |
|
|
|
with open( |
|
os.path.join(tmpdirname, "source.txt"), "w", encoding="utf8" |
|
) as f: |
|
print("Hello, world, my name is Iker!", file=f) |
|
|
|
model_name = "Helsinki-NLP/opus-mt-en-es" |
|
src_lang = None |
|
tgt_lang = None |
|
|
|
main( |
|
sentences_path=input_path, |
|
sentences_dir=None, |
|
files_extension="txt", |
|
output_path=output_path, |
|
source_lang=src_lang, |
|
target_lang=tgt_lang, |
|
starting_batch_size=32, |
|
model_name=model_name, |
|
lora_weights_name_or_path=None, |
|
force_auto_device_map=False, |
|
precision="bf16", |
|
max_length=64, |
|
num_beams=2, |
|
num_return_sequences=1, |
|
do_sample=False, |
|
temperature=1.0, |
|
top_k=50, |
|
top_p=1.0, |
|
keep_special_tokens=False, |
|
keep_tokenization_spaces=False, |
|
repetition_penalty=None, |
|
prompt=None, |
|
) |
|
|
|
main( |
|
sentences_path=input_path, |
|
sentences_dir=None, |
|
files_extension="txt", |
|
output_path=output_path, |
|
source_lang=src_lang, |
|
target_lang=tgt_lang, |
|
starting_batch_size=32, |
|
model_name=model_name, |
|
lora_weights_name_or_path=None, |
|
force_auto_device_map=False, |
|
precision="4", |
|
max_length=64, |
|
num_beams=2, |
|
num_return_sequences=1, |
|
do_sample=False, |
|
temperature=1.0, |
|
top_k=50, |
|
top_p=1.0, |
|
keep_special_tokens=False, |
|
keep_tokenization_spaces=False, |
|
repetition_penalty=None, |
|
prompt=None, |
|
) |
|
|
|
@unittest.skipIf( |
|
transformers.__version__ > "4.34.0", |
|
"Small100 tokenizer is not supported in transformers > 4.34.0. Please use transformers <= 4.34.0 if you want to use small100", |
|
) |
|
def test_small100(self): |
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
|
|
|
|
input_path = os.path.join(tmpdirname, "source.txt") |
|
output_path = os.path.join(tmpdirname, "target.txt") |
|
|
|
with open( |
|
os.path.join(tmpdirname, "source.txt"), "w", encoding="utf8" |
|
) as f: |
|
print("Hello, world, my name is Iker!", file=f) |
|
|
|
model_name = "alirezamsh/small100" |
|
src_lang = None |
|
tgt_lang = "es" |
|
|
|
main( |
|
sentences_path=input_path, |
|
sentences_dir=None, |
|
files_extension="txt", |
|
output_path=output_path, |
|
source_lang=src_lang, |
|
target_lang=tgt_lang, |
|
starting_batch_size=32, |
|
model_name=model_name, |
|
lora_weights_name_or_path=None, |
|
force_auto_device_map=True, |
|
precision="bf16", |
|
max_length=64, |
|
num_beams=2, |
|
num_return_sequences=1, |
|
do_sample=False, |
|
temperature=1.0, |
|
top_k=50, |
|
top_p=1.0, |
|
keep_special_tokens=False, |
|
keep_tokenization_spaces=False, |
|
repetition_penalty=None, |
|
prompt=None, |
|
) |
|
|
|
main( |
|
sentences_path=input_path, |
|
sentences_dir=None, |
|
files_extension="txt", |
|
output_path=output_path, |
|
source_lang=src_lang, |
|
target_lang=tgt_lang, |
|
starting_batch_size=32, |
|
model_name=model_name, |
|
lora_weights_name_or_path=None, |
|
force_auto_device_map=True, |
|
precision="4", |
|
max_length=64, |
|
num_beams=2, |
|
num_return_sequences=1, |
|
do_sample=False, |
|
temperature=1.0, |
|
top_k=50, |
|
top_p=1.0, |
|
keep_special_tokens=False, |
|
keep_tokenization_spaces=False, |
|
repetition_penalty=None, |
|
prompt=None, |
|
) |
|
|
|
def test_seamless(self): |
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
|
|
|
|
input_path = os.path.join(tmpdirname, "source.txt") |
|
output_path = os.path.join(tmpdirname, "target.txt") |
|
|
|
with open( |
|
os.path.join(tmpdirname, "source.txt"), "w", encoding="utf8" |
|
) as f: |
|
print("Hello, world, my name is Iker!", file=f) |
|
|
|
model_name = "facebook/hf-seamless-m4t-medium" |
|
src_lang = "eng" |
|
tgt_lang = "spa" |
|
|
|
main( |
|
sentences_path=input_path, |
|
sentences_dir=None, |
|
files_extension="txt", |
|
output_path=output_path, |
|
source_lang=src_lang, |
|
target_lang=tgt_lang, |
|
starting_batch_size=32, |
|
model_name=model_name, |
|
lora_weights_name_or_path=None, |
|
force_auto_device_map=True, |
|
precision="bf16", |
|
max_length=64, |
|
num_beams=2, |
|
num_return_sequences=1, |
|
do_sample=False, |
|
temperature=1.0, |
|
top_k=50, |
|
top_p=1.0, |
|
keep_special_tokens=False, |
|
keep_tokenization_spaces=False, |
|
repetition_penalty=None, |
|
prompt=None, |
|
) |
|
|
|
main( |
|
sentences_path=input_path, |
|
sentences_dir=None, |
|
files_extension="txt", |
|
output_path=output_path, |
|
source_lang=src_lang, |
|
target_lang=tgt_lang, |
|
starting_batch_size=32, |
|
model_name=model_name, |
|
lora_weights_name_or_path=None, |
|
force_auto_device_map=True, |
|
precision="4", |
|
max_length=64, |
|
num_beams=2, |
|
num_return_sequences=1, |
|
do_sample=False, |
|
temperature=1.0, |
|
top_k=50, |
|
top_p=1.0, |
|
keep_special_tokens=False, |
|
keep_tokenization_spaces=False, |
|
repetition_penalty=None, |
|
prompt=None, |
|
) |
|
|
|
|
|
class Prompting(unittest.TestCase): |
|
def test_llama(self): |
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
|
|
|
|
input_path = os.path.join(tmpdirname, "source.txt") |
|
output_path = os.path.join(tmpdirname, "target.txt") |
|
|
|
with open( |
|
os.path.join(tmpdirname, "source.txt"), "w", encoding="utf8" |
|
) as f: |
|
print("Hello, world, my name is Iker!", file=f) |
|
|
|
model_name = "stas/tiny-random-llama-2" |
|
prompt = "Translate English to Spanish: %%SENTENCE%%" |
|
|
|
main( |
|
sentences_path=input_path, |
|
sentences_dir=None, |
|
files_extension="txt", |
|
output_path=output_path, |
|
source_lang=None, |
|
target_lang=None, |
|
starting_batch_size=32, |
|
model_name=model_name, |
|
lora_weights_name_or_path=None, |
|
force_auto_device_map=True, |
|
precision="bf16", |
|
max_length=64, |
|
num_beams=2, |
|
num_return_sequences=1, |
|
do_sample=True, |
|
temperature=1.0, |
|
top_k=50, |
|
top_p=1.0, |
|
keep_special_tokens=False, |
|
keep_tokenization_spaces=False, |
|
repetition_penalty=None, |
|
prompt=prompt, |
|
) |
|
|
|
main( |
|
sentences_path=input_path, |
|
sentences_dir=None, |
|
files_extension="txt", |
|
output_path=output_path, |
|
source_lang=None, |
|
target_lang=None, |
|
starting_batch_size=32, |
|
model_name=model_name, |
|
lora_weights_name_or_path=None, |
|
force_auto_device_map=True, |
|
precision="4", |
|
max_length=64, |
|
num_beams=2, |
|
num_return_sequences=1, |
|
do_sample=True, |
|
temperature=1.0, |
|
top_k=50, |
|
top_p=1.0, |
|
keep_special_tokens=False, |
|
keep_tokenization_spaces=False, |
|
repetition_penalty=None, |
|
prompt=prompt, |
|
) |
|
|