jocko commited on
Commit
61e5bfd
·
1 Parent(s): dc78eca

merge code

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +18 -3
src/streamlit_app.py CHANGED
@@ -77,14 +77,29 @@ client = OpenAI(api_key=openai.api_key)
77
  # After seeing the real column name, let's say it's "text" instead of "description":
78
  text_field = "text" if "text" in data.features else list(data.features.keys())[0]
79
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  # Then use dynamic access:
81
  #text_embeddings = embed_texts(data[text_field])
82
 
83
  # ========== 🧠 Embedding Function ==========
84
  @st.cache_data(show_spinner=False)
85
- def embed_texts(_texts):
86
  return text_model.encode(_texts, convert_to_tensor=True)
87
 
 
 
 
88
  # Pick which text column to use
89
  TEXT_COLUMN = "complaints" # or "general_complaint", depending on your needs
90
 
@@ -112,8 +127,8 @@ def get_similar_prompt(query_embedding, text_embeddings):
112
 
113
  if query:
114
  with st.spinner("Searching medical cases..."):
115
- text_embeddings = embed_texts(data[TEXT_COLUMN])
116
- query_embedding = embed_texts([query])[0]
117
 
118
  # Compute similarity
119
  selected = get_similar_prompt(query_embedding, text_embeddings)
 
77
  # After seeing the real column name, let's say it's "text" instead of "description":
78
  text_field = "text" if "text" in data.features else list(data.features.keys())[0]
79
 
80
+
81
+ @st.cache_data(show_spinner=False)
82
+ def prepare_combined_texts(_dataset):
83
+ combined = []
84
+ for gc, c in zip(_dataset["general_complaint"], _dataset["complaints"]):
85
+ gc_str = gc if gc else ""
86
+ c_str = c if c else ""
87
+ combined.append(f"General complaint: {gc_str}. Additional details: {c_str}")
88
+ return combined
89
+
90
+ combined_texts = prepare_combined_texts(data)
91
+
92
  # Then use dynamic access:
93
  #text_embeddings = embed_texts(data[text_field])
94
 
95
  # ========== 🧠 Embedding Function ==========
96
  @st.cache_data(show_spinner=False)
97
+ def embed_dataset_texts(_texts):
98
  return text_model.encode(_texts, convert_to_tensor=True)
99
 
100
+ def embed_query_text(query):
101
+ return text_model.encode([query], convert_to_tensor=True)[0]
102
+
103
  # Pick which text column to use
104
  TEXT_COLUMN = "complaints" # or "general_complaint", depending on your needs
105
 
 
127
 
128
  if query:
129
  with st.spinner("Searching medical cases..."):
130
+ text_embeddings = embed_dataset_texts(combined_texts) # cached
131
+ query_embedding = embed_query_text(query) # recalculated each time
132
 
133
  # Compute similarity
134
  selected = get_similar_prompt(query_embedding, text_embeddings)