Spaces:
Runtime error
Runtime error
from itertools import chain | |
import math | |
import torch | |
import torch.nn as nn | |
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence | |
from .gaussian_encoder import GaussianEncoderBase | |
from ..utils import log_sum_exp | |
class GaussianLSTMEncoder(GaussianEncoderBase): | |
"""Gaussian LSTM Encoder with constant-length input""" | |
def __init__(self, args, vocab_size, model_init, emb_init): | |
super(GaussianLSTMEncoder, self).__init__() | |
self.ni = args.ni | |
self.nh = args.enc_nh | |
self.nz = args.nz | |
self.args = args | |
self.embed = nn.Embedding(vocab_size, args.ni) | |
self.lstm = nn.LSTM(input_size=args.ni, | |
hidden_size=args.enc_nh, | |
num_layers=1, | |
batch_first=True, | |
dropout=0) | |
self.linear = nn.Linear(args.enc_nh, 2 * args.nz, bias=False) | |
self.reset_parameters(model_init, emb_init) | |
def reset_parameters(self, model_init, emb_init): | |
# for name, param in self.lstm.named_parameters(): | |
# # self.initializer(param) | |
# if 'bias' in name: | |
# nn.init.constant_(param, 0.0) | |
# # model_init(param) | |
# elif 'weight' in name: | |
# model_init(param) | |
# model_init(self.linear.weight) | |
# emb_init(self.embed.weight) | |
for param in self.parameters(): | |
model_init(param) | |
emb_init(self.embed.weight) | |
def forward(self, input): | |
""" | |
Args: | |
x: (batch_size, seq_len) | |
Returns: Tensor1, Tensor2 | |
Tensor1: the mean tensor, shape (batch, nz) | |
Tensor2: the logvar tensor, shape (batch, nz) | |
""" | |
# (batch_size, seq_len-1, args.ni) | |
word_embed = self.embed(input) | |
_, (last_state, last_cell) = self.lstm(word_embed) | |
mean, logvar = self.linear(last_state).chunk(2, -1) | |
# fix variance as a pre-defined value | |
if self.args.fix_var > 0: | |
logvar = mean.new_tensor([[[math.log(self.args.fix_var)]]]).expand_as(mean) | |
return mean.squeeze(0), logvar.squeeze(0) | |
# def eval_inference_mode(self, x): | |
# """compute the mode points in the inference distribution | |
# (in Gaussian case) | |
# Returns: Tensor | |
# Tensor: the posterior mode points with shape (*, nz) | |
# """ | |
# # (batch_size, nz) | |
# mu, logvar = self.forward(x) | |
class VarLSTMEncoder(GaussianLSTMEncoder): | |
"""Gaussian LSTM Encoder with variable-length input""" | |
def __init__(self, args, vocab_size, model_init, emb_init): | |
super(VarLSTMEncoder, self).__init__(args, vocab_size, model_init, emb_init) | |
def forward(self, input): | |
""" | |
Args: | |
input: tuple which contains x and sents_len | |
x: (batch_size, seq_len) | |
sents_len: long tensor of sentence lengths | |
Returns: Tensor1, Tensor2 | |
Tensor1: the mean tensor, shape (batch, nz) | |
Tensor2: the logvar tensor, shape (batch, nz) | |
""" | |
input, sents_len = input | |
# (batch_size, seq_len, args.ni) | |
word_embed = self.embed(input) | |
packed_embed = pack_padded_sequence(word_embed, sents_len.tolist(), batch_first=True) | |
_, (last_state, last_cell) = self.lstm(packed_embed) | |
mean, logvar = self.linear(last_state).chunk(2, -1) | |
return mean.squeeze(0), logvar.squeeze(0) | |
def encode(self, input, nsamples): | |
"""perform the encoding and compute the KL term | |
Args: | |
input: tuple which contains x and sents_len | |
Returns: Tensor1, Tensor2 | |
Tensor1: the tensor latent z with shape [batch, nsamples, nz] | |
Tensor2: the tenor of KL for each x with shape [batch] | |
""" | |
# (batch_size, nz) | |
mu, logvar = self.forward(input) | |
# (batch, nsamples, nz) | |
z = self.reparameterize(mu, logvar, nsamples) | |
KL = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1).sum(dim=1) | |
return z, KL | |