VerbCentric-RIS / model_ /segmenter_verbonly_hardneg.py
dianecy's picture
Upload folder using huggingface_hub
da6d0ff verified
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from model_.clip import build_model
from .layers import FPN, Projector, TransformerDecoder
class CRIS_VerbOnly(nn.Module):
def __init__(self, cfg):
super().__init__()
# Vision & Text Encoder
clip_model = torch.jit.load(cfg.clip_pretrain,
map_location="cpu").eval()
self.backbone = build_model(clip_model.state_dict(), cfg.word_len, cfg.freeze).float()
# Multi-Modal FPN
self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
# Decoder
self.decoder = TransformerDecoder(num_layers=cfg.num_layers,
d_model=cfg.vis_dim,
nhead=cfg.num_head,
dim_ffn=cfg.dim_ffn,
dropout=cfg.dropout,
return_intermediate=cfg.intermediate)
# Projector
self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3)
self.metric_learning = False # cfg.metric_learning
# self.positive_strength = cfg.positive_strength
self.metric_loss_weight = cfg.metric_loss_weight
# self.eps = cfg.ptb_rate
self.cfg = cfg
def forward(self, image, text, target=None, hardpos=None, hardneg=None):
'''
image: b, 3, h, w
text: b, words
target: b, 1, h, w
verb: b, words (if applicable, only used in training mode for contrastive learning)
'''
if self.training:
sentences, images, targets, pad_masks = [], [], [], []
posverb_mask, negverb_mask = [], []
cl_masks = []
for idx in range(len(text)):
sentences.append(text[idx])
images.append(image[idx])
targets.append(target[idx])
pad_masks.append(torch.zeros_like(text[idx]).masked_fill_(text[idx] == 0, 1).bool())
if hardpos[idx].numel() > 0 and hardpos[idx].sum().item() > 0:
# if hard positive exists, check the condition
if hardneg[idx].numel() > 0 and hardneg[idx].sum().item() > 0:
# if hard positive and hard negative exists
posverb_mask.extend([1, 1, 0]) # mark original, hard positive as 1, negative as 0
negverb_mask.extend([0, 0, 1]) # mark only negative as 1
if not self.cfg.hn_celoss :
cl_masks.extend([1, 0, 0]) # mark only original as 1
else :
cl_masks.extend([1, 0, 1])
sentences.extend([hardpos[idx], hardneg[idx]])
images.extend([image[idx], image[idx]])
targets.extend([target[idx], torch.zeros_like(original_target, device=original_target.device)])
pad_masks.extend([
torch.zeros_like(hardpos[idx]).masked_fill_(hardpos[idx] == 0, 1).bool(),
torch.zeros_like(hardneg[idx]).masked_fill_(hardneg[idx] == 0, 1).bool()
])
else :
# only hard positive exists, no negatives
posverb_mask.extend([1, 1])
negverb_mask.extend([0, 0])
cl_masks.extend([1, 0])
sentences.append(hardpos[idx])
images.append(image[idx])
targets.append(target[idx])
pad_masks.append(torch.zeros_like(hardpos[idx]).masked_fill_(hardpos[idx] == 0, 1).bool())
else :
# no hard positive, no hard negative. only original sentence itself.
posverb_mask.append(0)
negverb_mask.append(0)
cl_masks.append(1)
sentences = torch.stack(sentences)
images = torch.stack(images)
targets = torch.stack(targets)
pad_masks = torch.stack(pad_masks)
posverb_mask = torch.tensor(posverb_mask, dtype=torch.bool)
negverb_mask = torch.tensor(negverb_mask, dtype=torch.bool)
cl_masks = torch.tensor(cl_masks, dtype=torch.bool)
else:
sentences = text
images = image
targets = target
pad_masks = torch.zeros_like(text).masked_fill_(text == 0, 1).bool()
# Encoding images and text
vis = self.backbone.encode_image(images)
word, state = self.backbone.encode_text(sentences)
# FPN neck and decoder
fq, f5 = self.neck(vis, state)
b, c, h, w = fq.size()
fq = self.decoder(fq, word, pad_masks)
metric_tensor = fq # b, c, h*w
fq = fq.reshape(b, c, h, w)
# Final prediction
pred = self.proj(fq, state)
if self.training:
if pred.shape[-2:] != targets.shape[-2:]:
targets = F.interpolate(targets, pred.shape[-2:], mode='nearest').detach()
loss = F.binary_cross_entropy_with_logits(pred[cl_masks], targets[cl_masks])
if self.metric_learning:
metric_loss = self.compute_metric_loss(metric_tensor, posverb_mask, negverb_mask, args=self.cfg)
loss = (loss + self.metric_loss_weight * metric_loss) / (1 + self.metric_loss_weight)
return pred.detach(), targets, loss
return pred.detach()
def compute_metric_loss(self, metric_tensor, positive_verbs, negative_verbs, args) :
if args.loss_option == "ACL_verbonly" :
metric_loss = self.UniAngularContrastLoss(metric_tensor, positive_verbs, negative_verbs, m=args.margin_value, tau=args.temperature, verbonly=True, args=args)
elif args.loss_option == "ACL" :
metric_loss = self.UniAngularContrastLoss(metric_tensor, positive_verbs, negative_verbs, m=args.margin_value, tau=args.temperature, verbonly=False, args=args)
return metric_loss
def return_mask(self, emb_distance, positive_verbs, negative_verbs, verb_mask):
B_, B_ = emb_distance.shape
positive_mask = torch.zeros_like(emb_distance)
negative_mask = torch.ones_like(emb_distance)
hard_negative_mask = torch.zeros_like(emb_distance)
positive_mask.fill_diagonal_(1)
if B_ < len(verb_mask):
# Considering only verbs that pass the verb_mask filter
positive_verbs = torch.tensor(positive_verbs)[verb_mask]
negative_verbs = torch.tensor(negative_verbs)[verb_mask]
# Exclude hard negatives from both masks (diagonal)
for i in range(B_):
if negative_verbs[i] == 1:
positive_mask[i, i] = 0
negative_mask[i, i] = 0
# Set the entire row and column for the hard negative, except the diagonal
hard_negative_mask[i, :] = 1 # Mark the i-th row
hard_negative_mask[:, i] = 1 # Mark the i-th column
hard_negative_mask[i, i] = 0 # Ensure diagonal element (i, i) is 0
i = 0
while i < B_:
if positive_verbs[i] == 1:
if i + 1 < B_ and positive_verbs[i + 1] == 1:
positive_mask[i, i + 1] = 1
positive_mask[i + 1, i] = 1
i += 2
else:
i += 1
else:
# Exclude hard negatives from both masks (diagonal)
for i in range(B_):
if negative_verbs[i] == 1:
positive_mask[i, i] = 0
negative_mask[i, i] = 0
# Set the entire row and column for the hard negative, except the diagonal
hard_negative_mask[i, :] = 1 # Mark the i-th row
hard_negative_mask[:, i] = 1 # Mark the i-th column
hard_negative_mask[i, i] = 0 # Ensure diagonal element (i, i) is 0
# Apply the positive pairs logic similarly as above
i = 0
while i < B_:
if positive_verbs[i] == 1 and i + 1 < B_ and positive_verbs[i + 1] == 1:
positive_mask[i, i + 1] = 1
positive_mask[i + 1, i] = 1
i += 2
else:
i += 1
negative_mask = negative_mask - positive_mask
negative_mask[hard_negative_mask.bool()] = 0 # Set hard negative indices to 0 in negative_mask
return positive_mask, negative_mask, hard_negative_mask
def UniAngularContrastLoss(self, total_fq, positive_verbs, negative_verbs, m=0.5, tau=0.05, verbonly=True, args=None):
"""
Angular Margin Contrastive Loss function with mask visualization.
"""
verb_mask = positive_verbs + negative_verbs
if verbonly:
emb = torch.mean(total_fq[verb_mask], dim=-1)
else:
emb = torch.mean(total_fq, dim=-1) # (B, C)
B_ = emb.shape[0]
# emb = F.normalize(emb, p=2, dim=1)
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (B_, B_, C)
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (B_, B_, C)
sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_)
sim_matrix = torch.clamp(sim_matrix, min=-1+1e-10, max=1-1e-10)
# ranking based loss prep
'''
l2_dist = torch.cdist(emb, emb, p=2) # i-> j distances
KLD_
ranking_per_i = get_ranking() # 어디선가 i번째 instance에 대한 hardness를 불러옴
'''
# Apply angular margin for positive pairs using return_mask
positive_mask, negative_mask, hard_negative_mask = self.return_mask(sim_matrix, positive_verbs, negative_verbs, verb_mask)
assert positive_mask.shape == sim_matrix.shape, f"Positive mask shape {positive_mask.shape} does not match sim_matrix shape {sim_matrix.shape}"
print(f"Positive mask: {positive_mask}")
print(f"Negative mask: {negative_mask}")
print(f"Hard negative mask: {hard_negative_mask}")
# Apply margin to positive pairs
sim_matrix_with_margin = sim_matrix.clone()
sim_matrix_with_margin[positive_mask.bool()] = torch.cos(torch.acos(sim_matrix[positive_mask.bool()]) + m / 57.2958)
# Scale logits with temperature
logits = sim_matrix_with_margin / tau
# Compute the softmax loss for all pairs
exp_logits = torch.exp(logits)
pos_exp_logits = exp_logits[positive_mask.bool()]
neg_exp_logits = exp_logits[negative_mask.bool()]
hardneg_exp_logits = exp_logits[hard_negative_mask.bool()]
# total_exp_logits = exp_logits.sum(dim=-1)
total_exp_logits = (
pos_exp_logits.sum(dim=-1) +
neg_exp_logits.sum(dim=-1) +
(hardneg_exp_logits.sum(dim=-1) * args.acl_hn_weight)
)
# Compute the final loss: L_arc = -log(e^(cos(theta + m)/tau) / sum(e^(cos(theta)/tau)))
positive_loss = -torch.log(pos_exp_logits / total_exp_logits)
angular_loss = positive_loss.mean()
return angular_loss