import random import torch import torch.nn as nn import torch.nn.functional as F from model_.clip import build_model from .layers import FPN, Projector, TransformerDecoder class CRIS_PosOnly_rev(nn.Module): def __init__(self, cfg): super().__init__() # Vision & Text Encoder clip_model = torch.jit.load(cfg.clip_pretrain, map_location="cpu").eval() self.backbone = build_model(clip_model.state_dict(), cfg.word_len, cfg.freeze).float() # Multi-Modal FPN self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) # Decoder self.decoder = TransformerDecoder(num_layers=cfg.num_layers, d_model=cfg.vis_dim, nhead=cfg.num_head, dim_ffn=cfg.dim_ffn, dropout=cfg.dropout, return_intermediate=cfg.intermediate) # Projector self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3) self.metric_learning = False # cfg.metric_learning self.metric_loss_weight = cfg.metric_loss_weight self.cfg = cfg def forward(self, image, text, target=None, verb=None): ''' image: b, 3, h, w text: b, words target: b, 1, h, w verb: b, words (if applicable, only used in training mode for contrastive learning) ''' sentences, images, targets, pad_masks = [], [], [], [] if self.training: verb_masks = [] cl_masks = [] for idx in range(len(text)): sentences.append(text[idx]) images.append(image[idx]) targets.append(target[idx]) pad_masks.append(torch.zeros_like(text[idx]).masked_fill_(text[idx] == 0, 1).bool()) # If verb exists, process it if verb[idx].numel() > 0 and verb[idx].sum().item() > 0: verb_masks.extend([1, 1]) # Both original sentence and verb are marked cl_masks.extend([1, 0]) # Only original sentence get marked sentences.append(verb[idx]) images.append(image[idx]) targets.append(target[idx]) pad_masks.append(torch.zeros_like(verb[idx]).masked_fill_(verb[idx] == 0, 1).bool()) else: verb_masks.append(0) cl_masks.append(1) sentences = torch.stack(sentences) images = torch.stack(images) targets = torch.stack(targets) pad_masks = torch.stack(pad_masks) verb_masks = torch.tensor(verb_masks, dtype=torch.bool) cl_masks = torch.tensor(cl_masks, dtype=torch.bool) else: sentences = text images = image targets = target pad_masks = torch.zeros_like(text).masked_fill_(text == 0, 1).bool() # Encoding images and text vis = self.backbone.encode_image(images) word, state = self.backbone.encode_text(sentences) # FPN neck and decoder fq, f5 = self.neck(vis, state) b, c, h, w = fq.size() fq = self.decoder(fq, word, pad_masks) metric_tensor = fq # b, c, h*w fq = fq.reshape(b, c, h, w) # Final prediction pred = self.proj(fq, state) if self.training: if pred.shape[-2:] != targets.shape[-2:]: targets = F.interpolate(targets, pred.shape[-2:], mode='nearest').detach() loss = F.binary_cross_entropy_with_logits(pred[cl_masks], targets[cl_masks]) if self.metric_learning: metric_loss = self.compute_metric_loss(metric_tensor, verb_masks, args=self.cfg) loss = (loss + self.metric_loss_weight * metric_loss) / (1 + self.metric_loss_weight) return pred[cl_masks].detach(), targets[cl_masks], loss return pred.detach() # In eval mode, only return the predictions def compute_metric_loss(self, metric_tensor, positive_verbs, negative_verbs, args) : if args.loss_option == "ACL_verbonly" : metric_loss = self.UniAngularContrastLoss(metric_tensor, positive_verbs, negative_verbs, m=args.margin_value, tau=args.temperature, verbonly=True, args=args) elif args.loss_option == "ACL" : metric_loss = self.UniAngularContrastLoss(metric_tensor, positive_verbs, negative_verbs, m=args.margin_value, tau=args.temperature, verbonly=False, args=args) return metric_loss def return_mask(self, emb_distance, verb_mask=None): B_, B_ = emb_distance.shape positive_mask = torch.zeros_like(emb_distance) positive_mask.fill_diagonal_(1) # Set diagonal elements to 1 for all cases if B_ < len(verb_mask): # If B_ equals to 2*K (double the number of verb phrase) for i in range(B_ // 2): positive_mask[2 * i, 2 * i + 1] = 1 positive_mask[2 * i + 1, 2 * i] = 1 else: # Process the case where we have a mix of sentences with and without verbs i = 0 while i < B_: if verb_mask[i] == 1: positive_mask[i, i + 1] = 1 positive_mask[i + 1, i] = 1 i += 2 else: i += 1 negative_mask = torch.ones_like(emb_distance) - positive_mask return positive_mask, negative_mask def UniAngularContrastLoss(self, total_fq, verb_mask, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None): _, C, HW = total_fq.shape if verbonly : emb = torch.mean(total_fq[verb_mask], dim=-1) assert emb.shape[0] % 2 == 0, f"Embedding count {emb.shape[0]} is not divisible by 2." else : emb = torch.mean(total_fq, dim=-1) B_ = emb.shape[0] # emb = F.normalize(emb, p=2, dim=1) emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (B_, B_, C) emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (B_, B_, C) sim = nn.CosineSimilarity(dim=-1, eps=1e-6) sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_) sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999) positive_mask, negative_mask = self.return_mask(sim_matrix, verb_mask) # Apply margin to positive pairs sim_matrix_with_margin = sim_matrix.clone() sim_matrix_with_margin[positive_mask.bool()] = torch.cos(torch.acos(sim_matrix[positive_mask.bool()]) + m / 57.2958) # Scale logits with temperature logits = sim_matrix_with_margin / tau # Compute the softmax loss for all pairs exp_logits = torch.exp(logits) pos_exp_logits = exp_logits[positive_mask.bool()] total_exp_logits = exp_logits.sum(dim=-1) # Compute the final loss: L_arc = -log(e^(cos(theta + m)/tau) / sum(e^(cos(theta)/tau))) positive_loss = -torch.log(pos_exp_logits / total_exp_logits) angular_loss = positive_loss.mean() return angular_loss