import random import torch import torch.nn as nn import torch.nn.functional as F from model.clip import build_model from .layers import FPN, Projector, TransformerDecoder def check_nan(x) : """ Check if there is NaN in tensor """ checker = False if True in torch.isnan(x): checker = True return checker def zero_filtering(x) : """ Add eps value for zero embedding, because competition metric is cosine similarity Cosine Similarity will be returned NaN, when input value has zero, like as torch.clamp() """ eps = 1e-4 x[x <= eps] = eps return x def nan_filtering(x, eps = 1e-4) : """ Change eps value for NaN Embedding, because competition metric is cosine similarity Cosine Similarity will be returned NaN """ return torch.nan_to_num(x, nan=eps) # def MetricLoss(embeddings, num_pos, alpha = 0.5, args = None): # # embeddings: ((2*B), C, (H*W)) # # n_pos : chunk size of positive pairs # # args: args # # returns: loss # metric_loss = 0 # # flatten embeddings # B_, C, HW = embeddings.shape # emb = torch.mean(embeddings, dim=-1) # (2*B, C) # emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C) # emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C) # emb_distance = torch.norm(emb_i - emb_j, dim=-1) # (2*B, 2*B) # assert torch.sum(torch.diag(emb_distance[:B_, :B_])) == 0, \ # "Diagonals are not zero. please check the permutation on the batch" # # print("distance metrix : ", emb_distance) # # positive pairs and loss # positive_mask = torch.zeros_like(emb_distance) # for i in range(B_//2): # positive_mask[2*i, 2*i+1] = 1 # positive_mask[2*i+1, 2*i] = 1 # positive_mask.fill_diagonal_(1) # positive_loss = torch.sum(emb_distance * positive_mask) / B_ # # negative pairs and loss # negative_mask = torch.ones_like(emb_distance) - positive_mask # negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / (B_**2 - 2*B_)) # # print(positive_mask, negative_mask) # metric_loss = alpha * positive_loss + (1-alpha) * negative_loss # return metric_loss # def return_mask(emb_distance, nsent): # B_, B_ = emb_distance.shape # positive_mask = torch.zeros_like(emb_distance) # for i in range(B_//nsent): # positive_mask[nsent*i, nsent*i+1] = 1 # positive_mask[nsent*i+1, nsent*i] = 1 # positive_mask.fill_diagonal_(1) # negative_mask = torch.ones_like(emb_distance) - positive_mask # return positive_mask, negative_mask # def AngularMetricLoss(embeddings, num_pos, num_neg, alpha = 0.5, args = None): # # embeddings: ((6*B), C, (H*W)) # # n_pos : chunk size of positive pairs # # args: args # # returns: loss # geometric_loss = 0 # nsent = num_pos + num_neg # assert nsent == 6, "number of sentences doesn't match" # nsent : S # # flatten embeddings # B_, C, HW = embeddings.shape # emb = torch.mean(embeddings, dim=-1) # (S*B, C) # emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (S*B, S*B, C) # emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (S*B, S*B, C) # ## zero filtering # sim = nn.CosineSimilarity(dim=-1, eps=1e-6) # sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (S*B , S*B) # sim_matrix = zero_filtering(sim_matrix) # if check_nan(sim_matrix) : # sim_matrix = nan_filtering(sim_matrix) # sim_matrix = torch.clamp(sim_matrix, min=-0.999, max=0.999) # phi = torch.acos(sim_matrix) # (S*B, S*B) # phi[torch.isnan(phi)] = 0 # # positive pairs and loss # positive_mask, negative_mask = return_mask(sim_matrix, nsent) # positive_loss = torch.sum((phi**2) * positive_mask) / B_ # # negative pairs and loss # # negative_mask = torch.ones_like(sim_matrix) - positive_mask # phi_mask = phi < args.phi_threshold # negative_loss = (args.phi_threshold - phi)**2 # negative_loss = zero_filtering(negative_loss) # if check_nan(negative_loss) : # negative_loss = nan_filtering(negative_loss) # if args.div_batch: # negative_loss = torch.sum(negative_loss * negative_mask * phi_mask) / B_ # else: # negative_loss = torch.sum(negative_loss * negative_mask * phi_mask) / (B_**2 - nsent*B_) # # print("pos loss, neg loss : ", positive_loss, negative_loss) # geometric_loss = alpha * positive_loss + (1-alpha) * negative_loss # return geometric_loss class CRIS_Wo_Noise(nn.Module): def __init__(self, cfg): super().__init__() # Vision & Text Encoder clip_model = torch.jit.load(cfg.clip_pretrain, map_location="cpu").eval() self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float() # Multi-Modal FPN self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) # Decoder self.decoder = TransformerDecoder(num_layers=cfg.num_layers, d_model=cfg.vis_dim, nhead=cfg.num_head, dim_ffn=cfg.dim_ffn, dropout=cfg.dropout, return_intermediate=cfg.intermediate) # Projector self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3) self.metric_learning = cfg.metric_learning self.positive_strength = cfg.positive_strength self.metric_loss_weight = cfg.metric_loss_weight self.add_noise = cfg.add_noise self.eps = cfg.ptb_rate self.cfg = cfg # self.bn_fq = nn.BatchNorm2d(1024) def forward(self, image, text, target=None): ''' img: b, 3, h, w word: b, words word_mask: b, words if self.metric_learning: word: b, 6, words word_mask: b, 6, words mask: b, 1, h, w ''' metric_learning_flag = (self.metric_learning and self.training) add_noise_flag = self.add_noise # TODO : mixing option btw distance & angular loss mix_distance_angular = False metric_loss = 0 #print("text shape : ", text.shape) if self.training: bt, nt, lt = text.size() else: nt = 1 bt, lt = text.size() npos= 2 nneg =nt-npos # 1.Resizing : if metric learning, batch size of the word is doubled if metric_learning_flag: #print("image shape : ", image.shape) b, c, h, w = image.size() # duplicate image and segmentation mask if image is not None: image = torch.cat([image, image, image, image, image, image], dim=0) image = image.reshape(-1, b, c, h, w).transpose(0, 1).reshape(-1, c, h, w) if target is not None: target = torch.cat([target, target, target, target, target, target], dim=0) target = target.reshape(-1, b, 1, h, w).transpose(0, 1).reshape(-1, 1, h, w) if add_noise_flag : noise_mask = (text[:, 0, :] == text[:, 1, :]) noise_mask = torch.all(noise_mask, dim=-1) noise_mask = noise_mask.unsqueeze(-1).repeat(1, 2).reshape(-1) # 2*b_ assert noise_mask.shape[0] == bt * npos, "noise mask shape should be 2*B_" text = text.reshape(bt * nt, lt) # 2*b, l # print(image.shape, image.dtype, image.type()) #float32 # print(target.shape, target.dtype, target.type()) #float32 # print(noise_mask.dtype, noise_mask.type()) #bool # padding mask used in decoder pad_mask = torch.zeros_like(text).masked_fill_(text == 0, 1).bool() # print(pad_mask.dtype, pad_mask.type()) # vis: C3 / C4 / C5 # word: b, length, 1024 # state: 6b, 1024 vis = self.backbone.encode_image(image) word, state = self.backbone.encode_text(text) # print(vis.dtype, vis.type()) # state= state.float() if check_nan(state) : print('state has nan valuses') state = nan_filtering(state) # print(state) b_, d_ = state.size() assert b_ == word.size(0), "batch size of state and word should be same" #npos = 2, nneg=4 if (add_noise_flag and self.training) : tmp_state = state.reshape(bt, nt, -1) pos_state = tmp_state[:, :npos, :].reshape(bt*npos, -1) neg_state = tmp_state[:, npos:, :].reshape(bt*nneg, -1) noise = torch.randn_like(pos_state) * self.eps pos_state_noisy = pos_state.clone() # Clone pos_state to avoid in-place operations pos_state_noisy[noise_mask] += noise[noise_mask] # Add noise where the mask is True new_state = torch.cat([pos_state_noisy, neg_state], dim=0) else: new_state = state.reshape(bt*nt, -1) # b, 512, 26, 26 (C4) a3, a4, a5 = vis fq, f5 = self.neck(vis, new_state) b, c, h, w = fq.size() fq = self.decoder(fq, word, pad_mask) metric_tensor = fq # # 3. Get metric loss # if metric_learning_flag: # metric_loss = AngularMetricLoss(fq, npos, nneg, alpha=self.positive_strength, args = self.cfg) fq = fq.reshape(b, c, h, w) # b, 1, 104, 104 pred = self.proj(fq, new_state) #print("pred shape : ", pred.shape, " fq shape : ", fq.shape, " new_state shape : ", new_state.shape) #breakpoint() if self.training: if pred.shape[-2:] != target.shape[-2:]: target = F.interpolate(target, pred.shape[-2:], mode='nearest').detach() # seunghoon : 임시로 size만 맞춰놓음 # b, _, h, w = pred.shape assert (pred.shape == target.shape), "pred shape and target shape should be same" pred = pred.reshape(-1, 6, h, w)[:, :2, :, :] # pred_neg = pred.reshape(-1, 6, h, w)[:, 2:, :, :] target = target.reshape(-1, 6, h, w)[:, :2, :, :] # target_neg = target.reshape(-1, 6, h, w)[:, 2:, :, :] CE_loss_pos = F.binary_cross_entropy_with_logits(pred, target) # loss_neg = nn.MSELoss()(pred_neg,torch.zeros_like(pred_neg)) # loss = loss_pos + loss_neg/(target_neg.shape[-1])**2 return pred.detach(), target, CE_loss_pos, metric_tensor else: #print(self.cfg.gpu, f"; loss = {loss}") return pred.detach() ## Original code # if self.training: # if pred.shape[-2:] != target.shape[-2:]: # target = F.interpolate(target, pred.shape[-2:], # mode='nearest').detach() # # seunghoon : 임시로 size만 맞춰놓음 # # b, _, h, w = pred.shape # assert (pred.shape == target.shape), "pred shape and target shape should be same" # pred = pred.reshape(-1, 6, h, w)[:, :2, :, :] # # pred_neg = pred.reshape(-1, 6, h, w)[:, 2:, :, :] # target = target.reshape(-1, 6, h, w)[:, :2, :, :] # # target_neg = target.reshape(-1, 6, h, w)[:, 2:, :, :] # CE_loss_pos = F.binary_cross_entropy_with_logits(pred, target) # # loss_neg = nn.MSELoss()(pred_neg,torch.zeros_like(pred_neg)) # # loss = loss_pos + loss_neg/(target_neg.shape[-1])**2 # # 4. if metric learning, add metric loss and normalize # if metric_learning_flag: # # print("CE loss : ", CE_loss_pos, "metric loss : ", metric_loss) # loss = (CE_loss_pos + self.metric_loss_weight * metric_loss) / (1+self.metric_loss_weight) # # DDP error handling : if there is no negative(BS = 1 or 0 for some GPUs), \ # # connect graph to avoid error # safety_loss = loss * 0. # loss = loss + safety_loss # # print(self.cfg.gpu, f"; loss = {loss}") # return pred.detach(), target, loss # else: # #print(self.cfg.gpu, f"; loss = {loss}") # return pred.detach()