Spaces:
Running
Running
import gradio as gr | |
import os | |
import torch | |
import gradio as gr | |
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration | |
if torch.cuda.is_available(): | |
device = torch.device("cuda:0") | |
else: | |
device = torch.device("cpu") | |
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_1.2B") | |
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_1.2B").to(device) | |
model.eval() | |
l1="Afrikaans" | |
class Language: | |
def __init__(self, name, code): | |
self.name = name | |
self.code = code | |
lang_id = [ | |
"":"", | |
Language("Afrikaans", "af"), | |
Language("Albanian", "sq"), | |
Language("Amharic", "am"), | |
Language("Arabic", "ar"), | |
Language("Armenian", "hy"), | |
Language("Asturian", "ast"), | |
Language("Azerbaijani", "az"), | |
Language("Bashkir", "ba"), | |
Language("Belarusian", "be"), | |
Language("Bulgarian", "bg"), | |
Language("Bengali", "bn"), | |
Language("Breton", "br"), | |
Language("Bosnian", "bs"), | |
Language("Burmese", "my"), | |
Language("Catalan", "ca"), | |
Language("Cebuano", "ceb"), | |
Language("Chinese","zh"), | |
Language("Croatian","hr"), | |
Language("Czech","cs"), | |
Language("Danish","da"), | |
Language("Dutch","nl"), | |
Language("English","en"), | |
Language("Estonian","et"), | |
Language("Fulah","ff"), | |
Language("Finnish","fi"), | |
Language("French","fr"), | |
Language("Western Frisian","fy"), | |
Language("Gaelic","gd"), | |
Language("Galician","gl"), | |
Language("Georgian","ka"), | |
Language("German","de"), | |
Language("Greek","el"), | |
Language("Gujarati","gu"), | |
Language("Hausa","ha"), | |
Language("Hebrew","he"), | |
Language("Hindi","hi"), | |
Language("Haitian","ht"), | |
Language("Hungarian","hu"), | |
Language("Irish","ga"), | |
Language("Indonesian","id"), | |
Language("Igbo","ig"), | |
Language("",""), | |
Language("",""), | |
Language("",""), | |
Language("",""), | |
Language("",""), | |
Language("",""), | |
Language("",""), | |
Language("",""), | |
Language("",""), | |
Language("",""), | |
Language("",""), | |
Language("",""), | |
Language("",""), | |
"Iloko": "ilo", | |
"Icelandic": "is", | |
"Italian": "it", | |
"Japanese": "ja", | |
"Javanese": "jv", | |
"Kazakh": "kk", | |
"Central Khmer": "km", | |
"Kannada": "kn", | |
"Korean": "ko", | |
"Luxembourgish": "lb", | |
"Ganda": "lg", | |
"Lingala": "ln", | |
"Lao": "lo", | |
"Lithuanian": "lt", | |
"Latvian": "lv", | |
"Malagasy": "mg", | |
"Macedonian": "mk", | |
"Malayalam": "ml", | |
"Mongolian": "mn", | |
"Marathi": "mr", | |
"Malay": "ms", | |
"Nepali": "ne", | |
"Norwegian": "no", | |
"Northern Sotho": "ns", | |
"Occitan": "oc", | |
"Oriya": "or", | |
"Panjabi": "pa", | |
"Persian": "fa", | |
"Polish": "pl", | |
"Pushto": "ps", | |
"Portuguese": "pt", | |
"Romanian": "ro", | |
"Russian": "ru", | |
"Sindhi": "sd", | |
"Sinhala": "si", | |
"Slovak": "sk", | |
"Slovenian": "sl", | |
"Spanish": "es", | |
"Somali": "so", | |
"Serbian": "sr", | |
"Serbian (cyrillic)": "sr", | |
"Serbian (latin)": "sr", | |
"Swati": "ss", | |
"Sundanese": "su", | |
"Swedish": "sv", | |
"Swahili": "sw", | |
"Tamil": "ta", | |
"Thai": "th", | |
"Tagalog": "tl", | |
"Tswana": "tn", | |
"Turkish": "tr", | |
"Ukrainian": "uk", | |
"Urdu": "ur", | |
"Uzbek": "uz", | |
"Vietnamese": "vi", | |
"Welsh": "cy", | |
"Wolof": "wo", | |
"Xhosa": "xh", | |
"Yiddish": "yi", | |
"Yoruba": "yo", | |
"Zulu": "zu", | |
] | |
def trans_page(input,input1,trg): | |
src_lang = lang_id["English"] | |
trg_lang = lang_id[trg] | |
if trg_lang != src_lang: | |
tokenizer.src_lang = src_lang | |
with torch.no_grad(): | |
encoded_input = tokenizer(input, return_tensors="pt").to(device) | |
generated_tokens = model.generate(**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang)) | |
translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
else: | |
translated_text=input | |
pass | |
if trg_lang != src_lang: | |
tokenizer.src_lang = src_lang | |
with torch.no_grad(): | |
#lang_tr = lang_id | |
encoded_input = tokenizer(lang_id, return_tensors="pt").to(device) | |
generated_tokens = model.generate(**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang)) | |
translated_text1 = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
else: | |
translated_text1=input1 | |
pass | |
return translated_text,gr.Dropdown.update(choices=list(translated_text1.keys())) | |
def trans_to(input,src,trg): | |
src_lang = lang_id[src] | |
trg_lang = lang_id[trg] | |
if trg_lang != src_lang: | |
tokenizer.src_lang = src_lang | |
with torch.no_grad(): | |
encoded_input = tokenizer(input, return_tensors="pt").to(device) | |
generated_tokens = model.generate(**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang)) | |
translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
else: | |
translated_text=input | |
pass | |
return translated_text | |
md1 = "Translate - 100 Languages" | |
with gr.Blocks() as transbot: | |
#this=gr.State() | |
with gr.Row(): | |
gr.Column() | |
with gr.Column(): | |
with gr.Row(): | |
t_space = gr.Dropdown(label="Translate Space", choices=list(lang_id.keys()),value="English") | |
t_submit = gr.Button("Translate Space") | |
gr.Column() | |
with gr.Row(): | |
gr.Column() | |
with gr.Column(): | |
md = gr.Markdown("""<h1><center>Translate - 100 Languages</center></h1><h4><center>Translation may not be accurate</center></h4>""") | |
with gr.Row(): | |
lang_from = gr.Dropdown(label="From:", choices=list(lang_id.keys()),value="English") | |
lang_to = gr.Dropdown(label="To:", choices=list(lang_id.keys()),value="Chinese") | |
submit = gr.Button("Go") | |
with gr.Row(): | |
with gr.Column(): | |
message = gr.Textbox(label="Prompt",placeholder="Enter Prompt",lines=4) | |
translated = gr.Textbox(label="Translated",lines=4,interactive=False) | |
gr.Column() | |
t_submit.click(trans_page,[md,lang_from,t_space],[md,lang_from]) | |
submit.click(trans_to, inputs=[message,lang_from,lang_to], outputs=[translated]) | |
transbot.queue(concurrency_count=20) | |
transbot.launch() |