Spaces:
Running
Running
Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
import torch | |
import math | |
__all__ = ['kl_divergence', 'discretized_gaussian_log_likelihood'] | |
def kl_divergence(mu1, logvar1, mu2, logvar2): | |
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mu1 - mu2) ** 2) * torch.exp(-logvar2)) | |
def standard_normal_cdf(x): | |
r"""A fast approximation of the cumulative distribution function of the standard normal. | |
""" | |
return 0.5 * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) | |
def discretized_gaussian_log_likelihood(x0, mean, log_scale): | |
assert x0.shape == mean.shape == log_scale.shape | |
cx = x0 - mean | |
inv_stdv = torch.exp(-log_scale) | |
cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0)) | |
cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0)) | |
log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) | |
log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) | |
cdf_delta = cdf_plus - cdf_min | |
log_probs = torch.where( | |
x0 < -0.999, | |
log_cdf_plus, | |
torch.where(x0 > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12)))) | |
assert log_probs.shape == x0.shape | |
return log_probs | |