shakespeare-demo / sampling.py
skar0's picture
Initial commit
4c2c4e8
# %%
import torch as t
import torch.nn.functional as F
import transformers
import numpy as np
gpt = transformers.AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
def apply_sampling_methods(
input_ids: t.Tensor, logits: t.Tensor, temperature=1.0, freq_penalty=0.0, top_k=0, top_p=0.0
) -> int:
'''
Return the next token, sampled from the model's probability distribution with modifiers.
x
input_ids: shape (seq,)
'''
assert input_ids.ndim == 1, "input_ids should be a 1D sequence of token ids"
assert temperature >= 0, "Temperature should be non-negative"
assert 0 <= top_p <= 1.0, "Top-p must be a probability"
assert 0 <= top_k, "Top-k must be non-negative"
assert not (top_p != 0 and top_k != 0), "At most one of top-p and top-k supported"
if temperature == 0:
return greedy_search(logits)
if temperature != 1.0:
logits = apply_temperature(logits, temperature)
if freq_penalty != 0.0:
logits = apply_freq_penalty(input_ids, logits, freq_penalty)
if top_k > 0:
return sample_top_k(logits, top_k)
if top_p > 0:
return sample_top_p(logits, top_p)
return sample_basic(logits)
def sample_tokens(
model,
tokenizer,
initial_text: str,
max_tokens_generated: int = 30,
**kwargs
) -> str:
'''
Sample tokens until the model outputs `tokenizer.eos_token_id` or the specified token limit is reached.
Return: the prompt and continuation concatenated
'''
model.eval()
input_ids: list = tokenizer.encode(initial_text)
generated = []
device = next(model.parameters()).device
for _ in range(max_tokens_generated):
new_input_ids = t.tensor(np.array(input_ids + generated), dtype=t.int64, device=device)
new_input_ids_truncated = new_input_ids[-min(tokenizer.model_max_length, new_input_ids.shape[0]):].unsqueeze(0)
output = model(new_input_ids_truncated)
all_logits = output if isinstance(output, t.Tensor) else output.logits
logits = all_logits[0, -1] #batch=0, seq_len=-1 -> returns vocab_size
new_token = apply_sampling_methods(new_input_ids, logits, **kwargs)
generated.append(new_token)
if new_token == getattr(tokenizer, "eos_token_id", None):
break
return tokenizer.decode(input_ids + generated)
# %%
def greedy_search(logits: t.Tensor) -> int:
'''
logits: shape (vocab_size, )
Return: the most likely token (as an integer)
'''
return logits.argmax().numpy()
if __name__ == "__main__":
prompt = "Jingle bells, jingle bells, jingle all the way"
print("Greedy decoding with prompt: ", prompt)
output = sample_tokens(gpt, tokenizer, prompt, max_tokens_generated=8, temperature=0.0)
print(f"Your model said: {output}")
expected = "Jingle bells, jingle bells, jingle all the way up to the top of the mountain."
assert output == expected
print("Greedy decoding a second time (should be deterministic): ")
output = sample_tokens(gpt, tokenizer, prompt, max_tokens_generated=8, temperature=0.0)
print(f"Your model said: {output}")
expected = "Jingle bells, jingle bells, jingle all the way up to the top of the mountain."
assert output == expected
print("Tests passed!")
# %%
def sample_basic(logits: t.Tensor) -> int:
'''
logits: shape (vocab_size, ) - unnormalized log-probabilities
Return: a sampled token
'''
return t.distributions.categorical.Categorical(logits=logits).sample()
if __name__ == "__main__":
N = 20000
probs = t.linspace(0, 0.4, 5)
unnormalized_logits = probs.log() + 1.2345
samples = t.tensor([sample_basic(unnormalized_logits) for _ in range(N)])
counts = t.bincount(samples, minlength=len(probs)) / N
print("Checking empirical frequencies (try to increase N if this test fails): ", counts)
t.testing.assert_close(counts, probs, atol=0.01, rtol=0)
print("Tests passed!")
# %%
def apply_temperature(logits: t.Tensor, temperature: float) -> t.Tensor:
'''
logits: shape (vocab_size, )
Return: shape (vocab_size, )
'''
assert temperature > 0
return logits / temperature
if __name__ == '__main__':
logits = t.tensor([1, 2]).log()
cold_logits = apply_temperature(logits, 0.001)
print('A low temperature "sharpens" or "peaks" the distribution: ', cold_logits)
t.testing.assert_close(cold_logits, 1000.0 * logits)
hot_logits = apply_temperature(logits, 1000.0)
print("A high temperature flattens the distribution: ", hot_logits)
t.testing.assert_close(hot_logits, 0.001 * logits)
print("Tests passed!")
# %%
def apply_freq_penalty(input_ids: t.Tensor, logits: t.Tensor, freq_penalty: float) -> t.Tensor:
'''
input_ids: shape (seq, )
logits: shape (vocab_size, )
Return: shape (vocab_size, )
'''
count = input_ids.bincount(minlength=len(logits))
logits -= count * freq_penalty
return logits
if __name__ == "__main__":
bieber_prompt = "And I was like Baby, baby, baby, oh Like, Baby, baby, baby, no Like, Baby, baby, baby, oh I thought you'd always be mine, mine"
input_ids = tokenizer.encode(bieber_prompt, return_tensors="pt").squeeze()
logits = t.ones(tokenizer.vocab_size)
penalized_logits = apply_freq_penalty(input_ids, logits, 2.0)
assert penalized_logits[5156].item() == -11, "Expected 6 occurrences of ' baby' with leading space"
assert penalized_logits[14801].item() == -5, "Expected 3 occurrences of ' Baby' with leading space"
print("Tests passed!")
# %%
N_RUNS = 0
your_prompt = "Jingle bells, jingle bells, jingle all the way"
cases = [
("High freq penalty", dict(freq_penalty=100.0)),
("Negative freq penalty", dict(freq_penalty=-1.0)),
("Too hot!", dict(temperature=2.0)),
("Pleasantly cool", dict(temperature=0.7)),
("Pleasantly warm", dict(temperature=0.9)),
("Too cold!", dict(temperature=0.01)),
]
for (name, kwargs) in cases:
for i in range(N_RUNS):
output = sample_tokens(gpt, tokenizer, your_prompt, max_tokens_generated=24, **kwargs)
print(f"Sample {i} with: {name} ({kwargs}):")
print(f"Your model said: {repr(output)}\n")
# %%
def sample_top_k(logits: t.Tensor, top_k: int) -> int:
'''
logits: shape (vocab_size, ) - unnormalized log-probabilities
top_k: only consider this many of the most likely tokens for sampling
Return: a sampled token
'''
values, indices = t.topk(logits, top_k)
return indices[sample_basic(values)].item()
if __name__ == "__main__":
N = 50000
k = 3
probs = t.linspace(0, 0.4, 5)
unnormalized_logits = probs.log() + 1.2345
samples = t.tensor([sample_top_k(unnormalized_logits, k) for _ in range(N)])
counts = t.bincount(samples, minlength=len(probs)) / N
expected = probs.clone()
expected[:-k] = 0
expected /= expected.sum()
print("Checking empirical frequencies (try to increase N if this test fails): ", counts)
t.testing.assert_close(counts, expected, atol=0.01, rtol=0)
print("Tests passed!")
# %%
if __name__ == "__main__":
your_prompt = "In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English."
output = sample_tokens(gpt, tokenizer, your_prompt, temperature=0.7, top_k=40, max_tokens_generated=64)
print(f"Your model said: {repr(output)}")
# %%
def sample_top_p(logits: t.Tensor, top_p: float, min_tokens_to_keep: int = 1) -> int:
'''
logits: shape (vocab_size, ) - unnormalized log-probabilities
Return: a sampled token
'''
probs = t.exp(logits.double()) / t.exp(logits.double()).sum()
sorted_probs, sorted_indices = probs.sort(descending=True)
cum_probs = sorted_probs.cumsum(-1)
last_index = max(min_tokens_to_keep, t.where(cum_probs >= top_p)[0][0].numpy() + 1)
masked_probs = sorted_probs[:last_index]
sample = t.distributions.categorical.Categorical(probs=t.tensor(masked_probs)).sample()
return sorted_indices[sample]
if __name__ == "__main__":
N = 2000
unnormalized_logits = t.tensor([0.2, 0.3, 0.5]).log() + 2.3456
samples = t.tensor([sample_top_p(unnormalized_logits, 0.5) for _ in range(N)])
counts = t.bincount(samples, minlength=len(unnormalized_logits)) / N
print("top_p of 0.5 or lower should only return token 2: ", counts)
assert counts[0] == 0 and counts[1] == 0
N = 2000
unnormalized_logits = t.tensor([0.2, 0.3, 0.5]).log() + 2.3456
samples = t.tensor([sample_top_p(unnormalized_logits, 0.50001) for _ in range(N)])
counts = t.bincount(samples, minlength=len(unnormalized_logits)) / N
print("top_p in (0.5, 0.8] should return tokens 1 and 2: ", counts)
assert counts[0] == 0
N = 50000
top_p = 0.71
probs = t.linspace(0, 0.4, 5)
unnormalized_logits = probs.log() + 1.2345
samples = t.tensor([sample_top_p(unnormalized_logits, top_p) for _ in range(N)])
counts = t.bincount(samples, minlength=len(probs)) / N
expected = probs.clone()
expected[0:2] = 0
expected /= expected.sum()
print("Checking empirical frequencies (try to increase N if this test fails): ", counts)
t.testing.assert_close(counts, expected, atol=0.01, rtol=0.0)
print("All tests passed!")
# %%
if __name__ == "__main__":
your_prompt = "Eliezer Shlomo Yudkowsky (born September 11, 1979) is an American decision and artificial intelligence (AI) theorist and writer, best known for"
output = sample_tokens(gpt, tokenizer, your_prompt, temperature=0.7, top_p=0.95, max_tokens_generated=64)
print(f"Your model said: {repr(output)}")
# %%