|
import os |
|
import py3langid as langid |
|
|
|
|
|
from .common import OfflineTranslator |
|
|
|
ISO_639_1_TO_MBart50 = { |
|
|
|
'ar': 'ar_AR', |
|
'de': 'de_DE', |
|
'en': 'en_XX', |
|
'es': 'es_XX', |
|
'fr': 'fr_XX', |
|
'hi': 'hi_IN', |
|
'it': 'it_IT', |
|
'ja': 'ja_XX', |
|
'nl': 'nl_XX', |
|
'pl': 'pl_PL', |
|
'pt': 'pt_XX', |
|
'ru': 'ru_RU', |
|
'sw': 'sw_KE', |
|
'th': 'th_TH', |
|
'tr': 'tr_TR', |
|
'ur': 'ur_PK', |
|
'vi': 'vi_VN', |
|
'zh': 'zh_CN', |
|
|
|
|
|
} |
|
|
|
class MBart50Translator(OfflineTranslator): |
|
|
|
|
|
_LANGUAGE_CODE_MAP = { |
|
"ARA": "ar_AR", |
|
"DEU": "de_DE", |
|
"ENG": "en_XX", |
|
"ESP": "es_XX", |
|
"FRA": "fr_XX", |
|
"HIN": "hi_IN", |
|
"ITA": "it_IT", |
|
"JPN": "ja_XX", |
|
"NLD": "nl_XX", |
|
"PLK": "pl_PL", |
|
"PTB": "pt_XX", |
|
"RUS": "ru_RU", |
|
"SWA": "sw_KE", |
|
"THA": "th_TH", |
|
"TRK": "tr_TR", |
|
"URD": "ur_PK", |
|
"VIN": "vi_VN", |
|
"CHS": "zh_CN", |
|
} |
|
|
|
_MODEL_SUB_DIR = os.path.join(OfflineTranslator._MODEL_DIR, OfflineTranslator._MODEL_SUB_DIR, 'mbart50') |
|
|
|
_TRANSLATOR_MODEL = "facebook/mbart-large-50-many-to-many-mmt" |
|
|
|
|
|
|
|
async def _load(self, from_lang: str, to_lang: str, device: str): |
|
from transformers import ( |
|
MBartForConditionalGeneration, |
|
AutoTokenizer, |
|
) |
|
if ':' not in device: |
|
device += ':0' |
|
self.device = device |
|
self.model = MBartForConditionalGeneration.from_pretrained(self._TRANSLATOR_MODEL) |
|
if self.device != 'cpu': |
|
self.model.to(self.device) |
|
self.model.eval() |
|
self.tokenizer = AutoTokenizer.from_pretrained(self._TRANSLATOR_MODEL) |
|
|
|
async def _unload(self): |
|
del self.model |
|
del self.tokenizer |
|
|
|
async def _infer(self, from_lang: str, to_lang: str, queries: list[str]) -> list[str]: |
|
if from_lang == 'auto': |
|
detected_lang = langid.classify('\n'.join(queries))[0] |
|
target_lang = self._map_detected_lang_to_translator(detected_lang) |
|
|
|
if target_lang == None: |
|
self.logger.warn('Could not detect language from over all sentence. Will try per sentence.') |
|
else: |
|
from_lang = target_lang |
|
|
|
return [self._translate_sentence(from_lang, to_lang, query) for query in queries] |
|
|
|
def _translate_sentence(self, from_lang: str, to_lang: str, query: str) -> str: |
|
|
|
if not self.is_loaded(): |
|
return '' |
|
|
|
if from_lang == 'auto': |
|
detected_lang = langid.classify(query)[0] |
|
from_lang = self._map_detected_lang_to_translator(detected_lang) |
|
|
|
if from_lang == None: |
|
self.logger.warn(f'MBart50 Translation Failed. Could not detect language (Or language not supported for text: {query})') |
|
return '' |
|
|
|
self.tokenizer.src_lang = from_lang |
|
tokens = self.tokenizer(query, return_tensors="pt") |
|
|
|
if self.device != 'cpu': |
|
tokens = tokens.to(self.device) |
|
generated_tokens = self.model.generate(**tokens, forced_bos_token_id=self.tokenizer.lang_code_to_id[to_lang]) |
|
result = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
|
return result |
|
|
|
def _map_detected_lang_to_translator(self, lang): |
|
if lang not in ISO_639_1_TO_MBart50: |
|
return None |
|
|
|
return ISO_639_1_TO_MBart50[lang] |
|
|
|
async def _download(self): |
|
import huggingface_hub |
|
|
|
huggingface_hub.snapshot_download(self._TRANSLATOR_MODEL, cache_dir=self._MODEL_SUB_DIR, ignore_patterns=["*.msgpack", "*.h5", '*.ot',".*", "*.safetensors"]) |
|
|
|
def _check_downloaded(self) -> bool: |
|
import huggingface_hub |
|
return huggingface_hub.try_to_load_from_cache(self._TRANSLATOR_MODEL, 'pytorch_model.bin', cache_dir=self._MODEL_SUB_DIR) is not None |