offry commited on
Commit
feaac02
·
1 Parent(s): c08cf9f
Files changed (2) hide show
  1. get_models.py +283 -0
  2. use_gradio.py +191 -0
get_models.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import kornia.filters
2
+ import scipy.ndimage
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ import random
8
+
9
+
10
+
11
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
12
+ """3x3 convolution with padding"""
13
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
14
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
15
+
16
+
17
+ def conv1x1(in_planes, out_planes, stride=1):
18
+ """1x1 convolution"""
19
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
20
+
21
+
22
+ class DoubleConv(nn.Module):
23
+ """(convolution => [BN] => ReLU) * 2"""
24
+
25
+ def __init__(self, in_channels, out_channels, mid_channels=None):
26
+ super().__init__()
27
+ if not mid_channels:
28
+ mid_channels = out_channels
29
+ norm_layer = nn.BatchNorm2d
30
+
31
+ self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False)
32
+ self.bn1 = nn.BatchNorm2d(mid_channels)
33
+ self.inst1 = nn.InstanceNorm2d(mid_channels)
34
+ # self.gn1 = nn.GroupNorm(4, mid_channels)
35
+ self.relu = nn.ReLU(inplace=True)
36
+ self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False)
37
+ self.bn2 = nn.BatchNorm2d(out_channels)
38
+ self.inst2 = nn.InstanceNorm2d(out_channels)
39
+ # self.gn2 = nn.GroupNorm(4, out_channels)
40
+ self.downsample = None
41
+ if in_channels != out_channels:
42
+ self.downsample = nn.Sequential(
43
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
44
+ nn.BatchNorm2d(out_channels),
45
+ )
46
+
47
+ def forward(self, x):
48
+ identity = x
49
+
50
+ out = self.conv1(x)
51
+ # out = self.bn1(out)
52
+ out = self.inst1(out)
53
+ # out = self.gn1(out)
54
+ out = self.relu(out)
55
+
56
+ out = self.conv2(out)
57
+ # out = self.bn2(out)
58
+ out = self.inst2(out)
59
+ # out = self.gn2(out)
60
+ if self.downsample is not None:
61
+ identity = self.downsample(x)
62
+
63
+ out += identity
64
+ out = self.relu(out)
65
+ return out
66
+
67
+
68
+ class Down(nn.Module):
69
+ """Downscaling with maxpool then double conv"""
70
+
71
+ def __init__(self, in_channels, out_channels):
72
+ super().__init__()
73
+ self.maxpool_conv = nn.Sequential(
74
+ nn.MaxPool2d(2),
75
+ DoubleConv(in_channels, out_channels)
76
+ )
77
+
78
+ def forward(self, x):
79
+ return self.maxpool_conv(x)
80
+
81
+
82
+ class Up(nn.Module):
83
+ """Upscaling then double conv"""
84
+
85
+ def __init__(self, in_channels, out_channels, bilinear=True):
86
+ super().__init__()
87
+
88
+ # if bilinear, use the normal convolutions to reduce the number of channels
89
+ if bilinear:
90
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
91
+ self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
92
+ else:
93
+ if in_channels == out_channels:
94
+ self.up = nn.Identity()
95
+ else:
96
+ self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
97
+ self.conv = DoubleConv(in_channels, out_channels)
98
+
99
+ def forward(self, x1, x2):
100
+ x1 = self.up(x1)
101
+ # input is CHW
102
+ diffY = x2.size()[2] - x1.size()[2]
103
+ diffX = x2.size()[3] - x1.size()[3]
104
+
105
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
106
+ diffY // 2, diffY - diffY // 2])
107
+ # if you have padding issues, see
108
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
109
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
110
+ x = torch.cat([x2, x1], dim=1)
111
+ return self.conv(x)
112
+
113
+
114
+ class OutConv(nn.Module):
115
+ def __init__(self, in_channels, out_channels):
116
+ super(OutConv, self).__init__()
117
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
118
+
119
+ def forward(self, x):
120
+ return self.conv(x)
121
+
122
+ class GaussianLayer(nn.Module):
123
+ def __init__(self):
124
+ super(GaussianLayer, self).__init__()
125
+ self.seq = nn.Sequential(
126
+ # nn.ReflectionPad2d(10),
127
+ nn.Conv2d(1, 1, 5, stride=1, padding=2, bias=False)
128
+ )
129
+
130
+ self.weights_init()
131
+ def forward(self, x):
132
+ return self.seq(x)
133
+
134
+ def weights_init(self):
135
+ n= np.zeros((5,5))
136
+ n[3,3] = 1
137
+ k = scipy.ndimage.gaussian_filter(n,sigma=1)
138
+ for name, f in self.named_parameters():
139
+ f.data.copy_(torch.from_numpy(k))
140
+
141
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
142
+ """3x3 convolution with padding"""
143
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
144
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
145
+
146
+ class Decoder(nn.Module):
147
+ def __init__(self):
148
+ super(Decoder, self).__init__()
149
+ self.up1 = Up(2048, 1024 // 1, False)
150
+ self.up2 = Up(1024, 512 // 1, False)
151
+ self.up3 = Up(512, 256 // 1, False)
152
+ self.conv2d_2_1 = conv3x3(256, 128)
153
+ self.gn1 = nn.GroupNorm(4, 128)
154
+ self.instance1 = nn.InstanceNorm2d(128)
155
+ self.up4 = Up(128, 64 // 1, False)
156
+ self.upsample4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
157
+ # self.upsample4 = nn.ConvTranspose2d(64, 64, 2, stride=2)
158
+ self.upsample4_conv = DoubleConv(64, 64, 64 // 2)
159
+ self.up_ = Up(128, 128 // 1, False)
160
+ self.conv2d_2_2 = conv3x3(128, 6)
161
+ self.instance2 = nn.InstanceNorm2d(6)
162
+ self.gn2 = nn.GroupNorm(3, 6)
163
+ self.gaussian_blur = GaussianLayer()
164
+ self.up5 = Up(6, 3, False)
165
+ self.conv2d_2_3 = conv3x3(3, 1)
166
+ self.instance3 = nn.InstanceNorm2d(1)
167
+ self.gaussian_blur = GaussianLayer()
168
+ self.kernel = nn.Parameter(torch.tensor(
169
+ [[[0.0, 0.0, 0.0], [0.0, 1.0, random.uniform(-1.0, 0.0)], [0.0, 0.0, 0.0]],
170
+ [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, random.uniform(-1.0, 0.0)]],
171
+ [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, random.uniform(random.uniform(-1.0, 0.0), -0.0), 0.0]],
172
+ [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [random.uniform(-1.0, 0.0), 0.0, 0.0]],
173
+ [[0.0, 0.0, 0.0], [random.uniform(-1.0, 0.0), 1.0, 0.0], [0.0, 0.0, 0.0]],
174
+ [[random.uniform(-1.0, 0.0), 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
175
+ [[0.0, random.uniform(-1.0, 0.0), 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
176
+ [[0.0, 0.0, random.uniform(-1.0, 0.0)], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], ],
177
+ ).unsqueeze(1))
178
+
179
+ self.nms_conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False, groups=1)
180
+ with torch.no_grad():
181
+ self.nms_conv.weight = self.kernel.float()
182
+
183
+
184
+ class Resnet_with_skip(nn.Module):
185
+ def __init__(self, model):
186
+ super(Resnet_with_skip, self).__init__()
187
+ self.model = model
188
+ self.decoder = Decoder()
189
+
190
+ def forward_pred(self, image):
191
+ pred_net = self.model(image)
192
+ return pred_net
193
+
194
+ def forward_decode(self, image):
195
+ identity = image
196
+
197
+ image = self.model.conv1(image)
198
+ image = self.model.bn1(image)
199
+ image = self.model.relu(image)
200
+ image1 = self.model.maxpool(image)
201
+
202
+ image2 = self.model.layer1(image1)
203
+ image3 = self.model.layer2(image2)
204
+ image4 = self.model.layer3(image3)
205
+ image5 = self.model.layer4(image4)
206
+
207
+ reconst1 = self.decoder.up1(image5, image4)
208
+ reconst2 = self.decoder.up2(reconst1, image3)
209
+ reconst3 = self.decoder.up3(reconst2, image2)
210
+ reconst = self.decoder.conv2d_2_1(reconst3)
211
+ # reconst = self.decoder.instance1(reconst)
212
+ reconst = self.decoder.gn1(reconst)
213
+ reconst = F.relu(reconst)
214
+ reconst4 = self.decoder.up4(reconst, image1)
215
+ # reconst5 = self.decoder.upsample4(reconst4)
216
+ reconst5 = self.decoder.upsample4(reconst4)
217
+ # reconst5 = self.decoder.upsample4_conv(reconst4)
218
+ reconst5 = self.decoder.up_(reconst5, image)
219
+ # reconst5 = reconst5 + image
220
+ reconst5 = self.decoder.conv2d_2_2(reconst5)
221
+ reconst5 = self.decoder.instance2(reconst5)
222
+ # reconst5 = self.decoder.gn2(reconst5)
223
+ reconst5 = F.relu(reconst5)
224
+ reconst = self.decoder.up5(reconst5, identity)
225
+ reconst = self.decoder.conv2d_2_3(reconst)
226
+ # reconst = self.decoder.instance3(reconst)
227
+ reconst = F.relu(reconst)
228
+
229
+ # return reconst
230
+
231
+ blurred = self.decoder.gaussian_blur(reconst)
232
+
233
+ gradients = kornia.filters.spatial_gradient(blurred, normalized=False)
234
+ # Unpack the edges
235
+ gx = gradients[:, :, 0]
236
+ gy = gradients[:, :, 1]
237
+
238
+ angle = torch.atan2(gy, gx)
239
+
240
+ # Radians to Degrees
241
+ import math
242
+ angle = 180.0 * angle / math.pi
243
+
244
+ # Round angle to the nearest 45 degree
245
+ angle = torch.round(angle / 45) * 45
246
+ nms_magnitude = self.decoder.nms_conv(blurred)
247
+ # nms_magnitude = F.conv2d(blurred, kernel.unsqueeze(1), padding=kernel.shape[-1]//2)
248
+
249
+ # Non-maximal suppression
250
+ # Get the indices for both directions
251
+ positive_idx = (angle / 45) % 8
252
+ positive_idx = positive_idx.long()
253
+
254
+ negative_idx = ((angle / 45) + 4) % 8
255
+ negative_idx = negative_idx.long()
256
+
257
+ # Apply the non-maximum suppression to the different directions
258
+ channel_select_filtered_positive = torch.gather(nms_magnitude, 1, positive_idx)
259
+ channel_select_filtered_negative = torch.gather(nms_magnitude, 1, negative_idx)
260
+
261
+ channel_select_filtered = torch.stack(
262
+ [channel_select_filtered_positive, channel_select_filtered_negative], 1
263
+ )
264
+
265
+ # is_max = channel_select_filtered.min(dim=1)[0] > 0.0
266
+
267
+ # magnitude = reconst * is_max
268
+
269
+ thresh = nn.Threshold(0.01, 0.01)
270
+ max_matrix = channel_select_filtered.min(dim=1)[0]
271
+ max_matrix = thresh(max_matrix)
272
+ magnitude = torch.mul(reconst, max_matrix)
273
+ # magnitude = torchvision.transforms.functional.invert(magnitude)
274
+ # magnitude = self.decoder.sharpen(magnitude)
275
+ # magnitude = self.decoder.threshold(magnitude)
276
+ magnitude = kornia.enhance.adjust_gamma(magnitude, 2.0)
277
+ # magnitude = F.leaky_relu(magnitude)
278
+ return magnitude
279
+
280
+ def forward(self, image):
281
+ reconst = self.forward_decode(image)
282
+ pred = self.forward_pred(image)
283
+ return pred, reconst
use_gradio.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import kornia.filters
5
+ import torchvision.transforms.functional
6
+ import requests
7
+ from PIL import Image
8
+ from torchvision import transforms
9
+ from operator import itemgetter
10
+ import pickle
11
+ import io
12
+ from skimage.transform import resize
13
+
14
+ from utils_functions.imports import *
15
+
16
+ from util_models.resnet_with_skip import *
17
+ from util_models.densenet_with_skip import *
18
+ from util_models.glyphnet_with_skip import *
19
+
20
+
21
+ def create_retrieval_figure(res):
22
+ fig = plt.figure(figsize=[10 * 3, 10 * 3])
23
+ cols = 5
24
+ rows = 2
25
+ ax_query = fig.add_subplot(rows, 1, 1)
26
+ plt.rcParams['figure.facecolor'] = 'white'
27
+ plt.axis('off')
28
+ ax_query.set_title('Top 10 most similar scarabs', fontsize=40)
29
+ names = ""
30
+ for i, image in zip(range(len(res)), res):
31
+ current_image_path = image
32
+ if i==0: continue
33
+ if i < 11:
34
+ image = cv2.imread(current_image_path)
35
+ # image_resized = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR)
36
+ ax = fig.add_subplot(rows, cols, i)
37
+ plt.axis('off')
38
+ plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
39
+ item_uuid = current_image_path.split("/")[4].split("_photoUUID")[0].split("itemUUID_")[1]
40
+ ax.set_title('Top {}'.format(i), fontsize=40)
41
+ names = names + "Top " + str(i) + " item UUID is " + item_uuid + "\n"
42
+ # img_buf = io.BytesIO()
43
+ # plt.savefig(img_buf, format='png')
44
+ # im_fig = Image.open(img_buf)
45
+ # img_buf.close()
46
+ # return im_fig
47
+
48
+ return fig, names
49
+
50
+ def knn_calc(image_name, query_feature, features):
51
+ current_image_feature = features[image_name].to(device)
52
+ criterion = torch.nn.CosineSimilarity(dim=1)
53
+ dist = criterion(query_feature, current_image_feature).mean()
54
+ dist = -dist.item()
55
+ return dist
56
+
57
+
58
+ def return_all_features(model_test, query_images_paths, glyph = False):
59
+ model_test.eval()
60
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
61
+ model_test.to(device)
62
+ features = dict()
63
+ i = 0
64
+ transform = transforms.Compose([
65
+ transforms.RandomApply([transforms.ToPILImage(),], p=1),
66
+ transforms.Resize((224, 224)),
67
+ transforms.Grayscale(num_output_channels=3),
68
+ transforms.ToTensor(),
69
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
70
+ ])
71
+ gray_scale = transforms.Grayscale(num_output_channels=1)
72
+ with torch.no_grad():
73
+ for image_path in query_images_paths:
74
+ print(i)
75
+ i = i + 1
76
+ # if check_image_label(image_path, labels_dict) is not None:
77
+ img = cv2.imread(image_path)
78
+ img = transform(img)
79
+ # img = transforms.Grayscale(num_output_channels=1)(img).to(device)
80
+ img = img.unsqueeze(0).contiguous().to(device)
81
+ if glyph:
82
+ img = gray_scale(img)
83
+ current_image_features = model_test(img)
84
+ # current_image_features, _, _, _ = model_test(x1=img, x2=img)
85
+ features[image_path] = current_image_features
86
+ # if i % 5 == 0:
87
+ # print("Finished embedding of {} images".format(i))
88
+ del current_image_features
89
+ torch.cuda.empty_cache()
90
+ return features
91
+
92
+
93
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
94
+ # device = 'cpu'
95
+
96
+ experiment = "experiment_0"
97
+ checkpoint_path = os.path.join("../shapes_classification/checkpoints/"
98
+ "50_50_pretrained_resnet101_experiment_0_train_images_with_drawings_batch_8_10:29:06/" +
99
+ "experiment_0_last_auto_model.pth.tar")
100
+ checkpoint_path = "multi_label.pth.tar"
101
+
102
+ resnet = models.resnet101(pretrained=True)
103
+ num_ftrs = resnet.fc.in_features
104
+ resnet.fc = nn.Linear(num_ftrs, 13)
105
+ model = Resnet_with_skip(resnet).to(device)
106
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
107
+ model.load_state_dict(checkpoint)
108
+ embedding_model_test = torch.nn.Sequential(*(list(model.children())[:-1]))
109
+ embedding_model_test.to(device)
110
+
111
+ periods_model = models.resnet101(pretrained=True)
112
+ periods_model.fc = nn.Linear(num_ftrs, 5)
113
+ periods_checkpoint = torch.load("periods.pth.tar", map_location="cpu")
114
+ periods_model.load_state_dict(periods_checkpoint)
115
+ periods_model.to(device)
116
+
117
+ data_dir = "../cssl_dataset/all_image_base/1/"
118
+ query_images_paths = []
119
+ for path in os.listdir(data_dir):
120
+ query_images_paths.append(os.path.join(data_dir, path))
121
+ # features = return_all_features(embedding_model_test, query_images_paths)
122
+ # with open('features.pkl', 'wb') as fp:
123
+ # pickle.dump(features, fp)
124
+
125
+ with open('features.pkl', 'rb') as fp:
126
+ features = pickle.load(fp)
127
+
128
+ model.eval()
129
+ transform = transforms.Compose([
130
+ transforms.Resize((224, 224)),
131
+ transforms.Grayscale(num_output_channels=3),
132
+ transforms.ToTensor(),
133
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
134
+ ])
135
+ invTrans = transforms.Compose([transforms.Normalize(mean=[0., 0., 0.],
136
+ std=[1 / 0.5, 1 / 0.5, 1 / 0.5]),
137
+ transforms.Normalize(mean=[-0.5, -0.5, -0.5],
138
+ std=[1., 1., 1.]),
139
+ ])
140
+
141
+ labels = sorted(os.listdir("../cssl_dataset/shape_multi_label/photos"))
142
+ periods_labels = ["MB1", "MB2", "LB", "Iron1", 'Iron2']
143
+ periods_model.eval()
144
+
145
+ def predict(inp):
146
+ image_tensor = transform(inp)
147
+ image_tensor = image_tensor.to(device)
148
+ with torch.no_grad():
149
+ classification, reconstruction = model(image_tensor.unsqueeze(0))
150
+ periods_classification = periods_model(image_tensor.unsqueeze(0))
151
+ recon_tensor = reconstruction[0].repeat(3, 1, 1)
152
+ recon_tensor = invTrans(kornia.enhance.invert(recon_tensor))
153
+ plot_recon = recon_tensor.to("cpu").permute(1, 2, 0).detach().numpy()
154
+ w, h = inp.size
155
+ plot_recon = resize(plot_recon, (h, w))
156
+ m = nn.Sigmoid()
157
+ y = m(classification)
158
+ preds = []
159
+ for sample in y:
160
+ for i in sample:
161
+ if i >=0.8:
162
+ preds.append(1)
163
+ else:
164
+ preds.append(0)
165
+ # prediction = torch.tensor(preds).to(device)
166
+ confidences = {}
167
+ true_labels = ""
168
+ for i in range(len(labels)):
169
+ if preds[i]==1:
170
+ if true_labels=="":
171
+ true_labels = true_labels + labels[i]
172
+ else:
173
+ true_labels = true_labels + "&" + labels[i]
174
+ confidences[true_labels] = torch.tensor(1.0).to(device)
175
+
176
+ periods_prediction = torch.nn.functional.softmax(periods_classification[0], dim=0)
177
+ periods_confidences = {periods_labels[i]: periods_prediction[i] for i in range(len(periods_labels))}
178
+ feature = embedding_model_test(image_tensor.unsqueeze(0)).to(device)
179
+ dists = dict()
180
+ with torch.no_grad():
181
+ for i, image_name in enumerate(query_images_paths):
182
+ dist = knn_calc(image_name, feature, features)
183
+ dists[image_name] = dist
184
+ res = dict(sorted(dists.items(), key=itemgetter(1)))
185
+ fig, names = create_retrieval_figure(res)
186
+ return fig, names, plot_recon, confidences, periods_confidences
187
+
188
+
189
+ gr.Interface(fn=predict,
190
+ inputs=gr.Image(type="pil"),
191
+ outputs=['plot', 'text', "image", gr.Label(num_top_classes=1), gr.Label(num_top_classes=1)], ).launch(share=True)