VerbCentric-RIS / model /segmenter_verbonly_hardneg.py
dianecy's picture
Upload folder using huggingface_hub
31dfd6a verified
raw
history blame
7.98 kB
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.clip import build_model
from .layers import FPN, Projector, TransformerDecoder
class CRIS_VerbOnly(nn.Module):
def __init__(self, cfg):
super().__init__()
# Vision & Text Encoder
clip_model = torch.jit.load(cfg.clip_pretrain,
map_location="cpu").eval()
self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float()
# Multi-Modal FPN
self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
# Decoder
self.decoder = TransformerDecoder(num_layers=cfg.num_layers,
d_model=cfg.vis_dim,
nhead=cfg.num_head,
dim_ffn=cfg.dim_ffn,
dropout=cfg.dropout,
return_intermediate=cfg.intermediate)
# Projector
self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3)
self.metric_learning = False # cfg.metric_learning
self.positive_strength = cfg.positive_strength
self.metric_loss_weight = cfg.metric_loss_weight
self.eps = cfg.ptb_rate
self.cfg = cfg
def forward(self, image, text, target=None, hardpos=None, hardneg=None):
'''
image: b, 3, h, w
text: b, words
target: b, 1, h, w
verb: b, words (if applicable, only used in training mode for contrastive learning)
'''
sentences, images, targets, pad_masks = [], [], [], []
positive_verbs, negative_verbs = [], []
if self.training:
verb_masks = []
cl_masks = []
for idx in range(len(text)):
sentences.append(text[idx])
images.append(image[idx])
targets.append(target[idx])
pad_masks.append(torch.zeros_like(text[idx]).masked_fill_(text[idx] == 0, 1).bool())
# If verb exists, process it
if verb[idx].numel() > 0 and verb[idx].sum().item() > 0:
verb_masks.extend([1, 1]) # Both original sentence and verb are marked
cl_masks.extend([0, 1]) # Only verb gets marked for exclusion from CE loss
sentences.append(verb[idx])
images.append(image[idx])
targets.append(target[idx])
pad_masks.append(torch.zeros_like(verb[idx]).masked_fill_(verb[idx] == 0, 1).bool())
else:
verb_masks.append(0)
cl_masks.append(0)
sentences = torch.stack(sentences)
images = torch.stack(images)
targets = torch.stack(targets)
pad_masks = torch.stack(pad_masks)
verb_masks = torch.tensor(verb_masks, dtype=torch.bool)
cl_masks = torch.tensor(cl_masks, dtype=torch.bool)
else:
sentences = text
images = image
targets = target
pad_masks = torch.zeros_like(text).masked_fill_(text == 0, 1).bool()
# Encoding images and text
vis = self.backbone.encode_image(images)
word, state = self.backbone.encode_text(sentences)
# FPN neck and decoder
fq, f5 = self.neck(vis, state)
b, c, h, w = fq.size()
fq = self.decoder(fq, word, pad_masks)
metric_tensor = fq # b, c, h*w
fq = fq.reshape(b, c, h, w)
# Final prediction
pred = self.proj(fq, state)
if self.training:
if pred.shape[-2:] != targets.shape[-2:]:
targets = F.interpolate(targets, pred.shape[-2:], mode='nearest').detach()
loss = F.binary_cross_entropy_with_logits(pred[~cl_masks], targets[~cl_masks])
if self.metric_learning:
metric_loss = self.compute_metric_loss(metric_tensor, verb_masks, args=self.cfg)
loss = (loss + self.metric_loss_weight * metric_loss) / (1 + self.metric_loss_weight)
return pred.detach(), targets, loss
return pred.detach() # In eval mode, only return the predictions
def compute_metric_loss(self, metric_tensor, verb_mask, args) :
return None
def return_mask(self, emb_distance, positive_verbs, negative_verbs, verb_mask):
B_, B_ = emb_distance.shape
positive_mask = torch.zeros_like(emb_distance)
negative_mask = torch.ones_like(emb_distance)
positive_mask.fill_diagonal_(1)
if B_ < len(verb_mask):
# Considering only verbs that pass the verb_mask filter
positive_verbs = torch.tensor(positive_verbs)[verb_mask]
negative_verbs = torch.tensor(negative_verbs)[verb_mask]
# Exclude hard negatives from both masks (diagonal)
for i in range(B_):
if negative_verbs[i] == 1:
positive_mask[i, i] = 0
negative_mask[i, i] = 0
i = 0
while i < B_:
if positive_verbs[i] == 1:
if i + 1 < B_ and positive_verbs[i + 1] == 1:
positive_mask[i, i + 1] = 1
positive_mask[i + 1, i] = 1
i += 2
else:
i += 1
else:
# Exclude hard negatives from both masks (diagonal)
for i in range(B_):
if negative_verbs[i] == 1:
positive_mask[i, i] = 0
negative_mask[i, i] = 0
# Apply the positive pairs logic similarly as above
i = 0
while i < B_:
if positive_verbs[i] == 1 and i + 1 < B_ and positive_verbs[i + 1] == 1:
positive_mask[i, i + 1] = 1
positive_mask[i + 1, i] = 1
i += 2
else:
i += 1
negative_mask = negative_mask - positive_mask
return positive_mask, negative_mask
def UniAngularContrastLoss(self, total_fq, positive_verbs, negative_verbs, m=0.5, tau=0.05, verbonly=True, args=None):
"""
Angular Margin Contrastive Loss function with mask visualization.
"""
verb_mask = positive_verbs + negative_verbs
if verbonly:
emb = torch.mean(total_fq[verb_mask], dim=-1)
else:
emb = torch.mean(total_fq, dim=-1) # (B, C)
B_ = emb.shape[0]
# emb = F.normalize(emb, p=2, dim=1)
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (B_, B_, C)
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (B_, B_, C)
sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_)
sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999)
# Apply angular margin for positive pairs using return_mask
positive_mask, negative_mask = self.return_mask(sim_matrix, positive_verbs, negative_verbs, verb_mask)
# Apply margin to positive pairs
sim_matrix_with_margin = sim_matrix.clone()
sim_matrix_with_margin[positive_mask.bool()] = torch.cos(torch.acos(sim_matrix[positive_mask.bool()]) + m / 57.2958)
# Scale logits with temperature
logits = sim_matrix_with_margin / tau
# Compute the softmax loss for all pairs
exp_logits = torch.exp(logits)
pos_exp_logits = exp_logits[positive_mask.bool()]
total_exp_logits = exp_logits.sum(dim=-1)
# Compute the final loss: L_arc = -log(e^(cos(theta + m)/tau) / sum(e^(cos(theta)/tau)))
positive_loss = -torch.log(pos_exp_logits / total_exp_logits[positive_mask.bool()])
angular_loss = positive_loss.mean()
return angular_loss