Spaces:
Build error
Build error
import gradio as gr | |
from transformers import AutoTokenizer, pipeline | |
import numpy as np | |
MODEL_FILE = "./model.onnx" | |
session = ort.InferenceSession(MODEL_FILE) | |
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr") | |
# Predefined labels for context detection | |
labels = [ | |
"general", "pharma", "legal", "technical", "UI", "medicine", "it", "marketing", | |
"e-commerce", "programming", "website", "html", "keywords", "food commerce", | |
"personal development", "literature", "poetry", "physics", "chemistry", "biology", | |
"theater", "finance", "sports", "education", "politics", "economics", "art", | |
"history", "music", "gaming", "aerospace", "engineering", "robotics", "travel", | |
"tourism", "healthcare", "psychology", "environment", "fashion", "design", | |
"real estate", "retail", "news", "entertainment", "social media", "automotive", | |
"machine learning", "cryptocurrency", "blockchain", "philosophy", "anthropology", | |
"archaeology", "data science" | |
] | |
# Context detection pipeline | |
context_pipeline = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
def detect_context(input_text): | |
result = context_pipeline(input_text, candidate_labels=labels) | |
return result["labels"][0] | |
def gradio_predict(input_text): | |
try: | |
tokenized_input = tokenizer( | |
input_text, return_tensors="np", padding=True, truncation=True, max_length=512 | |
) | |
input_ids = tokenized_input["input_ids"].astype(np.int64) | |
attention_mask = tokenized_input["attention_mask"].astype(np.int64) | |
decoder_start_token_id = tokenizer.cls_token_id or tokenizer.pad_token_id | |
decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64) | |
for _ in range(512): | |
outputs = session.run( | |
None, | |
{ | |
"input_ids": input_ids, | |
"attention_mask": attention_mask, | |
"decoder_input_ids": decoder_input_ids, | |
} | |
) | |
logits = outputs[0] | |
next_token_id = np.argmax(logits[:, -1, :], axis=-1).item() | |
decoder_input_ids = np.concatenate( | |
[decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1 | |
) | |
if next_token_id == tokenizer.eos_token_id: | |
break | |
translated_text = tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True) | |
return { | |
"translation": translated_text, | |
"context": detect_context(input_text) | |
} | |
except Exception as e: | |
return f"Error: {str(e)}" | |
gr.Interface( | |
fn=gradio_predict, | |
inputs="text", | |
outputs=["text", "text"], | |
live=True | |
).launch() |