Babel / Optimus /code /modules /encoders /gaussian_encoder.py
rynmurdock's picture
init
c5ca37a
import math
import torch
import torch.nn as nn
from .encoder import EncoderBase
from ..utils import log_sum_exp
class GaussianEncoderBase(EncoderBase):
"""docstring for EncoderBase"""
def __init__(self):
super(GaussianEncoderBase, self).__init__()
def freeze(self):
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
"""
Args:
x: (batch_size, *)
Returns: Tensor1, Tensor2
Tensor1: the mean tensor, shape (batch, nz)
Tensor2: the logvar tensor, shape (batch, nz)
"""
raise NotImplementedError
def encode_stats(self, x):
return self.forward(x)
def sample(self, input, nsamples):
"""sampling from the encoder
Returns: Tensor1
Tensor1: the tensor latent z with shape [batch, nsamples, nz]
"""
# (batch_size, nz)
mu, logvar = self.forward(input)
# (batch, nsamples, nz)
z = self.reparameterize(mu, logvar, nsamples)
return z, (mu, logvar)
def encode(self, input, nsamples):
"""perform the encoding and compute the KL term
Returns: Tensor1, Tensor2
Tensor1: the tensor latent z with shape [batch, nsamples, nz]
Tensor2: the tenor of KL for each x with shape [batch]
"""
# (batch_size, nz)
mu, logvar = self.forward(input)
# (batch, nsamples, nz)
z = self.reparameterize(mu, logvar, nsamples)
KL = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1).sum(dim=1)
return z, KL
def reparameterize(self, mu, logvar, nsamples=1):
"""sample from posterior Gaussian family
Args:
mu: Tensor
Mean of gaussian distribution with shape (batch, nz)
logvar: Tensor
logvar of gaussian distibution with shape (batch, nz)
Returns: Tensor
Sampled z with shape (batch, nsamples, nz)
"""
batch_size, nz = mu.size()
std = logvar.mul(0.5).exp()
mu_expd = mu.unsqueeze(1).expand(batch_size, nsamples, nz)
std_expd = std.unsqueeze(1).expand(batch_size, nsamples, nz)
eps = torch.zeros_like(std_expd).normal_()
return mu_expd + torch.mul(eps, std_expd)
def eval_inference_dist(self, x, z, param=None):
"""this function computes log q(z | x)
Args:
z: tensor
different z points that will be evaluated, with
shape [batch, nsamples, nz]
Returns: Tensor1
Tensor1: log q(z|x) with shape [batch, nsamples]
"""
nz = z.size(2)
if not param:
mu, logvar = self.forward(x)
else:
mu, logvar = param
# (batch_size, 1, nz)
mu, logvar = mu.unsqueeze(1), logvar.unsqueeze(1)
var = logvar.exp()
# (batch_size, nsamples, nz)
dev = z - mu
# (batch_size, nsamples)
log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \
0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1))
return log_density
def calc_mi(self, x):
"""Approximate the mutual information between x and z
I(x, z) = E_xE_{q(z|x)}log(q(z|x)) - E_xE_{q(z|x)}log(q(z))
Returns: Float
"""
# [x_batch, nz]
mu, logvar = self.forward(x)
x_batch, nz = mu.size()
# E_{q(z|x)}log(q(z|x)) = -0.5*nz*log(2*\pi) - 0.5*(1+logvar).sum(-1)
neg_entropy = (-0.5 * nz * math.log(2 * math.pi)- 0.5 * (1 + logvar).sum(-1)).mean()
# [z_batch, 1, nz]
z_samples = self.reparameterize(mu, logvar, 1)
# [1, x_batch, nz]
mu, logvar = mu.unsqueeze(0), logvar.unsqueeze(0)
var = logvar.exp()
# (z_batch, x_batch, nz)
dev = z_samples - mu
# (z_batch, x_batch)
log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \
0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1))
# log q(z): aggregate posterior
# [z_batch]
log_qz = log_sum_exp(log_density, dim=1) - math.log(x_batch)
return (neg_entropy - log_qz.mean(-1)).item()