Frenchizer commited on
Commit
f80fc89
·
verified ·
1 Parent(s): 0760540

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -12
app.py CHANGED
@@ -4,16 +4,16 @@ from transformers import MarianTokenizer
4
  import gradio as gr
5
 
6
  # Load the tokenizer from the local folder
7
- model_path = "./onnx_model" # Path to the folder containing the tokenizer files
8
- tokenizer = MarianTokenizer.from_pretrained(model_path)
9
 
10
  # Load the ONNX model
11
  onnx_model_path = "./model.onnx"
12
  session = ort.InferenceSession(onnx_model_path)
13
 
14
- def translate_text(input_texts, max_length=512):
15
  # Tokenize the input texts
16
- inputs = tokenizer(input_texts, return_tensors="np", padding=True, truncation=True, max_length=max_length)
17
  input_ids = inputs["input_ids"].astype(np.int64)
18
  attention_mask = inputs["attention_mask"].astype(np.int64)
19
 
@@ -24,7 +24,7 @@ def translate_text(input_texts, max_length=512):
24
  # Generate output tokens iteratively
25
  for _ in range(max_length):
26
  # Run the ONNX model
27
- ort_outputs = session.run(
28
  None,
29
  {
30
  "input_ids": input_ids,
@@ -34,7 +34,7 @@ def translate_text(input_texts, max_length=512):
34
  )
35
 
36
  # Get the next token logits (output of the ONNX model)
37
- next_token_logits = ort_outputs[0][:, -1, :] # Shape: (batch_size, vocab_size)
38
 
39
  # Greedy decoding: select the token with the highest probability
40
  next_tokens = np.argmax(next_token_logits, axis=-1) # Shape: (batch_size,)
@@ -51,17 +51,20 @@ def translate_text(input_texts, max_length=512):
51
  return translations
52
 
53
  # Gradio interface
54
- def gradio_translate(input_texts):
55
- translations = translate_text(input_texts)
56
- return translations
 
 
 
57
 
58
  # Create the Gradio interface
59
  interface = gr.Interface(
60
  fn=gradio_translate,
61
- inputs=gr.Textbox(lines=2, placeholder="Enter text to translate...", label="Input Text"),
62
- outputs=gr.Textbox(label="Translated Text"),
63
  title="ONNX English to French Translation",
64
- description="Translate English text to French using a MarianMT ONNX model.",
65
  )
66
 
67
  # Launch the Gradio app
 
4
  import gradio as gr
5
 
6
  # Load the tokenizer from the local folder
7
+ tokenizer_path = "./onnx_model" # Path to the local tokenizer folder
8
+ tokenizer = MarianTokenizer.from_pretrained(tokenizer_path)
9
 
10
  # Load the ONNX model
11
  onnx_model_path = "./model.onnx"
12
  session = ort.InferenceSession(onnx_model_path)
13
 
14
+ def translate(texts, max_length=512):
15
  # Tokenize the input texts
16
+ inputs = tokenizer(texts, return_tensors="np", padding=True, truncation=True, max_length=max_length)
17
  input_ids = inputs["input_ids"].astype(np.int64)
18
  attention_mask = inputs["attention_mask"].astype(np.int64)
19
 
 
24
  # Generate output tokens iteratively
25
  for _ in range(max_length):
26
  # Run the ONNX model
27
+ onnx_outputs = session.run(
28
  None,
29
  {
30
  "input_ids": input_ids,
 
34
  )
35
 
36
  # Get the next token logits (output of the ONNX model)
37
+ next_token_logits = onnx_outputs[0][:, -1, :] # Shape: (batch_size, vocab_size)
38
 
39
  # Greedy decoding: select the token with the highest probability
40
  next_tokens = np.argmax(next_token_logits, axis=-1) # Shape: (batch_size,)
 
51
  return translations
52
 
53
  # Gradio interface
54
+ def gradio_translate(input_text):
55
+ # Split the input text into lines (assuming one sentence per line)
56
+ texts = input_text.strip().split("\n")
57
+ translations = translate(texts)
58
+ # Join the translations into a single string with line breaks
59
+ return "\n".join(translations)
60
 
61
  # Create the Gradio interface
62
  interface = gr.Interface(
63
  fn=gradio_translate,
64
+ inputs=gr.Textbox(lines=5, placeholder="Enter text to translate...", label="Input Text"),
65
+ outputs=gr.Textbox(lines=5, label="Translated Text"),
66
  title="ONNX English to French Translation",
67
+ description="Translate English text to French using an ONNX model.",
68
  )
69
 
70
  # Launch the Gradio app