Spaces:
Runtime error
Runtime error
import math | |
import torch | |
import torch.nn as nn | |
from .encoder import EncoderBase | |
from ..utils import log_sum_exp | |
class GaussianEncoderBase(EncoderBase): | |
"""docstring for EncoderBase""" | |
def __init__(self): | |
super(GaussianEncoderBase, self).__init__() | |
def freeze(self): | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, x): | |
""" | |
Args: | |
x: (batch_size, *) | |
Returns: Tensor1, Tensor2 | |
Tensor1: the mean tensor, shape (batch, nz) | |
Tensor2: the logvar tensor, shape (batch, nz) | |
""" | |
raise NotImplementedError | |
def encode_stats(self, x): | |
return self.forward(x) | |
def sample(self, input, nsamples): | |
"""sampling from the encoder | |
Returns: Tensor1 | |
Tensor1: the tensor latent z with shape [batch, nsamples, nz] | |
""" | |
# (batch_size, nz) | |
mu, logvar = self.forward(input) | |
# (batch, nsamples, nz) | |
z = self.reparameterize(mu, logvar, nsamples) | |
return z, (mu, logvar) | |
def encode(self, input, nsamples): | |
"""perform the encoding and compute the KL term | |
Returns: Tensor1, Tensor2 | |
Tensor1: the tensor latent z with shape [batch, nsamples, nz] | |
Tensor2: the tenor of KL for each x with shape [batch] | |
""" | |
# (batch_size, nz) | |
mu, logvar = self.forward(input) | |
# (batch, nsamples, nz) | |
z = self.reparameterize(mu, logvar, nsamples) | |
KL = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1).sum(dim=1) | |
return z, KL | |
def reparameterize(self, mu, logvar, nsamples=1): | |
"""sample from posterior Gaussian family | |
Args: | |
mu: Tensor | |
Mean of gaussian distribution with shape (batch, nz) | |
logvar: Tensor | |
logvar of gaussian distibution with shape (batch, nz) | |
Returns: Tensor | |
Sampled z with shape (batch, nsamples, nz) | |
""" | |
batch_size, nz = mu.size() | |
std = logvar.mul(0.5).exp() | |
mu_expd = mu.unsqueeze(1).expand(batch_size, nsamples, nz) | |
std_expd = std.unsqueeze(1).expand(batch_size, nsamples, nz) | |
eps = torch.zeros_like(std_expd).normal_() | |
return mu_expd + torch.mul(eps, std_expd) | |
def eval_inference_dist(self, x, z, param=None): | |
"""this function computes log q(z | x) | |
Args: | |
z: tensor | |
different z points that will be evaluated, with | |
shape [batch, nsamples, nz] | |
Returns: Tensor1 | |
Tensor1: log q(z|x) with shape [batch, nsamples] | |
""" | |
nz = z.size(2) | |
if not param: | |
mu, logvar = self.forward(x) | |
else: | |
mu, logvar = param | |
# (batch_size, 1, nz) | |
mu, logvar = mu.unsqueeze(1), logvar.unsqueeze(1) | |
var = logvar.exp() | |
# (batch_size, nsamples, nz) | |
dev = z - mu | |
# (batch_size, nsamples) | |
log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \ | |
0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1)) | |
return log_density | |
def calc_mi(self, x): | |
"""Approximate the mutual information between x and z | |
I(x, z) = E_xE_{q(z|x)}log(q(z|x)) - E_xE_{q(z|x)}log(q(z)) | |
Returns: Float | |
""" | |
# [x_batch, nz] | |
mu, logvar = self.forward(x) | |
x_batch, nz = mu.size() | |
# E_{q(z|x)}log(q(z|x)) = -0.5*nz*log(2*\pi) - 0.5*(1+logvar).sum(-1) | |
neg_entropy = (-0.5 * nz * math.log(2 * math.pi)- 0.5 * (1 + logvar).sum(-1)).mean() | |
# [z_batch, 1, nz] | |
z_samples = self.reparameterize(mu, logvar, 1) | |
# [1, x_batch, nz] | |
mu, logvar = mu.unsqueeze(0), logvar.unsqueeze(0) | |
var = logvar.exp() | |
# (z_batch, x_batch, nz) | |
dev = z_samples - mu | |
# (z_batch, x_batch) | |
log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \ | |
0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1)) | |
# log q(z): aggregate posterior | |
# [z_batch] | |
log_qz = log_sum_exp(log_density, dim=1) - math.log(x_batch) | |
return (neg_entropy - log_qz.mean(-1)).item() |