Upload folder using huggingface_hub

#1
by dianecy - opened
dianecy/VerbCentric-RIS/model_/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
dianecy/VerbCentric-RIS/model_/__init__.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .segmenter import CRIS
2
+ # from .segmenter_angular import CRIS_S
3
+ from .segmenter_verbonly import CRIS_PosOnly
4
+ from .segmenter_verbonly_fin import CRIS_PosOnly_rev
5
+ from .segmenter_verbonly_hardneg import CRIS_VerbOnly
6
+ from loguru import logger
7
+
8
+ def build_segmenter_pos_rev(args):
9
+ model = CRIS_PosOnly_rev(args)
10
+ backbone = []
11
+ head = []
12
+ for k, v in model.named_parameters():
13
+ if k.startswith('backbone') and 'positional_embedding' not in k:
14
+ backbone.append(v)
15
+ else:
16
+ head.append(v)
17
+ logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
18
+ param_list = [{
19
+ 'params': backbone,
20
+ 'initial_lr': args.lr_multi * args.base_lr
21
+ }, {
22
+ 'params': head,
23
+ 'initial_lr': args.base_lr
24
+ }]
25
+ return model, param_list
26
+
27
+ def build_segmenter_pos(args):
28
+ model = CRIS_PosOnly(args)
29
+ backbone = []
30
+ head = []
31
+ for k, v in model.named_parameters():
32
+ if k.startswith('backbone') and 'positional_embedding' not in k:
33
+ backbone.append(v)
34
+ else:
35
+ head.append(v)
36
+ logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
37
+ param_list = [{
38
+ 'params': backbone,
39
+ 'initial_lr': args.lr_multi * args.base_lr
40
+ }, {
41
+ 'params': head,
42
+ 'initial_lr': args.base_lr
43
+ }]
44
+ return model, param_list
45
+
46
+
47
+ def build_segmenter(args):
48
+ model = CRIS_VerbOnly(args)
49
+ backbone = []
50
+ head = []
51
+ for k, v in model.named_parameters():
52
+ if k.startswith('backbone') and 'positional_embedding' not in k:
53
+ backbone.append(v)
54
+ else:
55
+ head.append(v)
56
+ logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
57
+ param_list = [{
58
+ 'params': backbone,
59
+ 'initial_lr': args.lr_multi * args.base_lr
60
+ }, {
61
+ 'params': head,
62
+ 'initial_lr': args.base_lr
63
+ }]
64
+ return model, param_list
65
+
66
+
67
+
68
+ def build_segmenter_original(args):
69
+ model = CRIS(args)
70
+ backbone = []
71
+ head = []
72
+ for k, v in model.named_parameters():
73
+ if k.startswith('backbone') and 'positional_embedding' not in k:
74
+ backbone.append(v)
75
+ else:
76
+ head.append(v)
77
+ logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
78
+ param_list = [{
79
+ 'params': backbone,
80
+ 'initial_lr': args.lr_multi * args.base_lr
81
+ }, {
82
+ 'params': head,
83
+ 'initial_lr': args.base_lr
84
+ }]
85
+ return model, param_list
86
+
87
+
88
+ # def build_segmenter_textaug(args):
89
+ # model = CRIS_Wo_Noise(args)
90
+ # backbone = []
91
+ # head = []
92
+ # for k, v in model.named_parameters():
93
+ # if k.startswith('backbone') and 'positional_embedding' not in k:
94
+ # backbone.append(v)
95
+ # else:
96
+ # head.append(v)
97
+ # logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
98
+ # param_list = [{
99
+ # 'params': backbone,
100
+ # 'initial_lr': args.lr_multi * args.base_lr
101
+ # }, {
102
+ # 'params': head,
103
+ # 'initial_lr': args.base_lr
104
+ # }]
105
+ # return model, param_list
dianecy/VerbCentric-RIS/model_/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (1.88 kB). View file
 
dianecy/VerbCentric-RIS/model_/__pycache__/clip.cpython-39.pyc ADDED
Binary file (16.8 kB). View file
 
dianecy/VerbCentric-RIS/model_/__pycache__/layers.cpython-39.pyc ADDED
Binary file (9.06 kB). View file
 
dianecy/VerbCentric-RIS/model_/__pycache__/segmenter.cpython-39.pyc ADDED
Binary file (4.83 kB). View file
 
dianecy/VerbCentric-RIS/model_/__pycache__/segmenter_verbonly.cpython-39.pyc ADDED
Binary file (4.76 kB). View file
 
dianecy/VerbCentric-RIS/model_/__pycache__/segmenter_verbonly_fin.cpython-39.pyc ADDED
Binary file (4.78 kB). View file
 
dianecy/VerbCentric-RIS/model_/__pycache__/segmenter_verbonly_hardneg.cpython-39.pyc ADDED
Binary file (5.96 kB). View file
 
