offry commited on
Commit
8ff538a
·
1 Parent(s): 0d6507b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -33
app.py CHANGED
@@ -13,7 +13,8 @@ import torch.nn.functional as F
13
  import numpy as np
14
  import matplotlib.pyplot as plt
15
  import random
16
- from skimage.transform import resize
 
17
  from torchvision import transforms, models
18
 
19
 
@@ -302,23 +303,19 @@ def create_retrieval_figure(res):
302
  ax_query.set_title('Top 10 most similar scarabs', fontsize=40)
303
  names = ""
304
  for i, image in zip(range(len(res)), res):
305
- current_image_path = image
306
  if i==0: continue
307
  if i < 11:
308
- image = cv2.imread(current_image_path)
 
 
309
  # image_resized = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR)
310
  ax = fig.add_subplot(rows, cols, i)
311
  plt.axis('off')
312
  plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
313
- item_uuid = current_image_path.split("/")[4].split("_photoUUID")[0].split("itemUUID_")[1]
314
  ax.set_title('Top {}'.format(i), fontsize=40)
315
  names = names + "Top " + str(i) + " item UUID is " + item_uuid + "\n"
316
- # img_buf = io.BytesIO()
317
- # plt.savefig(img_buf, format='png')
318
- # im_fig = Image.open(img_buf)
319
- # img_buf.close()
320
- # return im_fig
321
-
322
  return fig, names
323
 
324
  def knn_calc(image_name, query_feature, features):
@@ -328,38 +325,29 @@ def knn_calc(image_name, query_feature, features):
328
  dist = -dist.item()
329
  return dist
330
 
331
- # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
332
- device = 'cpu'
333
-
334
- experiment = "experiment_0"
335
- checkpoint_path = os.path.join("../shapes_classification/checkpoints/"
336
- "50_50_pretrained_resnet101_experiment_0_train_images_with_drawings_batch_8_10:29:06/" +
337
- "experiment_0_last_auto_model.pth.tar")
338
  checkpoint_path = "multi_label.pth.tar"
339
 
340
  resnet = models.resnet101(pretrained=True)
341
  num_ftrs = resnet.fc.in_features
342
  resnet.fc = nn.Linear(num_ftrs, 13)
343
- model = Resnet_with_skip(resnet).to(device)
344
  checkpoint = torch.load(checkpoint_path, map_location="cpu")
345
  model.load_state_dict(checkpoint)
346
  embedding_model_test = torch.nn.Sequential(*(list(model.children())[:-1]))
347
- embedding_model_test.to(device)
348
 
349
  periods_model = models.resnet101(pretrained=True)
350
  periods_model.fc = nn.Linear(num_ftrs, 5)
351
  periods_checkpoint = torch.load("periods.pth.tar", map_location="cpu")
352
  periods_model.load_state_dict(periods_checkpoint)
353
- periods_model.to(device)
354
 
355
- data_dir = "../cssl_dataset/all_image_base/1/"
356
- query_images_paths = []
357
- for path in os.listdir(data_dir):
358
- query_images_paths.append(os.path.join(data_dir, path))
359
 
360
  with open('features.pkl', 'rb') as fp:
361
  features = pickle.load(fp)
362
 
 
 
363
  model.eval()
