import gradio as gr from transformers import AutoTokenizer, AutoModel from sklearn.metrics.pairwise import cosine_similarity import torch import numpy as np from gradio_client import Client from functools import lru_cache # Cache the model and tokenizer using lru_cache @lru_cache(maxsize=1) def load_model_and_tokenizer(): model_name = "./all-MiniLM-L6-v2" # Replace with your Space and model path tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModel.from_pretrained(model_name) return tokenizer, model # Load the model and tokenizer tokenizer, model = load_model_and_tokenizer() # Precompute label embeddings labels = [ "aerospace", "anatomy", "anthropology", "art", "automotive", "blockchain", "biology", "chemistry", "cryptocurrency", "data science", "design", "e-commerce", "education", "engineering", "entertainment", "environment", "fashion", "finance", "food commerce", "general", "gaming", "healthcare", "history", "html", "information technology", "IT", "keywords", "legal", "literature", "machine learning", "marketing", "medicine", "music", "personal development", "philosophy", "physics", "politics", "poetry", "programming", "real estate", "retail", "robotics", "slang", "social media", "speech", "sports", "sustained", "technical", "theater", "tourism", "travel" ] @lru_cache(maxsize=1) def precompute_label_embeddings(): inputs = tokenizer(labels, padding=True, truncation=True, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) return outputs.last_hidden_state.mean(dim=1).numpy() # Mean pooling for embeddings label_embeddings = precompute_label_embeddings() # Function to detect context def detect_context(input_text, fallback_threshold=0.8, max_results=3): # Encode the input text inputs = tokenizer([input_text], padding=True, truncation=True, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) input_embedding = outputs.last_hidden_state.mean(dim=1).numpy() # Mean pooling for embedding # Compute cosine similarities similarities = cosine_similarity(input_embedding, label_embeddings)[0] # Check for fallback matches fallback_labels = [(labels[i], score) for i, score in enumerate(similarities) if score >= fallback_threshold] fallback_labels = sorted(fallback_labels, key=lambda x: x[1], reverse=True)[:max_results] return fallback_labels # Translation client translation_client = Client("Frenchizer/space_3") def translate_text(input_text, context="general"): # Append the context to the input text for the translation client (if needed) return translation_client.predict(input_text) def process_request(input_text): # Step 1: Return the general translation immediately general_translation = translate_text(input_text, context="general") # Step 2: Detect context in the background context_results = detect_context(input_text) # Step 3: Generate additional translations for high-confidence contexts additional_translations = {} for context, score in context_results: if context != "general": additional_translations[context] = translate_text(input_text, context=context) # Return the general translation and additional context translations return general_translation, additional_translations # Gradio interface with multiple outputs def gradio_interface(input_text): general_translation, additional_translations = process_request(input_text) outputs = f"General Translation: {general_translation}\n\n" for context, translation in additional_translations.items(): outputs += f"Context ({context}): {translation}\n\n" return outputs.strip() # Create the Gradio interface interface = gr.Interface( fn=gradio_interface, inputs="text", outputs="text", title="Frenchizer", description="Translate text from English to French with optimized context detection." ) interface.launch()