Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModel | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import torch | |
| import numpy as np | |
| from gradio_client import Client | |
| from functools import lru_cache | |
| # Cache the model and tokenizer using lru_cache | |
| def load_model_and_tokenizer(): | |
| model_name = "./all-MiniLM-L6-v2" # Replace with your Space and model path | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModel.from_pretrained(model_name) | |
| return tokenizer, model | |
| # Load the model and tokenizer | |
| tokenizer, model = load_model_and_tokenizer() | |
| # Precompute label embeddings | |
| labels = [ | |
| "aerospace", "anatomy", "anthropology", "art", | |
| "automotive", "blockchain", "biology", "chemistry", | |
| "cryptocurrency", "data science", "design", "e-commerce", | |
| "education", "engineering", "entertainment", "environment", | |
| "fashion", "finance", "food commerce", "general", | |
| "gaming", "healthcare", "history", "html", | |
| "information technology", "IT", "keywords", "legal", | |
| "literature", "machine learning", "marketing", "medicine", | |
| "music", "personal development", "philosophy", "physics", | |
| "politics", "poetry", "programming", "real estate", "retail", | |
| "robotics", "slang", "social media", "speech", "sports", | |
| "sustained", "technical", "theater", "tourism", "travel" | |
| ] | |
| def precompute_label_embeddings(): | |
| inputs = tokenizer(labels, padding=True, truncation=True, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| return outputs.last_hidden_state.mean(dim=1).numpy() # Mean pooling for embeddings | |
| label_embeddings = precompute_label_embeddings() | |
| # Softmax function to convert scores to probabilities | |
| def softmax(x): | |
| exp_x = np.exp(x - np.max(x)) # Subtract max for numerical stability | |
| return exp_x / exp_x.sum() | |
| # Function to detect context | |
| def detect_context(input_text, top_n=3): | |
| # Encode the input text | |
| inputs = tokenizer([input_text], padding=True, truncation=True, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| input_embedding = outputs.last_hidden_state.mean(dim=1).numpy() # Mean pooling for embedding | |
| # Compute cosine similarities | |
| similarities = cosine_similarity(input_embedding, label_embeddings)[0] | |
| # Apply softmax to convert similarities to probabilities | |
| probabilities = softmax(similarities) | |
| # Pair each label with its probability | |
| label_probabilities = list(zip(labels, probabilities)) | |
| # Sort by probability in descending order | |
| label_probabilities.sort(key=lambda x: x[1], reverse=True) | |
| # Select the top N contexts | |
| top_contexts = label_probabilities[:top_n] | |
| return top_contexts | |
| # Translation client | |
| translation_client = Client("Frenchizer/space_7") | |
| def translate_text(input_text): | |
| # Translate the input text | |
| return translation_client.predict(input_text) | |
| def process_request(input_text): | |
| # Step 1: Translate the text | |
| translation = translate_text(input_text) | |
| # Step 2: Detect context | |
| context_results = detect_context(input_text) | |
| # Step 3: Print the list of high-confidence contexts | |
| print("Detected Contexts (Top 3):", context_results) | |
| # Return the translation and contexts | |
| return translation, context_results | |
| # Gradio interface | |
| def gradio_interface(input_text): | |
| translation, contexts = process_request(input_text) | |
| # Format the output | |
| output = f"Translation: {translation}\n\nDetected Contexts (Top 3):\n" | |
| for context, score in contexts: | |
| output += f"- {context} (confidence: {score:.4f})\n" | |
| return output.strip() | |
| # Create the Gradio interface | |
| interface = gr.Interface( | |
| fn=gradio_interface, | |
| inputs="text", | |
| outputs="text", | |
| title="Frenchizer", | |
| description="Translate text from English to French with context detection." | |
| ) | |
| interface.launch() |