Spaces:
Sleeping
Sleeping
jocko
commited on
Commit
·
6f5e256
1
Parent(s):
7f5755e
fix image similarity detection
Browse files- src/streamlit_app.py +15 -14
src/streamlit_app.py
CHANGED
@@ -72,7 +72,22 @@ def load_medical_data():
|
|
72 |
)
|
73 |
return dataset
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
data = load_medical_data()
|
|
|
76 |
|
77 |
from openai import OpenAI
|
78 |
client = OpenAI(api_key=openai.api_key)
|
@@ -133,21 +148,7 @@ def get_similar_prompt(query):
|
|
133 |
idx = top_result.indices[0].item()
|
134 |
return data[idx]
|
135 |
|
136 |
-
# Cache dataset image embeddings (takes time, so cached)
|
137 |
-
@st.cache_data(show_spinner=True)
|
138 |
-
def embed_dataset_images(_dataset):
|
139 |
-
features = []
|
140 |
-
for item in _dataset:
|
141 |
-
# Load image from URL/path or raw bytes - adapt this if needed
|
142 |
-
img = item["image"]
|
143 |
-
inputs = clip_processor(images=img, return_tensors="pt")
|
144 |
-
with torch.no_grad():
|
145 |
-
feat = clip_model.get_image_features(**inputs)
|
146 |
-
feat /= feat.norm(p=2, dim=-1, keepdim=True)
|
147 |
-
features.append(feat.cpu())
|
148 |
-
return torch.cat(features, dim=0)
|
149 |
|
150 |
-
dataset_image_features = embed_dataset_images(data)
|
151 |
|
152 |
if query:
|
153 |
with st.spinner("Searching medical cases..."):
|
|
|
72 |
)
|
73 |
return dataset
|
74 |
|
75 |
+
# Cache dataset image embeddings (takes time, so cached)
|
76 |
+
@st.cache_data(show_spinner=True)
|
77 |
+
def embed_dataset_images(_dataset):
|
78 |
+
features = []
|
79 |
+
for item in _dataset:
|
80 |
+
# Load image from URL/path or raw bytes - adapt this if needed
|
81 |
+
img = item["image"]
|
82 |
+
inputs = clip_processor(images=img, return_tensors="pt")
|
83 |
+
with torch.no_grad():
|
84 |
+
feat = clip_model.get_image_features(**inputs)
|
85 |
+
feat /= feat.norm(p=2, dim=-1, keepdim=True)
|
86 |
+
features.append(feat.cpu())
|
87 |
+
return torch.cat(features, dim=0)
|
88 |
+
|
89 |
data = load_medical_data()
|
90 |
+
dataset_image_features = embed_dataset_images(data)
|
91 |
|
92 |
from openai import OpenAI
|
93 |
client = OpenAI(api_key=openai.api_key)
|
|
|
148 |
idx = top_result.indices[0].item()
|
149 |
return data[idx]
|
150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
|
|
152 |
|
153 |
if query:
|
154 |
with st.spinner("Searching medical cases..."):
|