truens66 commited on
Commit
23190a5
·
verified ·
1 Parent(s): 1781a02

Update resnet.py

Browse files
Files changed (1) hide show
  1. resnet.py +235 -0
resnet.py CHANGED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.utils.model_zoo as model_zoo
3
+ from torch.nn import functional as F
4
+ from typing import Any, cast, Dict, List, Optional, Union
5
+ import numpy as np
6
+
7
+ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
8
+ 'resnet152']
9
+
10
+
11
+ model_urls = {
12
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
13
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
14
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
15
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
16
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
17
+ }
18
+
19
+
20
+ def conv3x3(in_planes, out_planes, stride=1):
21
+ """3x3 convolution with padding"""
22
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
23
+ padding=1, bias=False)
24
+
25
+
26
+ def conv1x1(in_planes, out_planes, stride=1):
27
+ """1x1 convolution"""
28
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
29
+
30
+
31
+ class BasicBlock(nn.Module):
32
+ expansion = 1
33
+
34
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
35
+ super(BasicBlock, self).__init__()
36
+ self.conv1 = conv3x3(inplanes, planes, stride)
37
+ self.bn1 = nn.BatchNorm2d(planes)
38
+ self.relu = nn.ReLU(inplace=True)
39
+ self.conv2 = conv3x3(planes, planes)
40
+ self.bn2 = nn.BatchNorm2d(planes)
41
+ self.downsample = downsample
42
+ self.stride = stride
43
+
44
+ def forward(self, x):
45
+ identity = x
46
+
47
+ out = self.conv1(x)
48
+ out = self.bn1(out)
49
+ out = self.relu(out)
50
+
51
+ out = self.conv2(out)
52
+ out = self.bn2(out)
53
+
54
+ if self.downsample is not None:
55
+ identity = self.downsample(x)
56
+
57
+ out += identity
58
+ out = self.relu(out)
59
+
60
+ return out
61
+
62
+
63
+ class Bottleneck(nn.Module):
64
+ expansion = 4
65
+
66
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
67
+ super(Bottleneck, self).__init__()
68
+ self.conv1 = conv1x1(inplanes, planes)
69
+ self.bn1 = nn.BatchNorm2d(planes)
70
+ self.conv2 = conv3x3(planes, planes, stride)
71
+ self.bn2 = nn.BatchNorm2d(planes)
72
+ self.conv3 = conv1x1(planes, planes * self.expansion)
73
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
74
+ self.relu = nn.ReLU(inplace=True)
75
+ self.downsample = downsample
76
+ self.stride = stride
77
+
78
+ def forward(self, x):
79
+ identity = x
80
+
81
+ out = self.conv1(x)
82
+ out = self.bn1(out)
83
+ out = self.relu(out)
84
+
85
+ out = self.conv2(out)
86
+ out = self.bn2(out)
87
+ out = self.relu(out)
88
+
89
+ out = self.conv3(out)
90
+ out = self.bn3(out)
91
+
92
+ if self.downsample is not None:
93
+ identity = self.downsample(x)
94
+
95
+ out += identity
96
+ out = self.relu(out)
97
+
98
+ return out
99
+
100
+
101
+ class ResNet(nn.Module):
102
+
103
+ def __init__(self, block, layers, num_classes=1, zero_init_residual=False):
104
+ super(ResNet, self).__init__()
105
+
106
+ self.unfoldSize = 2
107
+ self.unfoldIndex = 0
108
+ assert self.unfoldSize > 1
109
+ assert -1 < self.unfoldIndex and self.unfoldIndex < self.unfoldSize*self.unfoldSize
110
+ self.inplanes = 64
111
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
112
+ self.bn1 = nn.BatchNorm2d(64)
113
+ self.relu = nn.ReLU(inplace=True)
114
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
115
+ self.layer1 = self._make_layer(block, 64 , layers[0])
116
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
117
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
118
+ # self.fc1 = nn.Linear(512 * block.expansion, 1)
119
+ self.fc1 = nn.Linear(512, num_classes)
120
+
121
+ for m in self.modules():
122
+ if isinstance(m, nn.Conv2d):
123
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
124
+ elif isinstance(m, nn.BatchNorm2d):
125
+ nn.init.constant_(m.weight, 1)
126
+ nn.init.constant_(m.bias, 0)
127
+
128
+ # Zero-initialize the last BN in each residual branch,
129
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
130
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
131
+ if zero_init_residual:
132
+ for m in self.modules():
133
+ if isinstance(m, Bottleneck):
134
+ nn.init.constant_(m.bn3.weight, 0)
135
+ elif isinstance(m, BasicBlock):
136
+ nn.init.constant_(m.bn2.weight, 0)
137
+
138
+ def _make_layer(self, block, planes, blocks, stride=1):
139
+ downsample = None
140
+ if stride != 1 or self.inplanes != planes * block.expansion:
141
+ downsample = nn.Sequential(
142
+ conv1x1(self.inplanes, planes * block.expansion, stride),
143
+ nn.BatchNorm2d(planes * block.expansion),
144
+ )
145
+
146
+ layers = []
147
+ layers.append(block(self.inplanes, planes, stride, downsample))
148
+ self.inplanes = planes * block.expansion
149
+ for _ in range(1, blocks):
150
+ layers.append(block(self.inplanes, planes))
151
+
152
+ return nn.Sequential(*layers)
153
+ def interpolate(self, img, factor):
154
+ return F.interpolate(F.interpolate(img, scale_factor=factor, mode='nearest', recompute_scale_factor=True), scale_factor=1/factor, mode='nearest', recompute_scale_factor=True)
155
+ def forward(self, x):
156
+ # n,c,w,h = x.shape
157
+ # if -1*w%2 != 0: x = x[:,:,:w%2*-1,: ]
158
+ # if -1*h%2 != 0: x = x[:,:,: ,:h%2*-1]
159
+ # factor = 0.5
160
+ # x_half = F.interpolate(x, scale_factor=factor, mode='nearest', recompute_scale_factor=True)
161
+ # x_re = F.interpolate(x_half, scale_factor=1/factor, mode='nearest', recompute_scale_factor=True)
162
+ # NPR = x - x_re
163
+ # n,c,w,h = x.shape
164
+ # if w%2 == 1 : x = x[:,:,:-1,:]
165
+ # if h%2 == 1 : x = x[:,:,:,:-1]
166
+ NPR = x - self.interpolate(x, 0.5)
167
+
168
+ x = self.conv1(NPR*2.0/3.0)
169
+ x = self.bn1(x)
170
+ x = self.relu(x)
171
+ x = self.maxpool(x)
172
+
173
+ x = self.layer1(x)
174
+ x = self.layer2(x)
175
+
176
+ x = self.avgpool(x)
177
+ x = x.view(x.size(0), -1)
178
+ x = self.fc1(x)
179
+
180
+ return x
181
+
182
+
183
+ def resnet18(pretrained=False, **kwargs):
184
+ """Constructs a ResNet-18 model.
185
+ Args:
186
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
187
+ """
188
+ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
189
+ if pretrained:
190
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
191
+ return model
192
+
193
+
194
+ def resnet34(pretrained=False, **kwargs):
195
+ """Constructs a ResNet-34 model.
196
+ Args:
197
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
198
+ """
199
+ model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
200
+ if pretrained:
201
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
202
+ return model
203
+
204
+
205
+ def resnet50(pretrained=False, **kwargs):
206
+ """Constructs a ResNet-50 model.
207
+ Args:
208
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
209
+ """
210
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
211
+ if pretrained:
212
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
213
+ return model
214
+
215
+
216
+ def resnet101(pretrained=False, **kwargs):
217
+ """Constructs a ResNet-101 model.
218
+ Args:
219
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
220
+ """
221
+ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
222
+ if pretrained:
223
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
224
+ return model
225
+
226
+
227
+ def resnet152(pretrained=False, **kwargs):
228
+ """Constructs a ResNet-152 model.
229
+ Args:
230
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
231
+ """
232
+ model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
233
+ if pretrained:
234
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
235
+ return model