import random import torch import torch.nn as nn import torch.nn.functional as F from model.clip import build_model from .layers import FPN, Projector, TransformerDecoder # def MetricLoss(embeddings, n_pos, alpha = 0.5, args = None): # # embeddings: ((2*B), C, (H*W)) # # n_pos : chunk size of positive pairs # # args: args # # returns: loss # metric_loss = 0 # # flatten embeddings # B_, C, HW = embeddings.shape # emb = torch.mean(embeddings, dim=-1) # (2*B, C) # emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C) # emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C) # emb_distance = torch.norm(emb_i - emb_j, dim=-1) # (2*B, 2*B) # assert torch.sum(torch.diag(emb_distance[:B_, :B_])) == 0, \ # "Diagonals are not zero. please check the permutation on the batch" # # print("distance metrix : ", emb_distance) # # positive pairs and loss # positive_mask = torch.zeros_like(emb_distance) # for i in range(B_//2): # positive_mask[2*i, 2*i+1] = 1 # positive_mask[2*i+1, 2*i] = 1 # positive_mask.fill_diagonal_(1) # positive_loss = torch.sum(emb_distance * positive_mask) / B_ # # negative pairs and loss # negative_mask = torch.ones_like(emb_distance) - positive_mask # negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / (B_**2 - 2*B_)) # # print(positive_mask, negative_mask) # metric_loss = alpha * positive_loss + (1-alpha) * negative_loss # return metric_loss class CRIS_S(nn.Module): def __init__(self, cfg): super().__init__() # Vision & Text Encoder clip_model = torch.jit.load(cfg.clip_pretrain, map_location="cpu").eval() self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float() # Multi-Modal FPN self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) # Decoder self.decoder = TransformerDecoder(num_layers=cfg.num_layers, d_model=cfg.vis_dim, nhead=cfg.num_head, dim_ffn=cfg.dim_ffn, dropout=cfg.dropout, return_intermediate=cfg.intermediate) # Projector self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3) self.metric_learning = cfg.metric_learning self.positive_strength = cfg.positive_strength self.metric_loss_weight = cfg.metric_loss_weight self.eps = cfg.ptb_rate self.cfg = cfg def forward(self, image, text, target=None): ''' img: b, 3, h, w word: b, words word_mask: b, words if self.metric_learning: word: b, 2, words word_mask: b, 2, words mask: b, 1, h, w ''' metric_learning_flag = (self.metric_learning and self.training) # TODO : mixing option btw distance & angular loss mix_distance_angular = False metric_loss = 0 # 1.Resizing : if metric learning, batch size of the word is doubled if metric_learning_flag: #print("image shape : ", image.shape) b, c, h, w = image.size() # duplicate image and segmentation mask if image is not None: image = torch.cat([image, image], dim=0) image = image.reshape(-1, b, c, h, w).transpose(0, 1).reshape(-1, c, h, w) if target is not None: target = torch.cat([target, target], dim=0) target = target.reshape(-1, b, 1, h, w).transpose(0, 1).reshape(-1, 1, h, w) # duplicate noise mask b_, n_, l_ = text.size() assert n_ == 2 ,"word size should be 2" noise_mask = (text[:, 0, :] == text[:, 1, :]) noise_mask = torch.all(noise_mask, dim=-1) noise_mask = noise_mask.unsqueeze(-1).repeat(1, 2).reshape(-1) # 2*b_ assert noise_mask.shape[0] == b_ * 2, "noise mask shape should be 2*b_" text = text.reshape(b_ * 2, l_) # 2*b, l # print("text shape : ", text.shape) # print("image shape : ", image.shape) # print("target shape : ", target.shape) # print(torch.sum(image[0::2]) == torch.sum(image[1::2])) # print(torch.sum(target[0::2]) == torch.sum(target[1::2])) # padding mask used in decoder pad_mask = torch.zeros_like(text).masked_fill_(text == 0, 1).bool() # vis: C3 / C4 / C5 # word: b, length, 1024 # state: b, 1024 vis = self.backbone.encode_image(image) word, state = self.backbone.encode_text(text) b_, d_ = state.size() assert b_ == word.size(0), "batch size of state and word should be same" # 2. State Noising Step : if number of caption is 1, # add noise to the corresponding indices if metric_learning_flag : noise = torch.randn_like(state) * self.eps state[noise_mask] = state[noise_mask] + noise[noise_mask] # b, 512, 26, 26 (C4) a3, a4, a5 = vis fq, f5 = self.neck(vis, state) b, c, h, w = fq.size() fq = self.decoder(fq, word, pad_mask) metric_tensor = fq # if metric_learning_flag: # metric_loss = AngularMetricLoss(fq, 2, alpha=self.positive_strength, args = self.cfg) # (1-self.positive_strength) * # if mix_distance_angular: # metric_loss += MetricLoss(fq, 2, alpha=self.positive_strength, args = self.cfg) # self.positive_strength * fq = fq.reshape(b, c, h, w) # b, 1, 104, 104 pred = self.proj(fq, state) if self.training: # resize mask if pred.shape[-2:] != target.shape[-2:]: target = F.interpolate(target, pred.shape[-2:], mode='nearest').detach() CE_loss = F.binary_cross_entropy_with_logits(pred, target) # 4. if metric learning, add metric loss and normalize # if metric_learning_flag: # loss = (loss + self.metric_loss_weight * metric_loss) / (1+self.metric_loss_weight) # safety_loss = loss * 0. # loss = loss + safety_loss return pred.detach(), target, CE_loss, metric_tensor else: #print(self.cfg.gpu, f"; loss = {loss}") return pred.detach()