| import torch | |
| import act_mem | |
| import layers | |
| if __name__ == "__main__": | |
| batch_size, seq_len, d_model, n_heads = 1, 128, 1024, 32 | |
| print(f"Batch size: {batch_size}, sequence length: {seq_len}, d_model: {d_model}, n_heads: {n_heads}") | |
| dtype = torch.bfloat16 | |
| inputs = torch.randn( | |
| batch_size, | |
| seq_len, | |
| d_model, | |
| device="cuda", | |
| requires_grad=True, | |
| dtype=dtype, | |
| ) | |
| attn = layers.Attention( | |
| d_model=d_model, | |
| n_heads=n_heads, | |
| device="cuda", | |
| dtype=dtype, | |
| ) | |
| with act_mem.AllocatedMemContext() as mem, act_mem.SavedTensorContext( | |
| ignored_tensors=attn.parameters() | |
| ) as saved: | |
| out = attn(inputs) | |
| stm = saved.saved_tensor_mem | |
| print(f'{mem.delta["current"]=}') | |
| print(f"{stm=}") | |
| print(f"{stm/out.numel()=}") | |