jocko commited on
Commit
6f5e256
·
1 Parent(s): 7f5755e

fix image similarity detection

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +15 -14
src/streamlit_app.py CHANGED
@@ -72,7 +72,22 @@ def load_medical_data():
72
  )
73
  return dataset
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  data = load_medical_data()
 
76
 
77
  from openai import OpenAI
78
  client = OpenAI(api_key=openai.api_key)
@@ -133,21 +148,7 @@ def get_similar_prompt(query):
133
  idx = top_result.indices[0].item()
134
  return data[idx]
135
 
136
- # Cache dataset image embeddings (takes time, so cached)
137
- @st.cache_data(show_spinner=True)
138
- def embed_dataset_images(_dataset):
139
- features = []
140
- for item in _dataset:
141
- # Load image from URL/path or raw bytes - adapt this if needed
142
- img = item["image"]
143
- inputs = clip_processor(images=img, return_tensors="pt")
144
- with torch.no_grad():
145
- feat = clip_model.get_image_features(**inputs)
146
- feat /= feat.norm(p=2, dim=-1, keepdim=True)
147
- features.append(feat.cpu())
148
- return torch.cat(features, dim=0)
149
 
150
- dataset_image_features = embed_dataset_images(data)
151
 
152
  if query:
153
  with st.spinner("Searching medical cases..."):
 
72
  )
73
  return dataset
74
 
75
+ # Cache dataset image embeddings (takes time, so cached)
76
+ @st.cache_data(show_spinner=True)
77
+ def embed_dataset_images(_dataset):
78
+ features = []
79
+ for item in _dataset:
80
+ # Load image from URL/path or raw bytes - adapt this if needed
81
+ img = item["image"]
82
+ inputs = clip_processor(images=img, return_tensors="pt")
83
+ with torch.no_grad():
84
+ feat = clip_model.get_image_features(**inputs)
85
+ feat /= feat.norm(p=2, dim=-1, keepdim=True)
86
+ features.append(feat.cpu())
87
+ return torch.cat(features, dim=0)
88
+
89
  data = load_medical_data()
90
+ dataset_image_features = embed_dataset_images(data)
91
 
92
  from openai import OpenAI
93
  client = OpenAI(api_key=openai.api_key)
 
148
  idx = top_result.indices[0].item()
149
  return data[idx]
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
 
152
 
153
  if query:
154
  with st.spinner("Searching medical cases..."):