space_2 / app.py
Frenchizer's picture
Update app.py
99b3521 verified
import gradio as gr
import onnxruntime as ort
from transformers import AutoTokenizer
import numpy as np
# Initialize models
context_model_file = "./bart-large-mnli.onnx"
translation_model_file = "./model.onnx"
# Create inference sessions for both models
context_session = ort.InferenceSession(context_model_file)
translation_session = ort.InferenceSession(translation_model_file)
# Load tokenizers for context and translation models
context_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")
translation_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
labels = [
'aerospace', 'agriculture', 'anatomy', 'anthropology', 'architecture',
'art', 'automotive', 'astronomy', 'aviation', 'banking',
'biotechnology', 'biology', 'blockchain', 'business',
'chemistry', 'climate change', 'communication', 'computer science',
'construction', 'consumer goods', 'cryptocurrency', 'cybersecurity',
'dance', 'diplomacy', 'ecology', 'economics',
'education', 'energy', 'engineering', 'entrepreneurship',
'entertainment', 'ethics', 'fashion', 'finance',
'film', 'fitness', 'food commerce', 'general',
'gaming', 'geography', 'geology', 'graphic design',
'healthcare', 'history', 'html', 'human resources',
'immigration', 'innovation', 'journalism',
'keywords','language','law enforcement','legal','logistics','literature',
'machine learning','management','manufacturing','mathematics','media','military',
'music','nanotechnology','nutrition','pharmaceuticals','photography',
'psychology','public health','publishing','religion','renewable energy',
'research','sales','science','social media','social work',
'space exploration','sports','statistics','supply chain',
'sustainability','telecommunications','transportation',
'urban planning','veterinary medicine','virtual reality',
'web development','writing','zoology'
]
def softmax_with_temperature(logits, temperature=1.0):
exp_logits = np.exp(logits / temperature)
return exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
def detect_context(input_text, temperature=2.0, top_n=3, score_threshold=0.05):
# Tokenize input text
inputs = context_tokenizer(input_text, return_tensors="np", padding=True, truncation=True, max_length=512)
input_ids = inputs["input_ids"].astype(np.int64)
attention_mask = inputs["attention_mask"].astype(np.int64)
# Debugging: Check tokenized input
print(f"Tokenized Input IDs: {input_ids}")
print(f"Tokenized Attention Mask: {attention_mask}")
# Run inference with the ONNX context model
outputs = context_session.run(None, {
"input_ids": input_ids,
"attention_mask": attention_mask
})
logits = outputs[0][0] # Assuming batch size 1; take the first set of logits
# Debugging: Print raw logits
print(f"Raw logits: {logits}")
# Apply softmax with temperature
scores = softmax_with_temperature(logits, temperature=temperature)
# Pair labels with scores
label_scores = [(label, score) for label, score in zip(labels, scores)]
# Sort by scores in descending order
sorted_labels = sorted(label_scores, key=lambda x: x[1], reverse=True)
# Filter by threshold and return top_n contexts
filtered_labels = [label for label, score in sorted_labels if score > score_threshold]
top_contexts = filtered_labels[:top_n]
print(f"All scores: {label_scores}") # Debugging: Print all scores
print(f"Selected contexts: {top_contexts}") # Debugging: Print selected contexts
return top_contexts if top_contexts else ["general"]
def translate_text(input_text):
tokenized_input = translation_tokenizer(
input_text, return_tensors="np",
padding=True, truncation=True, max_length=512
)
input_ids = tokenized_input["input_ids"].astype(np.int64)
attention_mask = tokenized_input["attention_mask"].astype(np.int64)
decoder_start_token_id = translation_tokenizer.cls_token_id or translation_tokenizer.pad_token_id
decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
for _ in range(512):
outputs = translation_session.run(
None,
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
}
)
logits = outputs[0]
next_token_id = np.argmax(logits[:, -1, :], axis=-1).item()
decoder_input_ids = np.concatenate(
[decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1
)
if next_token_id == translation_tokenizer.eos_token_id:
break
return translation_tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
def process_request(input_text):
context = detect_context(input_text)
translation = translate_text(input_text) # Translate without needing to pass context explicitly
return translation
gr.Interface(
fn=process_request,
inputs="text",
outputs="text",
live=True
).launch()