File size: 4,408 Bytes
9dce458 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import os
from typing import List
import py3langid as langid
from .common import OfflineTranslator
# https://github.com/facebookresearch/flores/blob/main/flores200/README.md
ISO_639_1_TO_FLORES_200 = {
'zh': 'zho_Hans',
'ja': 'jpn_Jpan',
'en': 'eng_Latn',
'kn': 'kor_Hang',
'cs': 'ces_Latn',
'nl': 'nld_Latn',
'fr': 'fra_Latn',
'de': 'deu_Latn',
'hu': 'hun_Latn',
'it': 'ita_Latn',
'pl': 'pol_Latn',
'pt': 'por_Latn',
'ro': 'ron_Latn',
'ru': 'rus_Cyrl',
'es': 'spa_Latn',
'tr': 'tur_Latn',
'uk': 'ukr_Cyrl',
'vi': 'vie_Latn',
'ar': 'arb_Arab',
'sr': 'srp_Cyrl',
'hr': 'hrv_Latn',
'th': 'tha_Thai',
'id': 'ind_Latn'
}
class NLLBTranslator(OfflineTranslator):
_LANGUAGE_CODE_MAP = {
'CHS': 'zho_Hans',
'CHT': 'zho_Hant',
'JPN': 'jpn_Jpan',
'ENG': 'eng_Latn',
'KOR': 'kor_Hang',
'CSY': 'ces_Latn',
'NLD': 'nld_Latn',
'FRA': 'fra_Latn',
'DEU': 'deu_Latn',
'HUN': 'hun_Latn',
'ITA': 'ita_Latn',
'PLK': 'pol_Latn',
'PTB': 'por_Latn',
'ROM': 'ron_Latn',
'RUS': 'rus_Cyrl',
'ESP': 'spa_Latn',
'TRK': 'tur_Latn',
'UKR': 'Ukrainian',
'VIN': 'vie_Latn',
'ARA': 'arb_Arab',
'SRP': 'srp_Cyrl',
'HRV': 'hrv_Latn',
'THA': 'tha_Thai',
'IND': 'ind_Latn'
}
_MODEL_SUB_DIR = os.path.join(OfflineTranslator._MODEL_DIR, OfflineTranslator._MODEL_SUB_DIR, 'nllb')
_TRANSLATOR_MODEL = 'facebook/nllb-200-distilled-600M'
async def _load(self, from_lang: str, to_lang: str, device: str):
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
if ':' not in device:
device += ':0'
self.device = device
self.model = AutoModelForSeq2SeqLM.from_pretrained(self._TRANSLATOR_MODEL)
self.tokenizer = AutoTokenizer.from_pretrained(self._TRANSLATOR_MODEL)
async def _unload(self):
del self.model
del self.tokenizer
async def _infer(self, from_lang: str, to_lang: str, queries: List[str]) -> List[str]:
if from_lang == 'auto':
detected_lang = langid.classify('\n'.join(queries))[0]
target_lang = self._map_detected_lang_to_translator(detected_lang)
if target_lang == None:
self.logger.warn('Could not detect language from over all sentence. Will try per sentence.')
else:
from_lang = target_lang
return [self._translate_sentence(from_lang, to_lang, query) for query in queries]
def _translate_sentence(self, from_lang: str, to_lang: str, query: str) -> str:
from transformers import pipeline
if not self.is_loaded():
return ''
if from_lang == 'auto':
detected_lang = langid.classify(query)[0]
from_lang = self._map_detected_lang_to_translator(detected_lang)
if from_lang == None:
self.logger.warn(f'NLLB Translation Failed. Could not detect language (Or language not supported for text: {query})')
return ''
translator = pipeline('translation',
device=self.device,
model=self.model,
tokenizer=self.tokenizer,
src_lang=from_lang,
tgt_lang=to_lang,
max_length = 512,
)
result = translator(query)[0]['translation_text']
return result
def _map_detected_lang_to_translator(self, lang):
if not lang in ISO_639_1_TO_FLORES_200:
return None
return ISO_639_1_TO_FLORES_200[lang]
async def _download(self):
import huggingface_hub
# do not download msgpack and h5 files as they are not needed to run the model
huggingface_hub.snapshot_download(self._TRANSLATOR_MODEL, cache_dir=self._MODEL_SUB_DIR, ignore_patterns=["*.msgpack", "*.h5", '*.ot',".*", "*.safetensors"])
def _check_downloaded(self) -> bool:
import huggingface_hub
return huggingface_hub.try_to_load_from_cache(self._TRANSLATOR_MODEL, 'pytorch_model.bin', cache_dir=self._MODEL_SUB_DIR) is not None
class NLLBBigTranslator(NLLBTranslator):
_MODEL_SUB_DIR = os.path.join(OfflineTranslator._MODEL_DIR, OfflineTranslator._MODEL_SUB_DIR, 'nllb_big')
_TRANSLATOR_MODEL = 'facebook/nllb-200-distilled-1.3B' |