offry commited on
Commit
c08cf9f
·
1 Parent(s): cf05ca8
Files changed (1) hide show
  1. app.py +140 -0
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from operator import itemgetter
3
+
4
+ import cv2
5
+ import gradio as gr
6
+ import kornia.filters
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import numpy as np
11
+ import matplotlib.pyplot as plt
12
+ import zipfile
13
+ # from skimage.transform import resize
14
+ from torchvision import transforms, models
15
+ from get_models import Resnet_with_skip
16
+
17
+ def create_retrieval_figure(res):
18
+ fig = plt.figure(figsize=[10 * 3, 10 * 3])
19
+ cols = 5
20
+ rows = 2
21
+ ax_query = fig.add_subplot(rows, 1, 1)
22
+ plt.rcParams['figure.facecolor'] = 'white'
23
+ plt.axis('off')
24
+ ax_query.set_title('Top 10 most similar scarabs', fontsize=40)
25
+ names = ""
26
+ for i, image in zip(range(len(res)), res):
27
+ current_image_path = image.split("/")[3]+"/"+image.split("/")[4]
28
+ if i==0: continue
29
+ if i < 11:
30
+ archive = zipfile.ZipFile('dataset.zip', 'r')
31
+ current_image_path = current_image_path.split(".")[0] + ".gif"
32
+ imgfile = archive.read(current_image_path)
33
+ image = cv2.imdecode(np.frombuffer(imgfile, np.uint8), 1)
34
+ # image_resized = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR)
35
+ ax = fig.add_subplot(rows, cols, i)
36
+ plt.axis('off')
37
+ plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
38
+ item_uuid = current_image_path.split("/")[1].split("_photoUUID")[0].split("itemUUID_")[1]
39
+ ax.set_title('Top {}'.format(i), fontsize=40)
40
+ names = names + "Top " + str(i) + " item UUID is " + item_uuid + "\n"
41
+ return fig, names
42
+
43
+ def knn_calc(image_name, query_feature, features):
44
+ current_image_feature = features[image_name]
45
+ criterion = torch.nn.CosineSimilarity(dim=1)
46
+ dist = criterion(query_feature, current_image_feature).mean()
47
+ dist = -dist.item()
48
+ return dist
49
+
50
+ checkpoint_path = "multi_label.pth.tar"
51
+
52
+ resnet = models.resnet101(pretrained=True)
53
+ num_ftrs = resnet.fc.in_features
54
+ resnet.fc = nn.Linear(num_ftrs, 13)
55
+ model = Resnet_with_skip(resnet)
56
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
57
+ model.load_state_dict(checkpoint)
58
+ embedding_model_test = torch.nn.Sequential(*(list(model.children())[:-1]))
59
+
60
+ periods_model = models.resnet101(pretrained=True)
61
+ periods_model.fc = nn.Linear(num_ftrs, 5)
62
+ periods_checkpoint = torch.load("periods.pth.tar", map_location="cpu")
63
+ periods_model.load_state_dict(periods_checkpoint)
64
+
65
+ with open('query_images_paths.pkl', 'rb') as fp:
66
+ query_images_paths = pickle.load(fp)
67
+
68
+ with open('features.pkl', 'rb') as fp:
69
+ features = pickle.load(fp)
70
+
71
+
72
+
73
+ model.eval()
74
+ transform = transforms.Compose([
75
+ transforms.Resize((224, 224)),
76
+ transforms.Grayscale(num_output_channels=3),
77
+ transforms.ToTensor(),
78
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
79
+ ])
80
+ invTrans = transforms.Compose([transforms.Normalize(mean=[0., 0., 0.],
81
+ std=[1 / 0.5, 1 / 0.5, 1 / 0.5]),
82
+ transforms.Normalize(mean=[-0.5, -0.5, -0.5],
83
+ std=[1., 1., 1.]),
84
+ ])
85
+
86
+ labels = ['ankh', 'anthropomorphic', 'bands', 'beetle', 'bird', 'circles', 'cross', 'duck', 'head', 'ibex', 'lion', 'sa', 'snake']
87
+
88
+ periods_labels = ["MB1", "MB2", "LB", "Iron1", 'Iron2']
89
+ periods_model.eval()
90
+
91
+ def predict(inp):
92
+ image_tensor = transform(inp)
93
+ with torch.no_grad():
94
+ classification, reconstruction = model(image_tensor.unsqueeze(0))
95
+ periods_classification = periods_model(image_tensor.unsqueeze(0))
96
+ recon_tensor = reconstruction[0].repeat(3, 1, 1)
97
+ recon_tensor = invTrans(kornia.enhance.invert(recon_tensor))
98
+ plot_recon = recon_tensor.permute(1, 2, 0).detach().numpy()
99
+ w, h = inp.size
100
+ # plot_recon = resize(plot_recon, (h, w))
101
+ m = nn.Sigmoid()
102
+ y = m(classification)
103
+ preds = []
104
+ for sample in y:
105
+ for i in sample:
106
+ if i >=0.8:
107
+ preds.append(1)
108
+ else:
109
+ preds.append(0)
110
+ confidences = {}
111
+ true_labels = ""
112
+ for i in range(len(labels)):
113
+ if preds[i]==1:
114
+ if true_labels=="":
115
+ true_labels = true_labels + labels[i]
116
+ else:
117
+ true_labels = true_labels + "&" + labels[i]
118
+ confidences[true_labels] = torch.tensor(1.0)
119
+
120
+ periods_prediction = torch.nn.functional.softmax(periods_classification[0], dim=0)
121
+ periods_confidences = {periods_labels[i]: periods_prediction[i] for i in range(len(periods_labels))}
122
+ feature = embedding_model_test(image_tensor.unsqueeze(0))
123
+ dists = dict()
124
+ with torch.no_grad():
125
+ for i, image_name in enumerate(query_images_paths):
126
+ dist = knn_calc(image_name, feature, features)
127
+ dists[image_name] = dist
128
+ res = dict(sorted(dists.items(), key=itemgetter(1)))
129
+ fig, names = create_retrieval_figure(res)
130
+ return confidences, periods_confidences, plot_recon, fig, names
131
+
132
+
133
+ a = gr.Interface(fn=predict,
134
+ inputs=gr.Image(type="pil"),
135
+ title="ArcAid: Analysis of Archaeological Artifacts using Drawings",
136
+ description="Easily classify artifacs, retrieve similar ones and generate drawings. "
137
+ "https://arxiv.org/abs/2211.09480.",
138
+ # examples=['anth.jpg', 'beetle_snakes.jpg', 'bird.jpg', 'cross.jpg', 'ibex.jpg',
139
+ # 'lion.jpg', 'lion2.jpg', 'sa.jpg'],
140
+ outputs=[gr.Label(num_top_classes=3), gr.Label(num_top_classes=1), "image", 'plot', 'text'], ).launch(share=True, enable_queue=True)