Spaces:
Runtime error
Runtime error
# %% | |
import torch as t | |
import torch.nn.functional as F | |
import transformers | |
import numpy as np | |
gpt = transformers.AutoModelForCausalLM.from_pretrained("gpt2") | |
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2") | |
def apply_sampling_methods( | |
input_ids: t.Tensor, logits: t.Tensor, temperature=1.0, freq_penalty=0.0, top_k=0, top_p=0.0 | |
) -> int: | |
''' | |
Return the next token, sampled from the model's probability distribution with modifiers. | |
x | |
input_ids: shape (seq,) | |
''' | |
assert input_ids.ndim == 1, "input_ids should be a 1D sequence of token ids" | |
assert temperature >= 0, "Temperature should be non-negative" | |
assert 0 <= top_p <= 1.0, "Top-p must be a probability" | |
assert 0 <= top_k, "Top-k must be non-negative" | |
assert not (top_p != 0 and top_k != 0), "At most one of top-p and top-k supported" | |
if temperature == 0: | |
return greedy_search(logits) | |
if temperature != 1.0: | |
logits = apply_temperature(logits, temperature) | |
if freq_penalty != 0.0: | |
logits = apply_freq_penalty(input_ids, logits, freq_penalty) | |
if top_k > 0: | |
return sample_top_k(logits, top_k) | |
if top_p > 0: | |
return sample_top_p(logits, top_p) | |
return sample_basic(logits) | |
def sample_tokens( | |
model, | |
tokenizer, | |
initial_text: str, | |
max_tokens_generated: int = 30, | |
**kwargs | |
) -> str: | |
''' | |
Sample tokens until the model outputs `tokenizer.eos_token_id` or the specified token limit is reached. | |
Return: the prompt and continuation concatenated | |
''' | |
model.eval() | |
input_ids: list = tokenizer.encode(initial_text) | |
generated = [] | |
device = next(model.parameters()).device | |
for _ in range(max_tokens_generated): | |
new_input_ids = t.tensor(np.array(input_ids + generated), dtype=t.int64, device=device) | |
new_input_ids_truncated = new_input_ids[-min(tokenizer.model_max_length, new_input_ids.shape[0]):].unsqueeze(0) | |
output = model(new_input_ids_truncated) | |
all_logits = output if isinstance(output, t.Tensor) else output.logits | |
logits = all_logits[0, -1] #batch=0, seq_len=-1 -> returns vocab_size | |
new_token = apply_sampling_methods(new_input_ids, logits, **kwargs) | |
generated.append(new_token) | |
if new_token == getattr(tokenizer, "eos_token_id", None): | |
break | |
return tokenizer.decode(input_ids + generated) | |
# %% | |
def greedy_search(logits: t.Tensor) -> int: | |
''' | |
logits: shape (vocab_size, ) | |
Return: the most likely token (as an integer) | |
''' | |
return logits.argmax().numpy() | |
if __name__ == "__main__": | |
prompt = "Jingle bells, jingle bells, jingle all the way" | |
print("Greedy decoding with prompt: ", prompt) | |
output = sample_tokens(gpt, tokenizer, prompt, max_tokens_generated=8, temperature=0.0) | |
print(f"Your model said: {output}") | |
expected = "Jingle bells, jingle bells, jingle all the way up to the top of the mountain." | |
assert output == expected | |
print("Greedy decoding a second time (should be deterministic): ") | |
output = sample_tokens(gpt, tokenizer, prompt, max_tokens_generated=8, temperature=0.0) | |
print(f"Your model said: {output}") | |
expected = "Jingle bells, jingle bells, jingle all the way up to the top of the mountain." | |
assert output == expected | |
print("Tests passed!") | |
# %% | |
def sample_basic(logits: t.Tensor) -> int: | |
''' | |
logits: shape (vocab_size, ) - unnormalized log-probabilities | |
Return: a sampled token | |
''' | |
return t.distributions.categorical.Categorical(logits=logits).sample() | |
if __name__ == "__main__": | |
N = 20000 | |
probs = t.linspace(0, 0.4, 5) | |
unnormalized_logits = probs.log() + 1.2345 | |
samples = t.tensor([sample_basic(unnormalized_logits) for _ in range(N)]) | |
counts = t.bincount(samples, minlength=len(probs)) / N | |
print("Checking empirical frequencies (try to increase N if this test fails): ", counts) | |
t.testing.assert_close(counts, probs, atol=0.01, rtol=0) | |
print("Tests passed!") | |
# %% | |
def apply_temperature(logits: t.Tensor, temperature: float) -> t.Tensor: | |
''' | |
logits: shape (vocab_size, ) | |
Return: shape (vocab_size, ) | |
''' | |
assert temperature > 0 | |
return logits / temperature | |
if __name__ == '__main__': | |
logits = t.tensor([1, 2]).log() | |
cold_logits = apply_temperature(logits, 0.001) | |
print('A low temperature "sharpens" or "peaks" the distribution: ', cold_logits) | |
t.testing.assert_close(cold_logits, 1000.0 * logits) | |
hot_logits = apply_temperature(logits, 1000.0) | |
print("A high temperature flattens the distribution: ", hot_logits) | |
t.testing.assert_close(hot_logits, 0.001 * logits) | |
print("Tests passed!") | |
# %% | |
def apply_freq_penalty(input_ids: t.Tensor, logits: t.Tensor, freq_penalty: float) -> t.Tensor: | |
''' | |
input_ids: shape (seq, ) | |
logits: shape (vocab_size, ) | |
Return: shape (vocab_size, ) | |
''' | |
count = input_ids.bincount(minlength=len(logits)) | |
logits -= count * freq_penalty | |
return logits | |
if __name__ == "__main__": | |
bieber_prompt = "And I was like Baby, baby, baby, oh Like, Baby, baby, baby, no Like, Baby, baby, baby, oh I thought you'd always be mine, mine" | |
input_ids = tokenizer.encode(bieber_prompt, return_tensors="pt").squeeze() | |
logits = t.ones(tokenizer.vocab_size) | |
penalized_logits = apply_freq_penalty(input_ids, logits, 2.0) | |
assert penalized_logits[5156].item() == -11, "Expected 6 occurrences of ' baby' with leading space" | |
assert penalized_logits[14801].item() == -5, "Expected 3 occurrences of ' Baby' with leading space" | |
print("Tests passed!") | |
# %% | |
N_RUNS = 0 | |
your_prompt = "Jingle bells, jingle bells, jingle all the way" | |
cases = [ | |
("High freq penalty", dict(freq_penalty=100.0)), | |
("Negative freq penalty", dict(freq_penalty=-1.0)), | |
("Too hot!", dict(temperature=2.0)), | |
("Pleasantly cool", dict(temperature=0.7)), | |
("Pleasantly warm", dict(temperature=0.9)), | |
("Too cold!", dict(temperature=0.01)), | |
] | |
for (name, kwargs) in cases: | |
for i in range(N_RUNS): | |
output = sample_tokens(gpt, tokenizer, your_prompt, max_tokens_generated=24, **kwargs) | |
print(f"Sample {i} with: {name} ({kwargs}):") | |
print(f"Your model said: {repr(output)}\n") | |
# %% | |
def sample_top_k(logits: t.Tensor, top_k: int) -> int: | |
''' | |
logits: shape (vocab_size, ) - unnormalized log-probabilities | |
top_k: only consider this many of the most likely tokens for sampling | |
Return: a sampled token | |
''' | |
values, indices = t.topk(logits, top_k) | |
return indices[sample_basic(values)].item() | |
if __name__ == "__main__": | |
N = 50000 | |
k = 3 | |
probs = t.linspace(0, 0.4, 5) | |
unnormalized_logits = probs.log() + 1.2345 | |
samples = t.tensor([sample_top_k(unnormalized_logits, k) for _ in range(N)]) | |
counts = t.bincount(samples, minlength=len(probs)) / N | |
expected = probs.clone() | |
expected[:-k] = 0 | |
expected /= expected.sum() | |
print("Checking empirical frequencies (try to increase N if this test fails): ", counts) | |
t.testing.assert_close(counts, expected, atol=0.01, rtol=0) | |
print("Tests passed!") | |
# %% | |
if __name__ == "__main__": | |
your_prompt = "In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English." | |
output = sample_tokens(gpt, tokenizer, your_prompt, temperature=0.7, top_k=40, max_tokens_generated=64) | |
print(f"Your model said: {repr(output)}") | |
# %% | |
def sample_top_p(logits: t.Tensor, top_p: float, min_tokens_to_keep: int = 1) -> int: | |
''' | |
logits: shape (vocab_size, ) - unnormalized log-probabilities | |
Return: a sampled token | |
''' | |
probs = t.exp(logits.double()) / t.exp(logits.double()).sum() | |
sorted_probs, sorted_indices = probs.sort(descending=True) | |
cum_probs = sorted_probs.cumsum(-1) | |
last_index = max(min_tokens_to_keep, t.where(cum_probs >= top_p)[0][0].numpy() + 1) | |
masked_probs = sorted_probs[:last_index] | |
sample = t.distributions.categorical.Categorical(probs=t.tensor(masked_probs)).sample() | |
return sorted_indices[sample] | |
if __name__ == "__main__": | |
N = 2000 | |
unnormalized_logits = t.tensor([0.2, 0.3, 0.5]).log() + 2.3456 | |
samples = t.tensor([sample_top_p(unnormalized_logits, 0.5) for _ in range(N)]) | |
counts = t.bincount(samples, minlength=len(unnormalized_logits)) / N | |
print("top_p of 0.5 or lower should only return token 2: ", counts) | |
assert counts[0] == 0 and counts[1] == 0 | |
N = 2000 | |
unnormalized_logits = t.tensor([0.2, 0.3, 0.5]).log() + 2.3456 | |
samples = t.tensor([sample_top_p(unnormalized_logits, 0.50001) for _ in range(N)]) | |
counts = t.bincount(samples, minlength=len(unnormalized_logits)) / N | |
print("top_p in (0.5, 0.8] should return tokens 1 and 2: ", counts) | |
assert counts[0] == 0 | |
N = 50000 | |
top_p = 0.71 | |
probs = t.linspace(0, 0.4, 5) | |
unnormalized_logits = probs.log() + 1.2345 | |
samples = t.tensor([sample_top_p(unnormalized_logits, top_p) for _ in range(N)]) | |
counts = t.bincount(samples, minlength=len(probs)) / N | |
expected = probs.clone() | |
expected[0:2] = 0 | |
expected /= expected.sum() | |
print("Checking empirical frequencies (try to increase N if this test fails): ", counts) | |
t.testing.assert_close(counts, expected, atol=0.01, rtol=0.0) | |
print("All tests passed!") | |
# %% | |
if __name__ == "__main__": | |
your_prompt = "Eliezer Shlomo Yudkowsky (born September 11, 1979) is an American decision and artificial intelligence (AI) theorist and writer, best known for" | |
output = sample_tokens(gpt, tokenizer, your_prompt, temperature=0.7, top_p=0.95, max_tokens_generated=64) | |
print(f"Your model said: {repr(output)}") | |
# %% |