Frenchizer commited on
Commit
828614b
·
verified ·
1 Parent(s): bf1807a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -22
app.py CHANGED
@@ -1,10 +1,13 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModel
 
3
  import torch
 
4
  from gradio_client import Client
5
- from functools import lru_cache
6
 
7
  # Cache the model and tokenizer using lru_cache
 
 
8
  @lru_cache(maxsize=1)
9
  def load_model_and_tokenizer():
10
  model_name = "./all-MiniLM-L6-v2" # Replace with your Space and model path
@@ -15,41 +18,85 @@ def load_model_and_tokenizer():
15
  # Load the model and tokenizer
16
  tokenizer, model = load_model_and_tokenizer()
17
 
18
- # Function to detect context (simplified)
19
- def detect_context(input_text):
20
- # Tokenize the input text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  inputs = tokenizer([input_text], padding=True, truncation=True, return_tensors="pt")
22
-
23
- # Run the model
24
  with torch.no_grad():
25
  outputs = model(**inputs)
26
-
27
- # Get the embedding (mean pooling)
28
- input_embedding = outputs.last_hidden_state.mean(dim=1).numpy()
29
-
30
- # For now, return a placeholder context
31
- # You can replace this with a more sophisticated logic if needed
32
- return ["general"]
 
 
33
 
34
  # Translation client
35
  translation_client = Client("Frenchizer/space_3")
36
 
37
- def translate_text(input_text):
 
38
  return translation_client.predict(input_text)
39
 
40
  def process_request(input_text):
41
- context = detect_context(input_text)
42
- print(f"Detected context: {context}")
43
- translation = translate_text(input_text)
44
- return translation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- # Gradio interface
47
  interface = gr.Interface(
48
- fn=process_request,
49
  inputs="text",
50
  outputs="text",
51
  title="Frenchizer",
52
- description="Translate text from English to French with context detection."
53
  )
54
 
55
- interface.launch()
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModel
3
+ from sklearn.metrics.pairwise import cosine_similarity
4
  import torch
5
+ import numpy as np
6
  from gradio_client import Client
 
7
 
8
  # Cache the model and tokenizer using lru_cache
9
+ from functools import lru_cache
10
+
11
  @lru_cache(maxsize=1)
12
  def load_model_and_tokenizer():
13
  model_name = "./all-MiniLM-L6-v2" # Replace with your Space and model path
 
18
  # Load the model and tokenizer
19
  tokenizer, model = load_model_and_tokenizer()
20
 
21
+ # Precompute label embeddings
22
+ labels = [
23
+ "aerospace", "anatomy", "anthropology", "art",
24
+ "automotive", "blockchain", "biology", "chemistry",
25
+ "cryptocurrency", "data science", "design", "e-commerce",
26
+ "education", "engineering", "entertainment", "environment",
27
+ "fashion", "finance", "food commerce", "general",
28
+ "gaming", "healthcare", "history", "html",
29
+ "information technology", "IT", "keywords", "legal",
30
+ "literature", "machine learning", "marketing", "medicine",
31
+ "music", "personal development", "philosophy", "physics",
32
+ "politics", "poetry", "programming", "real estate", "retail",
33
+ "robotics", "slang", "social media", "speech", "sports",
34
+ "sustained", "technical", "theater", "tourism", "travel"
35
+ ]
36
+
37
+ @lru_cache(maxsize=1)
38
+ def precompute_label_embeddings():
39
+ inputs = tokenizer(labels, padding=True, truncation=True, return_tensors="pt")
40
+ with torch.no_grad():
41
+ outputs = model(**inputs)
42
+ return outputs.last_hidden_state.mean(dim=1).numpy() # Mean pooling for embeddings
43
+
44
+ label_embeddings = precompute_label_embeddings()
45
+
46
+ # Function to detect context
47
+ def detect_context(input_text, fallback_threshold=0.8, max_results=3):
48
+ # Encode the input text
49
  inputs = tokenizer([input_text], padding=True, truncation=True, return_tensors="pt")
 
 
50
  with torch.no_grad():
51
  outputs = model(**inputs)
52
+ input_embedding = outputs.last_hidden_state.mean(dim=1).numpy() # Mean pooling for embedding
53
+
54
+ # Compute cosine similarities
55
+ similarities = cosine_similarity(input_embedding, label_embeddings)[0]
56
+
57
+ # Check for fallback matches
58
+ fallback_labels = [(labels[i], score) for i, score in enumerate(similarities) if score >= fallback_threshold]
59
+ fallback_labels = sorted(fallback_labels, key=lambda x: x[1], reverse=True)[:max_results]
60
+ return fallback_labels
61
 
62
  # Translation client
63
  translation_client = Client("Frenchizer/space_3")
64
 
65
+ def translate_text(input_text, context="general"):
66
+ # Append the context to the input text for the translation client (if needed)
67
  return translation_client.predict(input_text)
68
 
69
  def process_request(input_text):
70
+ # Step 1: Return the general translation immediately
71
+ general_translation = translate_text(input_text, context="general")
72
+
73
+ # Step 2: Detect context in the background
74
+ context_results = detect_context(input_text)
75
+
76
+ # Step 3: Generate additional translations for high-confidence contexts
77
+ additional_translations = {}
78
+ for context, score in context_results:
79
+ if context != "general":
80
+ additional_translations[context] = translate_text(input_text, context=context)
81
+
82
+ # Return the general translation and additional context translations
83
+ return general_translation, additional_translations
84
+
85
+ # Gradio interface with multiple outputs
86
+ def gradio_interface(input_text):
87
+ general_translation, additional_translations = process_request(input_text)
88
+ outputs = f"General Translation: {general_translation}\n\n"
89
+ for context, translation in additional_translations.items():
90
+ outputs += f"Context ({context}): {translation}\n\n"
91
+ return outputs.strip()
92
 
93
+ # Create the Gradio interface
94
  interface = gr.Interface(
95
+ fn=gradio_interface,
96
  inputs="text",
97
  outputs="text",
98
  title="Frenchizer",
99
+ description="Translate text from English to French with optimized context detection."
100
  )
101
 
102
+ interface.launch()