File size: 6,609 Bytes
da6d0ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.clip import build_model
from .layers import FPN, Projector, TransformerDecoder
# def MetricLoss(embeddings, n_pos, alpha = 0.5, args = None):
# # embeddings: ((2*B), C, (H*W))
# # n_pos : chunk size of positive pairs
# # args: args
# # returns: loss
# metric_loss = 0
# # flatten embeddings
# B_, C, HW = embeddings.shape
# emb = torch.mean(embeddings, dim=-1) # (2*B, C)
# emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C)
# emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C)
# emb_distance = torch.norm(emb_i - emb_j, dim=-1) # (2*B, 2*B)
# assert torch.sum(torch.diag(emb_distance[:B_, :B_])) == 0, \
# "Diagonals are not zero. please check the permutation on the batch"
# # print("distance metrix : ", emb_distance)
# # positive pairs and loss
# positive_mask = torch.zeros_like(emb_distance)
# for i in range(B_//2):
# positive_mask[2*i, 2*i+1] = 1
# positive_mask[2*i+1, 2*i] = 1
# positive_mask.fill_diagonal_(1)
# positive_loss = torch.sum(emb_distance * positive_mask) / B_
# # negative pairs and loss
# negative_mask = torch.ones_like(emb_distance) - positive_mask
# negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / (B_**2 - 2*B_))
# # print(positive_mask, negative_mask)
# metric_loss = alpha * positive_loss + (1-alpha) * negative_loss
# return metric_loss
class CRIS_S(nn.Module):
def __init__(self, cfg):
super().__init__()
# Vision & Text Encoder
clip_model = torch.jit.load(cfg.clip_pretrain,
map_location="cpu").eval()
self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float()
# Multi-Modal FPN
self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
# Decoder
self.decoder = TransformerDecoder(num_layers=cfg.num_layers,
d_model=cfg.vis_dim,
nhead=cfg.num_head,
dim_ffn=cfg.dim_ffn,
dropout=cfg.dropout,
return_intermediate=cfg.intermediate)
# Projector
self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3)
self.metric_learning = cfg.metric_learning
self.positive_strength = cfg.positive_strength
self.metric_loss_weight = cfg.metric_loss_weight
self.eps = cfg.ptb_rate
self.cfg = cfg
def forward(self, image, text, target=None):
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
if self.metric_learning:
word: b, 2, words
word_mask: b, 2, words
mask: b, 1, h, w
'''
metric_learning_flag = (self.metric_learning and self.training)
# TODO : mixing option btw distance & angular loss
mix_distance_angular = False
metric_loss = 0
# 1.Resizing : if metric learning, batch size of the word is doubled
if metric_learning_flag:
#print("image shape : ", image.shape)
b, c, h, w = image.size()
# duplicate image and segmentation mask
if image is not None:
image = torch.cat([image, image], dim=0)
image = image.reshape(-1, b, c, h, w).transpose(0, 1).reshape(-1, c, h, w)
if target is not None:
target = torch.cat([target, target], dim=0)
target = target.reshape(-1, b, 1, h, w).transpose(0, 1).reshape(-1, 1, h, w)
# duplicate noise mask
b_, n_, l_ = text.size()
assert n_ == 2 ,"word size should be 2"
noise_mask = (text[:, 0, :] == text[:, 1, :])
noise_mask = torch.all(noise_mask, dim=-1)
noise_mask = noise_mask.unsqueeze(-1).repeat(1, 2).reshape(-1) # 2*b_
assert noise_mask.shape[0] == b_ * 2, "noise mask shape should be 2*b_"
text = text.reshape(b_ * 2, l_) # 2*b, l
# print("text shape : ", text.shape)
# print("image shape : ", image.shape)
# print("target shape : ", target.shape)
# print(torch.sum(image[0::2]) == torch.sum(image[1::2]))
# print(torch.sum(target[0::2]) == torch.sum(target[1::2]))
# padding mask used in decoder
pad_mask = torch.zeros_like(text).masked_fill_(text == 0, 1).bool()
# vis: C3 / C4 / C5
# word: b, length, 1024
# state: b, 1024
vis = self.backbone.encode_image(image)
word, state = self.backbone.encode_text(text)
b_, d_ = state.size()
assert b_ == word.size(0), "batch size of state and word should be same"
# 2. State Noising Step : if number of caption is 1,
# add noise to the corresponding indices
if metric_learning_flag :
noise = torch.randn_like(state) * self.eps
state[noise_mask] = state[noise_mask] + noise[noise_mask]
# b, 512, 26, 26 (C4)
a3, a4, a5 = vis
fq, f5 = self.neck(vis, state)
b, c, h, w = fq.size()
fq = self.decoder(fq, word, pad_mask)
metric_tensor = fq
# if metric_learning_flag:
# metric_loss = AngularMetricLoss(fq, 2, alpha=self.positive_strength, args = self.cfg) # (1-self.positive_strength) *
# if mix_distance_angular:
# metric_loss += MetricLoss(fq, 2, alpha=self.positive_strength, args = self.cfg) # self.positive_strength *
fq = fq.reshape(b, c, h, w)
# b, 1, 104, 104
pred = self.proj(fq, state)
if self.training:
# resize mask
if pred.shape[-2:] != target.shape[-2:]:
target = F.interpolate(target, pred.shape[-2:],
mode='nearest').detach()
CE_loss = F.binary_cross_entropy_with_logits(pred, target)
# 4. if metric learning, add metric loss and normalize
# if metric_learning_flag:
# loss = (loss + self.metric_loss_weight * metric_loss) / (1+self.metric_loss_weight)
# safety_loss = loss * 0.
# loss = loss + safety_loss
return pred.detach(), target, CE_loss, metric_tensor
else:
#print(self.cfg.gpu, f"; loss = {loss}")
return pred.detach()
|