jocko commited on
Commit
234ca8e
·
1 Parent(s): 1e92b95

merge code

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +40 -0
src/streamlit_app.py CHANGED
@@ -107,6 +107,7 @@ TEXT_COLUMN = "complaints" # or "general_complaint", depending on your needs
107
  st.title("🩺 Multimodal Medical Chatbot")
108
 
109
  query = st.text_input("Enter your medical question or symptom description:")
 
110
 
111
  @track
112
  def get_chat_completion_openai(client, prompt: str):
@@ -127,6 +128,21 @@ def get_similar_prompt(query):
127
  idx = top_result.indices[0].item()
128
  return data[idx]
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  if query:
132
  with st.spinner("Searching medical cases..."):
@@ -152,4 +168,28 @@ if query:
152
  else:
153
  st.warning("OpenAI API key not found. Please set OPENAI_API_KEY as a secret environment variable.")
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  st.caption("This chatbot is for educational purposes only and does not provide medical advice.")
 
107
  st.title("🩺 Multimodal Medical Chatbot")
108
 
109
  query = st.text_input("Enter your medical question or symptom description:")
110
+ uploaded_file = st.file_uploader("Upload an image to find similar medical cases:", type=["png", "jpg", "jpeg"])
111
 
112
  @track
113
  def get_chat_completion_openai(client, prompt: str):
 
128
  idx = top_result.indices[0].item()
129
  return data[idx]
130
 
131
+ # Cache dataset image embeddings (takes time, so cached)
132
+ @st.cache_data(show_spinner=True)
133
+ def embed_dataset_images(dataset):
134
+ features = []
135
+ for item in dataset:
136
+ # Load image from URL/path or raw bytes - adapt this if needed
137
+ img = Image.open(item["image"]).convert("RGB")
138
+ inputs = clip_processor(images=img, return_tensors="pt")
139
+ with torch.no_grad():
140
+ feat = clip_model.get_image_features(**inputs)
141
+ feat /= feat.norm(p=2, dim=-1, keepdim=True)
142
+ features.append(feat.cpu())
143
+ return torch.cat(features, dim=0)
144
+
145
+ dataset_image_features = embed_dataset_images(data)
146
 
147
  if query:
148
  with st.spinner("Searching medical cases..."):
 
168
  else:
169
  st.warning("OpenAI API key not found. Please set OPENAI_API_KEY as a secret environment variable.")
170
 
171
+ if uploaded_file is not None:
172
+ query_image = Image.open(uploaded_file).convert("RGB")
173
+ st.image(query_image, caption="Your uploaded image", use_container_width=True)
174
+
175
+ # Embed uploaded image
176
+ inputs = clip_processor(images=query_image, return_tensors="pt")
177
+ with torch.no_grad():
178
+ query_feat = clip_model.get_image_features(**inputs)
179
+ query_feat /= query_feat.norm(p=2, dim=-1, keepdim=True)
180
+
181
+ # Compute cosine similarity
182
+ similarities = (dataset_image_features @ query_feat.T).squeeze(1) # [num_dataset_images]
183
+
184
+ top_k = 3
185
+ top_results = torch.topk(similarities, k=top_k)
186
+
187
+ st.write(f"Top {top_k} similar medical cases:")
188
+
189
+ for rank, idx in enumerate(top_results.indices):
190
+ score = top_results.values[rank].item()
191
+ similar_img = Image.open(data[int(idx)]["image"]).convert("RGB")
192
+ st.image(similar_img, caption=f"Similarity: {score:.3f}", use_container_width=True)
193
+ st.markdown(f"**Case description:** {data[int(idx)]['complaints']}")
194
+
195
  st.caption("This chatbot is for educational purposes only and does not provide medical advice.")