import random | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from model.clip import build_model | |
from .layers import FPN, Projector, TransformerDecoder | |
# def MetricLoss(embeddings, n_pos, alpha = 0.5, args = None): | |
# # embeddings: ((2*B), C, (H*W)) | |
# # n_pos : chunk size of positive pairs | |
# # args: args | |
# # returns: loss | |
# metric_loss = 0 | |
# # flatten embeddings | |
# B_, C, HW = embeddings.shape | |
# emb = torch.mean(embeddings, dim=-1) # (2*B, C) | |
# emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C) | |
# emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C) | |
# emb_distance = torch.norm(emb_i - emb_j, dim=-1) # (2*B, 2*B) | |
# assert torch.sum(torch.diag(emb_distance[:B_, :B_])) == 0, \ | |
# "Diagonals are not zero. please check the permutation on the batch" | |
# # print("distance metrix : ", emb_distance) | |
# # positive pairs and loss | |
# positive_mask = torch.zeros_like(emb_distance) | |
# for i in range(B_//2): | |
# positive_mask[2*i, 2*i+1] = 1 | |
# positive_mask[2*i+1, 2*i] = 1 | |
# positive_mask.fill_diagonal_(1) | |
# positive_loss = torch.sum(emb_distance * positive_mask) / B_ | |
# # negative pairs and loss | |
# negative_mask = torch.ones_like(emb_distance) - positive_mask | |
# negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / (B_**2 - 2*B_)) | |
# # print(positive_mask, negative_mask) | |
# metric_loss = alpha * positive_loss + (1-alpha) * negative_loss | |
# return metric_loss | |
class CRIS_S(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
# Vision & Text Encoder | |
clip_model = torch.jit.load(cfg.clip_pretrain, | |
map_location="cpu").eval() | |
self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float() | |
# Multi-Modal FPN | |
self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) | |
# Decoder | |
self.decoder = TransformerDecoder(num_layers=cfg.num_layers, | |
d_model=cfg.vis_dim, | |
nhead=cfg.num_head, | |
dim_ffn=cfg.dim_ffn, | |
dropout=cfg.dropout, | |
return_intermediate=cfg.intermediate) | |
# Projector | |
self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3) | |
self.metric_learning = cfg.metric_learning | |
self.positive_strength = cfg.positive_strength | |
self.metric_loss_weight = cfg.metric_loss_weight | |
self.eps = cfg.ptb_rate | |
self.cfg = cfg | |
def forward(self, image, text, target=None): | |
''' | |
img: b, 3, h, w | |
word: b, words | |
word_mask: b, words | |
if self.metric_learning: | |
word: b, 2, words | |
word_mask: b, 2, words | |
mask: b, 1, h, w | |
''' | |
metric_learning_flag = (self.metric_learning and self.training) | |
# TODO : mixing option btw distance & angular loss | |
mix_distance_angular = False | |
metric_loss = 0 | |
# 1.Resizing : if metric learning, batch size of the word is doubled | |
if metric_learning_flag: | |
#print("image shape : ", image.shape) | |
b, c, h, w = image.size() | |
# duplicate image and segmentation mask | |
if image is not None: | |
image = torch.cat([image, image], dim=0) | |
image = image.reshape(-1, b, c, h, w).transpose(0, 1).reshape(-1, c, h, w) | |
if target is not None: | |
target = torch.cat([target, target], dim=0) | |
target = target.reshape(-1, b, 1, h, w).transpose(0, 1).reshape(-1, 1, h, w) | |
# duplicate noise mask | |
b_, n_, l_ = text.size() | |
assert n_ == 2 ,"word size should be 2" | |
noise_mask = (text[:, 0, :] == text[:, 1, :]) | |
noise_mask = torch.all(noise_mask, dim=-1) | |
noise_mask = noise_mask.unsqueeze(-1).repeat(1, 2).reshape(-1) # 2*b_ | |
assert noise_mask.shape[0] == b_ * 2, "noise mask shape should be 2*b_" | |
text = text.reshape(b_ * 2, l_) # 2*b, l | |
# print("text shape : ", text.shape) | |
# print("image shape : ", image.shape) | |
# print("target shape : ", target.shape) | |
# print(torch.sum(image[0::2]) == torch.sum(image[1::2])) | |
# print(torch.sum(target[0::2]) == torch.sum(target[1::2])) | |
# padding mask used in decoder | |
pad_mask = torch.zeros_like(text).masked_fill_(text == 0, 1).bool() | |
# vis: C3 / C4 / C5 | |
# word: b, length, 1024 | |
# state: b, 1024 | |
vis = self.backbone.encode_image(image) | |
word, state = self.backbone.encode_text(text) | |
b_, d_ = state.size() | |
assert b_ == word.size(0), "batch size of state and word should be same" | |
# 2. State Noising Step : if number of caption is 1, | |
# add noise to the corresponding indices | |
if metric_learning_flag : | |
noise = torch.randn_like(state) * self.eps | |
state[noise_mask] = state[noise_mask] + noise[noise_mask] | |
# b, 512, 26, 26 (C4) | |
a3, a4, a5 = vis | |
fq, f5 = self.neck(vis, state) | |
b, c, h, w = fq.size() | |
fq = self.decoder(fq, word, pad_mask) | |
metric_tensor = fq | |
# if metric_learning_flag: | |
# metric_loss = AngularMetricLoss(fq, 2, alpha=self.positive_strength, args = self.cfg) # (1-self.positive_strength) * | |
# if mix_distance_angular: | |
# metric_loss += MetricLoss(fq, 2, alpha=self.positive_strength, args = self.cfg) # self.positive_strength * | |
fq = fq.reshape(b, c, h, w) | |
# b, 1, 104, 104 | |
pred = self.proj(fq, state) | |
if self.training: | |
# resize mask | |
if pred.shape[-2:] != target.shape[-2:]: | |
target = F.interpolate(target, pred.shape[-2:], | |
mode='nearest').detach() | |
CE_loss = F.binary_cross_entropy_with_logits(pred, target) | |
# 4. if metric learning, add metric loss and normalize | |
# if metric_learning_flag: | |
# loss = (loss + self.metric_loss_weight * metric_loss) / (1+self.metric_loss_weight) | |
# safety_loss = loss * 0. | |
# loss = loss + safety_loss | |
return pred.detach(), target, CE_loss, metric_tensor | |
else: | |
#print(self.cfg.gpu, f"; loss = {loss}") | |
return pred.detach() | |