offry commited on
Commit
1a2aea2
·
1 Parent(s): 0c2a598

Upload get_models.py

Browse files
Files changed (1) hide show
  1. get_models.py +284 -0
get_models.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import kornia.filters
2
+ import kornia.filters
3
+ import scipy.ndimage
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ import random
9
+
10
+
11
+
12
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
13
+ """3x3 convolution with padding"""
14
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
15
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
16
+
17
+
18
+ def conv1x1(in_planes, out_planes, stride=1):
19
+ """1x1 convolution"""
20
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
21
+
22
+
23
+ class DoubleConv(nn.Module):
24
+ """(convolution => [BN] => ReLU) * 2"""
25
+
26
+ def __init__(self, in_channels, out_channels, mid_channels=None):
27
+ super().__init__()
28
+ if not mid_channels:
29
+ mid_channels = out_channels
30
+ norm_layer = nn.BatchNorm2d
31
+
32
+ self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False)
33
+ self.bn1 = nn.BatchNorm2d(mid_channels)
34
+ self.inst1 = nn.InstanceNorm2d(mid_channels)
35
+ # self.gn1 = nn.GroupNorm(4, mid_channels)
36
+ self.relu = nn.ReLU(inplace=True)
37
+ self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False)
38
+ self.bn2 = nn.BatchNorm2d(out_channels)
39
+ self.inst2 = nn.InstanceNorm2d(out_channels)
40
+ # self.gn2 = nn.GroupNorm(4, out_channels)
41
+ self.downsample = None
42
+ if in_channels != out_channels:
43
+ self.downsample = nn.Sequential(
44
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
45
+ nn.BatchNorm2d(out_channels),
46
+ )
47
+
48
+ def forward(self, x):
49
+ identity = x
50
+
51
+ out = self.conv1(x)
52
+ # out = self.bn1(out)
53
+ out = self.inst1(out)
54
+ # out = self.gn1(out)
55
+ out = self.relu(out)
56
+
57
+ out = self.conv2(out)
58
+ # out = self.bn2(out)
59
+ out = self.inst2(out)
60
+ # out = self.gn2(out)
61
+ if self.downsample is not None:
62
+ identity = self.downsample(x)
63
+
64
+ out += identity
65
+ out = self.relu(out)
66
+ return out
67
+
68
+
69
+ class Down(nn.Module):
70
+ """Downscaling with maxpool then double conv"""
71
+
72
+ def __init__(self, in_channels, out_channels):
73
+ super().__init__()
74
+ self.maxpool_conv = nn.Sequential(
75
+ nn.MaxPool2d(2),
76
+ DoubleConv(in_channels, out_channels)
77
+ )
78
+
79
+ def forward(self, x):
80
+ return self.maxpool_conv(x)
81
+
82
+
83
+ class Up(nn.Module):
84
+ """Upscaling then double conv"""
85
+
86
+ def __init__(self, in_channels, out_channels, bilinear=True):
87
+ super().__init__()
88
+
89
+ # if bilinear, use the normal convolutions to reduce the number of channels
90
+ if bilinear:
91
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
92
+ self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
93
+ else:
94
+ if in_channels == out_channels:
95
+ self.up = nn.Identity()
96
+ else:
97
+ self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
98
+ self.conv = DoubleConv(in_channels, out_channels)
99
+
100
+ def forward(self, x1, x2):
101
+ x1 = self.up(x1)
102
+ # input is CHW
103
+ diffY = x2.size()[2] - x1.size()[2]
104
+ diffX = x2.size()[3] - x1.size()[3]
105
+
106
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
107
+ diffY // 2, diffY - diffY // 2])
108
+ # if you have padding issues, see
109
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
110
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
111
+ x = torch.cat([x2, x1], dim=1)
112
+ return self.conv(x)
113
+
114
+
115
+ class OutConv(nn.Module):
116
+ def __init__(self, in_channels, out_channels):
117
+ super(OutConv, self).__init__()
118
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
119
+
120
+ def forward(self, x):
121
+ return self.conv(x)
122
+
123
+ class GaussianLayer(nn.Module):
124
+ def __init__(self):
125
+ super(GaussianLayer, self).__init__()
126
+ self.seq = nn.Sequential(
127
+ # nn.ReflectionPad2d(10),
128
+ nn.Conv2d(1, 1, 5, stride=1, padding=2, bias=False)
129
+ )
130
+
131
+ self.weights_init()
132
+ def forward(self, x):
133
+ return self.seq(x)
134
+
135
+ def weights_init(self):
136
+ n= np.zeros((5,5))
137
+ n[3,3] = 1
138
+ k = scipy.ndimage.gaussian_filter(n,sigma=1)
139
+ for name, f in self.named_parameters():
140
+ f.data.copy_(torch.from_numpy(k))
141
+
142
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
143
+ """3x3 convolution with padding"""
144
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
145
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
146
+
147
+ class Decoder(nn.Module):
148
+ def __init__(self):
149
+ super(Decoder, self).__init__()
150
+ self.up1 = Up(2048, 1024 // 1, False)
151
+ self.up2 = Up(1024, 512 // 1, False)
152
+ self.up3 = Up(512, 256 // 1, False)
153
+ self.conv2d_2_1 = conv3x3(256, 128)
154
+ self.gn1 = nn.GroupNorm(4, 128)
155
+ self.instance1 = nn.InstanceNorm2d(128)
156
+ self.up4 = Up(128, 64 // 1, False)
157
+ self.upsample4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
158
+ # self.upsample4 = nn.ConvTranspose2d(64, 64, 2, stride=2)
159
+ self.upsample4_conv = DoubleConv(64, 64, 64 // 2)
160
+ self.up_ = Up(128, 128 // 1, False)
161
+ self.conv2d_2_2 = conv3x3(128, 6)
162
+ self.instance2 = nn.InstanceNorm2d(6)
163
+ self.gn2 = nn.GroupNorm(3, 6)
164
+ self.gaussian_blur = GaussianLayer()
165
+ self.up5 = Up(6, 3, False)
166
+ self.conv2d_2_3 = conv3x3(3, 1)
167
+ self.instance3 = nn.InstanceNorm2d(1)
168
+ self.gaussian_blur = GaussianLayer()
169
+ self.kernel = nn.Parameter(torch.tensor(
170
+ [[[0.0, 0.0, 0.0], [0.0, 1.0, random.uniform(-1.0, 0.0)], [0.0, 0.0, 0.0]],
171
+ [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, random.uniform(-1.0, 0.0)]],
172
+ [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, random.uniform(random.uniform(-1.0, 0.0), -0.0), 0.0]],
173
+ [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [random.uniform(-1.0, 0.0), 0.0, 0.0]],
174
+ [[0.0, 0.0, 0.0], [random.uniform(-1.0, 0.0), 1.0, 0.0], [0.0, 0.0, 0.0]],
175
+ [[random.uniform(-1.0, 0.0), 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
176
+ [[0.0, random.uniform(-1.0, 0.0), 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
177
+ [[0.0, 0.0, random.uniform(-1.0, 0.0)], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], ],
178
+ ).unsqueeze(1))
179
+
180
+ self.nms_conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False, groups=1)
181
+ with torch.no_grad():
182
+ self.nms_conv.weight = self.kernel.float()
183
+
184
+
185
+ class Resnet_with_skip(nn.Module):
186
+ def __init__(self, model):
187
+ super(Resnet_with_skip, self).__init__()
188
+ self.model = model
189
+ self.decoder = Decoder()
190
+
191
+ def forward_pred(self, image):
192
+ pred_net = self.model(image)
193
+ return pred_net
194
+
195
+ def forward_decode(self, image):
196
+ identity = image
197
+
198
+ image = self.model.conv1(image)
199
+ image = self.model.bn1(image)
200
+ image = self.model.relu(image)
201
+ image1 = self.model.maxpool(image)
202
+
203
+ image2 = self.model.layer1(image1)
204
+ image3 = self.model.layer2(image2)
205
+ image4 = self.model.layer3(image3)
206
+ image5 = self.model.layer4(image4)
207
+
208
+ reconst1 = self.decoder.up1(image5, image4)
209
+ reconst2 = self.decoder.up2(reconst1, image3)
210
+ reconst3 = self.decoder.up3(reconst2, image2)
211
+ reconst = self.decoder.conv2d_2_1(reconst3)
212
+ # reconst = self.decoder.instance1(reconst)
213
+ reconst = self.decoder.gn1(reconst)
214
+ reconst = F.relu(reconst)
215
+ reconst4 = self.decoder.up4(reconst, image1)
216
+ # reconst5 = self.decoder.upsample4(reconst4)
217
+ reconst5 = self.decoder.upsample4(reconst4)
218
+ # reconst5 = self.decoder.upsample4_conv(reconst4)
219
+ reconst5 = self.decoder.up_(reconst5, image)
220
+ # reconst5 = reconst5 + image
221
+ reconst5 = self.decoder.conv2d_2_2(reconst5)
222
+ reconst5 = self.decoder.instance2(reconst5)
223
+ # reconst5 = self.decoder.gn2(reconst5)
224
+ reconst5 = F.relu(reconst5)
225
+ reconst = self.decoder.up5(reconst5, identity)
226
+ reconst = self.decoder.conv2d_2_3(reconst)
227
+ # reconst = self.decoder.instance3(reconst)
228
+ reconst = F.relu(reconst)
229
+
230
+ # return reconst
231
+
232
+ blurred = self.decoder.gaussian_blur(reconst)
233
+
234
+ gradients = kornia.filters.spatial_gradient(blurred, normalized=False)
235
+ # Unpack the edges
236
+ gx = gradients[:, :, 0]
237
+ gy = gradients[:, :, 1]
238
+
239
+ angle = torch.atan2(gy, gx)
240
+
241
+ # Radians to Degrees
242
+ import math
243
+ angle = 180.0 * angle / math.pi
244
+
245
+ # Round angle to the nearest 45 degree
246
+ angle = torch.round(angle / 45) * 45
247
+ nms_magnitude = self.decoder.nms_conv(blurred)
248
+ # nms_magnitude = F.conv2d(blurred, kernel.unsqueeze(1), padding=kernel.shape[-1]//2)
249
+
250
+ # Non-maximal suppression
251
+ # Get the indices for both directions
252
+ positive_idx = (angle / 45) % 8
253
+ positive_idx = positive_idx.long()
254
+
255
+ negative_idx = ((angle / 45) + 4) % 8
256
+ negative_idx = negative_idx.long()
257
+
258
+ # Apply the non-maximum suppression to the different directions
259
+ channel_select_filtered_positive = torch.gather(nms_magnitude, 1, positive_idx)
260
+ channel_select_filtered_negative = torch.gather(nms_magnitude, 1, negative_idx)
261
+
262
+ channel_select_filtered = torch.stack(
263
+ [channel_select_filtered_positive, channel_select_filtered_negative], 1
264
+ )
265
+
266
+ # is_max = channel_select_filtered.min(dim=1)[0] > 0.0
267
+
268
+ # magnitude = reconst * is_max
269
+
270
+ thresh = nn.Threshold(0.01, 0.01)
271
+ max_matrix = channel_select_filtered.min(dim=1)[0]
272
+ max_matrix = thresh(max_matrix)
273
+ magnitude = torch.mul(reconst, max_matrix)
274
+ # magnitude = torchvision.transforms.functional.invert(magnitude)
275
+ # magnitude = self.decoder.sharpen(magnitude)
276
+ # magnitude = self.decoder.threshold(magnitude)
277
+ magnitude = kornia.enhance.adjust_gamma(magnitude, 2.0)
278
+ # magnitude = F.leaky_relu(magnitude)
279
+ return magnitude
280
+
281
+ def forward(self, image):
282
+ reconst = self.forward_decode(image)
283
+ pred = self.forward_pred(image)
284
+ return pred, reconst