Frenchizer commited on
Commit
bf1807a
·
verified ·
1 Parent(s): 6b27907

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -42
app.py CHANGED
@@ -1,8 +1,6 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModel
3
- from sklearn.metrics.pairwise import cosine_similarity
4
  import torch
5
- import numpy as np
6
  from gradio_client import Client
7
  from functools import lru_cache
8
 
@@ -17,49 +15,21 @@ def load_model_and_tokenizer():
17
  # Load the model and tokenizer
18
  tokenizer, model = load_model_and_tokenizer()
19
 
20
- # Precompute label embeddings
21
- labels = [
22
- "aerospace", "anatomy", "anthropology", "art",
23
- "automotive", "blockchain", "biology", "chemistry",
24
- "cryptocurrency", "data science", "design", "e-commerce",
25
- "education", "engineering", "entertainment", "environment",
26
- "fashion", "finance", "food commerce", "general",
27
- "gaming", "healthcare", "history", "html",
28
- "information technology", "IT", "keywords", "legal",
29
- "literature", "machine learning", "marketing", "medicine",
30
- "music", "personal development", "philosophy", "physics",
31
- "politics", "poetry", "programming", "real estate", "retail",
32
- "robotics", "slang", "social media", "speech", "sports",
33
- "sustained", "technical", "theater", "tourism", "travel"
34
- ]
35
-
36
- @lru_cache(maxsize=1)
37
- def precompute_label_embeddings():
38
- inputs = tokenizer(labels, padding=True, truncation=True, return_tensors="pt")
39
- with torch.no_grad():
40
- outputs = model(**inputs)
41
- return outputs.last_hidden_state.mean(dim=1).numpy() # Mean pooling for embeddings
42
-
43
- label_embeddings = precompute_label_embeddings()
44
-
45
- # Function to detect context (optimized)
46
- def detect_context(input_text, high_confidence_threshold=0.9, fallback_threshold=0.8, max_results=3):
47
- # Encode the input text
48
  inputs = tokenizer([input_text], padding=True, truncation=True, return_tensors="pt")
 
 
49
  with torch.no_grad():
50
  outputs = model(**inputs)
51
- input_embedding = outputs.last_hidden_state.mean(dim=1).numpy() # Mean pooling for embedding
52
-
53
- # Compute cosine similarities (optimized)
54
- similarities = cosine_similarity(input_embedding, label_embeddings)[0]
55
-
56
- # Find top-N labels based on thresholds
57
- top_indices = np.argsort(similarities)[-max_results:][::-1]
58
- top_labels = [labels[i] for i in top_indices if similarities[i] >= fallback_threshold]
59
-
60
- # Return high-confidence labels if any, otherwise fallback labels
61
- high_conf_labels = [label for label in top_labels if similarities[labels.index(label)] >= high_confidence_threshold]
62
- return high_conf_labels if high_conf_labels else top_labels[:max_results]
63
 
64
  # Translation client
65
  translation_client = Client("Frenchizer/space_3")
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModel
 
3
  import torch
 
4
  from gradio_client import Client
5
  from functools import lru_cache
6
 
 
15
  # Load the model and tokenizer
16
  tokenizer, model = load_model_and_tokenizer()
17
 
18
+ # Function to detect context (simplified)
19
+ def detect_context(input_text):
20
+ # Tokenize the input text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  inputs = tokenizer([input_text], padding=True, truncation=True, return_tensors="pt")
22
+
23
+ # Run the model
24
  with torch.no_grad():
25
  outputs = model(**inputs)
26
+
27
+ # Get the embedding (mean pooling)
28
+ input_embedding = outputs.last_hidden_state.mean(dim=1).numpy()
29
+
30
+ # For now, return a placeholder context
31
+ # You can replace this with a more sophisticated logic if needed
32
+ return ["general"]
 
 
 
 
 
33
 
34
  # Translation client
35
  translation_client = Client("Frenchizer/space_3")