File size: 19,054 Bytes
e45d058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
import re

import pytest
import torch
from einops import rearrange
from flash_attn.models.gpt import (
    GPTLMHeadModel,
    remap_state_dict_hf_gpt2,
    shard_state_dict_tp,
    combine_state_dicts_tp,
)
from flash_attn.utils.generation import InferenceParams
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import GPT2Config, GPT2Tokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF


@pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
# @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_state_dict(model_name):
    config = GPT2Config.from_pretrained(model_name)
    pretrained_state_dict = remap_state_dict_hf_gpt2(state_dict_from_pretrained(model_name), config)
    model = GPTLMHeadModel(config)
    state_dict = model.state_dict()
    assert state_dict.keys() == pretrained_state_dict.keys()
    for k in state_dict.keys():
        assert state_dict[k].shape == pretrained_state_dict[k].shape


@pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
# @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_non_optimized(model_name):
    """Check that our implementation of GPT2 (without any optimizations enabled) matches the

    HF implementation: the output of our forward pass in fp16 should be around the same as the HF

    forward pass in fp16, when compared to the HF forward pass in fp32.

    """
    dtype = torch.float16
    config = GPT2Config.from_pretrained(model_name)

    model = GPTLMHeadModel.from_pretrained(model_name, config)
    model = model.cuda().to(dtype=dtype)

    model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda()
    model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype)

    model.eval()
    model_ref.eval()
    model_hf.eval()

    torch.manual_seed(0)
    batch_size = 4
    max_seqlen = 512
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
    )
    out = model.transformer(input_ids)
    out_hf = model_hf.transformer(input_ids).last_hidden_state
    out_ref = model_ref.transformer(input_ids).last_hidden_state

    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
    print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
    assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()

    logits = model(input_ids).logits
    logits_hf = model_hf(input_ids).logits
    logits_ref = model_ref(input_ids).logits

    print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
    print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
    print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
    print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
    assert (logits - logits_ref).abs().max().item() < 3 * (
        logits_hf - logits_ref
    ).abs().max().item()


@pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
# @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_optimized(model_name):
    """Check that our implementation of GPT2 (with all optimizations enabled) matches the

    HF implementation: the output of our forward pass in fp16 should be around the same as the HF

    forward pass in fp16, when compared to the HF forward pass in fp32.

    """
    dtype = torch.float16
    config = GPT2Config.from_pretrained(model_name)
    vocab_size_og = config.vocab_size
    config.use_flash_attn = True
    config.fused_bias_fc = True
    config.fused_mlp = True
    config.fused_dropout_add_ln = True
    config.residual_in_fp32 = True
    config.pad_vocab_size_multiple = 8

    model = GPTLMHeadModel.from_pretrained(model_name, config)
    model = model.cuda().to(dtype=dtype)

    model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda()
    model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype)

    model.eval()
    model_ref.eval()
    model_hf.eval()

    torch.manual_seed(0)
    batch_size = 4
    max_seqlen = 512
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
    input_ids = torch.randint(
        0, vocab_size_og, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
    )
    out = model.transformer(input_ids)
    out_hf = model_hf.transformer(input_ids).last_hidden_state
    out_ref = model_ref.transformer(input_ids).last_hidden_state

    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
    print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
    assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()

    logits = model(input_ids).logits[..., :vocab_size_og]
    logits_hf = model_hf(input_ids).logits
    logits_ref = model_ref(input_ids).logits

    print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
    print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
    print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
    print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
    assert (logits - logits_ref).abs().max().item() < 3 * (
        logits_hf - logits_ref
    ).abs().max().item()


