import numpy as np import onnxruntime as ort from transformers import MarianTokenizer import gradio as gr # Load tokenizer tokenizer_path = "./onnx_model" tokenizer = MarianTokenizer.from_pretrained(tokenizer_path) # Load ONNX model onnx_model_path = "./model.onnx" session = ort.InferenceSession(onnx_model_path) def translate(text, max_length=512): # Tokenize input inputs = tokenizer(text, return_tensors="np", padding=True, truncation=True, max_length=max_length) input_ids = inputs["input_ids"].astype(np.int64) attention_mask = inputs["attention_mask"].astype(np.int64) # Initialize decoder batch_size = input_ids.shape[0] decoder_input_ids = np.full((batch_size, 1), tokenizer.pad_token_id, dtype=np.int64) eos_reached = np.zeros(batch_size, dtype=bool) outputs = [] for _ in range(max_length): onnx_outputs = session.run( None, { "input_ids": input_ids, "attention_mask": attention_mask, "decoder_input_ids": decoder_input_ids, }, ) next_token_logits = onnx_outputs[0][:, -1, :] # Get last token predictions next_tokens = np.argmax(next_token_logits, axis=-1).reshape(-1, 1) # Select highest probability token # Append new tokens to decoder input decoder_input_ids = np.hstack([decoder_input_ids, next_tokens]) # Stop if all sentences have reached EOS eos_reached |= (next_tokens == tokenizer.eos_token_id).flatten() if eos_reached.all(): break # Decode output tokens translated_texts = tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=True) return translated_texts[0] # Gradio interface with gr.Blocks() as interface: input_text = gr.Textbox(label="Input Text") output_translation = gr.Textbox(label="Translation") translate_button = gr.Button("Translate") translate_button.click(fn=translate, inputs=[input_text], outputs=[output_translation]) if __name__ == "__main__": interface.launch()