Update app.py
Browse files
app.py
CHANGED
@@ -13,7 +13,8 @@ import torch.nn.functional as F
|
|
13 |
import numpy as np
|
14 |
import matplotlib.pyplot as plt
|
15 |
import random
|
16 |
-
|
|
|
17 |
from torchvision import transforms, models
|
18 |
|
19 |
|
@@ -302,23 +303,19 @@ def create_retrieval_figure(res):
|
|
302 |
ax_query.set_title('Top 10 most similar scarabs', fontsize=40)
|
303 |
names = ""
|
304 |
for i, image in zip(range(len(res)), res):
|
305 |
-
current_image_path = image
|
306 |
if i==0: continue
|
307 |
if i < 11:
|
308 |
-
|
|
|
|
|
309 |
# image_resized = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR)
|
310 |
ax = fig.add_subplot(rows, cols, i)
|
311 |
plt.axis('off')
|
312 |
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
313 |
-
item_uuid = current_image_path.split("/")[
|
314 |
ax.set_title('Top {}'.format(i), fontsize=40)
|
315 |
names = names + "Top " + str(i) + " item UUID is " + item_uuid + "\n"
|
316 |
-
# img_buf = io.BytesIO()
|
317 |
-
# plt.savefig(img_buf, format='png')
|
318 |
-
# im_fig = Image.open(img_buf)
|
319 |
-
# img_buf.close()
|
320 |
-
# return im_fig
|
321 |
-
|
322 |
return fig, names
|
323 |
|
324 |
def knn_calc(image_name, query_feature, features):
|
@@ -328,38 +325,29 @@ def knn_calc(image_name, query_feature, features):
|
|
328 |
dist = -dist.item()
|
329 |
return dist
|
330 |
|
331 |
-
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
332 |
-
device = 'cpu'
|
333 |
-
|
334 |
-
experiment = "experiment_0"
|
335 |
-
checkpoint_path = os.path.join("../shapes_classification/checkpoints/"
|
336 |
-
"50_50_pretrained_resnet101_experiment_0_train_images_with_drawings_batch_8_10:29:06/" +
|
337 |
-
"experiment_0_last_auto_model.pth.tar")
|
338 |
checkpoint_path = "multi_label.pth.tar"
|
339 |
|
340 |
resnet = models.resnet101(pretrained=True)
|
341 |
num_ftrs = resnet.fc.in_features
|
342 |
resnet.fc = nn.Linear(num_ftrs, 13)
|
343 |
-
model = Resnet_with_skip(resnet)
|
344 |
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
345 |
model.load_state_dict(checkpoint)
|
346 |
embedding_model_test = torch.nn.Sequential(*(list(model.children())[:-1]))
|
347 |
-
embedding_model_test.to(device)
|
348 |
|
349 |
periods_model = models.resnet101(pretrained=True)
|
350 |
periods_model.fc = nn.Linear(num_ftrs, 5)
|
351 |
periods_checkpoint = torch.load("periods.pth.tar", map_location="cpu")
|
352 |
periods_model.load_state_dict(periods_checkpoint)
|
353 |
-
periods_model.to(device)
|
354 |
|
355 |
-
|
356 |
-
query_images_paths =
|
357 |
-
for path in os.listdir(data_dir):
|
358 |
-
query_images_paths.append(os.path.join(data_dir, path))
|
359 |
|
360 |
with open('features.pkl', 'rb') as fp:
|
361 |
features = pickle.load(fp)
|
362 |
|
|
|
|
|
363 |
model.eval()
|
364 |
transform = transforms.Compose([
|
365 |
transforms.Resize((224, 224)),
|
@@ -373,21 +361,21 @@ invTrans = transforms.Compose([transforms.Normalize(mean=[0., 0., 0.],
|
|
373 |
std=[1., 1., 1.]),
|
374 |
])
|
375 |
|
376 |
-
labels =
|
|
|
377 |
periods_labels = ["MB1", "MB2", "LB", "Iron1", 'Iron2']
|
378 |
periods_model.eval()
|
379 |
|
380 |
def predict(inp):
|
381 |
image_tensor = transform(inp)
|
382 |
-
image_tensor = image_tensor.to(device)
|
383 |
with torch.no_grad():
|
384 |
classification, reconstruction = model(image_tensor.unsqueeze(0))
|
385 |
periods_classification = periods_model(image_tensor.unsqueeze(0))
|
386 |
recon_tensor = reconstruction[0].repeat(3, 1, 1)
|
387 |
recon_tensor = invTrans(kornia.enhance.invert(recon_tensor))
|
388 |
-
plot_recon = recon_tensor.
|
389 |
w, h = inp.size
|
390 |
-
plot_recon = resize(plot_recon, (h, w))
|
391 |
m = nn.Sigmoid()
|
392 |
y = m(classification)
|
393 |
preds = []
|
@@ -397,7 +385,6 @@ def predict(inp):
|
|
397 |
preds.append(1)
|
398 |
else:
|
399 |
preds.append(0)
|
400 |
-
# prediction = torch.tensor(preds).to(device)
|
401 |
confidences = {}
|
402 |
true_labels = ""
|
403 |
for i in range(len(labels)):
|
@@ -406,11 +393,11 @@ def predict(inp):
|
|
406 |
true_labels = true_labels + labels[i]
|
407 |
else:
|
408 |
true_labels = true_labels + "&" + labels[i]
|
409 |
-
confidences[true_labels] = torch.tensor(1.0)
|
410 |
|
411 |
periods_prediction = torch.nn.functional.softmax(periods_classification[0], dim=0)
|
412 |
periods_confidences = {periods_labels[i]: periods_prediction[i] for i in range(len(periods_labels))}
|
413 |
-
feature = embedding_model_test(image_tensor.unsqueeze(0))
|
414 |
dists = dict()
|
415 |
with torch.no_grad():
|
416 |
for i, image_name in enumerate(query_images_paths):
|
@@ -418,9 +405,14 @@ def predict(inp):
|
|
418 |
dists[image_name] = dist
|
419 |
res = dict(sorted(dists.items(), key=itemgetter(1)))
|
420 |
fig, names = create_retrieval_figure(res)
|
421 |
-
return
|
422 |
|
423 |
|
424 |
gr.Interface(fn=predict,
|
425 |
inputs=gr.Image(type="pil"),
|
426 |
-
|
|
|
|
|
|
|
|
|
|
|
|
13 |
import numpy as np
|
14 |
import matplotlib.pyplot as plt
|
15 |
import random
|
16 |
+
import zipfile
|
17 |
+
# from skimage.transform import resize
|
18 |
from torchvision import transforms, models
|
19 |
|
20 |
|
|
|
303 |
ax_query.set_title('Top 10 most similar scarabs', fontsize=40)
|
304 |
names = ""
|
305 |
for i, image in zip(range(len(res)), res):
|
306 |
+
current_image_path = image.split("/")[3]+"/"+image.split("/")[4]
|
307 |
if i==0: continue
|
308 |
if i < 11:
|
309 |
+
archive = zipfile.ZipFile('dataset.zip', 'r')
|
310 |
+
imgfile = archive.read(current_image_path)
|
311 |
+
image = cv2.imdecode(np.frombuffer(imgfile, np.uint8), 1)
|
312 |
# image_resized = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR)
|
313 |
ax = fig.add_subplot(rows, cols, i)
|
314 |
plt.axis('off')
|
315 |
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
316 |
+
item_uuid = current_image_path.split("/")[1].split("_photoUUID")[0].split("itemUUID_")[1]
|
317 |
ax.set_title('Top {}'.format(i), fontsize=40)
|
318 |
names = names + "Top " + str(i) + " item UUID is " + item_uuid + "\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
return fig, names
|
320 |
|
321 |
def knn_calc(image_name, query_feature, features):
|
|
|
325 |
dist = -dist.item()
|
326 |
return dist
|
327 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
checkpoint_path = "multi_label.pth.tar"
|
329 |
|
330 |
resnet = models.resnet101(pretrained=True)
|
331 |
num_ftrs = resnet.fc.in_features
|
332 |
resnet.fc = nn.Linear(num_ftrs, 13)
|
333 |
+
model = Resnet_with_skip(resnet)
|
334 |
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
335 |
model.load_state_dict(checkpoint)
|
336 |
embedding_model_test = torch.nn.Sequential(*(list(model.children())[:-1]))
|
|
|
337 |
|
338 |
periods_model = models.resnet101(pretrained=True)
|
339 |
periods_model.fc = nn.Linear(num_ftrs, 5)
|
340 |
periods_checkpoint = torch.load("periods.pth.tar", map_location="cpu")
|
341 |
periods_model.load_state_dict(periods_checkpoint)
|
|
|
342 |
|
343 |
+
with open('query_images_paths.pkl', 'rb') as fp:
|
344 |
+
query_images_paths = pickle.load(fp)
|
|
|
|
|
345 |
|
346 |
with open('features.pkl', 'rb') as fp:
|
347 |
features = pickle.load(fp)
|
348 |
|
349 |
+
|
350 |
+
|
351 |
model.eval()
|
352 |
transform = transforms.Compose([
|
353 |
transforms.Resize((224, 224)),
|
|
|
361 |
std=[1., 1., 1.]),
|
362 |
])
|
363 |
|
364 |
+
labels = ['ankh', 'anthropomorphic', 'bands', 'beetle', 'bird', 'circles', 'cross', 'duck', 'head', 'ibex', 'lion', 'sa', 'snake']
|
365 |
+
|
366 |
periods_labels = ["MB1", "MB2", "LB", "Iron1", 'Iron2']
|
367 |
periods_model.eval()
|
368 |
|
369 |
def predict(inp):
|
370 |
image_tensor = transform(inp)
|
|
|
371 |
with torch.no_grad():
|
372 |
classification, reconstruction = model(image_tensor.unsqueeze(0))
|
373 |
periods_classification = periods_model(image_tensor.unsqueeze(0))
|
374 |
recon_tensor = reconstruction[0].repeat(3, 1, 1)
|
375 |
recon_tensor = invTrans(kornia.enhance.invert(recon_tensor))
|
376 |
+
plot_recon = recon_tensor.permute(1, 2, 0).detach().numpy()
|
377 |
w, h = inp.size
|
378 |
+
# plot_recon = resize(plot_recon, (h, w))
|
379 |
m = nn.Sigmoid()
|
380 |
y = m(classification)
|
381 |
preds = []
|
|
|
385 |
preds.append(1)
|
386 |
else:
|
387 |
preds.append(0)
|
|
|
388 |
confidences = {}
|
389 |
true_labels = ""
|
390 |
for i in range(len(labels)):
|
|
|
393 |
true_labels = true_labels + labels[i]
|
394 |
else:
|
395 |
true_labels = true_labels + "&" + labels[i]
|
396 |
+
confidences[true_labels] = torch.tensor(1.0)
|
397 |
|
398 |
periods_prediction = torch.nn.functional.softmax(periods_classification[0], dim=0)
|
399 |
periods_confidences = {periods_labels[i]: periods_prediction[i] for i in range(len(periods_labels))}
|
400 |
+
feature = embedding_model_test(image_tensor.unsqueeze(0))
|
401 |
dists = dict()
|
402 |
with torch.no_grad():
|
403 |
for i, image_name in enumerate(query_images_paths):
|
|
|
405 |
dists[image_name] = dist
|
406 |
res = dict(sorted(dists.items(), key=itemgetter(1)))
|
407 |
fig, names = create_retrieval_figure(res)
|
408 |
+
return confidences, periods_confidences, plot_recon, fig, names
|
409 |
|
410 |
|
411 |
gr.Interface(fn=predict,
|
412 |
inputs=gr.Image(type="pil"),
|
413 |
+
title="ArcAid: Analysis of Archaeological Artifacts using Drawings",
|
414 |
+
description="Easily classify artifacs, retrieve similar ones and generate drawings. "
|
415 |
+
"https://arxiv.org/abs/2211.09480.",
|
416 |
+
examples=['anth.jpg', 'beetle_snakes.jpg', 'bird.jpg', 'cross.jpg', 'ibex.jpg',
|
417 |
+
'lion.jpg', 'lion2.jpg', 'sa.jpg'],
|
418 |
+
outputs=[gr.Label(num_top_classes=1), gr.Label(num_top_classes=1), "image", 'plot', 'text', "image"], ).launch(share=True)
|