VerbCentric-RIS / model_ /segmenter_verbonly_fin.py
dianecy's picture
Upload folder using huggingface_hub
da6d0ff verified
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from model_.clip import build_model
from .layers import FPN, Projector, TransformerDecoder
class CRIS_PosOnly_rev(nn.Module):
def __init__(self, cfg):
super().__init__()
# Vision & Text Encoder
clip_model = torch.jit.load(cfg.clip_pretrain,
map_location="cpu").eval()
self.backbone = build_model(clip_model.state_dict(), cfg.word_len, cfg.freeze).float()
# Multi-Modal FPN
self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
# Decoder
self.decoder = TransformerDecoder(num_layers=cfg.num_layers,
d_model=cfg.vis_dim,
nhead=cfg.num_head,
dim_ffn=cfg.dim_ffn,
dropout=cfg.dropout,
return_intermediate=cfg.intermediate)
# Projector
self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3)
self.metric_learning = False # cfg.metric_learning
self.metric_loss_weight = cfg.metric_loss_weight
self.cfg = cfg
def forward(self, image, text, target=None, verb=None):
'''
image: b, 3, h, w
text: b, words
target: b, 1, h, w
verb: b, words (if applicable, only used in training mode for contrastive learning)
'''
sentences, images, targets, pad_masks = [], [], [], []
if self.training:
verb_masks = []
cl_masks = []
for idx in range(len(text)):
sentences.append(text[idx])
images.append(image[idx])
targets.append(target[idx])
pad_masks.append(torch.zeros_like(text[idx]).masked_fill_(text[idx] == 0, 1).bool())
# If verb exists, process it
if verb[idx].numel() > 0 and verb[idx].sum().item() > 0:
verb_masks.extend([1, 1]) # Both original sentence and verb are marked
cl_masks.extend([1, 0]) # Only original sentence get marked
sentences.append(verb[idx])
images.append(image[idx])
targets.append(target[idx])
pad_masks.append(torch.zeros_like(verb[idx]).masked_fill_(verb[idx] == 0, 1).bool())
else:
verb_masks.append(0)
cl_masks.append(1)
sentences = torch.stack(sentences)
images = torch.stack(images)
targets = torch.stack(targets)
pad_masks = torch.stack(pad_masks)
verb_masks = torch.tensor(verb_masks, dtype=torch.bool)
cl_masks = torch.tensor(cl_masks, dtype=torch.bool)
else:
sentences = text
images = image
targets = target
pad_masks = torch.zeros_like(text).masked_fill_(text == 0, 1).bool()
# Encoding images and text
vis = self.backbone.encode_image(images)
word, state = self.backbone.encode_text(sentences)
# FPN neck and decoder
fq, f5 = self.neck(vis, state)
b, c, h, w = fq.size()
fq = self.decoder(fq, word, pad_masks)
metric_tensor = fq # b, c, h*w
fq = fq.reshape(b, c, h, w)
# Final prediction
pred = self.proj(fq, state)
if self.training:
if pred.shape[-2:] != targets.shape[-2:]:
targets = F.interpolate(targets, pred.shape[-2:], mode='nearest').detach()
loss = F.binary_cross_entropy_with_logits(pred[cl_masks], targets[cl_masks])
if self.metric_learning:
metric_loss = self.compute_metric_loss(metric_tensor, verb_masks, args=self.cfg)
loss = (loss + self.metric_loss_weight * metric_loss) / (1 + self.metric_loss_weight)
return pred[cl_masks].detach(), targets[cl_masks], loss
return pred.detach() # In eval mode, only return the predictions
def compute_metric_loss(self, metric_tensor, positive_verbs, negative_verbs, args) :
if args.loss_option == "ACL_verbonly" :
metric_loss = self.UniAngularContrastLoss(metric_tensor, positive_verbs, negative_verbs, m=args.margin_value, tau=args.temperature, verbonly=True, args=args)
elif args.loss_option == "ACL" :
metric_loss = self.UniAngularContrastLoss(metric_tensor, positive_verbs, negative_verbs, m=args.margin_value, tau=args.temperature, verbonly=False, args=args)
return metric_loss
def return_mask(self, emb_distance, verb_mask=None):
B_, B_ = emb_distance.shape
positive_mask = torch.zeros_like(emb_distance)
positive_mask.fill_diagonal_(1) # Set diagonal elements to 1 for all cases
if B_ < len(verb_mask):
# If B_ equals to 2*K (double the number of verb phrase)
for i in range(B_ // 2):
positive_mask[2 * i, 2 * i + 1] = 1
positive_mask[2 * i + 1, 2 * i] = 1
else:
# Process the case where we have a mix of sentences with and without verbs
i = 0
while i < B_:
if verb_mask[i] == 1:
positive_mask[i, i + 1] = 1
positive_mask[i + 1, i] = 1
i += 2
else:
i += 1
negative_mask = torch.ones_like(emb_distance) - positive_mask
return positive_mask, negative_mask
def UniAngularContrastLoss(self, total_fq, verb_mask, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None):
_, C, HW = total_fq.shape
if verbonly :
emb = torch.mean(total_fq[verb_mask], dim=-1)
assert emb.shape[0] % 2 == 0, f"Embedding count {emb.shape[0]} is not divisible by 2."
else :
emb = torch.mean(total_fq, dim=-1)
B_ = emb.shape[0]
# emb = F.normalize(emb, p=2, dim=1)
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (B_, B_, C)
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (B_, B_, C)
sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_)
sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999)
positive_mask, negative_mask = self.return_mask(sim_matrix, verb_mask)
# Apply margin to positive pairs
sim_matrix_with_margin = sim_matrix.clone()
sim_matrix_with_margin[positive_mask.bool()] = torch.cos(torch.acos(sim_matrix[positive_mask.bool()]) + m / 57.2958)
# Scale logits with temperature
logits = sim_matrix_with_margin / tau
# Compute the softmax loss for all pairs
exp_logits = torch.exp(logits)
pos_exp_logits = exp_logits[positive_mask.bool()]
total_exp_logits = exp_logits.sum(dim=-1)
# Compute the final loss: L_arc = -log(e^(cos(theta + m)/tau) / sum(e^(cos(theta)/tau)))
positive_loss = -torch.log(pos_exp_logits / total_exp_logits)
angular_loss = positive_loss.mean()
return angular_loss