space_23 / app.py
Frenchizer's picture
Update app.py
80505e2 verified
raw
history blame
2.13 kB
import numpy as np
import onnxruntime as ort
from transformers import MarianTokenizer
import gradio as gr
# Load the tokenizer from the local folder
model_path = "./onnx_model" # Path to the folder containing the model files
tokenizer = MarianTokenizer.from_pretrained(model_path)
# Load the ONNX model
onnx_model_path = "./model.onnx"
session = ort.InferenceSession(onnx_model_path)
def translate_text(input_texts):
# Tokenize input texts (batch processing)
tokenized_input = tokenizer(
input_texts, return_tensors="np", padding=True, truncation=True, max_length=512
)
input_ids = tokenized_input["input_ids"]
attention_mask = tokenized_input["attention_mask"]
# Define the decoder start token ID
decoder_start_token_id = tokenizer.pad_token_id # Use pad_token_id as the decoder start token
decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
# Prepare inputs for ONNX model
ort_inputs = {
"input_ids": input_ids.astype(np.int64),
"attention_mask": attention_mask.astype(np.int64),
"decoder_input_ids": decoder_input_ids,
}
# Run inference using the ONNX model
ort_outputs = session.run(None, ort_inputs)
output_ids = ort_outputs[0] # Get the output token IDs
# Debug: Inspect the structure of output_ids
print("Output IDs shape:", output_ids.shape)
print("Output IDs:", output_ids)
# Ensure output_ids is in the correct format (2D array)
if isinstance(output_ids, list):
output_ids = np.array(output_ids) # Convert list to numpy array if necessary
if output_ids.ndim > 2:
output_ids = output_ids.squeeze(0) # Remove extra dimensions if present
# Decode the output tokens
translated_texts = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
return translated_texts
# Gradio interface
interface = gr.Interface(
fn=translate_text,
inputs="text",
outputs="text",
title="Frenchizer Translation Model",
description="Translate text with MarianMT ONNX model and encoding by batches."
)
# Launch the interface
interface.launch()