Gent (PG/R - Comp Sci & Elec Eng) commited on
Commit
bc8c24d
·
1 Parent(s): f91ccd0

dependent files

Browse files
Files changed (2) hide show
  1. utils.py +83 -0
  2. vision_transformer.py +288 -0
utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os
2
+ import torch.distributed as dist
3
+
4
+ def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
5
+ """
6
+ Re-start from checkpoint
7
+ """
8
+ if not os.path.isfile(ckp_path):
9
+ return
10
+ print("Found checkpoint at {}".format(ckp_path))
11
+ if ckp_path.startswith('https'):
12
+ checkpoint = torch.hub.load_state_dict_from_url(
13
+ ckp_path, map_location='cpu', check_hash=True)
14
+ else:
15
+ checkpoint = torch.load(ckp_path, map_location='cpu')
16
+
17
+ for key, value in kwargs.items():
18
+ if key in checkpoint and value is not None:
19
+ if key == "model_ema":
20
+ value.ema.load_state_dict(checkpoint[key])
21
+ else:
22
+ value.load_state_dict(checkpoint[key])
23
+ else:
24
+ print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path))
25
+
26
+ # re load variable important for the run
27
+ if run_variables is not None:
28
+ for var_name in run_variables:
29
+ if var_name in checkpoint:
30
+ run_variables[var_name] = checkpoint[var_name]
31
+
32
+
33
+
34
+ def load_pretrained_weights(model, pretrained_weights, checkpoint_key=None, prefixes=None,drop_head="head"):
35
+ """load vit weights"""
36
+ if pretrained_weights == '':
37
+ return
38
+ elif pretrained_weights.startswith('https'):
39
+ state_dict = torch.hub.load_state_dict_from_url(
40
+ pretrained_weights, map_location='cpu', check_hash=True)
41
+ else:
42
+ state_dict = torch.load(pretrained_weights, map_location='cpu')
43
+
44
+ epoch = state_dict['epoch'] if 'epoch' in state_dict else -1
45
+ if not checkpoint_key:
46
+ for key in ['model', 'teacher', 'encoder']:
47
+ if key in state_dict: checkpoint_key = key
48
+
49
+ print("Load pre-trained checkpoint from: %s[%s] at %d epoch" % (pretrained_weights, checkpoint_key, epoch))
50
+
51
+ state_dict = state_dict[checkpoint_key]
52
+ # remove `module.` prefix
53
+ if prefixes is None: prefixes= ["module.","backbone."]
54
+ for prefix in prefixes:
55
+ state_dict = {k.replace(prefix, ""): v for k, v in state_dict.items() if not drop_head in k }
56
+ # remove `backbone.` prefix induced by multicrop wrapper
57
+ checkpoint_model = state_dict
58
+
59
+
60
+ # interpolate position embedding
61
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
62
+ embedding_size = pos_embed_checkpoint.shape[-1]
63
+ num_patches = model.patch_embed.num_patches
64
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
65
+ # height (== width) for the checkpoint position embedding
66
+ orig_size = int((pos_embed_checkpoint.shape[-2] ) ** 0.5)
67
+ # height (== width) for the new position embedding
68
+ new_size = int(num_patches ** 0.5)
69
+ # class_token and dist_token are kept unchanged
70
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
71
+ # only the position tokens are interpolated
72
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
73
+ # print('debug:', pos_embed_checkpoint.shape,orig_size,new_size,num_extra_tokens)
74
+
75
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
76
+ pos_tokens = torch.nn.functional.interpolate(
77
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
78
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
79
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
80
+ checkpoint_model['pos_embed'] = new_pos_embed
81
+
82
+ msg = model.load_state_dict(checkpoint_model, strict=False)
83
+ print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))
vision_transformer.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import partial
3
+ import numpy as np
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ import warnings
9
+
10
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
11
+ def norm_cdf(x):
12
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
13
+
14
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
15
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
16
+ "The distribution of values may be incorrect.",
17
+ stacklevel=2)
18
+
19
+ with torch.no_grad():
20
+ l = norm_cdf((a - mean) / std)
21
+ u = norm_cdf((b - mean) / std)
22
+
23
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
24
+
25
+ tensor.erfinv_()
26
+
27
+ tensor.mul_(std * math.sqrt(2.))
28
+ tensor.add_(mean)
29
+
30
+ tensor.clamp_(min=a, max=b)
31
+ return tensor
32
+
33
+
34
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
35
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
36
+
37
+
38
+
39
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
40
+ if drop_prob == 0. or not training:
41
+ return x
42
+ keep_prob = 1 - drop_prob
43
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
44
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
45
+ random_tensor.floor_() # binarize
46
+ output = x.div(keep_prob) * random_tensor
47
+ return output
48
+
49
+
50
+ class DropPath(nn.Module):
51
+ def __init__(self, drop_prob=None):
52
+ super(DropPath, self).__init__()
53
+ self.drop_prob = drop_prob
54
+
55
+ def forward(self, x):
56
+ return drop_path(x, self.drop_prob, self.training)
57
+
58
+
59
+ class Mlp(nn.Module):
60
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
61
+ super().__init__()
62
+ out_features = out_features or in_features
63
+ hidden_features = hidden_features or in_features
64
+ self.fc1 = nn.Linear(in_features, hidden_features)
65
+ self.act = act_layer()
66
+ self.fc2 = nn.Linear(hidden_features, out_features)
67
+ self.drop = nn.Dropout(drop)
68
+
69
+ def forward(self, x):
70
+ x = self.fc1(x)
71
+ x = self.act(x)
72
+ x = self.drop(x)
73
+ x = self.fc2(x)
74
+ x = self.drop(x)
75
+ return x
76
+
77
+
78
+ class Attention(nn.Module):
79
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
80
+ super().__init__()
81
+ self.num_heads = num_heads
82
+ head_dim = dim // num_heads
83
+ self.scale = qk_scale or head_dim ** -0.5
84
+
85
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
86
+ self.attn_drop = nn.Dropout(attn_drop)
87
+ self.proj = nn.Linear(dim, dim)
88
+ self.proj_drop = nn.Dropout(proj_drop)
89
+ self.requires_attn = True
90
+
91
+ def forward(self, x):
92
+ B, N, C = x.shape
93
+ qkv = self.qkv(x) # B, N, 3, self.num_heads x C // self.num_heads
94
+ if self.requires_attn:
95
+ qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
96
+ q, k, v = qkv[0], qkv[1], qkv[2] # 1, B, self.num_heads, N, C // self.num_heads
97
+ attn = (q @ k.transpose(-2, -1)) * self.scale
98
+ attn = attn.softmax(dim=-1)
99
+ attn = self.attn_drop(attn)
100
+
101
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
102
+ else:
103
+ qkv = qkv.reshape(B, N, 3,C)
104
+ x = qkv[:,:,2]
105
+ attn = None
106
+ x = self.proj(x)
107
+ x = self.proj_drop(x)
108
+ return x, attn
109
+
110
+
111
+ class Block(nn.Module):
112
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
113
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
114
+ super().__init__()
115
+ self.norm1 = norm_layer(dim)
116
+ self.attn = Attention(
117
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
118
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
119
+ self.norm2 = norm_layer(dim)
120
+ mlp_hidden_dim = int(dim * mlp_ratio)
121
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
122
+
123
+ def forward(self, x, return_attention=False):
124
+ y, attn = self.attn(self.norm1(x))
125
+ if return_attention:
126
+ return attn
127
+ x = x + self.drop_path(y)
128
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
129
+ return x
130
+
131
+
132
+ class PatchEmbed(nn.Module):
133
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
134
+ super().__init__()
135
+ num_patches = (img_size // patch_size) * (img_size // patch_size)
136
+ self.img_size = img_size
137
+ self.patch_size = patch_size
138
+ self.num_patches = num_patches
139
+
140
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
141
+
142
+ def forward(self, x):
143
+ B, C, H, W = x.shape
144
+ x = self.proj(x).flatten(2).transpose(1, 2)
145
+ return x
146
+
147
+
148
+ class VisionTransformer(nn.Module):
149
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
150
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
151
+ drop_path_rate=0., norm_layer=nn.LayerNorm, head_type=2, **kwargs):
152
+ super().__init__()
153
+ self.num_features = self.embed_dim = embed_dim
154
+ self.head_type = head_type
155
+ if isinstance(img_size,list): img_size=img_size[0]
156
+ self.patch_embed = PatchEmbed(
157
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
158
+ num_patches = self.patch_embed.num_patches
159
+
160
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
161
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
162
+ self.pos_drop = nn.Dropout(p=drop_rate)
163
+
164
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
165
+ self.blocks = nn.ModuleList([
166
+ Block(
167
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
168
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
169
+ for i in range(depth)])
170
+ self.norm = norm_layer(embed_dim)
171
+
172
+ # Classifier head
173
+ if self.head_type==2:
174
+ self.head = nn.Linear(2*embed_dim, num_classes) if num_classes > 0 else nn.Identity()
175
+ else:
176
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
177
+
178
+
179
+ trunc_normal_(self.pos_embed, std=.02)
180
+ trunc_normal_(self.cls_token, std=.02)
181
+ self.apply(self._init_weights)
182
+
183
+ def _init_weights(self, m):
184
+ if isinstance(m, nn.Linear):
185
+ trunc_normal_(m.weight, std=.02)
186
+ if isinstance(m, nn.Linear) and m.bias is not None:
187
+ nn.init.constant_(m.bias, 0)
188
+ elif isinstance(m, nn.LayerNorm):
189
+ nn.init.constant_(m.bias, 0)
190
+ nn.init.constant_(m.weight, 1.0)
191
+
192
+ def interpolate_pos_encoding(self, x, w, h):
193
+ npatch = x.shape[1] - 1
194
+ N = self.pos_embed.shape[1] - 1
195
+ if npatch == N and w == h:
196
+ return self.pos_embed
197
+ class_pos_embed = self.pos_embed[:, 0]
198
+ patch_pos_embed = self.pos_embed[:, 1:]
199
+ dim = x.shape[-1]
200
+ w0 = w // self.patch_embed.patch_size
201
+ h0 = h // self.patch_embed.patch_size
202
+ w0, h0 = w0 + 0.1, h0 + 0.1
203
+ patch_pos_embed = nn.functional.interpolate(
204
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
205
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
206
+ mode='bicubic',
207
+ )
208
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
209
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
210
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
211
+
212
+
213
+ @torch.jit.ignore
214
+ def no_weight_decay(self):
215
+ return {'pos_embed', 'cls_token', 'dist_token'}
216
+
217
+ @torch.jit.ignore
218
+ def group_matcher(self, coarse=False):
219
+ return dict(
220
+ stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
221
+ blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
222
+ )
223
+
224
+
225
+ def prepare_tokens(self, x):
226
+ B, nc, w, h = x.shape
227
+ x = self.patch_embed(x)
228
+
229
+ cls_tokens = self.cls_token.expand(B, -1, -1)
230
+ x = torch.cat((cls_tokens, x), dim=1)
231
+
232
+ x = x + self.interpolate_pos_encoding(x, w, h)
233
+
234
+ return self.pos_drop(x)
235
+
236
+ def forward_features(self, x):
237
+ x = self.prepare_tokens(x)
238
+ for blk in self.blocks:
239
+ x = blk(x)
240
+
241
+ x = self.norm(x)
242
+ return x
243
+
244
+ def forward(self, x):
245
+ x = self.forward_features(x)
246
+ if self.head_type==0:
247
+ return self.head(x[:, 0])
248
+ elif self.head_type==1:
249
+ return self.head(x[:, 1:].mean(1))
250
+ elif self.head_type==2:
251
+ return self.head( torch.cat( (x[:, 0], torch.mean(x[:, 1:], dim=1)), dim=1 ))
252
+
253
+ def get_intermediate_layers(self, x, n=1):
254
+ x = self.prepare_tokens(x)
255
+ # we return the output tokens from the `n` last blocks
256
+ output = []
257
+ for i, blk in enumerate(self.blocks):
258
+ x = blk(x)
259
+ if len(self.blocks) - i <= n:
260
+ output.append(self.norm(x))
261
+ return output
262
+
263
+ def vit_tiny(patch_size=16, **kwargs):
264
+ model = VisionTransformer(
265
+ patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
266
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
267
+ return model
268
+
269
+
270
+ def vit_small(patch_size=16, **kwargs):
271
+ model = VisionTransformer(
272
+ patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
273
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
274
+ return model
275
+
276
+
277
+ def vit_base(patch_size=16, **kwargs):
278
+ model = VisionTransformer(
279
+ patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
280
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
281
+ return model
282
+
283
+
284
+ def vit_large(patch_size=16, **kwargs):
285
+ model = VisionTransformer(
286
+ patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
287
+ qkv_bias=True, **kwargs)
288
+ return model