Spaces:
Sleeping
Sleeping
import numpy as np | |
import onnxruntime as ort | |
from transformers import MarianTokenizer | |
import gradio as gr | |
# Load tokenizer | |
tokenizer_path = "./onnx_model" | |
tokenizer = MarianTokenizer.from_pretrained(tokenizer_path) | |
# Load ONNX model | |
onnx_model_path = "./model.onnx" | |
session = ort.InferenceSession(onnx_model_path) | |
def translate(text, max_length=512): | |
# Tokenize input | |
inputs = tokenizer(text, return_tensors="np", padding=True, truncation=True, max_length=max_length) | |
input_ids = inputs["input_ids"].astype(np.int64) | |
attention_mask = inputs["attention_mask"].astype(np.int64) | |
# Initialize decoder | |
batch_size = input_ids.shape[0] | |
decoder_input_ids = np.full((batch_size, 1), tokenizer.pad_token_id, dtype=np.int64) | |
eos_reached = np.zeros(batch_size, dtype=bool) | |
outputs = [] | |
for _ in range(max_length): | |
onnx_outputs = session.run( | |
None, | |
{ | |
"input_ids": input_ids, | |
"attention_mask": attention_mask, | |
"decoder_input_ids": decoder_input_ids, | |
}, | |
) | |
next_token_logits = onnx_outputs[0][:, -1, :] # Get last token predictions | |
next_tokens = np.argmax(next_token_logits, axis=-1).reshape(-1, 1) # Select highest probability token | |
# Append new tokens to decoder input | |
decoder_input_ids = np.hstack([decoder_input_ids, next_tokens]) | |
# Stop if all sentences have reached EOS | |
eos_reached |= (next_tokens == tokenizer.eos_token_id).flatten() | |
if eos_reached.all(): | |
break | |
# Decode output tokens | |
translated_texts = tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=True) | |
return translated_texts[0] | |
# Gradio interface | |
with gr.Blocks() as interface: | |
input_text = gr.Textbox(label="Input Text") | |
output_translation = gr.Textbox(label="Translation") | |
translate_button = gr.Button("Translate") | |
translate_button.click(fn=translate, inputs=[input_text], outputs=[output_translation]) | |
if __name__ == "__main__": | |
interface.launch() | |