jocko commited on
Commit
855ea3c
·
1 Parent(s): 02d2e6f

fix image similarity detection

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +25 -24
src/streamlit_app.py CHANGED
@@ -174,30 +174,31 @@ if query:
174
  st.warning("OpenAI API key not found. Please set OPENAI_API_KEY as a secret environment variable.")
175
 
176
  if uploaded_file is not None:
177
- print(uploaded_file)
178
- st.write(f'uploading file {uploaded_file.name}')
179
- query_image = Image.open(uploaded_file).convert("RGB")
180
- st.image(query_image, caption="Your uploaded image", use_container_width=True)
181
-
182
- # Embed uploaded image
183
- inputs = clip_processor(images=query_image, return_tensors="pt")
184
- with torch.no_grad():
185
- query_feat = clip_model.get_image_features(**inputs)
186
- query_feat /= query_feat.norm(p=2, dim=-1, keepdim=True)
187
-
188
- # Compute cosine similarity
189
- similarities = (dataset_image_features @ query_feat.T).squeeze(1) # [num_dataset_images]
190
-
191
- top_k = 3
192
- top_results = torch.topk(similarities, k=top_k)
193
-
194
- st.write(f"Top {top_k} similar medical cases:")
195
-
196
- for rank, idx in enumerate(top_results.indices):
197
- score = top_results.values[rank].item()
198
- similar_img = data[int(idx)]['image']
199
- st.image(similar_img, caption=f"Similarity: {score:.3f}", use_container_width=True)
200
- st.markdown(f"**Case description:** {data[int(idx)]['complaints']}")
 
201
  else:
202
  st.write("no image")
203
 
 
174
  st.warning("OpenAI API key not found. Please set OPENAI_API_KEY as a secret environment variable.")
175
 
176
  if uploaded_file is not None:
177
+ with st.spinner("Searching medical cases..."):
178
+ print(uploaded_file)
179
+ st.write(f'uploading file {uploaded_file.name}')
180
+ query_image = Image.open(uploaded_file).convert("RGB")
181
+ st.image(query_image, caption="Your uploaded image", use_container_width=True)
182
+
183
+ # Embed uploaded image
184
+ inputs = clip_processor(images=query_image, return_tensors="pt")
185
+ with torch.no_grad():
186
+ query_feat = clip_model.get_image_features(**inputs)
187
+ query_feat /= query_feat.norm(p=2, dim=-1, keepdim=True)
188
+
189
+ # Compute cosine similarity
190
+ similarities = (dataset_image_features @ query_feat.T).squeeze(1) # [num_dataset_images]
191
+
192
+ top_k = 3
193
+ top_results = torch.topk(similarities, k=top_k)
194
+
195
+ st.write(f"Top {top_k} similar medical cases:")
196
+
197
+ for rank, idx in enumerate(top_results.indices):
198
+ score = top_results.values[rank].item()
199
+ similar_img = data[int(idx)]['image']
200
+ st.image(similar_img, caption=f"Similarity: {score:.3f}", use_container_width=True)
201
+ st.markdown(f"**Case description:** {data[int(idx)]['complaints']}")
202
  else:
203
  st.write("no image")
204