Upload folder using huggingface_hub
#1
by
dianecy
- opened
- dianecy/VerbCentric-RIS/model_/.gitignore +1 -0
- dianecy/VerbCentric-RIS/model_/__init__.py +105 -0
- dianecy/VerbCentric-RIS/model_/__pycache__/__init__.cpython-39.pyc +0 -0
- dianecy/VerbCentric-RIS/model_/__pycache__/clip.cpython-39.pyc +0 -0
- dianecy/VerbCentric-RIS/model_/__pycache__/layers.cpython-39.pyc +0 -0
- dianecy/VerbCentric-RIS/model_/__pycache__/segmenter.cpython-39.pyc +0 -0
- dianecy/VerbCentric-RIS/model_/__pycache__/segmenter_verbonly.cpython-39.pyc +0 -0
- dianecy/VerbCentric-RIS/model_/__pycache__/segmenter_verbonly_fin.cpython-39.pyc +0 -0
- dianecy/VerbCentric-RIS/model_/__pycache__/segmenter_verbonly_hardneg.cpython-39.pyc +0 -0
- dianecy/VerbCentric-RIS/model_/clip.py +560 -0
- dianecy/VerbCentric-RIS/model_/layers.py +309 -0
- dianecy/VerbCentric-RIS/model_/segmenter.py +204 -0
- dianecy/VerbCentric-RIS/model_/segmenter_ang_nonoise_ddp.py +305 -0
- dianecy/VerbCentric-RIS/model_/segmenter_angular.py +163 -0
- dianecy/VerbCentric-RIS/model_/segmenter_verbonly.py +178 -0
- dianecy/VerbCentric-RIS/model_/segmenter_verbonly_fin.py +179 -0
- dianecy/VerbCentric-RIS/model_/segmenter_verbonly_hardneg.py +259 -0
dianecy/VerbCentric-RIS/model_/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__/
|
dianecy/VerbCentric-RIS/model_/__init__.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .segmenter import CRIS
|
2 |
+
# from .segmenter_angular import CRIS_S
|
3 |
+
from .segmenter_verbonly import CRIS_PosOnly
|
4 |
+
from .segmenter_verbonly_fin import CRIS_PosOnly_rev
|
5 |
+
from .segmenter_verbonly_hardneg import CRIS_VerbOnly
|
6 |
+
from loguru import logger
|
7 |
+
|
8 |
+
def build_segmenter_pos_rev(args):
|
9 |
+
model = CRIS_PosOnly_rev(args)
|
10 |
+
backbone = []
|
11 |
+
head = []
|
12 |
+
for k, v in model.named_parameters():
|
13 |
+
if k.startswith('backbone') and 'positional_embedding' not in k:
|
14 |
+
backbone.append(v)
|
15 |
+
else:
|
16 |
+
head.append(v)
|
17 |
+
logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
|
18 |
+
param_list = [{
|
19 |
+
'params': backbone,
|
20 |
+
'initial_lr': args.lr_multi * args.base_lr
|
21 |
+
}, {
|
22 |
+
'params': head,
|
23 |
+
'initial_lr': args.base_lr
|
24 |
+
}]
|
25 |
+
return model, param_list
|
26 |
+
|
27 |
+
def build_segmenter_pos(args):
|
28 |
+
model = CRIS_PosOnly(args)
|
29 |
+
backbone = []
|
30 |
+
head = []
|
31 |
+
for k, v in model.named_parameters():
|
32 |
+
if k.startswith('backbone') and 'positional_embedding' not in k:
|
33 |
+
backbone.append(v)
|
34 |
+
else:
|
35 |
+
head.append(v)
|
36 |
+
logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
|
37 |
+
param_list = [{
|
38 |
+
'params': backbone,
|
39 |
+
'initial_lr': args.lr_multi * args.base_lr
|
40 |
+
}, {
|
41 |
+
'params': head,
|
42 |
+
'initial_lr': args.base_lr
|
43 |
+
}]
|
44 |
+
return model, param_list
|
45 |
+
|
46 |
+
|
47 |
+
def build_segmenter(args):
|
48 |
+
model = CRIS_VerbOnly(args)
|
49 |
+
backbone = []
|
50 |
+
head = []
|
51 |
+
for k, v in model.named_parameters():
|
52 |
+
if k.startswith('backbone') and 'positional_embedding' not in k:
|
53 |
+
backbone.append(v)
|
54 |
+
else:
|
55 |
+
head.append(v)
|
56 |
+
logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
|
57 |
+
param_list = [{
|
58 |
+
'params': backbone,
|
59 |
+
'initial_lr': args.lr_multi * args.base_lr
|
60 |
+
}, {
|
61 |
+
'params': head,
|
62 |
+
'initial_lr': args.base_lr
|
63 |
+
}]
|
64 |
+
return model, param_list
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
def build_segmenter_original(args):
|
69 |
+
model = CRIS(args)
|
70 |
+
backbone = []
|
71 |
+
head = []
|
72 |
+
for k, v in model.named_parameters():
|
73 |
+
if k.startswith('backbone') and 'positional_embedding' not in k:
|
74 |
+
backbone.append(v)
|
75 |
+
else:
|
76 |
+
head.append(v)
|
77 |
+
logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
|
78 |
+
param_list = [{
|
79 |
+
'params': backbone,
|
80 |
+
'initial_lr': args.lr_multi * args.base_lr
|
81 |
+
}, {
|
82 |
+
'params': head,
|
83 |
+
'initial_lr': args.base_lr
|
84 |
+
}]
|
85 |
+
return model, param_list
|
86 |
+
|
87 |
+
|
88 |
+
# def build_segmenter_textaug(args):
|
89 |
+
# model = CRIS_Wo_Noise(args)
|
90 |
+
# backbone = []
|
91 |
+
# head = []
|
92 |
+
# for k, v in model.named_parameters():
|
93 |
+
# if k.startswith('backbone') and 'positional_embedding' not in k:
|
94 |
+
# backbone.append(v)
|
95 |
+
# else:
|
96 |
+
# head.append(v)
|
97 |
+
# logger.info('Backbone with decay={}, Head={}'.format(len(backbone), len(head)))
|
98 |
+
# param_list = [{
|
99 |
+
# 'params': backbone,
|
100 |
+
# 'initial_lr': args.lr_multi * args.base_lr
|
101 |
+
# }, {
|
102 |
+
# 'params': head,
|
103 |
+
# 'initial_lr': args.base_lr
|
104 |
+
# }]
|
105 |
+
# return model, param_list
|
dianecy/VerbCentric-RIS/model_/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (1.88 kB). View file
|
|
dianecy/VerbCentric-RIS/model_/__pycache__/clip.cpython-39.pyc
ADDED
Binary file (16.8 kB). View file
|
|
dianecy/VerbCentric-RIS/model_/__pycache__/layers.cpython-39.pyc
ADDED
Binary file (9.06 kB). View file
|
|
dianecy/VerbCentric-RIS/model_/__pycache__/segmenter.cpython-39.pyc
ADDED
Binary file (4.83 kB). View file
|
|
dianecy/VerbCentric-RIS/model_/__pycache__/segmenter_verbonly.cpython-39.pyc
ADDED
Binary file (4.76 kB). View file
|
|
dianecy/VerbCentric-RIS/model_/__pycache__/segmenter_verbonly_fin.cpython-39.pyc
ADDED
Binary file (4.78 kB). View file
|
|
dianecy/VerbCentric-RIS/model_/__pycache__/segmenter_verbonly_hardneg.cpython-39.pyc
ADDED
Binary file (5.96 kB). View file
|
|
dianecy/VerbCentric-RIS/model_/clip.py
ADDED
@@ -0,0 +1,560 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import Tuple, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
|
10 |
+
class Bottleneck(nn.Module):
|
11 |
+
expansion = 4
|
12 |
+
|
13 |
+
def __init__(self, inplanes, planes, stride=1):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
17 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
18 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
19 |
+
|
20 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
21 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
22 |
+
|
23 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
24 |
+
|
25 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
26 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
27 |
+
|
28 |
+
self.relu = nn.ReLU(inplace=True)
|
29 |
+
self.downsample = None
|
30 |
+
self.stride = stride
|
31 |
+
|
32 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
33 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
34 |
+
self.downsample = nn.Sequential(
|
35 |
+
OrderedDict([("-1", nn.AvgPool2d(stride)),
|
36 |
+
("0",
|
37 |
+
nn.Conv2d(inplanes,
|
38 |
+
planes * self.expansion,
|
39 |
+
1,
|
40 |
+
stride=1,
|
41 |
+
bias=False)),
|
42 |
+
("1", nn.BatchNorm2d(planes * self.expansion))]))
|
43 |
+
|
44 |
+
def forward(self, x: torch.Tensor):
|
45 |
+
identity = x
|
46 |
+
|
47 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
48 |
+
out = self.relu(self.bn2(self.conv2(out)))
|
49 |
+
out = self.avgpool(out)
|
50 |
+
out = self.bn3(self.conv3(out))
|
51 |
+
|
52 |
+
if self.downsample is not None:
|
53 |
+
identity = self.downsample(x)
|
54 |
+
|
55 |
+
out += identity
|
56 |
+
out = self.relu(out)
|
57 |
+
return out
|
58 |
+
|
59 |
+
|
60 |
+
class AttentionPool2d(nn.Module):
|
61 |
+
def __init__(self,
|
62 |
+
spacial_dim: int,
|
63 |
+
embed_dim: int,
|
64 |
+
num_heads: int,
|
65 |
+
output_dim: int = None):
|
66 |
+
super().__init__()
|
67 |
+
self.spacial_dim = spacial_dim
|
68 |
+
self.positional_embedding = nn.Parameter(
|
69 |
+
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
|
70 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
71 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
72 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
73 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
74 |
+
self.num_heads = num_heads
|
75 |
+
# residual
|
76 |
+
self.connect = nn.Sequential(
|
77 |
+
nn.Conv2d(embed_dim, output_dim, 1, stride=1, bias=False),
|
78 |
+
nn.BatchNorm2d(output_dim))
|
79 |
+
|
80 |
+
def resize_pos_embed(self, pos_embed, input_shpae):
|
81 |
+
"""Resize pos_embed weights.
|
82 |
+
Resize pos_embed using bicubic interpolate method.
|
83 |
+
Args:
|
84 |
+
pos_embed (torch.Tensor): Position embedding weights.
|
85 |
+
input_shpae (tuple): Tuple for (downsampled input image height,
|
86 |
+
downsampled input image width).
|
87 |
+
pos_shape (tuple): The resolution of downsampled origin training
|
88 |
+
image.
|
89 |
+
mode (str): Algorithm used for upsampling:
|
90 |
+
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
|
91 |
+
``'trilinear'``. Default: ``'nearest'``
|
92 |
+
Return:
|
93 |
+
torch.Tensor: The resized pos_embed of shape [B, C, L_new]
|
94 |
+
"""
|
95 |
+
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
|
96 |
+
pos_h = pos_w = self.spacial_dim
|
97 |
+
cls_token_weight = pos_embed[:, 0]
|
98 |
+
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
|
99 |
+
pos_embed_weight = pos_embed_weight.reshape(
|
100 |
+
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
|
101 |
+
pos_embed_weight = F.interpolate(pos_embed_weight,
|
102 |
+
size=input_shpae,
|
103 |
+
align_corners=False,
|
104 |
+
mode='bicubic')
|
105 |
+
cls_token_weight = cls_token_weight.unsqueeze(1)
|
106 |
+
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
|
107 |
+
# pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
|
108 |
+
return pos_embed_weight.transpose(-2, -1)
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
B, C, H, W = x.size()
|
112 |
+
res = self.connect(x)
|
113 |
+
x = x.reshape(B, C, -1) # NC(HW)
|
114 |
+
# x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(1+HW)
|
115 |
+
pos_embed = self.positional_embedding.unsqueeze(0)
|
116 |
+
pos_embed = self.resize_pos_embed(pos_embed, (H, W)) # NC(HW)
|
117 |
+
x = x + pos_embed.to(x.dtype) # NC(HW)
|
118 |
+
x = x.permute(2, 0, 1) # (HW)NC
|
119 |
+
x, _ = F.multi_head_attention_forward(
|
120 |
+
query=x,
|
121 |
+
key=x,
|
122 |
+
value=x,
|
123 |
+
embed_dim_to_check=x.shape[-1],
|
124 |
+
num_heads=self.num_heads,
|
125 |
+
q_proj_weight=self.q_proj.weight,
|
126 |
+
k_proj_weight=self.k_proj.weight,
|
127 |
+
v_proj_weight=self.v_proj.weight,
|
128 |
+
in_proj_weight=None,
|
129 |
+
in_proj_bias=torch.cat(
|
130 |
+
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
131 |
+
bias_k=None,
|
132 |
+
bias_v=None,
|
133 |
+
add_zero_attn=False,
|
134 |
+
dropout_p=0,
|
135 |
+
out_proj_weight=self.c_proj.weight,
|
136 |
+
out_proj_bias=self.c_proj.bias,
|
137 |
+
use_separate_proj_weight=True,
|
138 |
+
training=self.training,
|
139 |
+
need_weights=False)
|
140 |
+
x = x.permute(1, 2, 0).reshape(B, -1, H, W)
|
141 |
+
x = x + res
|
142 |
+
x = F.relu(x, True)
|
143 |
+
|
144 |
+
return x
|
145 |
+
|
146 |
+
|
147 |
+
class ModifiedResNet(nn.Module):
|
148 |
+
"""
|
149 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
150 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
151 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
152 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
153 |
+
"""
|
154 |
+
def __init__(self,
|
155 |
+
layers,
|
156 |
+
output_dim,
|
157 |
+
heads,
|
158 |
+
input_resolution=224,
|
159 |
+
width=64):
|
160 |
+
super().__init__()
|
161 |
+
self.output_dim = output_dim
|
162 |
+
self.input_resolution = input_resolution
|
163 |
+
|
164 |
+
# the 3-layer stem
|
165 |
+
self.conv1 = nn.Conv2d(3,
|
166 |
+
width // 2,
|
167 |
+
kernel_size=3,
|
168 |
+
stride=2,
|
169 |
+
padding=1,
|
170 |
+
bias=False)
|
171 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
172 |
+
self.conv2 = nn.Conv2d(width // 2,
|
173 |
+
width // 2,
|
174 |
+
kernel_size=3,
|
175 |
+
padding=1,
|
176 |
+
bias=False)
|
177 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
178 |
+
self.conv3 = nn.Conv2d(width // 2,
|
179 |
+
width,
|
180 |
+
kernel_size=3,
|
181 |
+
padding=1,
|
182 |
+
bias=False)
|
183 |
+
self.bn3 = nn.BatchNorm2d(width)
|
184 |
+
self.avgpool = nn.AvgPool2d(2)
|
185 |
+
self.relu = nn.ReLU(inplace=True)
|
186 |
+
|
187 |
+
# residual layers
|
188 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
189 |
+
self.layer1 = self._make_layer(width, layers[0])
|
190 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
191 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
192 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
193 |
+
|
194 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
195 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim,
|
196 |
+
heads, output_dim)
|
197 |
+
|
198 |
+
def _make_layer(self, planes, blocks, stride=1):
|
199 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
200 |
+
|
201 |
+
self._inplanes = planes * Bottleneck.expansion
|
202 |
+
for _ in range(1, blocks):
|
203 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
204 |
+
|
205 |
+
return nn.Sequential(*layers)
|
206 |
+
|
207 |
+
def forward(self, x):
|
208 |
+
def stem(x):
|
209 |
+
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2),
|
210 |
+
(self.conv3, self.bn3)]:
|
211 |
+
x = self.relu(bn(conv(x)))
|
212 |
+
x = self.avgpool(x)
|
213 |
+
return x
|
214 |
+
|
215 |
+
x = x.type(self.conv1.weight.dtype)
|
216 |
+
x = stem(x)
|
217 |
+
x = self.layer1(x)
|
218 |
+
x2 = self.layer2(x)
|
219 |
+
x3 = self.layer3(x2)
|
220 |
+
x4 = self.layer4(x3)
|
221 |
+
x4 = self.attnpool(x4)
|
222 |
+
|
223 |
+
return (x2, x3, x4)
|
224 |
+
|
225 |
+
|
226 |
+
class LayerNorm(nn.LayerNorm):
|
227 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
228 |
+
def forward(self, x: torch.Tensor):
|
229 |
+
orig_type = x.dtype
|
230 |
+
ret = super().forward(x.type(torch.float32))
|
231 |
+
return ret.type(orig_type)
|
232 |
+
|
233 |
+
|
234 |
+
class QuickGELU(nn.Module):
|
235 |
+
def forward(self, x: torch.Tensor):
|
236 |
+
return x * torch.sigmoid(1.702 * x)
|
237 |
+
|
238 |
+
|
239 |
+
class ResidualAttentionBlock(nn.Module):
|
240 |
+
def __init__(self,
|
241 |
+
d_model: int,
|
242 |
+
n_head: int,
|
243 |
+
attn_mask: torch.Tensor = None):
|
244 |
+
super().__init__()
|
245 |
+
|
246 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
247 |
+
self.ln_1 = LayerNorm(d_model)
|
248 |
+
self.mlp = nn.Sequential(
|
249 |
+
OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)),
|
250 |
+
("gelu", QuickGELU()),
|
251 |
+
("c_proj", nn.Linear(d_model * 4, d_model))]))
|
252 |
+
self.ln_2 = LayerNorm(d_model)
|
253 |
+
self.attn_mask = attn_mask
|
254 |
+
|
255 |
+
def attention(self, x: torch.Tensor):
|
256 |
+
self.attn_mask = self.attn_mask.to(
|
257 |
+
dtype=x.dtype,
|
258 |
+
device=x.device) if self.attn_mask is not None else None
|
259 |
+
return self.attn(x, x, x, need_weights=False,
|
260 |
+
attn_mask=self.attn_mask)[0]
|
261 |
+
|
262 |
+
def forward(self, x: torch.Tensor):
|
263 |
+
x = x + self.attention(self.ln_1(x))
|
264 |
+
x = x + self.mlp(self.ln_2(x))
|
265 |
+
return x
|
266 |
+
|
267 |
+
|
268 |
+
class Transformer(nn.Module):
|
269 |
+
def __init__(self,
|
270 |
+
width: int,
|
271 |
+
layers: int,
|
272 |
+
heads: int,
|
273 |
+
attn_mask: torch.Tensor = None):
|
274 |
+
super().__init__()
|
275 |
+
self.width = width
|
276 |
+
self.layers = layers
|
277 |
+
self.resblocks = nn.Sequential(*[
|
278 |
+
ResidualAttentionBlock(width, heads, attn_mask)
|
279 |
+
for _ in range(layers)
|
280 |
+
])
|
281 |
+
|
282 |
+
def forward(self, x: torch.Tensor):
|
283 |
+
return self.resblocks(x)
|
284 |
+
|
285 |
+
|
286 |
+
class VisionTransformer(nn.Module):
|
287 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int,
|
288 |
+
layers: int, heads: int, output_dim: int):
|
289 |
+
super().__init__()
|
290 |
+
self.input_resolution = input_resolution
|
291 |
+
self.output_dim = output_dim
|
292 |
+
self.conv1 = nn.Conv2d(in_channels=3,
|
293 |
+
out_channels=width,
|
294 |
+
kernel_size=patch_size,
|
295 |
+
stride=patch_size,
|
296 |
+
bias=False)
|
297 |
+
|
298 |
+
scale = width**-0.5
|
299 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
300 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn(
|
301 |
+
(input_resolution // patch_size)**2 + 1, width))
|
302 |
+
self.ln_pre = LayerNorm(width)
|
303 |
+
|
304 |
+
self.transformer = Transformer(width, layers, heads)
|
305 |
+
|
306 |
+
self.ln_post = LayerNorm(width)
|
307 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
308 |
+
|
309 |
+
def forward(self, x: torch.Tensor):
|
310 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
311 |
+
x = x.reshape(x.shape[0], x.shape[1],
|
312 |
+
-1) # shape = [*, width, grid ** 2]
|
313 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
314 |
+
x = torch.cat([
|
315 |
+
self.class_embedding.to(x.dtype) + torch.zeros(
|
316 |
+
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
|
317 |
+
],
|
318 |
+
dim=1) # shape = [*, grid ** 2 + 1, width]
|
319 |
+
x = x + self.positional_embedding.to(x.dtype)
|
320 |
+
x = self.ln_pre(x)
|
321 |
+
|
322 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
323 |
+
x = self.transformer(x)
|
324 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
325 |
+
|
326 |
+
# x = self.ln_post(x[:, 0, :])
|
327 |
+
x = self.ln_post(x[:, 1:, :])
|
328 |
+
|
329 |
+
if self.proj is not None:
|
330 |
+
x = x @ self.proj
|
331 |
+
|
332 |
+
return x
|
333 |
+
|
334 |
+
|
335 |
+
class CLIP(nn.Module):
|
336 |
+
def __init__(
|
337 |
+
self,
|
338 |
+
embed_dim: int,
|
339 |
+
# vision
|
340 |
+
image_resolution: int,
|
341 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
342 |
+
vision_width: int,
|
343 |
+
vision_patch_size: int,
|
344 |
+
# text
|
345 |
+
context_length: int,
|
346 |
+
txt_length: int,
|
347 |
+
vocab_size: int,
|
348 |
+
transformer_width: int,
|
349 |
+
transformer_heads: int,
|
350 |
+
transformer_layers: int):
|
351 |
+
super().__init__()
|
352 |
+
|
353 |
+
self.context_length = context_length
|
354 |
+
|
355 |
+
if isinstance(vision_layers, (tuple, list)):
|
356 |
+
vision_heads = vision_width * 32 // 64
|
357 |
+
self.visual = ModifiedResNet(layers=vision_layers,
|
358 |
+
output_dim=embed_dim,
|
359 |
+
heads=vision_heads,
|
360 |
+
input_resolution=image_resolution,
|
361 |
+
width=vision_width)
|
362 |
+
else:
|
363 |
+
vision_heads = vision_width // 64
|
364 |
+
self.visual = VisionTransformer(input_resolution=image_resolution,
|
365 |
+
patch_size=vision_patch_size,
|
366 |
+
width=vision_width,
|
367 |
+
layers=vision_layers,
|
368 |
+
heads=vision_heads,
|
369 |
+
output_dim=embed_dim)
|
370 |
+
|
371 |
+
self.transformer = Transformer(
|
372 |
+
width=transformer_width,
|
373 |
+
layers=transformer_layers,
|
374 |
+
heads=transformer_heads,
|
375 |
+
attn_mask=self.build_attention_mask(txt_length))
|
376 |
+
|
377 |
+
self.vocab_size = vocab_size
|
378 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
379 |
+
self.positional_embedding = nn.Parameter(
|
380 |
+
torch.empty(self.context_length, transformer_width))
|
381 |
+
self.ln_final = LayerNorm(transformer_width)
|
382 |
+
|
383 |
+
self.text_projection = nn.Parameter(
|
384 |
+
torch.empty(transformer_width, embed_dim))
|
385 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
386 |
+
|
387 |
+
self.token_embedding.requires_grad_ = False
|
388 |
+
self.initialize_parameters()
|
389 |
+
|
390 |
+
def initialize_parameters(self):
|
391 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
392 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
393 |
+
|
394 |
+
if isinstance(self.visual, ModifiedResNet):
|
395 |
+
if self.visual.attnpool is not None:
|
396 |
+
std = self.visual.attnpool.c_proj.in_features**-0.5
|
397 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
398 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
399 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
400 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
401 |
+
|
402 |
+
for resnet_block in [
|
403 |
+
self.visual.layer1, self.visual.layer2, self.visual.layer3,
|
404 |
+
self.visual.layer4
|
405 |
+
]:
|
406 |
+
for name, param in resnet_block.named_parameters():
|
407 |
+
if name.endswith("bn3.weight"):
|
408 |
+
nn.init.zeros_(param)
|
409 |
+
|
410 |
+
proj_std = (self.transformer.width**-0.5) * (
|
411 |
+
(2 * self.transformer.layers)**-0.5)
|
412 |
+
attn_std = self.transformer.width**-0.5
|
413 |
+
fc_std = (2 * self.transformer.width)**-0.5
|
414 |
+
for block in self.transformer.resblocks:
|
415 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
416 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
417 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
418 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
419 |
+
|
420 |
+
if self.text_projection is not None:
|
421 |
+
nn.init.normal_(self.text_projection,
|
422 |
+
std=self.transformer.width**-0.5)
|
423 |
+
|
424 |
+
def build_attention_mask(self, context_length):
|
425 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
426 |
+
# pytorch uses additive attention mask; fill with -inf
|
427 |
+
mask = torch.empty(context_length, context_length)
|
428 |
+
mask.fill_(float("-inf"))
|
429 |
+
mask.triu_(1) # zero out the lower diagonal
|
430 |
+
return mask
|
431 |
+
|
432 |
+
@property
|
433 |
+
def dtype(self):
|
434 |
+
return self.visual.conv1.weight.dtype
|
435 |
+
|
436 |
+
def encode_image(self, image):
|
437 |
+
return self.visual(image.type(self.dtype))
|
438 |
+
|
439 |
+
def encode_text(self, text):
|
440 |
+
x = self.token_embedding(text).type(
|
441 |
+
self.dtype) # [batch_size, n_ctx, d_model]
|
442 |
+
|
443 |
+
x = x + self.positional_embedding.type(self.dtype)[:x.size(1)]
|
444 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
445 |
+
x = self.transformer(x)
|
446 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
447 |
+
x = self.ln_final(x).type(self.dtype)
|
448 |
+
|
449 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
450 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
451 |
+
state = x[torch.arange(x.shape[0]),
|
452 |
+
text.argmax(dim=-1)] @ self.text_projection
|
453 |
+
# x = x @ self.text_projection
|
454 |
+
# state = x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
|
455 |
+
|
456 |
+
return x, state
|
457 |
+
|
458 |
+
def forward(self, image, text):
|
459 |
+
image_features = self.encode_image(image)
|
460 |
+
text_features = self.encode_text(text)
|
461 |
+
|
462 |
+
# normalized features
|
463 |
+
image_features = image_features / image_features.norm(dim=-1,
|
464 |
+
keepdim=True)
|
465 |
+
text_features = text_features / text_features.norm(dim=-1,
|
466 |
+
keepdim=True)
|
467 |
+
|
468 |
+
# cosine similarity as logits
|
469 |
+
logit_scale = self.logit_scale.exp()
|
470 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
471 |
+
logits_per_text = logits_per_image.t()
|
472 |
+
|
473 |
+
# shape = [global_batch_size, global_batch_size]
|
474 |
+
return logits_per_image, logits_per_text
|
475 |
+
|
476 |
+
|
477 |
+
def convert_weights(model: nn.Module):
|
478 |
+
"""Convert applicable model parameters to fp16"""
|
479 |
+
def _convert_weights_to_fp16(l):
|
480 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
481 |
+
l.weight.data = l.weight.data.half()
|
482 |
+
if l.bias is not None:
|
483 |
+
l.bias.data = l.bias.data.half()
|
484 |
+
|
485 |
+
if isinstance(l, nn.MultiheadAttention):
|
486 |
+
for attr in [
|
487 |
+
*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
|
488 |
+
"in_proj_bias", "bias_k", "bias_v"
|
489 |
+
]:
|
490 |
+
tensor = getattr(l, attr)
|
491 |
+
if tensor is not None:
|
492 |
+
tensor.data = tensor.data.half()
|
493 |
+
|
494 |
+
for name in ["text_projection", "proj"]:
|
495 |
+
if hasattr(l, name):
|
496 |
+
attr = getattr(l, name)
|
497 |
+
if attr is not None:
|
498 |
+
attr.data = attr.data.half()
|
499 |
+
|
500 |
+
model.apply(_convert_weights_to_fp16)
|
501 |
+
|
502 |
+
|
503 |
+
def build_model(state_dict: dict, txt_length: int, freeze: bool):
|
504 |
+
vit = "visual.proj" in state_dict
|
505 |
+
|
506 |
+
if vit:
|
507 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
508 |
+
vision_layers = len([
|
509 |
+
k for k in state_dict.keys()
|
510 |
+
if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
|
511 |
+
])
|
512 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
513 |
+
grid_size = round(
|
514 |
+
(state_dict["visual.positional_embedding"].shape[0] - 1)**0.5)
|
515 |
+
image_resolution = vision_patch_size * grid_size
|
516 |
+
else:
|
517 |
+
counts: list = [
|
518 |
+
len(
|
519 |
+
set(
|
520 |
+
k.split(".")[2] for k in state_dict
|
521 |
+
if k.startswith(f"visual.layer{b}")))
|
522 |
+
for b in [1, 2, 3, 4]
|
523 |
+
]
|
524 |
+
vision_layers = tuple(counts)
|
525 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
526 |
+
output_width = round(
|
527 |
+
(state_dict["visual.attnpool.positional_embedding"].shape[0] -
|
528 |
+
1)**0.5)
|
529 |
+
vision_patch_size = None
|
530 |
+
assert output_width**2 + 1 == state_dict[
|
531 |
+
"visual.attnpool.positional_embedding"].shape[0]
|
532 |
+
image_resolution = output_width * 32
|
533 |
+
|
534 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
535 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
536 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
537 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
538 |
+
transformer_heads = transformer_width // 64
|
539 |
+
transformer_layers = len(
|
540 |
+
set(
|
541 |
+
k.split(".")[2] for k in state_dict
|
542 |
+
if k.startswith(f"transformer.resblocks")))
|
543 |
+
|
544 |
+
model = CLIP(embed_dim, image_resolution, vision_layers, vision_width,
|
545 |
+
vision_patch_size, context_length, txt_length, vocab_size,
|
546 |
+
transformer_width, transformer_heads, transformer_layers)
|
547 |
+
|
548 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
549 |
+
if key in state_dict:
|
550 |
+
del state_dict[key]
|
551 |
+
|
552 |
+
convert_weights(model)
|
553 |
+
model.load_state_dict(state_dict, False)
|
554 |
+
|
555 |
+
if freeze:
|
556 |
+
print(f"CLIP FROZEN")
|
557 |
+
return model.eval()
|
558 |
+
else:
|
559 |
+
print(f"CLIP FINE TUNING")
|
560 |
+
return model.train()
|
dianecy/VerbCentric-RIS/model_/layers.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
def conv_layer(in_dim, out_dim, kernel_size=1, padding=0, stride=1):
|
9 |
+
return nn.Sequential(
|
10 |
+
nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False),
|
11 |
+
nn.BatchNorm2d(out_dim), nn.ReLU(True))
|
12 |
+
|
13 |
+
|
14 |
+
def linear_layer(in_dim, out_dim, bias=False):
|
15 |
+
return nn.Sequential(nn.Linear(in_dim, out_dim, bias),
|
16 |
+
nn.BatchNorm1d(out_dim), nn.ReLU(True))
|
17 |
+
|
18 |
+
|
19 |
+
class CoordConv(nn.Module):
|
20 |
+
def __init__(self,
|
21 |
+
in_channels,
|
22 |
+
out_channels,
|
23 |
+
kernel_size=3,
|
24 |
+
padding=1,
|
25 |
+
stride=1):
|
26 |
+
super().__init__()
|
27 |
+
self.conv1 = conv_layer(in_channels + 2, out_channels, kernel_size,
|
28 |
+
padding, stride)
|
29 |
+
|
30 |
+
def add_coord(self, input):
|
31 |
+
b, _, h, w = input.size()
|
32 |
+
x_range = torch.linspace(-1, 1, w, device=input.device)
|
33 |
+
y_range = torch.linspace(-1, 1, h, device=input.device)
|
34 |
+
y, x = torch.meshgrid(y_range, x_range)
|
35 |
+
y = y.expand([b, 1, -1, -1])
|
36 |
+
x = x.expand([b, 1, -1, -1])
|
37 |
+
coord_feat = torch.cat([x, y], 1)
|
38 |
+
input = torch.cat([input, coord_feat], 1)
|
39 |
+
return input
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
x = self.add_coord(x)
|
43 |
+
x = self.conv1(x)
|
44 |
+
return x
|
45 |
+
|
46 |
+
|
47 |
+
class Projector(nn.Module):
|
48 |
+
def __init__(self, word_dim=1024, in_dim=256, kernel_size=3):
|
49 |
+
super().__init__()
|
50 |
+
self.in_dim = in_dim
|
51 |
+
self.kernel_size = kernel_size
|
52 |
+
# visual projector
|
53 |
+
self.vis = nn.Sequential( # os16 -> os4
|
54 |
+
nn.Upsample(scale_factor=2, mode='bilinear'),
|
55 |
+
conv_layer(in_dim * 2, in_dim * 2, 3, padding=1),
|
56 |
+
nn.Upsample(scale_factor=2, mode='bilinear'),
|
57 |
+
conv_layer(in_dim * 2, in_dim, 3, padding=1),
|
58 |
+
nn.Conv2d(in_dim, in_dim, 1))
|
59 |
+
# textual projector
|
60 |
+
out_dim = 1 * in_dim * kernel_size * kernel_size + 1
|
61 |
+
self.txt = nn.Linear(word_dim, out_dim)
|
62 |
+
|
63 |
+
def forward(self, x, word):
|
64 |
+
'''
|
65 |
+
x: b, 512, 26, 26
|
66 |
+
word: b, 512
|
67 |
+
'''
|
68 |
+
x = self.vis(x)
|
69 |
+
B, C, H, W = x.size()
|
70 |
+
# 1, b*256, 104, 104
|
71 |
+
x = x.reshape(1, B * C, H, W)
|
72 |
+
# txt: b, (256*3*3 + 1) -> b, 256, 3, 3 / b
|
73 |
+
word = self.txt(word)
|
74 |
+
weight, bias = word[:, :-1], word[:, -1]
|
75 |
+
weight = weight.reshape(B, C, self.kernel_size, self.kernel_size)
|
76 |
+
# Conv2d - 1, b*256, 104, 104 -> 1, b, 104, 104
|
77 |
+
out = F.conv2d(x,
|
78 |
+
weight,
|
79 |
+
padding=self.kernel_size // 2,
|
80 |
+
groups=weight.size(0),
|
81 |
+
bias=bias)
|
82 |
+
out = out.transpose(0, 1)
|
83 |
+
# b, 1, 104, 104
|
84 |
+
return out
|
85 |
+
|
86 |
+
|
87 |
+
class TransformerDecoder(nn.Module):
|
88 |
+
def __init__(self,
|
89 |
+
num_layers,
|
90 |
+
d_model,
|
91 |
+
nhead,
|
92 |
+
dim_ffn,
|
93 |
+
dropout,
|
94 |
+
return_intermediate=False):
|
95 |
+
super().__init__()
|
96 |
+
self.layers = nn.ModuleList([
|
97 |
+
TransformerDecoderLayer(d_model=d_model,
|
98 |
+
nhead=nhead,
|
99 |
+
dim_feedforward=dim_ffn,
|
100 |
+
dropout=dropout) for _ in range(num_layers)
|
101 |
+
])
|
102 |
+
self.num_layers = num_layers
|
103 |
+
self.norm = nn.LayerNorm(d_model)
|
104 |
+
self.return_intermediate = return_intermediate
|
105 |
+
|
106 |
+
@staticmethod
|
107 |
+
def pos1d(d_model, length):
|
108 |
+
"""
|
109 |
+
:param d_model: dimension of the model
|
110 |
+
:param length: length of positions
|
111 |
+
:return: length*d_model position matrix
|
112 |
+
"""
|
113 |
+
if d_model % 2 != 0:
|
114 |
+
raise ValueError("Cannot use sin/cos positional encoding with "
|
115 |
+
"odd dim (got dim={:d})".format(d_model))
|
116 |
+
pe = torch.zeros(length, d_model)
|
117 |
+
position = torch.arange(0, length).unsqueeze(1)
|
118 |
+
div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
|
119 |
+
-(math.log(10000.0) / d_model)))
|
120 |
+
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
121 |
+
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
122 |
+
|
123 |
+
return pe.unsqueeze(1) # n, 1, 512
|
124 |
+
|
125 |
+
@staticmethod
|
126 |
+
def pos2d(d_model, height, width):
|
127 |
+
"""
|
128 |
+
:param d_model: dimension of the model
|
129 |
+
:param height: height of the positions
|
130 |
+
:param width: width of the positions
|
131 |
+
:return: d_model*height*width position matrix
|
132 |
+
"""
|
133 |
+
if d_model % 4 != 0:
|
134 |
+
raise ValueError("Cannot use sin/cos positional encoding with "
|
135 |
+
"odd dimension (got dim={:d})".format(d_model))
|
136 |
+
pe = torch.zeros(d_model, height, width)
|
137 |
+
# Each dimension use half of d_model
|
138 |
+
d_model = int(d_model / 2)
|
139 |
+
div_term = torch.exp(
|
140 |
+
torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))
|
141 |
+
pos_w = torch.arange(0., width).unsqueeze(1)
|
142 |
+
pos_h = torch.arange(0., height).unsqueeze(1)
|
143 |
+
pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(
|
144 |
+
0, 1).unsqueeze(1).repeat(1, height, 1)
|
145 |
+
pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(
|
146 |
+
0, 1).unsqueeze(1).repeat(1, height, 1)
|
147 |
+
pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(
|
148 |
+
0, 1).unsqueeze(2).repeat(1, 1, width)
|
149 |
+
pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(
|
150 |
+
0, 1).unsqueeze(2).repeat(1, 1, width)
|
151 |
+
|
152 |
+
return pe.reshape(-1, 1, height * width).permute(2, 1, 0) # hw, 1, 512
|
153 |
+
|
154 |
+
def forward(self, vis, txt, pad_mask):
|
155 |
+
'''
|
156 |
+
vis: b, 512, h, w
|
157 |
+
txt: b, L, 512
|
158 |
+
pad_mask: b, L
|
159 |
+
'''
|
160 |
+
B, C, H, W = vis.size()
|
161 |
+
_, L, D = txt.size()
|
162 |
+
# position encoding
|
163 |
+
vis_pos = self.pos2d(C, H, W)
|
164 |
+
txt_pos = self.pos1d(D, L)
|
165 |
+
# reshape & permute
|
166 |
+
vis = vis.reshape(B, C, -1).permute(2, 0, 1)
|
167 |
+
txt = txt.permute(1, 0, 2)
|
168 |
+
# forward
|
169 |
+
output = vis
|
170 |
+
intermediate = []
|
171 |
+
for layer in self.layers:
|
172 |
+
output = layer(output, txt, vis_pos, txt_pos, pad_mask)
|
173 |
+
if self.return_intermediate:
|
174 |
+
# HW, b, 512 -> b, 512, HW
|
175 |
+
intermediate.append(self.norm(output).permute(1, 2, 0))
|
176 |
+
|
177 |
+
if self.norm is not None:
|
178 |
+
# HW, b, 512 -> b, 512, HW
|
179 |
+
output = self.norm(output).permute(1, 2, 0)
|
180 |
+
if self.return_intermediate:
|
181 |
+
intermediate.pop()
|
182 |
+
intermediate.append(output)
|
183 |
+
# [output1, output2, ..., output_n]
|
184 |
+
return intermediate
|
185 |
+
else:
|
186 |
+
# b, 512, HW
|
187 |
+
return output
|
188 |
+
return output
|
189 |
+
|
190 |
+
|
191 |
+
class TransformerDecoderLayer(nn.Module):
|
192 |
+
def __init__(self,
|
193 |
+
d_model=512,
|
194 |
+
nhead=9,
|
195 |
+
dim_feedforward=2048,
|
196 |
+
dropout=0.1):
|
197 |
+
super().__init__()
|
198 |
+
# Normalization Layer
|
199 |
+
self.self_attn_norm = nn.LayerNorm(d_model)
|
200 |
+
self.cross_attn_norm = nn.LayerNorm(d_model)
|
201 |
+
# Attention Layer
|
202 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
203 |
+
self.multihead_attn = nn.MultiheadAttention(d_model,
|
204 |
+
nhead,
|
205 |
+
dropout=dropout,
|
206 |
+
kdim=d_model,
|
207 |
+
vdim=d_model)
|
208 |
+
# FFN
|
209 |
+
self.ffn = nn.Sequential(nn.Linear(d_model, dim_feedforward),
|
210 |
+
nn.ReLU(True), nn.Dropout(dropout),
|
211 |
+
nn.LayerNorm(dim_feedforward),
|
212 |
+
nn.Linear(dim_feedforward, d_model))
|
213 |
+
# LayerNorm & Dropout
|
214 |
+
self.norm1 = nn.LayerNorm(d_model)
|
215 |
+
self.norm2 = nn.LayerNorm(d_model)
|
216 |
+
self.norm3 = nn.LayerNorm(d_model)
|
217 |
+
self.dropout1 = nn.Dropout(dropout)
|
218 |
+
self.dropout2 = nn.Dropout(dropout)
|
219 |
+
self.dropout3 = nn.Dropout(dropout)
|
220 |
+
|
221 |
+
def with_pos_embed(self, tensor, pos):
|
222 |
+
return tensor if pos is None else tensor + pos.to(tensor.device)
|
223 |
+
|
224 |
+
def forward(self, vis, txt, vis_pos, txt_pos, pad_mask):
|
225 |
+
'''
|
226 |
+
vis: 26*26, b, 512
|
227 |
+
txt: L, b, 512
|
228 |
+
vis_pos: 26*26, 1, 512
|
229 |
+
txt_pos: L, 1, 512
|
230 |
+
pad_mask: b, L
|
231 |
+
'''
|
232 |
+
# Self-Attention
|
233 |
+
vis2 = self.norm1(vis)
|
234 |
+
q = k = self.with_pos_embed(vis2, vis_pos)
|
235 |
+
vis2 = self.self_attn(q, k, value=vis2)[0]
|
236 |
+
vis2 = self.self_attn_norm(vis2)
|
237 |
+
vis = vis + self.dropout1(vis2)
|
238 |
+
# Cross-Attention
|
239 |
+
vis2 = self.norm2(vis)
|
240 |
+
vis2 = self.multihead_attn(query=self.with_pos_embed(vis2, vis_pos),
|
241 |
+
key=self.with_pos_embed(txt, txt_pos),
|
242 |
+
value=txt,
|
243 |
+
key_padding_mask=pad_mask)[0]
|
244 |
+
vis2 = self.cross_attn_norm(vis2)
|
245 |
+
vis = vis + self.dropout2(vis2)
|
246 |
+
# FFN
|
247 |
+
vis2 = self.norm3(vis)
|
248 |
+
vis2 = self.ffn(vis2)
|
249 |
+
vis = vis + self.dropout3(vis2)
|
250 |
+
return vis
|
251 |
+
|
252 |
+
|
253 |
+
class FPN(nn.Module):
|
254 |
+
def __init__(self,
|
255 |
+
in_channels=[512, 1024, 1024],
|
256 |
+
out_channels=[256, 512, 1024]):
|
257 |
+
super(FPN, self).__init__()
|
258 |
+
# text projection
|
259 |
+
self.txt_proj = linear_layer(in_channels[2], out_channels[2])
|
260 |
+
# fusion 1: v5 & seq -> f_5: b, 1024, 13, 13
|
261 |
+
self.f1_v_proj = conv_layer(in_channels[2], out_channels[2], 1, 0)
|
262 |
+
self.norm_layer = nn.Sequential(nn.BatchNorm2d(out_channels[2]),
|
263 |
+
nn.ReLU(True))
|
264 |
+
# fusion 2: v4 & fm -> f_4: b, 512, 26, 26
|
265 |
+
self.f2_v_proj = conv_layer(in_channels[1], out_channels[1], 3, 1)
|
266 |
+
self.f2_cat = conv_layer(out_channels[2] + out_channels[1],
|
267 |
+
out_channels[1], 1, 0)
|
268 |
+
# fusion 3: v3 & fm_mid -> f_3: b, 512, 52, 52
|
269 |
+
self.f3_v_proj = conv_layer(in_channels[0], out_channels[0], 3, 1)
|
270 |
+
self.f3_cat = conv_layer(out_channels[0] + out_channels[1],
|
271 |
+
out_channels[1], 1, 0)
|
272 |
+
# fusion 4: f_3 & f_4 & f_5 -> fq: b, 256, 26, 26
|
273 |
+
self.f4_proj5 = conv_layer(out_channels[2], out_channels[1], 3, 1)
|
274 |
+
self.f4_proj4 = conv_layer(out_channels[1], out_channels[1], 3, 1)
|
275 |
+
self.f4_proj3 = conv_layer(out_channels[1], out_channels[1], 3, 1)
|
276 |
+
# aggregation
|
277 |
+
self.aggr = conv_layer(3 * out_channels[1], out_channels[1], 1, 0)
|
278 |
+
self.coordconv = nn.Sequential(
|
279 |
+
CoordConv(out_channels[1], out_channels[1], 3, 1),
|
280 |
+
conv_layer(out_channels[1], out_channels[1], 3, 1))
|
281 |
+
|
282 |
+
def forward(self, imgs, state):
|
283 |
+
# v3, v4, v5: 256, 52, 52 / 512, 26, 26 / 1024, 13, 13
|
284 |
+
v3, v4, v5 = imgs
|
285 |
+
# fusion 1: b, 1024, 13, 13
|
286 |
+
# text projection: b, 1024 -> b, 1024
|
287 |
+
state = self.txt_proj(state).unsqueeze(-1).unsqueeze(
|
288 |
+
-1) # b, 1024, 1, 1
|
289 |
+
f5 = self.f1_v_proj(v5)
|
290 |
+
f5 = self.norm_layer(f5 * state)
|
291 |
+
# fusion 2: b, 512, 26, 26
|
292 |
+
f4 = self.f2_v_proj(v4)
|
293 |
+
f5_ = F.interpolate(f5, scale_factor=2, mode='bilinear')
|
294 |
+
f4 = self.f2_cat(torch.cat([f4, f5_], dim=1))
|
295 |
+
# fusion 3: b, 256, 26, 26
|
296 |
+
f3 = self.f3_v_proj(v3)
|
297 |
+
f3 = F.avg_pool2d(f3, 2, 2)
|
298 |
+
f3 = self.f3_cat(torch.cat([f3, f4], dim=1))
|
299 |
+
# fusion 4: b, 512, 13, 13 / b, 512, 26, 26 / b, 512, 26, 26
|
300 |
+
fq5 = self.f4_proj5(f5)
|
301 |
+
fq4 = self.f4_proj4(f4)
|
302 |
+
fq3 = self.f4_proj3(f3)
|
303 |
+
# query
|
304 |
+
fq5 = F.interpolate(fq5, scale_factor=2, mode='bilinear')
|
305 |
+
fq = torch.cat([fq3, fq4, fq5], dim=1)
|
306 |
+
fq = self.aggr(fq)
|
307 |
+
fq = self.coordconv(fq)
|
308 |
+
# b, 512, 26, 26
|
309 |
+
return fq, f5
|
dianecy/VerbCentric-RIS/model_/segmenter.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from model_.clip import build_model
|
7 |
+
from .layers import FPN, Projector, TransformerDecoder
|
8 |
+
|
9 |
+
|
10 |
+
def MetricLoss(embeddings, n_pos, alpha = 0.5, args = None):
|
11 |
+
# embeddings: ((2*B), C, (H*W))
|
12 |
+
# n_pos : chunk size of positive pairs
|
13 |
+
# args: args
|
14 |
+
# returns: loss
|
15 |
+
metric_loss = 0
|
16 |
+
|
17 |
+
# flatten embeddings
|
18 |
+
B_, C, HW = embeddings.shape
|
19 |
+
emb = torch.mean(embeddings, dim=-1) # (2*B, C)
|
20 |
+
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C)
|
21 |
+
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C)
|
22 |
+
emb_distance = torch.norm(emb_i - emb_j, dim=-1) # (2*B, 2*B)
|
23 |
+
assert torch.sum(torch.diag(emb_distance[:B_, :B_])) == 0, \
|
24 |
+
"Diagonals are not zero. please check the permutation on the batch"
|
25 |
+
# print("distance metrix : ", emb_distance)
|
26 |
+
|
27 |
+
# positive pairs and loss
|
28 |
+
positive_mask = torch.zeros_like(emb_distance)
|
29 |
+
for i in range(B_//2):
|
30 |
+
positive_mask[2*i, 2*i+1] = 1
|
31 |
+
positive_mask[2*i+1, 2*i] = 1
|
32 |
+
positive_mask.fill_diagonal_(1)
|
33 |
+
positive_loss = torch.sum(emb_distance * positive_mask) / B_
|
34 |
+
|
35 |
+
# negative pairs and loss
|
36 |
+
negative_mask = torch.ones_like(emb_distance) - positive_mask
|
37 |
+
|
38 |
+
if args.div_batch:
|
39 |
+
negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / B_)
|
40 |
+
else:
|
41 |
+
negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / (B_**2 - 2*B_))
|
42 |
+
|
43 |
+
# print(positive_mask, negative_mask)
|
44 |
+
|
45 |
+
metric_loss = alpha * positive_loss + (1-alpha) * negative_loss
|
46 |
+
|
47 |
+
return metric_loss
|
48 |
+
|
49 |
+
|
50 |
+
def AngularMetricLoss(embeddings, n_pos, alpha = 0.5, args = None):
|
51 |
+
# embeddings: ((2*B), C, (H*W))
|
52 |
+
# n_pos : chunk size of positive pairs
|
53 |
+
# args: args
|
54 |
+
# returns: loss
|
55 |
+
geometric_loss = 0
|
56 |
+
|
57 |
+
# flatten embeddings
|
58 |
+
B_, C, HW = embeddings.shape
|
59 |
+
emb = torch.mean(embeddings, dim=-1) # (2*B, C)
|
60 |
+
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C)
|
61 |
+
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C)
|
62 |
+
sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
|
63 |
+
sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (2*B , 2*B)
|
64 |
+
print(sim_matrix)
|
65 |
+
assert torch.trace(sim_matrix) == B_, \
|
66 |
+
"similarity diagonals are not one. please check the permutation on the batch"
|
67 |
+
print("similarity metrix : ", sim_matrix)
|
68 |
+
phi = torch.acos(sim_matrix) # (2*B, 2*B)
|
69 |
+
print("phi metrix : ", phi)
|
70 |
+
|
71 |
+
# positive pairs and loss
|
72 |
+
positive_mask = torch.zeros_like(sim_matrix)
|
73 |
+
for i in range(B_//2):
|
74 |
+
positive_mask[2*i, 2*i+1] = 1
|
75 |
+
positive_mask[2*i+1, 2*i] = 1
|
76 |
+
positive_mask.fill_diagonal_(1)
|
77 |
+
positive_loss = torch.sum((phi**2) * positive_mask) / B_
|
78 |
+
|
79 |
+
# negative pairs and loss
|
80 |
+
negative_mask = torch.ones_like(sim_matrix) - positive_mask
|
81 |
+
phi_mask = phi < args.phi_threshold
|
82 |
+
negative_loss = (args.phi_threshold - phi)**2
|
83 |
+
print(negative_mask * phi_mask)
|
84 |
+
negative_loss = torch.sum(negative_loss * negative_mask * phi_mask) / (B_**2 - 2*B_)
|
85 |
+
|
86 |
+
print("pos loss, neg loss : ", positive_loss, negative_loss)
|
87 |
+
|
88 |
+
geometric_loss = alpha * positive_loss + (1-alpha) * negative_loss
|
89 |
+
|
90 |
+
return geometric_loss
|
91 |
+
|
92 |
+
|
93 |
+
class CRIS(nn.Module):
|
94 |
+
def __init__(self, cfg):
|
95 |
+
super().__init__()
|
96 |
+
# Vision & Text Encoder
|
97 |
+
clip_model = torch.jit.load(cfg.clip_pretrain,
|
98 |
+
map_location="cpu").eval()
|
99 |
+
self.backbone = build_model(clip_model.state_dict(), cfg.word_len, cfg.freeze).float()
|
100 |
+
# Multi-Modal FPN
|
101 |
+
self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
|
102 |
+
# Decoder
|
103 |
+
self.decoder = TransformerDecoder(num_layers=cfg.num_layers,
|
104 |
+
d_model=cfg.vis_dim,
|
105 |
+
nhead=cfg.num_head,
|
106 |
+
dim_ffn=cfg.dim_ffn,
|
107 |
+
dropout=cfg.dropout,
|
108 |
+
return_intermediate=cfg.intermediate)
|
109 |
+
# Projector
|
110 |
+
self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3)
|
111 |
+
self.metric_learning = cfg.metric_learning
|
112 |
+
# self.positive_strength = cfg.positive_strength
|
113 |
+
# self.metric_loss_weight = cfg.metric_loss_weight
|
114 |
+
# self.eps = cfg.ptb_rate
|
115 |
+
self.cfg = cfg
|
116 |
+
|
117 |
+
def forward(self, image, text, target=None):
|
118 |
+
'''
|
119 |
+
img: b, 3, h, w
|
120 |
+
word: b, words
|
121 |
+
word_mask: b, words
|
122 |
+
if self.metric_learning:
|
123 |
+
word: b, 2, words
|
124 |
+
word_mask: b, 2, words
|
125 |
+
mask: b, 1, h, w
|
126 |
+
'''
|
127 |
+
metric_learning_flag = (self.metric_learning and self.training)
|
128 |
+
metric_loss = 0
|
129 |
+
|
130 |
+
# 1.Resizing : if metric learning, batch size of the word is doubled
|
131 |
+
if metric_learning_flag:
|
132 |
+
#print("image shape : ", image.shape)
|
133 |
+
b, c, h, w = image.size()
|
134 |
+
# duplicate image and segmentation mask
|
135 |
+
if image is not None:
|
136 |
+
image = torch.cat([image, image], dim=0)
|
137 |
+
image = image.reshape(-1, b, c, h, w).transpose(0, 1).reshape(-1, c, h, w)
|
138 |
+
if target is not None:
|
139 |
+
target = torch.cat([target, target], dim=0)
|
140 |
+
target = target.reshape(-1, b, 1, h, w).transpose(0, 1).reshape(-1, 1, h, w)
|
141 |
+
# duplicate noise mask
|
142 |
+
b_, n_, l_ = text.size()
|
143 |
+
assert n_ == 2 ,"word size should be 2"
|
144 |
+
noise_mask = (text[:, 0, :] == text[:, 1, :])
|
145 |
+
noise_mask = torch.all(noise_mask, dim=-1)
|
146 |
+
noise_mask = noise_mask.unsqueeze(-1).repeat(1, 2).reshape(-1) # 2*b_
|
147 |
+
assert noise_mask.shape[0] == b_ * 2, "noise mask shape should be 2*b_"
|
148 |
+
text = text.reshape(b_ * 2, l_) # 2*b, l
|
149 |
+
|
150 |
+
# print("text shape : ", text.shape)
|
151 |
+
# print("image shape : ", image.shape)
|
152 |
+
# print("target shape : ", target.shape)
|
153 |
+
# print(torch.sum(image[0::2]) == torch.sum(image[1::2]))
|
154 |
+
# print(torch.sum(target[0::2]) == torch.sum(target[1::2]))
|
155 |
+
|
156 |
+
# padding mask used in decoder
|
157 |
+
pad_mask = torch.zeros_like(text).masked_fill_(text == 0, 1).bool()
|
158 |
+
# vis: C3 / C4 / C5
|
159 |
+
# word: b, length, 1024
|
160 |
+
# state: b, 1024
|
161 |
+
vis = self.backbone.encode_image(image)
|
162 |
+
word, state = self.backbone.encode_text(text)
|
163 |
+
|
164 |
+
b_, d_ = state.size()
|
165 |
+
assert b_ == word.size(0), "batch size of state and word should be same"
|
166 |
+
|
167 |
+
|
168 |
+
# 2. State Noising Step : if number of caption is 1,
|
169 |
+
# add noise to the corresponding indices
|
170 |
+
if metric_learning_flag :
|
171 |
+
noise = torch.randn_like(state) * self.eps
|
172 |
+
state[noise_mask] = state[noise_mask] + noise[noise_mask]
|
173 |
+
|
174 |
+
# print("shape of word, state : ", word.shape, state.shape)
|
175 |
+
|
176 |
+
# b, 512, 26, 26 (C4)
|
177 |
+
a3, a4, a5 = vis
|
178 |
+
# print("vis shape in model " , a3.shape, a4.shape, a5.shape)
|
179 |
+
fq, f5 = self.neck(vis, state)
|
180 |
+
b, c, h, w = fq.size()
|
181 |
+
fq = self.decoder(fq, word, pad_mask)
|
182 |
+
# print("decoder output shape : ", fq.shape)
|
183 |
+
# 3. Get metric loss
|
184 |
+
if metric_learning_flag:
|
185 |
+
metric_loss = MetricLoss(fq, 2, alpha=self.positive_strength, args = self.cfg)
|
186 |
+
|
187 |
+
fq = fq.reshape(b, c, h, w)
|
188 |
+
|
189 |
+
# b, 1, 104, 104
|
190 |
+
pred = self.proj(fq, state)
|
191 |
+
|
192 |
+
if self.training:
|
193 |
+
# resize mask
|
194 |
+
if pred.shape[-2:] != target.shape[-2:]:
|
195 |
+
target = F.interpolate(target, pred.shape[-2:],
|
196 |
+
mode='nearest').detach()
|
197 |
+
loss = F.binary_cross_entropy_with_logits(pred, target)
|
198 |
+
# 4. if metric learning, add metric loss and normalize
|
199 |
+
if metric_learning_flag:
|
200 |
+
#print("CE loss : ", loss, "metric loss : ", metric_loss)
|
201 |
+
loss = (loss + self.metric_loss_weight * metric_loss) / (1+self.metric_loss_weight)
|
202 |
+
return pred.detach(), target, loss
|
203 |
+
else:
|
204 |
+
return pred.detach()
|
dianecy/VerbCentric-RIS/model_/segmenter_ang_nonoise_ddp.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from model.clip import build_model
|
7 |
+
|
8 |
+
from .layers import FPN, Projector, TransformerDecoder
|
9 |
+
|
10 |
+
def check_nan(x) :
|
11 |
+
""" Check if there is NaN in tensor """
|
12 |
+
checker = False
|
13 |
+
if True in torch.isnan(x):
|
14 |
+
checker = True
|
15 |
+
return checker
|
16 |
+
|
17 |
+
def zero_filtering(x) :
|
18 |
+
"""
|
19 |
+
Add eps value for zero embedding, because competition metric is cosine similarity
|
20 |
+
Cosine Similarity will be returned NaN, when input value has zero, like as torch.clamp()
|
21 |
+
"""
|
22 |
+
eps = 1e-4
|
23 |
+
x[x <= eps] = eps
|
24 |
+
return x
|
25 |
+
|
26 |
+
def nan_filtering(x, eps = 1e-4) :
|
27 |
+
"""
|
28 |
+
Change eps value for NaN Embedding, because competition metric is cosine similarity
|
29 |
+
Cosine Similarity will be returned NaN
|
30 |
+
"""
|
31 |
+
return torch.nan_to_num(x, nan=eps)
|
32 |
+
|
33 |
+
# def MetricLoss(embeddings, num_pos, alpha = 0.5, args = None):
|
34 |
+
# # embeddings: ((2*B), C, (H*W))
|
35 |
+
# # n_pos : chunk size of positive pairs
|
36 |
+
# # args: args
|
37 |
+
# # returns: loss
|
38 |
+
# metric_loss = 0
|
39 |
+
# # flatten embeddings
|
40 |
+
# B_, C, HW = embeddings.shape
|
41 |
+
# emb = torch.mean(embeddings, dim=-1) # (2*B, C)
|
42 |
+
# emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C)
|
43 |
+
# emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C)
|
44 |
+
# emb_distance = torch.norm(emb_i - emb_j, dim=-1) # (2*B, 2*B)
|
45 |
+
# assert torch.sum(torch.diag(emb_distance[:B_, :B_])) == 0, \
|
46 |
+
# "Diagonals are not zero. please check the permutation on the batch"
|
47 |
+
# # print("distance metrix : ", emb_distance)
|
48 |
+
|
49 |
+
# # positive pairs and loss
|
50 |
+
# positive_mask = torch.zeros_like(emb_distance)
|
51 |
+
# for i in range(B_//2):
|
52 |
+
# positive_mask[2*i, 2*i+1] = 1
|
53 |
+
# positive_mask[2*i+1, 2*i] = 1
|
54 |
+
# positive_mask.fill_diagonal_(1)
|
55 |
+
# positive_loss = torch.sum(emb_distance * positive_mask) / B_
|
56 |
+
|
57 |
+
# # negative pairs and loss
|
58 |
+
# negative_mask = torch.ones_like(emb_distance) - positive_mask
|
59 |
+
# negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / (B_**2 - 2*B_))
|
60 |
+
|
61 |
+
# # print(positive_mask, negative_mask)
|
62 |
+
|
63 |
+
# metric_loss = alpha * positive_loss + (1-alpha) * negative_loss
|
64 |
+
|
65 |
+
# return metric_loss
|
66 |
+
|
67 |
+
# def return_mask(emb_distance, nsent):
|
68 |
+
# B_, B_ = emb_distance.shape
|
69 |
+
# positive_mask = torch.zeros_like(emb_distance)
|
70 |
+
# for i in range(B_//nsent):
|
71 |
+
# positive_mask[nsent*i, nsent*i+1] = 1
|
72 |
+
# positive_mask[nsent*i+1, nsent*i] = 1
|
73 |
+
# positive_mask.fill_diagonal_(1)
|
74 |
+
# negative_mask = torch.ones_like(emb_distance) - positive_mask
|
75 |
+
# return positive_mask, negative_mask
|
76 |
+
|
77 |
+
# def AngularMetricLoss(embeddings, num_pos, num_neg, alpha = 0.5, args = None):
|
78 |
+
# # embeddings: ((6*B), C, (H*W))
|
79 |
+
# # n_pos : chunk size of positive pairs
|
80 |
+
# # args: args
|
81 |
+
# # returns: loss
|
82 |
+
# geometric_loss = 0
|
83 |
+
# nsent = num_pos + num_neg
|
84 |
+
# assert nsent == 6, "number of sentences doesn't match" # nsent : S
|
85 |
+
|
86 |
+
# # flatten embeddings
|
87 |
+
# B_, C, HW = embeddings.shape
|
88 |
+
# emb = torch.mean(embeddings, dim=-1) # (S*B, C)
|
89 |
+
# emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (S*B, S*B, C)
|
90 |
+
# emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (S*B, S*B, C)
|
91 |
+
|
92 |
+
# ## zero filtering
|
93 |
+
# sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
|
94 |
+
# sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (S*B , S*B)
|
95 |
+
# sim_matrix = zero_filtering(sim_matrix)
|
96 |
+
# if check_nan(sim_matrix) :
|
97 |
+
# sim_matrix = nan_filtering(sim_matrix)
|
98 |
+
# sim_matrix = torch.clamp(sim_matrix, min=-0.999, max=0.999)
|
99 |
+
# phi = torch.acos(sim_matrix) # (S*B, S*B)
|
100 |
+
# phi[torch.isnan(phi)] = 0
|
101 |
+
|
102 |
+
# # positive pairs and loss
|
103 |
+
# positive_mask, negative_mask = return_mask(sim_matrix, nsent)
|
104 |
+
# positive_loss = torch.sum((phi**2) * positive_mask) / B_
|
105 |
+
|
106 |
+
# # negative pairs and loss
|
107 |
+
# # negative_mask = torch.ones_like(sim_matrix) - positive_mask
|
108 |
+
# phi_mask = phi < args.phi_threshold
|
109 |
+
|
110 |
+
# negative_loss = (args.phi_threshold - phi)**2
|
111 |
+
# negative_loss = zero_filtering(negative_loss)
|
112 |
+
# if check_nan(negative_loss) :
|
113 |
+
# negative_loss = nan_filtering(negative_loss)
|
114 |
+
|
115 |
+
# if args.div_batch:
|
116 |
+
# negative_loss = torch.sum(negative_loss * negative_mask * phi_mask) / B_
|
117 |
+
# else:
|
118 |
+
# negative_loss = torch.sum(negative_loss * negative_mask * phi_mask) / (B_**2 - nsent*B_)
|
119 |
+
# # print("pos loss, neg loss : ", positive_loss, negative_loss)
|
120 |
+
|
121 |
+
# geometric_loss = alpha * positive_loss + (1-alpha) * negative_loss
|
122 |
+
|
123 |
+
# return geometric_loss
|
124 |
+
|
125 |
+
|
126 |
+
class CRIS_Wo_Noise(nn.Module):
|
127 |
+
def __init__(self, cfg):
|
128 |
+
super().__init__()
|
129 |
+
# Vision & Text Encoder
|
130 |
+
clip_model = torch.jit.load(cfg.clip_pretrain,
|
131 |
+
map_location="cpu").eval()
|
132 |
+
self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float()
|
133 |
+
# Multi-Modal FPN
|
134 |
+
self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
|
135 |
+
# Decoder
|
136 |
+
self.decoder = TransformerDecoder(num_layers=cfg.num_layers,
|
137 |
+
d_model=cfg.vis_dim,
|
138 |
+
nhead=cfg.num_head,
|
139 |
+
dim_ffn=cfg.dim_ffn,
|
140 |
+
dropout=cfg.dropout,
|
141 |
+
return_intermediate=cfg.intermediate)
|
142 |
+
# Projector
|
143 |
+
self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3)
|
144 |
+
self.metric_learning = cfg.metric_learning
|
145 |
+
self.positive_strength = cfg.positive_strength
|
146 |
+
self.metric_loss_weight = cfg.metric_loss_weight
|
147 |
+
self.add_noise = cfg.add_noise
|
148 |
+
self.eps = cfg.ptb_rate
|
149 |
+
self.cfg = cfg
|
150 |
+
# self.bn_fq = nn.BatchNorm2d(1024)
|
151 |
+
|
152 |
+
def forward(self, image, text, target=None):
|
153 |
+
'''
|
154 |
+
img: b, 3, h, w
|
155 |
+
word: b, words
|
156 |
+
word_mask: b, words
|
157 |
+
if self.metric_learning:
|
158 |
+
word: b, 6, words
|
159 |
+
word_mask: b, 6, words
|
160 |
+
mask: b, 1, h, w
|
161 |
+
'''
|
162 |
+
metric_learning_flag = (self.metric_learning and self.training)
|
163 |
+
add_noise_flag = self.add_noise
|
164 |
+
# TODO : mixing option btw distance & angular loss
|
165 |
+
mix_distance_angular = False
|
166 |
+
metric_loss = 0
|
167 |
+
#print("text shape : ", text.shape)
|
168 |
+
if self.training:
|
169 |
+
bt, nt, lt = text.size()
|
170 |
+
else:
|
171 |
+
nt = 1
|
172 |
+
bt, lt = text.size()
|
173 |
+
|
174 |
+
npos= 2
|
175 |
+
nneg =nt-npos
|
176 |
+
|
177 |
+
# 1.Resizing : if metric learning, batch size of the word is doubled
|
178 |
+
if metric_learning_flag:
|
179 |
+
#print("image shape : ", image.shape)
|
180 |
+
b, c, h, w = image.size()
|
181 |
+
# duplicate image and segmentation mask
|
182 |
+
if image is not None:
|
183 |
+
image = torch.cat([image, image, image, image, image, image], dim=0)
|
184 |
+
image = image.reshape(-1, b, c, h, w).transpose(0, 1).reshape(-1, c, h, w)
|
185 |
+
if target is not None:
|
186 |
+
target = torch.cat([target, target, target, target, target, target], dim=0)
|
187 |
+
target = target.reshape(-1, b, 1, h, w).transpose(0, 1).reshape(-1, 1, h, w)
|
188 |
+
|
189 |
+
if add_noise_flag :
|
190 |
+
noise_mask = (text[:, 0, :] == text[:, 1, :])
|
191 |
+
noise_mask = torch.all(noise_mask, dim=-1)
|
192 |
+
noise_mask = noise_mask.unsqueeze(-1).repeat(1, 2).reshape(-1) # 2*b_
|
193 |
+
assert noise_mask.shape[0] == bt * npos, "noise mask shape should be 2*B_"
|
194 |
+
|
195 |
+
text = text.reshape(bt * nt, lt) # 2*b, l
|
196 |
+
# print(image.shape, image.dtype, image.type()) #float32
|
197 |
+
# print(target.shape, target.dtype, target.type()) #float32
|
198 |
+
# print(noise_mask.dtype, noise_mask.type()) #bool
|
199 |
+
|
200 |
+
# padding mask used in decoder
|
201 |
+
pad_mask = torch.zeros_like(text).masked_fill_(text == 0, 1).bool()
|
202 |
+
# print(pad_mask.dtype, pad_mask.type())
|
203 |
+
|
204 |
+
# vis: C3 / C4 / C5
|
205 |
+
# word: b, length, 1024
|
206 |
+
# state: 6b, 1024
|
207 |
+
vis = self.backbone.encode_image(image)
|
208 |
+
word, state = self.backbone.encode_text(text)
|
209 |
+
# print(vis.dtype, vis.type())
|
210 |
+
# state= state.float()
|
211 |
+
if check_nan(state) :
|
212 |
+
print('state has nan valuses')
|
213 |
+
state = nan_filtering(state)
|
214 |
+
# print(state)
|
215 |
+
|
216 |
+
|
217 |
+
b_, d_ = state.size()
|
218 |
+
assert b_ == word.size(0), "batch size of state and word should be same"
|
219 |
+
|
220 |
+
#npos = 2, nneg=4
|
221 |
+
if (add_noise_flag and self.training) :
|
222 |
+
tmp_state = state.reshape(bt, nt, -1)
|
223 |
+
pos_state = tmp_state[:, :npos, :].reshape(bt*npos, -1)
|
224 |
+
neg_state = tmp_state[:, npos:, :].reshape(bt*nneg, -1)
|
225 |
+
noise = torch.randn_like(pos_state) * self.eps
|
226 |
+
pos_state_noisy = pos_state.clone() # Clone pos_state to avoid in-place operations
|
227 |
+
pos_state_noisy[noise_mask] += noise[noise_mask] # Add noise where the mask is True
|
228 |
+
new_state = torch.cat([pos_state_noisy, neg_state], dim=0)
|
229 |
+
else:
|
230 |
+
new_state = state.reshape(bt*nt, -1)
|
231 |
+
|
232 |
+
|
233 |
+
# b, 512, 26, 26 (C4)
|
234 |
+
a3, a4, a5 = vis
|
235 |
+
|
236 |
+
fq, f5 = self.neck(vis, new_state)
|
237 |
+
b, c, h, w = fq.size()
|
238 |
+
fq = self.decoder(fq, word, pad_mask)
|
239 |
+
|
240 |
+
metric_tensor = fq
|
241 |
+
|
242 |
+
# # 3. Get metric loss
|
243 |
+
# if metric_learning_flag:
|
244 |
+
# metric_loss = AngularMetricLoss(fq, npos, nneg, alpha=self.positive_strength, args = self.cfg)
|
245 |
+
|
246 |
+
fq = fq.reshape(b, c, h, w)
|
247 |
+
|
248 |
+
# b, 1, 104, 104
|
249 |
+
pred = self.proj(fq, new_state)
|
250 |
+
#print("pred shape : ", pred.shape, " fq shape : ", fq.shape, " new_state shape : ", new_state.shape)
|
251 |
+
#breakpoint()
|
252 |
+
|
253 |
+
if self.training:
|
254 |
+
if pred.shape[-2:] != target.shape[-2:]:
|
255 |
+
target = F.interpolate(target, pred.shape[-2:],
|
256 |
+
mode='nearest').detach()
|
257 |
+
# seunghoon : 임시로 size만 맞춰놓음 #
|
258 |
+
b, _, h, w = pred.shape
|
259 |
+
assert (pred.shape == target.shape), "pred shape and target shape should be same"
|
260 |
+
pred = pred.reshape(-1, 6, h, w)[:, :2, :, :]
|
261 |
+
# pred_neg = pred.reshape(-1, 6, h, w)[:, 2:, :, :]
|
262 |
+
target = target.reshape(-1, 6, h, w)[:, :2, :, :]
|
263 |
+
# target_neg = target.reshape(-1, 6, h, w)[:, 2:, :, :]
|
264 |
+
|
265 |
+
CE_loss_pos = F.binary_cross_entropy_with_logits(pred, target)
|
266 |
+
# loss_neg = nn.MSELoss()(pred_neg,torch.zeros_like(pred_neg))
|
267 |
+
# loss = loss_pos + loss_neg/(target_neg.shape[-1])**2
|
268 |
+
|
269 |
+
return pred.detach(), target, CE_loss_pos, metric_tensor
|
270 |
+
else:
|
271 |
+
#print(self.cfg.gpu, f"; loss = {loss}")
|
272 |
+
return pred.detach()
|
273 |
+
|
274 |
+
|
275 |
+
## Original code
|
276 |
+
# if self.training:
|
277 |
+
# if pred.shape[-2:] != target.shape[-2:]:
|
278 |
+
# target = F.interpolate(target, pred.shape[-2:],
|
279 |
+
# mode='nearest').detach()
|
280 |
+
# # seunghoon : 임시로 size만 맞춰놓음 #
|
281 |
+
# b, _, h, w = pred.shape
|
282 |
+
# assert (pred.shape == target.shape), "pred shape and target shape should be same"
|
283 |
+
# pred = pred.reshape(-1, 6, h, w)[:, :2, :, :]
|
284 |
+
# # pred_neg = pred.reshape(-1, 6, h, w)[:, 2:, :, :]
|
285 |
+
# target = target.reshape(-1, 6, h, w)[:, :2, :, :]
|
286 |
+
# # target_neg = target.reshape(-1, 6, h, w)[:, 2:, :, :]
|
287 |
+
|
288 |
+
# CE_loss_pos = F.binary_cross_entropy_with_logits(pred, target)
|
289 |
+
# # loss_neg = nn.MSELoss()(pred_neg,torch.zeros_like(pred_neg))
|
290 |
+
# # loss = loss_pos + loss_neg/(target_neg.shape[-1])**2
|
291 |
+
|
292 |
+
|
293 |
+
# # 4. if metric learning, add metric loss and normalize
|
294 |
+
# if metric_learning_flag:
|
295 |
+
# # print("CE loss : ", CE_loss_pos, "metric loss : ", metric_loss)
|
296 |
+
# loss = (CE_loss_pos + self.metric_loss_weight * metric_loss) / (1+self.metric_loss_weight)
|
297 |
+
# # DDP error handling : if there is no negative(BS = 1 or 0 for some GPUs), \
|
298 |
+
# # connect graph to avoid error
|
299 |
+
# safety_loss = loss * 0.
|
300 |
+
# loss = loss + safety_loss
|
301 |
+
# # print(self.cfg.gpu, f"; loss = {loss}")
|
302 |
+
# return pred.detach(), target, loss
|
303 |
+
# else:
|
304 |
+
# #print(self.cfg.gpu, f"; loss = {loss}")
|
305 |
+
# return pred.detach()
|
dianecy/VerbCentric-RIS/model_/segmenter_angular.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from model.clip import build_model
|
7 |
+
|
8 |
+
from .layers import FPN, Projector, TransformerDecoder
|
9 |
+
|
10 |
+
|
11 |
+
# def MetricLoss(embeddings, n_pos, alpha = 0.5, args = None):
|
12 |
+
# # embeddings: ((2*B), C, (H*W))
|
13 |
+
# # n_pos : chunk size of positive pairs
|
14 |
+
# # args: args
|
15 |
+
# # returns: loss
|
16 |
+
# metric_loss = 0
|
17 |
+
|
18 |
+
# # flatten embeddings
|
19 |
+
# B_, C, HW = embeddings.shape
|
20 |
+
# emb = torch.mean(embeddings, dim=-1) # (2*B, C)
|
21 |
+
# emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C)
|
22 |
+
# emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C)
|
23 |
+
# emb_distance = torch.norm(emb_i - emb_j, dim=-1) # (2*B, 2*B)
|
24 |
+
# assert torch.sum(torch.diag(emb_distance[:B_, :B_])) == 0, \
|
25 |
+
# "Diagonals are not zero. please check the permutation on the batch"
|
26 |
+
# # print("distance metrix : ", emb_distance)
|
27 |
+
|
28 |
+
# # positive pairs and loss
|
29 |
+
# positive_mask = torch.zeros_like(emb_distance)
|
30 |
+
# for i in range(B_//2):
|
31 |
+
# positive_mask[2*i, 2*i+1] = 1
|
32 |
+
# positive_mask[2*i+1, 2*i] = 1
|
33 |
+
# positive_mask.fill_diagonal_(1)
|
34 |
+
# positive_loss = torch.sum(emb_distance * positive_mask) / B_
|
35 |
+
|
36 |
+
# # negative pairs and loss
|
37 |
+
# negative_mask = torch.ones_like(emb_distance) - positive_mask
|
38 |
+
# negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / (B_**2 - 2*B_))
|
39 |
+
|
40 |
+
# # print(positive_mask, negative_mask)
|
41 |
+
|
42 |
+
# metric_loss = alpha * positive_loss + (1-alpha) * negative_loss
|
43 |
+
|
44 |
+
# return metric_loss
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
class CRIS_S(nn.Module):
|
49 |
+
def __init__(self, cfg):
|
50 |
+
super().__init__()
|
51 |
+
# Vision & Text Encoder
|
52 |
+
clip_model = torch.jit.load(cfg.clip_pretrain,
|
53 |
+
map_location="cpu").eval()
|
54 |
+
self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float()
|
55 |
+
# Multi-Modal FPN
|
56 |
+
self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
|
57 |
+
# Decoder
|
58 |
+
self.decoder = TransformerDecoder(num_layers=cfg.num_layers,
|
59 |
+
d_model=cfg.vis_dim,
|
60 |
+
nhead=cfg.num_head,
|
61 |
+
dim_ffn=cfg.dim_ffn,
|
62 |
+
dropout=cfg.dropout,
|
63 |
+
return_intermediate=cfg.intermediate)
|
64 |
+
# Projector
|
65 |
+
self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3)
|
66 |
+
self.metric_learning = cfg.metric_learning
|
67 |
+
self.positive_strength = cfg.positive_strength
|
68 |
+
self.metric_loss_weight = cfg.metric_loss_weight
|
69 |
+
self.eps = cfg.ptb_rate
|
70 |
+
self.cfg = cfg
|
71 |
+
|
72 |
+
def forward(self, image, text, target=None):
|
73 |
+
'''
|
74 |
+
img: b, 3, h, w
|
75 |
+
word: b, words
|
76 |
+
word_mask: b, words
|
77 |
+
if self.metric_learning:
|
78 |
+
word: b, 2, words
|
79 |
+
word_mask: b, 2, words
|
80 |
+
mask: b, 1, h, w
|
81 |
+
'''
|
82 |
+
metric_learning_flag = (self.metric_learning and self.training)
|
83 |
+
# TODO : mixing option btw distance & angular loss
|
84 |
+
mix_distance_angular = False
|
85 |
+
metric_loss = 0
|
86 |
+
|
87 |
+
# 1.Resizing : if metric learning, batch size of the word is doubled
|
88 |
+
if metric_learning_flag:
|
89 |
+
#print("image shape : ", image.shape)
|
90 |
+
b, c, h, w = image.size()
|
91 |
+
# duplicate image and segmentation mask
|
92 |
+
if image is not None:
|
93 |
+
image = torch.cat([image, image], dim=0)
|
94 |
+
image = image.reshape(-1, b, c, h, w).transpose(0, 1).reshape(-1, c, h, w)
|
95 |
+
if target is not None:
|
96 |
+
target = torch.cat([target, target], dim=0)
|
97 |
+
target = target.reshape(-1, b, 1, h, w).transpose(0, 1).reshape(-1, 1, h, w)
|
98 |
+
# duplicate noise mask
|
99 |
+
b_, n_, l_ = text.size()
|
100 |
+
assert n_ == 2 ,"word size should be 2"
|
101 |
+
noise_mask = (text[:, 0, :] == text[:, 1, :])
|
102 |
+
noise_mask = torch.all(noise_mask, dim=-1)
|
103 |
+
noise_mask = noise_mask.unsqueeze(-1).repeat(1, 2).reshape(-1) # 2*b_
|
104 |
+
assert noise_mask.shape[0] == b_ * 2, "noise mask shape should be 2*b_"
|
105 |
+
text = text.reshape(b_ * 2, l_) # 2*b, l
|
106 |
+
|
107 |
+
# print("text shape : ", text.shape)
|
108 |
+
# print("image shape : ", image.shape)
|
109 |
+
# print("target shape : ", target.shape)
|
110 |
+
# print(torch.sum(image[0::2]) == torch.sum(image[1::2]))
|
111 |
+
# print(torch.sum(target[0::2]) == torch.sum(target[1::2]))
|
112 |
+
|
113 |
+
# padding mask used in decoder
|
114 |
+
pad_mask = torch.zeros_like(text).masked_fill_(text == 0, 1).bool()
|
115 |
+
# vis: C3 / C4 / C5
|
116 |
+
# word: b, length, 1024
|
117 |
+
# state: b, 1024
|
118 |
+
vis = self.backbone.encode_image(image)
|
119 |
+
word, state = self.backbone.encode_text(text)
|
120 |
+
|
121 |
+
b_, d_ = state.size()
|
122 |
+
assert b_ == word.size(0), "batch size of state and word should be same"
|
123 |
+
|
124 |
+
|
125 |
+
# 2. State Noising Step : if number of caption is 1,
|
126 |
+
# add noise to the corresponding indices
|
127 |
+
if metric_learning_flag :
|
128 |
+
noise = torch.randn_like(state) * self.eps
|
129 |
+
state[noise_mask] = state[noise_mask] + noise[noise_mask]
|
130 |
+
|
131 |
+
|
132 |
+
# b, 512, 26, 26 (C4)
|
133 |
+
a3, a4, a5 = vis
|
134 |
+
fq, f5 = self.neck(vis, state)
|
135 |
+
b, c, h, w = fq.size()
|
136 |
+
fq = self.decoder(fq, word, pad_mask)
|
137 |
+
metric_tensor = fq
|
138 |
+
# if metric_learning_flag:
|
139 |
+
# metric_loss = AngularMetricLoss(fq, 2, alpha=self.positive_strength, args = self.cfg) # (1-self.positive_strength) *
|
140 |
+
# if mix_distance_angular:
|
141 |
+
# metric_loss += MetricLoss(fq, 2, alpha=self.positive_strength, args = self.cfg) # self.positive_strength *
|
142 |
+
fq = fq.reshape(b, c, h, w)
|
143 |
+
|
144 |
+
# b, 1, 104, 104
|
145 |
+
pred = self.proj(fq, state)
|
146 |
+
|
147 |
+
if self.training:
|
148 |
+
# resize mask
|
149 |
+
if pred.shape[-2:] != target.shape[-2:]:
|
150 |
+
target = F.interpolate(target, pred.shape[-2:],
|
151 |
+
mode='nearest').detach()
|
152 |
+
CE_loss = F.binary_cross_entropy_with_logits(pred, target)
|
153 |
+
|
154 |
+
# 4. if metric learning, add metric loss and normalize
|
155 |
+
# if metric_learning_flag:
|
156 |
+
# loss = (loss + self.metric_loss_weight * metric_loss) / (1+self.metric_loss_weight)
|
157 |
+
# safety_loss = loss * 0.
|
158 |
+
# loss = loss + safety_loss
|
159 |
+
|
160 |
+
return pred.detach(), target, CE_loss, metric_tensor
|
161 |
+
else:
|
162 |
+
#print(self.cfg.gpu, f"; loss = {loss}")
|
163 |
+
return pred.detach()
|
dianecy/VerbCentric-RIS/model_/segmenter_verbonly.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from model_.clip import build_model
|
7 |
+
|
8 |
+
from .layers import FPN, Projector, TransformerDecoder
|
9 |
+
|
10 |
+
class CRIS_PosOnly(nn.Module):
|
11 |
+
def __init__(self, cfg):
|
12 |
+
super().__init__()
|
13 |
+
# Vision & Text Encoder
|
14 |
+
clip_model = torch.jit.load(cfg.clip_pretrain,
|
15 |
+
map_location="cpu").eval()
|
16 |
+
self.backbone = build_model(clip_model.state_dict(), cfg.word_len, cfg.freeze).float()
|
17 |
+
|
18 |
+
# Multi-Modal FPN
|
19 |
+
self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
|
20 |
+
# Decoder
|
21 |
+
self.decoder = TransformerDecoder(num_layers=cfg.num_layers,
|
22 |
+
d_model=cfg.vis_dim,
|
23 |
+
nhead=cfg.num_head,
|
24 |
+
dim_ffn=cfg.dim_ffn,
|
25 |
+
dropout=cfg.dropout,
|
26 |
+
return_intermediate=cfg.intermediate)
|
27 |
+
# Projector
|
28 |
+
self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3)
|
29 |
+
self.metric_learning = False # cfg.metric_learning
|
30 |
+
self.metric_loss_weight = cfg.metric_loss_weight
|
31 |
+
self.cfg = cfg
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
def forward(self, image, text, target=None, verb=None):
|
37 |
+
'''
|
38 |
+
image: b, 3, h, w
|
39 |
+
text: b, words
|
40 |
+
target: b, 1, h, w
|
41 |
+
verb: b, words (if applicable, only used in training mode for contrastive learning)
|
42 |
+
'''
|
43 |
+
|
44 |
+
sentences, images, targets, pad_masks = [], [], [], []
|
45 |
+
|
46 |
+
if self.training:
|
47 |
+
verb_masks = []
|
48 |
+
cl_masks = []
|
49 |
+
|
50 |
+
for idx in range(len(text)):
|
51 |
+
sentences.append(text[idx])
|
52 |
+
images.append(image[idx])
|
53 |
+
targets.append(target[idx])
|
54 |
+
pad_masks.append(torch.zeros_like(text[idx]).masked_fill_(text[idx] == 0, 1).bool())
|
55 |
+
|
56 |
+
# If verb exists, process it
|
57 |
+
if verb[idx].numel() > 0 and verb[idx].sum().item() > 0:
|
58 |
+
verb_masks.extend([1, 1]) # Both original sentence and verb are marked
|
59 |
+
cl_masks.extend([1, 0]) # Only original sentence get marked
|
60 |
+
sentences.append(verb[idx])
|
61 |
+
images.append(image[idx])
|
62 |
+
targets.append(target[idx])
|
63 |
+
pad_masks.append(torch.zeros_like(verb[idx]).masked_fill_(verb[idx] == 0, 1).bool())
|
64 |
+
else:
|
65 |
+
verb_masks.append(0)
|
66 |
+
cl_masks.append(1)
|
67 |
+
|
68 |
+
|
69 |
+
sentences = torch.stack(sentences)
|
70 |
+
images = torch.stack(images)
|
71 |
+
targets = torch.stack(targets)
|
72 |
+
pad_masks = torch.stack(pad_masks)
|
73 |
+
verb_masks = torch.tensor(verb_masks, dtype=torch.bool)
|
74 |
+
cl_masks = torch.tensor(cl_masks, dtype=torch.bool)
|
75 |
+
|
76 |
+
else:
|
77 |
+
sentences = text
|
78 |
+
images = image
|
79 |
+
targets = target
|
80 |
+
pad_masks = torch.zeros_like(text).masked_fill_(text == 0, 1).bool()
|
81 |
+
|
82 |
+
# Encoding images and text
|
83 |
+
vis = self.backbone.encode_image(images)
|
84 |
+
word, state = self.backbone.encode_text(sentences)
|
85 |
+
|
86 |
+
# FPN neck and decoder
|
87 |
+
fq, f5 = self.neck(vis, state)
|
88 |
+
b, c, h, w = fq.size()
|
89 |
+
fq = self.decoder(fq, word, pad_masks)
|
90 |
+
metric_tensor = fq # b, c, h*w
|
91 |
+
fq = fq.reshape(b, c, h, w)
|
92 |
+
|
93 |
+
# Final prediction
|
94 |
+
pred = self.proj(fq, state)
|
95 |
+
|
96 |
+
if self.training:
|
97 |
+
if pred.shape[-2:] != targets.shape[-2:]:
|
98 |
+
targets = F.interpolate(targets, pred.shape[-2:], mode='nearest').detach()
|
99 |
+
loss = F.binary_cross_entropy_with_logits(pred[cl_masks], targets[cl_masks])
|
100 |
+
|
101 |
+
if self.metric_learning:
|
102 |
+
metric_loss = self.compute_metric_loss(metric_tensor, verb_masks, args=self.cfg)
|
103 |
+
loss = (loss + self.metric_loss_weight * metric_loss) / (1 + self.metric_loss_weight)
|
104 |
+
|
105 |
+
return pred[cl_masks].detach(), targets[cl_masks], loss
|
106 |
+
|
107 |
+
return pred.detach() # In eval mode, only return the predictions
|
108 |
+
|
109 |
+
|
110 |
+
def compute_metric_loss(self, metric_tensor, positive_verbs, negative_verbs, args) :
|
111 |
+
if args.loss_option == "ACL_verbonly" :
|
112 |
+
metric_loss = self.UniAngularContrastLoss(metric_tensor, positive_verbs, negative_verbs, m=args.margin_value, tau=args.temperature, verbonly=True, args=args)
|
113 |
+
elif args.loss_option == "ACL" :
|
114 |
+
metric_loss = self.UniAngularContrastLoss(metric_tensor, positive_verbs, negative_verbs, m=args.margin_value, tau=args.temperature, verbonly=False, args=args)
|
115 |
+
|
116 |
+
return metric_loss
|
117 |
+
|
118 |
+
|
119 |
+
def return_mask(self, emb_distance, verb_mask=None):
|
120 |
+
B_, B_ = emb_distance.shape
|
121 |
+
positive_mask = torch.zeros_like(emb_distance)
|
122 |
+
positive_mask.fill_diagonal_(1) # Set diagonal elements to 1 for all cases
|
123 |
+
|
124 |
+
if B_ < len(verb_mask):
|
125 |
+
# If B_ equals to 2*K (double the number of verb phrase)
|
126 |
+
for i in range(B_ // 2):
|
127 |
+
positive_mask[2 * i, 2 * i + 1] = 1
|
128 |
+
positive_mask[2 * i + 1, 2 * i] = 1
|
129 |
+
else:
|
130 |
+
# Process the case where we have a mix of sentences with and without verbs
|
131 |
+
i = 0
|
132 |
+
while i < B_:
|
133 |
+
if verb_mask[i] == 1:
|
134 |
+
positive_mask[i, i + 1] = 1
|
135 |
+
positive_mask[i + 1, i] = 1
|
136 |
+
i += 2
|
137 |
+
else:
|
138 |
+
i += 1
|
139 |
+
negative_mask = torch.ones_like(emb_distance) - positive_mask
|
140 |
+
return positive_mask, negative_mask
|
141 |
+
|
142 |
+
|
143 |
+
def UniAngularContrastLoss(self, total_fq, verb_mask, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None):
|
144 |
+
_, C, HW = total_fq.shape
|
145 |
+
|
146 |
+
if verbonly :
|
147 |
+
emb = torch.mean(total_fq[verb_mask], dim=-1)
|
148 |
+
assert emb.shape[0] % 2 == 0, f"Embedding count {emb.shape[0]} is not divisible by 2."
|
149 |
+
else :
|
150 |
+
emb = torch.mean(total_fq, dim=-1)
|
151 |
+
|
152 |
+
B_ = emb.shape[0]
|
153 |
+
# emb = F.normalize(emb, p=2, dim=1)
|
154 |
+
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (B_, B_, C)
|
155 |
+
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (B_, B_, C)
|
156 |
+
sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
|
157 |
+
sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_)
|
158 |
+
sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999)
|
159 |
+
|
160 |
+
positive_mask, negative_mask = self.return_mask(sim_matrix, verb_mask)
|
161 |
+
|
162 |
+
# Apply margin to positive pairs
|
163 |
+
sim_matrix_with_margin = sim_matrix.clone()
|
164 |
+
sim_matrix_with_margin[positive_mask.bool()] = torch.cos(torch.acos(sim_matrix[positive_mask.bool()]) + m / 57.2958)
|
165 |
+
|
166 |
+
# Scale logits with temperature
|
167 |
+
logits = sim_matrix_with_margin / tau
|
168 |
+
|
169 |
+
# Compute the softmax loss for all pairs
|
170 |
+
exp_logits = torch.exp(logits)
|
171 |
+
pos_exp_logits = exp_logits[positive_mask.bool()]
|
172 |
+
total_exp_logits = exp_logits.sum(dim=-1)
|
173 |
+
|
174 |
+
# Compute the final loss: L_arc = -log(e^(cos(theta + m)/tau) / sum(e^(cos(theta)/tau)))
|
175 |
+
positive_loss = -torch.log(pos_exp_logits / total_exp_logits[positive_mask.bool()])
|
176 |
+
angular_loss = positive_loss.mean()
|
177 |
+
|
178 |
+
return angular_loss
|
dianecy/VerbCentric-RIS/model_/segmenter_verbonly_fin.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from model_.clip import build_model
|
7 |
+
|
8 |
+
from .layers import FPN, Projector, TransformerDecoder
|
9 |
+
|
10 |
+
class CRIS_PosOnly_rev(nn.Module):
|
11 |
+
def __init__(self, cfg):
|
12 |
+
super().__init__()
|
13 |
+
# Vision & Text Encoder
|
14 |
+
clip_model = torch.jit.load(cfg.clip_pretrain,
|
15 |
+
map_location="cpu").eval()
|
16 |
+
self.backbone = build_model(clip_model.state_dict(), cfg.word_len, cfg.freeze).float()
|
17 |
+
|
18 |
+
# Multi-Modal FPN
|
19 |
+
self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
|
20 |
+
# Decoder
|
21 |
+
self.decoder = TransformerDecoder(num_layers=cfg.num_layers,
|
22 |
+
d_model=cfg.vis_dim,
|
23 |
+
nhead=cfg.num_head,
|
24 |
+
dim_ffn=cfg.dim_ffn,
|
25 |
+
dropout=cfg.dropout,
|
26 |
+
return_intermediate=cfg.intermediate)
|
27 |
+
# Projector
|
28 |
+
self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3)
|
29 |
+
self.metric_learning = False # cfg.metric_learning
|
30 |
+
self.metric_loss_weight = cfg.metric_loss_weight
|
31 |
+
self.cfg = cfg
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
def forward(self, image, text, target=None, verb=None):
|
37 |
+
'''
|
38 |
+
image: b, 3, h, w
|
39 |
+
text: b, words
|
40 |
+
target: b, 1, h, w
|
41 |
+
verb: b, words (if applicable, only used in training mode for contrastive learning)
|
42 |
+
'''
|
43 |
+
|
44 |
+
sentences, images, targets, pad_masks = [], [], [], []
|
45 |
+
|
46 |
+
if self.training:
|
47 |
+
verb_masks = []
|
48 |
+
cl_masks = []
|
49 |
+
|
50 |
+
for idx in range(len(text)):
|
51 |
+
sentences.append(text[idx])
|
52 |
+
images.append(image[idx])
|
53 |
+
targets.append(target[idx])
|
54 |
+
pad_masks.append(torch.zeros_like(text[idx]).masked_fill_(text[idx] == 0, 1).bool())
|
55 |
+
|
56 |
+
# If verb exists, process it
|
57 |
+
if verb[idx].numel() > 0 and verb[idx].sum().item() > 0:
|
58 |
+
verb_masks.extend([1, 1]) # Both original sentence and verb are marked
|
59 |
+
cl_masks.extend([1, 0]) # Only original sentence get marked
|
60 |
+
sentences.append(verb[idx])
|
61 |
+
images.append(image[idx])
|
62 |
+
targets.append(target[idx])
|
63 |
+
pad_masks.append(torch.zeros_like(verb[idx]).masked_fill_(verb[idx] == 0, 1).bool())
|
64 |
+
else:
|
65 |
+
verb_masks.append(0)
|
66 |
+
cl_masks.append(1)
|
67 |
+
|
68 |
+
|
69 |
+
sentences = torch.stack(sentences)
|
70 |
+
images = torch.stack(images)
|
71 |
+
targets = torch.stack(targets)
|
72 |
+
pad_masks = torch.stack(pad_masks)
|
73 |
+
verb_masks = torch.tensor(verb_masks, dtype=torch.bool)
|
74 |
+
cl_masks = torch.tensor(cl_masks, dtype=torch.bool)
|
75 |
+
|
76 |
+
else:
|
77 |
+
sentences = text
|
78 |
+
images = image
|
79 |
+
targets = target
|
80 |
+
pad_masks = torch.zeros_like(text).masked_fill_(text == 0, 1).bool()
|
81 |
+
|
82 |
+
# Encoding images and text
|
83 |
+
vis = self.backbone.encode_image(images)
|
84 |
+
word, state = self.backbone.encode_text(sentences)
|
85 |
+
|
86 |
+
# FPN neck and decoder
|
87 |
+
fq, f5 = self.neck(vis, state)
|
88 |
+
b, c, h, w = fq.size()
|
89 |
+
fq = self.decoder(fq, word, pad_masks)
|
90 |
+
metric_tensor = fq # b, c, h*w
|
91 |
+
fq = fq.reshape(b, c, h, w)
|
92 |
+
|
93 |
+
# Final prediction
|
94 |
+
pred = self.proj(fq, state)
|
95 |
+
|
96 |
+
if self.training:
|
97 |
+
if pred.shape[-2:] != targets.shape[-2:]:
|
98 |
+
targets = F.interpolate(targets, pred.shape[-2:], mode='nearest').detach()
|
99 |
+
|
100 |
+
loss = F.binary_cross_entropy_with_logits(pred[cl_masks], targets[cl_masks])
|
101 |
+
|
102 |
+
if self.metric_learning:
|
103 |
+
metric_loss = self.compute_metric_loss(metric_tensor, verb_masks, args=self.cfg)
|
104 |
+
loss = (loss + self.metric_loss_weight * metric_loss) / (1 + self.metric_loss_weight)
|
105 |
+
|
106 |
+
return pred[cl_masks].detach(), targets[cl_masks], loss
|
107 |
+
|
108 |
+
return pred.detach() # In eval mode, only return the predictions
|
109 |
+
|
110 |
+
|
111 |
+
def compute_metric_loss(self, metric_tensor, positive_verbs, negative_verbs, args) :
|
112 |
+
if args.loss_option == "ACL_verbonly" :
|
113 |
+
metric_loss = self.UniAngularContrastLoss(metric_tensor, positive_verbs, negative_verbs, m=args.margin_value, tau=args.temperature, verbonly=True, args=args)
|
114 |
+
elif args.loss_option == "ACL" :
|
115 |
+
metric_loss = self.UniAngularContrastLoss(metric_tensor, positive_verbs, negative_verbs, m=args.margin_value, tau=args.temperature, verbonly=False, args=args)
|
116 |
+
|
117 |
+
return metric_loss
|
118 |
+
|
119 |
+
|
120 |
+
def return_mask(self, emb_distance, verb_mask=None):
|
121 |
+
B_, B_ = emb_distance.shape
|
122 |
+
positive_mask = torch.zeros_like(emb_distance)
|
123 |
+
positive_mask.fill_diagonal_(1) # Set diagonal elements to 1 for all cases
|
124 |
+
|
125 |
+
if B_ < len(verb_mask):
|
126 |
+
# If B_ equals to 2*K (double the number of verb phrase)
|
127 |
+
for i in range(B_ // 2):
|
128 |
+
positive_mask[2 * i, 2 * i + 1] = 1
|
129 |
+
positive_mask[2 * i + 1, 2 * i] = 1
|
130 |
+
else:
|
131 |
+
# Process the case where we have a mix of sentences with and without verbs
|
132 |
+
i = 0
|
133 |
+
while i < B_:
|
134 |
+
if verb_mask[i] == 1:
|
135 |
+
positive_mask[i, i + 1] = 1
|
136 |
+
positive_mask[i + 1, i] = 1
|
137 |
+
i += 2
|
138 |
+
else:
|
139 |
+
i += 1
|
140 |
+
negative_mask = torch.ones_like(emb_distance) - positive_mask
|
141 |
+
return positive_mask, negative_mask
|
142 |
+
|
143 |
+
|
144 |
+
def UniAngularContrastLoss(self, total_fq, verb_mask, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None):
|
145 |
+
_, C, HW = total_fq.shape
|
146 |
+
|
147 |
+
if verbonly :
|
148 |
+
emb = torch.mean(total_fq[verb_mask], dim=-1)
|
149 |
+
assert emb.shape[0] % 2 == 0, f"Embedding count {emb.shape[0]} is not divisible by 2."
|
150 |
+
else :
|
151 |
+
emb = torch.mean(total_fq, dim=-1)
|
152 |
+
|
153 |
+
B_ = emb.shape[0]
|
154 |
+
# emb = F.normalize(emb, p=2, dim=1)
|
155 |
+
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (B_, B_, C)
|
156 |
+
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (B_, B_, C)
|
157 |
+
sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
|
158 |
+
sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_)
|
159 |
+
sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999)
|
160 |
+
|
161 |
+
positive_mask, negative_mask = self.return_mask(sim_matrix, verb_mask)
|
162 |
+
|
163 |
+
# Apply margin to positive pairs
|
164 |
+
sim_matrix_with_margin = sim_matrix.clone()
|
165 |
+
sim_matrix_with_margin[positive_mask.bool()] = torch.cos(torch.acos(sim_matrix[positive_mask.bool()]) + m / 57.2958)
|
166 |
+
|
167 |
+
# Scale logits with temperature
|
168 |
+
logits = sim_matrix_with_margin / tau
|
169 |
+
|
170 |
+
# Compute the softmax loss for all pairs
|
171 |
+
exp_logits = torch.exp(logits)
|
172 |
+
pos_exp_logits = exp_logits[positive_mask.bool()]
|
173 |
+
total_exp_logits = exp_logits.sum(dim=-1)
|
174 |
+
|
175 |
+
# Compute the final loss: L_arc = -log(e^(cos(theta + m)/tau) / sum(e^(cos(theta)/tau)))
|
176 |
+
positive_loss = -torch.log(pos_exp_logits / total_exp_logits)
|
177 |
+
angular_loss = positive_loss.mean()
|
178 |
+
|
179 |
+
return angular_loss
|
dianecy/VerbCentric-RIS/model_/segmenter_verbonly_hardneg.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from model_.clip import build_model
|
7 |
+
|
8 |
+
from .layers import FPN, Projector, TransformerDecoder
|
9 |
+
|
10 |
+
class CRIS_VerbOnly(nn.Module):
|
11 |
+
def __init__(self, cfg):
|
12 |
+
super().__init__()
|
13 |
+
# Vision & Text Encoder
|
14 |
+
clip_model = torch.jit.load(cfg.clip_pretrain,
|
15 |
+
map_location="cpu").eval()
|
16 |
+
self.backbone = build_model(clip_model.state_dict(), cfg.word_len, cfg.freeze).float()
|
17 |
+
# Multi-Modal FPN
|
18 |
+
self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
|
19 |
+
# Decoder
|
20 |
+
self.decoder = TransformerDecoder(num_layers=cfg.num_layers,
|
21 |
+
d_model=cfg.vis_dim,
|
22 |
+
nhead=cfg.num_head,
|
23 |
+
dim_ffn=cfg.dim_ffn,
|
24 |
+
dropout=cfg.dropout,
|
25 |
+
return_intermediate=cfg.intermediate)
|
26 |
+
# Projector
|
27 |
+
self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3)
|
28 |
+
self.metric_learning = False # cfg.metric_learning
|
29 |
+
# self.positive_strength = cfg.positive_strength
|
30 |
+
self.metric_loss_weight = cfg.metric_loss_weight
|
31 |
+
# self.eps = cfg.ptb_rate
|
32 |
+
self.cfg = cfg
|
33 |
+
|
34 |
+
|
35 |
+
def forward(self, image, text, target=None, hardpos=None, hardneg=None):
|
36 |
+
'''
|
37 |
+
image: b, 3, h, w
|
38 |
+
text: b, words
|
39 |
+
target: b, 1, h, w
|
40 |
+
verb: b, words (if applicable, only used in training mode for contrastive learning)
|
41 |
+
'''
|
42 |
+
|
43 |
+
if self.training:
|
44 |
+
sentences, images, targets, pad_masks = [], [], [], []
|
45 |
+
posverb_mask, negverb_mask = [], []
|
46 |
+
cl_masks = []
|
47 |
+
|
48 |
+
for idx in range(len(text)):
|
49 |
+
sentences.append(text[idx])
|
50 |
+
images.append(image[idx])
|
51 |
+
targets.append(target[idx])
|
52 |
+
pad_masks.append(torch.zeros_like(text[idx]).masked_fill_(text[idx] == 0, 1).bool())
|
53 |
+
|
54 |
+
if hardpos[idx].numel() > 0 and hardpos[idx].sum().item() > 0:
|
55 |
+
# if hard positive exists, check the condition
|
56 |
+
if hardneg[idx].numel() > 0 and hardneg[idx].sum().item() > 0:
|
57 |
+
# if hard positive and hard negative exists
|
58 |
+
posverb_mask.extend([1, 1, 0]) # mark original, hard positive as 1, negative as 0
|
59 |
+
negverb_mask.extend([0, 0, 1]) # mark only negative as 1
|
60 |
+
|
61 |
+
if not self.cfg.hn_celoss :
|
62 |
+
cl_masks.extend([1, 0, 0]) # mark only original as 1
|
63 |
+
else :
|
64 |
+
cl_masks.extend([1, 0, 1])
|
65 |
+
sentences.extend([hardpos[idx], hardneg[idx]])
|
66 |
+
images.extend([image[idx], image[idx]])
|
67 |
+
targets.extend([target[idx], torch.zeros_like(original_target, device=original_target.device)])
|
68 |
+
pad_masks.extend([
|
69 |
+
torch.zeros_like(hardpos[idx]).masked_fill_(hardpos[idx] == 0, 1).bool(),
|
70 |
+
torch.zeros_like(hardneg[idx]).masked_fill_(hardneg[idx] == 0, 1).bool()
|
71 |
+
])
|
72 |
+
|
73 |
+
else :
|
74 |
+
# only hard positive exists, no negatives
|
75 |
+
posverb_mask.extend([1, 1])
|
76 |
+
negverb_mask.extend([0, 0])
|
77 |
+
cl_masks.extend([1, 0])
|
78 |
+
|
79 |
+
sentences.append(hardpos[idx])
|
80 |
+
images.append(image[idx])
|
81 |
+
targets.append(target[idx])
|
82 |
+
pad_masks.append(torch.zeros_like(hardpos[idx]).masked_fill_(hardpos[idx] == 0, 1).bool())
|
83 |
+
else :
|
84 |
+
# no hard positive, no hard negative. only original sentence itself.
|
85 |
+
posverb_mask.append(0)
|
86 |
+
negverb_mask.append(0)
|
87 |
+
cl_masks.append(1)
|
88 |
+
|
89 |
+
|
90 |
+
sentences = torch.stack(sentences)
|
91 |
+
images = torch.stack(images)
|
92 |
+
targets = torch.stack(targets)
|
93 |
+
pad_masks = torch.stack(pad_masks)
|
94 |
+
posverb_mask = torch.tensor(posverb_mask, dtype=torch.bool)
|
95 |
+
negverb_mask = torch.tensor(negverb_mask, dtype=torch.bool)
|
96 |
+
cl_masks = torch.tensor(cl_masks, dtype=torch.bool)
|
97 |
+
|
98 |
+
else:
|
99 |
+
sentences = text
|
100 |
+
images = image
|
101 |
+
targets = target
|
102 |
+
pad_masks = torch.zeros_like(text).masked_fill_(text == 0, 1).bool()
|
103 |
+
|
104 |
+
# Encoding images and text
|
105 |
+
vis = self.backbone.encode_image(images)
|
106 |
+
word, state = self.backbone.encode_text(sentences)
|
107 |
+
|
108 |
+
# FPN neck and decoder
|
109 |
+
fq, f5 = self.neck(vis, state)
|
110 |
+
b, c, h, w = fq.size()
|
111 |
+
fq = self.decoder(fq, word, pad_masks)
|
112 |
+
metric_tensor = fq # b, c, h*w
|
113 |
+
fq = fq.reshape(b, c, h, w)
|
114 |
+
|
115 |
+
# Final prediction
|
116 |
+
pred = self.proj(fq, state)
|
117 |
+
|
118 |
+
if self.training:
|
119 |
+
if pred.shape[-2:] != targets.shape[-2:]:
|
120 |
+
targets = F.interpolate(targets, pred.shape[-2:], mode='nearest').detach()
|
121 |
+
|
122 |
+
loss = F.binary_cross_entropy_with_logits(pred[cl_masks], targets[cl_masks])
|
123 |
+
|
124 |
+
if self.metric_learning:
|
125 |
+
metric_loss = self.compute_metric_loss(metric_tensor, posverb_mask, negverb_mask, args=self.cfg)
|
126 |
+
loss = (loss + self.metric_loss_weight * metric_loss) / (1 + self.metric_loss_weight)
|
127 |
+
|
128 |
+
return pred.detach(), targets, loss
|
129 |
+
|
130 |
+
return pred.detach()
|
131 |
+
|
132 |
+
|
133 |
+
def compute_metric_loss(self, metric_tensor, positive_verbs, negative_verbs, args) :
|
134 |
+
if args.loss_option == "ACL_verbonly" :
|
135 |
+
metric_loss = self.UniAngularContrastLoss(metric_tensor, positive_verbs, negative_verbs, m=args.margin_value, tau=args.temperature, verbonly=True, args=args)
|
136 |
+
elif args.loss_option == "ACL" :
|
137 |
+
metric_loss = self.UniAngularContrastLoss(metric_tensor, positive_verbs, negative_verbs, m=args.margin_value, tau=args.temperature, verbonly=False, args=args)
|
138 |
+
|
139 |
+
return metric_loss
|
140 |
+
|
141 |
+
|
142 |
+
def return_mask(self, emb_distance, positive_verbs, negative_verbs, verb_mask):
|
143 |
+
B_, B_ = emb_distance.shape
|
144 |
+
positive_mask = torch.zeros_like(emb_distance)
|
145 |
+
negative_mask = torch.ones_like(emb_distance)
|
146 |
+
hard_negative_mask = torch.zeros_like(emb_distance)
|
147 |
+
positive_mask.fill_diagonal_(1)
|
148 |
+
|
149 |
+
if B_ < len(verb_mask):
|
150 |
+
# Considering only verbs that pass the verb_mask filter
|
151 |
+
positive_verbs = torch.tensor(positive_verbs)[verb_mask]
|
152 |
+
negative_verbs = torch.tensor(negative_verbs)[verb_mask]
|
153 |
+
|
154 |
+
# Exclude hard negatives from both masks (diagonal)
|
155 |
+
for i in range(B_):
|
156 |
+
if negative_verbs[i] == 1:
|
157 |
+
positive_mask[i, i] = 0
|
158 |
+
negative_mask[i, i] = 0
|
159 |
+
# Set the entire row and column for the hard negative, except the diagonal
|
160 |
+
hard_negative_mask[i, :] = 1 # Mark the i-th row
|
161 |
+
hard_negative_mask[:, i] = 1 # Mark the i-th column
|
162 |
+
hard_negative_mask[i, i] = 0 # Ensure diagonal element (i, i) is 0
|
163 |
+
|
164 |
+
i = 0
|
165 |
+
while i < B_:
|
166 |
+
if positive_verbs[i] == 1:
|
167 |
+
if i + 1 < B_ and positive_verbs[i + 1] == 1:
|
168 |
+
positive_mask[i, i + 1] = 1
|
169 |
+
positive_mask[i + 1, i] = 1
|
170 |
+
i += 2
|
171 |
+
else:
|
172 |
+
i += 1
|
173 |
+
else:
|
174 |
+
# Exclude hard negatives from both masks (diagonal)
|
175 |
+
for i in range(B_):
|
176 |
+
if negative_verbs[i] == 1:
|
177 |
+
positive_mask[i, i] = 0
|
178 |
+
negative_mask[i, i] = 0
|
179 |
+
# Set the entire row and column for the hard negative, except the diagonal
|
180 |
+
hard_negative_mask[i, :] = 1 # Mark the i-th row
|
181 |
+
hard_negative_mask[:, i] = 1 # Mark the i-th column
|
182 |
+
hard_negative_mask[i, i] = 0 # Ensure diagonal element (i, i) is 0
|
183 |
+
|
184 |
+
# Apply the positive pairs logic similarly as above
|
185 |
+
i = 0
|
186 |
+
while i < B_:
|
187 |
+
if positive_verbs[i] == 1 and i + 1 < B_ and positive_verbs[i + 1] == 1:
|
188 |
+
positive_mask[i, i + 1] = 1
|
189 |
+
positive_mask[i + 1, i] = 1
|
190 |
+
i += 2
|
191 |
+
else:
|
192 |
+
i += 1
|
193 |
+
|
194 |
+
negative_mask = negative_mask - positive_mask
|
195 |
+
negative_mask[hard_negative_mask.bool()] = 0 # Set hard negative indices to 0 in negative_mask
|
196 |
+
return positive_mask, negative_mask, hard_negative_mask
|
197 |
+
|
198 |
+
|
199 |
+
def UniAngularContrastLoss(self, total_fq, positive_verbs, negative_verbs, m=0.5, tau=0.05, verbonly=True, args=None):
|
200 |
+
"""
|
201 |
+
Angular Margin Contrastive Loss function with mask visualization.
|
202 |
+
"""
|
203 |
+
verb_mask = positive_verbs + negative_verbs
|
204 |
+
|
205 |
+
if verbonly:
|
206 |
+
emb = torch.mean(total_fq[verb_mask], dim=-1)
|
207 |
+
else:
|
208 |
+
emb = torch.mean(total_fq, dim=-1) # (B, C)
|
209 |
+
|
210 |
+
B_ = emb.shape[0]
|
211 |
+
# emb = F.normalize(emb, p=2, dim=1)
|
212 |
+
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (B_, B_, C)
|
213 |
+
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (B_, B_, C)
|
214 |
+
sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
|
215 |
+
sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_)
|
216 |
+
sim_matrix = torch.clamp(sim_matrix, min=-1+1e-10, max=1-1e-10)
|
217 |
+
|
218 |
+
# ranking based loss prep
|
219 |
+
'''
|
220 |
+
l2_dist = torch.cdist(emb, emb, p=2) # i-> j distances
|
221 |
+
KLD_
|
222 |
+
ranking_per_i = get_ranking() # 어디선가 i번째 instance에 대한 hardness를 불러��
|
223 |
+
|
224 |
+
|
225 |
+
'''
|
226 |
+
|
227 |
+
# Apply angular margin for positive pairs using return_mask
|
228 |
+
positive_mask, negative_mask, hard_negative_mask = self.return_mask(sim_matrix, positive_verbs, negative_verbs, verb_mask)
|
229 |
+
assert positive_mask.shape == sim_matrix.shape, f"Positive mask shape {positive_mask.shape} does not match sim_matrix shape {sim_matrix.shape}"
|
230 |
+
print(f"Positive mask: {positive_mask}")
|
231 |
+
print(f"Negative mask: {negative_mask}")
|
232 |
+
print(f"Hard negative mask: {hard_negative_mask}")
|
233 |
+
|
234 |
+
# Apply margin to positive pairs
|
235 |
+
sim_matrix_with_margin = sim_matrix.clone()
|
236 |
+
sim_matrix_with_margin[positive_mask.bool()] = torch.cos(torch.acos(sim_matrix[positive_mask.bool()]) + m / 57.2958)
|
237 |
+
|
238 |
+
# Scale logits with temperature
|
239 |
+
logits = sim_matrix_with_margin / tau
|
240 |
+
|
241 |
+
# Compute the softmax loss for all pairs
|
242 |
+
exp_logits = torch.exp(logits)
|
243 |
+
|
244 |
+
pos_exp_logits = exp_logits[positive_mask.bool()]
|
245 |
+
neg_exp_logits = exp_logits[negative_mask.bool()]
|
246 |
+
hardneg_exp_logits = exp_logits[hard_negative_mask.bool()]
|
247 |
+
|
248 |
+
# total_exp_logits = exp_logits.sum(dim=-1)
|
249 |
+
total_exp_logits = (
|
250 |
+
pos_exp_logits.sum(dim=-1) +
|
251 |
+
neg_exp_logits.sum(dim=-1) +
|
252 |
+
(hardneg_exp_logits.sum(dim=-1) * args.acl_hn_weight)
|
253 |
+
)
|
254 |
+
|
255 |
+
# Compute the final loss: L_arc = -log(e^(cos(theta + m)/tau) / sum(e^(cos(theta)/tau)))
|
256 |
+
positive_loss = -torch.log(pos_exp_logits / total_exp_logits)
|
257 |
+
angular_loss = positive_loss.mean()
|
258 |
+
|
259 |
+
return angular_loss
|