File size: 2,134 Bytes
c9d7167
 
65ed74c
c9d7167
 
65ed74c
 
 
fc36581
65ed74c
 
 
c9d7167
65ed74c
 
fc36581
65ed74c
c9d7167
fc36581
 
051b67f
 
 
76d4210
65ed74c
 
 
 
 
 
 
 
 
 
 
89b5af7
80505e2
 
 
 
 
 
 
 
 
 
89b5af7
65ed74c
 
 
 
 
47e25c1
 
65ed74c
47e25c1
 
65ed74c
c9d7167
65ed74c
c9d7167
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import numpy as np
import onnxruntime as ort
from transformers import MarianTokenizer
import gradio as gr

# Load the tokenizer from the local folder
model_path = "./onnx_model"  # Path to the folder containing the model files
tokenizer = MarianTokenizer.from_pretrained(model_path)

# Load the ONNX model
onnx_model_path = "./model.onnx"
session = ort.InferenceSession(onnx_model_path)

def translate_text(input_texts):
    # Tokenize input texts (batch processing)
    tokenized_input = tokenizer(
        input_texts, return_tensors="np", padding=True, truncation=True, max_length=512
    )
    input_ids = tokenized_input["input_ids"]
    attention_mask = tokenized_input["attention_mask"]
    
    # Define the decoder start token ID
    decoder_start_token_id = tokenizer.pad_token_id  # Use pad_token_id as the decoder start token
    decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)

    # Prepare inputs for ONNX model
    ort_inputs = {
        "input_ids": input_ids.astype(np.int64),
        "attention_mask": attention_mask.astype(np.int64),
        "decoder_input_ids": decoder_input_ids,
    }

    # Run inference using the ONNX model
    ort_outputs = session.run(None, ort_inputs)
    output_ids = ort_outputs[0]  # Get the output token IDs

    # Debug: Inspect the structure of output_ids
    print("Output IDs shape:", output_ids.shape)
    print("Output IDs:", output_ids)

    # Ensure output_ids is in the correct format (2D array)
    if isinstance(output_ids, list):
        output_ids = np.array(output_ids)  # Convert list to numpy array if necessary
    if output_ids.ndim > 2:
        output_ids = output_ids.squeeze(0)  # Remove extra dimensions if present

    # Decode the output tokens
    translated_texts = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    return translated_texts

# Gradio interface
interface = gr.Interface(
	    fn=translate_text,
    inputs="text",
    outputs="text",
    title="Frenchizer Translation Model",
    description="Translate text with MarianMT ONNX model and encoding by batches."
)

# Launch the interface
interface.launch()