dianecy commited on
Commit
31dfd6a
·
verified ·
1 Parent(s): 47f6054

Upload folder using huggingface_hub

Browse files
model/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
model/__init__.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .segmenter import CRIS
2
+ from .segmenter_angular import CRIS_S
3
+ # from .segmenter_ang_nonoise_ddp import CRIS_Wo_Noise
4
+ from loguru import logger
5
+
6
+ # def build_segmenter(args):
7
+ # model = CRIS(args)
8
+ # backbone = []
9
+ # backbone_no_decay = []
10
+ # head = []
11
+ # for k, v in model.named_parameters():
12
+ # if k.startswith('backbone') and 'positional_embedding' not in k:
13
+ # backbone.append(v)
14
+ # elif 'positional_embedding' in k:
15
+ # backbone_no_decay.append(v)
16
+ # else:
17
+ # head.append(v)
18
+ # print('Backbone with decay: {}, Backbone without decay: {}, Head: {}'.format(
19
+ # len(backbone), len(backbone_no_decay), len(head)))
20
+ # param_list = [{
21
+ # 'params': backbone,
22
+ # 'initial_lr': args.lr_multi * args.base_lr
23
+ # }, {
24
+ # 'params': backbone_no_decay,
25
+ # 'initial_lr': args.lr_multi * args.base_lr,
26
+ # 'weight_decay': 0
27
+ # }, {
28
+ # 'params': head,
29
+ # 'initial_lr': args.base_lr
30
+ # }]
31
+ # return model, param_list
32
+
33
+
34
+ def build_segmenter(args):
35
+ model = CRIS_S(args)
36
+ backbone = []
37
+ head = []
38
+ for k, v in model.named_parameters():
39
+ if k.startswith('backbone') and 'positional_embedding' not in k:
40
+ backbone.append(v)
41
+ else:
42
+ head.append(v)
43
+ logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
44
+ param_list = [{
45
+ 'params': backbone,
46
+ 'initial_lr': args.lr_multi * args.base_lr
47
+ }, {
48
+ 'params': head,
49
+ 'initial_lr': args.base_lr
50
+ }]
51
+ return model, param_list
52
+
53
+ def build_segmenter_original(args):
54
+ model = CRIS(args)
55
+ backbone = []
56
+ head = []
57
+ for k, v in model.named_parameters():
58
+ if k.startswith('backbone') and 'positional_embedding' not in k:
59
+ backbone.append(v)
60
+ else:
61
+ head.append(v)
62
+ logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
63
+ param_list = [{
64
+ 'params': backbone,
65
+ 'initial_lr': args.lr_multi * args.base_lr
66
+ }, {
67
+ 'params': head,
68
+ 'initial_lr': args.base_lr
69
+ }]
70
+ return model, param_list
71
+
72
+
73
+ # def build_segmenter_textaug(args):
74
+ # model = CRIS_Wo_Noise(args)
75
+ # backbone = []
76
+ # head = []
77
+ # for k, v in model.named_parameters():
78
+ # if k.startswith('backbone') and 'positional_embedding' not in k:
79
+ # backbone.append(v)
80
+ # else:
81
+ # head.append(v)
82
+ # logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
83
+ # param_list = [{
84
+ # 'params': backbone,
85
+ # 'initial_lr': args.lr_multi * args.base_lr
86
+ # }, {
87
+ # 'params': head,
88
+ # 'initial_lr': args.base_lr
89
+ # }]
90
+ # return model, param_list
model/clip.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, inplanes, planes, stride=1):
14
+ super().__init__()
15
+
16
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+
20
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
21
+ self.bn2 = nn.BatchNorm2d(planes)
22
+
23
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
24
+
25
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
26
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
27
+
28
+ self.relu = nn.ReLU(inplace=True)
29
+ self.downsample = None
30
+ self.stride = stride
31
+
32
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
33
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
34
+ self.downsample = nn.Sequential(
35
+ OrderedDict([("-1", nn.AvgPool2d(stride)),
36
+ ("0",
37
+ nn.Conv2d(inplanes,
38
+ planes * self.expansion,
39
+ 1,
40
+ stride=1,
41
+ bias=False)),
42
+ ("1", nn.BatchNorm2d(planes * self.expansion))]))
43
+
44
+ def forward(self, x: torch.Tensor):
45
+ identity = x
46
+
47
+ out = self.relu(self.bn1(self.conv1(x)))
48
+ out = self.relu(self.bn2(self.conv2(out)))
49
+ out = self.avgpool(out)
50
+ out = self.bn3(self.conv3(out))
51
+
52
+ if self.downsample is not None:
53
+ identity = self.downsample(x)
54
+
55
+ out += identity
56
+ out = self.relu(out)
57
+ return out
58
+
59
+
60
+ class AttentionPool2d(nn.Module):
61
+ def __init__(self,
62
+ spacial_dim: int,
63
+ embed_dim: int,
64
+ num_heads: int,
65
+ output_dim: int = None):
66
+ super().__init__()
67
+ self.spacial_dim = spacial_dim
68
+ self.positional_embedding = nn.Parameter(
69
+ torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
70
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
71
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
72
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
73
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
74
+ self.num_heads = num_heads
75
+ # residual
76
+ self.connect = nn.Sequential(
77
+ nn.Conv2d(embed_dim, output_dim, 1, stride=1, bias=False),
78
+ nn.BatchNorm2d(output_dim))
79
+
80
+ def resize_pos_embed(self, pos_embed, input_shpae):
81
+ """Resize pos_embed weights.
82
+ Resize pos_embed using bicubic interpolate method.
83
+ Args:
84
+ pos_embed (torch.Tensor): Position embedding weights.
85
+ input_shpae (tuple): Tuple for (downsampled input image height,
86
+ downsampled input image width).
87
+ pos_shape (tuple): The resolution of downsampled origin training
88
+ image.
89
+ mode (str): Algorithm used for upsampling:
90
+ ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
91
+ ``'trilinear'``. Default: ``'nearest'``
92
+ Return:
93
+ torch.Tensor: The resized pos_embed of shape [B, C, L_new]
94
+ """
95
+ assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
96
+ pos_h = pos_w = self.spacial_dim
97
+ cls_token_weight = pos_embed[:, 0]
98
+ pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
99
+ pos_embed_weight = pos_embed_weight.reshape(
100
+ 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
101
+ pos_embed_weight = F.interpolate(pos_embed_weight,
102
+ size=input_shpae,
103
+ align_corners=False,
104
+ mode='bicubic')
105
+ cls_token_weight = cls_token_weight.unsqueeze(1)
106
+ pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
107
+ # pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
108
+ return pos_embed_weight.transpose(-2, -1)
109
+
110
+ def forward(self, x):
111
+ B, C, H, W = x.size()
112
+ res = self.connect(x)
113
+ x = x.reshape(B, C, -1) # NC(HW)
114
+ # x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(1+HW)
115
+ pos_embed = self.positional_embedding.unsqueeze(0)
116
+ pos_embed = self.resize_pos_embed(pos_embed, (H, W)) # NC(HW)
117
+ x = x + pos_embed.to(x.dtype) # NC(HW)
118
+ x = x.permute(2, 0, 1) # (HW)NC
119
+ x, _ = F.multi_head_attention_forward(
120
+ query=x,
121
+ key=x,
122
+ value=x,
123
+ embed_dim_to_check=x.shape[-1],
124
+ num_heads=self.num_heads,
125
+ q_proj_weight=self.q_proj.weight,
126
+ k_proj_weight=self.k_proj.weight,
127
+ v_proj_weight=self.v_proj.weight,
128
+ in_proj_weight=None,
129
+ in_proj_bias=torch.cat(
130
+ [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
131
+ bias_k=None,
132
+ bias_v=None,
133
+ add_zero_attn=False,
134
+ dropout_p=0,
135
+ out_proj_weight=self.c_proj.weight,
136
+ out_proj_bias=self.c_proj.bias,
137
+ use_separate_proj_weight=True,
138
+ training=self.training,
139
+ need_weights=False)
140
+ x = x.permute(1, 2, 0).reshape(B, -1, H, W)
141
+ x = x + res
142
+ x = F.relu(x, True)
143
+
144
+ return x
145
+
146
+
147
+ class ModifiedResNet(nn.Module):
148
+ """
149
+ A ResNet class that is similar to torchvision's but contains the following changes:
150
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
151
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
152
+ - The final pooling layer is a QKV attention instead of an average pool
153
+ """
154
+ def __init__(self,
155
+ layers,
156
+ output_dim,
157
+ heads,
158
+ input_resolution=224,
159
+ width=64):
160
+ super().__init__()
161
+ self.output_dim = output_dim
162
+ self.input_resolution = input_resolution
163
+
164
+ # the 3-layer stem
165
+ self.conv1 = nn.Conv2d(3,
166
+ width // 2,
167
+ kernel_size=3,
168
+ stride=2,
169
+ padding=1,
170
+ bias=False)
171
+ self.bn1 = nn.BatchNorm2d(width // 2)
172
+ self.conv2 = nn.Conv2d(width // 2,
173
+ width // 2,
174
+ kernel_size=3,
175
+ padding=1,
176
+ bias=False)
177
+ self.bn2 = nn.BatchNorm2d(width // 2)
178
+ self.conv3 = nn.Conv2d(width // 2,
179
+ width,
180
+ kernel_size=3,
181
+ padding=1,
182
+ bias=False)
183
+ self.bn3 = nn.BatchNorm2d(width)
184
+ self.avgpool = nn.AvgPool2d(2)
185
+ self.relu = nn.ReLU(inplace=True)
186
+
187
+ # residual layers
188
+ self._inplanes = width # this is a *mutable* variable used during construction
189
+ self.layer1 = self._make_layer(width, layers[0])
190
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
191
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
192
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
193
+
194
+ embed_dim = width * 32 # the ResNet feature dimension
195
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim,
196
+ heads, output_dim)
197
+
198
+ def _make_layer(self, planes, blocks, stride=1):
199
+ layers = [Bottleneck(self._inplanes, planes, stride)]
200
+
201
+ self._inplanes = planes * Bottleneck.expansion
202
+ for _ in range(1, blocks):
203
+ layers.append(Bottleneck(self._inplanes, planes))
204
+
205
+ return nn.Sequential(*layers)
206
+
207
+ def forward(self, x):
208
+ def stem(x):
209
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2),
210
+ (self.conv3, self.bn3)]:
211
+ x = self.relu(bn(conv(x)))
212
+ x = self.avgpool(x)
213
+ return x
214
+
215
+ x = x.type(self.conv1.weight.dtype)
216
+ x = stem(x)
217
+ x = self.layer1(x)
218
+ x2 = self.layer2(x)
219
+ x3 = self.layer3(x2)
220
+ x4 = self.layer4(x3)
221
+ x4 = self.attnpool(x4)
222
+
223
+ return (x2, x3, x4)
224
+
225
+
226
+ class LayerNorm(nn.LayerNorm):
227
+ """Subclass torch's LayerNorm to handle fp16."""
228
+ def forward(self, x: torch.Tensor):
229
+ orig_type = x.dtype
230
+ ret = super().forward(x.type(torch.float32))
231
+ return ret.type(orig_type)
232
+
233
+
234
+ class QuickGELU(nn.Module):
235
+ def forward(self, x: torch.Tensor):
236
+ return x * torch.sigmoid(1.702 * x)
237
+
238
+
239
+ class ResidualAttentionBlock(nn.Module):
240
+ def __init__(self,
241
+ d_model: int,
242
+ n_head: int,
243
+ attn_mask: torch.Tensor = None):
244
+ super().__init__()
245
+
246
+ self.attn = nn.MultiheadAttention(d_model, n_head)
247
+ self.ln_1 = LayerNorm(d_model)
248
+ self.mlp = nn.Sequential(
249
+ OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)),
250
+ ("gelu", QuickGELU()),
251
+ ("c_proj", nn.Linear(d_model * 4, d_model))]))
252
+ self.ln_2 = LayerNorm(d_model)
253
+ self.attn_mask = attn_mask
254
+
255
+ def attention(self, x: torch.Tensor):
256
+ self.attn_mask = self.attn_mask.to(
257
+ dtype=x.dtype,
258
+ device=x.device) if self.attn_mask is not None else None
259
+ return self.attn(x, x, x, need_weights=False,
260
+ attn_mask=self.attn_mask)[0]
261
+
262
+ def forward(self, x: torch.Tensor):
263
+ x = x + self.attention(self.ln_1(x))
264
+ x = x + self.mlp(self.ln_2(x))
265
+ return x
266
+
267
+
268
+ class Transformer(nn.Module):
269
+ def __init__(self,
270
+ width: int,
271
+ layers: int,
272
+ heads: int,
273
+ attn_mask: torch.Tensor = None):
274
+ super().__init__()
275
+ self.width = width
276
+ self.layers = layers
277
+ self.resblocks = nn.Sequential(*[
278
+ ResidualAttentionBlock(width, heads, attn_mask)
279
+ for _ in range(layers)
280
+ ])
281
+
282
+ def forward(self, x: torch.Tensor):
283
+ return self.resblocks(x)
284
+
285
+
286
+ class VisionTransformer(nn.Module):
287
+ def __init__(self, input_resolution: int, patch_size: int, width: int,
288
+ layers: int, heads: int, output_dim: int):
289
+ super().__init__()
290
+ self.input_resolution = input_resolution
291
+ self.output_dim = output_dim
292
+ self.conv1 = nn.Conv2d(in_channels=3,
293
+ out_channels=width,
294
+ kernel_size=patch_size,
295
+ stride=patch_size,
296
+ bias=False)
297
+
298
+ scale = width**-0.5
299
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
300
+ self.positional_embedding = nn.Parameter(scale * torch.randn(
301
+ (input_resolution // patch_size)**2 + 1, width))
302
+ self.ln_pre = LayerNorm(width)
303
+
304
+ self.transformer = Transformer(width, layers, heads)
305
+
306
+ self.ln_post = LayerNorm(width)
307
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
308
+
309
+ def forward(self, x: torch.Tensor):
310
+ x = self.conv1(x) # shape = [*, width, grid, grid]
311
+ x = x.reshape(x.shape[0], x.shape[1],
312
+ -1) # shape = [*, width, grid ** 2]
313
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
314
+ x = torch.cat([
315
+ self.class_embedding.to(x.dtype) + torch.zeros(
316
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
317
+ ],
318
+ dim=1) # shape = [*, grid ** 2 + 1, width]
319
+ x = x + self.positional_embedding.to(x.dtype)
320
+ x = self.ln_pre(x)
321
+
322
+ x = x.permute(1, 0, 2) # NLD -> LND
323
+ x = self.transformer(x)
324
+ x = x.permute(1, 0, 2) # LND -> NLD
325
+
326
+ # x = self.ln_post(x[:, 0, :])
327
+ x = self.ln_post(x[:, 1:, :])
328
+
329
+ if self.proj is not None:
330
+ x = x @ self.proj
331
+
332
+ return x
333
+
334
+
335
+ class CLIP(nn.Module):
336
+ def __init__(
337
+ self,
338
+ embed_dim: int,
339
+ # vision
340
+ image_resolution: int,
341
+ vision_layers: Union[Tuple[int, int, int, int], int],
342
+ vision_width: int,
343
+ vision_patch_size: int,
344
+ # text
345
+ context_length: int,
346
+ txt_length: int,
347
+ vocab_size: int,
348
+ transformer_width: int,
349
+ transformer_heads: int,
350
+ transformer_layers: int):
351
+ super().__init__()
352
+
353
+ self.context_length = context_length
354
+
355
+ if isinstance(vision_layers, (tuple, list)):
356
+ vision_heads = vision_width * 32 // 64
357
+ self.visual = ModifiedResNet(layers=vision_layers,
358
+ output_dim=embed_dim,
359
+ heads=vision_heads,
360
+ input_resolution=image_resolution,
361
+ width=vision_width)
362
+ else:
363
+ vision_heads = vision_width // 64
364
+ self.visual = VisionTransformer(input_resolution=image_resolution,
365
+ patch_size=vision_patch_size,
366
+ width=vision_width,
367
+ layers=vision_layers,
368
+ heads=vision_heads,
369
+ output_dim=embed_dim)
370
+
371
+ self.transformer = Transformer(
372
+ width=transformer_width,
373
+ layers=transformer_layers,
374
+ heads=transformer_heads,
375
+ attn_mask=self.build_attention_mask(txt_length))
376
+
377
+ self.vocab_size = vocab_size
378
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
379
+ self.positional_embedding = nn.Parameter(
380
+ torch.empty(self.context_length, transformer_width))
381
+ self.ln_final = LayerNorm(transformer_width)
382
+
383
+ self.text_projection = nn.Parameter(
384
+ torch.empty(transformer_width, embed_dim))
385
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
386
+
387
+ self.token_embedding.requires_grad_ = False
388
+ self.initialize_parameters()
389
+
390
+ def initialize_parameters(self):
391
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
392
+ nn.init.normal_(self.positional_embedding, std=0.01)
393
+
394
+ if isinstance(self.visual, ModifiedResNet):
395
+ if self.visual.attnpool is not None:
396
+ std = self.visual.attnpool.c_proj.in_features**-0.5
397
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
398
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
399
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
400
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
401
+
402
+ for resnet_block in [
403
+ self.visual.layer1, self.visual.layer2, self.visual.layer3,
404
+ self.visual.layer4
405
+ ]:
406
+ for name, param in resnet_block.named_parameters():
407
+ if name.endswith("bn3.weight"):
408
+ nn.init.zeros_(param)
409
+
410
+ proj_std = (self.transformer.width**-0.5) * (
411
+ (2 * self.transformer.layers)**-0.5)
412
+ attn_std = self.transformer.width**-0.5
413
+ fc_std = (2 * self.transformer.width)**-0.5
414
+ for block in self.transformer.resblocks:
415
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
416
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
417
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
418
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
419
+
420
+ if self.text_projection is not None:
421
+ nn.init.normal_(self.text_projection,
422
+ std=self.transformer.width**-0.5)
423
+
424
+ def build_attention_mask(self, context_length):
425
+ # lazily create causal attention mask, with full attention between the vision tokens
426
+ # pytorch uses additive attention mask; fill with -inf
427
+ mask = torch.empty(context_length, context_length)
428
+ mask.fill_(float("-inf"))
429
+ mask.triu_(1) # zero out the lower diagonal
430
+ return mask
431
+
432
+ @property
433
+ def dtype(self):
434
+ return self.visual.conv1.weight.dtype
435
+
436
+ def encode_image(self, image):
437
+ return self.visual(image.type(self.dtype))
438
+
439
+ def encode_text(self, text):
440
+ x = self.token_embedding(text).type(
441
+ self.dtype) # [batch_size, n_ctx, d_model]
442
+
443
+ x = x + self.positional_embedding.type(self.dtype)[:x.size(1)]
444
+ x = x.permute(1, 0, 2) # NLD -> LND
445
+ x = self.transformer(x)
446
+ x = x.permute(1, 0, 2) # LND -> NLD
447
+ x = self.ln_final(x).type(self.dtype)
448
+
449
+ # x.shape = [batch_size, n_ctx, transformer.width]
450
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
451
+ state = x[torch.arange(x.shape[0]),
452
+ text.argmax(dim=-1)] @ self.text_projection
453
+ # x = x @ self.text_projection
454
+ # state = x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
455
+
456
+ return x, state
457
+
458
+ def forward(self, image, text):
459
+ image_features = self.encode_image(image)
460
+ text_features = self.encode_text(text)
461
+
462
+ # normalized features
463
+ image_features = image_features / image_features.norm(dim=-1,
464
+ keepdim=True)
465
+ text_features = text_features / text_features.norm(dim=-1,
466
+ keepdim=True)
467
+
468
+ # cosine similarity as logits
469
+ logit_scale = self.logit_scale.exp()
470
+ logits_per_image = logit_scale * image_features @ text_features.t()
471
+ logits_per_text = logits_per_image.t()
472
+
473
+ # shape = [global_batch_size, global_batch_size]
474
+ return logits_per_image, logits_per_text
475
+
476
+
477
+ def convert_weights(model: nn.Module):
478
+ """Convert applicable model parameters to fp16"""
479
+ def _convert_weights_to_fp16(l):
480
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
481
+ l.weight.data = l.weight.data.half()
482
+ if l.bias is not None:
483
+ l.bias.data = l.bias.data.half()
484
+
485
+ if isinstance(l, nn.MultiheadAttention):
486
+ for attr in [
487
+ *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
488
+ "in_proj_bias", "bias_k", "bias_v"
489
+ ]:
490
+ tensor = getattr(l, attr)
491
+ if tensor is not None:
492
+ tensor.data = tensor.data.half()
493
+
494
+ for name in ["text_projection", "proj"]:
495
+ if hasattr(l, name):
496
+ attr = getattr(l, name)
497
+ if attr is not None:
498
+ attr.data = attr.data.half()
499
+
500
+ model.apply(_convert_weights_to_fp16)
501
+
502
+
503
+ def build_model(state_dict: dict, txt_length: int):
504
+ vit = "visual.proj" in state_dict
505
+
506
+ if vit:
507
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
508
+ vision_layers = len([
509
+ k for k in state_dict.keys()
510
+ if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
511
+ ])
512
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
513
+ grid_size = round(
514
+ (state_dict["visual.positional_embedding"].shape[0] - 1)**0.5)
515
+ image_resolution = vision_patch_size * grid_size
516
+ else:
517
+ counts: list = [
518
+ len(
519
+ set(
520
+ k.split(".")[2] for k in state_dict
521
+ if k.startswith(f"visual.layer{b}")))
522
+ for b in [1, 2, 3, 4]
523
+ ]
524
+ vision_layers = tuple(counts)
525
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
526
+ output_width = round(
527
+ (state_dict["visual.attnpool.positional_embedding"].shape[0] -
528
+ 1)**0.5)
529
+ vision_patch_size = None
530
+ assert output_width**2 + 1 == state_dict[
531
+ "visual.attnpool.positional_embedding"].shape[0]
532
+ image_resolution = output_width * 32
533
+
534
+ embed_dim = state_dict["text_projection"].shape[1]
535
+ context_length = state_dict["positional_embedding"].shape[0]
536
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
537
+ transformer_width = state_dict["ln_final.weight"].shape[0]
538
+ transformer_heads = transformer_width // 64
539
+ transformer_layers = len(
540
+ set(
541
+ k.split(".")[2] for k in state_dict
542
+ if k.startswith(f"transformer.resblocks")))
543
+
544
+ model = CLIP(embed_dim, image_resolution, vision_layers, vision_width,
545
+ vision_patch_size, context_length, txt_length, vocab_size,
546
+ transformer_width, transformer_heads, transformer_layers)
547
+
548
+ for key in ["input_resolution", "context_length", "vocab_size"]:
549
+ if key in state_dict:
550
+ del state_dict[key]
551
+
552
+ convert_weights(model)
553
+ model.load_state_dict(state_dict, False)
554
+ return model.eval()
model/layers.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def conv_layer(in_dim, out_dim, kernel_size=1, padding=0, stride=1):
9
+ return nn.Sequential(
10
+ nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False),
11
+ nn.BatchNorm2d(out_dim), nn.ReLU(True))
12
+
13
+
14
+ def linear_layer(in_dim, out_dim, bias=False):
15
+ return nn.Sequential(nn.Linear(in_dim, out_dim, bias),
16
+ nn.BatchNorm1d(out_dim), nn.ReLU(True))
17
+
18
+
19
+ class CoordConv(nn.Module):
20
+ def __init__(self,
21
+ in_channels,
22
+ out_channels,
23
+ kernel_size=3,
24
+ padding=1,
25
+ stride=1):
26
+ super().__init__()
27
+ self.conv1 = conv_layer(in_channels + 2, out_channels, kernel_size,
28
+ padding, stride)
29
+
30
+ def add_coord(self, input):
31
+ b, _, h, w = input.size()
32
+ x_range = torch.linspace(-1, 1, w, device=input.device)
33
+ y_range = torch.linspace(-1, 1, h, device=input.device)
34
+ y, x = torch.meshgrid(y_range, x_range)
35
+ y = y.expand([b, 1, -1, -1])
36
+ x = x.expand([b, 1, -1, -1])
37
+ coord_feat = torch.cat([x, y], 1)
38
+ input = torch.cat([input, coord_feat], 1)
39
+ return input
40
+
41
+ def forward(self, x):
42
+ x = self.add_coord(x)
43
+ x = self.conv1(x)
44
+ return x
45
+
46
+
47
+ class Projector(nn.Module):
48
+ def __init__(self, word_dim=1024, in_dim=256, kernel_size=3):
49
+ super().__init__()
50
+ self.in_dim = in_dim
51
+ self.kernel_size = kernel_size
52
+ # visual projector
53
+ self.vis = nn.Sequential( # os16 -> os4
54
+ nn.Upsample(scale_factor=2, mode='bilinear'),
55
+ conv_layer(in_dim * 2, in_dim * 2, 3, padding=1),
56
+ nn.Upsample(scale_factor=2, mode='bilinear'),
57
+ conv_layer(in_dim * 2, in_dim, 3, padding=1),
58
+ nn.Conv2d(in_dim, in_dim, 1))
59
+ # textual projector
60
+ out_dim = 1 * in_dim * kernel_size * kernel_size + 1
61
+ self.txt = nn.Linear(word_dim, out_dim)
62
+
63
+ def forward(self, x, word):
64
+ '''
65
+ x: b, 512, 26, 26
66
+ word: b, 512
67
+ '''
68
+ x = self.vis(x)
69
+ B, C, H, W = x.size()
70
+ # 1, b*256, 104, 104
71
+ x = x.reshape(1, B * C, H, W)
72
+ # txt: b, (256*3*3 + 1) -> b, 256, 3, 3 / b
73
+ word = self.txt(word)
74
+ weight, bias = word[:, :-1], word[:, -1]
75
+ weight = weight.reshape(B, C, self.kernel_size, self.kernel_size)
76
+ # Conv2d - 1, b*256, 104, 104 -> 1, b, 104, 104
77
+ out = F.conv2d(x,
78
+ weight,
79
+ padding=self.kernel_size // 2,
80
+ groups=weight.size(0),
81
+ bias=bias)
82
+ out = out.transpose(0, 1)
83
+ # b, 1, 104, 104
84
+ return out
85
+
86
+
87
+ class TransformerDecoder(nn.Module):
88
+ def __init__(self,
89
+ num_layers,
90
+ d_model,
91
+ nhead,
92
+ dim_ffn,
93
+ dropout,
94
+ return_intermediate=False):
95
+ super().__init__()
96
+ self.layers = nn.ModuleList([
97
+ TransformerDecoderLayer(d_model=d_model,
98
+ nhead=nhead,
99
+ dim_feedforward=dim_ffn,
100
+ dropout=dropout) for _ in range(num_layers)
101
+ ])
102
+ self.num_layers = num_layers
103
+ self.norm = nn.LayerNorm(d_model)
104
+ self.return_intermediate = return_intermediate
105
+
106
+ @staticmethod
107
+ def pos1d(d_model, length):
108
+ """
109
+ :param d_model: dimension of the model
110
+ :param length: length of positions
111
+ :return: length*d_model position matrix
112
+ """
113
+ if d_model % 2 != 0:
114
+ raise ValueError("Cannot use sin/cos positional encoding with "
115
+ "odd dim (got dim={:d})".format(d_model))
116
+ pe = torch.zeros(length, d_model)
117
+ position = torch.arange(0, length).unsqueeze(1)
118
+ div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
119
+ -(math.log(10000.0) / d_model)))
120
+ pe[:, 0::2] = torch.sin(position.float() * div_term)
121
+ pe[:, 1::2] = torch.cos(position.float() * div_term)
122
+
123
+ return pe.unsqueeze(1) # n, 1, 512
124
+
125
+ @staticmethod
126
+ def pos2d(d_model, height, width):
127
+ """
128
+ :param d_model: dimension of the model
129
+ :param height: height of the positions
130
+ :param width: width of the positions
131
+ :return: d_model*height*width position matrix
132
+ """
133
+ if d_model % 4 != 0:
134
+ raise ValueError("Cannot use sin/cos positional encoding with "
135
+ "odd dimension (got dim={:d})".format(d_model))
136
+ pe = torch.zeros(d_model, height, width)
137
+ # Each dimension use half of d_model
138
+ d_model = int(d_model / 2)
139
+ div_term = torch.exp(
140
+ torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))
141
+ pos_w = torch.arange(0., width).unsqueeze(1)
142
+ pos_h = torch.arange(0., height).unsqueeze(1)
143
+ pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(
144
+ 0, 1).unsqueeze(1).repeat(1, height, 1)
145
+ pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(
146
+ 0, 1).unsqueeze(1).repeat(1, height, 1)
147
+ pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(
148
+ 0, 1).unsqueeze(2).repeat(1, 1, width)
149
+ pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(
150
+ 0, 1).unsqueeze(2).repeat(1, 1, width)
151
+
152
+ return pe.reshape(-1, 1, height * width).permute(2, 1, 0) # hw, 1, 512
153
+
154
+ def forward(self, vis, txt, pad_mask):
155
+ '''
156
+ vis: b, 512, h, w
157
+ txt: b, L, 512
158
+ pad_mask: b, L
159
+ '''
160
+ B, C, H, W = vis.size()
161
+ _, L, D = txt.size()
162
+ # position encoding
163
+ vis_pos = self.pos2d(C, H, W)
164
+ txt_pos = self.pos1d(D, L)
165
+ # reshape & permute
166
+ vis = vis.reshape(B, C, -1).permute(2, 0, 1)
167
+ txt = txt.permute(1, 0, 2)
168
+ # forward
169
+ output = vis
170
+ intermediate = []
171
+ for layer in self.layers:
172
+ output = layer(output, txt, vis_pos, txt_pos, pad_mask)
173
+ if self.return_intermediate:
174
+ # HW, b, 512 -> b, 512, HW
175
+ intermediate.append(self.norm(output).permute(1, 2, 0))
176
+
177
+ if self.norm is not None:
178
+ # HW, b, 512 -> b, 512, HW
179
+ output = self.norm(output).permute(1, 2, 0)
180
+ if self.return_intermediate:
181
+ intermediate.pop()
182
+ intermediate.append(output)
183
+ # [output1, output2, ..., output_n]
184
+ return intermediate
185
+ else:
186
+ # b, 512, HW
187
+ return output
188
+ return output
189
+
190
+
191
+ class TransformerDecoderLayer(nn.Module):
192
+ def __init__(self,
193
+ d_model=512,
194
+ nhead=9,
195
+ dim_feedforward=2048,
196
+ dropout=0.1):
197
+ super().__init__()
198
+ # Normalization Layer
199
+ self.self_attn_norm = nn.LayerNorm(d_model)
200
+ self.cross_attn_norm = nn.LayerNorm(d_model)
201
+ # Attention Layer
202
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
203
+ self.multihead_attn = nn.MultiheadAttention(d_model,
204
+ nhead,
205
+ dropout=dropout,
206
+ kdim=d_model,
207
+ vdim=d_model)
208
+ # FFN
209
+ self.ffn = nn.Sequential(nn.Linear(d_model, dim_feedforward),
210
+ nn.ReLU(True), nn.Dropout(dropout),
211
+ nn.LayerNorm(dim_feedforward),
212
+ nn.Linear(dim_feedforward, d_model))
213
+ # LayerNorm & Dropout
214
+ self.norm1 = nn.LayerNorm(d_model)
215
+ self.norm2 = nn.LayerNorm(d_model)
216
+ self.norm3 = nn.LayerNorm(d_model)
217
+ self.dropout1 = nn.Dropout(dropout)
218
+ self.dropout2 = nn.Dropout(dropout)
219
+ self.dropout3 = nn.Dropout(dropout)
220
+
221
+ def with_pos_embed(self, tensor, pos):
222
+ return tensor if pos is None else tensor + pos.to(tensor.device)
223
+
224
+ def forward(self, vis, txt, vis_pos, txt_pos, pad_mask):
225
+ '''
226
+ vis: 26*26, b, 512
227
+ txt: L, b, 512
228
+ vis_pos: 26*26, 1, 512
229
+ txt_pos: L, 1, 512
230
+ pad_mask: b, L
231
+ '''
232
+ # Self-Attention
233
+ vis2 = self.norm1(vis)
234
+ q = k = self.with_pos_embed(vis2, vis_pos)
235
+ vis2 = self.self_attn(q, k, value=vis2)[0]
236
+ vis2 = self.self_attn_norm(vis2)
237
+ vis = vis + self.dropout1(vis2)
238
+ # Cross-Attention
239
+ vis2 = self.norm2(vis)
240
+ vis2 = self.multihead_attn(query=self.with_pos_embed(vis2, vis_pos),
241
+ key=self.with_pos_embed(txt, txt_pos),
242
+ value=txt,
243
+ key_padding_mask=pad_mask)[0]
244
+ vis2 = self.cross_attn_norm(vis2)
245
+ vis = vis + self.dropout2(vis2)
246
+ # FFN
247
+ vis2 = self.norm3(vis)
248
+ vis2 = self.ffn(vis2)
249
+ vis = vis + self.dropout3(vis2)
250
+ return vis
251
+
252
+
253
+ class FPN(nn.Module):
254
+ def __init__(self,
255
+ in_channels=[512, 1024, 1024],
256
+ out_channels=[256, 512, 1024]):
257
+ super(FPN, self).__init__()
258
+ # text projection
259
+ self.txt_proj = linear_layer(in_channels[2], out_channels[2])
260
+ # fusion 1: v5 & seq -> f_5: b, 1024, 13, 13
261
+ self.f1_v_proj = conv_layer(in_channels[2], out_channels[2], 1, 0)
262
+ self.norm_layer = nn.Sequential(nn.BatchNorm2d(out_channels[2]),
263
+ nn.ReLU(True))
264
+ # fusion 2: v4 & fm -> f_4: b, 512, 26, 26
265
+ self.f2_v_proj = conv_layer(in_channels[1], out_channels[1], 3, 1)
266
+ self.f2_cat = conv_layer(out_channels[2] + out_channels[1],
267
+ out_channels[1], 1, 0)
268
+ # fusion 3: v3 & fm_mid -> f_3: b, 512, 52, 52
269
+ self.f3_v_proj = conv_layer(in_channels[0], out_channels[0], 3, 1)
270
+ self.f3_cat = conv_layer(out_channels[0] + out_channels[1],
271
+ out_channels[1], 1, 0)
272
+ # fusion 4: f_3 & f_4 & f_5 -> fq: b, 256, 26, 26
273
+ self.f4_proj5 = conv_layer(out_channels[2], out_channels[1], 3, 1)
274
+ self.f4_proj4 = conv_layer(out_channels[1], out_channels[1], 3, 1)
275
+ self.f4_proj3 = conv_layer(out_channels[1], out_channels[1], 3, 1)
276
+ # aggregation
277
+ self.aggr = conv_layer(3 * out_channels[1], out_channels[1], 1, 0)
278
+ self.coordconv = nn.Sequential(
279
+ CoordConv(out_channels[1], out_channels[1], 3, 1),
280
+ conv_layer(out_channels[1], out_channels[1], 3, 1))
281
+
282
+ def forward(self, imgs, state):
283
+ # v3, v4, v5: 256, 52, 52 / 512, 26, 26 / 1024, 13, 13
284
+ v3, v4, v5 = imgs
285
+ # fusion 1: b, 1024, 13, 13
286
+ # text projection: b, 1024 -> b, 1024
287
+ state = self.txt_proj(state).unsqueeze(-1).unsqueeze(
288
+ -1) # b, 1024, 1, 1
289
+ f5 = self.f1_v_proj(v5)
290
+ f5 = self.norm_layer(f5 * state)
291
+ # fusion 2: b, 512, 26, 26
292
+ f4 = self.f2_v_proj(v4)
293
+ f5_ = F.interpolate(f5, scale_factor=2, mode='bilinear')
294
+ f4 = self.f2_cat(torch.cat([f4, f5_], dim=1))
295
+ # fusion 3: b, 256, 26, 26
296
+ f3 = self.f3_v_proj(v3)
297
+ f3 = F.avg_pool2d(f3, 2, 2)
298
+ f3 = self.f3_cat(torch.cat([f3, f4], dim=1))
299
+ # fusion 4: b, 512, 13, 13 / b, 512, 26, 26 / b, 512, 26, 26
300
+ fq5 = self.f4_proj5(f5)
301
+ fq4 = self.f4_proj4(f4)
302
+ fq3 = self.f4_proj3(f3)
303
+ # query
304
+ fq5 = F.interpolate(fq5, scale_factor=2, mode='bilinear')
305
+ fq = torch.cat([fq3, fq4, fq5], dim=1)
306
+ fq = self.aggr(fq)
307
+ fq = self.coordconv(fq)
308
+ # b, 512, 26, 26
309
+ return fq, f5
model/segmenter.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from model.clip import build_model
7
+ from .layers import FPN, Projector, TransformerDecoder
8
+
9
+
10
+ def MetricLoss(embeddings, n_pos, alpha = 0.5, args = None):
11
+ # embeddings: ((2*B), C, (H*W))
12
+ # n_pos : chunk size of positive pairs
13
+ # args: args
14
+ # returns: loss
15
+ metric_loss = 0
16
+
17
+ # flatten embeddings
18
+ B_, C, HW = embeddings.shape
19
+ emb = torch.mean(embeddings, dim=-1) # (2*B, C)
20
+ emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C)
21
+ emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C)
22
+ emb_distance = torch.norm(emb_i - emb_j, dim=-1) # (2*B, 2*B)
23
+ assert torch.sum(torch.diag(emb_distance[:B_, :B_])) == 0, \
24
+ "Diagonals are not zero. please check the permutation on the batch"
25
+ # print("distance metrix : ", emb_distance)
26
+
27
+ # positive pairs and loss
28
+ positive_mask = torch.zeros_like(emb_distance)
29
+ for i in range(B_//2):
30
+ positive_mask[2*i, 2*i+1] = 1
31
+ positive_mask[2*i+1, 2*i] = 1
32
+ positive_mask.fill_diagonal_(1)
33
+ positive_loss = torch.sum(emb_distance * positive_mask) / B_
34
+
35
+ # negative pairs and loss
36
+ negative_mask = torch.ones_like(emb_distance) - positive_mask
37
+
38
+ if args.div_batch:
39
+ negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / B_)
40
+ else:
41
+ negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / (B_**2 - 2*B_))
42
+
43
+ # print(positive_mask, negative_mask)
44
+
45
+ metric_loss = alpha * positive_loss + (1-alpha) * negative_loss
46
+
47
+ return metric_loss
48
+
49
+
50
+ def AngularMetricLoss(embeddings, n_pos, alpha = 0.5, args = None):
51
+ # embeddings: ((2*B), C, (H*W))
52
+ # n_pos : chunk size of positive pairs
53
+ # args: args
54
+ # returns: loss
55
+ geometric_loss = 0
56
+
57
+ # flatten embeddings
58
+ B_, C, HW = embeddings.shape
59
+ emb = torch.mean(embeddings, dim=-1) # (2*B, C)
60
+ emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C)
61
+ emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C)
62
+ sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
63
+ sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (2*B , 2*B)
64
+ print(sim_matrix)
65
+ assert torch.trace(sim_matrix) == B_, \
66
+ "similarity diagonals are not one. please check the permutation on the batch"
67
+ print("similarity metrix : ", sim_matrix)
68
+ phi = torch.acos(sim_matrix) # (2*B, 2*B)
69
+ print("phi metrix : ", phi)
70
+
71
+ # positive pairs and loss
72
+ positive_mask = torch.zeros_like(sim_matrix)
73
+ for i in range(B_//2):
74
+ positive_mask[2*i, 2*i+1] = 1
75
+ positive_mask[2*i+1, 2*i] = 1
76
+ positive_mask.fill_diagonal_(1)
77
+ positive_loss = torch.sum((phi**2) * positive_mask) / B_
78
+
79
+ # negative pairs and loss
80
+ negative_mask = torch.ones_like(sim_matrix) - positive_mask
81
+ phi_mask = phi < args.phi_threshold
82
+ negative_loss = (args.phi_threshold - phi)**2
83
+ print(negative_mask * phi_mask)
84
+ negative_loss = torch.sum(negative_loss * negative_mask * phi_mask) / (B_**2 - 2*B_)
85
+
86
+ print("pos loss, neg loss : ", positive_loss, negative_loss)
87
+
88
+ geometric_loss = alpha * positive_loss + (1-alpha) * negative_loss
89
+
90
+ return geometric_loss
91
+
92
+
93
+ class CRIS(nn.Module):
94
+ def __init__(self, cfg):
95
+ super().__init__()
96
+ # Vision & Text Encoder
97
+ clip_model = torch.jit.load(cfg.clip_pretrain,
98
+ map_location="cpu").eval()
99
+ self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float()
100
+ # Multi-Modal FPN
101
+ self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
102
+ # Decoder
103
+ self.decoder = TransformerDecoder(num_layers=cfg.num_layers,
104
+ d_model=cfg.vis_dim,
105
+ nhead=cfg.num_head,
106
+ dim_ffn=cfg.dim_ffn,
107
+ dropout=cfg.dropout,
108
+ return_intermediate=cfg.intermediate)
109
+ # Projector
110
+ self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3)
111
+ self.metric_learning = cfg.metric_learning
112
+ self.positive_strength = cfg.positive_strength
113
+ self.metric_loss_weight = cfg.metric_loss_weight
114
+ self.eps = cfg.ptb_rate
115
+ self.cfg = cfg
116
+
117
+ def forward(self, image, text, target=None):
118
+ '''
119
+ img: b, 3, h, w
120
+ word: b, words
121
+ word_mask: b, words
122
+ if self.metric_learning:
123
+ word: b, 2, words
124
+ word_mask: b, 2, words
125
+ mask: b, 1, h, w
126
+ '''
127
+ metric_learning_flag = (self.metric_learning and self.training)
128
+ metric_loss = 0
129
+
130
+ # 1.Resizing : if metric learning, batch size of the word is doubled
131
+ if metric_learning_flag:
132
+ #print("image shape : ", image.shape)
133
+ b, c, h, w = image.size()
134
+ # duplicate image and segmentation mask
135
+ if image is not None:
136
+ image = torch.cat([image, image], dim=0)
137
+ image = image.reshape(-1, b, c, h, w).transpose(0, 1).reshape(-1, c, h, w)
138
+ if target is not None:
139
+ target = torch.cat([target, target], dim=0)
140
+ target = target.reshape(-1, b, 1, h, w).transpose(0, 1).reshape(-1, 1, h, w)
141
+ # duplicate noise mask
142
+ b_, n_, l_ = text.size()
143
+ assert n_ == 2 ,"word size should be 2"
144
+ noise_mask = (text[:, 0, :] == text[:, 1, :])
145
+ noise_mask = torch.all(noise_mask, dim=-1)
146
+ noise_mask = noise_mask.unsqueeze(-1).repeat(1, 2).reshape(-1) # 2*b_
147
+ assert noise_mask.shape[0] == b_ * 2, "noise mask shape should be 2*b_"
148
+ text = text.reshape(b_ * 2, l_) # 2*b, l
149
+
150
+ # print("text shape : ", text.shape)
151
+ # print("image shape : ", image.shape)
152
+ # print("target shape : ", target.shape)
153
+ # print(torch.sum(image[0::2]) == torch.sum(image[1::2]))
154
+ # print(torch.sum(target[0::2]) == torch.sum(target[1::2]))
155
+
156
+ # padding mask used in decoder
157
+ pad_mask = torch.zeros_like(text).masked_fill_(text == 0, 1).bool()
158
+ # vis: C3 / C4 / C5
159
+ # word: b, length, 1024
160
+ # state: b, 1024
161
+ vis = self.backbone.encode_image(image)
162
+ word, state = self.backbone.encode_text(text)
163
+
164
+ b_, d_ = state.size()
165
+ assert b_ == word.size(0), "batch size of state and word should be same"
166
+
167
+
168
+ # 2. State Noising Step : if number of caption is 1,
169
+ # add noise to the corresponding indices
170
+ if metric_learning_flag :
171
+ noise = torch.randn_like(state) * self.eps
172
+ state[noise_mask] = state[noise_mask] + noise[noise_mask]
173
+
174
+ # print("shape of word, state : ", word.shape, state.shape)
175
+
176
+ # b, 512, 26, 26 (C4)
177
+ a3, a4, a5 = vis
178
+ # print("vis shape in model " , a3.shape, a4.shape, a5.shape)
179
+ fq, f5 = self.neck(vis, state)
180
+ b, c, h, w = fq.size()
181
+ fq = self.decoder(fq, word, pad_mask)
182
+ # print("decoder output shape : ", fq.shape)
183
+ # 3. Get metric loss
184
+ if metric_learning_flag:
185
+ metric_loss = MetricLoss(fq, 2, alpha=self.positive_strength, args = self.cfg)
186
+
187
+ fq = fq.reshape(b, c, h, w)
188
+
189
+ # b, 1, 104, 104
190
+ pred = self.proj(fq, state)
191
+
192
+ if self.training:
193
+ # resize mask
194
+ if pred.shape[-2:] != target.shape[-2:]:
195
+ target = F.interpolate(target, pred.shape[-2:],
196
+ mode='nearest').detach()
197
+ loss = F.binary_cross_entropy_with_logits(pred, target)
198
+ # 4. if metric learning, add metric loss and normalize
199
+ if metric_learning_flag:
200
+ #print("CE loss : ", loss, "metric loss : ", metric_loss)
201
+ loss = (loss + self.metric_loss_weight * metric_loss) / (1+self.metric_loss_weight)
202
+ return pred.detach(), target, loss
203
+ else:
204
+ return pred.detach()
model/segmenter_angular.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from model.clip import build_model
7
+
8
+ from .layers import FPN, Projector, TransformerDecoder
9
+
10
+
11
+ # def MetricLoss(embeddings, n_pos, alpha = 0.5, args = None):
12
+ # # embeddings: ((2*B), C, (H*W))
13
+ # # n_pos : chunk size of positive pairs
14
+ # # args: args
15
+ # # returns: loss
16
+ # metric_loss = 0
17
+
18
+ # # flatten embeddings
19
+ # B_, C, HW = embeddings.shape
20
+ # emb = torch.mean(embeddings, dim=-1) # (2*B, C)
21
+ # emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C)
22
+ # emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C)
23
+ # emb_distance = torch.norm(emb_i - emb_j, dim=-1) # (2*B, 2*B)
24
+ # assert torch.sum(torch.diag(emb_distance[:B_, :B_])) == 0, \
25
+ # "Diagonals are not zero. please check the permutation on the batch"
26
+ # # print("distance metrix : ", emb_distance)
27
+
28
+ # # positive pairs and loss
29
+ # positive_mask = torch.zeros_like(emb_distance)
30
+ # for i in range(B_//2):
31
+ # positive_mask[2*i, 2*i+1] = 1
32
+ # positive_mask[2*i+1, 2*i] = 1
33
+ # positive_mask.fill_diagonal_(1)
34
+ # positive_loss = torch.sum(emb_distance * positive_mask) / B_
35
+
36
+ # # negative pairs and loss
37
+ # negative_mask = torch.ones_like(emb_distance) - positive_mask
38
+ # negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / (B_**2 - 2*B_))
39
+
40
+ # # print(positive_mask, negative_mask)
41
+
42
+ # metric_loss = alpha * positive_loss + (1-alpha) * negative_loss
43
+
44
+ # return metric_loss
45
+
46
+
47
+
48
+ class CRIS_S(nn.Module):
49
+ def __init__(self, cfg):
50
+ super().__init__()
51
+ # Vision & Text Encoder
52
+ clip_model = torch.jit.load(cfg.clip_pretrain,
53
+ map_location="cpu").eval()
54
+ self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float()
55
+ # Multi-Modal FPN
56
+ self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
57
+ # Decoder
58
+ self.decoder = TransformerDecoder(num_layers=cfg.num_layers,
59
+ d_model=cfg.vis_dim,
60
+ nhead=cfg.num_head,
61
+ dim_ffn=cfg.dim_ffn,
62
+ dropout=cfg.dropout,
63
+ return_intermediate=cfg.intermediate)
64
+ # Projector
65
+ self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3)
66
+ self.metric_learning = cfg.metric_learning
67
+ self.positive_strength = cfg.positive_strength
68
+ self.metric_loss_weight = cfg.metric_loss_weight
69
+ self.eps = cfg.ptb_rate
70
+ self.cfg = cfg
71
+
72
+ def forward(self, image, text, target=None):
73
+ '''
74
+ img: b, 3, h, w
75
+ word: b, words
76
+ word_mask: b, words
77
+ if self.metric_learning:
78
+ word: b, 2, words
79
+ word_mask: b, 2, words
80
+ mask: b, 1, h, w
81
+ '''
82
+ metric_learning_flag = (self.metric_learning and self.training)
83
+ # TODO : mixing option btw distance & angular loss
84
+ mix_distance_angular = False
85
+ metric_loss = 0
86
+
87
+ # 1.Resizing : if metric learning, batch size of the word is doubled
88
+ if metric_learning_flag:
89
+ #print("image shape : ", image.shape)
90
+ b, c, h, w = image.size()
91
+ # duplicate image and segmentation mask
92
+ if image is not None:
93
+ image = torch.cat([image, image], dim=0)
94
+ image = image.reshape(-1, b, c, h, w).transpose(0, 1).reshape(-1, c, h, w)
95
+ if target is not None:
96
+ target = torch.cat([target, target], dim=0)
97
+ target = target.reshape(-1, b, 1, h, w).transpose(0, 1).reshape(-1, 1, h, w)
98
+ # duplicate noise mask
99
+ b_, n_, l_ = text.size()
100
+ assert n_ == 2 ,"word size should be 2"
101
+ noise_mask = (text[:, 0, :] == text[:, 1, :])
102
+ noise_mask = torch.all(noise_mask, dim=-1)
103
+ noise_mask = noise_mask.unsqueeze(-1).repeat(1, 2).reshape(-1) # 2*b_
104
+ assert noise_mask.shape[0] == b_ * 2, "noise mask shape should be 2*b_"
105
+ text = text.reshape(b_ * 2, l_) # 2*b, l
106
+
107
+ # print("text shape : ", text.shape)
108
+ # print("image shape : ", image.shape)
109
+ # print("target shape : ", target.shape)
110
+ # print(torch.sum(image[0::2]) == torch.sum(image[1::2]))
111
+ # print(torch.sum(target[0::2]) == torch.sum(target[1::2]))
112
+
113
+ # padding mask used in decoder
114
+ pad_mask = torch.zeros_like(text).masked_fill_(text == 0, 1).bool()
115
+ # vis: C3 / C4 / C5
116
+ # word: b, length, 1024
117
+ # state: b, 1024
118
+ vis = self.backbone.encode_image(image)
119
+ word, state = self.backbone.encode_text(text)
120
+
121
+ b_, d_ = state.size()
122
+ assert b_ == word.size(0), "batch size of state and word should be same"
123
+
124
+
125
+ # 2. State Noising Step : if number of caption is 1,
126
+ # add noise to the corresponding indices
127
+ if metric_learning_flag :
128
+ noise = torch.randn_like(state) * self.eps
129
+ state[noise_mask] = state[noise_mask] + noise[noise_mask]
130
+
131
+
132
+ # b, 512, 26, 26 (C4)
133
+ a3, a4, a5 = vis
134
+ fq, f5 = self.neck(vis, state)
135
+ b, c, h, w = fq.size()
136
+ fq = self.decoder(fq, word, pad_mask)
137
+ metric_tensor = fq
138
+ # if metric_learning_flag:
139
+ # metric_loss = AngularMetricLoss(fq, 2, alpha=self.positive_strength, args = self.cfg) # (1-self.positive_strength) *
140
+ # if mix_distance_angular:
141
+ # metric_loss += MetricLoss(fq, 2, alpha=self.positive_strength, args = self.cfg) # self.positive_strength *
142
+ fq = fq.reshape(b, c, h, w)
143
+
144
+ # b, 1, 104, 104
145
+ pred = self.proj(fq, state)
146
+
147
+ if self.training:
148
+ # resize mask
149
+ if pred.shape[-2:] != target.shape[-2:]:
150
+ target = F.interpolate(target, pred.shape[-2:],
151
+ mode='nearest').detach()
152
+ CE_loss = F.binary_cross_entropy_with_logits(pred, target)
153
+
154
+ # 4. if metric learning, add metric loss and normalize
155
+ # if metric_learning_flag:
156
+ # loss = (loss + self.metric_loss_weight * metric_loss) / (1+self.metric_loss_weight)
157
+ # safety_loss = loss * 0.
158
+ # loss = loss + safety_loss
159
+
160
+ return pred.detach(), target, CE_loss, metric_tensor
161
+ else:
162
+ #print(self.cfg.gpu, f"; loss = {loss}")
163
+ return pred.detach()
model/segmenter_verbonly.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from model.clip import build_model
7
+
8
+ from .layers import FPN, Projector, TransformerDecoder
9
+
10
+ class CRIS_VerbOnly(nn.Module):
11
+ def __init__(self, cfg):
12
+ super().__init__()
13
+ # Vision & Text Encoder
14
+ clip_model = torch.jit.load(cfg.clip_pretrain,
15
+ map_location="cpu").eval()
16
+ self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float()
17
+ # Multi-Modal FPN
18
+ self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
19
+ # Decoder
20
+ self.decoder = TransformerDecoder(num_layers=cfg.num_layers,
21
+ d_model=cfg.vis_dim,
22
+ nhead=cfg.num_head,
23
+ dim_ffn=cfg.dim_ffn,
24
+ dropout=cfg.dropout,
25
+ return_intermediate=cfg.intermediate)
26
+ # Projector
27
+ self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3)
28
+ self.metric_learning = False # cfg.metric_learning
29
+ self.positive_strength = cfg.positive_strength
30
+ self.metric_loss_weight = cfg.metric_loss_weight
31
+ self.eps = cfg.ptb_rate
32
+ self.cfg = cfg
33
+
34
+
35
+
36
+
37
+ def forward(self, image, text, target=None, verb=None):
38
+ '''
39
+ image: b, 3, h, w
40
+ text: b, words
41
+ target: b, 1, h, w
42
+ verb: b, words (if applicable, only used in training mode for contrastive learning)
43
+ '''
44
+
45
+ sentences, images, targets, pad_masks = [], [], [], []
46
+
47
+ if self.training:
48
+ verb_masks = []
49
+ cl_masks = []
50
+
51
+ for idx in range(len(text)):
52
+ sentences.append(text[idx])
53
+ images.append(image[idx])
54
+ targets.append(target[idx])
55
+ pad_masks.append(torch.zeros_like(text[idx]).masked_fill_(text[idx] == 0, 1).bool())
56
+
57
+ # If verb exists, process it
58
+ if verb[idx].numel() > 0 and verb[idx].sum().item() > 0:
59
+ verb_masks.extend([1, 1]) # Both original sentence and verb are marked
60
+ cl_masks.extend([0, 1]) # Only verb gets marked for exclusion from CE loss
61
+ sentences.append(verb[idx])
62
+ images.append(image[idx])
63
+ targets.append(target[idx])
64
+ pad_masks.append(torch.zeros_like(verb[idx]).masked_fill_(verb[idx] == 0, 1).bool())
65
+ else:
66
+ verb_masks.append(0)
67
+ cl_masks.append(0)
68
+
69
+
70
+ sentences = torch.stack(sentences)
71
+ images = torch.stack(images)
72
+ targets = torch.stack(targets)
73
+ pad_masks = torch.stack(pad_masks)
74
+ verb_masks = torch.tensor(verb_masks, dtype=torch.bool)
75
+ cl_masks = torch.tensor(cl_masks, dtype=torch.bool)
76
+
77
+ else:
78
+ sentences = text
79
+ images = image
80
+ targets = target
81
+ pad_masks = torch.zeros_like(text).masked_fill_(text == 0, 1).bool()
82
+
83
+ # Encoding images and text
84
+ vis = self.backbone.encode_image(images)
85
+ word, state = self.backbone.encode_text(sentences)
86
+
87
+ # FPN neck and decoder
88
+ fq, f5 = self.neck(vis, state)
89
+ b, c, h, w = fq.size()
90
+ fq = self.decoder(fq, word, pad_masks)
91
+ metric_tensor = fq # b, c, h*w
92
+ fq = fq.reshape(b, c, h, w)
93
+
94
+ # Final prediction
95
+ pred = self.proj(fq, state)
96
+
97
+ if self.training:
98
+ if pred.shape[-2:] != targets.shape[-2:]:
99
+ targets = F.interpolate(targets, pred.shape[-2:], mode='nearest').detach()
100
+
101
+ loss = F.binary_cross_entropy_with_logits(pred[~cl_masks], targets[~cl_masks])
102
+
103
+ if self.metric_learning:
104
+ metric_loss = self.compute_metric_loss(metric_tensor, verb_masks, args=self.cfg)
105
+ loss = (loss + self.metric_loss_weight * metric_loss) / (1 + self.metric_loss_weight)
106
+
107
+ return pred.detach(), targets, loss
108
+
109
+ return pred.detach() # In eval mode, only return the predictions
110
+
111
+
112
+
113
+
114
+ def return_mask_hponly(self, emb_distance, verb_mask=None):
115
+ B_, B_ = emb_distance.shape
116
+ positive_mask = torch.zeros_like(emb_distance)
117
+ positive_mask.fill_diagonal_(1) # Set diagonal elements to 1 for all cases
118
+
119
+ if B_ < len(verb_mask):
120
+ # If B_ equals to 2*K (double the number of verb phrase)
121
+ for i in range(B_ // 2):
122
+ positive_mask[2 * i, 2 * i + 1] = 1
123
+ positive_mask[2 * i + 1, 2 * i] = 1
124
+ else:
125
+ # Process the case where we have a mix of sentences with and without verbs
126
+ i = 0
127
+ while i < B_:
128
+ if verb_mask[i] == 1:
129
+ positive_mask[i, i + 1] = 1
130
+ positive_mask[i + 1, i] = 1
131
+ i += 2
132
+ else:
133
+ i += 1
134
+ negative_mask = torch.ones_like(emb_distance) - positive_mask
135
+ return positive_mask, negative_mask
136
+
137
+
138
+ def return_mask_hphn(self, emb_distance, positive_verbs, negative_verbs, verb_mask):
139
+ B_, B_ = emb_distance.shape
140
+ positive_mask = torch.zeros_like(emb_distance)
141
+ negative_mask = torch.ones_like(emb_distance)
142
+ positive_mask.fill_diagonal_(1)
143
+
144
+ if B_ < len(verb_mask):
145
+ # Considering only verbs that pass the verb_mask filter
146
+ positive_verbs = torch.tensor(positive_verbs)[verb_mask]
147
+ negative_verbs = torch.tensor(negative_verbs)[verb_mask]
148
+
149
+ # Exclude hard negatives from both masks (diagonal)
150
+ for i in range(B_):
151
+ if negative_verbs[i] == 1:
152
+ positive_mask[i, i] = 0
153
+ negative_mask[i, i] = 0
154
+
155
+ i = 0
156
+ while i < B_:
157
+ if positive_verbs[i] == 1:
158
+ if i + 1 < B_ and positive_verbs[i + 1] == 1:
159
+ positive_mask[i, i + 1] = 1
160
+ positive_mask[i + 1, i] = 1
161
+ i += 2
162
+ else:
163
+ i += 1
164
+ else:
165
+ # Exclude hard negatives from both masks (diagonal)
166
+ for i in range(B_):
167
+ if negative_verbs[i] == 1:
168
+ positive_mask[i, i] = 0
169
+ negative_mask[i, i] = 0
170
+
171
+ # Apply the positive pairs logic similarly as above
172
+ i = 0
173
+ while i < B_:
174
+ if positive_verbs[i] == 1 and i + 1 < B_ and positive_verbs[i + 1] == 1:
175
+ positive_mask[i, i + 1] = 1
176
+ positive_mask[i + 1, i] = 1
177
+ i += 2
178
+ else:
179
+ i += 1
180
+
181
+ negative_mask = negative_mask - positive_mask
182
+
183
+ return positive_mask, negative_mask
184
+
185
+
186
+
187
+
188
+
189
+ def compute_contrastive_loss(self, fq, state, verb_masks, temperature=0.05):
190
+ """
191
+ Compute contrastive loss (NCE) only for the samples with verb phrases.
192
+ fq: shape (b, c, h*w) -> Encoded image features
193
+ state: shape (b, d) -> Encoded text features (word representations)
194
+ verb_masks: boolean mask indicating samples containing verb phrases
195
+ temperature: scaling factor for contrastive loss
196
+ """
197
+
198
+ # Extract only the samples that contain verbs using verb_masks
199
+ fq_verb_samples = fq[verb_masks] # (num_verbs, c, h*w)
200
+ state_verb_samples = state[verb_masks] # (num_verbs, d)
201
+
202
+ fq_verb_samples = F.normalize(fq_verb_samples, p=2, dim=1)
203
+ state_verb_samples = F.normalize(state_verb_samples, p=2, dim=1)
204
+
205
+ # Compute the inner product between language conditioned feature output and encoded text (verb phrases)
206
+ fq_verb_flat = fq_verb_samples.view(fq_verb_samples.size(0), -1)
207
+ logits = torch.matmul(fq_verb_flat, state_verb_samples.t())
208
+ logits = logits / temperature
209
+
210
+ # Create labels for the contrastive loss (positive pairs are diagonals)
211
+ labels = torch.arange(logits.size(0), device=logits.device)
212
+
213
+ contrastive_loss = F.cross_entropy(logits, labels)
214
+
215
+ return contrastive_loss
216
+
217
+
218
+ # cosine sim only on metric_tensor
219
+ def AngularContrastiveLoss_1(self, embeddings, verb_mask, alpha=0.5, m=0.5, tau=0.05, args=None):
220
+ """
221
+ Angular Margin Contrastive Loss function.
222
+ - \( \theta_{i, i^*} \) represents the cosine similarity between the anchor \( h_i \) and the positive sample \( h_{i^*} \).
223
+ - An angular margin \( m \) is added to increase the distance between the positive and negative pairs.
224
+ - \( \tau \) is a temperature scaling factor to control the sharpness of the probability distribution.
225
+
226
+ https://aclanthology.org/2022.acl-long.336.pdf
227
+ \[
228
+ \mathcal{L}_{arc} = -\log \frac{\exp\left(\cos(\theta_{i,i^*} + m)/\tau\right)}{\exp\left(\cos(\theta_{i,i^*} + m)/\tau\right) + \sum_{j \neq i} \exp\left(\cos(\theta_{i,j})/\tau\right)}
229
+ \]
230
+
231
+ Args:
232
+ embeddings: Encoded embeddings with shape (B, C, H*W) for image-text fused features.
233
+ verb_mask: A mask indicating the samples with verb phrases.
234
+ alpha: Weight for balancing positive and negative loss components.
235
+ m: Angular margin to add to the cosine similarity of positive pairs.
236
+ tau: Temperature scaling factor for softmax.
237
+ args: Optional arguments for additional control.
238
+
239
+ Returns:
240
+ geometric_loss: Calculated Angular Metric Loss.
241
+ """
242
+
243
+ # Get batch size and feature dimensions
244
+ B_, C, HW = embeddings.shape
245
+
246
+ # Mean pooling across the spatial dimension (H*W) and normalize embeddings
247
+ emb = torch.mean(embeddings[verb_mask], dim=-1) # (B_, C)
248
+ emb = F.normalize(emb, p=2, dim=1) # Normalize the embeddings
249
+
250
+ # Create cosine similarity matrix
251
+ sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
252
+
253
+ # Pairwise cosine similarities
254
+ emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # Expand emb_i to pair with all other embeddings
255
+ emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # Expand emb_j to pair with all other embeddings
256
+ sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_)
257
+
258
+ # Clamp values to avoid numerical instability
259
+ sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999)
260
+
261
+ # Apply angular margin for positive pairs
262
+ positive_mask = torch.eye(B_, device=embeddings.device).bool() # Diagonal is the positive pairs
263
+ sim_matrix_with_margin = sim_matrix.clone()
264
+
265
+ # Apply the angular margin `m` only to positive pairs (diagonal)
266
+ sim_matrix_with_margin[positive_mask] = torch.cos(torch.acos(sim_matrix[positive_mask]) + m)
267
+
268
+ # Scale logits with temperature
269
+ logits = sim_matrix_with_margin / tau
270
+
271
+ # Compute the softmax loss for all pairs
272
+ exp_logits = torch.exp(logits)
273
+ pos_exp_logits = exp_logits[positive_mask]
274
+ total_exp_logits = exp_logits.sum(dim=-1)
275
+
276
+ # Compute the final loss: L_arc = -log(e^(cos(theta + m)/tau) / sum(e^(cos(theta)/tau)))
277
+ positive_loss = -torch.log(pos_exp_logits / total_exp_logits)
278
+
279
+ # Average the loss over the batch size
280
+ angular_loss = positive_loss.mean()
281
+
282
+ return angular_loss
283
+
284
+
285
+ # cosine similarity on metric_tensor (image-text) and state (text eos)
286
+ def AngularContrastiveLoss_2(self, fq, state, verb_masks, alpha=1.0, margin=0.5, temperature=0.05):
287
+ """
288
+ Angular Margin Contrastive Loss function.
289
+ - \( \theta_{i, i^*} \) represents the cosine similarity between the anchor \( h_i \) and the positive sample \( h_{i^*} \).
290
+ - An angular margin \( m \) is added to increase the distance between the positive and negative pairs.
291
+ - \( \tau \) is a temperature scaling factor to control the sharpness of the probability distribution.
292
+
293
+ https://aclanthology.org/2022.acl-long.336.pdf
294
+ \[
295
+ \mathcal{L}_{arc} = -\log \frac{\exp\left(\cos(\theta_{i,i^*} + m)/\tau\right)}{\exp\left(\cos(\theta_{i,i^*} + m)/\tau\right) + \sum_{j \neq i} \exp\left(\cos(\theta_{i,j})/\tau\right)}
296
+ \]
297
+
298
+
299
+ fq: (b, c, h*w) -> Encoded language-fused multimodal feature (metric_tensor)
300
+ state: (b, d) -> Encoded text features (word representations)
301
+ verb_masks: boolean mask indicating samples containing verb phrases
302
+ alpha: weight for positive samples
303
+ margin: the angular margin to enforce between positive pairs
304
+ temperature: scaling factor for contrastive loss
305
+ """
306
+
307
+ # Select only the verb-containing samples
308
+ # Assume c equals to d (CLIP model backbone)
309
+ fq_verb_samples = torch.mean(fq[verb_masks], dim=-1) # (num_verbs, d)
310
+ state_verb_samples = state[verb_masks] # (num_verbs, d)
311
+ fq_verb_samples = F.normalize(fq_verb_samples, p=2, dim=1) # (num_verbs, d)
312
+ state_verb_samples = F.normalize(state_verb_samples, p=2, dim=1) # (num_verbs, d)
313
+
314
+ # Compute cosine similarity (logits) between image and text features
315
+ logits = torch.matmul(fq_verb_samples, state_verb_samples.t()) # (num_verbs, num_verbs)
316
+
317
+ # Apply the angular margin to positive pairs (diagonal entries)
318
+ diagonal_indices = torch.arange(logits.size(0), device=logits.device)
319
+ positive_logits = logits[diagonal_indices, diagonal_indices]
320
+ positive_logits_with_margin = positive_logits + margin
321
+
322
+ # Replace the diagonal (positive) entries with the margin-added values
323
+ logits[diagonal_indices, diagonal_indices] = positive_logits_with_margin
324
+ logits = logits / temperature
325
+
326
+ # Create positive mask (diagonal) and negative mask (non-diagonal)
327
+ positive_mask = torch.eye(logits.size(0), device=logits.device).bool() # Diagonal for positive pairs
328
+ negative_mask = ~positive_mask # Non-diagonal for negative pairs
329
+ exp_logits = torch.exp(logits) # Exponentials of logits
330
+
331
+ # Positive and negative softmax components
332
+ pos_exp_logits = exp_logits[positive_mask].view(-1) # Positive pairs (diagonal entries)
333
+ neg_exp_logits = exp_logits[negative_mask].view(logits.size(0), -1).sum(dim=1) # Sum of negative pairs
334
+
335
+ # Final loss: -log(e^(cos(theta + m)/tau) / (e^(cos(theta + m)/tau) + sum(e^(cos(theta)/tau)))
336
+ positive_loss = -torch.log(pos_exp_logits / (pos_exp_logits + neg_exp_logits))
337
+
338
+ loss = positive_loss.mean()
339
+
340
+ return loss
341
+
342
+
343
+ # def AngularMetricLoss_Seunghoon(self, embeddings, n_pos , verb_mask, alpha = 0.5, args = None):
344
+ # # embeddings: ((2*B), C, (H*W))
345
+ # # n_pos : chunk size of positive pairs
346
+ # # args: args
347
+ # # returns: loss
348
+ # geometric_loss = 0
349
+
350
+ # # flatten embeddings
351
+ # B_, C, HW = embeddings.shape
352
+ # emb = torch.mean(embeddings[verb_mask], dim=-1)
353
+ # emb_i = emb.unsqueeze(1).repeat(1, B_, 1)
354
+ # emb_j = emb.unsqueeze(0).repeat(B_, 1, 1)
355
+ # sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
356
+ # sim_matrix = sim(emb_i, emb_j).reshape(B_, B_)
357
+ # sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999)
358
+ # phi = torch.acos(sim_matrix)
359
+
360
+ # # positive pairs and loss
361
+ # positive_mask = torch.zeros_like(sim_matrix)
362
+ # positive_mask.fill_diagonal_(1)
363
+ # positive_loss = torch.sum((phi**2) * positive_mask) / B_
364
+
365
+ # # negative pairs and loss
366
+ # negative_mask = torch.ones_like(sim_matrix) - positive_mask
367
+ # phi_mask = phi < args.phi_threshold
368
+ # negative_loss = (args.phi_threshold - phi)**2
369
+ # #print(negative_mask * phi_mask)
370
+ # negative_loss = torch.sum(negative_loss * negative_mask * phi_mask) / (B_**2 - 2*B_)
371
+
372
+ # #print("pos loss, neg loss : ", positive_loss, negative_loss)
373
+ # geometric_loss = alpha * positive_loss + (1-alpha) * negative_loss
374
+
375
+ # return geometric_loss
model/segmenter_verbonly_hardneg.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from model.clip import build_model
7
+
8
+ from .layers import FPN, Projector, TransformerDecoder
9
+
10
+ class CRIS_VerbOnly(nn.Module):
11
+ def __init__(self, cfg):
12
+ super().__init__()
13
+ # Vision & Text Encoder
14
+ clip_model = torch.jit.load(cfg.clip_pretrain,
15
+ map_location="cpu").eval()
16
+ self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float()
17
+ # Multi-Modal FPN
18
+ self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
19
+ # Decoder
20
+ self.decoder = TransformerDecoder(num_layers=cfg.num_layers,
21
+ d_model=cfg.vis_dim,
22
+ nhead=cfg.num_head,
23
+ dim_ffn=cfg.dim_ffn,
24
+ dropout=cfg.dropout,
25
+ return_intermediate=cfg.intermediate)
26
+ # Projector
27
+ self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3)
28
+ self.metric_learning = False # cfg.metric_learning
29
+ self.positive_strength = cfg.positive_strength
30
+ self.metric_loss_weight = cfg.metric_loss_weight
31
+ self.eps = cfg.ptb_rate
32
+ self.cfg = cfg
33
+
34
+
35
+ def forward(self, image, text, target=None, hardpos=None, hardneg=None):
36
+ '''
37
+ image: b, 3, h, w
38
+ text: b, words
39
+ target: b, 1, h, w
40
+ verb: b, words (if applicable, only used in training mode for contrastive learning)
41
+ '''
42
+
43
+ sentences, images, targets, pad_masks = [], [], [], []
44
+ positive_verbs, negative_verbs = [], []
45
+
46
+
47
+ if self.training:
48
+ verb_masks = []
49
+ cl_masks = []
50
+
51
+ for idx in range(len(text)):
52
+ sentences.append(text[idx])
53
+ images.append(image[idx])
54
+ targets.append(target[idx])
55
+ pad_masks.append(torch.zeros_like(text[idx]).masked_fill_(text[idx] == 0, 1).bool())
56
+
57
+
58
+
59
+ # If verb exists, process it
60
+ if verb[idx].numel() > 0 and verb[idx].sum().item() > 0:
61
+ verb_masks.extend([1, 1]) # Both original sentence and verb are marked
62
+ cl_masks.extend([0, 1]) # Only verb gets marked for exclusion from CE loss
63
+ sentences.append(verb[idx])
64
+ images.append(image[idx])
65
+ targets.append(target[idx])
66
+ pad_masks.append(torch.zeros_like(verb[idx]).masked_fill_(verb[idx] == 0, 1).bool())
67
+ else:
68
+ verb_masks.append(0)
69
+ cl_masks.append(0)
70
+
71
+
72
+ sentences = torch.stack(sentences)
73
+ images = torch.stack(images)
74
+ targets = torch.stack(targets)
75
+ pad_masks = torch.stack(pad_masks)
76
+ verb_masks = torch.tensor(verb_masks, dtype=torch.bool)
77
+ cl_masks = torch.tensor(cl_masks, dtype=torch.bool)
78
+
79
+ else:
80
+ sentences = text
81
+ images = image
82
+ targets = target
83
+ pad_masks = torch.zeros_like(text).masked_fill_(text == 0, 1).bool()
84
+
85
+ # Encoding images and text
86
+ vis = self.backbone.encode_image(images)
87
+ word, state = self.backbone.encode_text(sentences)
88
+
89
+ # FPN neck and decoder
90
+ fq, f5 = self.neck(vis, state)
91
+ b, c, h, w = fq.size()
92
+ fq = self.decoder(fq, word, pad_masks)
93
+ metric_tensor = fq # b, c, h*w
94
+ fq = fq.reshape(b, c, h, w)
95
+
96
+ # Final prediction
97
+ pred = self.proj(fq, state)
98
+
99
+ if self.training:
100
+ if pred.shape[-2:] != targets.shape[-2:]:
101
+ targets = F.interpolate(targets, pred.shape[-2:], mode='nearest').detach()
102
+
103
+ loss = F.binary_cross_entropy_with_logits(pred[~cl_masks], targets[~cl_masks])
104
+
105
+ if self.metric_learning:
106
+ metric_loss = self.compute_metric_loss(metric_tensor, verb_masks, args=self.cfg)
107
+ loss = (loss + self.metric_loss_weight * metric_loss) / (1 + self.metric_loss_weight)
108
+
109
+ return pred.detach(), targets, loss
110
+
111
+ return pred.detach() # In eval mode, only return the predictions
112
+
113
+
114
+ def compute_metric_loss(self, metric_tensor, verb_mask, args) :
115
+ return None
116
+
117
+
118
+ def return_mask(self, emb_distance, positive_verbs, negative_verbs, verb_mask):
119
+ B_, B_ = emb_distance.shape
120
+ positive_mask = torch.zeros_like(emb_distance)
121
+ negative_mask = torch.ones_like(emb_distance)
122
+ positive_mask.fill_diagonal_(1)
123
+
124
+ if B_ < len(verb_mask):
125
+ # Considering only verbs that pass the verb_mask filter
126
+ positive_verbs = torch.tensor(positive_verbs)[verb_mask]
127
+ negative_verbs = torch.tensor(negative_verbs)[verb_mask]
128
+
129
+ # Exclude hard negatives from both masks (diagonal)
130
+ for i in range(B_):
131
+ if negative_verbs[i] == 1:
132
+ positive_mask[i, i] = 0
133
+ negative_mask[i, i] = 0
134
+
135
+ i = 0
136
+ while i < B_:
137
+ if positive_verbs[i] == 1:
138
+ if i + 1 < B_ and positive_verbs[i + 1] == 1:
139
+ positive_mask[i, i + 1] = 1
140
+ positive_mask[i + 1, i] = 1
141
+ i += 2
142
+ else:
143
+ i += 1
144
+ else:
145
+ # Exclude hard negatives from both masks (diagonal)
146
+ for i in range(B_):
147
+ if negative_verbs[i] == 1:
148
+ positive_mask[i, i] = 0
149
+ negative_mask[i, i] = 0
150
+
151
+ # Apply the positive pairs logic similarly as above
152
+ i = 0
153
+ while i < B_:
154
+ if positive_verbs[i] == 1 and i + 1 < B_ and positive_verbs[i + 1] == 1:
155
+ positive_mask[i, i + 1] = 1
156
+ positive_mask[i + 1, i] = 1
157
+ i += 2
158
+ else:
159
+ i += 1
160
+
161
+ negative_mask = negative_mask - positive_mask
162
+
163
+ return positive_mask, negative_mask
164
+
165
+
166
+ def UniAngularContrastLoss(self, total_fq, positive_verbs, negative_verbs, m=0.5, tau=0.05, verbonly=True, args=None):
167
+ """
168
+ Angular Margin Contrastive Loss function with mask visualization.
169
+ """
170
+ verb_mask = positive_verbs + negative_verbs
171
+
172
+ if verbonly:
173
+ emb = torch.mean(total_fq[verb_mask], dim=-1)
174
+ else:
175
+ emb = torch.mean(total_fq, dim=-1) # (B, C)
176
+
177
+ B_ = emb.shape[0]
178
+ # emb = F.normalize(emb, p=2, dim=1)
179
+ emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (B_, B_, C)
180
+ emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (B_, B_, C)
181
+ sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
182
+ sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_)
183
+ sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999)
184
+
185
+ # Apply angular margin for positive pairs using return_mask
186
+ positive_mask, negative_mask = self.return_mask(sim_matrix, positive_verbs, negative_verbs, verb_mask)
187
+
188
+ # Apply margin to positive pairs
189
+ sim_matrix_with_margin = sim_matrix.clone()
190
+ sim_matrix_with_margin[positive_mask.bool()] = torch.cos(torch.acos(sim_matrix[positive_mask.bool()]) + m / 57.2958)
191
+
192
+ # Scale logits with temperature
193
+ logits = sim_matrix_with_margin / tau
194
+
195
+ # Compute the softmax loss for all pairs
196
+ exp_logits = torch.exp(logits)
197
+ pos_exp_logits = exp_logits[positive_mask.bool()]
198
+ total_exp_logits = exp_logits.sum(dim=-1)
199
+
200
+ # Compute the final loss: L_arc = -log(e^(cos(theta + m)/tau) / sum(e^(cos(theta)/tau)))
201
+ positive_loss = -torch.log(pos_exp_logits / total_exp_logits[positive_mask.bool()])
202
+ angular_loss = positive_loss.mean()
203
+
204
+ return angular_loss