File size: 2,030 Bytes
5e9bd47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import torch
class DecodeStrategy(object):
def __init__(self, pad, bos, eos, batch_size, parallel_paths, min_length, max_length,
return_attention=False, return_hidden=False):
self.pad = pad
self.bos = bos
self.eos = eos
self.batch_size = batch_size
self.parallel_paths = parallel_paths
# result catching
self.predictions = [[] for _ in range(batch_size)]
self.scores = [[] for _ in range(batch_size)]
self.token_scores = [[] for _ in range(batch_size)]
self.attention = [[] for _ in range(batch_size)]
self.hidden = [[] for _ in range(batch_size)]
self.alive_attn = None
self.alive_hidden = None
self.min_length = min_length
self.max_length = max_length
n_paths = batch_size * parallel_paths
self.return_attention = return_attention
self.return_hidden = return_hidden
self.done = False
def initialize(self, memory_bank, device=None):
if device is None:
device = torch.device('cpu')
self.alive_seq = torch.full(
[self.batch_size * self.parallel_paths, 1], self.bos,
dtype=torch.long, device=device)
self.is_finished = torch.zeros(
[self.batch_size, self.parallel_paths],
dtype=torch.uint8, device=device)
self.alive_log_token_scores = torch.zeros(
[self.batch_size * self.parallel_paths, 0],
dtype=torch.float, device=device)
return None, memory_bank
def __len__(self):
return self.alive_seq.shape[1]
def ensure_min_length(self, log_probs):
if len(self) <= self.min_length:
log_probs[:, self.eos] = -1e20 # forced non-end
def ensure_max_length(self):
if len(self) == self.max_length + 1:
self.is_finished.fill_(1)
def advance(self, log_probs, attn):
raise NotImplementedError()
def update_finished(self):
raise NotImplementedError
|