Spaces:
Sleeping
Sleeping
Upload predict.py
Browse files
img2art_search/models/predict.py
CHANGED
|
@@ -11,13 +11,10 @@ from img2art_search.models.compute_embeddings import search_image
|
|
| 11 |
|
| 12 |
|
| 13 |
def predict(img: Image.Image) -> list:
|
| 14 |
-
tmp_img_path = "tmp_img.png"
|
| 15 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 16 |
if img:
|
| 17 |
-
img.
|
| 18 |
-
|
| 19 |
-
pred_dataset = ImageRetrievalDataset(pred_img, transform=transform)
|
| 20 |
-
pred_image_data = pred_dataset[0][0].unsqueeze(0).to(DEVICE)
|
| 21 |
indices, distances = search_image(pred_image_data)
|
| 22 |
results = []
|
| 23 |
for index, distance in zip(indices, distances):
|
|
@@ -31,7 +28,6 @@ def predict(img: Image.Image) -> list:
|
|
| 31 |
str(distance),
|
| 32 |
)
|
| 33 |
)
|
| 34 |
-
os.remove(tmp_img_path)
|
| 35 |
return results
|
| 36 |
else:
|
| 37 |
return []
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
def predict(img: Image.Image) -> list:
|
|
|
|
| 14 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 15 |
if img:
|
| 16 |
+
img = img.convert("RGB")
|
| 17 |
+
pred_image_data = transform(img).unsqueeze(0).to(DEVICE)
|
|
|
|
|
|
|
| 18 |
indices, distances = search_image(pred_image_data)
|
| 19 |
results = []
|
| 20 |
for index, distance in zip(indices, distances):
|
|
|
|
| 28 |
str(distance),
|
| 29 |
)
|
| 30 |
)
|
|
|
|
| 31 |
return results
|
| 32 |
else:
|
| 33 |
return []
|