Frenchizer commited on
Commit
4662c92
·
verified ·
1 Parent(s): e8f61e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -14
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModel
3
  from sklearn.metrics.pairwise import cosine_similarity
4
  import torch
 
5
  from gradio_client import Client
6
  from functools import lru_cache
7
 
@@ -41,8 +42,13 @@ def precompute_label_embeddings():
41
 
42
  label_embeddings = precompute_label_embeddings()
43
 
 
 
 
 
 
44
  # Function to detect context
45
- def detect_context(input_text, fallback_threshold=0.5): # Lowered threshold for debugging
46
  # Encode the input text
47
  inputs = tokenizer([input_text], padding=True, truncation=True, return_tensors="pt")
48
  with torch.no_grad():
@@ -52,19 +58,19 @@ def detect_context(input_text, fallback_threshold=0.5): # Lowered threshold for
52
  # Compute cosine similarities
53
  similarities = cosine_similarity(input_embedding, label_embeddings)[0]
54
 
55
- # Debugging: Print all labels and their similarity scores
56
- print("Debug: Similarity scores for all labels:")
57
- for label, score in zip(labels, similarities):
58
- print(f"{label}: {score:.4f}")
 
59
 
60
- # Filter contexts with confidence >= fallback_threshold
61
- high_confidence_contexts = [(labels[i], score) for i, score in enumerate(similarities) if score >= fallback_threshold]
62
 
63
- # If no contexts meet the threshold, include "general" as a fallback
64
- if not high_confidence_contexts:
65
- high_confidence_contexts = [("general", 1.0)] # Assign a default score of 1.0 for "general"
66
 
67
- return high_confidence_contexts
68
 
69
  # Translation client
70
  translation_client = Client("Frenchizer/space_7")
@@ -81,7 +87,7 @@ def process_request(input_text):
81
  context_results = detect_context(input_text)
82
 
83
  # Step 3: Print the list of high-confidence contexts
84
- print("High-confidence contexts:", context_results)
85
 
86
  # Return the translation and contexts
87
  return translation, context_results
@@ -90,9 +96,9 @@ def process_request(input_text):
90
  def gradio_interface(input_text):
91
  translation, contexts = process_request(input_text)
92
  # Format the output
93
- output = f"Translation: {translation}\n\nDetected Contexts:\n"
94
  for context, score in contexts:
95
- output += f"- {context} (confidence: {score:.2f})\n"
96
  return output.strip()
97
 
98
  # Create the Gradio interface
 
2
  from transformers import AutoTokenizer, AutoModel
3
  from sklearn.metrics.pairwise import cosine_similarity
4
  import torch
5
+ import numpy as np
6
  from gradio_client import Client
7
  from functools import lru_cache
8
 
 
42
 
43
  label_embeddings = precompute_label_embeddings()
44
 
45
+ # Softmax function to convert scores to probabilities
46
+ def softmax(x):
47
+ exp_x = np.exp(x - np.max(x)) # Subtract max for numerical stability
48
+ return exp_x / exp_x.sum()
49
+
50
  # Function to detect context
51
+ def detect_context(input_text, top_n=3):
52
  # Encode the input text
53
  inputs = tokenizer([input_text], padding=True, truncation=True, return_tensors="pt")
54
  with torch.no_grad():
 
58
  # Compute cosine similarities
59
  similarities = cosine_similarity(input_embedding, label_embeddings)[0]
60
 
61
+ # Apply softmax to convert similarities to probabilities
62
+ probabilities = softmax(similarities)
63
+
64
+ # Pair each label with its probability
65
+ label_probabilities = list(zip(labels, probabilities))
66
 
67
+ # Sort by probability in descending order
68
+ label_probabilities.sort(key=lambda x: x[1], reverse=True)
69
 
70
+ # Select the top N contexts
71
+ top_contexts = label_probabilities[:top_n]
 
72
 
73
+ return top_contexts
74
 
75
  # Translation client
76
  translation_client = Client("Frenchizer/space_7")
 
87
  context_results = detect_context(input_text)
88
 
89
  # Step 3: Print the list of high-confidence contexts
90
+ print("Detected Contexts (Top 3):", context_results)
91
 
92
  # Return the translation and contexts
93
  return translation, context_results
 
96
  def gradio_interface(input_text):
97
  translation, contexts = process_request(input_text)
98
  # Format the output
99
+ output = f"Translation: {translation}\n\nDetected Contexts (Top 3):\n"
100
  for context, score in contexts:
101
+ output += f"- {context} (confidence: {score:.4f})\n"
102
  return output.strip()
103
 
104
  # Create the Gradio interface