TenzinGayche's picture
Update app.py
b508b39 verified
import os
import torch
import gradio as gr
import time
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
# Define flores_codes dictionary
flores_codes = {
"Standard Tibetan": "bod_Tibt",
"English": "eng_Latn"
}
def load_models():
# build model and tokenizer
model_name_dict = {
'nllb-biboen-1': 'TenzinGayche/nllb_600M_bi_boen',
'nllb-biboen-2': 'TenzinGayche/nllb_600M_bi_boen_gold',
'nllb-biboen-3': 'TenzinGayche/nllb_600M_bi_boen_3',
}
model_dict = {}
for call_name, real_name in model_name_dict.items():
print('\tLoading model: %s' % call_name)
model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
model_dict[call_name + '_model'] = model
model_dict[call_name + '_tokenizer'] = tokenizer
return model_dict
def translation(model_name, source, target, text):
start_time = time.time()
source = flores_codes[source]
target = flores_codes[target]
model = model_dict[model_name + '_model']
tokenizer = model_dict[model_name + '_tokenizer']
# Check if a GPU is available and set device accordingly
device = 0 if torch.cuda.is_available() else -1
translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target, device=device)
output = translator(text, max_length=400)
end_time = time.time()
output = output[0]['translation_text']
return output
if __name__ == '__main__':
print('\tinit models')
global model_dict
model_dict = load_models()
# Define gradio demo
lang_codes = list(flores_codes.keys())
with gr.Blocks() as demo:
gr.Markdown("# NLLB Distilled 600M Translation Demo")
gr.Markdown("This demo allows you to test the translation models for English to Standard Tibetan and vice versa.")
with gr.Row():
with gr.Column():
model_input = gr.Radio(['nllb-biboen-1', 'nllb-biboen-2','nllb-biboen-3'], label='Select NLLB Model')
source_lang = gr.Dropdown(lang_codes, value='English', label='Source Language')
target_lang = gr.Dropdown(lang_codes, value='Standard Tibetan', label='Target Language')
input_text = gr.Textbox(lines=5, label="Input Text", placeholder="Enter the text you want to translate")
with gr.Column():
output_text = gr.Textbox(lines=5, label="Translated Text", interactive=False, placeholder="The translated text will appear here")
def update_output(model_name, source, target, text):
result = translation(model_name, source, target, text)
output_text.value = result
return result
translate_button = gr.Button("Translate")
translate_button.click(update_output, inputs=[model_input, source_lang, target_lang, input_text], outputs=output_text)
demo.launch()