offry commited on
Commit
c6b8c55
·
1 Parent(s): 4a1b3a5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +426 -0
app.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ from operator import itemgetter
4
+
5
+ import cv2
6
+ import gradio as gr
7
+ import kornia.filters
8
+ import kornia.filters
9
+ import scipy.ndimage
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import numpy as np
14
+ import matplotlib.pyplot as plt
15
+ import random
16
+ from skimage.transform import resize
17
+ from torchvision import transforms, models
18
+
19
+
20
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
21
+ """3x3 convolution with padding"""
22
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
23
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
24
+
25
+
26
+ def conv1x1(in_planes, out_planes, stride=1):
27
+ """1x1 convolution"""
28
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
29
+
30
+
31
+ class DoubleConv(nn.Module):
32
+ """(convolution => [BN] => ReLU) * 2"""
33
+
34
+ def __init__(self, in_channels, out_channels, mid_channels=None):
35
+ super().__init__()
36
+ if not mid_channels:
37
+ mid_channels = out_channels
38
+ norm_layer = nn.BatchNorm2d
39
+
40
+ self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False)
41
+ self.bn1 = nn.BatchNorm2d(mid_channels)
42
+ self.inst1 = nn.InstanceNorm2d(mid_channels)
43
+ # self.gn1 = nn.GroupNorm(4, mid_channels)
44
+ self.relu = nn.ReLU(inplace=True)
45
+ self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False)
46
+ self.bn2 = nn.BatchNorm2d(out_channels)
47
+ self.inst2 = nn.InstanceNorm2d(out_channels)
48
+ # self.gn2 = nn.GroupNorm(4, out_channels)
49
+ self.downsample = None
50
+ if in_channels != out_channels:
51
+ self.downsample = nn.Sequential(
52
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
53
+ nn.BatchNorm2d(out_channels),
54
+ )
55
+
56
+ def forward(self, x):
57
+ identity = x
58
+
59
+ out = self.conv1(x)
60
+ # out = self.bn1(out)
61
+ out = self.inst1(out)
62
+ # out = self.gn1(out)
63
+ out = self.relu(out)
64
+
65
+ out = self.conv2(out)
66
+ # out = self.bn2(out)
67
+ out = self.inst2(out)
68
+ # out = self.gn2(out)
69
+ if self.downsample is not None:
70
+ identity = self.downsample(x)
71
+
72
+ out += identity
73
+ out = self.relu(out)
74
+ return out
75
+
76
+
77
+ class Down(nn.Module):
78
+ """Downscaling with maxpool then double conv"""
79
+
80
+ def __init__(self, in_channels, out_channels):
81
+ super().__init__()
82
+ self.maxpool_conv = nn.Sequential(
83
+ nn.MaxPool2d(2),
84
+ DoubleConv(in_channels, out_channels)
85
+ )
86
+
87
+ def forward(self, x):
88
+ return self.maxpool_conv(x)
89
+
90
+
91
+ class Up(nn.Module):
92
+ """Upscaling then double conv"""
93
+
94
+ def __init__(self, in_channels, out_channels, bilinear=True):
95
+ super().__init__()
96
+
97
+ # if bilinear, use the normal convolutions to reduce the number of channels
98
+ if bilinear:
99
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
100
+ self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
101
+ else:
102
+ if in_channels == out_channels:
103
+ self.up = nn.Identity()
104
+ else:
105
+ self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
106
+ self.conv = DoubleConv(in_channels, out_channels)
107
+
108
+ def forward(self, x1, x2):
109
+ x1 = self.up(x1)
110
+ # input is CHW
111
+ diffY = x2.size()[2] - x1.size()[2]
112
+ diffX = x2.size()[3] - x1.size()[3]
113
+
114
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
115
+ diffY // 2, diffY - diffY // 2])
116
+ # if you have padding issues, see
117
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
118
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
119
+ x = torch.cat([x2, x1], dim=1)
120
+ return self.conv(x)
121
+
122
+
123
+ class OutConv(nn.Module):
124
+ def __init__(self, in_channels, out_channels):
125
+ super(OutConv, self).__init__()
126
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
127
+
128
+ def forward(self, x):
129
+ return self.conv(x)
130
+
131
+ class GaussianLayer(nn.Module):
132
+ def __init__(self):
133
+ super(GaussianLayer, self).__init__()
134
+ self.seq = nn.Sequential(
135
+ # nn.ReflectionPad2d(10),
136
+ nn.Conv2d(1, 1, 5, stride=1, padding=2, bias=False)
137
+ )
138
+
139
+ self.weights_init()
140
+ def forward(self, x):
141
+ return self.seq(x)
142
+
143
+ def weights_init(self):
144
+ n= np.zeros((5,5))
145
+ n[3,3] = 1
146
+ k = scipy.ndimage.gaussian_filter(n,sigma=1)
147
+ for name, f in self.named_parameters():
148
+ f.data.copy_(torch.from_numpy(k))
149
+
150
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
151
+ """3x3 convolution with padding"""
152
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
153
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
154
+
155
+ class Decoder(nn.Module):
156
+ def __init__(self):
157
+ super(Decoder, self).__init__()
158
+ self.up1 = Up(2048, 1024 // 1, False)
159
+ self.up2 = Up(1024, 512 // 1, False)
160
+ self.up3 = Up(512, 256 // 1, False)
161
+ self.conv2d_2_1 = conv3x3(256, 128)
162
+ self.gn1 = nn.GroupNorm(4, 128)
163
+ self.instance1 = nn.InstanceNorm2d(128)
164
+ self.up4 = Up(128, 64 // 1, False)
165
+ self.upsample4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
166
+ # self.upsample4 = nn.ConvTranspose2d(64, 64, 2, stride=2)
167
+ self.upsample4_conv = DoubleConv(64, 64, 64 // 2)
168
+ self.up_ = Up(128, 128 // 1, False)
169
+ self.conv2d_2_2 = conv3x3(128, 6)
170
+ self.instance2 = nn.InstanceNorm2d(6)
171
+ self.gn2 = nn.GroupNorm(3, 6)
172
+ self.gaussian_blur = GaussianLayer()
173
+ self.up5 = Up(6, 3, False)
174
+ self.conv2d_2_3 = conv3x3(3, 1)
175
+ self.instance3 = nn.InstanceNorm2d(1)
176
+ self.gaussian_blur = GaussianLayer()
177
+ self.kernel = nn.Parameter(torch.tensor(
178
+ [[[0.0, 0.0, 0.0], [0.0, 1.0, random.uniform(-1.0, 0.0)], [0.0, 0.0, 0.0]],
179
+ [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, random.uniform(-1.0, 0.0)]],
180
+ [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, random.uniform(random.uniform(-1.0, 0.0), -0.0), 0.0]],
181
+ [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [random.uniform(-1.0, 0.0), 0.0, 0.0]],
182
+ [[0.0, 0.0, 0.0], [random.uniform(-1.0, 0.0), 1.0, 0.0], [0.0, 0.0, 0.0]],
183
+ [[random.uniform(-1.0, 0.0), 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
184
+ [[0.0, random.uniform(-1.0, 0.0), 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
185
+ [[0.0, 0.0, random.uniform(-1.0, 0.0)], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], ],
186
+ ).unsqueeze(1))
187
+
188
+ self.nms_conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False, groups=1)
189
+ with torch.no_grad():
190
+ self.nms_conv.weight = self.kernel.float()
191
+
192
+
193
+ class Resnet_with_skip(nn.Module):
194
+ def __init__(self, model):
195
+ super(Resnet_with_skip, self).__init__()
196
+ self.model = model
197
+ self.decoder = Decoder()
198
+
199
+ def forward_pred(self, image):
200
+ pred_net = self.model(image)
201
+ return pred_net
202
+
203
+ def forward_decode(self, image):
204
+ identity = image
205
+
206
+ image = self.model.conv1(image)
207
+ image = self.model.bn1(image)
208
+ image = self.model.relu(image)
209
+ image1 = self.model.maxpool(image)
210
+
211
+ image2 = self.model.layer1(image1)
212
+ image3 = self.model.layer2(image2)
213
+ image4 = self.model.layer3(image3)
214
+ image5 = self.model.layer4(image4)
215
+
216
+ reconst1 = self.decoder.up1(image5, image4)
217
+ reconst2 = self.decoder.up2(reconst1, image3)
218
+ reconst3 = self.decoder.up3(reconst2, image2)
219
+ reconst = self.decoder.conv2d_2_1(reconst3)
220
+ # reconst = self.decoder.instance1(reconst)
221
+ reconst = self.decoder.gn1(reconst)
222
+ reconst = F.relu(reconst)
223
+ reconst4 = self.decoder.up4(reconst, image1)
224
+ # reconst5 = self.decoder.upsample4(reconst4)
225
+ reconst5 = self.decoder.upsample4(reconst4)
226
+ # reconst5 = self.decoder.upsample4_conv(reconst4)
227
+ reconst5 = self.decoder.up_(reconst5, image)
228
+ # reconst5 = reconst5 + image
229
+ reconst5 = self.decoder.conv2d_2_2(reconst5)
230
+ reconst5 = self.decoder.instance2(reconst5)
231
+ # reconst5 = self.decoder.gn2(reconst5)
232
+ reconst5 = F.relu(reconst5)
233
+ reconst = self.decoder.up5(reconst5, identity)
234
+ reconst = self.decoder.conv2d_2_3(reconst)
235
+ # reconst = self.decoder.instance3(reconst)
236
+ reconst = F.relu(reconst)
237
+
238
+ # return reconst
239
+
240
+ blurred = self.decoder.gaussian_blur(reconst)
241
+
242
+ gradients = kornia.filters.spatial_gradient(blurred, normalized=False)
243
+ # Unpack the edges
244
+ gx = gradients[:, :, 0]
245
+ gy = gradients[:, :, 1]
246
+
247
+ angle = torch.atan2(gy, gx)
248
+
249
+ # Radians to Degrees
250
+ import math
251
+ angle = 180.0 * angle / math.pi
252
+
253
+ # Round angle to the nearest 45 degree
254
+ angle = torch.round(angle / 45) * 45
255
+ nms_magnitude = self.decoder.nms_conv(blurred)
256
+ # nms_magnitude = F.conv2d(blurred, kernel.unsqueeze(1), padding=kernel.shape[-1]//2)
257
+
258
+ # Non-maximal suppression
259
+ # Get the indices for both directions
260
+ positive_idx = (angle / 45) % 8
261
+ positive_idx = positive_idx.long()
262
+
263
+ negative_idx = ((angle / 45) + 4) % 8
264
+ negative_idx = negative_idx.long()
265
+
266
+ # Apply the non-maximum suppression to the different directions
267
+ channel_select_filtered_positive = torch.gather(nms_magnitude, 1, positive_idx)
268
+ channel_select_filtered_negative = torch.gather(nms_magnitude, 1, negative_idx)
269
+
270
+ channel_select_filtered = torch.stack(
271
+ [channel_select_filtered_positive, channel_select_filtered_negative], 1
272
+ )
273
+
274
+ # is_max = channel_select_filtered.min(dim=1)[0] > 0.0
275
+
276
+ # magnitude = reconst * is_max
277
+
278
+ thresh = nn.Threshold(0.01, 0.01)
279
+ max_matrix = channel_select_filtered.min(dim=1)[0]
280
+ max_matrix = thresh(max_matrix)
281
+ magnitude = torch.mul(reconst, max_matrix)
282
+ # magnitude = torchvision.transforms.functional.invert(magnitude)
283
+ # magnitude = self.decoder.sharpen(magnitude)
284
+ # magnitude = self.decoder.threshold(magnitude)
285
+ magnitude = kornia.enhance.adjust_gamma(magnitude, 2.0)
286
+ # magnitude = F.leaky_relu(magnitude)
287
+ return magnitude
288
+
289
+ def forward(self, image):
290
+ reconst = self.forward_decode(image)
291
+ pred = self.forward_pred(image)
292
+ return pred, reconst
293
+
294
+
295
+ def create_retrieval_figure(res):
296
+ fig = plt.figure(figsize=[10 * 3, 10 * 3])
297
+ cols = 5
298
+ rows = 2
299
+ ax_query = fig.add_subplot(rows, 1, 1)
300
+ plt.rcParams['figure.facecolor'] = 'white'
301
+ plt.axis('off')
302
+ ax_query.set_title('Top 10 most similar scarabs', fontsize=40)
303
+ names = ""
304
+ for i, image in zip(range(len(res)), res):
305
+ current_image_path = image
306
+ if i==0: continue
307
+ if i < 11:
308
+ image = cv2.imread(current_image_path)
309
+ # image_resized = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR)
310
+ ax = fig.add_subplot(rows, cols, i)
311
+ plt.axis('off')
312
+ plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
313
+ item_uuid = current_image_path.split("/")[4].split("_photoUUID")[0].split("itemUUID_")[1]
314
+ ax.set_title('Top {}'.format(i), fontsize=40)
315
+ names = names + "Top " + str(i) + " item UUID is " + item_uuid + "\n"
316
+ # img_buf = io.BytesIO()
317
+ # plt.savefig(img_buf, format='png')
318
+ # im_fig = Image.open(img_buf)
319
+ # img_buf.close()
320
+ # return im_fig
321
+
322
+ return fig, names
323
+
324
+ def knn_calc(image_name, query_feature, features):
325
+ current_image_feature = features[image_name]
326
+ criterion = torch.nn.CosineSimilarity(dim=1)
327
+ dist = criterion(query_feature, current_image_feature).mean()
328
+ dist = -dist.item()
329
+ return dist
330
+
331
+ # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
332
+ device = 'cpu'
333
+
334
+ experiment = "experiment_0"
335
+ checkpoint_path = os.path.join("../shapes_classification/checkpoints/"
336
+ "50_50_pretrained_resnet101_experiment_0_train_images_with_drawings_batch_8_10:29:06/" +
337
+ "experiment_0_last_auto_model.pth.tar")
338
+ checkpoint_path = "multi_label.pth.tar"
339
+
340
+ resnet = models.resnet101(pretrained=True)
341
+ num_ftrs = resnet.fc.in_features
342
+ resnet.fc = nn.Linear(num_ftrs, 13)
343
+ model = Resnet_with_skip(resnet).to(device)
344
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
345
+ model.load_state_dict(checkpoint)
346
+ embedding_model_test = torch.nn.Sequential(*(list(model.children())[:-1]))
347
+ embedding_model_test.to(device)
348
+
349
+ periods_model = models.resnet101(pretrained=True)
350
+ periods_model.fc = nn.Linear(num_ftrs, 5)
351
+ periods_checkpoint = torch.load("periods.pth.tar", map_location="cpu")
352
+ periods_model.load_state_dict(periods_checkpoint)
353
+ periods_model.to(device)
354
+
355
+ data_dir = "../cssl_dataset/all_image_base/1/"
356
+ query_images_paths = []
357
+ for path in os.listdir(data_dir):
358
+ query_images_paths.append(os.path.join(data_dir, path))
359
+
360
+ with open('features.pkl', 'rb') as fp:
361
+ features = pickle.load(fp)
362
+
363
+ model.eval()
364
+ transform = transforms.Compose([
365
+ transforms.Resize((224, 224)),
366
+ transforms.Grayscale(num_output_channels=3),
367
+ transforms.ToTensor(),
368
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
369
+ ])
370
+ invTrans = transforms.Compose([transforms.Normalize(mean=[0., 0., 0.],
371
+ std=[1 / 0.5, 1 / 0.5, 1 / 0.5]),
372
+ transforms.Normalize(mean=[-0.5, -0.5, -0.5],
373
+ std=[1., 1., 1.]),
374
+ ])
375
+
376
+ labels = sorted(os.listdir("../cssl_dataset/shape_multi_label/photos"))
377
+ periods_labels = ["MB1", "MB2", "LB", "Iron1", 'Iron2']
378
+ periods_model.eval()
379
+
380
+ def predict(inp):
381
+ image_tensor = transform(inp)
382
+ image_tensor = image_tensor.to(device)
383
+ with torch.no_grad():
384
+ classification, reconstruction = model(image_tensor.unsqueeze(0))
385
+ periods_classification = periods_model(image_tensor.unsqueeze(0))
386
+ recon_tensor = reconstruction[0].repeat(3, 1, 1)
387
+ recon_tensor = invTrans(kornia.enhance.invert(recon_tensor))
388
+ plot_recon = recon_tensor.to("cpu").permute(1, 2, 0).detach().numpy()
389
+ w, h = inp.size
390
+ plot_recon = resize(plot_recon, (h, w))
391
+ m = nn.Sigmoid()
392
+ y = m(classification)
393
+ preds = []
394
+ for sample in y:
395
+ for i in sample:
396
+ if i >=0.8:
397
+ preds.append(1)
398
+ else:
399
+ preds.append(0)
400
+ # prediction = torch.tensor(preds).to(device)
401
+ confidences = {}
402
+ true_labels = ""
403
+ for i in range(len(labels)):
404
+ if preds[i]==1:
405
+ if true_labels=="":
406
+ true_labels = true_labels + labels[i]
407
+ else:
408
+ true_labels = true_labels + "&" + labels[i]
409
+ confidences[true_labels] = torch.tensor(1.0).to(device)
410
+
411
+ periods_prediction = torch.nn.functional.softmax(periods_classification[0], dim=0)
412
+ periods_confidences = {periods_labels[i]: periods_prediction[i] for i in range(len(periods_labels))}
413
+ feature = embedding_model_test(image_tensor.unsqueeze(0)).to(device)
414
+ dists = dict()
415
+ with torch.no_grad():
416
+ for i, image_name in enumerate(query_images_paths):
417
+ dist = knn_calc(image_name, feature, features)
418
+ dists[image_name] = dist
419
+ res = dict(sorted(dists.items(), key=itemgetter(1)))
420
+ fig, names = create_retrieval_figure(res)
421
+ return fig, names, plot_recon, confidences, periods_confidences
422
+
423
+
424
+ gr.Interface(fn=predict,
425
+ inputs=gr.Image(type="pil"),
426
+ outputs=['plot', 'text', "image", gr.Label(num_top_classes=1), gr.Label(num_top_classes=1)], ).launch(share=True)