|
import random |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from model_.clip import build_model |
|
from .layers import FPN, Projector, TransformerDecoder |
|
|
|
|
|
def MetricLoss(embeddings, n_pos, alpha = 0.5, args = None): |
|
|
|
|
|
|
|
|
|
metric_loss = 0 |
|
|
|
|
|
B_, C, HW = embeddings.shape |
|
emb = torch.mean(embeddings, dim=-1) |
|
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) |
|
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) |
|
emb_distance = torch.norm(emb_i - emb_j, dim=-1) |
|
assert torch.sum(torch.diag(emb_distance[:B_, :B_])) == 0, \ |
|
"Diagonals are not zero. please check the permutation on the batch" |
|
|
|
|
|
|
|
positive_mask = torch.zeros_like(emb_distance) |
|
for i in range(B_//2): |
|
positive_mask[2*i, 2*i+1] = 1 |
|
positive_mask[2*i+1, 2*i] = 1 |
|
positive_mask.fill_diagonal_(1) |
|
positive_loss = torch.sum(emb_distance * positive_mask) / B_ |
|
|
|
|
|
negative_mask = torch.ones_like(emb_distance) - positive_mask |
|
|
|
if args.div_batch: |
|
negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / B_) |
|
else: |
|
negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / (B_**2 - 2*B_)) |
|
|
|
|
|
|
|
metric_loss = alpha * positive_loss + (1-alpha) * negative_loss |
|
|
|
return metric_loss |
|
|
|
|
|
def AngularMetricLoss(embeddings, n_pos, alpha = 0.5, args = None): |
|
|
|
|
|
|
|
|
|
geometric_loss = 0 |
|
|
|
|
|
B_, C, HW = embeddings.shape |
|
emb = torch.mean(embeddings, dim=-1) |
|
emb_i = emb.unsqueeze(1).repeat(1, B_, 1) |
|
emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) |
|
sim = nn.CosineSimilarity(dim=-1, eps=1e-6) |
|
sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) |
|
print(sim_matrix) |
|
assert torch.trace(sim_matrix) == B_, \ |
|
"similarity diagonals are not one. please check the permutation on the batch" |
|
print("similarity metrix : ", sim_matrix) |
|
phi = torch.acos(sim_matrix) |
|
print("phi metrix : ", phi) |
|
|
|
|
|
positive_mask = torch.zeros_like(sim_matrix) |
|
for i in range(B_//2): |
|
positive_mask[2*i, 2*i+1] = 1 |
|
positive_mask[2*i+1, 2*i] = 1 |
|
positive_mask.fill_diagonal_(1) |
|
positive_loss = torch.sum((phi**2) * positive_mask) / B_ |
|
|
|
|
|
negative_mask = torch.ones_like(sim_matrix) - positive_mask |
|
phi_mask = phi < args.phi_threshold |
|
negative_loss = (args.phi_threshold - phi)**2 |
|
print(negative_mask * phi_mask) |
|
negative_loss = torch.sum(negative_loss * negative_mask * phi_mask) / (B_**2 - 2*B_) |
|
|
|
print("pos loss, neg loss : ", positive_loss, negative_loss) |
|
|
|
geometric_loss = alpha * positive_loss + (1-alpha) * negative_loss |
|
|
|
return geometric_loss |
|
|
|
|
|
class CRIS(nn.Module): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
|
|
clip_model = torch.jit.load(cfg.clip_pretrain, |
|
map_location="cpu").eval() |
|
self.backbone = build_model(clip_model.state_dict(), cfg.word_len, cfg.freeze).float() |
|
|
|
self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) |
|
|
|
self.decoder = TransformerDecoder(num_layers=cfg.num_layers, |
|
d_model=cfg.vis_dim, |
|
nhead=cfg.num_head, |
|
dim_ffn=cfg.dim_ffn, |
|
dropout=cfg.dropout, |
|
return_intermediate=cfg.intermediate) |
|
|
|
self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3) |
|
self.metric_learning = cfg.metric_learning |
|
|
|
|
|
|
|
self.cfg = cfg |
|
|
|
def forward(self, image, text, target=None): |
|
''' |
|
img: b, 3, h, w |
|
word: b, words |
|
word_mask: b, words |
|
if self.metric_learning: |
|
word: b, 2, words |
|
word_mask: b, 2, words |
|
mask: b, 1, h, w |
|
''' |
|
metric_learning_flag = (self.metric_learning and self.training) |
|
metric_loss = 0 |
|
|
|
|
|
if metric_learning_flag: |
|
|
|
b, c, h, w = image.size() |
|
|
|
if image is not None: |
|
image = torch.cat([image, image], dim=0) |
|
image = image.reshape(-1, b, c, h, w).transpose(0, 1).reshape(-1, c, h, w) |
|
if target is not None: |
|
target = torch.cat([target, target], dim=0) |
|
target = target.reshape(-1, b, 1, h, w).transpose(0, 1).reshape(-1, 1, h, w) |
|
|
|
b_, n_, l_ = text.size() |
|
assert n_ == 2 ,"word size should be 2" |
|
noise_mask = (text[:, 0, :] == text[:, 1, :]) |
|
noise_mask = torch.all(noise_mask, dim=-1) |
|
noise_mask = noise_mask.unsqueeze(-1).repeat(1, 2).reshape(-1) |
|
assert noise_mask.shape[0] == b_ * 2, "noise mask shape should be 2*b_" |
|
text = text.reshape(b_ * 2, l_) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pad_mask = torch.zeros_like(text).masked_fill_(text == 0, 1).bool() |
|
|
|
|
|
|
|
vis = self.backbone.encode_image(image) |
|
word, state = self.backbone.encode_text(text) |
|
|
|
b_, d_ = state.size() |
|
assert b_ == word.size(0), "batch size of state and word should be same" |
|
|
|
|
|
|
|
|
|
if metric_learning_flag : |
|
noise = torch.randn_like(state) * self.eps |
|
state[noise_mask] = state[noise_mask] + noise[noise_mask] |
|
|
|
|
|
|
|
|
|
a3, a4, a5 = vis |
|
|
|
fq, f5 = self.neck(vis, state) |
|
b, c, h, w = fq.size() |
|
fq = self.decoder(fq, word, pad_mask) |
|
|
|
|
|
if metric_learning_flag: |
|
metric_loss = MetricLoss(fq, 2, alpha=self.positive_strength, args = self.cfg) |
|
|
|
fq = fq.reshape(b, c, h, w) |
|
|
|
|
|
pred = self.proj(fq, state) |
|
|
|
if self.training: |
|
|
|
if pred.shape[-2:] != target.shape[-2:]: |
|
target = F.interpolate(target, pred.shape[-2:], |
|
mode='nearest').detach() |
|
loss = F.binary_cross_entropy_with_logits(pred, target) |
|
|
|
if metric_learning_flag: |
|
|
|
loss = (loss + self.metric_loss_weight * metric_loss) / (1+self.metric_loss_weight) |
|
return pred.detach(), target, loss |
|
else: |
|
return pred.detach() |
|
|