rynmurdock's picture
init
c5ca37a
# import torch
import time
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
import numpy as np
from .decoder import DecoderBase
class LSTMDecoder(DecoderBase):
"""LSTM decoder with constant-length data"""
def __init__(self, args, vocab, model_init, emb_init):
super(LSTMDecoder, self).__init__()
self.ni = args.ni
self.nh = args.dec_nh
self.nz = args.nz
self.vocab = vocab
self.device = args.device
# no padding when setting padding_idx to -1
self.embed = nn.Embedding(len(vocab), args.ni, padding_idx=-1)
self.dropout_in = nn.Dropout(args.dec_dropout_in)
self.dropout_out = nn.Dropout(args.dec_dropout_out)
# for initializing hidden state and cell
self.trans_linear = nn.Linear(args.nz, args.dec_nh, bias=False)
# concatenate z with input
self.lstm = nn.LSTM(input_size=args.ni + args.nz,
hidden_size=args.dec_nh,
num_layers=1,
batch_first=True)
# prediction layer
self.pred_linear = nn.Linear(args.dec_nh, len(vocab), bias=False)
vocab_mask = torch.ones(len(vocab))
# vocab_mask[vocab['<pad>']] = 0
self.loss = nn.CrossEntropyLoss(weight=vocab_mask, reduce=False)
self.reset_parameters(model_init, emb_init)
def reset_parameters(self, model_init, emb_init):
# for name, param in self.lstm.named_parameters():
# # self.initializer(param)
# if 'bias' in name:
# nn.init.constant_(param, 0.0)
# # model_init(param)
# elif 'weight' in name:
# model_init(param)
# model_init(self.trans_linear.weight)
# model_init(self.pred_linear.weight)
for param in self.parameters():
model_init(param)
emb_init(self.embed.weight)
def sample_text(self, input, z, EOS, device):
sentence = [input]
max_index = 0
input_word = input
batch_size, n_sample, _ = z.size()
seq_len = 1
z_ = z.expand(batch_size, seq_len, self.nz)
seq_len = input.size(1)
softmax = torch.nn.Softmax(dim=0)
while max_index != EOS and len(sentence) < 100:
# (batch_size, seq_len, ni)
word_embed = self.embed(input_word)
word_embed = torch.cat((word_embed, z_), -1)
c_init = self.trans_linear(z).unsqueeze(0)
h_init = torch.tanh(c_init)
if len(sentence) == 1:
h_init = h_init.squeeze(dim=1)
c_init = c_init.squeeze(dim=1)
output, hidden = self.lstm.forward(word_embed, (h_init, c_init))
else:
output, hidden = self.lstm.forward(word_embed, hidden)
# (batch_size * n_sample, seq_len, vocab_size)
output_logits = self.pred_linear(output)
output_logits = output_logits.view(-1)
probs = softmax(output_logits)
# max_index = torch.argmax(output_logits)
max_index = torch.multinomial(probs, num_samples=1)
input_word = torch.tensor([[max_index]]).to(device)
sentence.append(max_index)
return sentence
def decode(self, input, z):
"""
Args:
input: (batch_size, seq_len)
z: (batch_size, n_sample, nz)
"""
# not predicting start symbol
# sents_len -= 1
batch_size, n_sample, _ = z.size()
seq_len = input.size(1)
# (batch_size, seq_len, ni)
word_embed = self.embed(input)
word_embed = self.dropout_in(word_embed)
if n_sample == 1:
z_ = z.expand(batch_size, seq_len, self.nz)
else:
word_embed = word_embed.unsqueeze(1).expand(batch_size, n_sample, seq_len, self.ni) \
.contiguous()
# (batch_size * n_sample, seq_len, ni)
word_embed = word_embed.view(batch_size * n_sample, seq_len, self.ni)
z_ = z.unsqueeze(2).expand(batch_size, n_sample, seq_len, self.nz).contiguous()
z_ = z_.view(batch_size * n_sample, seq_len, self.nz)
# (batch_size * n_sample, seq_len, ni + nz)
word_embed = torch.cat((word_embed, z_), -1)
z = z.view(batch_size * n_sample, self.nz)
c_init = self.trans_linear(z).unsqueeze(0)
h_init = torch.tanh(c_init)
# h_init = self.trans_linear(z).unsqueeze(0)
# c_init = h_init.new_zeros(h_init.size())
output, _ = self.lstm(word_embed, (h_init, c_init))
output = self.dropout_out(output)
# (batch_size * n_sample, seq_len, vocab_size)
output_logits = self.pred_linear(output)
return output_logits
def reconstruct_error(self, x, z):
"""Cross Entropy in the language case
Args:
x: (batch_size, seq_len)
z: (batch_size, n_sample, nz)
Returns:
loss: (batch_size, n_sample). Loss
across different sentence and z
"""
#remove end symbol
src = x[:, :-1]
# remove start symbol
tgt = x[:, 1:]
batch_size, seq_len = src.size()
n_sample = z.size(1)
# (batch_size * n_sample, seq_len, vocab_size)
output_logits = self.decode(src, z)
if n_sample == 1:
tgt = tgt.contiguous().view(-1)
else:
# (batch_size * n_sample * seq_len)
tgt = tgt.unsqueeze(1).expand(batch_size, n_sample, seq_len) \
.contiguous().view(-1)
# (batch_size * n_sample * seq_len)
loss = self.loss(output_logits.view(-1, output_logits.size(2)),
tgt)
# (batch_size, n_sample)
return loss.view(batch_size, n_sample, -1).sum(-1)
def log_probability(self, x, z):
"""Cross Entropy in the language case
Args:
x: (batch_size, seq_len)
z: (batch_size, n_sample, nz)
Returns:
log_p: (batch_size, n_sample).
log_p(x|z) across different x and z
"""
return -self.reconstruct_error(x, z)
def greedy_decode(self, z):
return self.sample_decode(z, greedy=True)
def sample_decode(self, z, greedy=False):
"""sample/greedy decoding from z
Args:
z: (batch_size, nz)
Returns: List1
List1: the decoded word sentence list
"""
batch_size = z.size(0)
decoded_batch = [[] for _ in range(batch_size)]
# (batch_size, 1, nz)
c_init = self.trans_linear(z).unsqueeze(0)
h_init = torch.tanh(c_init)
decoder_hidden = (h_init, c_init)
decoder_input = torch.tensor([self.vocab["<s>"]] * batch_size, dtype=torch.long, device=self.device).unsqueeze(1)
end_symbol = torch.tensor([self.vocab["</s>"]] * batch_size, dtype=torch.long, device=self.device)
mask = torch.ones((batch_size), dtype=torch.uint8, device=self.device)
length_c = 1
while mask.sum().item() != 0 and length_c < 100:
# (batch_size, 1, ni) --> (batch_size, 1, ni+nz)
word_embed = self.embed(decoder_input)
word_embed = torch.cat((word_embed, z.unsqueeze(1)), dim=-1)
output, decoder_hidden = self.lstm(word_embed, decoder_hidden)
# (batch_size, 1, vocab_size) --> (batch_size, vocab_size)
decoder_output = self.pred_linear(output)
output_logits = decoder_output.squeeze(1)
# (batch_size)
if greedy:
max_index = torch.argmax(output_logits, dim=1)
else:
probs = F.softmax(output_logits, dim=1)
max_index = torch.multinomial(probs, num_samples=1).squeeze(1)
decoder_input = max_index.unsqueeze(1)
length_c += 1
for i in range(batch_size):
word = self.vocab.id2word(max_index[i].item())
if mask[i].item():
decoded_batch[i].append(self.vocab.id2word(max_index[i].item()))
mask = torch.mul((max_index != end_symbol), mask)
return decoded_batch
class VarLSTMDecoder(LSTMDecoder):
"""LSTM decoder with constant-length data"""
def __init__(self, args, vocab, model_init, emb_init):
super(VarLSTMDecoder, self).__init__(args, vocab, model_init, emb_init)
self.embed = nn.Embedding(len(vocab), args.ni, padding_idx=vocab['<pad>'])
vocab_mask = torch.ones(len(vocab))
vocab_mask[vocab['<pad>']] = 0
self.loss = nn.CrossEntropyLoss(weight=vocab_mask, reduce=False)
self.reset_parameters(model_init, emb_init)
def decode(self, input, z):
"""
Args:
input: tuple which contains x and sents_len
x: (batch_size, seq_len)
sents_len: long tensor of sentence lengths
z: (batch_size, n_sample, nz)
"""
input, sents_len = input
# not predicting start symbol
sents_len = sents_len - 1
batch_size, n_sample, _ = z.size()
seq_len = input.size(1)
# (batch_size, seq_len, ni)
word_embed = self.embed(input)
word_embed = self.dropout_in(word_embed)
if n_sample == 1:
z_ = z.expand(batch_size, seq_len, self.nz)
else:
word_embed = word_embed.unsqueeze(1).expand(batch_size, n_sample, seq_len, self.ni) \
.contiguous()
# (batch_size * n_sample, seq_len, ni)
word_embed = word_embed.view(batch_size * n_sample, seq_len, self.ni)
z_ = z.unsqueeze(2).expand(batch_size, n_sample, seq_len, self.nz).contiguous()
z_ = z_.view(batch_size * n_sample, seq_len, self.nz)
# (batch_size * n_sample, seq_len, ni + nz)
word_embed = torch.cat((word_embed, z_), -1)
sents_len = sents_len.unsqueeze(1).expand(batch_size, n_sample).contiguous().view(-1)
packed_embed = pack_padded_sequence(word_embed, sents_len.tolist(), batch_first=True)
z = z.view(batch_size * n_sample, self.nz)
# h_init = self.trans_linear(z).unsqueeze(0)
# c_init = h_init.new_zeros(h_init.size())
c_init = self.trans_linear(z).unsqueeze(0)
h_init = torch.tanh(c_init)
output, _ = self.lstm(packed_embed, (h_init, c_init))
output, _ = pad_packed_sequence(output, batch_first=True)
output = self.dropout_out(output)
# (batch_size * n_sample, seq_len, vocab_size)
output_logits = self.pred_linear(output)
return output_logits
def reconstruct_error(self, x, z):
"""Cross Entropy in the language case
Args:
x: tuple which contains x_ and sents_len
x_: (batch_size, seq_len)
sents_len: long tensor of sentence lengths
z: (batch_size, n_sample, nz)
Returns:
loss: (batch_size, n_sample). Loss
across different sentence and z
"""
x, sents_len = x
#remove end symbol
src = x[:, :-1]
# remove start symbol
tgt = x[:, 1:]
batch_size, seq_len = src.size()
n_sample = z.size(1)
# (batch_size * n_sample, seq_len, vocab_size)
output_logits = self.decode((src, sents_len), z)
if n_sample == 1:
tgt = tgt.contiguous().view(-1)
else:
# (batch_size * n_sample * seq_len)
tgt = tgt.unsqueeze(1).expand(batch_size, n_sample, seq_len) \
.contiguous().view(-1)
# (batch_size * n_sample * seq_len)
loss = self.loss(output_logits.view(-1, output_logits.size(2)),
tgt)
# (batch_size, n_sample)
return loss.view(batch_size, n_sample, -1).sum(-1)