|
import random |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from model_.clip import build_model |
|
|
|
from .layers import FPN, Projector, TransformerDecoder |
|
|
|
class CRIS_PosOnly_rev(nn.Module): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
|
|
clip_model = torch.jit.load(cfg.clip_pretrain, |
|
map_location="cpu").eval() |
|
self.backbone = build_model(clip_model.state_dict(), cfg.word_len, cfg.freeze).float() |
|
|
|
|
|
self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) |
|
|
|
self.decoder = TransformerDecoder(num_layers=cfg.num_layers, |
|
d_model=cfg.vis_dim, |
|
nhead=cfg.num_head, |
|
dim_ffn=cfg.dim_ffn, |
|
dropout=cfg.dropout, |
|
return_intermediate=cfg.intermediate) |
|
|
|
self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3) |
|
self.metric_learning = False |
|
self.metric_loss_weight = cfg.metric_loss_weight |
|
self.cfg = cfg |
|
|
|
|
|
|
|
|
|
def forward(self, image, text, target=None, verb=None): |
|
''' |
|
image: b, 3, h, w |
|
text: b, words |
|
target: b, 1, h, w |
|
verb: b, words (if applicable, only used in training mode for contrastive learning) |
|
''' |
|
|
|
sentences, images, targets, pad_masks = [], [], [], [] |
|
|
|
if self.training: |
|
verb_masks = [] |
|
cl_masks = [] |
|
|
|
for idx in range(len(text)): |
|
sentences.append(text[idx]) |
|
images.append(image[idx]) |
|
targets.append(target[idx]) |
|
pad_masks.append(torch.zeros_like(text[idx]).masked_fill_(text[idx] == 0, 1).bool()) |
|
|
|
|
|
if verb[idx].numel() > 0 and verb[idx].sum().item() > 0: |
|
verb_masks.extend([1, 1]) |
|
cl_masks.extend([1, 0]) |
|
sentences.append(verb[idx]) |
|
images.append(image[idx]) |
|
targets.append(target[idx]) |
|
pad_masks.append(torch.zeros_like(verb[idx]).masked_fill_(verb[idx] == 0, 1).bool()) |
|
else: |
|
verb_masks.append(0) |
|
cl_masks.append(1) |
|
|
|
|
|
sentences = torch.stack(sentences) |
|
images = torch.stack(images) |
|
targets = torch.stack(targets) |
|
pad_masks = torch.stack(pad_masks) |
|
verb_masks = torch.tensor(verb_masks, dtype=torch.bool) |
|
cl_masks = torch.tensor(cl_masks, dtype=torch.bool) |
|
|
|
else: |
|
sentences = text |
|
images = image |
|
targets = target |
|
pad_masks = torch.zeros_like(text).masked_fill_(text == 0, 1).bool() |
|
|
|
|
|
vis = self.backbone.encode_image(images) |
|
word, state = self.backbone.encode_text(sentences) |
|
|
|
|
|
fq, f5 = self.neck(vis, state) |
|
b, c, h, w = fq.size() |
|
fq = self.decoder(fq, word, pad_masks) |
|
metric_tensor = fq |
|
fq = fq.reshape(b, c, h, w) |
|
|
|
|
|
pred = self.proj(fq, state) |
|
|
|
if self.training: |
|
if pred.shape[-2:] != targets.shape[-2:]: |
|
targets = F.interpolate(targets, pred.shape[-2:], mode='nearest').detach() |
|
|
|
loss = F.binary_cross_entropy_with_logits(pred[cl_masks], targets[cl_masks]) |
|
|
|
if self.metric_learning: |
|
metric_loss = self.compute_metric_loss(metric_tensor, verb_masks, args=self.cfg) |
|
loss = (loss + self.metric_loss_weight * metric_loss) / (1 + self.metric_loss_weight) |
|
|
|
return pred[cl_masks].detach(), targets[cl_masks], loss |
|
|
|
return pred.detach() |
|
|
|
|
|
def compute_metric_loss(self, metric_tensor, positive_verbs, negative_verbs, args) : |
|
if args.loss_option == "ACL_verbonly" : |
|
metric_loss = self.UniAngularContrastLoss(metric_tensor, positive_verbs, negative_verbs, m=args.margin_value, tau=args.temperature, verbonly=True, args=args) |
|
elif args.loss_option == "ACL" : |
|
metric_loss = self.UniAngularContrastLoss(metric_tensor, positive_verbs, negative_verbs, m=args.margin_value, tau=args.temperature, verbonly=False, args=args) |
|
|
|
return metric_loss |
|
|
|
|
|
def return_mask(self, emb_distance, verb_mask=None): |
|
B_, B_ = emb_distance.shape |
|
positive_mask = torch.zeros_like(emb_distance) |
|
positive_mask.fill_diagonal_(1) |
|
|
|
if B_ < len(verb_mask): |
|
|
|
for i in range(B_ // 2): |
|
positive_mask[2 * i, 2 * i + 1] = 1 |
|
positive_mask[2 * i + 1, 2 * i] = 1 |
|
else: |
|
|
|
i = 0 |
|
while i < B_: |
|
if verb_mask[i] == 1: |
|
positive_mask[i, i + 1] = 1 |
|
positive_mask[i + 1, i] = 1 |
|
i += 2 |
|
else: |
|
i += 1 |
|
negative_mask = torch.ones_like(emb_distance) - positive_mask |
|
return positive_mask, negative_mask |
|
|
|
|
|
def UniAngularContrastLoss(self, total_fq, verb_mask, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None): |
|
_, C, HW = total_fq.shape |
|
|
|
if verbonly : |
|
emb = torch.mean(total_fq[verb_mask], dim=-1) |
|
assert emb.shape[0] % 2 == 0, f"Embedding count {emb.shape[0]} is not divisible by 2." |
|
else : |
|
emb = torch.mean(total_fq, dim=-1) |
|
|
|
B_ = emb.shape[0] |
|
|
|
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) |
|
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) |
|
sim = nn.CosineSimilarity(dim=-1, eps=1e-6) |
|
sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) |
|
sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999) |
|
|
|
positive_mask, negative_mask = self.return_mask(sim_matrix, verb_mask) |
|
|
|
|
|
sim_matrix_with_margin = sim_matrix.clone() |
|
sim_matrix_with_margin[positive_mask.bool()] = torch.cos(torch.acos(sim_matrix[positive_mask.bool()]) + m / 57.2958) |
|
|
|
|
|
logits = sim_matrix_with_margin / tau |
|
|
|
|
|
exp_logits = torch.exp(logits) |
|
pos_exp_logits = exp_logits[positive_mask.bool()] |
|
total_exp_logits = exp_logits.sum(dim=-1) |
|
|
|
|
|
positive_loss = -torch.log(pos_exp_logits / total_exp_logits) |
|
angular_loss = positive_loss.mean() |
|
|
|
return angular_loss |