Spaces:
Runtime error
Runtime error
'''cortex_DIM losses. | |
''' | |
import math | |
import torch | |
import torch.nn.functional as F | |
from cortex_DIM.functions.gan_losses import get_positive_expectation, get_negative_expectation | |
def fenchel_dual_loss(l, g, measure=None): | |
'''Computes the f-divergence distance between positive and negative joint distributions. | |
Divergences supported are Jensen-Shannon `JSD`, `GAN` (equivalent to JSD), | |
Squared Hellinger `H2`, Chi-squeared `X2`, `KL`, and reverse KL `RKL`. | |
Args: | |
l: Local feature map. | |
g: Global features. | |
measure: f-divergence measure. | |
Returns: | |
torch.Tensor: Loss. | |
''' | |
N, local_units, n_locs = l.size() | |
l = l.permute(0, 2, 1) | |
l = l.reshape(-1, local_units) | |
u = torch.mm(g, l.t()) | |
u = u.reshape(N, N, -1) | |
mask = torch.eye(N).cuda() | |
n_mask = 1 - mask | |
E_pos = get_positive_expectation(u, measure, average=False).mean(2) | |
E_neg = get_negative_expectation(u, measure, average=False).mean(2) | |
E_pos = (E_pos * mask).sum() / mask.sum() | |
E_neg = (E_neg * n_mask).sum() / n_mask.sum() | |
loss = E_neg - E_pos | |
return loss | |
def multi_fenchel_dual_loss(l, m, measure=None): | |
'''Computes the f-divergence distance between positive and negative joint distributions. | |
Used for multiple globals. | |
Divergences supported are Jensen-Shannon `JSD`, `GAN` (equivalent to JSD), | |
Squared Hellinger `H2`, Chi-squeared `X2`, `KL`, and reverse KL `RKL`. | |
Args: | |
l: Local feature map. | |
m: Multiple globals feature map. | |
measure: f-divergence measure. | |
Returns: | |
torch.Tensor: Loss. | |
''' | |
N, units, n_locals = l.size() | |
n_multis = m.size(2) | |
l = l.view(N, units, n_locals) | |
l = l.permute(0, 2, 1) | |
l = l.reshape(-1, units) | |
m = m.view(N, units, n_multis) | |
m = m.permute(0, 2, 1) | |
m = m.reshape(-1, units) | |
u = torch.mm(m, l.t()) | |
u = u.reshape(N, n_multis, N, n_locals).permute(0, 2, 3, 1) | |
mask = torch.eye(N).cuda() | |
n_mask = 1 - mask | |
E_pos = get_positive_expectation(u, measure, average=False).mean(2).mean(2) | |
E_neg = get_negative_expectation(u, measure, average=False).mean(2).mean(2) | |
E_pos = (E_pos * mask).sum() / mask.sum() | |
E_neg = (E_neg * n_mask).sum() / n_mask.sum() | |
loss = E_neg - E_pos | |
return loss | |
def nce_loss(l, g): | |
'''Computes the noise contrastive estimation-based loss. | |
Args: | |
l: Local feature map. | |
g: Global features. | |
Returns: | |
torch.Tensor: Loss. | |
''' | |
N, local_units, n_locs = l.size() | |
l_p = l.permute(0, 2, 1) | |
u_p = torch.matmul(l_p, g.unsqueeze(dim=2)) | |
l_n = l_p.reshape(-1, local_units) | |
u_n = torch.mm(g, l_n.t()) | |
u_n = u_n.reshape(N, N, n_locs) | |
mask = torch.eye(N).unsqueeze(dim=2).cuda() | |
n_mask = 1 - mask | |
u_n = (n_mask * u_n) - (10. * (1 - n_mask)) # mask out "self" examples | |
u_n = u_n.reshape(N, -1).unsqueeze(dim=1).expand(-1, n_locs, -1) | |
pred_lgt = torch.cat([u_p, u_n], dim=2) | |
pred_log = F.log_softmax(pred_lgt, dim=2) | |
loss = -pred_log[:, :, 0].mean() | |
return loss | |
def multi_nce_loss(l, m): | |
''' | |
Used for multiple globals. | |
Args: | |
l: Local feature map. | |
m: Multiple globals feature map. | |
Returns: | |
torch.Tensor: Loss. | |
''' | |
N, units, n_locals = l.size() | |
_, _ , n_multis = m.size() | |
l = l.view(N, units, n_locals) | |
m = m.view(N, units, n_multis) | |
l_p = l.permute(0, 2, 1) | |
m_p = m.permute(0, 2, 1) | |
u_p = torch.matmul(l_p, m).unsqueeze(2) | |
l_n = l_p.reshape(-1, units) | |
m_n = m_p.reshape(-1, units) | |
u_n = torch.mm(m_n, l_n.t()) | |
u_n = u_n.reshape(N, n_multis, N, n_locals).permute(0, 2, 3, 1) | |
mask = torch.eye(N)[:, :, None, None].cuda() | |
n_mask = 1 - mask | |
u_n = (n_mask * u_n) - (10. * (1 - n_mask)) # mask out "self" examples | |
u_n = u_n.reshape(N, N * n_locals, n_multis).unsqueeze(dim=1).expand(-1, n_locals, -1, -1) | |
pred_lgt = torch.cat([u_p, u_n], dim=2) | |
pred_log = F.log_softmax(pred_lgt, dim=2) | |
loss = -pred_log[:, :, 0].mean() | |
return loss | |
def donsker_varadhan_loss(l, g): | |
''' | |
Args: | |
l: Local feature map. | |
g: Global features. | |
Returns: | |
torch.Tensor: Loss. | |
''' | |
N, local_units, n_locs = l.size() | |
l = l.permute(0, 2, 1) | |
l = l.reshape(-1, local_units) | |
u = torch.mm(g, l.t()) | |
u = u.reshape(N, N, n_locs) | |
mask = torch.eye(N).cuda() | |
n_mask = (1 - mask)[:, :, None] | |
E_pos = (u.mean(2) * mask).sum() / mask.sum() | |
u -= 100 * (1 - n_mask) | |
u_max = torch.max(u) | |
E_neg = torch.log((n_mask * torch.exp(u - u_max)).sum() + 1e-6) + u_max - math.log(n_mask.sum()) | |
loss = E_neg - E_pos | |
return loss | |
def multi_donsker_varadhan_loss(l, m): | |
''' | |
Used for multiple globals. | |
Args: | |
l: Local feature map. | |
m: Multiple globals feature map. | |
Returns: | |
torch.Tensor: Loss. | |
''' | |
N, units, n_locals = l.size() | |
n_multis = m.size(2) | |
l = l.view(N, units, n_locals) | |
l = l.permute(0, 2, 1) | |
l = l.reshape(-1, units) | |
m = m.view(N, units, n_multis) | |
m = m.permute(0, 2, 1) | |
m = m.reshape(-1, units) | |
u = torch.mm(m, l.t()) | |
u = u.reshape(N, n_multis, N, n_locals).permute(0, 2, 3, 1) | |
mask = torch.eye(N).cuda() | |
n_mask = 1 - mask | |
E_pos = (u.mean(2) * mask).sum() / mask.sum() | |
u -= 100 * (1 - n_mask) | |
u_max = torch.max(u) | |
E_neg = torch.log((n_mask * torch.exp(u - u_max)).sum() + 1e-6) + u_max - math.log(n_mask.sum()) | |
loss = E_neg - E_pos | |
return loss |