File size: 4,073 Bytes
9dce458 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import os
import py3langid as langid
from .common import OfflineTranslator
ISO_639_1_TO_MBart50 = {
'ar': 'ar_AR',
'de': 'de_DE',
'en': 'en_XX',
'es': 'es_XX',
'fr': 'fr_XX',
'hi': 'hi_IN',
'it': 'it_IT',
'ja': 'ja_XX',
'nl': 'nl_XX',
'pl': 'pl_PL',
'pt': 'pt_XX',
'ru': 'ru_RU',
'sw': 'sw_KE',
'th': 'th_TH',
'tr': 'tr_TR',
'ur': 'ur_PK',
'vi': 'vi_VN',
'zh': 'zh_CN',
}
class MBart50Translator(OfflineTranslator):
# https://huggingface.co/facebook/mbart-large-50
# other languages can be added as well
_LANGUAGE_CODE_MAP = {
"ARA": "ar_AR",
"DEU": "de_DE",
"ENG": "en_XX",
"ESP": "es_XX",
"FRA": "fr_XX",
"HIN": "hi_IN",
"ITA": "it_IT",
"JPN": "ja_XX",
"NLD": "nl_XX",
"PLK": "pl_PL",
"PTB": "pt_XX",
"RUS": "ru_RU",
"SWA": "sw_KE",
"THA": "th_TH",
"TRK": "tr_TR",
"URD": "ur_PK",
"VIN": "vi_VN",
"CHS": "zh_CN",
}
_MODEL_SUB_DIR = os.path.join(OfflineTranslator._MODEL_DIR, OfflineTranslator._MODEL_SUB_DIR, 'mbart50')
_TRANSLATOR_MODEL = "facebook/mbart-large-50-many-to-many-mmt"
async def _load(self, from_lang: str, to_lang: str, device: str):
from transformers import (
MBartForConditionalGeneration,
AutoTokenizer,
)
if ':' not in device:
device += ':0'
self.device = device
self.model = MBartForConditionalGeneration.from_pretrained(self._TRANSLATOR_MODEL)
if self.device != 'cpu':
self.model.to(self.device)
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(self._TRANSLATOR_MODEL)
async def _unload(self):
del self.model
del self.tokenizer
async def _infer(self, from_lang: str, to_lang: str, queries: list[str]) -> list[str]:
if from_lang == 'auto':
detected_lang = langid.classify('\n'.join(queries))[0]
target_lang = self._map_detected_lang_to_translator(detected_lang)
if target_lang == None:
self.logger.warn('Could not detect language from over all sentence. Will try per sentence.')
else:
from_lang = target_lang
return [self._translate_sentence(from_lang, to_lang, query) for query in queries]
def _translate_sentence(self, from_lang: str, to_lang: str, query: str) -> str:
if not self.is_loaded():
return ''
if from_lang == 'auto':
detected_lang = langid.classify(query)[0]
from_lang = self._map_detected_lang_to_translator(detected_lang)
if from_lang == None:
self.logger.warn(f'MBart50 Translation Failed. Could not detect language (Or language not supported for text: {query})')
return ''
self.tokenizer.src_lang = from_lang
tokens = self.tokenizer(query, return_tensors="pt")
# move to device
if self.device != 'cpu':
tokens = tokens.to(self.device)
generated_tokens = self.model.generate(**tokens, forced_bos_token_id=self.tokenizer.lang_code_to_id[to_lang])
result = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
return result
def _map_detected_lang_to_translator(self, lang):
if lang not in ISO_639_1_TO_MBart50:
return None
return ISO_639_1_TO_MBart50[lang]
async def _download(self):
import huggingface_hub
# do not download msgpack and h5 files as they are not needed to run the model
huggingface_hub.snapshot_download(self._TRANSLATOR_MODEL, cache_dir=self._MODEL_SUB_DIR, ignore_patterns=["*.msgpack", "*.h5", '*.ot',".*", "*.safetensors"])
def _check_downloaded(self) -> bool:
import huggingface_hub
return huggingface_hub.try_to_load_from_cache(self._TRANSLATOR_MODEL, 'pytorch_model.bin', cache_dir=self._MODEL_SUB_DIR) is not None |