LuyangZ's picture
Upload 30 files
01df1d6 verified
"""
"""
import math
import torch
import torch.nn.functional as F
from cortex_DIM.functions.misc import log_sum_exp
def raise_measure_error(measure):
supported_measures = ['GAN', 'JSD', 'X2', 'KL', 'RKL', 'DV', 'H2', 'W1']
raise NotImplementedError(
'Measure `{}` not supported. Supported: {}'.format(measure,
supported_measures))
def get_positive_expectation(p_samples, measure, average=True):
"""Computes the positive part of a divergence / difference.
Args:
p_samples: Positive samples.
measure: Measure to compute for.
average: Average the result over samples.
Returns:
torch.Tensor
"""
log_2 = math.log(2.)
if measure == 'GAN':
Ep = - F.softplus(-p_samples)
elif measure == 'JSD':
Ep = log_2 - F.softplus(- p_samples)
elif measure == 'X2':
Ep = p_samples ** 2
elif measure == 'KL':
Ep = p_samples + 1.
elif measure == 'RKL':
Ep = -torch.exp(-p_samples)
elif measure == 'DV':
Ep = p_samples
elif measure == 'H2':
Ep = 1. - torch.exp(-p_samples)
elif measure == 'W1':
Ep = p_samples
else:
raise_measure_error(measure)
if average:
return Ep.mean()
else:
return Ep
def get_negative_expectation(q_samples, measure, average=True):
"""Computes the negative part of a divergence / difference.
Args:
q_samples: Negative samples.
measure: Measure to compute for.
average: Average the result over samples.
Returns:
torch.Tensor
"""
log_2 = math.log(2.)
if measure == 'GAN':
Eq = F.softplus(-q_samples) + q_samples
elif measure == 'JSD':
Eq = F.softplus(-q_samples) + q_samples - log_2
elif measure == 'X2':
Eq = -0.5 * ((torch.sqrt(q_samples ** 2) + 1.) ** 2)
elif measure == 'KL':
Eq = torch.exp(q_samples)
elif measure == 'RKL':
Eq = q_samples - 1.
elif measure == 'DV':
Eq = log_sum_exp(q_samples, 0) - math.log(q_samples.size(0))
elif measure == 'H2':
Eq = torch.exp(q_samples) - 1.
elif measure == 'W1':
Eq = q_samples
else:
raise_measure_error(measure)
if average:
return Eq.mean()
else:
return Eq