Frenchizer commited on
Commit
60b3f0e
·
1 Parent(s): 2d236b6

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +74 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, pipeline
3
+ import numpy as np
4
+
5
+ MODEL_FILE = "./model.onnx"
6
+ session = ort.InferenceSession(MODEL_FILE)
7
+ tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
8
+
9
+ # Predefined labels for context detection
10
+ labels = [
11
+ "general", "pharma", "legal", "technical", "UI", "medicine", "it", "marketing",
12
+ "e-commerce", "programming", "website", "html", "keywords", "food commerce",
13
+ "personal development", "literature", "poetry", "physics", "chemistry", "biology",
14
+ "theater", "finance", "sports", "education", "politics", "economics", "art",
15
+ "history", "music", "gaming", "aerospace", "engineering", "robotics", "travel",
16
+ "tourism", "healthcare", "psychology", "environment", "fashion", "design",
17
+ "real estate", "retail", "news", "entertainment", "social media", "automotive",
18
+ "machine learning", "cryptocurrency", "blockchain", "philosophy", "anthropology",
19
+ "archaeology", "data science"
20
+ ]
21
+
22
+ # Context detection pipeline
23
+ context_pipeline = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
24
+
25
+ def detect_context(input_text):
26
+ result = context_pipeline(input_text, candidate_labels=labels)
27
+ return result["labels"][0]
28
+
29
+ def gradio_predict(input_text):
30
+ try:
31
+ tokenized_input = tokenizer(
32
+ input_text, return_tensors="np", padding=True, truncation=True, max_length=512
33
+ )
34
+ input_ids = tokenized_input["input_ids"].astype(np.int64)
35
+ attention_mask = tokenized_input["attention_mask"].astype(np.int64)
36
+
37
+ decoder_start_token_id = tokenizer.cls_token_id or tokenizer.pad_token_id
38
+ decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
39
+
40
+ for _ in range(512):
41
+ outputs = session.run(
42
+ None,
43
+ {
44
+ "input_ids": input_ids,
45
+ "attention_mask": attention_mask,
46
+ "decoder_input_ids": decoder_input_ids,
47
+ }
48
+ )
49
+
50
+ logits = outputs[0]
51
+ next_token_id = np.argmax(logits[:, -1, :], axis=-1).item()
52
+ decoder_input_ids = np.concatenate(
53
+ [decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1
54
+ )
55
+
56
+ if next_token_id == tokenizer.eos_token_id:
57
+ break
58
+
59
+ translated_text = tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
60
+
61
+ return {
62
+ "translation": translated_text,
63
+ "context": detect_context(input_text)
64
+ }
65
+
66
+ except Exception as e:
67
+ return f"Error: {str(e)}"
68
+
69
+ gr.Interface(
70
+ fn=gradio_predict,
71
+ inputs="text",
72
+ outputs=["text", "text"],
73
+ live=True
74
+ ).launch()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ onnxruntime
4
+ transformers
5
+ numpy
6
+ pydantic
7
+ requests
8
+ gradio
9
+ sentencepiece
10
+ sacremoses