offry commited on
Commit
4a508f2
·
1 Parent(s): 1a2aea2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -280
app.py CHANGED
@@ -1,297 +1,18 @@
1
- import os
2
  import pickle
3
  from operator import itemgetter
4
 
5
  import cv2
6
  import gradio as gr
7
  import kornia.filters
8
- import kornia.filters
9
- import scipy.ndimage
10
  import torch
11
  import torch.nn as nn
12
  import torch.nn.functional as F
13
  import numpy as np
14
  import matplotlib.pyplot as plt
15
- import random
16
  import zipfile
17
  # from skimage.transform import resize
18
  from torchvision import transforms, models
19
-
20
-
21
- def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
22
- """3x3 convolution with padding"""
23
- return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
24
- padding=dilation, groups=groups, bias=False, dilation=dilation)
25
-
26
-
27
- def conv1x1(in_planes, out_planes, stride=1):
28
- """1x1 convolution"""
29
- return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
30
-
31
-
32
- class DoubleConv(nn.Module):
33
- """(convolution => [BN] => ReLU) * 2"""
34
-
35
- def __init__(self, in_channels, out_channels, mid_channels=None):
36
- super().__init__()
37
- if not mid_channels:
38
- mid_channels = out_channels
39
- norm_layer = nn.BatchNorm2d
40
-
41
- self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False)
42
- self.bn1 = nn.BatchNorm2d(mid_channels)
43
- self.inst1 = nn.InstanceNorm2d(mid_channels)
44
- # self.gn1 = nn.GroupNorm(4, mid_channels)
45
- self.relu = nn.ReLU(inplace=True)
46
- self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False)
47
- self.bn2 = nn.BatchNorm2d(out_channels)
48
- self.inst2 = nn.InstanceNorm2d(out_channels)
49
- # self.gn2 = nn.GroupNorm(4, out_channels)
50
- self.downsample = None
51
- if in_channels != out_channels:
52
- self.downsample = nn.Sequential(
53
- nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
54
- nn.BatchNorm2d(out_channels),
55
- )
56
-
57
- def forward(self, x):
58
- identity = x
59
-
60
- out = self.conv1(x)
61
- # out = self.bn1(out)
62
- out = self.inst1(out)
63
- # out = self.gn1(out)
64
- out = self.relu(out)
65
-
66
- out = self.conv2(out)
67
- # out = self.bn2(out)
68
- out = self.inst2(out)
69
- # out = self.gn2(out)
70
- if self.downsample is not None:
71
- identity = self.downsample(x)
72
-
73
- out += identity
74
- out = self.relu(out)
75
- return out
76
-
77
-
78
- class Down(nn.Module):
79
- """Downscaling with maxpool then double conv"""
80
-
81
- def __init__(self, in_channels, out_channels):
82
- super().__init__()
83
- self.maxpool_conv = nn.Sequential(
84
- nn.MaxPool2d(2),
85
- DoubleConv(in_channels, out_channels)
86
- )
87
-
88
- def forward(self, x):
89
- return self.maxpool_conv(x)
90
-
91
-
92
- class Up(nn.Module):
93
- """Upscaling then double conv"""
94
-
95
- def __init__(self, in_channels, out_channels, bilinear=True):
96
- super().__init__()
97
-
98
- # if bilinear, use the normal convolutions to reduce the number of channels
99
- if bilinear:
100
- self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
101
- self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
102
- else:
103
- if in_channels == out_channels:
104
- self.up = nn.Identity()
105
- else:
106
- self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
107
- self.conv = DoubleConv(in_channels, out_channels)
108
-
109
- def forward(self, x1, x2):
110
- x1 = self.up(x1)
111
- # input is CHW
112
- diffY = x2.size()[2] - x1.size()[2]
113
- diffX = x2.size()[3] - x1.size()[3]
114
-
115
- x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
116
- diffY // 2, diffY - diffY // 2])
117
- # if you have padding issues, see
118
- # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
119
- # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
120
- x = torch.cat([x2, x1], dim=1)
121
- return self.conv(x)
122
-
123
-
124
- class OutConv(nn.Module):
125
- def __init__(self, in_channels, out_channels):
126
- super(OutConv, self).__init__()
127
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
128
-
129
- def forward(self, x):
130
- return self.conv(x)
131
-
132
- class GaussianLayer(nn.Module):
133
- def __init__(self):
134
- super(GaussianLayer, self).__init__()
135
- self.seq = nn.Sequential(
136
- # nn.ReflectionPad2d(10),
137
- nn.Conv2d(1, 1, 5, stride=1, padding=2, bias=False)
138
- )
139
-
140
- self.weights_init()
141
- def forward(self, x):
142
- return self.seq(x)
143
-
144
- def weights_init(self):
145
- n= np.zeros((5,5))
146
- n[3,3] = 1
147
- k = scipy.ndimage.gaussian_filter(n,sigma=1)
148
- for name, f in self.named_parameters():
149
- f.data.copy_(torch.from_numpy(k))
150
-
151
- def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
152
- """3x3 convolution with padding"""
153
- return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
154
- padding=dilation, groups=groups, bias=False, dilation=dilation)
155
-
156
- class Decoder(nn.Module):
157
- def __init__(self):
158
- super(Decoder, self).__init__()
159
- self.up1 = Up(2048, 1024 // 1, False)
160
- self.up2 = Up(1024, 512 // 1, False)
161
- self.up3 = Up(512, 256 // 1, False)
162
- self.conv2d_2_1 = conv3x3(256, 128)
163
- self.gn1 = nn.GroupNorm(4, 128)
164
- self.instance1 = nn.InstanceNorm2d(128)
165
- self.up4 = Up(128, 64 // 1, False)
166
- self.upsample4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
167
- # self.upsample4 = nn.ConvTranspose2d(64, 64, 2, stride=2)
168
- self.upsample4_conv = DoubleConv(64, 64, 64 // 2)
169
- self.up_ = Up(128, 128 // 1, False)
170
- self.conv2d_2_2 = conv3x3(128, 6)
171
- self.instance2 = nn.InstanceNorm2d(6)
172
- self.gn2 = nn.GroupNorm(3, 6)
173
- self.gaussian_blur = GaussianLayer()
174
- self.up5 = Up(6, 3, False)
175
- self.conv2d_2_3 = conv3x3(3, 1)
176
- self.instance3 = nn.InstanceNorm2d(1)
177
- self.gaussian_blur = GaussianLayer()
178
- self.kernel = nn.Parameter(torch.tensor(
179
- [[[0.0, 0.0, 0.0], [0.0, 1.0, random.uniform(-1.0, 0.0)], [0.0, 0.0, 0.0]],
180
- [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, random.uniform(-1.0, 0.0)]],
181
- [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, random.uniform(random.uniform(-1.0, 0.0), -0.0), 0.0]],
182
- [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [random.uniform(-1.0, 0.0), 0.0, 0.0]],
183
- [[0.0, 0.0, 0.0], [random.uniform(-1.0, 0.0), 1.0, 0.0], [0.0, 0.0, 0.0]],
184
- [[random.uniform(-1.0, 0.0), 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
185
- [[0.0, random.uniform(-1.0, 0.0), 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
186
- [[0.0, 0.0, random.uniform(-1.0, 0.0)], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], ],
187
- ).unsqueeze(1))
188
-
189
- self.nms_conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False, groups=1)
190
- with torch.no_grad():
191
- self.nms_conv.weight = self.kernel.float()
192
-
193
-
194
- class Resnet_with_skip(nn.Module):
195
- def __init__(self, model):
196
- super(Resnet_with_skip, self).__init__()
197
- self.model = model
198
- self.decoder = Decoder()
199
-
200
- def forward_pred(self, image):
201
- pred_net = self.model(image)
202
- return pred_net
203
-
204
- def forward_decode(self, image):
205
- identity = image
206
-
207
- image = self.model.conv1(image)
208
- image = self.model.bn1(image)
209
- image = self.model.relu(image)
210
- image1 = self.model.maxpool(image)
211
-
212
- image2 = self.model.layer1(image1)
213
- image3 = self.model.layer2(image2)
214
- image4 = self.model.layer3(image3)
215
- image5 = self.model.layer4(image4)
216
-
217
- reconst1 = self.decoder.up1(image5, image4)
218
- reconst2 = self.decoder.up2(reconst1, image3)
219
- reconst3 = self.decoder.up3(reconst2, image2)
220
- reconst = self.decoder.conv2d_2_1(reconst3)
221
- # reconst = self.decoder.instance1(reconst)
222
- reconst = self.decoder.gn1(reconst)
223
- reconst = F.relu(reconst)
224
- reconst4 = self.decoder.up4(reconst, image1)
225
- # reconst5 = self.decoder.upsample4(reconst4)
226
- reconst5 = self.decoder.upsample4(reconst4)
227
- # reconst5 = self.decoder.upsample4_conv(reconst4)
228
- reconst5 = self.decoder.up_(reconst5, image)
229
- # reconst5 = reconst5 + image
230
- reconst5 = self.decoder.conv2d_2_2(reconst5)
231
- reconst5 = self.decoder.instance2(reconst5)
232
- # reconst5 = self.decoder.gn2(reconst5)
233
- reconst5 = F.relu(reconst5)
234
- reconst = self.decoder.up5(reconst5, identity)
235
- reconst = self.decoder.conv2d_2_3(reconst)
236
- # reconst = self.decoder.instance3(reconst)
237
- reconst = F.relu(reconst)
238
-
239
- # return reconst
240
-
241
- blurred = self.decoder.gaussian_blur(reconst)
242
-
243
- gradients = kornia.filters.spatial_gradient(blurred, normalized=False)
244
- # Unpack the edges
245
- gx = gradients[:, :, 0]
246
- gy = gradients[:, :, 1]
247
-
248
- angle = torch.atan2(gy, gx)
249
-
250
- # Radians to Degrees
251
- import math
252
- angle = 180.0 * angle / math.pi
253
-
254
- # Round angle to the nearest 45 degree
255
- angle = torch.round(angle / 45) * 45
256
- nms_magnitude = self.decoder.nms_conv(blurred)
257
- # nms_magnitude = F.conv2d(blurred, kernel.unsqueeze(1), padding=kernel.shape[-1]//2)
258
-
259
- # Non-maximal suppression
260
- # Get the indices for both directions
261
- positive_idx = (angle / 45) % 8
262
- positive_idx = positive_idx.long()
263
-
264
- negative_idx = ((angle / 45) + 4) % 8
265
- negative_idx = negative_idx.long()
266
-
267
- # Apply the non-maximum suppression to the different directions
268
- channel_select_filtered_positive = torch.gather(nms_magnitude, 1, positive_idx)
269
- channel_select_filtered_negative = torch.gather(nms_magnitude, 1, negative_idx)
270
-
271
- channel_select_filtered = torch.stack(
272
- [channel_select_filtered_positive, channel_select_filtered_negative], 1
273
- )
274
-
275
- # is_max = channel_select_filtered.min(dim=1)[0] > 0.0
276
-
277
- # magnitude = reconst * is_max
278
-
279
- thresh = nn.Threshold(0.01, 0.01)
280
- max_matrix = channel_select_filtered.min(dim=1)[0]
281
- max_matrix = thresh(max_matrix)
282
- magnitude = torch.mul(reconst, max_matrix)
283
- # magnitude = torchvision.transforms.functional.invert(magnitude)
284
- # magnitude = self.decoder.sharpen(magnitude)
285
- # magnitude = self.decoder.threshold(magnitude)
286
- magnitude = kornia.enhance.adjust_gamma(magnitude, 2.0)
287
- # magnitude = F.leaky_relu(magnitude)
288
- return magnitude
289
-
290
- def forward(self, image):
291
- reconst = self.forward_decode(image)
292
- pred = self.forward_pred(image)
293
- return pred, reconst
294
-
295
 
296
  def create_retrieval_figure(res):
297
  fig = plt.figure(figsize=[10 * 3, 10 * 3])
 
 
1
  import pickle
2
  from operator import itemgetter
3
 
4
  import cv2
5
  import gradio as gr
6
  import kornia.filters
 
 
7
  import torch
8
  import torch.nn as nn
9
  import torch.nn.functional as F
10
  import numpy as np
11
  import matplotlib.pyplot as plt
 
12
  import zipfile
13
  # from skimage.transform import resize
14
  from torchvision import transforms, models
15
+ from get_models import Resnet_with_skip
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  def create_retrieval_figure(res):
18
  fig = plt.figure(figsize=[10 * 3, 10 * 3])