@pytest.mark.parametrize("optimized", [False, True])
# @pytest.mark.parametrize('optimized', [True])
@pytest.mark.parametrize("rotary", [False, True])
# @pytest.mark.parametrize('rotary', [False])
@pytest.mark.parametrize("model_name", ["gpt2"])
def test_gpt2_generation(model_name, rotary, optimized):
    """Check that our implementation of GPT2 generation matches the HF implementation:

    the scores in fp16 should be around the same as the HF scores in fp16, when compared to

    the HF scores in fp32.

    """
    dtype = torch.float16
    device = "cuda"
    rtol, atol = 3e-3, 3e-1
    config = GPT2Config.from_pretrained(model_name)
    if rotary:
        config.n_positions = 0
        config.rotary_emb_fraction = 0.5
        config.rotary_emb_base = 24000
    config.residual_in_fp32 = True
    if optimized:
        config.use_flash_attn = True
        config.fused_bias_fc = True
        config.fused_mlp = True
        config.fused_dropout_add_ln = True

    # if not rotary, we load the weight from HF but ignore the position embeddings.
    # The model would be nonsense but it doesn't matter for the test.
    model = GPTLMHeadModel.from_pretrained(
        model_name, config, strict=not rotary, device=device, dtype=dtype
    )
    model.eval()

    if not rotary:
        model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device)
        model_hf = GPT2LMHeadModelHF.from_pretrained(model_name, torch_dtype=dtype).to(
            device=device
        )
        model_ref.eval()
        model_hf.eval()

    torch.manual_seed(0)
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(
        device=device
    )
    max_length = 25
    # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
    # max_length = input_ids.shape[1] + 40

    # Slow generation for reference
    sequences = []
    scores = []
    cur_input_ids = input_ids
    with torch.inference_mode():
        scores.append(model(cur_input_ids).logits[:, -1])
        sequences.append(scores[-1].argmax(dim=-1))
        for _ in range(input_ids.shape[1] + 1, max_length):
            cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1)
            scores.append(model(cur_input_ids).logits[:, -1])
            sequences.append(scores[-1].argmax(dim=-1))
    sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
    scores = tuple(scores)

    out = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        return_dict_in_generate=True,
        output_scores=True,
        enable_timing=True,
    )
    print(out.sequences)
    print(tokenizer.batch_decode(out.sequences.tolist()))
    if getattr(config, "use_flash_attn", False):
        out_cg = model.generate(
            input_ids=input_ids,
            max_length=max_length,
            cg=True,
            return_dict_in_generate=True,
            output_scores=True,
            enable_timing=True,
        )
        print(out_cg.sequences)
        assert torch.equal(torch.stack(out.scores, dim=1), torch.stack(out_cg.scores, dim=1))

    if not rotary:
        out_hf = model_hf.generate(
            input_ids=input_ids,
            max_length=max_length,
            return_dict_in_generate=True,
            output_scores=True,
        )
        out_ref = model_ref.generate(
            input_ids=input_ids,
            max_length=max_length,
            return_dict_in_generate=True,
            output_scores=True,
        )

        print(
            f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
        )
        print(
            f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
        )
        print(
            f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
        )
        print(
            f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
        )
        print(tokenizer.batch_decode(out_ref.sequences.tolist()))

    assert torch.all(out.sequences == sequences)
    assert torch.allclose(
        torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
    )
    if not rotary:
        assert torch.all(out.sequences == out_ref.sequences)
        assert torch.all(out.sequences == out_hf.sequences)

        assert (
            torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)
        ).abs().max().item() < 3 * (
            torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
        ).abs().max().item()


def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
    out = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        teacher_outputs=teacher_outputs,
        return_dict_in_generate=True,
        output_scores=True,
        enable_timing=True,
        **kwargs,
    )
    return torch.stack(out.scores, dim=1)


@pytest.mark.parametrize("seqlen,maxlen", [(10, 20), (30, 150), (3000, 3400), (14000, 15000)])
# @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)])
@pytest.mark.parametrize("rotary", [None, "interleaved", "contiguous"])
# @pytest.mark.parametrize('rotary', [None])
@pytest.mark.parametrize("model_name", ["gpt2"])
def test_gpt2_generation_cg(model_name, rotary, seqlen, maxlen):
    """Check that decoding with CUDA graph is the same as decoding without CUDA graph."""
    dtype = torch.float16
    device = "cuda"
    rtol, atol = 3e-3, 3e-1
    config = GPT2Config.from_pretrained(model_name)
    config.n_positions = 16 * 1024
    assert seqlen <= maxlen <= config.n_positions
    if rotary is not None:
        config.n_positions = 0
        config.rotary_emb_dim = 32
        config.rotary_emb_interleaved = rotary == "interleaved"
    config.residual_in_fp32 = True
    config.use_flash_attn = True
    config.fused_bias_fc = True
    config.fused_mlp = True
    config.fused_dropout_add_ln = True

    model = GPTLMHeadModel(config, device=device, dtype=dtype)
    model.eval()

    torch.manual_seed(0)
    batch_size = 1
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
    )
    teacher_outputs = torch.randint(
        0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
    )

    logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
    logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
    assert torch.equal(logits, logits_cg)

    # Try increasing batch size and seqlen, then decrease them to see if it's still correct
    batch_size = 3
    maxlen += 30
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
    )
    teacher_outputs = torch.randint(
        0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
    )
    logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
    logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
    assert torch.equal(logits, logits_cg)

    batch_size = 2
    maxlen -= 35
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
    )
    teacher_outputs = torch.randint(
        0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
    )
    logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
    logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
    assert torch.equal(logits, logits_cg)