364
  transform = transforms.Compose([
365
  transforms.Resize((224, 224)),
@@ -373,21 +361,21 @@ invTrans = transforms.Compose([transforms.Normalize(mean=[0., 0., 0.],
373
  std=[1., 1., 1.]),
374
  ])
375
 
376
- labels = sorted(os.listdir("../cssl_dataset/shape_multi_label/photos"))
 
377
  periods_labels = ["MB1", "MB2", "LB", "Iron1", 'Iron2']
378
  periods_model.eval()
379
 
380
  def predict(inp):
381
  image_tensor = transform(inp)
382
- image_tensor = image_tensor.to(device)
383
  with torch.no_grad():
384
  classification, reconstruction = model(image_tensor.unsqueeze(0))
385
  periods_classification = periods_model(image_tensor.unsqueeze(0))
386
  recon_tensor = reconstruction[0].repeat(3, 1, 1)
387
  recon_tensor = invTrans(kornia.enhance.invert(recon_tensor))
388
- plot_recon = recon_tensor.to("cpu").permute(1, 2, 0).detach().numpy()
389
  w, h = inp.size
390
- plot_recon = resize(plot_recon, (h, w))
391
  m = nn.Sigmoid()
392
  y = m(classification)
393
  preds = []
@@ -397,7 +385,6 @@ def predict(inp):
397
  preds.append(1)
398
  else:
399
  preds.append(0)
400
- # prediction = torch.tensor(preds).to(device)
401
  confidences = {}
402
  true_labels = ""
403
  for i in range(len(labels)):
@@ -406,11 +393,11 @@ def predict(inp):
406
  true_labels = true_labels + labels[i]
407
  else:
408
  true_labels = true_labels + "&" + labels[i]
409
- confidences[true_labels] = torch.tensor(1.0).to(device)
410
 
411
  periods_prediction = torch.nn.functional.softmax(periods_classification[0], dim=0)
412
  periods_confidences = {periods_labels[i]: periods_prediction[i] for i in range(len(periods_labels))}
413
- feature = embedding_model_test(image_tensor.unsqueeze(0)).to(device)
414
  dists = dict()
415
  with torch.no_grad():
416
  for i, image_name in enumerate(query_images_paths):
@@ -418,9 +405,14 @@ def predict(inp):
418
  dists[image_name] = dist
419
  res = dict(sorted(dists.items(), key=itemgetter(1)))
420
  fig, names = create_retrieval_figure(res)
421
- return fig, names, plot_recon, confidences, periods_confidences
422
 
423
 
424
  gr.Interface(fn=predict,
425
  inputs=gr.Image(type="pil"),
426
- outputs=['plot', 'text', "image", gr.Label(num_top_classes=1), gr.Label(num_top_classes=1)], ).launch(share=True)
 
 
 
 
 
 
13
  import numpy as np
14
  import matplotlib.pyplot as plt
15
  import random
16
+ import zipfile
17
+ # from skimage.transform import resize
18
  from torchvision import transforms, models
19
 
20
 
 
303
  ax_query.set_title('Top 10 most similar scarabs', fontsize=40)
304
  names = ""
305
  for i, image in zip(range(len(res)), res):
306
+ current_image_path = image.split("/")[3]+"/"+image.split("/")[4]
307
  if i==0: continue
308
  if i < 11:
309
+ archive = zipfile.ZipFile('dataset.zip', 'r')
310
+ imgfile = archive.read(current_image_path)
311
+ image = cv2.imdecode(np.frombuffer(imgfile, np.uint8), 1)
312
  # image_resized = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR)
313
  ax = fig.add_subplot(rows, cols, i)
314
  plt.axis('off')
315
  plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
316
+ item_uuid = current_image_path.split("/")[1].split("_photoUUID")[0].split("itemUUID_")[1]
317
  ax.set_title('Top {}'.format(i), fontsize=40)
318
  names = names + "Top " + str(i) + " item UUID is " + item_uuid + "\n"
 
 
 
 
 
 
319
  return fig, names
320
 
321
  def knn_calc(image_name, query_feature, features):
 
325
  dist = -dist.item()
326
  return dist
327
 
 
 
 
 
 
 
 
328
  checkpoint_path = "multi_label.pth.tar"
329
 
330
  resnet = models.resnet101(pretrained=True)
331
  num_ftrs = resnet.fc.in_features
332
  resnet.fc = nn.Linear(num_ftrs, 13)
333
+ model = Resnet_with_skip(resnet)
334
  checkpoint = torch.load(checkpoint_path, map_location="cpu")
335
  model.load_state_dict(checkpoint)
336
  embedding_model_test = torch.nn.Sequential(*(list(model.children())[:-1]))
 
337
 
338
  periods_model = models.resnet101(pretrained=True)
339
  periods_model.fc = nn.Linear(num_ftrs, 5)
340
  periods_checkpoint = torch.load("periods.pth.tar", map_location="cpu")
341
  periods_model.load_state_dict(periods_checkpoint)
 
342
 
343
+ with open('query_images_paths.pkl', 'rb') as fp:
344
+ query_images_paths = pickle.load(fp)
 
 
345
 
346
  with open('features.pkl', 'rb') as fp:
347
  features = pickle.load(fp)
348
 
349
+
350
+
351
  model.eval()
352
  transform = transforms.Compose([
353
  transforms.Resize((224, 224)),
 
361
  std=[1., 1., 1.]),
362
  ])
363
 
364
+ labels = ['ankh', 'anthropomorphic', 'bands', 'beetle', 'bird', 'circles', 'cross', 'duck', 'head', 'ibex', 'lion', 'sa', 'snake']
365
+
366
  periods_labels = ["MB1", "MB2", "LB", "Iron1", 'Iron2']
367
  periods_model.eval()
368
 
369
  def predict(inp):
370
  image_tensor = transform(inp)
 
371
  with torch.no_grad():
372
  classification, reconstruction = model(image_tensor.unsqueeze(0))
373
  periods_classification = periods_model(image_tensor.unsqueeze(0))
374
  recon_tensor = reconstruction[0].repeat(3, 1, 1)
375
  recon_tensor = invTrans(kornia.enhance.invert(recon_tensor))
376
+ plot_recon = recon_tensor.permute(1, 2, 0).detach().numpy()
377
  w, h = inp.size
378
+ # plot_recon = resize(plot_recon, (h, w))
379
  m = nn.Sigmoid()
380
  y = m(classification)
381
  preds = []
 
385
  preds.append(1)
386
  else:
387
  preds.append(0)
 
388
  confidences = {}
389
  true_labels = ""
390
  for i in range(len(labels)):
 
393
  true_labels = true_labels + labels[i]
394
  else:
395
  true_labels = true_labels + "&" + labels[i]
396
+ confidences[true_labels] = torch.tensor(1.0)
397
 
398
  periods_prediction = torch.nn.functional.softmax(periods_classification[0], dim=0)
399
  periods_confidences = {periods_labels[i]: periods_prediction[i] for i in range(len(periods_labels))}
400
+ feature = embedding_model_test(image_tensor.unsqueeze(0))
401
  dists = dict()
402
  with torch.no_grad():
403
  for i, image_name in enumerate(query_images_paths):
 
405
  dists[image_name] = dist
406
  res = dict(sorted(dists.items(), key=itemgetter(1)))
407
  fig, names = create_retrieval_figure(res)
408
+ return confidences, periods_confidences, plot_recon, fig, names
409
 
410
 
411
  gr.Interface(fn=predict,
412
  inputs=gr.Image(type="pil"),
413
+ title="ArcAid: Analysis of Archaeological Artifacts using Drawings",
414
+ description="Easily classify artifacs, retrieve similar ones and generate drawings. "
415
+ "https://arxiv.org/abs/2211.09480.",
416
+ examples=['anth.jpg', 'beetle_snakes.jpg', 'bird.jpg', 'cross.jpg', 'ibex.jpg',
417
+ 'lion.jpg', 'lion2.jpg', 'sa.jpg'],
418
+ outputs=[gr.Label(num_top_classes=1), gr.Label(num_top_classes=1), "image", 'plot', 'text', "image"], ).launch(share=True)