Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -184,30 +184,31 @@ def _to_html_table(S: np.ndarray, names: List[str]) -> str:
|
|
| 184 |
|
| 185 |
@spaces.GPU(duration=_gpu_duration_gallery)
|
| 186 |
def batch_similarity(files: List[str], model_name: str, pooling: str):
|
| 187 |
-
# files is a list of filepaths from gr.Files
|
| 188 |
paths = files or []
|
| 189 |
if len(paths) < 2:
|
| 190 |
-
return "Upload at least 2 images", None
|
| 191 |
-
|
| 192 |
if not torch.cuda.is_available():
|
| 193 |
-
raise RuntimeError("CUDA not available. Ensure
|
| 194 |
|
| 195 |
model_id = MODELS[model_name]
|
|
|
|
| 196 |
embs = []
|
| 197 |
-
for img in
|
| 198 |
e, _, _ = _extract_core(img, model_id, pooling, want_overlay=False)
|
| 199 |
embs.append(e)
|
| 200 |
|
| 201 |
if len(embs) < 2:
|
| 202 |
-
return "Failed to read or embed images", None
|
| 203 |
|
| 204 |
X = np.vstack(embs).astype(np.float32)
|
| 205 |
Xn = X / np.clip(np.linalg.norm(X, axis=1, keepdims=True), 1e-8, None)
|
| 206 |
S = Xn @ Xn.T
|
| 207 |
|
|
|
|
| 208 |
csv_path = os.path.join(tempfile.gettempdir(), f"cosine_{uuid4().hex}.csv")
|
| 209 |
np.savetxt(csv_path, S, delimiter=",", fmt="%.6f")
|
| 210 |
-
|
|
|
|
| 211 |
|
| 212 |
|
| 213 |
# ---------------------------
|
|
@@ -235,7 +236,7 @@ with gr.Blocks() as app:
|
|
| 235 |
gr.Markdown("Upload multiple images. We compute a cosine similarity matrix on GPU and return a CSV.")
|
| 236 |
# Input as Files so you can multi-upload, plus a Gallery preview
|
| 237 |
files_in = gr.Files(label="Upload images", file_types=["image"], file_count="multiple", type="filepath")
|
| 238 |
-
|
| 239 |
model_dd2 = gr.Dropdown(choices=list(MODELS.keys()), value=DEFAULT_MODEL, label="Backbone")
|
| 240 |
pooling2 = gr.Radio(["CLS", "Mean of patch tokens"], value="CLS", label="Pooling")
|
| 241 |
go = gr.Button("Compute cosine on GPU")
|
|
@@ -255,7 +256,7 @@ with gr.Blocks() as app:
|
|
| 255 |
return imgs
|
| 256 |
|
| 257 |
files_in.change(_preview, inputs=files_in, outputs=gallery_preview)
|
| 258 |
-
go.click(batch_similarity, [
|
| 259 |
|
| 260 |
|
| 261 |
|
|
|
|
| 184 |
|
| 185 |
@spaces.GPU(duration=_gpu_duration_gallery)
|
| 186 |
def batch_similarity(files: List[str], model_name: str, pooling: str):
|
|
|
|
| 187 |
paths = files or []
|
| 188 |
if len(paths) < 2:
|
| 189 |
+
return "<em>Upload at least 2 images</em>", None
|
|
|
|
| 190 |
if not torch.cuda.is_available():
|
| 191 |
+
raise RuntimeError("CUDA not available. Ensure ZeroGPU is selected.")
|
| 192 |
|
| 193 |
model_id = MODELS[model_name]
|
| 194 |
+
imgs = _open_images_from_paths(paths)
|
| 195 |
embs = []
|
| 196 |
+
for img in imgs:
|
| 197 |
e, _, _ = _extract_core(img, model_id, pooling, want_overlay=False)
|
| 198 |
embs.append(e)
|
| 199 |
|
| 200 |
if len(embs) < 2:
|
| 201 |
+
return "<em>Failed to read or embed images</em>", None
|
| 202 |
|
| 203 |
X = np.vstack(embs).astype(np.float32)
|
| 204 |
Xn = X / np.clip(np.linalg.norm(X, axis=1, keepdims=True), 1e-8, None)
|
| 205 |
S = Xn @ Xn.T
|
| 206 |
|
| 207 |
+
# save CSV and build HTML table
|
| 208 |
csv_path = os.path.join(tempfile.gettempdir(), f"cosine_{uuid4().hex}.csv")
|
| 209 |
np.savetxt(csv_path, S, delimiter=",", fmt="%.6f")
|
| 210 |
+
html = _to_html_table(S, paths)
|
| 211 |
+
return html, csv_path
|
| 212 |
|
| 213 |
|
| 214 |
# ---------------------------
|
|
|
|
| 236 |
gr.Markdown("Upload multiple images. We compute a cosine similarity matrix on GPU and return a CSV.")
|
| 237 |
# Input as Files so you can multi-upload, plus a Gallery preview
|
| 238 |
files_in = gr.Files(label="Upload images", file_types=["image"], file_count="multiple", type="filepath")
|
| 239 |
+
gal = gr.Gallery(label="Images", columns=4, height=360, allow_preview=True)
|
| 240 |
model_dd2 = gr.Dropdown(choices=list(MODELS.keys()), value=DEFAULT_MODEL, label="Backbone")
|
| 241 |
pooling2 = gr.Radio(["CLS", "Mean of patch tokens"], value="CLS", label="Pooling")
|
| 242 |
go = gr.Button("Compute cosine on GPU")
|
|
|
|
| 256 |
return imgs
|
| 257 |
|
| 258 |
files_in.change(_preview, inputs=files_in, outputs=gallery_preview)
|
| 259 |
+
go.click(batch_similarity, [gal, model_dd2, pooling2], [status, csv])
|
| 260 |
|
| 261 |
|
| 262 |
|