Spaces:
Sleeping
Sleeping
import numpy as np | |
import onnxruntime as ort | |
from transformers import MarianTokenizer | |
import gradio as gr | |
# Load the tokenizer from the local folder | |
model_path = "./onnx_model" # Path to the folder containing the model files | |
tokenizer = MarianTokenizer.from_pretrained(model_path) | |
# Load the ONNX model | |
onnx_model_path = "./model.onnx" | |
session = ort.InferenceSession(onnx_model_path) | |
def translate_text(input_texts): | |
# Tokenize input texts (batch processing) | |
tokenized_input = tokenizer( | |
input_texts, return_tensors="np", padding=True, truncation=True, max_length=512 | |
) | |
input_ids = tokenized_input["input_ids"] | |
attention_mask = tokenized_input["attention_mask"] | |
# Define the decoder start token ID | |
decoder_start_token_id = tokenizer.pad_token_id # Use pad_token_id as the decoder start token | |
decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64) | |
# Prepare inputs for ONNX model | |
ort_inputs = { | |
"input_ids": input_ids.astype(np.int64), | |
"attention_mask": attention_mask.astype(np.int64), | |
"decoder_input_ids": decoder_input_ids, | |
} | |
# Run inference using the ONNX model | |
ort_outputs = session.run(None, ort_inputs) | |
output_ids = ort_outputs[0] # Get the output token IDs | |
# Debug: Inspect the structure of output_ids | |
print("Output IDs shape:", output_ids.shape) | |
print("Output IDs:", output_ids) | |
# Ensure output_ids is in the correct format (2D array) | |
if isinstance(output_ids, list): | |
output_ids = np.array(output_ids) # Convert list to numpy array if necessary | |
if output_ids.ndim > 2: | |
output_ids = output_ids.squeeze(0) # Remove extra dimensions if present | |
# Decode the output tokens | |
translated_texts = tokenizer.batch_decode(output_ids, skip_special_tokens=True) | |
return translated_texts | |
# Gradio interface | |
interface = gr.Interface( | |
fn=translate_text, | |
inputs="text", | |
outputs="text", | |
title="Frenchizer Translation Model", | |
description="Translate text with MarianMT ONNX model and encoding by batches." | |
) | |
# Launch the interface | |
interface.launch() |