@pytest.mark.parametrize("optimized", [False, True])
# @pytest.mark.parametrize("optimized", [False])
@pytest.mark.parametrize("model_name", ["gpt2"])
def test_gpt2_multiple_token_generation(model_name, optimized):
    """Generation when we pass in multiple tokens at a time, not just one."""
    dtype = torch.float16
    device = "cuda"
    rtol, atol = 3e-3, 3e-1
    config = GPT2Config.from_pretrained(model_name)
    config.residual_in_fp32 = True
    if optimized:
        config.use_flash_attn = True
        config.fused_bias_fc = True
        config.fused_mlp = True
        config.fused_dropout_add_ln = True

    model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
    model.eval()

    torch.manual_seed(0)
    input_ids = torch.randint(0, config.vocab_size, (1, 20), dtype=torch.long, device=device)
    # Reference logits
    logits_ref = model(input_ids).logits

    # Run 10 tokens, then pass in another 4, then another 6, to see if we get the same logits
    inference_params = InferenceParams(max_seqlen=20, max_batch_size=1)
    logits_10 = model(input_ids[:, :10], inference_params=inference_params).logits
    inference_params.seqlen_offset += 10
    position_ids = torch.arange(10, 14, dtype=torch.long, device=device)
    logits_1014 = model(
        input_ids[:, 10:14], position_ids=position_ids, inference_params=inference_params
    ).logits
    inference_params.seqlen_offset += 4
    position_ids = torch.arange(14, 20, dtype=torch.long, device=device)
    logits_1420 = model(
        input_ids[:, 14:20], position_ids=position_ids, inference_params=inference_params
    ).logits
    logits = torch.cat([logits_10, logits_1014, logits_1420], dim=1)
    print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
    print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
    assert torch.allclose(logits, logits_ref, rtol=rtol, atol=atol)


@pytest.mark.parametrize("cg", [False, True])
# @pytest.mark.parametrize("cg", [True])
@pytest.mark.parametrize("optimized", [False, True])
# @pytest.mark.parametrize("optimized", [True])
# @pytest.mark.parametrize("model_name", ["gpt2-medium"])
@pytest.mark.parametrize("model_name", ["gpt2-xl"])
def test_gpt2_speculative_decoding(model_name, optimized, cg):
    if cg and not optimized:
        pytest.skip()  # CG requires use_flash_attn
    dtype = torch.float16
    device = "cuda"
    rtol, atol = 3e-3, 3e-1
    config = GPT2Config.from_pretrained(model_name)
    config.residual_in_fp32 = True
    if optimized:
        config.use_flash_attn = True
        config.fused_bias_fc = True
        config.fused_mlp = True
        config.fused_dropout_add_ln = True
    config_draft = GPT2Config.from_pretrained("gpt2")
    config_draft.residual_in_fp32 = True
    if optimized:
        config_draft.use_flash_attn = True
        config_draft.fused_bias_fc = True
        config_draft.fused_mlp = True
        config_draft.fused_dropout_add_ln = True

    model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
    model.eval()
    model_draft = GPTLMHeadModel.from_pretrained("gpt2", config_draft, device=device, dtype=dtype)
    model_draft.eval()

    torch.manual_seed(0)
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(
        device=device
    )
    max_length = 100

    from flash_attn.utils.generation import decode_speculative

    torch.manual_seed(42)
    print(f"Speculative decoding, {optimized = }")
    out = decode_speculative(
        input_ids,
        model,
        model_draft,
        max_length=max_length,
        top_k=5,
        cg=cg,
        speculative_lookahead=4,
        enable_timing=True,
        # debug=True,
    )
    print(tokenizer.batch_decode(out.sequences))
    print(f"Without speculative decoding, {cg = }")
    out_og = model.generate(
        input_ids,
        max_length=max_length,
        top_k=5,
        cg=cg,
        enable_timing=True,
        return_dict_in_generate=True,
    )
    print(tokenizer.batch_decode(out_og.sequences))


@pytest.mark.parametrize(

    "n_heads_q_kv",

    [

        (8, 8),  # Regular attention

        (8, 4),  # GQA

        (8, 2),  # MQA

    ],

)
def test_gpt2_shard_unshard(n_heads_q_kv):
    world_size = 2

    config = GPT2Config.from_pretrained("gpt2")
    config.vocab_size = 1024
    config.n_head, config.n_head_kv = n_heads_q_kv
    model = GPTLMHeadModel(config, device="cuda", dtype=torch.float16)
    state_dict = model.state_dict()
    shards = [
        # NOTE: Shallow copy as `state_dict` is modified in-place
        shard_state_dict_tp(dict(state_dict), config, world_size, rank)
        for rank in range(world_size)
    ]
    state_dict2 = combine_state_dicts_tp(shards, config)
    assert state_dict2.keys() == state_dict.keys()
    for k in state_dict.keys():
        ref = state_dict[k]
        new = state_dict[k]
        assert torch.allclose(ref, new, atol=0.0, rtol=0.0)