| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| import gradio as gr | |
| import os | |
| from modules.utils.paths import TRANSLATION_OUTPUT_DIR, NLLB_MODELS_DIR | |
| from modules.translation.translation_base import TranslationBase | |
| class NLLBInference(TranslationBase): | |
| def __init__(self, | |
| model_dir: str = NLLB_MODELS_DIR, | |
| output_dir: str = TRANSLATION_OUTPUT_DIR | |
| ): | |
| super().__init__( | |
| model_dir=model_dir, | |
| output_dir=output_dir | |
| ) | |
| self.tokenizer = None | |
| self.available_models = ["facebook/nllb-200-3.3B", "facebook/nllb-200-1.3B", "facebook/nllb-200-distilled-1.3B", "facebook/nllb-200-distilled-600M"] | |
| self.available_source_langs = list(NLLB_AVAILABLE_LANGS.keys()) | |
| self.available_target_langs = list(NLLB_AVAILABLE_LANGS.keys()) | |
| self.pipeline = None | |
| def translate(self, | |
| text: str, | |
| max_length: int | |
| ): | |
| result = self.pipeline( | |
| text, | |
| max_length=max_length | |
| ) | |
| return result[0]['translation_text'] | |
| def update_model(self, | |
| model_size: str, | |
| src_lang: str, | |
| tgt_lang: str, | |
| progress: gr.Progress = gr.Progress() | |
| ): | |
| def validate_language(lang: str) -> str: | |
| if lang in NLLB_AVAILABLE_LANGS: | |
| return NLLB_AVAILABLE_LANGS[lang] | |
| elif lang not in NLLB_AVAILABLE_LANGS.values(): | |
| raise ValueError( | |
| f"Language '{lang}' is not supported. Use one of: {list(NLLB_AVAILABLE_LANGS.keys())}") | |
| return lang | |
| src_lang = validate_language(src_lang) | |
| tgt_lang = validate_language(tgt_lang) | |
| if model_size != self.current_model_size or self.model is None: | |
| print("\nInitializing NLLB Model...\n") | |
| progress(0, desc="Initializing NLLB Model...") | |
| self.current_model_size = model_size | |
| local_files_only = self.is_model_exists(self.current_model_size) | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size, | |
| cache_dir=self.model_dir, | |
| local_files_only=local_files_only) | |
| self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size, | |
| cache_dir=os.path.join(self.model_dir, "tokenizers"), | |
| local_files_only=local_files_only) | |
| self.pipeline = pipeline("translation", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| src_lang=src_lang, | |
| tgt_lang=tgt_lang, | |
| device=self.device) | |
| def is_model_exists(self, | |
| model_size: str): | |
| """Check if model exists or not (Only facebook model)""" | |
| prefix = "models--facebook--" | |
| _id, model_size_name = model_size.split("/") | |
| model_dir_name = prefix + model_size_name | |
| model_dir_path = os.path.join(self.model_dir, model_dir_name) | |
| if os.path.exists(model_dir_path) and os.listdir(model_dir_path): | |
| return True | |
| return False | |
| NLLB_AVAILABLE_LANGS = { | |
| "Acehnese (Arabic script)": "ace_Arab", | |
| "Acehnese (Latin script)": "ace_Latn", | |
| "Mesopotamian Arabic": "acm_Arab", | |
| "Ta’izzi-Adeni Arabic": "acq_Arab", | |
| "Tunisian Arabic": "aeb_Arab", | |
| "Afrikaans": "afr_Latn", | |
| "South Levantine Arabic": "ajp_Arab", | |
| "Akan": "aka_Latn", | |
| "Amharic": "amh_Ethi", | |
| "North Levantine Arabic": "apc_Arab", | |
| "Modern Standard Arabic": "arb_Arab", | |
| "Modern Standard Arabic (Romanized)": "arb_Latn", | |
| "Najdi Arabic": "ars_Arab", | |
| "Moroccan Arabic": "ary_Arab", | |
| "Egyptian Arabic": "arz_Arab", | |
| "Assamese": "asm_Beng", | |
| "Asturian": "ast_Latn", | |
| "Awadhi": "awa_Deva", | |
| "Central Aymara": "ayr_Latn", | |
| "South Azerbaijani": "azb_Arab", | |
| "North Azerbaijani": "azj_Latn", | |
| "Bashkir": "bak_Cyrl", | |
| "Bambara": "bam_Latn", | |
| "Balinese": "ban_Latn", | |
| "Belarusian": "bel_Cyrl", | |
| "Bemba": "bem_Latn", | |
| "Bengali": "ben_Beng", | |
| "Bhojpuri": "bho_Deva", | |
| "Banjar (Arabic script)": "bjn_Arab", | |
| "Banjar (Latin script)": "bjn_Latn", | |
| "Standard Tibetan": "bod_Tibt", | |
| "Bosnian": "bos_Latn", | |
| "Buginese": "bug_Latn", | |
| "Bulgarian": "bul_Cyrl", | |
| "Catalan": "cat_Latn", | |
| "Cebuano": "ceb_Latn", | |
| "Czech": "ces_Latn", | |
| "Chokwe": "cjk_Latn", | |
| "Central Kurdish": "ckb_Arab", | |
| "Crimean Tatar": "crh_Latn", | |
| "Welsh": "cym_Latn", | |
| "Danish": "dan_Latn", | |
| "German": "deu_Latn", | |
| "Southwestern Dinka": "dik_Latn", | |
| "Dyula": "dyu_Latn", | |
| "Dzongkha": "dzo_Tibt", | |
| "Greek": "ell_Grek", | |
| "English": "eng_Latn", | |
| "Esperanto": "epo_Latn", | |
| "Estonian": "est_Latn", | |
| "Basque": "eus_Latn", | |
| "Ewe": "ewe_Latn", | |
| "Faroese": "fao_Latn", | |
| "Fijian": "fij_Latn", | |
| "Finnish": "fin_Latn", | |
| "Fon": "fon_Latn", | |
| "French": "fra_Latn", | |
| "Friulian": "fur_Latn", | |
| "Nigerian Fulfulde": "fuv_Latn", | |
| "Scottish Gaelic": "gla_Latn", | |
| "Irish": "gle_Latn", | |
| "Galician": "glg_Latn", | |
| "Guarani": "grn_Latn", | |
| "Gujarati": "guj_Gujr", | |
| "Haitian Creole": "hat_Latn", | |
| "Hausa": "hau_Latn", | |
| "Hebrew": "heb_Hebr", | |
| "Hindi": "hin_Deva", | |
| "Chhattisgarhi": "hne_Deva", | |
| "Croatian": "hrv_Latn", | |
| "Hungarian": "hun_Latn", | |
| "Armenian": "hye_Armn", | |
| "Igbo": "ibo_Latn", | |
| "Ilocano": "ilo_Latn", | |
| "Indonesian": "ind_Latn", | |
| "Icelandic": "isl_Latn", | |
| "Italian": "ita_Latn", | |
| "Javanese": "jav_Latn", | |
| "Japanese": "jpn_Jpan", | |
| "Kabyle": "kab_Latn", | |
| "Jingpho": "kac_Latn", | |
| "Kamba": "kam_Latn", | |
| "Kannada": "kan_Knda", | |
| "Kashmiri (Arabic script)": "kas_Arab", | |
| "Kashmiri (Devanagari script)": "kas_Deva", | |
| "Georgian": "kat_Geor", | |
| "Central Kanuri (Arabic script)": "knc_Arab", | |
| "Central Kanuri (Latin script)": "knc_Latn", | |
| "Kazakh": "kaz_Cyrl", | |
| "Kabiyè": "kbp_Latn", | |
| "Kabuverdianu": "kea_Latn", | |
| "Khmer": "khm_Khmr", | |
| "Kikuyu": "kik_Latn", | |
| "Kinyarwanda": "kin_Latn", | |
| "Kyrgyz": "kir_Cyrl", | |
| "Kimbundu": "kmb_Latn", | |
| "Northern Kurdish": "kmr_Latn", | |
| "Kikongo": "kon_Latn", | |
| "Korean": "kor_Hang", | |
| "Lao": "lao_Laoo", | |
| "Ligurian": "lij_Latn", | |
| "Limburgish": "lim_Latn", | |
| "Lingala": "lin_Latn", | |
| "Lithuanian": "lit_Latn", | |
| "Lombard": "lmo_Latn", | |
| "Latgalian": "ltg_Latn", | |
| "Luxembourgish": "ltz_Latn", | |
| "Luba-Kasai": "lua_Latn", | |
| "Ganda": "lug_Latn", | |
| "Luo": "luo_Latn", | |
| "Mizo": "lus_Latn", | |
| "Standard Latvian": "lvs_Latn", | |
| "Magahi": "mag_Deva", | |
| "Maithili": "mai_Deva", | |
| "Malayalam": "mal_Mlym", | |
| "Marathi": "mar_Deva", | |
| "Minangkabau (Arabic script)": "min_Arab", | |
| "Minangkabau (Latin script)": "min_Latn", | |
| "Macedonian": "mkd_Cyrl", | |
| "Plateau Malagasy": "plt_Latn", | |
| "Maltese": "mlt_Latn", | |
| "Meitei (Bengali script)": "mni_Beng", | |
| "Halh Mongolian": "khk_Cyrl", | |
| "Mossi": "mos_Latn", | |
| "Maori": "mri_Latn", | |
| "Burmese": "mya_Mymr", | |
| "Dutch": "nld_Latn", | |
| "Norwegian Nynorsk": "nno_Latn", | |
| "Norwegian Bokmål": "nob_Latn", | |
| "Nepali": "npi_Deva", | |
| "Northern Sotho": "nso_Latn", | |
| "Nuer": "nus_Latn", | |
| "Nyanja": "nya_Latn", | |
| "Occitan": "oci_Latn", | |
| "West Central Oromo": "gaz_Latn", | |
| "Odia": "ory_Orya", | |
| "Pangasinan": "pag_Latn", | |
| "Eastern Panjabi": "pan_Guru", | |
| "Papiamento": "pap_Latn", | |
| "Western Persian": "pes_Arab", | |
| "Polish": "pol_Latn", | |
| "Portuguese": "por_Latn", | |
| "Dari": "prs_Arab", | |
| "Southern Pashto": "pbt_Arab", | |
| "Ayacucho Quechua": "quy_Latn", | |
| "Romanian": "ron_Latn", | |
| "Rundi": "run_Latn", | |
| "Russian": "rus_Cyrl", | |
| "Sango": "sag_Latn", | |
| "Sanskrit": "san_Deva", | |
| "Santali": "sat_Olck", | |
| "Sicilian": "scn_Latn", | |
| "Shan": "shn_Mymr", | |
| "Sinhala": "sin_Sinh", | |
| "Slovak": "slk_Latn", | |
| "Slovenian": "slv_Latn", | |
| "Samoan": "smo_Latn", | |
| "Shona": "sna_Latn", | |
| "Sindhi": "snd_Arab", | |
| "Somali": "som_Latn", | |
| "Southern Sotho": "sot_Latn", | |
| "Spanish": "spa_Latn", | |
| "Tosk Albanian": "als_Latn", | |
| "Sardinian": "srd_Latn", | |
| "Serbian": "srp_Cyrl", | |
| "Swati": "ssw_Latn", | |
| "Sundanese": "sun_Latn", | |
| "Swedish": "swe_Latn", | |
| "Swahili": "swh_Latn", | |
| "Silesian": "szl_Latn", | |
| "Tamil": "tam_Taml", | |
| "Tatar": "tat_Cyrl", | |
| "Telugu": "tel_Telu", | |
| "Tajik": "tgk_Cyrl", | |
| "Tagalog": "tgl_Latn", | |
| "Thai": "tha_Thai", | |
| "Tigrinya": "tir_Ethi", | |
| "Tamasheq (Latin script)": "taq_Latn", | |
| "Tamasheq (Tifinagh script)": "taq_Tfng", | |
| "Tok Pisin": "tpi_Latn", | |
| "Tswana": "tsn_Latn", | |
| "Tsonga": "tso_Latn", | |
| "Turkmen": "tuk_Latn", | |
| "Tumbuka": "tum_Latn", | |
| "Turkish": "tur_Latn", | |
| "Twi": "twi_Latn", | |
| "Central Atlas Tamazight": "tzm_Tfng", | |
| "Uyghur": "uig_Arab", | |
| "Ukrainian": "ukr_Cyrl", | |
| "Umbundu": "umb_Latn", | |
| "Urdu": "urd_Arab", | |
| "Northern Uzbek": "uzn_Latn", | |
| "Venetian": "vec_Latn", | |
| "Vietnamese": "vie_Latn", | |
| "Waray": "war_Latn", | |
| "Wolof": "wol_Latn", | |
| "Xhosa": "xho_Latn", | |
| "Eastern Yiddish": "ydd_Hebr", | |
| "Yoruba": "yor_Latn", | |
| "Yue Chinese": "yue_Hant", | |
| "Chinese (Simplified)": "zho_Hans", | |
| "Chinese (Traditional)": "zho_Hant", | |
| "Standard Malay": "zsm_Latn", | |
| "Zulu": "zul_Latn", | |
| } | |