Frenchizer commited on
Commit
65ed74c
·
verified ·
1 Parent(s): 8d5320c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -22
app.py CHANGED
@@ -1,37 +1,50 @@
1
  import numpy as np
2
  import onnxruntime as ort
3
- import torch
4
- from transformers import MarianMTModel, MarianTokenizer
5
  import gradio as gr
6
 
7
- # Load the MarianMT model and tokenizer from the local folder
8
- model_path = "./model.onnx" # Path to the folder containing the model files
9
- tokenizer = MarianTokenizer.from_pretrained(model_name)
10
- decoder_model = MarianMTModel.from_pretrained(model_name).get_decoder()
11
 
12
- # Load the ONNX encoder
13
- encoder_session = ort.InferenceSession("./onnx_model/encoder.onnx")
 
14
 
15
- def translate_text(input_text):
16
- # Tokenize input text
17
  tokenized_input = tokenizer(
18
- input_text, return_tensors="pt", padding=True, truncation=True, max_length=512
19
  )
20
  input_ids = tokenized_input["input_ids"]
21
  attention_mask = tokenized_input["attention_mask"]
22
 
23
- # Generate translation using the model
24
- with torch.no_grad():
25
- outputs = model.generate(
26
- input_ids=input_ids,
27
- attention_mask=attention_mask,
28
- max_length=512, # Maximum length of the output
29
- num_beams=5, # Use beam search for better translations
30
- early_stopping=True, # Stop generation when the model predicts the end-of-sequence token
31
- )
 
 
 
 
32
 
33
  # Decode the output tokens
34
- translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
35
- return translated_text
 
 
 
 
 
 
 
 
 
36
 
 
37
  interface.launch()
 
1
  import numpy as np
2
  import onnxruntime as ort
3
+ from transformers import MarianTokenizer
 
4
  import gradio as gr
5
 
6
+ # Load the tokenizer from the local folder
7
+ model_path = "./onnx_model" # Path to the folder containing the model files
8
+ tokenizer = MarianTokenizer.from_pretrained(model_path)
 
9
 
10
+ # Load the ONNX model
11
+ onnx_model_path = "./model.onnx"
12
+ session = ort.InferenceSession(onnx_model_path)
13
 
14
+ def translate_text(input_texts):
15
+ # Tokenize input texts (batch processing)
16
  tokenized_input = tokenizer(
17
+ input_texts, return_tensors="np", padding=True, truncation=True, max_length=512
18
  )
19
  input_ids = tokenized_input["input_ids"]
20
  attention_mask = tokenized_input["attention_mask"]
21
 
22
+ decoder_start_token_id = translation_tokenizer.cls_token_id or translation_tokenizer.pad_token_id
23
+ decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
24
+
25
+ # Prepare inputs for ONNX model
26
+ ort_inputs = {
27
+ "input_ids": input_ids.astype(np.int64),
28
+ "attention_mask": attention_mask.astype(np.int64),
29
+ "decoder_input_ids": decoder_input_ids,
30
+ }
31
+
32
+ # Run inference using the ONNX model
33
+ ort_outputs = session.run(None, ort_inputs)
34
+ output_ids = ort_outputs[0] # Get the output token IDs
35
 
36
  # Decode the output tokens
37
+ translated_texts = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
38
+ return translated_texts
39
+
40
+ # Gradio interface
41
+ interface = gr.Interface(
42
+ fn=translate_text,
43
+ inputs=gr.inputs.Textbox(lines=2, placeholder="Enter text to translate..."),
44
+ outputs="text",
45
+ title="MarianMT Translation",
46
+ description="Translate text using MarianMT model with ONNX runtime.",
47
+ )
48
 
49
+ # Launch the interface
50
  interface.launch()