""" """ import math import torch import torch.nn.functional as F from cortex_DIM.functions.misc import log_sum_exp def raise_measure_error(measure): supported_measures = ['GAN', 'JSD', 'X2', 'KL', 'RKL', 'DV', 'H2', 'W1'] raise NotImplementedError( 'Measure `{}` not supported. Supported: {}'.format(measure, supported_measures)) def get_positive_expectation(p_samples, measure, average=True): """Computes the positive part of a divergence / difference. Args: p_samples: Positive samples. measure: Measure to compute for. average: Average the result over samples. Returns: torch.Tensor """ log_2 = math.log(2.) if measure == 'GAN': Ep = - F.softplus(-p_samples) elif measure == 'JSD': Ep = log_2 - F.softplus(- p_samples) elif measure == 'X2': Ep = p_samples ** 2 elif measure == 'KL': Ep = p_samples + 1. elif measure == 'RKL': Ep = -torch.exp(-p_samples) elif measure == 'DV': Ep = p_samples elif measure == 'H2': Ep = 1. - torch.exp(-p_samples) elif measure == 'W1': Ep = p_samples else: raise_measure_error(measure) if average: return Ep.mean() else: return Ep def get_negative_expectation(q_samples, measure, average=True): """Computes the negative part of a divergence / difference. Args: q_samples: Negative samples. measure: Measure to compute for. average: Average the result over samples. Returns: torch.Tensor """ log_2 = math.log(2.) if measure == 'GAN': Eq = F.softplus(-q_samples) + q_samples elif measure == 'JSD': Eq = F.softplus(-q_samples) + q_samples - log_2 elif measure == 'X2': Eq = -0.5 * ((torch.sqrt(q_samples ** 2) + 1.) ** 2) elif measure == 'KL': Eq = torch.exp(q_samples) elif measure == 'RKL': Eq = q_samples - 1. elif measure == 'DV': Eq = log_sum_exp(q_samples, 0) - math.log(q_samples.size(0)) elif measure == 'H2': Eq = torch.exp(q_samples) - 1. elif measure == 'W1': Eq = q_samples else: raise_measure_error(measure) if average: return Eq.mean() else: return Eq