offry commited on
Commit
2bc4604
·
verified ·
1 Parent(s): 1ff6e08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -55
app.py CHANGED
@@ -14,6 +14,7 @@ import zipfile
14
  from torchvision import transforms, models
15
  from get_models import Resnet_with_skip
16
 
 
17
  def create_retrieval_figure(res):
18
  fig = plt.figure(figsize=[10 * 3, 10 * 3])
19
  cols = 5
@@ -21,22 +22,22 @@ def create_retrieval_figure(res):
21
  ax_query = fig.add_subplot(rows, 1, 1)
22
  plt.rcParams['figure.facecolor'] = 'white'
23
  plt.axis('off')
24
- ax_query.set_title('Top 10 most similar scarabs', fontsize=40)
25
  names = ""
26
  for i, image in zip(range(len(res)), res):
27
- current_image_path = "dataset/" + image.split("/")[3]+"/"+image.split("/")[4]
28
- if i==0: continue
29
- if i < 11:
30
- archive = zipfile.ZipFile('dataset.zip', 'r')
31
- imgfile = archive.read(current_image_path)
32
- image = cv2.imdecode(np.frombuffer(imgfile, np.uint8), 1)
33
- # image_resized = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR)
34
- ax = fig.add_subplot(rows, cols, i)
35
- plt.axis('off')
36
- plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
37
- item_uuid = current_image_path.split("/")[2].split("_photoUUID")[0].split("itemUUID_")[1]
38
- ax.set_title('Top {}'.format(i), fontsize=40)
39
- names = names + "Top " + str(i) + " item UUID is " + item_uuid + "\n"
40
  return fig, names
41
 
42
  def knn_calc(image_name, query_feature, features):
@@ -54,12 +55,13 @@ resnet.fc = nn.Linear(num_ftrs, 13)
54
  model = Resnet_with_skip(resnet)
55
  checkpoint = torch.load(checkpoint_path, map_location="cpu")
56
  model.load_state_dict(checkpoint)
 
57
  embedding_model_test = torch.nn.Sequential(*(list(model.children())[:-1]))
58
 
59
- periods_model = models.resnet101(pretrained=True)
60
- periods_model.fc = nn.Linear(num_ftrs, 5)
61
- periods_checkpoint = torch.load("periods.pth.tar", map_location="cpu")
62
- periods_model.load_state_dict(periods_checkpoint)
63
 
64
  with open('query_images_paths.pkl', 'rb') as fp:
65
  query_images_paths = pickle.load(fp)
@@ -71,7 +73,7 @@ with open('features.pkl', 'rb') as fp:
71
 
72
  model.eval()
