File size: 8,068 Bytes
31dfd6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

from model.clip import build_model
from .layers import FPN, Projector, TransformerDecoder


def MetricLoss(embeddings, n_pos, alpha = 0.5, args = None):
    # embeddings: ((2*B), C, (H*W))
    # n_pos : chunk size of positive pairs
    # args: args
    # returns: loss
    metric_loss = 0

    # flatten embeddings
    B_, C, HW = embeddings.shape
    emb = torch.mean(embeddings, dim=-1) # (2*B, C)
    emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C)
    emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C)
    emb_distance = torch.norm(emb_i - emb_j, dim=-1) # (2*B, 2*B)
    assert torch.sum(torch.diag(emb_distance[:B_, :B_])) == 0, \
    "Diagonals are not zero. please check the permutation on the batch"
    # print("distance metrix : ", emb_distance)

    # positive pairs and loss
    positive_mask = torch.zeros_like(emb_distance)
    for i in range(B_//2):
        positive_mask[2*i, 2*i+1] = 1
        positive_mask[2*i+1, 2*i] = 1
    positive_mask.fill_diagonal_(1)
    positive_loss = torch.sum(emb_distance * positive_mask) / B_

    # negative pairs and loss
    negative_mask = torch.ones_like(emb_distance) - positive_mask

    if args.div_batch:
        negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / B_)
    else:
        negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / (B_**2 - 2*B_))

    # print(positive_mask, negative_mask)

    metric_loss = alpha * positive_loss + (1-alpha) * negative_loss

    return metric_loss


def AngularMetricLoss(embeddings, n_pos, alpha = 0.5, args = None):
    # embeddings: ((2*B), C, (H*W))
    # n_pos : chunk size of positive pairs
    # args: args
    # returns: loss
    geometric_loss = 0

    # flatten embeddings
    B_, C, HW = embeddings.shape
    emb = torch.mean(embeddings, dim=-1) # (2*B, C)
    emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C)
    emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C)
    sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
    sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (2*B , 2*B)
    print(sim_matrix)
    assert torch.trace(sim_matrix) == B_, \
    "similarity diagonals are not one. please check the permutation on the batch"
    print("similarity metrix : ", sim_matrix)
    phi = torch.acos(sim_matrix) # (2*B, 2*B)
    print("phi metrix : ", phi)

    # positive pairs and loss
    positive_mask = torch.zeros_like(sim_matrix)
    for i in range(B_//2):
        positive_mask[2*i, 2*i+1] = 1
        positive_mask[2*i+1, 2*i] = 1
    positive_mask.fill_diagonal_(1)
    positive_loss = torch.sum((phi**2) * positive_mask) / B_

    # negative pairs and loss
    negative_mask = torch.ones_like(sim_matrix) - positive_mask
    phi_mask = phi < args.phi_threshold
    negative_loss = (args.phi_threshold - phi)**2 
    print(negative_mask * phi_mask)
    negative_loss = torch.sum(negative_loss * negative_mask * phi_mask) / (B_**2 - 2*B_)

    print("pos loss, neg loss : ", positive_loss, negative_loss)

    geometric_loss = alpha * positive_loss + (1-alpha) * negative_loss

    return geometric_loss


class CRIS(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        # Vision & Text Encoder
        clip_model = torch.jit.load(cfg.clip_pretrain,
                                    map_location="cpu").eval()
        self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float()
        # Multi-Modal FPN
        self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
        # Decoder
        self.decoder = TransformerDecoder(num_layers=cfg.num_layers,
                                            d_model=cfg.vis_dim,
                                            nhead=cfg.num_head,
                                            dim_ffn=cfg.dim_ffn,
                                            dropout=cfg.dropout,
                                            return_intermediate=cfg.intermediate)
        # Projector
        self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3)
        self.metric_learning = cfg.metric_learning
        self.positive_strength = cfg.positive_strength
        self.metric_loss_weight = cfg.metric_loss_weight
        self.eps = cfg.ptb_rate
        self.cfg = cfg

    def forward(self, image, text, target=None):
        '''
            img: b, 3, h, w
            word: b, words
            word_mask: b, words
            if self.metric_learning:
                word: b, 2, words
                word_mask: b, 2, words
            mask: b, 1, h, w
        '''
        metric_learning_flag = (self.metric_learning and self.training)
        metric_loss = 0

        # 1.Resizing : if metric learning, batch size of the word is doubled
        if metric_learning_flag:
            #print("image shape : ", image.shape)
            b, c, h, w = image.size()
            # duplicate image and segmentation mask
            if image is not None:
                image = torch.cat([image, image], dim=0)
                image = image.reshape(-1, b, c, h, w).transpose(0, 1).reshape(-1, c, h, w)
            if target is not None:
                target = torch.cat([target, target], dim=0)
                target = target.reshape(-1, b, 1, h, w).transpose(0, 1).reshape(-1, 1, h, w)
            # duplicate noise mask
            b_, n_, l_ = text.size()
            assert n_ == 2 ,"word size should be 2"
            noise_mask = (text[:, 0, :] == text[:, 1, :])
            noise_mask = torch.all(noise_mask, dim=-1)
            noise_mask = noise_mask.unsqueeze(-1).repeat(1, 2).reshape(-1) # 2*b_
            assert noise_mask.shape[0] == b_ * 2, "noise mask shape should be 2*b_"
            text = text.reshape(b_ * 2, l_) # 2*b, l

        # print("text shape : ", text.shape)
        # print("image shape : ", image.shape)
        # print("target shape : ", target.shape)
        # print(torch.sum(image[0::2]) == torch.sum(image[1::2]))
        # print(torch.sum(target[0::2]) == torch.sum(target[1::2]))
        
        # padding mask used in decoder
        pad_mask = torch.zeros_like(text).masked_fill_(text == 0, 1).bool()
        # vis: C3 / C4 / C5
        # word: b, length, 1024
        # state: b, 1024
        vis = self.backbone.encode_image(image)
        word, state = self.backbone.encode_text(text)

        b_, d_ = state.size()
        assert b_ == word.size(0), "batch size of state and word should be same"


        # 2. State Noising Step : if number of caption is 1,
        # add noise to the corresponding indices
        if metric_learning_flag :
            noise = torch.randn_like(state) * self.eps
            state[noise_mask] = state[noise_mask] + noise[noise_mask]

        # print("shape of word, state : ", word.shape, state.shape)

        # b, 512, 26, 26 (C4)
        a3, a4, a5 = vis
        # print("vis shape in model " , a3.shape, a4.shape, a5.shape)
        fq, f5 = self.neck(vis, state)
        b, c, h, w = fq.size()
        fq = self.decoder(fq, word, pad_mask)
        # print("decoder output shape : ", fq.shape)
        # 3. Get metric loss
        if metric_learning_flag:
            metric_loss = MetricLoss(fq, 2, alpha=self.positive_strength, args = self.cfg)
            
        fq = fq.reshape(b, c, h, w)

        # b, 1, 104, 104
        pred = self.proj(fq, state)

        if self.training:
            # resize mask
            if pred.shape[-2:] != target.shape[-2:]:
                target = F.interpolate(target, pred.shape[-2:],
                                    mode='nearest').detach()
            loss = F.binary_cross_entropy_with_logits(pred, target)
            # 4. if metric learning, add metric loss and normalize
            if metric_learning_flag:
                #print("CE loss : ", loss, "metric loss : ", metric_loss)
                loss = (loss + self.metric_loss_weight * metric_loss) / (1+self.metric_loss_weight)
            return pred.detach(), target, loss
        else:
            return pred.detach()