File size: 5,460 Bytes
5e9bd47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
from .decode_strategy import DecodeStrategy


def sample_with_temperature(logits, sampling_temp, keep_topk):
    """Select next tokens randomly from the top k possible next tokens.

    Samples from a categorical distribution over the ``keep_topk`` words using
    the category probabilities ``logits / sampling_temp``.
    """

    if sampling_temp == 0.0 or keep_topk == 1:
        # argmax
        topk_scores, topk_ids = logits.topk(1, dim=-1)
        if sampling_temp > 0:
            topk_scores /= sampling_temp
    else:
        logits = torch.div(logits, sampling_temp)
        if keep_topk > 0:
            top_values, top_indices = torch.topk(logits, keep_topk, dim=1)
            kth_best = top_values[:, -1].view([-1, 1])
            kth_best = kth_best.repeat([1, logits.shape[1]]).float()
            ignore = torch.lt(logits, kth_best)
            logits = logits.masked_fill(ignore, -10000)

        dist = torch.distributions.Multinomial(logits=logits, total_count=1)
        topk_ids = torch.argmax(dist.sample(), dim=1, keepdim=True)
        topk_scores = logits.gather(dim=1, index=topk_ids)

    return topk_ids, topk_scores


class GreedySearch(DecodeStrategy):
    """Select next tokens randomly from the top k possible next tokens.
    """

    def __init__(self, pad, bos, eos, batch_size, min_length, max_length,
                 return_attention=False, return_hidden=False, sampling_temp=1, keep_topk=1):
        super().__init__(
            pad, bos, eos, batch_size, 1, min_length, max_length, return_attention, return_hidden)
        self.sampling_temp = sampling_temp
        self.keep_topk = keep_topk
        self.topk_scores = None

    def initialize(self, memory_bank, device=None):
        fn_map_state = None

        if device is None:
            device = memory_bank.device

        self.memory_length = memory_bank.size(1)
        super().initialize(memory_bank, device)

        self.select_indices = torch.arange(
            self.batch_size, dtype=torch.long, device=device)
        self.original_batch_idx = torch.arange(
            self.batch_size, dtype=torch.long, device=device)

        return fn_map_state, memory_bank

    @property
    def current_predictions(self):
        return self.alive_seq[:, -1]

    @property
    def batch_offset(self):
        return self.select_indices

    def _pick(self, log_probs):
        """Function used to pick next tokens.
        """
        topk_ids, topk_scores = sample_with_temperature(
            log_probs, self.sampling_temp, self.keep_topk)
        return topk_ids, topk_scores

    def advance(self, log_probs, attn=None, hidden=None, label=None):
        """Select next tokens randomly from the top k possible next tokens.
        """
        self.ensure_min_length(log_probs)
        topk_ids, self.topk_scores = self._pick(log_probs)  # log_probs: b x v; topk_ids & self.topk_scores: b x (t=1)
        self.is_finished = topk_ids.eq(self.eos)
        if label is not None:
            label = label.view_as(self.is_finished)
            self.is_finished = label.eq(self.eos)
        self.alive_seq = torch.cat([self.alive_seq, topk_ids], -1)  # b x (l+1) (first element is <bos>; note l = len(self)-1)
        self.alive_log_token_scores = torch.cat([self.alive_log_token_scores, self.topk_scores], -1)

        if self.return_attention:
            if self.alive_attn is None:
                self.alive_attn = attn
            else:
                self.alive_attn = torch.cat([self.alive_attn, attn], 1)
        if self.return_hidden:
            if self.alive_hidden is None:
                self.alive_hidden = hidden
            else:
                self.alive_hidden = torch.cat([self.alive_hidden, hidden], 1)  # b x l x h
        self.ensure_max_length()

    def update_finished(self):
        """Finalize scores and predictions."""
        # is_finished indicates the decoder finished generating the sequence. Remove it from the batch and update
        # the results.
        finished_batches = self.is_finished.view(-1).nonzero()
        for b in finished_batches.view(-1):
            b_orig = self.original_batch_idx[b]
            # scores/predictions/attention are lists,
            # (to be compatible with beam-search)
            self.scores[b_orig].append(torch.exp(torch.mean(self.alive_log_token_scores[b])).item())
            self.token_scores[b_orig].append(torch.exp(self.alive_log_token_scores[b]).tolist())
            self.predictions[b_orig].append(self.alive_seq[b, 1:])  # skip <bos>
            self.attention[b_orig].append(
                self.alive_attn[b, :, :self.memory_length] if self.alive_attn is not None else [])
            self.hidden[b_orig].append(
                self.alive_hidden[b, :] if self.alive_hidden is not None else [])
        self.done = self.is_finished.all()
        if self.done:
            return
        is_alive = ~self.is_finished.view(-1)
        self.alive_seq = self.alive_seq[is_alive]
        self.alive_log_token_scores = self.alive_log_token_scores[is_alive]
        if self.alive_attn is not None:
            self.alive_attn = self.alive_attn[is_alive]
        if self.alive_hidden is not None:
            self.alive_hidden = self.alive_hidden[is_alive]
        self.select_indices = is_alive.nonzero().view(-1)
        self.original_batch_idx = self.original_batch_idx[is_alive]
        # select_indices is equal to original_batch_idx for greedy search?