73
  transform = transforms.Compose([
74
- transforms.Resize((112, 112)),
75
  transforms.Grayscale(num_output_channels=3),
76
  transforms.ToTensor(),
77
  transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
@@ -82,42 +84,41 @@ invTrans = transforms.Compose([transforms.Normalize(mean=[0., 0., 0.],
82
  std=[1., 1., 1.]),
83
  ])
84
 
85
- labels = ['ankh', 'anthropomorphic', 'bands', 'beetle', 'bird', 'circles', 'cross', 'duck', 'head', 'ibex', 'lion', 'sa', 'snake']
86
-
87
- periods_labels = ["MB1", "MB2", "LB", "Iron1", 'Iron2']
88
- periods_model.eval()
89
 
90
  def predict(inp):
91
  image_tensor = transform(inp)
92
  with torch.no_grad():
93
- classification, reconstruction = model(image_tensor.unsqueeze(0))
94
- periods_classification = periods_model(image_tensor.unsqueeze(0))
95
- recon_tensor = reconstruction[0].repeat(3, 1, 1)
96
- recon_tensor = invTrans(kornia.enhance.invert(recon_tensor))
97
- plot_recon = recon_tensor.permute(1, 2, 0).detach().numpy()
98
- w, h = inp.size
99
- # plot_recon = resize(plot_recon, (h, w))
100
- m = nn.Sigmoid()
101
- y = m(classification)
102
- preds = []
103
- for sample in y:
104
- for i in sample:
105
- if i >=0.8:
106
- preds.append(1)
107
- else:
108
- preds.append(0)
109
- confidences = {}
110
- true_labels = ""
111
- for i in range(len(labels)):
112
- if preds[i]==1:
113
- if true_labels=="":
114
- true_labels = true_labels + labels[i]
115
- else:
116
- true_labels = true_labels + "&" + labels[i]
117
- confidences[true_labels] = torch.tensor(1.0)
118
-
119
- periods_prediction = torch.nn.functional.softmax(periods_classification[0], dim=0)
120
- periods_confidences = {periods_labels[i]: periods_prediction[i] for i in range(len(periods_labels))}
121
  feature = embedding_model_test(image_tensor.unsqueeze(0))
122
  dists = dict()
123
  with torch.no_grad():
@@ -126,7 +127,7 @@ def predict(inp):
126
  dists[image_name] = dist
127
  res = dict(sorted(dists.items(), key=itemgetter(1)))
128
  fig, names = create_retrieval_figure(res)
129
- return confidences, periods_confidences, fig, names, plot_recon
130
 
131
 
132
  a = gr.Interface(fn=predict,
@@ -134,6 +135,7 @@ a = gr.Interface(fn=predict,
134
  title="ArcAid: Analysis of Archaeological Artifacts using Drawings",
135
  description="Easily classify artifacs, retrieve similar ones and generate drawings. "
136
  "https://arxiv.org/abs/2211.09480.",
137
- # examples=['anth.jpg', 'beetle_snakes.jpg', 'bird.jpg', 'cross.jpg', 'ibex.jpg',
138
- # 'lion.jpg', 'lion2.jpg', 'sa.jpg'],
139
- outputs=[gr.Label(num_top_classes=3), gr.Label(num_top_classes=1), 'plot', 'text', "image"], ).launch()
 
 
14
  from torchvision import transforms, models
15
  from get_models import Resnet_with_skip
16
 
17
+
18
  def create_retrieval_figure(res):
19
  fig = plt.figure(figsize=[10 * 3, 10 * 3])
20
  cols = 5
 
22
  ax_query = fig.add_subplot(rows, 1, 1)
23
  plt.rcParams['figure.facecolor'] = 'white'
24
  plt.axis('off')
25
+ ax_query.set_title('Top 10 most similar items', fontsize=40)
26
  names = ""
27
  for i, image in zip(range(len(res)), res):
28
+ if i >= 10:
29
+ break
30
+ current_image_path = image.split("/")[3]+"/"+image.split("/")[4]
31
+ archive = zipfile.ZipFile('dataset.zip', 'r')
32
+ imgfile = archive.read(current_image_path)
33
+ image = cv2.imdecode(np.frombuffer(imgfile, np.uint8), 1)
34
+ # image_resized = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR)
35
+ ax = fig.add_subplot(rows, cols, i + 1)
36
+ plt.axis('off')
37
+ plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
38
+ item_uuid = current_image_path.split("/")[1].split("_photoUUID")[0].split("itemUUID_")[1]
39
+ ax.set_title('Top {}'.format(i), fontsize=40)
40
+ names = names + "Top " + str(i) + " item UUID is " + item_uuid + "\n"
41
  return fig, names
42
 
43
  def knn_calc(image_name, query_feature, features):
 
55
  model = Resnet_with_skip(resnet)
56
  checkpoint = torch.load(checkpoint_path, map_location="cpu")
57
  model.load_state_dict(checkpoint)
58
+ model.eval()
59
  embedding_model_test = torch.nn.Sequential(*(list(model.children())[:-1]))
60
 
61
+ # periods_model = models.resnet101(pretrained=True)
62
+ # periods_model.fc = nn.Linear(num_ftrs, 5)
63
+ # periods_checkpoint = torch.load("periods.pth.tar", map_location="cpu")
64
+ # periods_model.load_state_dict(periods_checkpoint)
65
 
66
  with open('query_images_paths.pkl', 'rb') as fp:
67
  query_images_paths = pickle.load(fp)
 
73
 
74
  model.eval()
75
  transform = transforms.Compose([
76
+ transforms.Resize((224, 224)),
77
  transforms.Grayscale(num_output_channels=3),
78
  transforms.ToTensor(),
79
  transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
 
84
  std=[1., 1., 1.]),
85
  ])
86
 
87
+ # labels = ['ankh', 'anthropomorphic', 'bands', 'beetle', 'bird', 'circles', 'cross', 'duck', 'head', 'ibex', 'lion', 'sa', 'snake']
88
+ #
89
+ # periods_labels = ["MB1", "MB2", "LB", "Iron1", 'Iron2']
90
+ # periods_model.eval()
91
 
92
  def predict(inp):
93
  image_tensor = transform(inp)
94
  with torch.no_grad():
95
+ # classification, reconstruction = model(image_tensor.unsqueeze(0))
96
+ # periods_classification = periods_model(image_tensor.unsqueeze(0))
97
+ # recon_tensor = reconstruction[0].repeat(3, 1, 1)
98
+ # recon_tensor = invTrans(kornia.enhance.invert(recon_tensor))
99
+ # plot_recon = recon_tensor.permute(1, 2, 0).detach().numpy()
100
+ # w, h = inp.size
101
+ # m = nn.Sigmoid()
102
+ # y = m(classification)
103
+ # preds = []
104
+ # for sample in y:
105
+ # for i in sample:
106
+ # if i >=0.8:
107
+ # preds.append(1)
108
+ # else:
109
+ # preds.append(0)
110
+ # confidences = {}
111
+ # true_labels = ""
112
+ # for i in range(len(labels)):
113
+ # if preds[i]==1:
114
+ # if true_labels=="":
115
+ # true_labels = true_labels + labels[i]
116
+ # else:
117
+ # true_labels = true_labels + "&" + labels[i]
118
+ # confidences[true_labels] = torch.tensor(1.0)
119
+ #
120
+ # periods_prediction = torch.nn.functional.softmax(periods_classification[0], dim=0)
121
+ # periods_confidences = {periods_labels[i]: periods_prediction[i] for i in range(len(periods_labels))}
 
122
  feature = embedding_model_test(image_tensor.unsqueeze(0))
123
  dists = dict()
124
  with torch.no_grad():
 
127
  dists[image_name] = dist
128
  res = dict(sorted(dists.items(), key=itemgetter(1)))
129
  fig, names = create_retrieval_figure(res)
130
+ return fig, names
131
 
132
 
133
  a = gr.Interface(fn=predict,
 
135
  title="ArcAid: Analysis of Archaeological Artifacts using Drawings",
136
  description="Easily classify artifacs, retrieve similar ones and generate drawings. "
137
  "https://arxiv.org/abs/2211.09480.",
138
+ examples=['anth.jpg', 'beetle_snakes.jpg', 'bird.jpg', 'cross.jpg', 'ibex.jpg',
139
+ 'lion.jpg', 'lion2.jpg', 'sa.jpg'],
140
+ outputs=['plot', 'text'], ).launch(share=True, enable_queue=True)
141
+ # outputs=[gr.Label(num_top_classes=3), gr.Label(num_top_classes=1), "image", 'plot', 'text'], ).launch(share=True, enable_queue=True)