Frenchizer commited on
Commit
e8f61e6
·
verified ·
1 Parent(s): 0747082

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -25
app.py CHANGED
@@ -2,7 +2,6 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModel
3
  from sklearn.metrics.pairwise import cosine_similarity
4
  import torch
5
- import numpy as np
6
  from gradio_client import Client
7
  from functools import lru_cache
8
 
@@ -43,7 +42,7 @@ def precompute_label_embeddings():
43
  label_embeddings = precompute_label_embeddings()
44
 
45
  # Function to detect context
46
- def detect_context(input_text, fallback_threshold=0.8, max_results=3):
47
  # Encode the input text
48
  inputs = tokenizer([input_text], padding=True, truncation=True, return_tensors="pt")
49
  with torch.no_grad():
@@ -53,41 +52,48 @@ def detect_context(input_text, fallback_threshold=0.8, max_results=3):
53
  # Compute cosine similarities
54
  similarities = cosine_similarity(input_embedding, label_embeddings)[0]
55
 
56
- # Check for fallback matches
57
- fallback_labels = [(labels[i], score) for i, score in enumerate(similarities) if score >= fallback_threshold]
58
- fallback_labels = sorted(fallback_labels, key=lambda x: x[1], reverse=True)[:max_results]
59
- return fallback_labels
 
 
 
 
 
 
 
 
 
60
 
61
  # Translation client
62
  translation_client = Client("Frenchizer/space_7")
63
 
64
- def translate_text(input_text, context="general"):
65
- # Append the context to the input text for the translation client (if needed)
66
  return translation_client.predict(input_text)
67
 
68
  def process_request(input_text):
69
- # Step 1: Return the general translation immediately
70
- general_translation = translate_text(input_text, context="general")
71
 
72
- # Step 2: Detect context in the background
73
  context_results = detect_context(input_text)
74
 
75
- # Step 3: Generate additional translations for high-confidence contexts
76
- additional_translations = {}
77
- for context, score in context_results:
78
- if context != "general":
79
- additional_translations[context] = translate_text(input_text, context=context)
80
 
81
- # Return the general translation and additional context translations
82
- return general_translation, additional_translations
83
 
84
- # Gradio interface with multiple outputs
85
  def gradio_interface(input_text):
86
- general_translation, additional_translations = process_request(input_text)
87
- outputs = f"{general_translation}\n\n"
88
- for context, translation in additional_translations.items():
89
- outputs += f"Context ({context}): {translation}\n\n"
90
- return outputs.strip()
 
91
 
92
  # Create the Gradio interface
93
  interface = gr.Interface(
@@ -95,7 +101,7 @@ interface = gr.Interface(
95
  inputs="text",
96
  outputs="text",
97
  title="Frenchizer",
98
- description="Translate text from English to French with optimized context detection and MarianMT model."
99
  )
100
 
101
  interface.launch()
 
2
  from transformers import AutoTokenizer, AutoModel
3
  from sklearn.metrics.pairwise import cosine_similarity
4
  import torch
 
5
  from gradio_client import Client
6
  from functools import lru_cache
7
 
 
42
  label_embeddings = precompute_label_embeddings()
43
 
44
  # Function to detect context
45
+ def detect_context(input_text, fallback_threshold=0.5): # Lowered threshold for debugging
46
  # Encode the input text
47
  inputs = tokenizer([input_text], padding=True, truncation=True, return_tensors="pt")
48
  with torch.no_grad():
 
52
  # Compute cosine similarities
53
  similarities = cosine_similarity(input_embedding, label_embeddings)[0]
54
 
55
+ # Debugging: Print all labels and their similarity scores
56
+ print("Debug: Similarity scores for all labels:")
57
+ for label, score in zip(labels, similarities):
58
+ print(f"{label}: {score:.4f}")
59
+
60
+ # Filter contexts with confidence >= fallback_threshold
61
+ high_confidence_contexts = [(labels[i], score) for i, score in enumerate(similarities) if score >= fallback_threshold]
62
+
63
+ # If no contexts meet the threshold, include "general" as a fallback
64
+ if not high_confidence_contexts:
65
+ high_confidence_contexts = [("general", 1.0)] # Assign a default score of 1.0 for "general"
66
+
67
+ return high_confidence_contexts
68
 
69
  # Translation client
70
  translation_client = Client("Frenchizer/space_7")
71
 
72
+ def translate_text(input_text):
73
+ # Translate the input text
74
  return translation_client.predict(input_text)
75
 
76
  def process_request(input_text):
77
+ # Step 1: Translate the text
78
+ translation = translate_text(input_text)
79
 
80
+ # Step 2: Detect context
81
  context_results = detect_context(input_text)
82
 
83
+ # Step 3: Print the list of high-confidence contexts
84
+ print("High-confidence contexts:", context_results)
 
 
 
85
 
86
+ # Return the translation and contexts
87
+ return translation, context_results
88
 
89
+ # Gradio interface
90
  def gradio_interface(input_text):
91
+ translation, contexts = process_request(input_text)
92
+ # Format the output
93
+ output = f"Translation: {translation}\n\nDetected Contexts:\n"
94
+ for context, score in contexts:
95
+ output += f"- {context} (confidence: {score:.2f})\n"
96
+ return output.strip()
97
 
98
  # Create the Gradio interface
99
  interface = gr.Interface(
 
101
  inputs="text",
102
  outputs="text",
103
  title="Frenchizer",
104
+ description="Translate text from English to French with context detection."
105
  )
106
 
107
  interface.launch()