|
import torch |
|
import gradio as gr |
|
from transformers import pipeline |
|
import ast |
|
|
|
translation_task_names = { |
|
'English to French': 'translation_en_to_fr', |
|
|
|
|
|
|
|
'English to German': 'translation_en_to_de', |
|
|
|
|
|
|
|
'English to Dutch': 'translation_en_to_nl', |
|
'Dutch to English': 'translation_nl_to_en', |
|
|
|
|
|
'English to Russian': 'translation_en_to_ru', |
|
'Russian to English': 'translation_ru_to_en', |
|
'English to Chinese': 'translation_en_to_zh', |
|
'Chinese to English': 'translation_zh_to_en', |
|
|
|
|
|
'English to Romanian': 'translation_en_to_ro', |
|
'Swedish to English': 'translation_SV_to_EN', |
|
} |
|
|
|
model_names = { |
|
'T5-Base': 't5-base', |
|
'T5-Small': 't5-small', |
|
'T5-Large': 't5-large', |
|
'Opus-En-ZH': 'liam168/trans-opus-mt-en-zh', |
|
'Opus-ZH-En': 'Helsinki-NLP/opus-mt-zh-en', |
|
'DDDSSS/translation_en-zh': 'DDDSSS/translation_en-zh', |
|
'T5-Base-nl-en': 'yhavinga/t5-base-36L-ccmatrix-multi', |
|
'T5-Small-nl-en': 'yhavinga/t5-small-24L-ccmatrix-multi', |
|
'Opus-Sv-En': 'Helsinki-NLP/opus-mt-sv-en', |
|
'Opus-En-Ru': 'Helsinki-NLP/opus-mt-en-ru', |
|
'Opus-Ru-En': 'Helsinki-NLP/opus-mt-ru-en', |
|
} |
|
|
|
|
|
loaded_models = {} |
|
|
|
|
|
def translate_text(model_choice, task_choice, text_input, load_in_8bit, device): |
|
model_key = (model_choice, task_choice, load_in_8bit) |
|
|
|
|
|
if model_key in loaded_models: |
|
translator = loaded_models[model_key] |
|
else: |
|
model_kwargs = {"load_in_8bit": load_in_8bit} if load_in_8bit else {} |
|
dtype = torch.float16 if load_in_8bit else torch.float32 |
|
translator = pipeline(task=translation_task_names[task_choice], |
|
model=model_names[model_choice], |
|
device=device, |
|
model_kwargs=model_kwargs, |
|
torch_dtype=dtype, |
|
use_fast=True |
|
) |
|
|
|
loaded_models[model_key] = translator |
|
|
|
translation = translator(text_input)[0]['translation_text'] |
|
return str(translation).strip() |
|
|
|
def launch(model_choice, task_choice, text_input, load_in_8bit, device): |
|
return translate_text(model_choice, task_choice, text_input, load_in_8bit, device) |
|
|
|
model_dropdown = gr.Dropdown(choices=list(model_names.keys()), label='Select Model') |
|
task_dropdown = gr.Dropdown(choices=list(translation_task_names.keys()), label='Select Translation Task') |
|
text_input = gr.Textbox(label="Input Text") |
|
load_in_8bit = gr.Checkbox(label="Load model in 8bit") |
|
|
|
device = gr.Radio(['cpu', 'cuda'], label='Select device', value='cpu') |
|
|
|
iface = gr.Interface(launch, inputs=[model_dropdown, task_dropdown, text_input, load_in_8bit, device], |
|
outputs=gr.Textbox(type="text", label="Translation")) |
|
iface.launch() |