File size: 2,823 Bytes
60b3f0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr
from transformers import AutoTokenizer, pipeline
import numpy as np

MODEL_FILE = "./model.onnx"
session = ort.InferenceSession(MODEL_FILE)
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")

# Predefined labels for context detection
labels = [
    "general", "pharma", "legal", "technical", "UI", "medicine", "it", "marketing", 
    "e-commerce", "programming", "website", "html", "keywords", "food commerce", 
    "personal development", "literature", "poetry", "physics", "chemistry", "biology", 
    "theater", "finance", "sports", "education", "politics", "economics", "art", 
    "history", "music", "gaming", "aerospace", "engineering", "robotics", "travel", 
    "tourism", "healthcare", "psychology", "environment", "fashion", "design", 
    "real estate", "retail", "news", "entertainment", "social media", "automotive", 
    "machine learning", "cryptocurrency", "blockchain", "philosophy", "anthropology", 
    "archaeology", "data science"
]

# Context detection pipeline
context_pipeline = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

def detect_context(input_text):
    result = context_pipeline(input_text, candidate_labels=labels)
    return result["labels"][0]

def gradio_predict(input_text):
    try:
        tokenized_input = tokenizer(
            input_text, return_tensors="np", padding=True, truncation=True, max_length=512
        )
        input_ids = tokenized_input["input_ids"].astype(np.int64)
        attention_mask = tokenized_input["attention_mask"].astype(np.int64)

        decoder_start_token_id = tokenizer.cls_token_id or tokenizer.pad_token_id
        decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)

        for _ in range(512):
            outputs = session.run(
                None,
                {
                    "input_ids": input_ids,
                    "attention_mask": attention_mask,
                    "decoder_input_ids": decoder_input_ids,
                }
            )

            logits = outputs[0]
            next_token_id = np.argmax(logits[:, -1, :], axis=-1).item()
            decoder_input_ids = np.concatenate(
                [decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1
            )

            if next_token_id == tokenizer.eos_token_id:
                break

        translated_text = tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)

        return {
            "translation": translated_text,
            "context": detect_context(input_text)
        }

    except Exception as e:
        return f"Error: {str(e)}"

gr.Interface(
    fn=gradio_predict,
    inputs="text",
    outputs=["text", "text"],
    live=True
).launch()