Spaces:
Runtime error
Runtime error
# import torch | |
import time | |
import argparse | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence | |
import numpy as np | |
from .decoder import DecoderBase | |
class LSTMDecoder(DecoderBase): | |
"""LSTM decoder with constant-length data""" | |
def __init__(self, args, vocab, model_init, emb_init): | |
super(LSTMDecoder, self).__init__() | |
self.ni = args.ni | |
self.nh = args.dec_nh | |
self.nz = args.nz | |
self.vocab = vocab | |
self.device = args.device | |
# no padding when setting padding_idx to -1 | |
self.embed = nn.Embedding(len(vocab), args.ni, padding_idx=-1) | |
self.dropout_in = nn.Dropout(args.dec_dropout_in) | |
self.dropout_out = nn.Dropout(args.dec_dropout_out) | |
# for initializing hidden state and cell | |
self.trans_linear = nn.Linear(args.nz, args.dec_nh, bias=False) | |
# concatenate z with input | |
self.lstm = nn.LSTM(input_size=args.ni + args.nz, | |
hidden_size=args.dec_nh, | |
num_layers=1, | |
batch_first=True) | |
# prediction layer | |
self.pred_linear = nn.Linear(args.dec_nh, len(vocab), bias=False) | |
vocab_mask = torch.ones(len(vocab)) | |
# vocab_mask[vocab['<pad>']] = 0 | |
self.loss = nn.CrossEntropyLoss(weight=vocab_mask, reduce=False) | |
self.reset_parameters(model_init, emb_init) | |
def reset_parameters(self, model_init, emb_init): | |
# for name, param in self.lstm.named_parameters(): | |
# # self.initializer(param) | |
# if 'bias' in name: | |
# nn.init.constant_(param, 0.0) | |
# # model_init(param) | |
# elif 'weight' in name: | |
# model_init(param) | |
# model_init(self.trans_linear.weight) | |
# model_init(self.pred_linear.weight) | |
for param in self.parameters(): | |
model_init(param) | |
emb_init(self.embed.weight) | |
def sample_text(self, input, z, EOS, device): | |
sentence = [input] | |
max_index = 0 | |
input_word = input | |
batch_size, n_sample, _ = z.size() | |
seq_len = 1 | |
z_ = z.expand(batch_size, seq_len, self.nz) | |
seq_len = input.size(1) | |
softmax = torch.nn.Softmax(dim=0) | |
while max_index != EOS and len(sentence) < 100: | |
# (batch_size, seq_len, ni) | |
word_embed = self.embed(input_word) | |
word_embed = torch.cat((word_embed, z_), -1) | |
c_init = self.trans_linear(z).unsqueeze(0) | |
h_init = torch.tanh(c_init) | |
if len(sentence) == 1: | |
h_init = h_init.squeeze(dim=1) | |
c_init = c_init.squeeze(dim=1) | |
output, hidden = self.lstm.forward(word_embed, (h_init, c_init)) | |
else: | |
output, hidden = self.lstm.forward(word_embed, hidden) | |
# (batch_size * n_sample, seq_len, vocab_size) | |
output_logits = self.pred_linear(output) | |
output_logits = output_logits.view(-1) | |
probs = softmax(output_logits) | |
# max_index = torch.argmax(output_logits) | |
max_index = torch.multinomial(probs, num_samples=1) | |
input_word = torch.tensor([[max_index]]).to(device) | |
sentence.append(max_index) | |
return sentence | |
def decode(self, input, z): | |
""" | |
Args: | |
input: (batch_size, seq_len) | |
z: (batch_size, n_sample, nz) | |
""" | |
# not predicting start symbol | |
# sents_len -= 1 | |
batch_size, n_sample, _ = z.size() | |
seq_len = input.size(1) | |
# (batch_size, seq_len, ni) | |
word_embed = self.embed(input) | |
word_embed = self.dropout_in(word_embed) | |
if n_sample == 1: | |
z_ = z.expand(batch_size, seq_len, self.nz) | |
else: | |
word_embed = word_embed.unsqueeze(1).expand(batch_size, n_sample, seq_len, self.ni) \ | |
.contiguous() | |
# (batch_size * n_sample, seq_len, ni) | |
word_embed = word_embed.view(batch_size * n_sample, seq_len, self.ni) | |
z_ = z.unsqueeze(2).expand(batch_size, n_sample, seq_len, self.nz).contiguous() | |
z_ = z_.view(batch_size * n_sample, seq_len, self.nz) | |
# (batch_size * n_sample, seq_len, ni + nz) | |
word_embed = torch.cat((word_embed, z_), -1) | |
z = z.view(batch_size * n_sample, self.nz) | |
c_init = self.trans_linear(z).unsqueeze(0) | |
h_init = torch.tanh(c_init) | |
# h_init = self.trans_linear(z).unsqueeze(0) | |
# c_init = h_init.new_zeros(h_init.size()) | |
output, _ = self.lstm(word_embed, (h_init, c_init)) | |
output = self.dropout_out(output) | |
# (batch_size * n_sample, seq_len, vocab_size) | |
output_logits = self.pred_linear(output) | |
return output_logits | |
def reconstruct_error(self, x, z): | |
"""Cross Entropy in the language case | |
Args: | |
x: (batch_size, seq_len) | |
z: (batch_size, n_sample, nz) | |
Returns: | |
loss: (batch_size, n_sample). Loss | |
across different sentence and z | |
""" | |
#remove end symbol | |
src = x[:, :-1] | |
# remove start symbol | |
tgt = x[:, 1:] | |
batch_size, seq_len = src.size() | |
n_sample = z.size(1) | |
# (batch_size * n_sample, seq_len, vocab_size) | |
output_logits = self.decode(src, z) | |
if n_sample == 1: | |
tgt = tgt.contiguous().view(-1) | |
else: | |
# (batch_size * n_sample * seq_len) | |
tgt = tgt.unsqueeze(1).expand(batch_size, n_sample, seq_len) \ | |
.contiguous().view(-1) | |
# (batch_size * n_sample * seq_len) | |
loss = self.loss(output_logits.view(-1, output_logits.size(2)), | |
tgt) | |
# (batch_size, n_sample) | |
return loss.view(batch_size, n_sample, -1).sum(-1) | |
def log_probability(self, x, z): | |
"""Cross Entropy in the language case | |
Args: | |
x: (batch_size, seq_len) | |
z: (batch_size, n_sample, nz) | |
Returns: | |
log_p: (batch_size, n_sample). | |
log_p(x|z) across different x and z | |
""" | |
return -self.reconstruct_error(x, z) | |
def greedy_decode(self, z): | |
return self.sample_decode(z, greedy=True) | |
def sample_decode(self, z, greedy=False): | |
"""sample/greedy decoding from z | |
Args: | |
z: (batch_size, nz) | |
Returns: List1 | |
List1: the decoded word sentence list | |
""" | |
batch_size = z.size(0) | |
decoded_batch = [[] for _ in range(batch_size)] | |
# (batch_size, 1, nz) | |
c_init = self.trans_linear(z).unsqueeze(0) | |
h_init = torch.tanh(c_init) | |
decoder_hidden = (h_init, c_init) | |
decoder_input = torch.tensor([self.vocab["<s>"]] * batch_size, dtype=torch.long, device=self.device).unsqueeze(1) | |
end_symbol = torch.tensor([self.vocab["</s>"]] * batch_size, dtype=torch.long, device=self.device) | |
mask = torch.ones((batch_size), dtype=torch.uint8, device=self.device) | |
length_c = 1 | |
while mask.sum().item() != 0 and length_c < 100: | |
# (batch_size, 1, ni) --> (batch_size, 1, ni+nz) | |
word_embed = self.embed(decoder_input) | |
word_embed = torch.cat((word_embed, z.unsqueeze(1)), dim=-1) | |
output, decoder_hidden = self.lstm(word_embed, decoder_hidden) | |
# (batch_size, 1, vocab_size) --> (batch_size, vocab_size) | |
decoder_output = self.pred_linear(output) | |
output_logits = decoder_output.squeeze(1) | |
# (batch_size) | |
if greedy: | |
max_index = torch.argmax(output_logits, dim=1) | |
else: | |
probs = F.softmax(output_logits, dim=1) | |
max_index = torch.multinomial(probs, num_samples=1).squeeze(1) | |
decoder_input = max_index.unsqueeze(1) | |
length_c += 1 | |
for i in range(batch_size): | |
word = self.vocab.id2word(max_index[i].item()) | |
if mask[i].item(): | |
decoded_batch[i].append(self.vocab.id2word(max_index[i].item())) | |
mask = torch.mul((max_index != end_symbol), mask) | |
return decoded_batch | |
class VarLSTMDecoder(LSTMDecoder): | |
"""LSTM decoder with constant-length data""" | |
def __init__(self, args, vocab, model_init, emb_init): | |
super(VarLSTMDecoder, self).__init__(args, vocab, model_init, emb_init) | |
self.embed = nn.Embedding(len(vocab), args.ni, padding_idx=vocab['<pad>']) | |
vocab_mask = torch.ones(len(vocab)) | |
vocab_mask[vocab['<pad>']] = 0 | |
self.loss = nn.CrossEntropyLoss(weight=vocab_mask, reduce=False) | |
self.reset_parameters(model_init, emb_init) | |
def decode(self, input, z): | |
""" | |
Args: | |
input: tuple which contains x and sents_len | |
x: (batch_size, seq_len) | |
sents_len: long tensor of sentence lengths | |
z: (batch_size, n_sample, nz) | |
""" | |
input, sents_len = input | |
# not predicting start symbol | |
sents_len = sents_len - 1 | |
batch_size, n_sample, _ = z.size() | |
seq_len = input.size(1) | |
# (batch_size, seq_len, ni) | |
word_embed = self.embed(input) | |
word_embed = self.dropout_in(word_embed) | |
if n_sample == 1: | |
z_ = z.expand(batch_size, seq_len, self.nz) | |
else: | |
word_embed = word_embed.unsqueeze(1).expand(batch_size, n_sample, seq_len, self.ni) \ | |
.contiguous() | |
# (batch_size * n_sample, seq_len, ni) | |
word_embed = word_embed.view(batch_size * n_sample, seq_len, self.ni) | |
z_ = z.unsqueeze(2).expand(batch_size, n_sample, seq_len, self.nz).contiguous() | |
z_ = z_.view(batch_size * n_sample, seq_len, self.nz) | |
# (batch_size * n_sample, seq_len, ni + nz) | |
word_embed = torch.cat((word_embed, z_), -1) | |
sents_len = sents_len.unsqueeze(1).expand(batch_size, n_sample).contiguous().view(-1) | |
packed_embed = pack_padded_sequence(word_embed, sents_len.tolist(), batch_first=True) | |
z = z.view(batch_size * n_sample, self.nz) | |
# h_init = self.trans_linear(z).unsqueeze(0) | |
# c_init = h_init.new_zeros(h_init.size()) | |
c_init = self.trans_linear(z).unsqueeze(0) | |
h_init = torch.tanh(c_init) | |
output, _ = self.lstm(packed_embed, (h_init, c_init)) | |
output, _ = pad_packed_sequence(output, batch_first=True) | |
output = self.dropout_out(output) | |
# (batch_size * n_sample, seq_len, vocab_size) | |
output_logits = self.pred_linear(output) | |
return output_logits | |
def reconstruct_error(self, x, z): | |
"""Cross Entropy in the language case | |
Args: | |
x: tuple which contains x_ and sents_len | |
x_: (batch_size, seq_len) | |
sents_len: long tensor of sentence lengths | |
z: (batch_size, n_sample, nz) | |
Returns: | |
loss: (batch_size, n_sample). Loss | |
across different sentence and z | |
""" | |
x, sents_len = x | |
#remove end symbol | |
src = x[:, :-1] | |
# remove start symbol | |
tgt = x[:, 1:] | |
batch_size, seq_len = src.size() | |
n_sample = z.size(1) | |
# (batch_size * n_sample, seq_len, vocab_size) | |
output_logits = self.decode((src, sents_len), z) | |
if n_sample == 1: | |
tgt = tgt.contiguous().view(-1) | |
else: | |
# (batch_size * n_sample * seq_len) | |
tgt = tgt.unsqueeze(1).expand(batch_size, n_sample, seq_len) \ | |
.contiguous().view(-1) | |
# (batch_size * n_sample * seq_len) | |
loss = self.loss(output_logits.view(-1, output_logits.size(2)), | |
tgt) | |
# (batch_size, n_sample) | |
return loss.view(batch_size, n_sample, -1).sum(-1) |