# import torch import time import argparse import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence import numpy as np from .decoder import DecoderBase class LSTMDecoder(DecoderBase): """LSTM decoder with constant-length data""" def __init__(self, args, vocab, model_init, emb_init): super(LSTMDecoder, self).__init__() self.ni = args.ni self.nh = args.dec_nh self.nz = args.nz self.vocab = vocab self.device = args.device # no padding when setting padding_idx to -1 self.embed = nn.Embedding(len(vocab), args.ni, padding_idx=-1) self.dropout_in = nn.Dropout(args.dec_dropout_in) self.dropout_out = nn.Dropout(args.dec_dropout_out) # for initializing hidden state and cell self.trans_linear = nn.Linear(args.nz, args.dec_nh, bias=False) # concatenate z with input self.lstm = nn.LSTM(input_size=args.ni + args.nz, hidden_size=args.dec_nh, num_layers=1, batch_first=True) # prediction layer self.pred_linear = nn.Linear(args.dec_nh, len(vocab), bias=False) vocab_mask = torch.ones(len(vocab)) # vocab_mask[vocab['']] = 0 self.loss = nn.CrossEntropyLoss(weight=vocab_mask, reduce=False) self.reset_parameters(model_init, emb_init) def reset_parameters(self, model_init, emb_init): # for name, param in self.lstm.named_parameters(): # # self.initializer(param) # if 'bias' in name: # nn.init.constant_(param, 0.0) # # model_init(param) # elif 'weight' in name: # model_init(param) # model_init(self.trans_linear.weight) # model_init(self.pred_linear.weight) for param in self.parameters(): model_init(param) emb_init(self.embed.weight) def sample_text(self, input, z, EOS, device): sentence = [input] max_index = 0 input_word = input batch_size, n_sample, _ = z.size() seq_len = 1 z_ = z.expand(batch_size, seq_len, self.nz) seq_len = input.size(1) softmax = torch.nn.Softmax(dim=0) while max_index != EOS and len(sentence) < 100: # (batch_size, seq_len, ni) word_embed = self.embed(input_word) word_embed = torch.cat((word_embed, z_), -1) c_init = self.trans_linear(z).unsqueeze(0) h_init = torch.tanh(c_init) if len(sentence) == 1: h_init = h_init.squeeze(dim=1) c_init = c_init.squeeze(dim=1) output, hidden = self.lstm.forward(word_embed, (h_init, c_init)) else: output, hidden = self.lstm.forward(word_embed, hidden) # (batch_size * n_sample, seq_len, vocab_size) output_logits = self.pred_linear(output) output_logits = output_logits.view(-1) probs = softmax(output_logits) # max_index = torch.argmax(output_logits) max_index = torch.multinomial(probs, num_samples=1) input_word = torch.tensor([[max_index]]).to(device) sentence.append(max_index) return sentence def decode(self, input, z): """ Args: input: (batch_size, seq_len) z: (batch_size, n_sample, nz) """ # not predicting start symbol # sents_len -= 1 batch_size, n_sample, _ = z.size() seq_len = input.size(1) # (batch_size, seq_len, ni) word_embed = self.embed(input) word_embed = self.dropout_in(word_embed) if n_sample == 1: z_ = z.expand(batch_size, seq_len, self.nz) else: word_embed = word_embed.unsqueeze(1).expand(batch_size, n_sample, seq_len, self.ni) \ .contiguous() # (batch_size * n_sample, seq_len, ni) word_embed = word_embed.view(batch_size * n_sample, seq_len, self.ni) z_ = z.unsqueeze(2).expand(batch_size, n_sample, seq_len, self.nz).contiguous() z_ = z_.view(batch_size * n_sample, seq_len, self.nz) # (batch_size * n_sample, seq_len, ni + nz) word_embed = torch.cat((word_embed, z_), -1) z = z.view(batch_size * n_sample, self.nz) c_init = self.trans_linear(z).unsqueeze(0) h_init = torch.tanh(c_init) # h_init = self.trans_linear(z).unsqueeze(0) # c_init = h_init.new_zeros(h_init.size()) output, _ = self.lstm(word_embed, (h_init, c_init)) output = self.dropout_out(output) # (batch_size * n_sample, seq_len, vocab_size) output_logits = self.pred_linear(output) return output_logits def reconstruct_error(self, x, z): """Cross Entropy in the language case Args: x: (batch_size, seq_len) z: (batch_size, n_sample, nz) Returns: loss: (batch_size, n_sample). Loss across different sentence and z """ #remove end symbol src = x[:, :-1] # remove start symbol tgt = x[:, 1:] batch_size, seq_len = src.size() n_sample = z.size(1) # (batch_size * n_sample, seq_len, vocab_size) output_logits = self.decode(src, z) if n_sample == 1: tgt = tgt.contiguous().view(-1) else: # (batch_size * n_sample * seq_len) tgt = tgt.unsqueeze(1).expand(batch_size, n_sample, seq_len) \ .contiguous().view(-1) # (batch_size * n_sample * seq_len) loss = self.loss(output_logits.view(-1, output_logits.size(2)), tgt) # (batch_size, n_sample) return loss.view(batch_size, n_sample, -1).sum(-1) def log_probability(self, x, z): """Cross Entropy in the language case Args: x: (batch_size, seq_len) z: (batch_size, n_sample, nz) Returns: log_p: (batch_size, n_sample). log_p(x|z) across different x and z """ return -self.reconstruct_error(x, z) def greedy_decode(self, z): return self.sample_decode(z, greedy=True) def sample_decode(self, z, greedy=False): """sample/greedy decoding from z Args: z: (batch_size, nz) Returns: List1 List1: the decoded word sentence list """ batch_size = z.size(0) decoded_batch = [[] for _ in range(batch_size)] # (batch_size, 1, nz) c_init = self.trans_linear(z).unsqueeze(0) h_init = torch.tanh(c_init) decoder_hidden = (h_init, c_init) decoder_input = torch.tensor([self.vocab[""]] * batch_size, dtype=torch.long, device=self.device).unsqueeze(1) end_symbol = torch.tensor([self.vocab[""]] * batch_size, dtype=torch.long, device=self.device) mask = torch.ones((batch_size), dtype=torch.uint8, device=self.device) length_c = 1 while mask.sum().item() != 0 and length_c < 100: # (batch_size, 1, ni) --> (batch_size, 1, ni+nz) word_embed = self.embed(decoder_input) word_embed = torch.cat((word_embed, z.unsqueeze(1)), dim=-1) output, decoder_hidden = self.lstm(word_embed, decoder_hidden) # (batch_size, 1, vocab_size) --> (batch_size, vocab_size) decoder_output = self.pred_linear(output) output_logits = decoder_output.squeeze(1) # (batch_size) if greedy: max_index = torch.argmax(output_logits, dim=1) else: probs = F.softmax(output_logits, dim=1) max_index = torch.multinomial(probs, num_samples=1).squeeze(1) decoder_input = max_index.unsqueeze(1) length_c += 1 for i in range(batch_size): word = self.vocab.id2word(max_index[i].item()) if mask[i].item(): decoded_batch[i].append(self.vocab.id2word(max_index[i].item())) mask = torch.mul((max_index != end_symbol), mask) return decoded_batch class VarLSTMDecoder(LSTMDecoder): """LSTM decoder with constant-length data""" def __init__(self, args, vocab, model_init, emb_init): super(VarLSTMDecoder, self).__init__(args, vocab, model_init, emb_init) self.embed = nn.Embedding(len(vocab), args.ni, padding_idx=vocab['']) vocab_mask = torch.ones(len(vocab)) vocab_mask[vocab['']] = 0 self.loss = nn.CrossEntropyLoss(weight=vocab_mask, reduce=False) self.reset_parameters(model_init, emb_init) def decode(self, input, z): """ Args: input: tuple which contains x and sents_len x: (batch_size, seq_len) sents_len: long tensor of sentence lengths z: (batch_size, n_sample, nz) """ input, sents_len = input # not predicting start symbol sents_len = sents_len - 1 batch_size, n_sample, _ = z.size() seq_len = input.size(1) # (batch_size, seq_len, ni) word_embed = self.embed(input) word_embed = self.dropout_in(word_embed) if n_sample == 1: z_ = z.expand(batch_size, seq_len, self.nz) else: word_embed = word_embed.unsqueeze(1).expand(batch_size, n_sample, seq_len, self.ni) \ .contiguous() # (batch_size * n_sample, seq_len, ni) word_embed = word_embed.view(batch_size * n_sample, seq_len, self.ni) z_ = z.unsqueeze(2).expand(batch_size, n_sample, seq_len, self.nz).contiguous() z_ = z_.view(batch_size * n_sample, seq_len, self.nz) # (batch_size * n_sample, seq_len, ni + nz) word_embed = torch.cat((word_embed, z_), -1) sents_len = sents_len.unsqueeze(1).expand(batch_size, n_sample).contiguous().view(-1) packed_embed = pack_padded_sequence(word_embed, sents_len.tolist(), batch_first=True) z = z.view(batch_size * n_sample, self.nz) # h_init = self.trans_linear(z).unsqueeze(0) # c_init = h_init.new_zeros(h_init.size()) c_init = self.trans_linear(z).unsqueeze(0) h_init = torch.tanh(c_init) output, _ = self.lstm(packed_embed, (h_init, c_init)) output, _ = pad_packed_sequence(output, batch_first=True) output = self.dropout_out(output) # (batch_size * n_sample, seq_len, vocab_size) output_logits = self.pred_linear(output) return output_logits def reconstruct_error(self, x, z): """Cross Entropy in the language case Args: x: tuple which contains x_ and sents_len x_: (batch_size, seq_len) sents_len: long tensor of sentence lengths z: (batch_size, n_sample, nz) Returns: loss: (batch_size, n_sample). Loss across different sentence and z """ x, sents_len = x #remove end symbol src = x[:, :-1] # remove start symbol tgt = x[:, 1:] batch_size, seq_len = src.size() n_sample = z.size(1) # (batch_size * n_sample, seq_len, vocab_size) output_logits = self.decode((src, sents_len), z) if n_sample == 1: tgt = tgt.contiguous().view(-1) else: # (batch_size * n_sample * seq_len) tgt = tgt.unsqueeze(1).expand(batch_size, n_sample, seq_len) \ .contiguous().view(-1) # (batch_size * n_sample * seq_len) loss = self.loss(output_logits.view(-1, output_logits.size(2)), tgt) # (batch_size, n_sample) return loss.view(batch_size, n_sample, -1).sum(-1)