dianecy/VerbCentric-RIS/model_/clip.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, inplanes, planes, stride=1):
14
+ super().__init__()
15
+
16
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+
20
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
21
+ self.bn2 = nn.BatchNorm2d(planes)
22
+
23
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
24
+
25
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
26
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
27
+
28
+ self.relu = nn.ReLU(inplace=True)
29
+ self.downsample = None
30
+ self.stride = stride
31
+
32
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
33
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
34
+ self.downsample = nn.Sequential(
35
+ OrderedDict([("-1", nn.AvgPool2d(stride)),
36
+ ("0",
37
+ nn.Conv2d(inplanes,
38
+ planes * self.expansion,
39
+ 1,
40
+ stride=1,
41
+ bias=False)),
42
+ ("1", nn.BatchNorm2d(planes * self.expansion))]))
43
+
44
+ def forward(self, x: torch.Tensor):
45
+ identity = x
46
+
47
+ out = self.relu(self.bn1(self.conv1(x)))
48
+ out = self.relu(self.bn2(self.conv2(out)))
49
+ out = self.avgpool(out)
50
+ out = self.bn3(self.conv3(out))
51
+
52
+ if self.downsample is not None:
53
+ identity = self.downsample(x)
54
+
55
+ out += identity
56
+ out = self.relu(out)
57
+ return out
58
+
59
+
60
+ class AttentionPool2d(nn.Module):
61
+ def __init__(self,
62
+ spacial_dim: int,
63
+ embed_dim: int,
64
+ num_heads: int,
65
+ output_dim: int = None):
66
+ super().__init__()
67
+ self.spacial_dim = spacial_dim
68
+ self.positional_embedding = nn.Parameter(
69
+ torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
70
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
71
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
72
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
73
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
74
+ self.num_heads = num_heads
75
+ # residual
76
+ self.connect = nn.Sequential(
77
+ nn.Conv2d(embed_dim, output_dim, 1, stride=1, bias=False),
78
+ nn.BatchNorm2d(output_dim))
79
+
80
+ def resize_pos_embed(self, pos_embed, input_shpae):
81
+ """Resize pos_embed weights.
82
+ Resize pos_embed using bicubic interpolate method.
83
+ Args:
84
+ pos_embed (torch.Tensor): Position embedding weights.
85
+ input_shpae (tuple): Tuple for (downsampled input image height,
86
+ downsampled input image width).
87
+ pos_shape (tuple): The resolution of downsampled origin training
88
+ image.
89
+ mode (str): Algorithm used for upsampling:
90
+ ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
91
+ ``'trilinear'``. Default: ``'nearest'``
92
+ Return:
93
+ torch.Tensor: The resized pos_embed of shape [B, C, L_new]
94
+ """
95
+ assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
96
+ pos_h = pos_w = self.spacial_dim
97
+ cls_token_weight = pos_embed[:, 0]
98
+ pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
99
+ pos_embed_weight = pos_embed_weight.reshape(
100
+ 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
101
+ pos_embed_weight = F.interpolate(pos_embed_weight,
102
+ size=input_shpae,
103
+ align_corners=False,
104
+ mode='bicubic')
105
+ cls_token_weight = cls_token_weight.unsqueeze(1)
106
+ pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
107
+ # pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
108
+ return pos_embed_weight.transpose(-2, -1)
109
+
110
+ def forward(self, x):
111
+ B, C, H, W = x.size()
112
+ res = self.connect(x)
113
+ x = x.reshape(B, C, -1) # NC(HW)
114
+ # x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(1+HW)
115
+ pos_embed = self.positional_embedding.unsqueeze(0)
116
+ pos_embed = self.resize_pos_embed(pos_embed, (H, W)) # NC(HW)
117
+ x = x + pos_embed.to(x.dtype) # NC(HW)
118
+ x = x.permute(2, 0, 1) # (HW)NC
119
+ x, _ = F.multi_head_attention_forward(
120
+ query=x,
121
+ key=x,
122
+ value=x,
123
+ embed_dim_to_check=x.shape[-1],
124
+ num_heads=self.num_heads,
125
+ q_proj_weight=self.q_proj.weight,
126
+ k_proj_weight=self.k_proj.weight,
127
+ v_proj_weight=self.v_proj.weight,
128
+ in_proj_weight=None,
129
+ in_proj_bias=torch.cat(
130
+ [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
131
+ bias_k=None,
132
+ bias_v=None,
133
+ add_zero_attn=False,
134
+ dropout_p=0,
135
+ out_proj_weight=self.c_proj.weight,
136
+ out_proj_bias=self.c_proj.bias,
137
+ use_separate_proj_weight=True,
138
+ training=self.training,
139
+ need_weights=False)
140
+ x = x.permute(1, 2, 0).reshape(B, -1, H, W)
141
+ x = x + res
142
+ x = F.relu(x, True)
143
+
144
+ return x
145
+
146
+
147
+ class ModifiedResNet(nn.Module):
148
+ """
149
+ A ResNet class that is similar to torchvision's but contains the following changes:
150
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
151
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
152
+ - The final pooling layer is a QKV attention instead of an average pool
153
+ """
154
+ def __init__(self,
155
+ layers,
156
+ output_dim,
157
+ heads,
158
+ input_resolution=224,
159
+ width=64):
160
+ super().__init__()
161
+ self.output_dim = output_dim
162
+ self.input_resolution = input_resolution
163
+
164
+ # the 3-layer stem
165
+ self.conv1 = nn.Conv2d(3,
166
+ width // 2,
167
+ kernel_size=3,
168
+ stride=2,
169
+ padding=1,
170
+ bias=False)
171
+ self.bn1 = nn.BatchNorm2d(width // 2)
172
+ self.conv2 = nn.Conv2d(width // 2,
173
+ width // 2,
174
+ kernel_size=3,
175
+ padding=1,
176
+ bias=False)
177
+ self.bn2 = nn.BatchNorm2d(width // 2)
178
+ self.conv3 = nn.Conv2d(width // 2,
179
+ width,
180
+ kernel_size=3,
181
+ padding=1,
182
+ bias=False)
183
+ self.bn3 = nn.BatchNorm2d(width)
184
+ self.avgpool = nn.AvgPool2d(2)
185
+ self.relu = nn.ReLU(inplace=True)
186
+
187
+ # residual layers
188
+ self._inplanes = width # this is a *mutable* variable used during construction
189
+ self.layer1 = self._make_layer(width, layers[0])
190
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
191
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
192
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
193
+
194
+ embed_dim = width * 32 # the ResNet feature dimension
195
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim,
196
+ heads, output_dim)
197
+
198
+ def _make_layer(self, planes, blocks, stride=1):
199
+ layers = [Bottleneck(self._inplanes, planes, stride)]
200
+
201
+ self._inplanes = planes * Bottleneck.expansion
202
+ for _ in range(1, blocks):
203
+ layers.append(Bottleneck(self._inplanes, planes))
204
+
205
+ return nn.Sequential(*layers)
206
+
207
+ def forward(self, x):
208
+ def stem(x):
209
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2),
210
+ (self.conv3, self.bn3)]:
211
+ x = self.relu(bn(conv(x)))
212
+ x = self.avgpool(x)
213
+ return x
214
+
215
+ x = x.type(self.conv1.weight.dtype)
216
+ x = stem(x)
217
+ x = self.layer1(x)
218
+ x2 = self.layer2(x)
219
+ x3 = self.layer3(x2)
220
+ x4 = self.layer4(x3)
221
+ x4 = self.attnpool(x4)
222
+
223
+ return (x2, x3, x4)
224
+
225
+
226
+ class LayerNorm(nn.LayerNorm):
227
+ """Subclass torch's LayerNorm to handle fp16."""
228
+ def forward(self, x: torch.Tensor):
229
+ orig_type = x.dtype
230
+ ret = super().forward(x.type(torch.float32))
231
+ return ret.type(orig_type)
232
+
233
+
234
+ class QuickGELU(nn.Module):
235
+ def forward(self, x: torch.Tensor):
236
+ return x * torch.sigmoid(1.702 * x)
237
+
238
+
239
+ class ResidualAttentionBlock(nn.Module):
240
+ def __init__(self,
241
+ d_model: int,
242
+ n_head: int,
243
+ attn_mask: torch.Tensor = None):
244
+ super().__init__()
245
+
246
+ self.attn = nn.MultiheadAttention(d_model, n_head)
247
+ self.ln_1 = LayerNorm(d_model)
248
+ self.mlp = nn.Sequential(
249
+ OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)),
250
+ ("gelu", QuickGELU()),
251
+ ("c_proj", nn.Linear(d_model * 4, d_model))]))
252
+ self.ln_2 = LayerNorm(d_model)
253
+ self.attn_mask = attn_mask
254
+
255
+ def attention(self, x: torch.Tensor):
256
+ self.attn_mask = self.attn_mask.to(
257
+ dtype=x.dtype,
258
+ device=x.device) if self.attn_mask is not None else None
259
+ return self.attn(x, x, x, need_weights=False,
260
+ attn_mask=self.attn_mask)[0]
261
+
262
+ def forward(self, x: torch.Tensor):
263
+ x = x + self.attention(self.ln_1(x))
264
+ x = x + self.mlp(self.ln_2(x))
265
+ return x
266
+
267
+
268
+ class Transformer(nn.Module):
269
+ def __init__(self,
270
+ width: int,
271
+ layers: int,
272
+ heads: int,
273
+ attn_mask: torch.Tensor = None):
274
+ super().__init__()
275
+ self.width = width
276
+ self.layers = layers
277
+ self.resblocks = nn.Sequential(*[
278
+ ResidualAttentionBlock(width, heads, attn_mask)
279
+ for _ in range(layers)
280
+ ])
281
+
282
+ def forward(self, x: torch.Tensor):
283
+ return self.resblocks(x)
284
+
285
+
286
+ class VisionTransformer(nn.Module):
287
+ def __init__(self, input_resolution: int, patch_size: int, width: int,
288
+ layers: int, heads: int, output_dim: int):
289
+ super().__init__()
290
+ self.input_resolution = input_resolution
291
+ self.output_dim = output_dim
292
+ self.conv1 = nn.Conv2d(in_channels=3,
293
+ out_channels=width,
294
+ kernel_size=patch_size,
295
+ stride=patch_size,
296
+ bias=False)
297
+
298
+ scale = width**-0.5
299
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
300
+ self.positional_embedding = nn.Parameter(scale * torch.randn(
301
+ (input_resolution // patch_size)**2 + 1, width))
302
+ self.ln_pre = LayerNorm(width)
303
+
304
+ self.transformer = Transformer(width, layers, heads)
305
+
306
+ self.ln_post = LayerNorm(width)
307
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
308
+
309
+ def forward(self, x: torch.Tensor):
310
+ x = self.conv1(x) # shape = [*, width, grid, grid]
311
+ x = x.reshape(x.shape[0], x.shape[1],
312
+ -1) # shape = [*, width, grid ** 2]
313
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
314
+ x = torch.cat([
315
+ self.class_embedding.to(x.dtype) + torch.zeros(
316
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
317
+ ],
318
+ dim=1) # shape = [*, grid ** 2 + 1, width]
319
+ x = x + self.positional_embedding.to(x.dtype)
320
+ x = self.ln_pre(x)
321
+
322
+ x = x.permute(1, 0, 2) # NLD -> LND
323
+ x = self.transformer(x)
324
+ x = x.permute(1, 0, 2) # LND -> NLD
325
+
326
+ # x = self.ln_post(x[:, 0, :])
327
+ x = self.ln_post(x[:, 1:, :])
328
+
329
+ if self.proj is not None:
330
+ x = x @ self.proj
331
+
332
+ return x
333
+
334
+
335
+ class CLIP(nn.Module):
336
+ def __init__(
337
+ self,
338
+ embed_dim: int,
339
+ # vision
340
+ image_resolution: int,
341
+ vision_layers: Union[Tuple[int, int, int, int], int],
342
+ vision_width: int,
343
+ vision_patch_size: int,
344
+ # text
345
+ context_length: int,
346
+ txt_length: int,
347
+ vocab_size: int,
348
+ transformer_width: int,
349
+ transformer_heads: int,
350
+ transformer_layers: int):
351
+ super().__init__()
352
+
353
+ self.context_length = context_length
354
+
355
+ if isinstance(vision_layers, (tuple, list)):
356
+ vision_heads = vision_width * 32 // 64
357
+ self.visual = ModifiedResNet(layers=vision_layers,
358
+ output_dim=embed_dim,
359
+ heads=vision_heads,
360
+ input_resolution=image_resolution,
361
+ width=vision_width)
362
+ else:
363
+ vision_heads = vision_width // 64
364
+ self.visual = VisionTransformer(input_resolution=image_resolution,
365
+ patch_size=vision_patch_size,
366
+ width=vision_width,
367
+ layers=vision_layers,
368
+ heads=vision_heads,
369
+ output_dim=embed_dim)
370
+
371
+ self.transformer = Transformer(
372
+ width=transformer_width,
373
+ layers=transformer_layers,
374
+ heads=transformer_heads,
375
+ attn_mask=self.build_attention_mask(txt_length))
376
+
377
+ self.vocab_size = vocab_size
378
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
379
+ self.positional_embedding = nn.Parameter(
380
+ torch.empty(self.context_length, transformer_width))
381
+ self.ln_final = LayerNorm(transformer_width)
382
+
383
+ self.text_projection = nn.Parameter(
384
+ torch.empty(transformer_width, embed_dim))
385
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
386
+
387
+ self.token_embedding.requires_grad_ = False
388
+ self.initialize_parameters()
389
+
390
+ def initialize_parameters(self):
391
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
392
+ nn.init.normal_(self.positional_embedding, std=0.01)
393
+
394
+ if isinstance(self.visual, ModifiedResNet):
395
+ if self.visual.attnpool is not None:
396
+ std = self.visual.attnpool.c_proj.in_features**-0.5
397
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
398
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
399
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
400
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
401
+
402
+ for resnet_block in [
403
+ self.visual.layer1, self.visual.layer2, self.visual.layer3,
404
+ self.visual.layer4
405
+ ]:
406
+ for name, param in resnet_block.named_parameters():
407
+ if name.endswith("bn3.weight"):
408
+ nn.init.zeros_(param)
409
+
410
+ proj_std = (self.transformer.width**-0.5) * (
411
+ (2 * self.transformer.layers)**-0.5)
412
+ attn_std = self.transformer.width**-0.5
413
+ fc_std = (2 * self.transformer.width)**-0.5
414
+ for block in self.transformer.resblocks:
415
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
416
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
417
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
418
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
419
+
420
+ if self.text_projection is not None:
421
+ nn.init.normal_(self.text_projection,
422
+ std=self.transformer.width**-0.5)
423
+
424
+ def build_attention_mask(self, context_length):
425
+ # lazily create causal attention mask, with full attention between the vision tokens
426
+ # pytorch uses additive attention mask; fill with -inf
427
+ mask = torch.empty(context_length, context_length)
428
+ mask.fill_(float("-inf"))
429
+ mask.triu_(1) # zero out the lower diagonal
430
+ return mask
431
+
432
+ @property
433
+ def dtype(self):
434
+ return self.visual.conv1.weight.dtype
435
+
436
+ def encode_image(self, image):
437
+ return self.visual(image.type(self.dtype))
438
+
439
+ def encode_text(self, text):
440
+ x = self.token_embedding(text).type(
441
+ self.dtype) # [batch_size, n_ctx, d_model]
442
+
443
+ x = x + self.positional_embedding.type(self.dtype)[:x.size(1)]
444
+ x = x.permute(1, 0, 2) # NLD -> LND
445
+ x = self.transformer(x)
446
+ x = x.permute(1, 0, 2) # LND -> NLD
447
+ x = self.ln_final(x).type(self.dtype)
448
+
449
+ # x.shape = [batch_size, n_ctx, transformer.width]
450
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
451
+ state = x[torch.arange(x.shape[0]),
452
+ text.argmax(dim=-1)] @ self.text_projection
453
+ # x = x @ self.text_projection
454
+ # state = x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
455
+
456
+ return x, state
457
+
458
+ def forward(self, image, text):
459
+ image_features = self.encode_image(image)
460
+ text_features = self.encode_text(text)
461
+
462
+ # normalized features
463
+ image_features = image_features / image_features.norm(dim=-1,
464
+ keepdim=True)
465
+ text_features = text_features / text_features.norm(dim=-1,
466
+ keepdim=True)
467
+
468
+ # cosine similarity as logits
469
+ logit_scale = self.logit_scale.exp()
470
+ logits_per_image = logit_scale * image_features @ text_features.t()
471
+ logits_per_text = logits_per_image.t()
472
+
473
+ # shape = [global_batch_size, global_batch_size]
474
+ return logits_per_image, logits_per_text
475
+
476
+
477
+ def convert_weights(model: nn.Module):
478
+ """Convert applicable model parameters to fp16"""
479
+ def _convert_weights_to_fp16(l):
480
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
481
+ l.weight.data = l.weight.data.half()
482
+ if l.bias is not None:
483
+ l.bias.data = l.bias.data.half()
484
+
485
+ if isinstance(l, nn.MultiheadAttention):
486
+ for attr in [
487
+ *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
488
+ "in_proj_bias", "bias_k", "bias_v"
489
+ ]:
490
+ tensor = getattr(l, attr)
491
+ if tensor is not None:
492
+ tensor.data = tensor.data.half()
493
+
494
+ for name in ["text_projection", "proj"]:
495
+ if hasattr(l, name):
496
+ attr = getattr(l, name)
497
+ if attr is not None:
498
+ attr.data = attr.data.half()
499
+
500
+ model.apply(_convert_weights_to_fp16)
501
+
502
+
503
+ def build_model(state_dict: dict, txt_length: int, freeze: bool):
504
+ vit = "visual.proj" in state_dict
505
+
506
+ if vit:
507
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
508
+ vision_layers = len([
509
+ k for k in state_dict.keys()
510
+ if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
511
+ ])
512
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
513
+ grid_size = round(
514
+ (state_dict["visual.positional_embedding"].shape[0] - 1)**0.5)
515
+ image_resolution = vision_patch_size * grid_size
516
+ else:
517
+ counts: list = [
518
+ len(
519
+ set(
520
+ k.split(".")[2] for k in state_dict
521
+ if k.startswith(f"visual.layer{b}")))
522
+ for b in [1, 2, 3, 4]
523
+ ]
524
+ vision_layers = tuple(counts)
525
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
526
+ output_width = round(
527
+ (state_dict["visual.attnpool.positional_embedding"].shape[0] -
528
+ 1)**0.5)
529
+ vision_patch_size = None
530
+ assert output_width**2 + 1 == state_dict[
531
+ "visual.attnpool.positional_embedding"].shape[0]
532
+ image_resolution = output_width * 32
533
+
534
+ embed_dim = state_dict["text_projection"].shape[1]
535
+ context_length = state_dict["positional_embedding"].shape[0]
536
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
537
+ transformer_width = state_dict["ln_final.weight"].shape[0]
538
+ transformer_heads = transformer_width // 64
539
+ transformer_layers = len(
540
+ set(
541
+ k.split(".")[2] for k in state_dict
542
+ if k.startswith(f"transformer.resblocks")))
543
+
544
+ model = CLIP(embed_dim, image_resolution, vision_layers, vision_width,
545
+ vision_patch_size, context_length, txt_length, vocab_size,
546
+ transformer_width, transformer_heads, transformer_layers)
547
+
548
+ for key in ["input_resolution", "context_length", "vocab_size"]:
549
+ if key in state_dict:
550
+ del state_dict[key]
551
+
552
+ convert_weights(model)
553
+ model.load_state_dict(state_dict, False)
554
+
555
+ if freeze:
556
+ print(f"CLIP FROZEN")
557
+ return model.eval()
558
+ else:
559
+ print(f"CLIP FINE TUNING")
560
+ return model.train()
dianecy/VerbCentric-RIS/model_/layers.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def conv_layer(in_dim, out_dim, kernel_size=1, padding=0, stride=1):
9
+ return nn.Sequential(
10
+ nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False),
11
+ nn.BatchNorm2d(out_dim), nn.ReLU(True))
12
+
13
+
14
+ def linear_layer(in_dim, out_dim, bias=False):
15
+ return nn.Sequential(nn.Linear(in_dim, out_dim, bias),
16
+ nn.BatchNorm1d(out_dim), nn.ReLU(True))
17
+
18
+
19
+ class CoordConv(nn.Module):
20
+ def __init__(self,
21
+ in_channels,
22
+ out_channels,
23
+ kernel_size=3,
24
+ padding=1,
25
+ stride=1):
26
+ super().__init__()
27
+ self.conv1 = conv_layer(in_channels + 2, out_channels, kernel_size,
28
+ padding, stride)
29
+
30
+ def add_coord(self, input):
31
+ b, _, h, w = input.size()
32
+ x_range = torch.linspace(-1, 1, w, device=input.device)
33
+ y_range = torch.linspace(-1, 1, h, device=input.device)
34
+ y, x = torch.meshgrid(y_range, x_range)
35
+ y = y.expand([b, 1, -1, -1])
36
+ x = x.expand([b, 1, -1, -1])
37
+ coord_feat = torch.cat([x, y], 1)
38
+ input = torch.cat([input, coord_feat], 1)
39
+ return input
40
+
41
+ def forward(self, x):
42
+ x = self.add_coord(x)
43
+ x = self.conv1(x)
44
+ return x
45
+
46
+
47
+ class Projector(nn.Module):
48
+ def __init__(self, word_dim=1024, in_dim=256, kernel_size=3):
49
+ super().__init__()
50
+ self.in_dim = in_dim
51
+ self.kernel_size = kernel_size
52
+ # visual projector
53
+ self.vis = nn.Sequential( # os16 -> os4
54
+ nn.Upsample(scale_factor=2, mode='bilinear'),
55
+ conv_layer(in_dim * 2, in_dim * 2, 3, padding=1),
56
+ nn.Upsample(scale_factor=2, mode='bilinear'),
57
+ conv_layer(in_dim * 2, in_dim, 3, padding=1),
58
+ nn.Conv2d(in_dim, in_dim, 1))
59
+ # textual projector
60
+ out_dim = 1 * in_dim * kernel_size * kernel_size + 1
61
+ self.txt = nn.Linear(word_dim, out_dim)
62
+
63
+ def forward(self, x, word):
64
+ '''
65
+ x: b, 512, 26, 26
66
+ word: b, 512
67
+ '''
68
+ x = self.vis(x)
69
+ B, C, H, W = x.size()
70
+ # 1, b*256, 104, 104
71
+ x = x.reshape(1, B * C, H, W)
72
+ # txt: b, (256*3*3 + 1) -> b, 256, 3, 3 / b
73
+ word = self.txt(word)
74
+ weight, bias = word[:, :-1], word[:, -1]
75
+ weight = weight.reshape(B, C, self.kernel_size, self.kernel_size)
76
+ # Conv2d - 1, b*256, 104, 104 -> 1, b, 104, 104
77
+ out = F.conv2d(x,
78
+ weight,
79
+ padding=self.kernel_size // 2,
80
+ groups=weight.size(0),
81
+ bias=bias)
82
+ out = out.transpose(0, 1)
83
+ # b, 1, 104, 104
84
+ return out
85
+
86
+
87
+ class TransformerDecoder(nn.Module):
88
+ def __init__(self,
89
+ num_layers,
90
+ d_model,
91
+ nhead,
92
+ dim_ffn,
93
+ dropout,
94
+ return_intermediate=False):
95
+ super().__init__()
96
+ self.layers = nn.ModuleList([
97
+ TransformerDecoderLayer(d_model=d_model,
98
+ nhead=nhead,
99
+ dim_feedforward=dim_ffn,
100
+ dropout=dropout) for _ in range(num_layers)
101
+ ])
102
+ self.num_layers = num_layers
103
+ self.norm = nn.LayerNorm(d_model)
104
+ self.return_intermediate = return_intermediate
105
+
106
+ @staticmethod
107
+ def pos1d(d_model, length):
108
+ """
109
+ :param d_model: dimension of the model
110
+ :param length: length of positions
111
+ :return: length*d_model position matrix
112
+ """
113
+ if d_model % 2 != 0:
114
+ raise ValueError("Cannot use sin/cos positional encoding with "
115
+ "odd dim (got dim={:d})".format(d_model))
116
+ pe = torch.zeros(length, d_model)
117
+ position = torch.arange(0, length).unsqueeze(1)
118
+ div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
119
+ -(math.log(10000.0) / d_model)))
120
+ pe[:, 0::2] = torch.sin(position.float() * div_term)
121
+ pe[:, 1::2] = torch.cos(position.float() * div_term)
122
+
123
+ return pe.unsqueeze(1) # n, 1, 512
124
+
125
+ @staticmethod
126
+ def pos2d(d_model, height, width):
127
+ """
128
+ :param d_model: dimension of the model
129
+ :param height: height of the positions
130
+ :param width: width of the positions
131
+ :return: d_model*height*width position matrix
132
+ """
133
+ if d_model % 4 != 0:
134
+ raise ValueError("Cannot use sin/cos positional encoding with "
135
+ "odd dimension (got dim={:d})".format(d_model))
136
+ pe = torch.zeros(d_model, height, width)
137
+ # Each dimension use half of d_model
138
+ d_model = int(d_model / 2)
139
+ div_term = torch.exp(
140
+ torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))
141
+ pos_w = torch.arange(0., width).unsqueeze(1)
142
+ pos_h = torch.arange(0., height).unsqueeze(1)
143
+ pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(
144
+ 0, 1).unsqueeze(1).repeat(1, height, 1)
145
+ pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(
146
+ 0, 1).unsqueeze(1).repeat(1, height, 1)
147
+ pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(
148
+ 0, 1).unsqueeze(2).repeat(1, 1, width)
149
+ pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(
150
+ 0, 1).unsqueeze(2).repeat(1, 1, width)
151
+
152
+ return pe.reshape(-1, 1, height * width).permute(2, 1, 0) # hw, 1, 512
153
+
154
+ def forward(self, vis, txt, pad_mask):
155
+ '''
156
+ vis: b, 512, h, w
157
+ txt: b, L, 512
158
+ pad_mask: b, L
159
+ '''
160
+ B, C, H, W = vis.size()
161
+ _, L, D = txt.size()
162
+ # position encoding
163
+ vis_pos = self.pos2d(C, H, W)
164
+ txt_pos = self.pos1d(D, L)
165
+ # reshape & permute
166
+ vis = vis.reshape(B, C, -1).permute(2, 0, 1)
167
+ txt = txt.permute(1, 0, 2)
168
+ # forward
169
+ output = vis
170
+ intermediate = []
171
+ for layer in self.layers:
172
+ output = layer(output, txt, vis_pos, txt_pos, pad_mask)
173
+ if self.return_intermediate:
174
+ # HW, b, 512 -> b, 512, HW
175
+ intermediate.append(self.norm(output).permute(1, 2, 0))
176
+
177
+ if self.norm is not None:
178
+ # HW, b, 512 -> b, 512, HW
179
+ output = self.norm(output).permute(1, 2, 0)
180
+ if self.return_intermediate:
181
+ intermediate.pop()
182
+ intermediate.append(output)
183
+ # [output1, output2, ..., output_n]
184
+ return intermediate
185
+ else:
186
+ # b, 512, HW
187
+ return output
188
+ return output
189
+
190
+
191
+ class TransformerDecoderLayer(nn.Module):
192
+ def __init__(self,
193
+ d_model=512,
194
+ nhead=9,
195
+ dim_feedforward=2048,
196
+ dropout=0.1):
197
+ super().__init__()
198
+ # Normalization Layer
199
+ self.self_attn_norm = nn.LayerNorm(d_model)
200
+ self.cross_attn_norm = nn.LayerNorm(d_model)
201
+ # Attention Layer
202
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
203
+ self.multihead_attn = nn.MultiheadAttention(d_model,
204
+ nhead,
205
+ dropout=dropout,
206
+ kdim=d_model,
207
+ vdim=d_model)
208
+ # FFN
209
+ self.ffn = nn.Sequential(nn.Linear(d_model, dim_feedforward),
210
+ nn.ReLU(True), nn.Dropout(dropout),
211
+ nn.LayerNorm(dim_feedforward),
212
+ nn.Linear(dim_feedforward, d_model))
213
+ # LayerNorm & Dropout
214
+ self.norm1 = nn.LayerNorm(d_model)
215
+ self.norm2 = nn.LayerNorm(d_model)
216
+ self.norm3 = nn.LayerNorm(d_model)
217
+ self.dropout1 = nn.Dropout(dropout)
218
+ self.dropout2 = nn.Dropout(dropout)
219
+ self.dropout3 = nn.Dropout(dropout)
220
+
221
+ def with_pos_embed(self, tensor, pos):
222
+ return tensor if pos is None else tensor + pos.to(tensor.device)
223
+
224
+ def forward(self, vis, txt, vis_pos, txt_pos, pad_mask):
225
+ '''
226
+ vis: 26*26, b, 512
227
+ txt: L, b, 512
228
+ vis_pos: 26*26, 1, 512
229
+ txt_pos: L, 1, 512
230
+ pad_mask: b, L
231
+ '''
232
+ # Self-Attention
233
+ vis2 = self.norm1(vis)
234
+ q = k = self.with_pos_embed(vis2, vis_pos)
235
+ vis2 = self.self_attn(q, k, value=vis2)[0]
236
+ vis2 = self.self_attn_norm(vis2)
237
+ vis = vis + self.dropout1(vis2)
238
+ # Cross-Attention
239
+ vis2 = self.norm2(vis)
240
+ vis2 = self.multihead_attn(query=self.with_pos_embed(vis2, vis_pos),
241
+ key=self.with_pos_embed(txt, txt_pos),
242
+ value=txt,
243
+ key_padding_mask=pad_mask)[0]
244
+ vis2 = self.cross_attn_norm(vis2)
245
+ vis = vis + self.dropout2(vis2)
246
+ # FFN
247
+ vis2 = self.norm3(vis)
248
+ vis2 = self.ffn(vis2)
249
+ vis = vis + self.dropout3(vis2)
250
+ return vis
251
+
252
+
253
+ class FPN(nn.Module):
254
+ def __init__(self,
255
+ in_channels=[512, 1024, 1024],
256
+ out_channels=[256, 512, 1024]):
257
+ super(FPN, self).__init__()
258
+ # text projection
259
+ self.txt_proj = linear_layer(in_channels[2], out_channels[2])
260
+ # fusion 1: v5 & seq -> f_5: b, 1024, 13, 13
261
+ self.f1_v_proj = conv_layer(in_channels[2], out_channels[2], 1, 0)
262
+ self.norm_layer = nn.Sequential(nn.BatchNorm2d(out_channels[2]),
263
+ nn.ReLU(True))
264
+ # fusion 2: v4 & fm -> f_4: b, 512, 26, 26
265
+ self.f2_v_proj = conv_layer(in_channels[1], out_channels[1], 3, 1)
266
+ self.f2_cat = conv_layer(out_channels[2] + out_channels[1],
267
+ out_channels[1], 1, 0)
268
+ # fusion 3: v3 & fm_mid -> f_3: b, 512, 52, 52
269
+ self.f3_v_proj = conv_layer(in_channels[0], out_channels[0], 3, 1)
270
+ self.f3_cat = conv_layer(out_channels[0] + out_channels[1],
271
+ out_channels[1], 1, 0)
272
+ # fusion 4: f_3 & f_4 & f_5 -> fq: b, 256, 26, 26
273
+ self.f4_proj5 = conv_layer(out_channels[2], out_channels[1], 3, 1)
274
+ self.f4_proj4 = conv_layer(out_channels[1], out_channels[1], 3, 1)
275
+ self.f4_proj3 = conv_layer(out_channels[1], out_channels[1], 3, 1)
276
+ # aggregation
277
+ self.aggr = conv_layer(3 * out_channels[1], out_channels[1], 1, 0)
278
+ self.coordconv = nn.Sequential(
279
+ CoordConv(out_channels[1], out_channels[1], 3, 1),
280
+ conv_layer(out_channels[1], out_channels[1], 3, 1))
281
+
282
+ def forward(self, imgs, state):
283
+ # v3, v4, v5: 256, 52, 52 / 512, 26, 26 / 1024, 13, 13
284
+ v3, v4, v5 = imgs
285
+ # fusion 1: b, 1024, 13, 13
286
+ # text projection: b, 1024 -> b, 1024
287
+ state = self.txt_proj(state).unsqueeze(-1).unsqueeze(
288
+ -1) # b, 1024, 1, 1
289
+ f5 = self.f1_v_proj(v5)
290
+ f5 = self.norm_layer(f5 * state)
291
+ # fusion 2: b, 512, 26, 26
292
+ f4 = self.f2_v_proj(v4)
293
+ f5_ = F.interpolate(f5, scale_factor=2, mode='bilinear')
294
+ f4 = self.f2_cat(torch.cat([f4, f5_], dim=1))
295
+ # fusion 3: b, 256, 26, 26
296
+ f3 = self.f3_v_proj(v3)
297
+ f3 = F.avg_pool2d(f3, 2, 2)
298
+ f3 = self.f3_cat(torch.cat([f3, f4], dim=1))
299
+ # fusion 4: b, 512, 13, 13 / b, 512, 26, 26 / b, 512, 26, 26
300
+ fq5 = self.f4_proj5(f5)
301
+ fq4 = self.f4_proj4(f4)
302
+ fq3 = self.f4_proj3(f3)
303
+ # query
304
+ fq5 = F.interpolate(fq5, scale_factor=2, mode='bilinear')
305
+ fq = torch.cat([fq3, fq4, fq5], dim=1)
306
+ fq = self.aggr(fq)
307
+ fq = self.coordconv(fq)
308
+ # b, 512, 26, 26
309
+ return fq, f5
dianecy/VerbCentric-RIS/model_/segmenter.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from model_.clip import build_model
7
+ from .layers import FPN, Projector, TransformerDecoder
8
+
9
+
10
+ def MetricLoss(embeddings, n_pos, alpha = 0.5, args = None):
11
+ # embeddings: ((2*B), C, (H*W))
12
+ # n_pos : chunk size of positive pairs
13
+ # args: args
14
+ # returns: loss
15
+ metric_loss = 0
16
+
17
+ # flatten embeddings
18
+ B_, C, HW = embeddings.shape
19
+ emb = torch.mean(embeddings, dim=-1) # (2*B, C)
20
+ emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C)
21
+ emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C)
22
+ emb_distance = torch.norm(emb_i - emb_j, dim=-1) # (2*B, 2*B)
23
+ assert torch.sum(torch.diag(emb_distance[:B_, :B_])) == 0, \
24
+ "Diagonals are not zero. please check the permutation on the batch"
25
+ # print("distance metrix : ", emb_distance)
26
+
27
+ # positive pairs and loss
28
+ positive_mask = torch.zeros_like(emb_distance)
29
+ for i in range(B_//2):
30
+ positive_mask[2*i, 2*i+1] = 1
31
+ positive_mask[2*i+1, 2*i] = 1
32
+ positive_mask.fill_diagonal_(1)
33
+ positive_loss = torch.sum(emb_distance * positive_mask) / B_
34
+
35
+ # negative pairs and loss
36
+ negative_mask = torch.ones_like(emb_distance) - positive_mask
37
+
38
+ if args.div_batch:
39
+ negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / B_)
40
+ else:
41
+ negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / (B_**2 - 2*B_))
42
+
43
+ # print(positive_mask, negative_mask)
44
+
45
+ metric_loss = alpha * positive_loss + (1-alpha) * negative_loss
46
+
47
+ return metric_loss
48
+
49
+
50
+ def AngularMetricLoss(embeddings, n_pos, alpha = 0.5, args = None):
51
+ # embeddings: ((2*B), C, (H*W))
52
+ # n_pos : chunk size of positive pairs
53
+ # args: args
54
+ # returns: loss
55
+ geometric_loss = 0
56
+
57
+ # flatten embeddings
58
+ B_, C, HW = embeddings.shape
59
+ emb = torch.mean(embeddings, dim=-1) # (2*B, C)
60
+ emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C)
61
+ emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C)
62
+ sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
63
+ sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (2*B , 2*B)
64
+ print(sim_matrix)
65
+ assert torch.trace(sim_matrix) == B_, \
66
+ "similarity diagonals are not one. please check the permutation on the batch"
67
+ print("similarity metrix : ", sim_matrix)
68
+ phi = torch.acos(sim_matrix) # (2*B, 2*B)
69
+ print("phi metrix : ", phi)
70
+
71
+ # positive pairs and loss
72
+ positive_mask = torch.zeros_like(sim_matrix)
73
+ for i in range(B_//2):
74
+ positive_mask[2*i, 2*i+1] = 1
75
+ positive_mask[2*i+1, 2*i] = 1
76
+ positive_mask.fill_diagonal_(1)
77
+ positive_loss = torch.sum((phi**2) * positive_mask) / B_
78
+
79
+ # negative pairs and loss
80
+ negative_mask = torch.ones_like(sim_matrix) - positive_mask
81
+ phi_mask = phi < args.phi_threshold
82
+ negative_loss = (args.phi_threshold - phi)**2
83
+ print(negative_mask * phi_mask)
84
+ negative_loss = torch.sum(negative_loss * negative_mask * phi_mask) / (B_**2 - 2*B_)
85
+
86
+ print("pos loss, neg loss : ", positive_loss, negative_loss)
87
+
88
+ geometric_loss = alpha * positive_loss + (1-alpha) * negative_loss
89
+
90
+ return geometric_loss
91
+
92
+
93
+ class CRIS(nn.Module):
94
+ def __init__(self, cfg):
95
+ super().__init__()
96
+ # Vision & Text Encoder
97
+ clip_model = torch.jit.load(cfg.clip_pretrain,
98
+ map_location="cpu").eval()
99
+ self.backbone = build_model(clip_model.state_dict(), cfg.word_len, cfg.freeze).float()
100
+ # Multi-Modal FPN
101
+ self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
102
+ # Decoder
103
+ self.decoder = TransformerDecoder(num_layers=cfg.num_layers,
104
+ d_model=cfg.vis_dim,
105
+ nhead=cfg.num_head,
106
+ dim_ffn=cfg.dim_ffn,
107
+ dropout=cfg.dropout,
108
+ return_intermediate=cfg.intermediate)
109
+ # Projector
110
+ self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3)
111
+ self.metric_learning = cfg.metric_learning
112
+ # self.positive_strength = cfg.positive_strength
113
+ # self.metric_loss_weight = cfg.metric_loss_weight
114
+ # self.eps = cfg.ptb_rate
115
+ self.cfg = cfg
116
+
117
+ def forward(self, image, text, target=None):
118
+ '''
119
+ img: b, 3, h, w
120
+ word: b, words
121
+ word_mask: b, words
122
+ if self.metric_learning:
123
+ word: b, 2, words
124
+ word_mask: b, 2, words
125
+ mask: b, 1, h, w
126
+ '''
127
+ metric_learning_flag = (self.metric_learning and self.training)
128
+ metric_loss = 0
129
+
130
+ # 1.Resizing : if metric learning, batch size of the word is doubled
131
+ if metric_learning_flag:
132
+ #print("image shape : ", image.shape)
133
+ b, c, h, w = image.size()
134
+ # duplicate image and segmentation mask
135
+ if image is not None:
136
+ image = torch.cat([image, image], dim=0)
137
+ image = image.reshape(-1, b, c, h, w).transpose(0, 1).reshape(-1, c, h, w)
138
+ if target is not None:
139
+ target = torch.cat([target, target], dim=0)
140
+ target = target.reshape(-1, b, 1, h, w).transpose(0, 1).reshape(-1, 1, h, w)
141
+ # duplicate noise mask
142
+ b_, n_, l_ = text.size()
143
+ assert n_ == 2 ,"word size should be 2"
144
+ noise_mask = (text[:, 0, :] == text[:, 1, :])
145
+ noise_mask = torch.all(noise_mask, dim=-1)
146
+ noise_mask = noise_mask.unsqueeze(-1).repeat(1, 2).reshape(-1) # 2*b_
147
+ assert noise_mask.shape[0] == b_ * 2, "noise mask shape should be 2*b_"
148
+ text = text.reshape(b_ * 2, l_) # 2*b, l
149
+
150
+ # print("text shape : ", text.shape)
151
+ # print("image shape : ", image.shape)
152
+ # print("target shape : ", target.shape)
153
+ # print(torch.sum(image[0::2]) == torch.sum(image[1::2]))
154
+ # print(torch.sum(target[0::2]) == torch.sum(target[1::2]))
155
+
156
+ # padding mask used in decoder
157
+ pad_mask = torch.zeros_like(text).masked_fill_(text == 0, 1).bool()
158
+ # vis: C3 / C4 / C5
159
+ # word: b, length, 1024
160
+ # state: b, 1024
161
+ vis = self.backbone.encode_image(image)
162
+ word, state = self.backbone.encode_text(text)
163
+
164
+ b_, d_ = state.size()
165
+ assert b_ == word.size(0), "batch size of state and word should be same"
166
+
167
+
168
+ # 2. State Noising Step : if number of caption is 1,
169
+ # add noise to the corresponding indices
170
+ if metric_learning_flag :
171
+ noise = torch.randn_like(state) * self.eps
172
+ state[noise_mask] = state[noise_mask] + noise[noise_mask]
173
+
174
+ # print("shape of word, state : ", word.shape, state.shape)
175
+
176
+ # b, 512, 26, 26 (C4)
177
+ a3, a4, a5 = vis
178
+ # print("vis shape in model " , a3.shape, a4.shape, a5.shape)
179
+ fq, f5 = self.neck(vis, state)
180
+ b, c, h, w = fq.size()
181
+ fq = self.decoder(fq, word, pad_mask)
182
+ # print("decoder output shape : ", fq.shape)
183
+ # 3. Get metric loss
184
+ if metric_learning_flag:
185
+ metric_loss = MetricLoss(fq, 2, alpha=self.positive_strength, args = self.cfg)
186
+
187
+ fq = fq.reshape(b, c, h, w)
188
+
189
+ # b, 1, 104, 104
190
+ pred = self.proj(fq, state)
191
+
192
+ if self.training:
193
+ # resize mask
194
+ if pred.shape[-2:] != target.shape[-2:]:
195
+ target = F.interpolate(target, pred.shape[-2:],
196
+ mode='nearest').detach()
197
+ loss = F.binary_cross_entropy_with_logits(pred, target)
198
+ # 4. if metric learning, add metric loss and normalize
199
+ if metric_learning_flag:
200
+ #print("CE loss : ", loss, "metric loss : ", metric_loss)
201
+ loss = (loss + self.metric_loss_weight * metric_loss) / (1+self.metric_loss_weight)
202
+ return pred.detach(), target, loss
203
+ else:
204
+ return pred.detach()
dianecy/VerbCentric-RIS/model_/segmenter_ang_nonoise_ddp.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from model.clip import build_model
7
+
8
+ from .layers import FPN, Projector, TransformerDecoder
9
+
10
+ def check_nan(x) :
11
+ """ Check if there is NaN in tensor """
12
+ checker = False
13
+ if True in torch.isnan(x):
14
+ checker = True
15
+ return checker
16
+
17
+ def zero_filtering(x) :
18
+ """
19
+ Add eps value for zero embedding, because competition metric is cosine similarity
20
+ Cosine Similarity will be returned NaN, when input value has zero, like as torch.clamp()
21
+ """
22
+ eps = 1e-4
23
+ x[x <= eps] = eps
24
+ return x
25
+
26
+ def nan_filtering(x, eps = 1e-4) :
27
+ """
28
+ Change eps value for NaN Embedding, because competition metric is cosine similarity
29
+ Cosine Similarity will be returned NaN
30
+ """
31
+ return torch.nan_to_num(x, nan=eps)
32
+
33
+ # def MetricLoss(embeddings, num_pos, alpha = 0.5, args = None):
34
+ # # embeddings: ((2*B), C, (H*W))
35
+ # # n_pos : chunk size of positive pairs
36
+ # # args: args
37
+ # # returns: loss
38
+ # metric_loss = 0
39
+ # # flatten embeddings
40
+ # B_, C, HW = embeddings.shape
41
+ # emb = torch.mean(embeddings, dim=-1) # (2*B, C)
42
+ # emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C)
43
+ # emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C)
44
+ # emb_distance = torch.norm(emb_i - emb_j, dim=-1) # (2*B, 2*B)
45
+ # assert torch.sum(torch.diag(emb_distance[:B_, :B_])) == 0, \
46
+ # "Diagonals are not zero. please check the permutation on the batch"
47
+ # # print("distance metrix : ", emb_distance)
48
+
49
+ # # positive pairs and loss
50
+ # positive_mask = torch.zeros_like(emb_distance)
51
+ # for i in range(B_//2):
52
+ # positive_mask[2*i, 2*i+1] = 1
53
+ # positive_mask[2*i+1, 2*i] = 1
54
+ # positive_mask.fill_diagonal_(1)
55
+ # positive_loss = torch.sum(emb_distance * positive_mask) / B_
56
+
57
+ # # negative pairs and loss
58
+ # negative_mask = torch.ones_like(emb_distance) - positive_mask
59
+ # negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / (B_**2 - 2*B_))
60
+
61
+ # # print(positive_mask, negative_mask)
62
+
63
+ # metric_loss = alpha * positive_loss + (1-alpha) * negative_loss
64
+
65
+ # return metric_loss
66
+
67
+ # def return_mask(emb_distance, nsent):
68
+ # B_, B_ = emb_distance.shape
69
+ # positive_mask = torch.zeros_like(emb_distance)
70
+ # for i in range(B_//nsent):
71
+ # positive_mask[nsent*i, nsent*i+1] = 1
72
+ # positive_mask[nsent*i+1, nsent*i] = 1
73
+ # positive_mask.fill_diagonal_(1)
74
+ # negative_mask = torch.ones_like(emb_distance) - positive_mask
75
+ # return positive_mask, negative_mask
76
+
77
+ # def AngularMetricLoss(embeddings, num_pos, num_neg, alpha = 0.5, args = None):
78
+ # # embeddings: ((6*B), C, (H*W))
79
+ # # n_pos : chunk size of positive pairs
80
+ # # args: args
81
+ # # returns: loss
82
+ # geometric_loss = 0
83
+ # nsent = num_pos + num_neg
84
+ # assert nsent == 6, "number of sentences doesn't match" # nsent : S
85
+
86
+ # # flatten embeddings
87
+ # B_, C, HW = embeddings.shape
88
+ # emb = torch.mean(embeddings, dim=-1) # (S*B, C)
89
+ # emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (S*B, S*B, C)
90
+ # emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (S*B, S*B, C)
91
+
92
+ # ## zero filtering
93
+ # sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
94
+ # sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (S*B , S*B)
95
+ # sim_matrix = zero_filtering(sim_matrix)
96
+ # if check_nan(sim_matrix) :
97
+ # sim_matrix = nan_filtering(sim_matrix)
98
+ # sim_matrix = torch.clamp(sim_matrix, min=-0.999, max=0.999)
99
+ # phi = torch.acos(sim_matrix) # (S*B, S*B)
100
+ # phi[torch.isnan(phi)] = 0
101
+
102
+ # # positive pairs and loss
103
+ # positive_mask, negative_mask = return_mask(sim_matrix, nsent)
104
+ # positive_loss = torch.sum((phi**2) * positive_mask) / B_
105
+
106
+ # # negative pairs and loss
107
+ # # negative_mask = torch.ones_like(sim_matrix) - positive_mask
108
+ # phi_mask = phi < args.phi_threshold
109
+
110
+ # negative_loss = (args.phi_threshold - phi)**2
111
+ # negative_loss = zero_filtering(negative_loss)
112
+ # if check_nan(negative_loss) :
113
+ # negative_loss = nan_filtering(negative_loss)
114
+
115
+ # if args.div_batch:
116
+ # negative_loss = torch.sum(negative_loss * negative_mask * phi_mask) / B_
117
+ # else:
118
+ # negative_loss = torch.sum(negative_loss * negative_mask * phi_mask) / (B_**2 - nsent*B_)
119
+ # # print("pos loss, neg loss : ", positive_loss, negative_loss)
120
+
121
+ # geometric_loss = alpha * positive_loss + (1-alpha) * negative_loss
122
+
123
+ # return geometric_loss
124
+
125
+
126
+ class CRIS_Wo_Noise(nn.Module):
127
+ def __init__(self, cfg):
128
+ super().__init__()
129
+ # Vision & Text Encoder
130
+ clip_model = torch.jit.load(cfg.clip_pretrain,
131
+ map_location="cpu").eval()
132
+ self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float()
133
+ # Multi-Modal FPN
134
+ self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
135
+ # Decoder
136
+ self.decoder = TransformerDecoder(num_layers=cfg.num_layers,
137
+ d_model=cfg.vis_dim,
138
+ nhead=cfg.num_head,
139
+ dim_ffn=cfg.dim_ffn,
140
+ dropout=cfg.dropout,
141
+ return_intermediate=cfg.intermediate)
142
+ # Projector
143
+ self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3)
144
+ self.metric_learning = cfg.metric_learning
145
+ self.positive_strength = cfg.positive_strength
146
+ self.metric_loss_weight = cfg.metric_loss_weight
147
+ self.add_noise = cfg.add_noise
148
+ self.eps = cfg.ptb_rate
149
+ self.cfg = cfg
150
+ # self.bn_fq = nn.BatchNorm2d(1024)
151
+
152
+ def forward(self, image, text, target=None):
153
+ '''
154
+ img: b, 3, h, w
155
+ word: b, words
156
+ word_mask: b, words
157
+ if self.metric_learning:
158
+ word: b, 6, words
159
+ word_mask: b, 6, words
160
+ mask: b, 1, h, w
161
+ '''
162
+ metric_learning_flag = (self.metric_learning and self.training)
163
+ add_noise_flag = self.add_noise
164
+ # TODO : mixing option btw distance & angular loss
165
+ mix_distance_angular = False
166
+ metric_loss = 0
167
+ #print("text shape : ", text.shape)
168
+ if self.training:
169
+ bt, nt, lt = text.size()
170
+ else:
171
+ nt = 1
172
+ bt, lt = text.size()
173
+
174
+ npos= 2
175
+ nneg =nt-npos
176
+
177
+ # 1.Resizing : if metric learning, batch size of the word is doubled
178
+ if metric_learning_flag:
179
+ #print("image shape : ", image.shape)
180
+ b, c, h, w = image.size()
181
+ # duplicate image and segmentation mask
182
+ if image is not None:
183
+ image = torch.cat([image, image, image, image, image, image], dim=0)
184
+ image = image.reshape(-1, b, c, h, w).transpose(0, 1).reshape(-1, c, h, w)
185
+ if target is not None:
186
+ target = torch.cat([target, target, target, target, target, target], dim=0)
187
+ target = target.reshape(-1, b, 1, h, w).transpose(0, 1).reshape(-1, 1, h, w)
188
+
189
+ if add_noise_flag :
190
+ noise_mask = (text[:, 0, :] == text[:, 1, :])
191
+ noise_mask = torch.all(noise_mask, dim=-1)
192
+ noise_mask = noise_mask.unsqueeze(-1).repeat(1, 2).reshape(-1) # 2*b_
193
+ assert noise_mask.shape[0] == bt * npos, "noise mask shape should be 2*B_"
194
+
195
+ text = text.reshape(bt * nt, lt) # 2*b, l
196
+ # print(image.shape, image.dtype, image.type()) #float32
197
+ # print(target.shape, target.dtype, target.type()) #float32
198
+ # print(noise_mask.dtype, noise_mask.type()) #bool
199
+
200
+ # padding mask used in decoder
201
+ pad_mask = torch.zeros_like(text).masked_fill_(text == 0, 1).bool()
202
+ # print(pad_mask.dtype, pad_mask.type())
203
+
204
+ # vis: C3 / C4 / C5
205
+ # word: b, length, 1024
206
+ # state: 6b, 1024
207
+ vis = self.backbone.encode_image(image)
208
+ word, state = self.backbone.encode_text(text)
209
+ # print(vis.dtype, vis.type())
210
+ # state= state.float()
211
+ if check_nan(state) :
212
+ print('state has nan valuses')
213
+ state = nan_filtering(state)
214
+ # print(state)
215
+
216
+
217
+ b_, d_ = state.size()
218
+ assert b_ == word.size(0), "batch size of state and word should be same"
219
+
220
+ #npos = 2, nneg=4
221
+ if (add_noise_flag and self.training) :
222
+ tmp_state = state.reshape(bt, nt, -1)
223
+ pos_state = tmp_state[:, :npos, :].reshape(bt*npos, -1)
224
+ neg_state = tmp_state[:, npos:, :].reshape(bt*nneg, -1)
225
+ noise = torch.randn_like(pos_state) * self.eps
226
+ pos_state_noisy = pos_state.clone() # Clone pos_state to avoid in-place operations
227
+ pos_state_noisy[noise_mask] += noise[noise_mask] # Add noise where the mask is True
228
+ new_state = torch.cat([pos_state_noisy, neg_state], dim=0)
229
+ else:
230
+ new_state = state.reshape(bt*nt, -1)
231
+
232
+
233
+ # b, 512, 26, 26 (C4)
234
+ a3, a4, a5 = vis
235
+
236
+ fq, f5 = self.neck(vis, new_state)
237
+ b, c, h, w = fq.size()
238
+ fq = self.decoder(fq, word, pad_mask)
239
+
240
+ metric_tensor = fq
241
+
242
+ # # 3. Get metric loss
243
+ # if metric_learning_flag:
244
+ # metric_loss = AngularMetricLoss(fq, npos, nneg, alpha=self.positive_strength, args = self.cfg)
245
+
246
+ fq = fq.reshape(b, c, h, w)
247
+
248
+ # b, 1, 104, 104
249
+ pred = self.proj(fq, new_state)
250
+ #print("pred shape : ", pred.shape, " fq shape : ", fq.shape, " new_state shape : ", new_state.shape)
251
+ #breakpoint()
252
+
253
+ if self.training:
254
+ if pred.shape[-2:] != target.shape[-2:]:
255
+ target = F.interpolate(target, pred.shape[-2:],
256
+ mode='nearest').detach()
257
+ # seunghoon : 임시로 size만 맞춰놓음 #
258
+ b, _, h, w = pred.shape
259
+ assert (pred.shape == target.shape), "pred shape and target shape should be same"
260
+ pred = pred.reshape(-1, 6, h, w)[:, :2, :, :]
261
+ # pred_neg = pred.reshape(-1, 6, h, w)[:, 2:, :, :]
262
+ target = target.reshape(-1, 6, h, w)[:, :2, :, :]
263
+ # target_neg = target.reshape(-1, 6, h, w)[:, 2:, :, :]
264
+
265
+ CE_loss_pos = F.binary_cross_entropy_with_logits(pred, target)
266
+ # loss_neg = nn.MSELoss()(pred_neg,torch.zeros_like(pred_neg))
267
+ # loss = loss_pos + loss_neg/(target_neg.shape[-1])**2
268
+
269
+ return pred.detach(), target, CE_loss_pos, metric_tensor
270
+ else:
271
+ #print(self.cfg.gpu, f"; loss = {loss}")
272
+ return pred.detach()
273
+
274
+
275
+ ## Original code
276
+ # if self.training:
277
+ # if pred.shape[-2:] != target.shape[-2:]:
278
+ # target = F.interpolate(target, pred.shape[-2:],
279
+ # mode='nearest').detach()
280
+ # # seunghoon : 임시로 size만 맞춰놓음 #
281
+ # b, _, h, w = pred.shape
282
+ # assert (pred.shape == target.shape), "pred shape and target shape should be same"
283
+ # pred = pred.reshape(-1, 6, h, w)[:, :2, :, :]
284
+ # # pred_neg = pred.reshape(-1, 6, h, w)[:, 2:, :, :]
285
+ # target = target.reshape(-1, 6, h, w)[:, :2, :, :]
286
+ # # target_neg = target.reshape(-1, 6, h, w)[:, 2:, :, :]
287
+
288
+ # CE_loss_pos = F.binary_cross_entropy_with_logits(pred, target)
289
+ # # loss_neg = nn.MSELoss()(pred_neg,torch.zeros_like(pred_neg))
290
+ # # loss = loss_pos + loss_neg/(target_neg.shape[-1])**2
291
+
292
+
293
+ # # 4. if metric learning, add metric loss and normalize
294
+ # if metric_learning_flag:
295
+ # # print("CE loss : ", CE_loss_pos, "metric loss : ", metric_loss)
296
+ # loss = (CE_loss_pos + self.metric_loss_weight * metric_loss) / (1+self.metric_loss_weight)
297
+ # # DDP error handling : if there is no negative(BS = 1 or 0 for some GPUs), \
298
+ # # connect graph to avoid error
299
+ # safety_loss = loss * 0.
300
+ # loss = loss + safety_loss
301
+ # # print(self.cfg.gpu, f"; loss = {loss}")
302
+ # return pred.detach(), target, loss
303
+ # else:
304
+ # #print(self.cfg.gpu, f"; loss = {loss}")
305
+ # return pred.detach()
dianecy/VerbCentric-RIS/model_/segmenter_angular.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from model.clip import build_model
7
+
8
+ from .layers import FPN, Projector, TransformerDecoder
9
+
10
+
11
+ # def MetricLoss(embeddings, n_pos, alpha = 0.5, args = None):
12
+ # # embeddings: ((2*B), C, (H*W))
13
+ # # n_pos : chunk size of positive pairs
14
+ # # args: args
15
+ # # returns: loss
16
+ # metric_loss = 0
17
+
18
+ # # flatten embeddings
19
+ # B_, C, HW = embeddings.shape
20
+ # emb = torch.mean(embeddings, dim=-1) # (2*B, C)
21
+ # emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C)
22
+ # emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C)
23
+ # emb_distance = torch.norm(emb_i - emb_j, dim=-1) # (2*B, 2*B)
24
+ # assert torch.sum(torch.diag(emb_distance[:B_, :B_])) == 0, \
25
+ # "Diagonals are not zero. please check the permutation on the batch"
26
+ # # print("distance metrix : ", emb_distance)
27
+
28
+ # # positive pairs and loss
29
+ # positive_mask = torch.zeros_like(emb_distance)
30
+ # for i in range(B_//2):
31
+ # positive_mask[2*i, 2*i+1] = 1
32
+ # positive_mask[2*i+1, 2*i] = 1
33
+ # positive_mask.fill_diagonal_(1)
34
+ # positive_loss = torch.sum(emb_distance * positive_mask) / B_
35
+
36
+ # # negative pairs and loss
37
+ # negative_mask = torch.ones_like(emb_distance) - positive_mask
38
+ # negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / (B_**2 - 2*B_))
39
+
40
+ # # print(positive_mask, negative_mask)
41
+
42
+ # metric_loss = alpha * positive_loss + (1-alpha) * negative_loss
43
+
44
+ # return metric_loss
45
+
46
+
47
+
48
+ class CRIS_S(nn.Module):
49
+ def __init__(self, cfg):
50
+ super().__init__()
51
+ # Vision & Text Encoder
52
+ clip_model = torch.jit.load(cfg.clip_pretrain,
53
+ map_location="cpu").eval()
54
+ self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float()
55
+ # Multi-Modal FPN
56
+ self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
57
+ # Decoder
58
+ self.decoder = TransformerDecoder(num_layers=cfg.num_layers,
59
+ d_model=cfg.vis_dim,
60
+ nhead=cfg.num_head,
61
+ dim_ffn=cfg.dim_ffn,
62
+ dropout=cfg.dropout,
63
+ return_intermediate=cfg.intermediate)
64
+ # Projector
65
+ self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3)
66
+ self.metric_learning = cfg.metric_learning
67
+ self.positive_strength = cfg.positive_strength
68
+ self.metric_loss_weight = cfg.metric_loss_weight
69
+ self.eps = cfg.ptb_rate
70
+ self.cfg = cfg
71
+
72
+ def forward(self, image, text, target=None):
73
+ '''
74
+ img: b, 3, h, w
75
+ word: b, words
76
+ word_mask: b, words
77
+ if self.metric_learning:
78
+ word: b, 2, words
79
+ word_mask: b, 2, words
80
+ mask: b, 1, h, w
81
+ '''
82
+ metric_learning_flag = (self.metric_learning and self.training)
83
+ # TODO : mixing option btw distance & angular loss
84
+ mix_distance_angular = False
85
+ metric_loss = 0
86
+
87
+ # 1.Resizing : if metric learning, batch size of the word is doubled
88
+ if metric_learning_flag:
89
+ #print("image shape : ", image.shape)
90
+ b, c, h, w = image.size()
91
+ # duplicate image and segmentation mask
92
+ if image is not None:
93
+ image = torch.cat([image, image], dim=0)
94
+ image = image.reshape(-1, b, c, h, w).transpose(0, 1).reshape(-1, c, h, w)
95
+ if target is not None:
96
+ target = torch.cat([target, target], dim=0)
97
+ target = target.reshape(-1, b, 1, h, w).transpose(0, 1).reshape(-1, 1, h, w)
98
+ # duplicate noise mask
99
+ b_, n_, l_ = text.size()
100
+ assert n_ == 2 ,"word size should be 2"
101
+ noise_mask = (text[:, 0, :] == text[:, 1, :])
102
+ noise_mask = torch.all(noise_mask, dim=-1)
103
+ noise_mask = noise_mask.unsqueeze(-1).repeat(1, 2).reshape(-1) # 2*b_
104
+ assert noise_mask.shape[0] == b_ * 2, "noise mask shape should be 2*b_"
105
+ text = text.reshape(b_ * 2, l_) # 2*b, l
106
+
107
+ # print("text shape : ", text.shape)
108
+ # print("image shape : ", image.shape)
109
+ # print("target shape : ", target.shape)
110
+ # print(torch.sum(image[0::2]) == torch.sum(image[1::2]))
111
+ # print(torch.sum(target[0::2]) == torch.sum(target[1::2]))
112
+
113
+ # padding mask used in decoder
114
+ pad_mask = torch.zeros_like(text).masked_fill_(text == 0, 1).bool()
115
+ # vis: C3 / C4 / C5
116
+ # word: b, length, 1024
117
+ # state: b, 1024
118
+ vis = self.backbone.encode_image(image)
119
+ word, state = self.backbone.encode_text(text)
120
+
121
+ b_, d_ = state.size()
122
+ assert b_ == word.size(0), "batch size of state and word should be same"
123
+
124
+
125
+ # 2. State Noising Step : if number of caption is 1,
126
+ # add noise to the corresponding indices
127
+ if metric_learning_flag :
128
+ noise = torch.randn_like(state) * self.eps
129
+ state[noise_mask] = state[noise_mask] + noise[noise_mask]
130
+
131
+
132
+ # b, 512, 26, 26 (C4)
133
+ a3, a4, a5 = vis
134
+ fq, f5 = self.neck(vis, state)
135
+ b, c, h, w = fq.size()
136
+ fq = self.decoder(fq, word, pad_mask)
137
+ metric_tensor = fq
138
+ # if metric_learning_flag:
139
+ # metric_loss = AngularMetricLoss(fq, 2, alpha=self.positive_strength, args = self.cfg) # (1-self.positive_strength) *
140
+ # if mix_distance_angular:
141
+ # metric_loss += MetricLoss(fq, 2, alpha=self.positive_strength, args = self.cfg) # self.positive_strength *
142
+ fq = fq.reshape(b, c, h, w)
143
+
144
+ # b, 1, 104, 104
145
+ pred = self.proj(fq, state)
146
+
147
+ if self.training:
148
+ # resize mask
149
+ if pred.shape[-2:] != target.shape[-2:]:
150
+ target = F.interpolate(target, pred.shape[-2:],
151
+ mode='nearest').detach()
152
+ CE_loss = F.binary_cross_entropy_with_logits(pred, target)
153
+
154
+ # 4. if metric learning, add metric loss and normalize
155
+ # if metric_learning_flag:
156
+ # loss = (loss + self.metric_loss_weight * metric_loss) / (1+self.metric_loss_weight)
157
+ # safety_loss = loss * 0.
158
+ # loss = loss + safety_loss
159
+
160
+ return pred.detach(), target, CE_loss, metric_tensor
161
+ else:
162
+ #print(self.cfg.gpu, f"; loss = {loss}")
163
+ return pred.detach()
dianecy/VerbCentric-RIS/model_/segmenter_verbonly.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from model_.clip import build_model
7
+
8
+ from .layers import FPN, Projector, TransformerDecoder
9
+
10
+ class CRIS_PosOnly(nn.Module):
11
+ def __init__(self, cfg):
12
+ super().__init__()
13
+ # Vision & Text Encoder
14
+ clip_model = torch.jit.load(cfg.clip_pretrain,
15
+ map_location="cpu").eval()
16
+ self.backbone = build_model(clip_model.state_dict(), cfg.word_len, cfg.freeze).float()
17
+
18
+ # Multi-Modal FPN
19
+ self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
20
+ # Decoder
21
+ self.decoder = TransformerDecoder(num_layers=cfg.num_layers,
22
+ d_model=cfg.vis_dim,
23
+ nhead=cfg.num_head,
24
+ dim_ffn=cfg.dim_ffn,
25
+ dropout=cfg.dropout,
26
+ return_intermediate=cfg.intermediate)
27
+ # Projector
28
+ self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3)
29
+ self.metric_learning = False # cfg.metric_learning
30
+ self.metric_loss_weight = cfg.metric_loss_weight
31
+ self.cfg = cfg
32
+
33
+
34
+
35
+
36
+ def forward(self, image, text, target=None, verb=None):
37
+ '''
38
+ image: b, 3, h, w
39
+ text: b, words
40
+ target: b, 1, h, w
41
+ verb: b, words (if applicable, only used in training mode for contrastive learning)
42
+ '''
43
+
44
+ sentences, images, targets, pad_masks = [], [], [], []
45
+
46
+ if self.training:
47
+ verb_masks = []
48
+ cl_masks = []
49
+
50
+ for idx in range(len(text)):
51
+ sentences.append(text[idx])
52
+ images.append(image[idx])
53
+ targets.append(target[idx])
54
+ pad_masks.append(torch.zeros_like(text[idx]).masked_fill_(text[idx] == 0, 1).bool())
55
+
56
+ # If verb exists, process it
57
+ if verb[idx].numel() > 0 and verb[idx].sum().item() > 0:
58
+ verb_masks.extend([1, 1]) # Both original sentence and verb are marked
59
+ cl_masks.extend([1, 0]) # Only original sentence get marked
60
+ sentences.append(verb[idx])
61
+ images.append(image[idx])
62
+ targets.append(target[idx])
63
+ pad_masks.append(torch.zeros_like(verb[idx]).masked_fill_(verb[idx] == 0, 1).bool())
64
+ else:
65
+ verb_masks.append(0)
66
+ cl_masks.append(1)
67
+
68
+
69
+ sentences = torch.stack(sentences)
70
+ images = torch.stack(images)
71
+ targets = torch.stack(targets)
72
+ pad_masks = torch.stack(pad_masks)
73
+ verb_masks = torch.tensor(verb_masks, dtype=torch.bool)
74
+ cl_masks = torch.tensor(cl_masks, dtype=torch.bool)
75
+
76
+ else:
77
+ sentences = text
78
+ images = image
79
+ targets = target
80
+ pad_masks = torch.zeros_like(text).masked_fill_(text == 0, 1).bool()
81
+
82
+ # Encoding images and text
83
+ vis = self.backbone.encode_image(images)
84
+ word, state = self.backbone.encode_text(sentences)
85
+
86
+ # FPN neck and decoder
87
+ fq, f5 = self.neck(vis, state)
88
+ b, c, h, w = fq.size()
89
+ fq = self.decoder(fq, word, pad_masks)
90
+ metric_tensor = fq # b, c, h*w
91
+ fq = fq.reshape(b, c, h, w)
92
+
93
+ # Final prediction
94
+ pred = self.proj(fq, state)
95
+
96
+ if self.training:
97
+ if pred.shape[-2:] != targets.shape[-2:]:
98
+ targets = F.interpolate(targets, pred.shape[-2:], mode='nearest').detach()
99
+ loss = F.binary_cross_entropy_with_logits(pred[cl_masks], targets[cl_masks])
100
+
101
+ if self.metric_learning:
102
+ metric_loss = self.compute_metric_loss(metric_tensor, verb_masks, args=self.cfg)
103
+ loss = (loss + self.metric_loss_weight * metric_loss) / (1 + self.metric_loss_weight)
104
+
105
+ return pred[cl_masks].detach(), targets[cl_masks], loss
106
+
107
+ return pred.detach() # In eval mode, only return the predictions
108
+
109
+
110
+ def compute_metric_loss(self, metric_tensor, positive_verbs, negative_verbs, args) :
111
+ if args.loss_option == "ACL_verbonly" :
112
+ metric_loss = self.UniAngularContrastLoss(metric_tensor, positive_verbs, negative_verbs, m=args.margin_value, tau=args.temperature, verbonly=True, args=args)
113
+ elif args.loss_option == "ACL" :
114
+ metric_loss = self.UniAngularContrastLoss(metric_tensor, positive_verbs, negative_verbs, m=args.margin_value, tau=args.temperature, verbonly=False, args=args)
115
+
116
+ return metric_loss
117
+
118
+
119
+ def return_mask(self, emb_distance, verb_mask=None):
120
+ B_, B_ = emb_distance.shape
121
+ positive_mask = torch.zeros_like(emb_distance)
122
+ positive_mask.fill_diagonal_(1) # Set diagonal elements to 1 for all cases
123
+
124
+ if B_ < len(verb_mask):
125
+ # If B_ equals to 2*K (double the number of verb phrase)
126
+ for i in range(B_ // 2):
127
+ positive_mask[2 * i, 2 * i + 1] = 1
128
+ positive_mask[2 * i + 1, 2 * i] = 1
129
+ else:
130
+ # Process the case where we have a mix of sentences with and without verbs
131
+ i = 0
132
+ while i < B_:
133
+ if verb_mask[i] == 1:
134
+ positive_mask[i, i + 1] = 1
135
+ positive_mask[i + 1, i] = 1
136
+ i += 2
137
+ else:
138
+ i += 1
139
+ negative_mask = torch.ones_like(emb_distance) - positive_mask
140
+ return positive_mask, negative_mask
141
+
142
+
143
+ def UniAngularContrastLoss(self, total_fq, verb_mask, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None):
144
+ _, C, HW = total_fq.shape
145
+
146
+ if verbonly :
147
+ emb = torch.mean(total_fq[verb_mask], dim=-1)
148
+ assert emb.shape[0] % 2 == 0, f"Embedding count {emb.shape[0]} is not divisible by 2."
149
+ else :
150
+ emb = torch.mean(total_fq, dim=-1)
151
+
152
+ B_ = emb.shape[0]
153
+ # emb = F.normalize(emb, p=2, dim=1)
154
+ emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (B_, B_, C)
155
+ emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (B_, B_, C)
156
+ sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
157
+ sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_)
158
+ sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999)
159
+
160
+ positive_mask, negative_mask = self.return_mask(sim_matrix, verb_mask)
161
+
162
+ # Apply margin to positive pairs
163
+ sim_matrix_with_margin = sim_matrix.clone()
164
+ sim_matrix_with_margin[positive_mask.bool()] = torch.cos(torch.acos(sim_matrix[positive_mask.bool()]) + m / 57.2958)
165
+
166
+ # Scale logits with temperature
167
+ logits = sim_matrix_with_margin / tau
168
+
169
+ # Compute the softmax loss for all pairs
170
+ exp_logits = torch.exp(logits)
171
+ pos_exp_logits = exp_logits[positive_mask.bool()]
172
+ total_exp_logits = exp_logits.sum(dim=-1)
173
+
174
+ # Compute the final loss: L_arc = -log(e^(cos(theta + m)/tau) / sum(e^(cos(theta)/tau)))
175
+ positive_loss = -torch.log(pos_exp_logits / total_exp_logits[positive_mask.bool()])
176
+ angular_loss = positive_loss.mean()
177
+
178
+ return angular_loss
dianecy/VerbCentric-RIS/model_/segmenter_verbonly_fin.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from model_.clip import build_model
7
+
8
+ from .layers import FPN, Projector, TransformerDecoder
9
+
10
+ class CRIS_PosOnly_rev(nn.Module):
11
+ def __init__(self, cfg):
12
+ super().__init__()
13
+ # Vision & Text Encoder
14
+ clip_model = torch.jit.load(cfg.clip_pretrain,
15
+ map_location="cpu").eval()
16
+ self.backbone = build_model(clip_model.state_dict(), cfg.word_len, cfg.freeze).float()
17
+
18
+ # Multi-Modal FPN
19
+ self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
20
+ # Decoder
21
+ self.decoder = TransformerDecoder(num_layers=cfg.num_layers,
22
+ d_model=cfg.vis_dim,
23
+ nhead=cfg.num_head,
24
+ dim_ffn=cfg.dim_ffn,
25
+ dropout=cfg.dropout,
26
+ return_intermediate=cfg.intermediate)
27
+ # Projector
28
+ self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3)
29
+ self.metric_learning = False # cfg.metric_learning
30
+ self.metric_loss_weight = cfg.metric_loss_weight
31
+ self.cfg = cfg
32
+
33
+
34
+
35
+
36
+ def forward(self, image, text, target=None, verb=None):
37
+ '''
38
+ image: b, 3, h, w
39
+ text: b, words
40
+ target: b, 1, h, w
41
+ verb: b, words (if applicable, only used in training mode for contrastive learning)
42
+ '''
43
+
44
+ sentences, images, targets, pad_masks = [], [], [], []
45
+
46
+ if self.training:
47
+ verb_masks = []
48
+ cl_masks = []
49
+
50
+ for idx in range(len(text)):
51
+ sentences.append(text[idx])
52
+ images.append(image[idx])
53
+ targets.append(target[idx])
54
+ pad_masks.append(torch.zeros_like(text[idx]).masked_fill_(text[idx] == 0, 1).bool())
55
+
56
+ # If verb exists, process it
57
+ if verb[idx].numel() > 0 and verb[idx].sum().item() > 0:
58
+ verb_masks.extend([1, 1]) # Both original sentence and verb are marked
59
+ cl_masks.extend([1, 0]) # Only original sentence get marked
60
+ sentences.append(verb[idx])
61
+ images.append(image[idx])
62
+ targets.append(target[idx])
63
+ pad_masks.append(torch.zeros_like(verb[idx]).masked_fill_(verb[idx] == 0, 1).bool())
64
+ else:
65
+ verb_masks.append(0)
66
+ cl_masks.append(1)
67
+
68
+
69
+ sentences = torch.stack(sentences)
70
+ images = torch.stack(images)
71
+ targets = torch.stack(targets)
72
+ pad_masks = torch.stack(pad_masks)
73
+ verb_masks = torch.tensor(verb_masks, dtype=torch.bool)
74
+ cl_masks = torch.tensor(cl_masks, dtype=torch.bool)
75
+
76
+ else:
77
+ sentences = text
78
+ images = image
79
+ targets = target
80
+ pad_masks = torch.zeros_like(text).masked_fill_(text == 0, 1).bool()
81
+
82
+ # Encoding images and text
83
+ vis = self.backbone.encode_image(images)
84
+ word, state = self.backbone.encode_text(sentences)
85
+
86
+ # FPN neck and decoder
87
+ fq, f5 = self.neck(vis, state)
88
+ b, c, h, w = fq.size()
89
+ fq = self.decoder(fq, word, pad_masks)
90
+ metric_tensor = fq # b, c, h*w
91
+ fq = fq.reshape(b, c, h, w)
92
+
93
+ # Final prediction
94
+ pred = self.proj(fq, state)
95
+
96
+ if self.training:
97
+ if pred.shape[-2:] != targets.shape[-2:]:
98
+ targets = F.interpolate(targets, pred.shape[-2:], mode='nearest').detach()
99
+
100
+ loss = F.binary_cross_entropy_with_logits(pred[cl_masks], targets[cl_masks])
101
+
102
+ if self.metric_learning:
103
+ metric_loss = self.compute_metric_loss(metric_tensor, verb_masks, args=self.cfg)
104
+ loss = (loss + self.metric_loss_weight * metric_loss) / (1 + self.metric_loss_weight)
105
+
106
+ return pred[cl_masks].detach(), targets[cl_masks], loss
107
+
108
+ return pred.detach() # In eval mode, only return the predictions
109
+
110
+
111
+ def compute_metric_loss(self, metric_tensor, positive_verbs, negative_verbs, args) :
112
+ if args.loss_option == "ACL_verbonly" :
113
+ metric_loss = self.UniAngularContrastLoss(metric_tensor, positive_verbs, negative_verbs, m=args.margin_value, tau=args.temperature, verbonly=True, args=args)
114
+ elif args.loss_option == "ACL" :
115
+ metric_loss = self.UniAngularContrastLoss(metric_tensor, positive_verbs, negative_verbs, m=args.margin_value, tau=args.temperature, verbonly=False, args=args)
116
+
117
+ return metric_loss
118
+
119
+
120
+ def return_mask(self, emb_distance, verb_mask=None):
121
+ B_, B_ = emb_distance.shape
122
+ positive_mask = torch.zeros_like(emb_distance)
123
+ positive_mask.fill_diagonal_(1) # Set diagonal elements to 1 for all cases
124
+
125
+ if B_ < len(verb_mask):
126
+ # If B_ equals to 2*K (double the number of verb phrase)
127
+ for i in range(B_ // 2):
128
+ positive_mask[2 * i, 2 * i + 1] = 1
129
+ positive_mask[2 * i + 1, 2 * i] = 1
130
+ else:
131
+ # Process the case where we have a mix of sentences with and without verbs
132
+ i = 0
133
+ while i < B_:
134
+ if verb_mask[i] == 1:
135
+ positive_mask[i, i + 1] = 1
136
+ positive_mask[i + 1, i] = 1
137
+ i += 2
138
+ else:
139
+ i += 1
140
+ negative_mask = torch.ones_like(emb_distance) - positive_mask
141
+ return positive_mask, negative_mask
142
+
143
+
144
+ def UniAngularContrastLoss(self, total_fq, verb_mask, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None):
145
+ _, C, HW = total_fq.shape
146
+
147
+ if verbonly :
148
+ emb = torch.mean(total_fq[verb_mask], dim=-1)
149
+ assert emb.shape[0] % 2 == 0, f"Embedding count {emb.shape[0]} is not divisible by 2."
150
+ else :
151
+ emb = torch.mean(total_fq, dim=-1)
152
+
153
+ B_ = emb.shape[0]
154
+ # emb = F.normalize(emb, p=2, dim=1)
155
+ emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (B_, B_, C)
156
+ emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (B_, B_, C)
157
+ sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
158
+ sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_)
159
+ sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999)
160
+
161
+ positive_mask, negative_mask = self.return_mask(sim_matrix, verb_mask)
162
+
163
+ # Apply margin to positive pairs
164
+ sim_matrix_with_margin = sim_matrix.clone()
165
+ sim_matrix_with_margin[positive_mask.bool()] = torch.cos(torch.acos(sim_matrix[positive_mask.bool()]) + m / 57.2958)
166
+
167
+ # Scale logits with temperature
168
+ logits = sim_matrix_with_margin / tau
169
+
170
+ # Compute the softmax loss for all pairs
171
+ exp_logits = torch.exp(logits)
172
+ pos_exp_logits = exp_logits[positive_mask.bool()]
173
+ total_exp_logits = exp_logits.sum(dim=-1)
174
+
175
+ # Compute the final loss: L_arc = -log(e^(cos(theta + m)/tau) / sum(e^(cos(theta)/tau)))
176
+ positive_loss = -torch.log(pos_exp_logits / total_exp_logits)
177
+ angular_loss = positive_loss.mean()
178
+
179
+ return angular_loss
dianecy/VerbCentric-RIS/model_/segmenter_verbonly_hardneg.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from model_.clip import build_model
7
+
8
+ from .layers import FPN, Projector, TransformerDecoder
9
+
10
+ class CRIS_VerbOnly(nn.Module):
11
+ def __init__(self, cfg):
12
+ super().__init__()
13
+ # Vision & Text Encoder
14
+ clip_model = torch.jit.load(cfg.clip_pretrain,
15
+ map_location="cpu").eval()
16
+ self.backbone = build_model(clip_model.state_dict(), cfg.word_len, cfg.freeze).float()
17
+ # Multi-Modal FPN
18
+ self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
19
+ # Decoder
20
+ self.decoder = TransformerDecoder(num_layers=cfg.num_layers,
21
+ d_model=cfg.vis_dim,
22
+ nhead=cfg.num_head,
23
+ dim_ffn=cfg.dim_ffn,
24
+ dropout=cfg.dropout,
25
+ return_intermediate=cfg.intermediate)
26
+ # Projector
27
+ self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3)
28
+ self.metric_learning = False # cfg.metric_learning
29
+ # self.positive_strength = cfg.positive_strength
30
+ self.metric_loss_weight = cfg.metric_loss_weight
31
+ # self.eps = cfg.ptb_rate
32
+ self.cfg = cfg
33
+
34
+
35
+ def forward(self, image, text, target=None, hardpos=None, hardneg=None):
36
+ '''
37
+ image: b, 3, h, w
38
+ text: b, words
39
+ target: b, 1, h, w
40
+ verb: b, words (if applicable, only used in training mode for contrastive learning)
41
+ '''
42
+
43
+ if self.training:
44
+ sentences, images, targets, pad_masks = [], [], [], []
45
+ posverb_mask, negverb_mask = [], []
46
+ cl_masks = []
47
+
48
+ for idx in range(len(text)):
49
+ sentences.append(text[idx])
50
+ images.append(image[idx])
51
+ targets.append(target[idx])
52
+ pad_masks.append(torch.zeros_like(text[idx]).masked_fill_(text[idx] == 0, 1).bool())
53
+
54
+ if hardpos[idx].numel() > 0 and hardpos[idx].sum().item() > 0:
55
+ # if hard positive exists, check the condition
56
+ if hardneg[idx].numel() > 0 and hardneg[idx].sum().item() > 0:
57
+ # if hard positive and hard negative exists
58
+ posverb_mask.extend([1, 1, 0]) # mark original, hard positive as 1, negative as 0
59
+ negverb_mask.extend([0, 0, 1]) # mark only negative as 1
60
+
61
+ if not self.cfg.hn_celoss :
62
+ cl_masks.extend([1, 0, 0]) # mark only original as 1
63
+ else :
64
+ cl_masks.extend([1, 0, 1])
65
+ sentences.extend([hardpos[idx], hardneg[idx]])
66
+ images.extend([image[idx], image[idx]])
67
+ targets.extend([target[idx], torch.zeros_like(original_target, device=original_target.device)])
68
+ pad_masks.extend([
69
+ torch.zeros_like(hardpos[idx]).masked_fill_(hardpos[idx] == 0, 1).bool(),
70
+ torch.zeros_like(hardneg[idx]).masked_fill_(hardneg[idx] == 0, 1).bool()
71
+ ])
72
+
73
+ else :
74
+ # only hard positive exists, no negatives
75
+ posverb_mask.extend([1, 1])
76
+ negverb_mask.extend([0, 0])
77
+ cl_masks.extend([1, 0])
78
+
79
+ sentences.append(hardpos[idx])
80
+ images.append(image[idx])
81
+ targets.append(target[idx])
82
+ pad_masks.append(torch.zeros_like(hardpos[idx]).masked_fill_(hardpos[idx] == 0, 1).bool())
83
+ else :
84
+ # no hard positive, no hard negative. only original sentence itself.
85
+ posverb_mask.append(0)
86
+ negverb_mask.append(0)
87
+ cl_masks.append(1)
88
+
89
+
90
+ sentences = torch.stack(sentences)
91
+ images = torch.stack(images)
92
+ targets = torch.stack(targets)
93
+ pad_masks = torch.stack(pad_masks)
94
+ posverb_mask = torch.tensor(posverb_mask, dtype=torch.bool)
95
+ negverb_mask = torch.tensor(negverb_mask, dtype=torch.bool)
96
+ cl_masks = torch.tensor(cl_masks, dtype=torch.bool)
97
+
98
+ else:
99
+ sentences = text
100
+ images = image
101
+ targets = target
102
+ pad_masks = torch.zeros_like(text).masked_fill_(text == 0, 1).bool()
103
+
104
+ # Encoding images and text
105
+ vis = self.backbone.encode_image(images)
106
+ word, state = self.backbone.encode_text(sentences)
107
+
108
+ # FPN neck and decoder
109
+ fq, f5 = self.neck(vis, state)
110
+ b, c, h, w = fq.size()
111
+ fq = self.decoder(fq, word, pad_masks)
112
+ metric_tensor = fq # b, c, h*w
113
+ fq = fq.reshape(b, c, h, w)
114
+
115
+ # Final prediction
116
+ pred = self.proj(fq, state)
117
+
118
+ if self.training:
119
+ if pred.shape[-2:] != targets.shape[-2:]:
120
+ targets = F.interpolate(targets, pred.shape[-2:], mode='nearest').detach()
121
+
122
+ loss = F.binary_cross_entropy_with_logits(pred[cl_masks], targets[cl_masks])
123
+
124
+ if self.metric_learning:
125
+ metric_loss = self.compute_metric_loss(metric_tensor, posverb_mask, negverb_mask, args=self.cfg)
126
+ loss = (loss + self.metric_loss_weight * metric_loss) / (1 + self.metric_loss_weight)
127
+
128
+ return pred.detach(), targets, loss
129
+
130
+ return pred.detach()
131
+
132
+
133
+ def compute_metric_loss(self, metric_tensor, positive_verbs, negative_verbs, args) :
134
+ if args.loss_option == "ACL_verbonly" :
135
+ metric_loss = self.UniAngularContrastLoss(metric_tensor, positive_verbs, negative_verbs, m=args.margin_value, tau=args.temperature, verbonly=True, args=args)
136
+ elif args.loss_option == "ACL" :
137
+ metric_loss = self.UniAngularContrastLoss(metric_tensor, positive_verbs, negative_verbs, m=args.margin_value, tau=args.temperature, verbonly=False, args=args)
138
+
139
+ return metric_loss
140
+
141
+
142
+ def return_mask(self, emb_distance, positive_verbs, negative_verbs, verb_mask):
143
+ B_, B_ = emb_distance.shape
144
+ positive_mask = torch.zeros_like(emb_distance)
145
+ negative_mask = torch.ones_like(emb_distance)
146
+ hard_negative_mask = torch.zeros_like(emb_distance)
147
+ positive_mask.fill_diagonal_(1)
148
+
149
+ if B_ < len(verb_mask):
150
+ # Considering only verbs that pass the verb_mask filter
151
+ positive_verbs = torch.tensor(positive_verbs)[verb_mask]
152
+ negative_verbs = torch.tensor(negative_verbs)[verb_mask]
153
+
154
+ # Exclude hard negatives from both masks (diagonal)
155
+ for i in range(B_):
156
+ if negative_verbs[i] == 1:
157
+ positive_mask[i, i] = 0
158
+ negative_mask[i, i] = 0
159
+ # Set the entire row and column for the hard negative, except the diagonal
160
+ hard_negative_mask[i, :] = 1 # Mark the i-th row
161
+ hard_negative_mask[:, i] = 1 # Mark the i-th column
162
+ hard_negative_mask[i, i] = 0 # Ensure diagonal element (i, i) is 0
163
+
164
+ i = 0
165
+ while i < B_:
166
+ if positive_verbs[i] == 1:
167
+ if i + 1 < B_ and positive_verbs[i + 1] == 1:
168
+ positive_mask[i, i + 1] = 1
169
+ positive_mask[i + 1, i] = 1
170
+ i += 2
171
+ else:
172
+ i += 1
173
+ else:
174
+ # Exclude hard negatives from both masks (diagonal)
175
+ for i in range(B_):
176
+ if negative_verbs[i] == 1:
177
+ positive_mask[i, i] = 0
178
+ negative_mask[i, i] = 0
179
+ # Set the entire row and column for the hard negative, except the diagonal
180
+ hard_negative_mask[i, :] = 1 # Mark the i-th row
181
+ hard_negative_mask[:, i] = 1 # Mark the i-th column
182
+ hard_negative_mask[i, i] = 0 # Ensure diagonal element (i, i) is 0
183
+
184
+ # Apply the positive pairs logic similarly as above
185
+ i = 0
186
+ while i < B_:
187
+ if positive_verbs[i] == 1 and i + 1 < B_ and positive_verbs[i + 1] == 1:
188
+ positive_mask[i, i + 1] = 1
189
+ positive_mask[i + 1, i] = 1
190
+ i += 2
191
+ else:
192
+ i += 1
193
+
194
+ negative_mask = negative_mask - positive_mask
195
+ negative_mask[hard_negative_mask.bool()] = 0 # Set hard negative indices to 0 in negative_mask
196
+ return positive_mask, negative_mask, hard_negative_mask
197
+
198
+
199
+ def UniAngularContrastLoss(self, total_fq, positive_verbs, negative_verbs, m=0.5, tau=0.05, verbonly=True, args=None):
200
+ """
201
+ Angular Margin Contrastive Loss function with mask visualization.
202
+ """
203
+ verb_mask = positive_verbs + negative_verbs
204
+
205
+ if verbonly:
206
+ emb = torch.mean(total_fq[verb_mask], dim=-1)
207
+ else:
208
+ emb = torch.mean(total_fq, dim=-1) # (B, C)
209
+
210
+ B_ = emb.shape[0]
211
+ # emb = F.normalize(emb, p=2, dim=1)
212
+ emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (B_, B_, C)
213
+ emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (B_, B_, C)
214
+ sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
215
+ sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_)
216
+ sim_matrix = torch.clamp(sim_matrix, min=-1+1e-10, max=1-1e-10)
217
+
218
+ # ranking based loss prep
219
+ '''
220
+ l2_dist = torch.cdist(emb, emb, p=2) # i-> j distances
221
+ KLD_
222
+ ranking_per_i = get_ranking() # 어디선가 i번째 instance에 대한 hardness를 불러��
223
+
224
+
225
+ '''
226
+
227
+ # Apply angular margin for positive pairs using return_mask
228
+ positive_mask, negative_mask, hard_negative_mask = self.return_mask(sim_matrix, positive_verbs, negative_verbs, verb_mask)
229
+ assert positive_mask.shape == sim_matrix.shape, f"Positive mask shape {positive_mask.shape} does not match sim_matrix shape {sim_matrix.shape}"
230
+ print(f"Positive mask: {positive_mask}")
231
+ print(f"Negative mask: {negative_mask}")
232
+ print(f"Hard negative mask: {hard_negative_mask}")
233
+
234
+ # Apply margin to positive pairs
235
+ sim_matrix_with_margin = sim_matrix.clone()
236
+ sim_matrix_with_margin[positive_mask.bool()] = torch.cos(torch.acos(sim_matrix[positive_mask.bool()]) + m / 57.2958)
237
+
238
+ # Scale logits with temperature
239
+ logits = sim_matrix_with_margin / tau
240
+
241
+ # Compute the softmax loss for all pairs
242
+ exp_logits = torch.exp(logits)
243
+
244
+ pos_exp_logits = exp_logits[positive_mask.bool()]
245
+ neg_exp_logits = exp_logits[negative_mask.bool()]
246
+ hardneg_exp_logits = exp_logits[hard_negative_mask.bool()]
247
+
248
+ # total_exp_logits = exp_logits.sum(dim=-1)
249
+ total_exp_logits = (
250
+ pos_exp_logits.sum(dim=-1) +
251
+ neg_exp_logits.sum(dim=-1) +
252
+ (hardneg_exp_logits.sum(dim=-1) * args.acl_hn_weight)
253
+ )
254
+
255
+ # Compute the final loss: L_arc = -log(e^(cos(theta + m)/tau) / sum(e^(cos(theta)/tau)))
256
+ positive_loss = -torch.log(pos_exp_logits / total_exp_logits)
257
+ angular_loss = positive_loss.mean()
258
+
259
+ return angular_loss