File size: 5,531 Bytes
9aaa3a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch


def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
    """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.

    Args:
        input (torch.Tensor): The input tensor containing probabilities.
        num_samples (int): Number of samples to draw.
        replacement (bool): Whether to draw with replacement or not.
    Keywords args:
        generator (torch.Generator): A pseudorandom number generator for sampling.
    Returns:
        torch.Tensor: Last dimension contains num_samples indices
            sampled from the multinomial probability distribution
            located in the last dimension of tensor input.
    """

    if num_samples == 1:
        q = torch.empty_like(input).exponential_(1, generator=generator)
        return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)

    input_ = input.reshape(-1, input.shape[-1])
    output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
    output = output_.reshape(*list(input.shape[:-1]), -1)
    return output


def apply_top_k(
    probs: torch.Tensor,
    k: int,
) -> torch.Tensor:
    """Sample next token from top K values along the last dimension of the input probs tensor.

    Args:
        probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
        k (int): The k in “top-k”.
    Returns:
        torch.Tensor: Sampled tokens.
    """
    v, _ = torch.topk(probs, min(k, probs.size(-1)))
    pivot = v.select(-1, -1).unsqueeze(-1)
    probs = torch.where(probs < pivot, 0.0, probs)
    probs.div_(probs.sum(dim=-1, keepdim=True))
    return probs


def apply_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
    """Sample next token from top P probabilities along the last dimension of the input probs tensor.

    Args:
        probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
        p (int): The p in “top-p”.
    Returns:
        torch.Tensor: Sampled tokens.
    """
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > p
    probs_sort *= (~mask).float()
    probs = probs.scatter(-1, probs_idx, probs_sort)
    probs.div_(probs.sum(dim=-1, keepdim=True))
    return probs


def apply_min_p(probs: torch.Tensor, min_p: float) -> torch.Tensor:
    """Sample next token using min-p sampling.

    Args:
        scores (torch.FloatTensor): Input logits with token candidates on the last dimension.
        min_p (float): Minimum token probability, scaled by the probability of the most likely token.
                       Must be between 0 and 1. Typical values are in the 0.01-0.2 range.
    Returns:
        torch.Tensor: Sampled tokens.
    """
    top_probs, _ = probs.max(dim=-1, keepdim=True)
    tokens_to_remove = probs < (min_p * top_probs)
    probs = probs.masked_fill(tokens_to_remove, 0.0)
    probs.div_(probs.sum(dim=-1, keepdim=True))
    return probs


def modify_logit_for_repetition_penalty(
    logits: torch.Tensor,
    generated_tokens: torch.Tensor,
    repetition_penalty: float,
    repetition_penalty_window: int,
):
    """See https://arxiv.org/abs/1909.05858
    Apply repetition penalty over a sliding window of the last `repetition_penalty_window` tokens.
    logits: (batch_size, n_codebooks, vocab_size)
    generated_tokens: (batch_size, n_codebooks, seq_len)
    """
    generated_tokens = generated_tokens[..., -repetition_penalty_window:]
    generated_tokens = generated_tokens.clamp_max(logits.shape[-1] - 1).to(torch.int64)
    rp = torch.full_like(logits, repetition_penalty)
    factors = torch.ones_like(logits).scatter_reduce(2, generated_tokens, rp, reduce="prod")
    return torch.where(logits <= 0, logits * factors, logits / factors)


def sample_from_logits(
    logits: torch.Tensor,
    temperature: float = 1.0,
    top_p: float = 0.0,
    top_k: int = 0,
    min_p: float = 0.0,
    generated_tokens: torch.Tensor | None = None,
    repetition_penalty: float = 3.0,
    repetition_penalty_window: float = 2,
) -> torch.Tensor:
    """Sample next token from logits using temperature, top-p, top-k, or min-p sampling.

    Args:
        logits (torch.Tensor): Input logits with token candidates on the last dimension.
        temperature (float): Sampling temperature. Lower temperature results in more deterministic samples.
        top_p (float): The p in “top-p”.
        top_k (int): The k in “top-k”.
        min_p (float): Minimum token probability, scaled by the probability of the most likely token.
                       Must be between 0 and 1. Typical values are in the 0.01-0.2 range.

    Returns:
        torch.Tensor: Sampled tokens.
    """
    if repetition_penalty != 1.0 and generated_tokens is not None:
        logits = modify_logit_for_repetition_penalty(logits, generated_tokens, repetition_penalty, repetition_penalty_window)

    if temperature > 0:
        probs = torch.softmax(logits / temperature, dim=-1)

        if top_p > 0:
            probs = apply_top_p(probs, top_p)
        if top_k > 0:
            probs = apply_top_k(probs, top_k)
        if min_p > 0:
            probs = apply_min_p(probs, min_p)

        next_token = multinomial(probs, num_samples=1)
    else:
        next_token = torch.argmax(logits, dim=-1, keepdim=True)

    return next_token  # [batch_size, num_codebooks, 1]