VerbCentric-RIS / model /segmenter_angular.py
dianecy's picture
Upload folder using huggingface_hub
31dfd6a verified
raw
history blame
6.61 kB
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.clip import build_model
from .layers import FPN, Projector, TransformerDecoder
# def MetricLoss(embeddings, n_pos, alpha = 0.5, args = None):
# # embeddings: ((2*B), C, (H*W))
# # n_pos : chunk size of positive pairs
# # args: args
# # returns: loss
# metric_loss = 0
# # flatten embeddings
# B_, C, HW = embeddings.shape
# emb = torch.mean(embeddings, dim=-1) # (2*B, C)
# emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C)
# emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C)
# emb_distance = torch.norm(emb_i - emb_j, dim=-1) # (2*B, 2*B)
# assert torch.sum(torch.diag(emb_distance[:B_, :B_])) == 0, \
# "Diagonals are not zero. please check the permutation on the batch"
# # print("distance metrix : ", emb_distance)
# # positive pairs and loss
# positive_mask = torch.zeros_like(emb_distance)
# for i in range(B_//2):
# positive_mask[2*i, 2*i+1] = 1
# positive_mask[2*i+1, 2*i] = 1
# positive_mask.fill_diagonal_(1)
# positive_loss = torch.sum(emb_distance * positive_mask) / B_
# # negative pairs and loss
# negative_mask = torch.ones_like(emb_distance) - positive_mask
# negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / (B_**2 - 2*B_))
# # print(positive_mask, negative_mask)
# metric_loss = alpha * positive_loss + (1-alpha) * negative_loss
# return metric_loss
class CRIS_S(nn.Module):
def __init__(self, cfg):
super().__init__()
# Vision & Text Encoder
clip_model = torch.jit.load(cfg.clip_pretrain,
map_location="cpu").eval()
self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float()
# Multi-Modal FPN
self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out)
# Decoder
self.decoder = TransformerDecoder(num_layers=cfg.num_layers,
d_model=cfg.vis_dim,
nhead=cfg.num_head,
dim_ffn=cfg.dim_ffn,
dropout=cfg.dropout,
return_intermediate=cfg.intermediate)
# Projector
self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3)
self.metric_learning = cfg.metric_learning
self.positive_strength = cfg.positive_strength
self.metric_loss_weight = cfg.metric_loss_weight
self.eps = cfg.ptb_rate
self.cfg = cfg
def forward(self, image, text, target=None):
'''
img: b, 3, h, w
word: b, words
word_mask: b, words
if self.metric_learning:
word: b, 2, words
word_mask: b, 2, words
mask: b, 1, h, w
'''
metric_learning_flag = (self.metric_learning and self.training)
# TODO : mixing option btw distance & angular loss
mix_distance_angular = False
metric_loss = 0
# 1.Resizing : if metric learning, batch size of the word is doubled
if metric_learning_flag:
#print("image shape : ", image.shape)
b, c, h, w = image.size()
# duplicate image and segmentation mask
if image is not None:
image = torch.cat([image, image], dim=0)
image = image.reshape(-1, b, c, h, w).transpose(0, 1).reshape(-1, c, h, w)
if target is not None:
target = torch.cat([target, target], dim=0)
target = target.reshape(-1, b, 1, h, w).transpose(0, 1).reshape(-1, 1, h, w)
# duplicate noise mask
b_, n_, l_ = text.size()
assert n_ == 2 ,"word size should be 2"
noise_mask = (text[:, 0, :] == text[:, 1, :])
noise_mask = torch.all(noise_mask, dim=-1)
noise_mask = noise_mask.unsqueeze(-1).repeat(1, 2).reshape(-1) # 2*b_
assert noise_mask.shape[0] == b_ * 2, "noise mask shape should be 2*b_"
text = text.reshape(b_ * 2, l_) # 2*b, l
# print("text shape : ", text.shape)
# print("image shape : ", image.shape)
# print("target shape : ", target.shape)
# print(torch.sum(image[0::2]) == torch.sum(image[1::2]))
# print(torch.sum(target[0::2]) == torch.sum(target[1::2]))
# padding mask used in decoder
pad_mask = torch.zeros_like(text).masked_fill_(text == 0, 1).bool()
# vis: C3 / C4 / C5
# word: b, length, 1024
# state: b, 1024
vis = self.backbone.encode_image(image)
word, state = self.backbone.encode_text(text)
b_, d_ = state.size()
assert b_ == word.size(0), "batch size of state and word should be same"
# 2. State Noising Step : if number of caption is 1,
# add noise to the corresponding indices
if metric_learning_flag :
noise = torch.randn_like(state) * self.eps
state[noise_mask] = state[noise_mask] + noise[noise_mask]
# b, 512, 26, 26 (C4)
a3, a4, a5 = vis
fq, f5 = self.neck(vis, state)
b, c, h, w = fq.size()
fq = self.decoder(fq, word, pad_mask)
metric_tensor = fq
# if metric_learning_flag:
# metric_loss = AngularMetricLoss(fq, 2, alpha=self.positive_strength, args = self.cfg) # (1-self.positive_strength) *
# if mix_distance_angular:
# metric_loss += MetricLoss(fq, 2, alpha=self.positive_strength, args = self.cfg) # self.positive_strength *
fq = fq.reshape(b, c, h, w)
# b, 1, 104, 104
pred = self.proj(fq, state)
if self.training:
# resize mask
if pred.shape[-2:] != target.shape[-2:]:
target = F.interpolate(target, pred.shape[-2:],
mode='nearest').detach()
CE_loss = F.binary_cross_entropy_with_logits(pred, target)
# 4. if metric learning, add metric loss and normalize
# if metric_learning_flag:
# loss = (loss + self.metric_loss_weight * metric_loss) / (1+self.metric_loss_weight)
# safety_loss = loss * 0.
# loss = loss + safety_loss
return pred.detach(), target, CE_loss, metric_tensor
else:
#print(self.cfg.gpu, f"; loss = {loss}")
return pred.detach()