space_23 / app.py
Frenchizer's picture
Update app.py
fc36581
raw
history blame
1.36 kB
import numpy as np
import onnxruntime as ort
import torch
from transformers import MarianMTModel, MarianTokenizer
import gradio as gr
# Load the MarianMT model and tokenizer from the local folder
model_path = "./model.onnx" # Path to the folder containing the model files
tokenizer = MarianTokenizer.from_pretrained(model_name)
decoder_model = MarianMTModel.from_pretrained(model_name).get_decoder()
# Load the ONNX encoder
encoder_session = ort.InferenceSession("./onnx_model/encoder.onnx")
def translate_text(input_text):
# Tokenize input text
tokenized_input = tokenizer(
input_text, return_tensors="pt", padding=True, truncation=True, max_length=512
)
input_ids = tokenized_input["input_ids"]
attention_mask = tokenized_input["attention_mask"]
# Generate translation using the model
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=512, # Maximum length of the output
num_beams=5, # Use beam search for better translations
early_stopping=True, # Stop generation when the model predicts the end-of-sequence token
)
# Decode the output tokens
translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return translated_text
interface.launch()