VRIS_vip / RIS-DMMI /lib /loss_functions.py
dianecy's picture
Upload folder using huggingface_hub
0b32e3c verified
raw
history blame
2.23 kB
import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class Contrastive_Loss(nn.Module):
"""Triplet loss with hard positive/negative mining.
"""
def __init__(self, temperature=0.05):
super(Contrastive_Loss, self).__init__()
self.temperature = temperature
self.adpool = nn.AdaptiveAvgPool2d((1, 1))
self.align_lan = nn.Sequential(
nn.Conv1d(768, 768, kernel_size=1, stride=1),
)
def forward(self, vis_feature, lan_feature, target_flag):
"""
"""
vis_feature1 = F.normalize(vis_feature, dim=1)
lan_feature1 = self.align_lan(lan_feature)
lan_feature1 = self.adpool(lan_feature1.unsqueeze(3)).view(lan_feature.shape[0], lan_feature.shape[1])
lan_feature1 = F.normalize(lan_feature1, dim=1)
img_text_logits = torch.matmul(vis_feature1, lan_feature1.permute(1, 0)) / self.temperature
text_img_logits = img_text_logits.permute(1, 0)
labels = torch.arange(0, len(lan_feature)).cuda()
loss_a = nn.functional.cross_entropy(img_text_logits, labels, reduce=False)
loss_b = nn.functional.cross_entropy(text_img_logits, labels, reduce=False)
loss_a = torch.mean(loss_a * target_flag)
loss_b = torch.mean(loss_b * target_flag)
loss_con = loss_a + loss_b
return loss_con
class Cosine_Sim_Loss(nn.Module):
"""cosine similarity function.
"""
def __init__(self):
super(Cosine_Sim_Loss, self).__init__()
self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
def forward(self, lan1, lan2, mask_full, target_flag):
"""
"""
maskf1 = mask_full.permute(0, 2, 1)
target_flag = target_flag.view(target_flag.shape[0], 1)
lan1_1 = lan1 * maskf1
lan2_1 = lan2 * maskf1
lan1_1_clone = lan1_1.detach()
score = self.cos(lan1_1_clone, lan2_1)
score = score * target_flag
score1 = torch.sum(score, dim=-1)
length = torch.sum(maskf1, dim=-1).squeeze(-1)
mean_score = score1 / length
# pdb.set_trace()
loss_cossim = 1 - torch.mean(mean_score)
return loss_cossim