space_18 / app.py
Frenchizer's picture
Update app.py
7a6d259 verified
import numpy as np
import onnxruntime as ort
from transformers import MarianTokenizer
import gradio as gr
# Load tokenizer
tokenizer_path = "./onnx_model"
tokenizer = MarianTokenizer.from_pretrained(tokenizer_path)
# Load ONNX model
onnx_model_path = "./model.onnx"
session = ort.InferenceSession(onnx_model_path)
def translate(text, max_length=512):
# Tokenize input
inputs = tokenizer(text, return_tensors="np", padding=True, truncation=True, max_length=max_length)
input_ids = inputs["input_ids"].astype(np.int64)
attention_mask = inputs["attention_mask"].astype(np.int64)
# Initialize decoder
batch_size = input_ids.shape[0]
decoder_input_ids = np.full((batch_size, 1), tokenizer.pad_token_id, dtype=np.int64)
eos_reached = np.zeros(batch_size, dtype=bool)
outputs = []
for _ in range(max_length):
onnx_outputs = session.run(
None,
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
},
)
next_token_logits = onnx_outputs[0][:, -1, :] # Get last token predictions
next_tokens = np.argmax(next_token_logits, axis=-1).reshape(-1, 1) # Select highest probability token
# Append new tokens to decoder input
decoder_input_ids = np.hstack([decoder_input_ids, next_tokens])
# Stop if all sentences have reached EOS
eos_reached |= (next_tokens == tokenizer.eos_token_id).flatten()
if eos_reached.all():
break
# Decode output tokens
translated_texts = tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=True)
return translated_texts[0]
# Gradio interface
with gr.Blocks() as interface:
input_text = gr.Textbox(label="Input Text")
output_translation = gr.Textbox(label="Translation")
translate_button = gr.Button("Translate")
translate_button.click(fn=translate, inputs=[input_text], outputs=[output_translation])
if __name__ == "__main__":
interface.launch()