import gradio as gr
import os
import torch
import gradio as gr
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_1.2B")
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_1.2B").to(device)
model.eval()


class Language:
    def __init__(self, name, code):
        self.name = name
        self.code = code

lang_id = [
    Language("Afrikaans", "af"),
    Language("Albanian", "sq"),
    Language("Amharic", "am"),
    Language("Arabic", "ar"),
    Language("Armenian", "hy"),
    Language("Asturian", "ast"),
    Language("Azerbaijani", "az"),
    Language("Bashkir", "ba"),
    Language("Belarusian", "be"),
    Language("Bulgarian", "bg"),
    Language("Bengali", "bn"),
    Language("Breton", "br"),
    Language("Bosnian", "bs"),
    Language("Burmese", "my"),
    Language("Catalan", "ca"),
    Language("Cebuano", "ceb"),
    Language("Chinese","zh"),
    Language("Croatian","hr"),
    Language("Czech","cs"),
    Language("Danish","da"),
    Language("Dutch","nl"),
    Language("English","en"),
    Language("Estonian","et"),
    Language("Fulah","ff"),
    Language("Finnish","fi"),
    Language("French","fr"),
    Language("Western Frisian","fy"),
    Language("Gaelic","gd"),
    Language("Galician","gl"),
    Language("Georgian","ka"),
    Language("German","de"),
    Language("Greek","el"),
    Language("Gujarati","gu"),
    Language("Hausa","ha"),
    Language("Hebrew","he"),
    Language("Hindi","hi"),
    Language("Haitian","ht"),
    Language("Hungarian","hu"),
    Language("Irish","ga"),
    Language("Indonesian","id"),
    Language("Igbo","ig"),
    Language("Iloko","ilo"),
    Language("Icelandic","is"),
    Language("Italian","it"),
    Language("Japanese","ja"),
    Language("Javanese","jv"),
    Language("Kazakh","kk"),
    Language("Central Khmer","km"),
    Language("Kannada","kn"),
    Language("Korean","ko"),
    Language("Luxembourgish","lb"),
    Language("Ganda","lg"),
    Language("Lingala","ln"),
    Language("Lao","lo"),
    Language("Lithuanian","lt"),
    Language("Latvian","lv"),
    Language("Malagasy","mg"),
    Language("Macedonian","mk"),
    Language("Malayalam","ml"),
    Language("Mongolian","mn"),
    Language("Marathi","mr"),
    Language("Malay","ms"),
    Language("Nepali","ne"),
    Language("Norwegian","no"),
    Language("Northern Sotho","ns"),
    Language("Occitan","oc"),
    Language("Oriya","or"),
    Language("Panjabi","pa"),
    Language("Persian","fa"),
    Language("Polish","pl"),
    Language("Pushto","ps"),
    Language("Portuguese","pt"),
    Language("Romanian","ro"),
    Language("Russian","ru"),
    Language("Sindhi","sd"),
    Language("Sinhala","si"),
    Language("Slovak","sk"),
    Language("Slovenian","sl"),
    Language("Spanish","es"),
    Language("Somali","so"),
    Language("Serbian","sr"),
    Language("Serbian (cyrillic)","sr"),
    Language("Serbian (latin)","sr"),
    Language("Swati","ss"),
    Language("Sundanese","su"),
    Language("Swedish","sv"),
    Language("Swahili","sw"),
    Language("Tamil","ta"),
    Language("Thai","th"),
    Language("Tagalog","tl"),
    Language("Tswana","tn"),
    Language("Turkish","tr"),
    Language("Ukrainian","uk"),
    Language("Urdu","ur"),
    Language("Uzbek","uz"),
    Language("Vietnamese","vi"),
    Language("Welsh","cy"),
    Language("Wolof","wo"),
    Language("Xhosa","xh"),
    Language("Yiddish","yi"),
    Language("Yoruba","yo"),
    Language("Zulu","zu"),
]
d_lang = lang_id[21]
#d_lang_code = d_lang.code

def trans_page(input,trg):
    src_lang = d_lang.code
    for lang in lang_id:
        if lang.name == trg:
            trg_lang = lang.code
    if trg_lang != src_lang:
        tokenizer.src_lang = src_lang
        with torch.no_grad():
            encoded_input = tokenizer(input, return_tensors="pt").to(device)
            generated_tokens = model.generate(**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang))
            translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
    else:
        translated_text=input
        pass

     
        
    return translated_text

def trans_to(input,src,trg):
    for lang in lang_id:
        if lang.name == trg:
            trg_lang = lang.code    
    for lang in lang_id:
        if lang.name == src:
            src_lang = lang.code       
    if trg_lang != src_lang:
 
        tokenizer.src_lang = src_lang
        with torch.no_grad():
            encoded_input = tokenizer(input, return_tensors="pt").to(device)
            generated_tokens = model.generate(**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang))
            translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
    else:
        translated_text=input
        pass
    return translated_text

md1 = "Translate - 100 Languages"



with gr.Blocks() as transbot:
    #this=gr.State()
    with gr.Row():
        gr.Column()
        with gr.Column():
            with gr.Row():
                t_space = gr.Dropdown(label="Translate Space to:", choices=[l.name for l in lang_id], value="English")
                #t_space = gr.Dropdown(label="Translate Space", choices=list(lang_id.keys()),value="English")
                t_submit = gr.Button("Translate Space")
        gr.Column()
        
    with gr.Row():
        gr.Column()
        with gr.Column():
            md = gr.Markdown("""<h1><center>Translate - 100 Languages</center></h1><h4><center>Translation may not be accurate</center></h4>""")
            with gr.Row():

                lang_from = gr.Dropdown(label="From:", choices=[l.name for l in lang_id],value="English")
                lang_to = gr.Dropdown(label="To:", choices=[l.name for l in lang_id],value="Chinese")
                
                #lang_from = gr.Dropdown(label="From:", choices=list(lang_id.keys()),value="English")
                #lang_to = gr.Dropdown(label="To:", choices=list(lang_id.keys()),value="Chinese")
            submit = gr.Button("Go")
            with gr.Row():
                with gr.Column():
                    message = gr.Textbox(label="Prompt",placeholder="Enter Prompt",lines=4)
                    translated = gr.Textbox(label="Translated",lines=4,interactive=False)
        gr.Column()
    t_submit.click(trans_page,[md,t_space],[md])
    
    submit.click(trans_to, inputs=[message,lang_from,lang_to], outputs=[translated])
transbot.queue(concurrency_count=20)
transbot.launch()