File size: 839 Bytes
f2c15d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch

import act_mem
import layers

if __name__ == "__main__":
    batch_size, seq_len, d_model, n_heads = 1, 128, 1024, 32
    print(f"Batch size: {batch_size}, sequence length: {seq_len}, d_model: {d_model}, n_heads: {n_heads}")
    dtype = torch.bfloat16
    inputs = torch.randn(
        batch_size,
        seq_len,
        d_model,
        device="cuda",
        requires_grad=True,
        dtype=dtype,
    )

    attn = layers.Attention(
        d_model=d_model,
        n_heads=n_heads,
        device="cuda",
        dtype=dtype,
    )
    with act_mem.AllocatedMemContext() as mem, act_mem.SavedTensorContext(
        ignored_tensors=attn.parameters()
    ) as saved:
        out = attn(inputs)
    stm = saved.saved_tensor_mem
    print(f'{mem.delta["current"]=}')
    print(f"{stm=}")
    print(f"{stm/out.numel()=}")