File size: 8,068 Bytes
31dfd6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.clip import build_model
from .layers import FPN, Projector, TransformerDecoder
def MetricLoss(embeddings, n_pos, alpha = 0.5, args = None):
# embeddings: ((2*B), C, (H*W))
# n_pos : chunk size of positive pairs
# args: args
# returns: loss
metric_loss = 0
# flatten embeddings
B_, C, HW = embeddings.shape
emb = torch.mean(embeddings, dim=-1) # (2*B, C)
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C)
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C)
emb_distance = torch.norm(emb_i - emb_j, dim=-1) # (2*B, 2*B)
assert torch.sum(torch.diag(emb_distance[:B_, :B_])) == 0, \
"Diagonals are not zero. please check the permutation on the batch"
# print("distance metrix : ", emb_distance)
# positive pairs and loss
positive_mask = torch.zeros_like(emb_distance)
for i in range(B_//2):
positive_mask[2*i, 2*i+1] = 1
positive_mask[2*i+1, 2*i] = 1
positive_mask.fill_diagonal_(1)
positive_loss = torch.sum(emb_distance * positive_mask) / B_
# negative pairs and loss
negative_mask = torch.ones_like(emb_distance) - positive_mask
if args.div_batch:
negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / B_)
else:
negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / (B_**2 - 2*B_))
# print(positive_mask, negative_mask)
metric_loss = alpha * positive_loss + (1-alpha) * negative_loss
return metric_loss
def AngularMetricLoss(embeddings, n_pos, alpha = 0.5, args = None):
# embeddings: ((2*B), C, (H*W))
# n_pos : chunk size of positive pairs
# args: args
# returns: loss
geometric_loss = 0
# flatten embeddings
B_, C, HW = embeddings.shape
emb = torch.mean(embeddings, dim=-1) # (2*B, C)
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C)
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C)
sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (2*B , 2*B)
print(sim_matrix)
assert torch.trace(sim_matrix) == B_, \
"similarity diagonals are not one. please check the permutation on the batch"
print("similarity metrix : ", sim_matrix)
phi = torch.acos(sim_matrix) # (2*B, 2*B)
print("phi metrix : ", phi)
# positive pairs and loss
positive_mask = torch.zeros_like(sim_matrix)
for i in range(B_//2):
positive_mask[2*i, 2*i+1] = 1
positive_mask[2*i+1, 2*i] = 1
positive_mask.fill_diagonal_(1)
positive_loss = torch.sum((phi**2) * positive_mask) / B_
# negative pairs and loss
negative_mask = torch.ones_like(sim_matrix) - positive_mask
phi_mask = phi < args.phi_threshold
negative_loss = (args.phi_threshold - phi)**2
print(negative_mask * phi_mask)
negative_loss = torch.sum(negative_loss * negative_mask * phi_mask) / (B_**2 - 2*B_)
print("pos loss, neg loss : ", positive_loss, negative_loss)
geometric_loss = alpha * positive_loss + (1-alpha) * negative_loss
return geometric_loss
class CRIS(nn.Module):
def __init__(self, cfg):
super().__init__()
# Vision & Text Encoder
clip_model = torch.jit.load(cfg.clip_pretrain,
map_location="cpu").eval()
self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float()
# Multi-Modal FPN
self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
# Decoder
self.decoder = TransformerDecoder(num_layers=cfg.num_layers,
d_model=cfg.vis_dim,
nhead=cfg.num_head,
dim_ffn=cfg.dim_ffn,
dropout=cfg.dropout,
return_intermediate=cfg.intermediate)
# Projector
self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3)
self.metric_learning = cfg.metric_learning
self.positive_strength = cfg.positive_strength
self.metric_loss_weight = cfg.metric_loss_weight
self.eps = cfg.ptb_rate
self.cfg = cfg
def forward(self, image, text, target=None):
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
if self.metric_learning:
word: b, 2, words
word_mask: b, 2, words
mask: b, 1, h, w
'''
metric_learning_flag = (self.metric_learning and self.training)
metric_loss = 0
# 1.Resizing : if metric learning, batch size of the word is doubled
if metric_learning_flag:
#print("image shape : ", image.shape)
b, c, h, w = image.size()
# duplicate image and segmentation mask
if image is not None:
image = torch.cat([image, image], dim=0)
image = image.reshape(-1, b, c, h, w).transpose(0, 1).reshape(-1, c, h, w)
if target is not None:
target = torch.cat([target, target], dim=0)
target = target.reshape(-1, b, 1, h, w).transpose(0, 1).reshape(-1, 1, h, w)
# duplicate noise mask
b_, n_, l_ = text.size()
assert n_ == 2 ,"word size should be 2"
noise_mask = (text[:, 0, :] == text[:, 1, :])
noise_mask = torch.all(noise_mask, dim=-1)
noise_mask = noise_mask.unsqueeze(-1).repeat(1, 2).reshape(-1) # 2*b_
assert noise_mask.shape[0] == b_ * 2, "noise mask shape should be 2*b_"
text = text.reshape(b_ * 2, l_) # 2*b, l
# print("text shape : ", text.shape)
# print("image shape : ", image.shape)
# print("target shape : ", target.shape)
# print(torch.sum(image[0::2]) == torch.sum(image[1::2]))
# print(torch.sum(target[0::2]) == torch.sum(target[1::2]))
# padding mask used in decoder
pad_mask = torch.zeros_like(text).masked_fill_(text == 0, 1).bool()
# vis: C3 / C4 / C5
# word: b, length, 1024
# state: b, 1024
vis = self.backbone.encode_image(image)
word, state = self.backbone.encode_text(text)
b_, d_ = state.size()
assert b_ == word.size(0), "batch size of state and word should be same"
# 2. State Noising Step : if number of caption is 1,
# add noise to the corresponding indices
if metric_learning_flag :
noise = torch.randn_like(state) * self.eps
state[noise_mask] = state[noise_mask] + noise[noise_mask]
# print("shape of word, state : ", word.shape, state.shape)
# b, 512, 26, 26 (C4)
a3, a4, a5 = vis
# print("vis shape in model " , a3.shape, a4.shape, a5.shape)
fq, f5 = self.neck(vis, state)
b, c, h, w = fq.size()
fq = self.decoder(fq, word, pad_mask)
# print("decoder output shape : ", fq.shape)
# 3. Get metric loss
if metric_learning_flag:
metric_loss = MetricLoss(fq, 2, alpha=self.positive_strength, args = self.cfg)
fq = fq.reshape(b, c, h, w)
# b, 1, 104, 104
pred = self.proj(fq, state)
if self.training:
# resize mask
if pred.shape[-2:] != target.shape[-2:]:
target = F.interpolate(target, pred.shape[-2:],
mode='nearest').detach()
loss = F.binary_cross_entropy_with_logits(pred, target)
# 4. if metric learning, add metric loss and normalize
if metric_learning_flag:
#print("CE loss : ", loss, "metric loss : ", metric_loss)
loss = (loss + self.metric_loss_weight * metric_loss) / (1+self.metric_loss_weight)
return pred.detach(), target, loss
else:
return pred.detach()
|