space_2 / app.py
Frenchizer's picture
Upload 2 files
60b3f0e
raw
history blame
2.82 kB
import gradio as gr
from transformers import AutoTokenizer, pipeline
import numpy as np
MODEL_FILE = "./model.onnx"
session = ort.InferenceSession(MODEL_FILE)
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
# Predefined labels for context detection
labels = [
"general", "pharma", "legal", "technical", "UI", "medicine", "it", "marketing",
"e-commerce", "programming", "website", "html", "keywords", "food commerce",
"personal development", "literature", "poetry", "physics", "chemistry", "biology",
"theater", "finance", "sports", "education", "politics", "economics", "art",
"history", "music", "gaming", "aerospace", "engineering", "robotics", "travel",
"tourism", "healthcare", "psychology", "environment", "fashion", "design",
"real estate", "retail", "news", "entertainment", "social media", "automotive",
"machine learning", "cryptocurrency", "blockchain", "philosophy", "anthropology",
"archaeology", "data science"
]
# Context detection pipeline
context_pipeline = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
def detect_context(input_text):
result = context_pipeline(input_text, candidate_labels=labels)
return result["labels"][0]
def gradio_predict(input_text):
try:
tokenized_input = tokenizer(
input_text, return_tensors="np", padding=True, truncation=True, max_length=512
)
input_ids = tokenized_input["input_ids"].astype(np.int64)
attention_mask = tokenized_input["attention_mask"].astype(np.int64)
decoder_start_token_id = tokenizer.cls_token_id or tokenizer.pad_token_id
decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
for _ in range(512):
outputs = session.run(
None,
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
}
)
logits = outputs[0]
next_token_id = np.argmax(logits[:, -1, :], axis=-1).item()
decoder_input_ids = np.concatenate(
[decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1
)
if next_token_id == tokenizer.eos_token_id:
break
translated_text = tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
return {
"translation": translated_text,
"context": detect_context(input_text)
}
except Exception as e:
return f"Error: {str(e)}"
gr.Interface(
fn=gradio_predict,
inputs="text",
outputs=["text", "text"],
live=True
).launch()