LuyangZ's picture
Upload 30 files
01df1d6 verified
'''cortex_DIM losses.
'''
import math
import torch
import torch.nn.functional as F
from cortex_DIM.functions.gan_losses import get_positive_expectation, get_negative_expectation
def fenchel_dual_loss(l, g, measure=None):
'''Computes the f-divergence distance between positive and negative joint distributions.
Divergences supported are Jensen-Shannon `JSD`, `GAN` (equivalent to JSD),
Squared Hellinger `H2`, Chi-squeared `X2`, `KL`, and reverse KL `RKL`.
Args:
l: Local feature map.
g: Global features.
measure: f-divergence measure.
Returns:
torch.Tensor: Loss.
'''
N, local_units, n_locs = l.size()
l = l.permute(0, 2, 1)
l = l.reshape(-1, local_units)
u = torch.mm(g, l.t())
u = u.reshape(N, N, -1)
mask = torch.eye(N).cuda()
n_mask = 1 - mask
E_pos = get_positive_expectation(u, measure, average=False).mean(2)
E_neg = get_negative_expectation(u, measure, average=False).mean(2)
E_pos = (E_pos * mask).sum() / mask.sum()
E_neg = (E_neg * n_mask).sum() / n_mask.sum()
loss = E_neg - E_pos
return loss
def multi_fenchel_dual_loss(l, m, measure=None):
'''Computes the f-divergence distance between positive and negative joint distributions.
Used for multiple globals.
Divergences supported are Jensen-Shannon `JSD`, `GAN` (equivalent to JSD),
Squared Hellinger `H2`, Chi-squeared `X2`, `KL`, and reverse KL `RKL`.
Args:
l: Local feature map.
m: Multiple globals feature map.
measure: f-divergence measure.
Returns:
torch.Tensor: Loss.
'''
N, units, n_locals = l.size()
n_multis = m.size(2)
l = l.view(N, units, n_locals)
l = l.permute(0, 2, 1)
l = l.reshape(-1, units)
m = m.view(N, units, n_multis)
m = m.permute(0, 2, 1)
m = m.reshape(-1, units)
u = torch.mm(m, l.t())
u = u.reshape(N, n_multis, N, n_locals).permute(0, 2, 3, 1)
mask = torch.eye(N).cuda()
n_mask = 1 - mask
E_pos = get_positive_expectation(u, measure, average=False).mean(2).mean(2)
E_neg = get_negative_expectation(u, measure, average=False).mean(2).mean(2)
E_pos = (E_pos * mask).sum() / mask.sum()
E_neg = (E_neg * n_mask).sum() / n_mask.sum()
loss = E_neg - E_pos
return loss
def nce_loss(l, g):
'''Computes the noise contrastive estimation-based loss.
Args:
l: Local feature map.
g: Global features.
Returns:
torch.Tensor: Loss.
'''
N, local_units, n_locs = l.size()
l_p = l.permute(0, 2, 1)
u_p = torch.matmul(l_p, g.unsqueeze(dim=2))
l_n = l_p.reshape(-1, local_units)
u_n = torch.mm(g, l_n.t())
u_n = u_n.reshape(N, N, n_locs)
mask = torch.eye(N).unsqueeze(dim=2).cuda()
n_mask = 1 - mask
u_n = (n_mask * u_n) - (10. * (1 - n_mask)) # mask out "self" examples
u_n = u_n.reshape(N, -1).unsqueeze(dim=1).expand(-1, n_locs, -1)
pred_lgt = torch.cat([u_p, u_n], dim=2)
pred_log = F.log_softmax(pred_lgt, dim=2)
loss = -pred_log[:, :, 0].mean()
return loss
def multi_nce_loss(l, m):
'''
Used for multiple globals.
Args:
l: Local feature map.
m: Multiple globals feature map.
Returns:
torch.Tensor: Loss.
'''
N, units, n_locals = l.size()
_, _ , n_multis = m.size()
l = l.view(N, units, n_locals)
m = m.view(N, units, n_multis)
l_p = l.permute(0, 2, 1)
m_p = m.permute(0, 2, 1)
u_p = torch.matmul(l_p, m).unsqueeze(2)
l_n = l_p.reshape(-1, units)
m_n = m_p.reshape(-1, units)
u_n = torch.mm(m_n, l_n.t())
u_n = u_n.reshape(N, n_multis, N, n_locals).permute(0, 2, 3, 1)
mask = torch.eye(N)[:, :, None, None].cuda()
n_mask = 1 - mask
u_n = (n_mask * u_n) - (10. * (1 - n_mask)) # mask out "self" examples
u_n = u_n.reshape(N, N * n_locals, n_multis).unsqueeze(dim=1).expand(-1, n_locals, -1, -1)
pred_lgt = torch.cat([u_p, u_n], dim=2)
pred_log = F.log_softmax(pred_lgt, dim=2)
loss = -pred_log[:, :, 0].mean()
return loss
def donsker_varadhan_loss(l, g):
'''
Args:
l: Local feature map.
g: Global features.
Returns:
torch.Tensor: Loss.
'''
N, local_units, n_locs = l.size()
l = l.permute(0, 2, 1)
l = l.reshape(-1, local_units)
u = torch.mm(g, l.t())
u = u.reshape(N, N, n_locs)
mask = torch.eye(N).cuda()
n_mask = (1 - mask)[:, :, None]
E_pos = (u.mean(2) * mask).sum() / mask.sum()
u -= 100 * (1 - n_mask)
u_max = torch.max(u)
E_neg = torch.log((n_mask * torch.exp(u - u_max)).sum() + 1e-6) + u_max - math.log(n_mask.sum())
loss = E_neg - E_pos
return loss
def multi_donsker_varadhan_loss(l, m):
'''
Used for multiple globals.
Args:
l: Local feature map.
m: Multiple globals feature map.
Returns:
torch.Tensor: Loss.
'''
N, units, n_locals = l.size()
n_multis = m.size(2)
l = l.view(N, units, n_locals)
l = l.permute(0, 2, 1)
l = l.reshape(-1, units)
m = m.view(N, units, n_multis)
m = m.permute(0, 2, 1)
m = m.reshape(-1, units)
u = torch.mm(m, l.t())
u = u.reshape(N, n_multis, N, n_locals).permute(0, 2, 3, 1)
mask = torch.eye(N).cuda()
n_mask = 1 - mask
E_pos = (u.mean(2) * mask).sum() / mask.sum()
u -= 100 * (1 - n_mask)
u_max = torch.max(u)
E_neg = torch.log((n_mask * torch.exp(u - u_max)).sum() + 1e-6) + u_max - math.log(n_mask.sum())
loss = E_neg - E_pos
return loss