Spaces:
Build error
Build error
import gradio as gr | |
import onnxruntime as ort | |
from transformers import AutoTokenizer | |
import numpy as np | |
# Initialize models | |
context_model_file = "./bart-large-mnli.onnx" | |
translation_model_file = "./model.onnx" | |
# Create inference sessions for both models | |
context_session = ort.InferenceSession(context_model_file) | |
translation_session = ort.InferenceSession(translation_model_file) | |
# Load tokenizers for context and translation models | |
context_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli") | |
translation_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr") | |
labels = [ | |
'aerospace', 'agriculture', 'anatomy', 'anthropology', 'architecture', | |
'art', 'automotive', 'astronomy', 'aviation', 'banking', | |
'biotechnology', 'biology', 'blockchain', 'business', | |
'chemistry', 'climate change', 'communication', 'computer science', | |
'construction', 'consumer goods', 'cryptocurrency', 'cybersecurity', | |
'dance', 'diplomacy', 'ecology', 'economics', | |
'education', 'energy', 'engineering', 'entrepreneurship', | |
'entertainment', 'ethics', 'fashion', 'finance', | |
'film', 'fitness', 'food commerce', 'general', | |
'gaming', 'geography', 'geology', 'graphic design', | |
'healthcare', 'history', 'html', 'human resources', | |
'immigration', 'innovation', 'journalism', | |
'keywords','language','law enforcement','legal','logistics','literature', | |
'machine learning','management','manufacturing','mathematics','media','military', | |
'music','nanotechnology','nutrition','pharmaceuticals','photography', | |
'psychology','public health','publishing','religion','renewable energy', | |
'research','sales','science','social media','social work', | |
'space exploration','sports','statistics','supply chain', | |
'sustainability','telecommunications','transportation', | |
'urban planning','veterinary medicine','virtual reality', | |
'web development','writing','zoology' | |
] | |
def softmax_with_temperature(logits, temperature=1.0): | |
exp_logits = np.exp(logits / temperature) | |
return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True) | |
def detect_context(input_text, temperature=2.0, top_n=3, score_threshold=0.05): | |
# Tokenize input text | |
inputs = context_tokenizer(input_text, return_tensors="np", padding=True, truncation=True, max_length=512) | |
input_ids = inputs["input_ids"].astype(np.int64) | |
attention_mask = inputs["attention_mask"].astype(np.int64) | |
# Debugging: Check tokenized input | |
print(f"Tokenized Input IDs: {input_ids}") | |
print(f"Tokenized Attention Mask: {attention_mask}") | |
# Run inference with the ONNX context model | |
outputs = context_session.run(None, { | |
"input_ids": input_ids, | |
"attention_mask": attention_mask | |
}) | |
logits = outputs[0][0] # Assuming batch size 1; take the first set of logits | |
# Debugging: Print raw logits | |
print(f"Raw logits: {logits}") | |
# Apply softmax with temperature | |
scores = softmax_with_temperature(logits, temperature=temperature) | |
# Pair labels with scores | |
label_scores = [(label, score) for label, score in zip(labels, scores)] | |
# Sort by scores in descending order | |
sorted_labels = sorted(label_scores, key=lambda x: x[1], reverse=True) | |
# Filter by threshold and return top_n contexts | |
filtered_labels = [label for label, score in sorted_labels if score > score_threshold] | |
top_contexts = filtered_labels[:top_n] | |
print(f"All scores: {label_scores}") # Debugging: Print all scores | |
print(f"Selected contexts: {top_contexts}") # Debugging: Print selected contexts | |
return top_contexts if top_contexts else ["general"] | |
def translate_text(input_text): | |
tokenized_input = translation_tokenizer( | |
input_text, return_tensors="np", | |
padding=True, truncation=True, max_length=512 | |
) | |
input_ids = tokenized_input["input_ids"].astype(np.int64) | |
attention_mask = tokenized_input["attention_mask"].astype(np.int64) | |
decoder_start_token_id = translation_tokenizer.cls_token_id or translation_tokenizer.pad_token_id | |
decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64) | |
for _ in range(512): | |
outputs = translation_session.run( | |
None, | |
{ | |
"input_ids": input_ids, | |
"attention_mask": attention_mask, | |
"decoder_input_ids": decoder_input_ids, | |
} | |
) | |
logits = outputs[0] | |
next_token_id = np.argmax(logits[:, -1, :], axis=-1).item() | |
decoder_input_ids = np.concatenate( | |
[decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1 | |
) | |
if next_token_id == translation_tokenizer.eos_token_id: | |
break | |
return translation_tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True) | |
def process_request(input_text): | |
context = detect_context(input_text) | |
translation = translate_text(input_text) # Translate without needing to pass context explicitly | |
return translation | |
gr.Interface( | |
fn=process_request, | |
inputs="text", | |
outputs="text", | |
live=True | |
).launch() | |