Spaces:
Sleeping
Sleeping
import re | |
import pytest | |
import torch | |
from einops import rearrange | |
from flash_attn.models.gpt import ( | |
GPTLMHeadModel, | |
remap_state_dict_hf_gpt2, | |
shard_state_dict_tp, | |
combine_state_dicts_tp, | |
) | |
from flash_attn.utils.generation import InferenceParams | |
from flash_attn.utils.pretrained import state_dict_from_pretrained | |
from transformers import GPT2Config, GPT2Tokenizer | |
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF | |
# @pytest.mark.parametrize('model_name', ["gpt2"]) | |
def test_gpt2_state_dict(model_name): | |
config = GPT2Config.from_pretrained(model_name) | |
pretrained_state_dict = remap_state_dict_hf_gpt2(state_dict_from_pretrained(model_name), config) | |
model = GPTLMHeadModel(config) | |
state_dict = model.state_dict() | |
assert state_dict.keys() == pretrained_state_dict.keys() | |
for k in state_dict.keys(): | |
assert state_dict[k].shape == pretrained_state_dict[k].shape | |
# @pytest.mark.parametrize('model_name', ["gpt2"]) | |
def test_gpt2_non_optimized(model_name): | |
"""Check that our implementation of GPT2 (without any optimizations enabled) matches the | |
HF implementation: the output of our forward pass in fp16 should be around the same as the HF | |
forward pass in fp16, when compared to the HF forward pass in fp32. | |
""" | |
dtype = torch.float16 | |
config = GPT2Config.from_pretrained(model_name) | |
model = GPTLMHeadModel.from_pretrained(model_name, config) | |
model = model.cuda().to(dtype=dtype) | |
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda() | |
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype) | |
model.eval() | |
model_ref.eval() | |
model_hf.eval() | |
torch.manual_seed(0) | |
batch_size = 4 | |
max_seqlen = 512 | |
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda") | |
input_ids = torch.randint( | |
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda" | |
) | |
out = model.transformer(input_ids) | |
out_hf = model_hf.transformer(input_ids).last_hidden_state | |
out_ref = model_ref.transformer(input_ids).last_hidden_state | |
print(f"Output max diff: {(out - out_ref).abs().max().item()}") | |
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") | |
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") | |
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") | |
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() | |
logits = model(input_ids).logits | |
logits_hf = model_hf(input_ids).logits | |
logits_ref = model_ref(input_ids).logits | |
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") | |
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") | |
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}") | |
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}") | |
assert (logits - logits_ref).abs().max().item() < 3 * ( | |
logits_hf - logits_ref | |
).abs().max().item() | |
# @pytest.mark.parametrize('model_name', ["gpt2"]) | |
def test_gpt2_optimized(model_name): | |
"""Check that our implementation of GPT2 (with all optimizations enabled) matches the | |
HF implementation: the output of our forward pass in fp16 should be around the same as the HF | |
forward pass in fp16, when compared to the HF forward pass in fp32. | |
""" | |
dtype = torch.float16 | |
config = GPT2Config.from_pretrained(model_name) | |
vocab_size_og = config.vocab_size | |
config.use_flash_attn = True | |
config.fused_bias_fc = True | |
config.fused_mlp = True | |
config.fused_dropout_add_ln = True | |
config.residual_in_fp32 = True | |
config.pad_vocab_size_multiple = 8 | |
model = GPTLMHeadModel.from_pretrained(model_name, config) | |
model = model.cuda().to(dtype=dtype) | |
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda() | |
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype) | |
model.eval() | |
model_ref.eval() | |
model_hf.eval() | |
torch.manual_seed(0) | |
batch_size = 4 | |
max_seqlen = 512 | |
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda") | |
input_ids = torch.randint( | |
0, vocab_size_og, (batch_size, max_seqlen), dtype=torch.long, device="cuda" | |
) | |
out = model.transformer(input_ids) | |
out_hf = model_hf.transformer(input_ids).last_hidden_state | |
out_ref = model_ref.transformer(input_ids).last_hidden_state | |
print(f"Output max diff: {(out - out_ref).abs().max().item()}") | |
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") | |
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") | |
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") | |
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() | |
logits = model(input_ids).logits[..., :vocab_size_og] | |
logits_hf = model_hf(input_ids).logits | |
logits_ref = model_ref(input_ids).logits | |
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") | |
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") | |
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}") | |
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}") | |
assert (logits - logits_ref).abs().max().item() < 3 * ( | |
logits_hf - logits_ref | |
).abs().max().item() | |
# @pytest.mark.parametrize('optimized', [True]) | |
# @pytest.mark.parametrize('rotary', [False]) | |
def test_gpt2_generation(model_name, rotary, optimized): | |
"""Check that our implementation of GPT2 generation matches the HF implementation: | |
the scores in fp16 should be around the same as the HF scores in fp16, when compared to | |
the HF scores in fp32. | |
""" | |
dtype = torch.float16 | |
device = "cuda" | |
rtol, atol = 3e-3, 3e-1 | |
config = GPT2Config.from_pretrained(model_name) | |
if rotary: | |
config.n_positions = 0 | |
config.rotary_emb_fraction = 0.5 | |
config.rotary_emb_base = 24000 | |
config.residual_in_fp32 = True | |
if optimized: | |
config.use_flash_attn = True | |
config.fused_bias_fc = True | |
config.fused_mlp = True | |
config.fused_dropout_add_ln = True | |
# if not rotary, we load the weight from HF but ignore the position embeddings. | |
# The model would be nonsense but it doesn't matter for the test. | |
model = GPTLMHeadModel.from_pretrained( | |
model_name, config, strict=not rotary, device=device, dtype=dtype | |
) | |
model.eval() | |
if not rotary: | |
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device) | |
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name, torch_dtype=dtype).to( | |
device=device | |
) | |
model_ref.eval() | |
model_hf.eval() | |
torch.manual_seed(0) | |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to( | |
device=device | |
) | |
max_length = 25 | |
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda') | |
# max_length = input_ids.shape[1] + 40 | |
# Slow generation for reference | |
sequences = [] | |
scores = [] | |
cur_input_ids = input_ids | |
with torch.inference_mode(): | |
scores.append(model(cur_input_ids).logits[:, -1]) | |
sequences.append(scores[-1].argmax(dim=-1)) | |
for _ in range(input_ids.shape[1] + 1, max_length): | |
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1) | |
scores.append(model(cur_input_ids).logits[:, -1]) | |
sequences.append(scores[-1].argmax(dim=-1)) | |
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1) | |
scores = tuple(scores) | |
out = model.generate( | |
input_ids=input_ids, | |
max_length=max_length, | |
return_dict_in_generate=True, | |
output_scores=True, | |
enable_timing=True, | |
) | |
print(out.sequences) | |
print(tokenizer.batch_decode(out.sequences.tolist())) | |
if getattr(config, "use_flash_attn", False): | |
out_cg = model.generate( | |
input_ids=input_ids, | |
max_length=max_length, | |
cg=True, | |
return_dict_in_generate=True, | |
output_scores=True, | |
enable_timing=True, | |
) | |
print(out_cg.sequences) | |
assert torch.equal(torch.stack(out.scores, dim=1), torch.stack(out_cg.scores, dim=1)) | |
if not rotary: | |
out_hf = model_hf.generate( | |
input_ids=input_ids, | |
max_length=max_length, | |
return_dict_in_generate=True, | |
output_scores=True, | |
) | |
out_ref = model_ref.generate( | |
input_ids=input_ids, | |
max_length=max_length, | |
return_dict_in_generate=True, | |
output_scores=True, | |
) | |
print( | |
f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}" | |
) | |
print( | |
f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}" | |
) | |
print( | |
f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}" | |
) | |
print( | |
f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}" | |
) | |
print(tokenizer.batch_decode(out_ref.sequences.tolist())) | |
assert torch.all(out.sequences == sequences) | |
assert torch.allclose( | |
torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol | |
) | |
if not rotary: | |
assert torch.all(out.sequences == out_ref.sequences) | |
assert torch.all(out.sequences == out_hf.sequences) | |
assert ( | |
torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1) | |
).abs().max().item() < 3 * ( | |
torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1) | |
).abs().max().item() | |
def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs): | |
out = model.generate( | |
input_ids=input_ids, | |
max_length=max_length, | |
teacher_outputs=teacher_outputs, | |
return_dict_in_generate=True, | |
output_scores=True, | |
enable_timing=True, | |
**kwargs, | |
) | |
return torch.stack(out.scores, dim=1) | |
# @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)]) | |
# @pytest.mark.parametrize('rotary', [None]) | |
def test_gpt2_generation_cg(model_name, rotary, seqlen, maxlen): | |
"""Check that decoding with CUDA graph is the same as decoding without CUDA graph.""" | |
dtype = torch.float16 | |
device = "cuda" | |
rtol, atol = 3e-3, 3e-1 | |
config = GPT2Config.from_pretrained(model_name) | |
config.n_positions = 16 * 1024 | |
assert seqlen <= maxlen <= config.n_positions | |
if rotary is not None: | |
config.n_positions = 0 | |
config.rotary_emb_dim = 32 | |
config.rotary_emb_interleaved = rotary == "interleaved" | |
config.residual_in_fp32 = True | |
config.use_flash_attn = True | |
config.fused_bias_fc = True | |
config.fused_mlp = True | |
config.fused_dropout_add_ln = True | |
model = GPTLMHeadModel(config, device=device, dtype=dtype) | |
model.eval() | |
torch.manual_seed(0) | |
batch_size = 1 | |
input_ids = torch.randint( | |
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device | |
) | |
teacher_outputs = torch.randint( | |
0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device | |
) | |
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs) | |
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True) | |
assert torch.equal(logits, logits_cg) | |
# Try increasing batch size and seqlen, then decrease them to see if it's still correct | |
batch_size = 3 | |
maxlen += 30 | |
input_ids = torch.randint( | |
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device | |
) | |
teacher_outputs = torch.randint( | |
0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device | |
) | |
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs) | |
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True) | |
assert torch.equal(logits, logits_cg) | |
batch_size = 2 | |
maxlen -= 35 | |
input_ids = torch.randint( | |
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device | |
) | |
teacher_outputs = torch.randint( | |
0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device | |
) | |
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs) | |
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True) | |
assert torch.equal(logits, logits_cg) | |
# @pytest.mark.parametrize("optimized", [False]) | |
def test_gpt2_multiple_token_generation(model_name, optimized): | |
"""Generation when we pass in multiple tokens at a time, not just one.""" | |
dtype = torch.float16 | |
device = "cuda" | |
rtol, atol = 3e-3, 3e-1 | |
config = GPT2Config.from_pretrained(model_name) | |
config.residual_in_fp32 = True | |
if optimized: | |
config.use_flash_attn = True | |
config.fused_bias_fc = True | |
config.fused_mlp = True | |
config.fused_dropout_add_ln = True | |
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) | |
model.eval() | |
torch.manual_seed(0) | |
input_ids = torch.randint(0, config.vocab_size, (1, 20), dtype=torch.long, device=device) | |
# Reference logits | |
logits_ref = model(input_ids).logits | |
# Run 10 tokens, then pass in another 4, then another 6, to see if we get the same logits | |
inference_params = InferenceParams(max_seqlen=20, max_batch_size=1) | |
logits_10 = model(input_ids[:, :10], inference_params=inference_params).logits | |
inference_params.seqlen_offset += 10 | |
position_ids = torch.arange(10, 14, dtype=torch.long, device=device) | |
logits_1014 = model( | |
input_ids[:, 10:14], position_ids=position_ids, inference_params=inference_params | |
).logits | |
inference_params.seqlen_offset += 4 | |
position_ids = torch.arange(14, 20, dtype=torch.long, device=device) | |
logits_1420 = model( | |
input_ids[:, 14:20], position_ids=position_ids, inference_params=inference_params | |
).logits | |
logits = torch.cat([logits_10, logits_1014, logits_1420], dim=1) | |
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") | |
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") | |
assert torch.allclose(logits, logits_ref, rtol=rtol, atol=atol) | |
# @pytest.mark.parametrize("cg", [True]) | |
# @pytest.mark.parametrize("optimized", [True]) | |
# @pytest.mark.parametrize("model_name", ["gpt2-medium"]) | |
def test_gpt2_speculative_decoding(model_name, optimized, cg): | |
if cg and not optimized: | |
pytest.skip() # CG requires use_flash_attn | |
dtype = torch.float16 | |
device = "cuda" | |
rtol, atol = 3e-3, 3e-1 | |
config = GPT2Config.from_pretrained(model_name) | |
config.residual_in_fp32 = True | |
if optimized: | |
config.use_flash_attn = True | |
config.fused_bias_fc = True | |
config.fused_mlp = True | |
config.fused_dropout_add_ln = True | |
config_draft = GPT2Config.from_pretrained("gpt2") | |
config_draft.residual_in_fp32 = True | |
if optimized: | |
config_draft.use_flash_attn = True | |
config_draft.fused_bias_fc = True | |
config_draft.fused_mlp = True | |
config_draft.fused_dropout_add_ln = True | |
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) | |
model.eval() | |
model_draft = GPTLMHeadModel.from_pretrained("gpt2", config_draft, device=device, dtype=dtype) | |
model_draft.eval() | |
torch.manual_seed(0) | |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to( | |
device=device | |
) | |
max_length = 100 | |
from flash_attn.utils.generation import decode_speculative | |
torch.manual_seed(42) | |
print(f"Speculative decoding, {optimized = }") | |
out = decode_speculative( | |
input_ids, | |
model, | |
model_draft, | |
max_length=max_length, | |
top_k=5, | |
cg=cg, | |
speculative_lookahead=4, | |
enable_timing=True, | |
# debug=True, | |
) | |
print(tokenizer.batch_decode(out.sequences)) | |
print(f"Without speculative decoding, {cg = }") | |
out_og = model.generate( | |
input_ids, | |
max_length=max_length, | |
top_k=5, | |
cg=cg, | |
enable_timing=True, | |
return_dict_in_generate=True, | |
) | |
print(tokenizer.batch_decode(out_og.sequences)) | |
def test_gpt2_shard_unshard(n_heads_q_kv): | |
world_size = 2 | |
config = GPT2Config.from_pretrained("gpt2") | |
config.vocab_size = 1024 | |
config.n_head, config.n_head_kv = n_heads_q_kv | |
model = GPTLMHeadModel(config, device="cuda", dtype=torch.float16) | |
state_dict = model.state_dict() | |
shards = [ | |
# NOTE: Shallow copy as `state_dict` is modified in-place | |
shard_state_dict_tp(dict(state_dict), config, world_size, rank) | |
for rank in range(world_size) | |
] | |
state_dict2 = combine_state_dicts_tp(shards, config) | |
assert state_dict2.keys() == state_dict.keys() | |
for k in state_dict.keys(): | |
ref = state_dict[k] | |
new = state_dict[k] | |
assert torch.allclose(ref, new, atol=0.0, rtol=0.0) | |