File size: 6,609 Bytes
da6d0ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

from model.clip import build_model

from .layers import FPN, Projector, TransformerDecoder


# def MetricLoss(embeddings, n_pos, alpha = 0.5, args = None):
#     # embeddings: ((2*B), C, (H*W))
#     # n_pos : chunk size of positive pairs
#     # args: args
#     # returns: loss
#     metric_loss = 0

#     # flatten embeddings
#     B_, C, HW = embeddings.shape
#     emb = torch.mean(embeddings, dim=-1) # (2*B, C)
#     emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C)
#     emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C)
#     emb_distance = torch.norm(emb_i - emb_j, dim=-1) # (2*B, 2*B)
#     assert torch.sum(torch.diag(emb_distance[:B_, :B_])) == 0, \
#     "Diagonals are not zero. please check the permutation on the batch"
#     # print("distance metrix : ", emb_distance)

#     # positive pairs and loss
#     positive_mask = torch.zeros_like(emb_distance)
#     for i in range(B_//2):
#         positive_mask[2*i, 2*i+1] = 1
#         positive_mask[2*i+1, 2*i] = 1
#     positive_mask.fill_diagonal_(1)
#     positive_loss = torch.sum(emb_distance * positive_mask) / B_

#     # negative pairs and loss
#     negative_mask = torch.ones_like(emb_distance) - positive_mask
#     negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / (B_**2 - 2*B_))

#     # print(positive_mask, negative_mask)

#     metric_loss = alpha * positive_loss + (1-alpha) * negative_loss

#     return metric_loss



class CRIS_S(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        # Vision & Text Encoder
        clip_model = torch.jit.load(cfg.clip_pretrain,
                                    map_location="cpu").eval()
        self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float()
        # Multi-Modal FPN
        self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
        # Decoder
        self.decoder = TransformerDecoder(num_layers=cfg.num_layers,
                                            d_model=cfg.vis_dim,
                                            nhead=cfg.num_head,
                                            dim_ffn=cfg.dim_ffn,
                                            dropout=cfg.dropout,
                                            return_intermediate=cfg.intermediate)
        # Projector
        self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3)
        self.metric_learning = cfg.metric_learning
        self.positive_strength = cfg.positive_strength
        self.metric_loss_weight = cfg.metric_loss_weight
        self.eps = cfg.ptb_rate
        self.cfg = cfg

    def forward(self, image, text, target=None):
        '''
            img: b, 3, h, w
            word: b, words
            word_mask: b, words
            if self.metric_learning:
                word: b, 2, words
                word_mask: b, 2, words
            mask: b, 1, h, w
        '''
        metric_learning_flag = (self.metric_learning and self.training)
        # TODO : mixing option btw distance & angular loss
        mix_distance_angular = False
        metric_loss = 0

        # 1.Resizing : if metric learning, batch size of the word is doubled
        if metric_learning_flag:
            #print("image shape : ", image.shape)
            b, c, h, w = image.size()
            # duplicate image and segmentation mask
            if image is not None:
                image = torch.cat([image, image], dim=0)
                image = image.reshape(-1, b, c, h, w).transpose(0, 1).reshape(-1, c, h, w)
            if target is not None:
                target = torch.cat([target, target], dim=0)
                target = target.reshape(-1, b, 1, h, w).transpose(0, 1).reshape(-1, 1, h, w)
            # duplicate noise mask
            b_, n_, l_ = text.size()
            assert n_ == 2 ,"word size should be 2"
            noise_mask = (text[:, 0, :] == text[:, 1, :])
            noise_mask = torch.all(noise_mask, dim=-1)
            noise_mask = noise_mask.unsqueeze(-1).repeat(1, 2).reshape(-1) # 2*b_
            assert noise_mask.shape[0] == b_ * 2, "noise mask shape should be 2*b_"
            text = text.reshape(b_ * 2, l_) # 2*b, l

        # print("text shape : ", text.shape)
        # print("image shape : ", image.shape)
        # print("target shape : ", target.shape)
        # print(torch.sum(image[0::2]) == torch.sum(image[1::2]))
        # print(torch.sum(target[0::2]) == torch.sum(target[1::2]))
        
        # padding mask used in decoder
        pad_mask = torch.zeros_like(text).masked_fill_(text == 0, 1).bool()
        # vis: C3 / C4 / C5
        # word: b, length, 1024
        # state: b, 1024
        vis = self.backbone.encode_image(image)
        word, state = self.backbone.encode_text(text)

        b_, d_ = state.size()
        assert b_ == word.size(0), "batch size of state and word should be same"


        # 2. State Noising Step : if number of caption is 1,
        # add noise to the corresponding indices
        if metric_learning_flag :
            noise = torch.randn_like(state) * self.eps
            state[noise_mask] = state[noise_mask] + noise[noise_mask]


        # b, 512, 26, 26 (C4)
        a3, a4, a5 = vis
        fq, f5 = self.neck(vis, state)
        b, c, h, w = fq.size()
        fq = self.decoder(fq, word, pad_mask)
        metric_tensor = fq
        # if metric_learning_flag:
        #     metric_loss = AngularMetricLoss(fq, 2, alpha=self.positive_strength, args = self.cfg) #  (1-self.positive_strength) *
        #     if mix_distance_angular:
        #         metric_loss += MetricLoss(fq, 2, alpha=self.positive_strength, args = self.cfg) #  self.positive_strength * 
        fq = fq.reshape(b, c, h, w)

        # b, 1, 104, 104
        pred = self.proj(fq, state)

        if self.training:
            # resize mask
            if pred.shape[-2:] != target.shape[-2:]:
                target = F.interpolate(target, pred.shape[-2:],
                                    mode='nearest').detach()
            CE_loss = F.binary_cross_entropy_with_logits(pred, target)

            # 4. if metric learning, add metric loss and normalize
            # if metric_learning_flag:
            #     loss = (loss + self.metric_loss_weight * metric_loss) / (1+self.metric_loss_weight)
            #     safety_loss = loss * 0.
            #     loss = loss + safety_loss

            return pred.detach(), target, CE_loss, metric_tensor
        else:
            #print(self.cfg.gpu, f"; loss = {loss}")
            return pred.detach()