Omnibus's picture
Update app.py
efc5d92
raw
history blame
5.1 kB
import gradio as gr
import os
import torch
import gradio as gr
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_1.2B")
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_1.2B").to(device)
model.eval()
lang_id = {
"":"",
"Afrikaans": "af",
"Albanian": "sq",
"Amharic": "am",
"Arabic": "ar",
"Armenian": "hy",
"Asturian": "ast",
"Azerbaijani": "az",
"Bashkir": "ba",
"Belarusian": "be",
"Bulgarian": "bg",
"Bengali": "bn",
"Breton": "br",
"Bosnian": "bs",
"Burmese": "my",
"Catalan": "ca",
"Cebuano": "ceb",
"Chinese": "zh",
"Chinese (simplified)": "zh",
"Chinese (traditional)": "zh",
"Croatian": "hr",
"Czech": "cs",
"Danish": "da",
"Dutch": "nl",
"English": "en",
"Estonian": "et",
"Fulah": "ff",
"Finnish": "fi",
"French": "fr",
"Western Frisian": "fy",
"Gaelic": "gd",
"Galician": "gl",
"Georgian": "ka",
"German": "de",
"Greek": "el",
"Gujarati": "gu",
"Hausa": "ha",
"Hebrew": "he",
"Hindi": "hi",
"Haitian": "ht",
"Hungarian": "hu",
"Irish": "ga",
"Indonesian": "id",
"Igbo": "ig",
"Iloko": "ilo",
"Icelandic": "is",
"Italian": "it",
"Japanese": "ja",
"Javanese": "jv",
"Kazakh": "kk",
"Central Khmer": "km",
"Kannada": "kn",
"Korean": "ko",
"Luxembourgish": "lb",
"Ganda": "lg",
"Lingala": "ln",
"Lao": "lo",
"Lithuanian": "lt",
"Latvian": "lv",
"Malagasy": "mg",
"Macedonian": "mk",
"Malayalam": "ml",
"Mongolian": "mn",
"Marathi": "mr",
"Malay": "ms",
"Nepali": "ne",
"Norwegian": "no",
"Northern Sotho": "ns",
"Occitan": "oc",
"Oriya": "or",
"Panjabi": "pa",
"Persian": "fa",
"Polish": "pl",
"Pushto": "ps",
"Portuguese": "pt",
"Romanian": "ro",
"Russian": "ru",
"Sindhi": "sd",
"Sinhala": "si",
"Slovak": "sk",
"Slovenian": "sl",
"Spanish": "es",
"Somali": "so",
"Serbian": "sr",
"Serbian (cyrillic)": "sr",
"Serbian (latin)": "sr",
"Swati": "ss",
"Sundanese": "su",
"Swedish": "sv",
"Swahili": "sw",
"Tamil": "ta",
"Thai": "th",
"Tagalog": "tl",
"Tswana": "tn",
"Turkish": "tr",
"Ukrainian": "uk",
"Urdu": "ur",
"Uzbek": "uz",
"Vietnamese": "vi",
"Welsh": "cy",
"Wolof": "wo",
"Xhosa": "xh",
"Yiddish": "yi",
"Yoruba": "yo",
"Zulu": "zu",
}
def trans_page(input,trg):
src_lang = lang_id["English"]
trg_lang = lang_id[trg]
if trg_lang != src_lang:
tokenizer.src_lang = src_lang
with torch.no_grad():
encoded_input = tokenizer(input, return_tensors="pt").to(device)
generated_tokens = model.generate(**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang))
translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
else:
translated_text=input
pass
return translated_text
def trans_to(input,src,trg):
src_lang = lang_id[src]
trg_lang = lang_id[trg]
if trg_lang != src_lang:
tokenizer.src_lang = src_lang
with torch.no_grad():
encoded_input = tokenizer(input, return_tensors="pt").to(device)
generated_tokens = model.generate(**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang))
translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
else:
translated_text=input
pass
return translated_text
md1 = "Translate - 100 Languages"
with gr.Blocks() as transbot:
#lang_id=gr.State({})
with gr.Row():
gr.Column()
with gr.Column():
with gr.Row():
t_space = gr.Dropdown(label="Translate Space", choices=list(lang_id.keys()),value="English")
t_submit = gr.Button("Translate Space")
gr.Column()
with gr.Row():
gr.Column()
with gr.Column():
md = gr.Markdown(""f"<h1><center>{md1}</center></h1><h4><center>Translation may not be accurate</center></h4>""")
with gr.Row():
lang_from = gr.Dropdown(label="From:", choices=list(lang_id.keys()),value="English")
lang_to = gr.Dropdown(label="To:", choices=list(lang_id.keys()),value="Chinese")
submit = gr.Button("Go")
with gr.Row():
with gr.Column():
message = gr.Textbox(label="Prompt",placeholder="Enter Prompt",lines=4)
translated = gr.Textbox(label="Translated",lines=4,interactive=False)
gr.Column()
t_submit.click(trans_page,[md,t_space],md)
submit.click(trans_to, inputs=[message,lang_from,lang_to], outputs=[translated])
transbot.queue(concurrency_count=20)
transbot.launch()