|
|
|
|
|
|
|
|
|
|
|
import random |
|
import unittest |
|
|
|
import pytest |
|
import torch |
|
|
|
from fairseq.modules.multihead_attention import MultiheadAttention, _mask_for_xformers |
|
|
|
BATCH = [20, 41, 97] |
|
SEQ = [64] |
|
EMB = [48] |
|
HEADS = [4] |
|
DROP = 0.1 |
|
DEVICE = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] |
|
ATTN_MASK_DTYPE = [None, torch.uint8, torch.bool, torch.float] |
|
KEY_PADDING_MASK_DTYPE = [None, torch.uint8, torch.bool] |
|
|
|
|
|
|
|
def assert_almost_equal(x, y, decimal=1, err_msg=""): |
|
import numpy.testing as npt |
|
|
|
if isinstance(x, torch.Tensor): |
|
x = x.cpu().detach().numpy() |
|
if isinstance(y, torch.Tensor): |
|
y = y.cpu().detach().numpy() |
|
npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal) |
|
|
|
|
|
def _reset_seeds(): |
|
torch.manual_seed(0) |
|
torch.random.manual_seed(0) |
|
random.seed(0) |
|
torch.cuda.manual_seed_all(0) |
|
|
|
|
|
def _get_mask(to_dtype: torch.dtype, dim0: int, dim1: int): |
|
if to_dtype == torch.float: |
|
mask = torch.randint(0, 2, (dim0, dim1)).to(dtype=torch.bool) |
|
return mask.to(dtype=to_dtype).masked_fill(mask, -float("inf")) |
|
return torch.randint(0, 2, (dim0, dim1)).to(dtype=to_dtype) |
|
|
|
|
|
def test_mask_for_xformers(): |
|
|
|
m_float_add = torch.tensor([float("-inf"), 0]).to(torch.float) |
|
m_float_add_flipped = torch.tensor([0, float("-inf")]).to(torch.float) |
|
m_float16_add = torch.tensor([float("-inf"), 0]).to(torch.float16) |
|
m_float16_add_flipped = torch.tensor([0, float("-inf")]).to(torch.float16) |
|
m_uint = torch.tensor([1, 0]).to(torch.uint8) |
|
m_uint_flipped = torch.tensor([0, 1]).to(torch.uint8) |
|
m_bool = torch.tensor([False, True]) |
|
|
|
assert torch.equal(_mask_for_xformers(m_float_add), m_float_add) |
|
assert torch.equal(_mask_for_xformers(m_float16_add), m_float16_add) |
|
assert torch.equal(_mask_for_xformers(m_uint), m_uint_flipped) |
|
assert torch.equal(_mask_for_xformers(m_bool), ~m_bool) |
|
|
|
assert torch.equal( |
|
_mask_for_xformers(m_float_add, to_dtype=torch.float16), m_float16_add |
|
) |
|
assert torch.equal( |
|
_mask_for_xformers(m_float_add, to_dtype=torch.float), m_float_add |
|
) |
|
assert torch.equal(_mask_for_xformers(m_float_add, to_dtype=torch.bool), m_bool) |
|
assert torch.equal( |
|
_mask_for_xformers(m_float_add, to_dtype=torch.uint8), m_uint_flipped |
|
) |
|
|
|
assert torch.equal( |
|
_mask_for_xformers(m_float16_add, to_dtype=torch.float16), m_float16_add |
|
) |
|
assert torch.equal( |
|
_mask_for_xformers(m_float16_add, to_dtype=torch.float), m_float_add |
|
) |
|
assert torch.equal(_mask_for_xformers(m_float16_add, to_dtype=torch.bool), m_bool) |
|
assert torch.equal( |
|
_mask_for_xformers(m_float16_add, to_dtype=torch.uint8), m_uint_flipped |
|
) |
|
|
|
assert torch.equal( |
|
_mask_for_xformers(m_bool, to_dtype=torch.float16), m_float16_add_flipped |
|
) |
|
assert torch.equal( |
|
_mask_for_xformers(m_bool, to_dtype=torch.float), m_float_add_flipped |
|
) |
|
assert torch.equal(_mask_for_xformers(m_bool, to_dtype=torch.bool), ~m_bool) |
|
assert torch.equal(_mask_for_xformers(m_bool, to_dtype=torch.uint8), m_uint) |
|
|
|
assert torch.equal( |
|
_mask_for_xformers(m_uint, to_dtype=torch.float16), m_float16_add |
|
) |
|
assert torch.equal(_mask_for_xformers(m_uint, to_dtype=torch.float), m_float_add) |
|
assert torch.equal(_mask_for_xformers(m_uint, to_dtype=torch.bool), m_bool) |
|
assert torch.equal(_mask_for_xformers(m_uint, to_dtype=torch.uint8), m_uint_flipped) |
|
|
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="blocksparse requires gpu") |
|
@pytest.mark.skip(reason="not part of latest xformers") |
|
@pytest.mark.parametrize("device", ["cuda"]) |
|
@pytest.mark.parametrize("add_zero_attn", [False]) |
|
@pytest.mark.parametrize("batch_size", [20]) |
|
@pytest.mark.parametrize("embedding", [64]) |
|
@pytest.mark.parametrize("seq_len", [64]) |
|
@pytest.mark.parametrize("num_heads", [4]) |
|
def test_xformers_blocksparse_parity( |
|
device, |
|
add_zero_attn, |
|
batch_size, |
|
embedding, |
|
seq_len, |
|
num_heads, |
|
): |
|
|
|
xformers_att_config = '{"name": "scaled_dot_product"}' |
|
xformers_blocksparse_blocksize = 16 |
|
xformers_blocksparse_layout = torch.ones( |
|
seq_len // xformers_blocksparse_blocksize, |
|
seq_len // xformers_blocksparse_blocksize, |
|
dtype=torch.int32, |
|
) |
|
|
|
q = torch.rand(seq_len, batch_size, embedding).to(device).half() |
|
q.requires_grad = True |
|
k = torch.rand(seq_len, batch_size, embedding).to(device).half() |
|
k.requires_grad = True |
|
v = torch.rand(seq_len, batch_size, embedding).to(device).half() |
|
v.requires_grad = True |
|
|
|
q_ = q.detach().clone().half() |
|
q_.requires_grad = True |
|
k_ = k.detach().clone().half() |
|
k_.requires_grad = True |
|
v_ = v.detach().clone().half() |
|
v_.requires_grad = True |
|
|
|
_reset_seeds() |
|
xf_blocksparse_mha = ( |
|
MultiheadAttention( |
|
embedding, |
|
num_heads, |
|
dropout=0.0, |
|
add_zero_attn=add_zero_attn, |
|
xformers_att_config=xformers_att_config, |
|
xformers_blocksparse_layout=xformers_blocksparse_layout, |
|
xformers_blocksparse_blocksize=xformers_blocksparse_blocksize, |
|
) |
|
.to(device) |
|
.half() |
|
) |
|
|
|
xf_blocksparse_output, _ = xf_blocksparse_mha( |
|
q, |
|
k, |
|
v, |
|
) |
|
|
|
_reset_seeds() |
|
xformers_mha = ( |
|
MultiheadAttention( |
|
embedding, |
|
num_heads, |
|
dropout=0.0, |
|
add_zero_attn=add_zero_attn, |
|
xformers_att_config=xformers_att_config, |
|
xformers_blocksparse_layout=None, |
|
) |
|
.to(device) |
|
.half() |
|
) |
|
|
|
xformers_output, _ = xformers_mha( |
|
q_, |
|
k_, |
|
v_, |
|
) |
|
|
|
|
|
rand = random.uniform(0, 1) |
|
xformers_output = xformers_output.masked_fill(xformers_output.isnan(), rand) |
|
xf_blocksparse_output = xf_blocksparse_output.masked_fill( |
|
xf_blocksparse_output.isnan(), rand |
|
) |
|
|
|
assert_almost_equal(xformers_output, xf_blocksparse_output) |
|
|
|
loss_blocksparse = torch.norm(xformers_output) |
|
loss_original = torch.norm(xf_blocksparse_output) |
|
loss_blocksparse.backward() |
|
loss_original.backward() |
|
|
|
q.masked_fill(q.isnan(), rand) |
|
q_.masked_fill(q_.isnan(), rand) |
|
k.masked_fill(k.isnan(), rand) |
|
k_.masked_fill(k_.isnan(), rand) |
|
v.masked_fill(v.isnan(), rand) |
|
v_.masked_fill(v_.isnan(), rand) |
|
|
|
assert_almost_equal(q.grad, q_.grad) |
|
assert_almost_equal(k.grad, k_.grad) |
|
assert_almost_equal(v.grad, v_.grad) |
|
|
|
|
|
@pytest.mark.parametrize("device", DEVICE) |
|
@pytest.mark.parametrize("attn_dtype", ATTN_MASK_DTYPE) |
|
@pytest.mark.parametrize("key_padding_dtype", KEY_PADDING_MASK_DTYPE) |
|
@pytest.mark.parametrize("add_bias_kv", [True, False]) |
|
@pytest.mark.parametrize("add_zero_attn", [True, False]) |
|
|
|
@pytest.mark.parametrize("static_kv", [False]) |
|
@pytest.mark.parametrize("batch_size", BATCH) |
|
@pytest.mark.parametrize("embedding", EMB) |
|
@pytest.mark.parametrize("seq_len", SEQ) |
|
@pytest.mark.parametrize("num_heads", HEADS) |
|
def test_xformers_single_forward_parity( |
|
device, |
|
attn_dtype, |
|
key_padding_dtype, |
|
add_bias_kv, |
|
add_zero_attn, |
|
static_kv, |
|
batch_size, |
|
embedding, |
|
seq_len, |
|
num_heads, |
|
): |
|
|
|
xformers_att_config = '{"name": "scaled_dot_product"}' |
|
|
|
attn_mask = ( |
|
None |
|
if attn_dtype is None |
|
else _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len).to(device) |
|
) |
|
key_padding_mask = ( |
|
None |
|
if key_padding_dtype is None |
|
else _get_mask(to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len).to( |
|
device |
|
) |
|
) |
|
|
|
q = torch.rand(seq_len, batch_size, embedding).to(device) |
|
q.requires_grad = True |
|
k = torch.rand(seq_len, batch_size, embedding).to(device) |
|
k.requires_grad = True |
|
v = torch.rand(seq_len, batch_size, embedding).to(device) |
|
v.requires_grad = True |
|
|
|
q_ = q.detach().clone() |
|
q_.requires_grad = True |
|
k_ = k.detach().clone() |
|
k_.requires_grad = True |
|
v_ = v.detach().clone() |
|
v_.requires_grad = True |
|
|
|
|
|
_reset_seeds() |
|
xformers_mha = MultiheadAttention( |
|
embedding, |
|
num_heads, |
|
dropout=0.0, |
|
xformers_att_config=xformers_att_config, |
|
add_bias_kv=add_bias_kv, |
|
add_zero_attn=add_zero_attn, |
|
).to(device) |
|
xformers_output, _ = xformers_mha( |
|
q, |
|
k, |
|
v, |
|
key_padding_mask=key_padding_mask, |
|
attn_mask=attn_mask, |
|
static_kv=static_kv, |
|
) |
|
|
|
_reset_seeds() |
|
original_mha = MultiheadAttention( |
|
embedding, |
|
num_heads, |
|
dropout=0.0, |
|
xformers_att_config=None, |
|
add_bias_kv=add_bias_kv, |
|
add_zero_attn=add_zero_attn, |
|
).to(device) |
|
original_output, _ = original_mha( |
|
q_, |
|
k_, |
|
v_, |
|
key_padding_mask=key_padding_mask, |
|
attn_mask=attn_mask, |
|
static_kv=static_kv, |
|
) |
|
|
|
|
|
if xformers_output.isnan().any() or original_output.isnan().any(): |
|
rand = random.uniform(0, 1) |
|
xformers_output = xformers_output.masked_fill(xformers_output.isnan(), rand) |
|
original_output = original_output.masked_fill(original_output.isnan(), rand) |
|
|
|
|
|
assert torch.allclose( |
|
xformers_output, original_output, atol=1e-06 |
|
), f"max diff is {torch.max(torch.abs(xformers_output - original_output))}" |
|
|
|
loss_xformers = torch.norm(xformers_output) |
|
loss_original = torch.norm(original_output) |
|
loss_xformers.backward() |
|
loss_original.backward() |
|
|
|
|
|
assert torch.allclose( |
|
q.grad, q_.grad |
|
), f"max diff is {torch.max(torch.abs(q.grad - q_.grad))}" |
|
assert torch.allclose( |
|
k.grad, k_.grad |
|
), f"max diff is {torch.max(torch.abs(k.grad - k_.grad))}" |
|
assert torch.allclose( |
|
v.grad, v_.grad |
|
), f"max diff is {torch.max(torch.abs(v.grad - v_.grad))}" |
|
|
|
|
|
def test_mask_padding_parity(): |
|
def old_padding_code(key_padding_mask, attn_mask): |
|
if attn_mask is not None: |
|
attn_mask = torch.cat( |
|
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 |
|
) |
|
if key_padding_mask is not None: |
|
key_padding_mask = torch.cat( |
|
[ |
|
key_padding_mask, |
|
torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask), |
|
], |
|
dim=1, |
|
) |
|
return key_padding_mask, attn_mask |
|
|
|
|
|
mha = MultiheadAttention( |
|
embed_dim=8, |
|
num_heads=2, |
|
dropout=0.0, |
|
add_bias_kv=True, |
|
add_zero_attn=True, |
|
) |
|
|
|
key_padding_mask = torch.rand((8, 64)) |
|
attn_mask = torch.rand((64, 64)) |
|
|
|
kp_mask_orig, a_mask_orig = old_padding_code(key_padding_mask, attn_mask) |
|
kp_mask_new, a_mask_new = mha._pad_masks(key_padding_mask, attn_mask) |
|
|
|
assert kp_mask_orig.size() == kp_mask_new.size() |
|
assert a_mask_orig.size() == a_mask_new.size() |
|
assert torch.equal(kp_mask_orig, kp_mask_new) |
|
assert torch.equal(a_mask_orig, a_mask_new) |
|
|
|
|
|
def test_add_bias_parity(): |
|
|
|
mha = MultiheadAttention( |
|
embed_dim=8, |
|
num_heads=2, |
|
dropout=0.0, |
|
add_bias_kv=True, |
|
add_zero_attn=True, |
|
) |
|
|
|
def old_bias_code(k, v, key_padding_mask, attn_mask, bsz): |
|
k = torch.cat([k, mha.bias_k.repeat(1, bsz, 1)]) |
|
v = torch.cat([v, mha.bias_v.repeat(1, bsz, 1)]) |
|
if attn_mask is not None: |
|
attn_mask = torch.cat( |
|
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 |
|
) |
|
if key_padding_mask is not None: |
|
key_padding_mask = torch.cat( |
|
[ |
|
key_padding_mask, |
|
key_padding_mask.new_zeros(key_padding_mask.size(0), 1), |
|
], |
|
dim=1, |
|
) |
|
return k, v, key_padding_mask, attn_mask |
|
|
|
seq_len = 64 |
|
bsz = 8 |
|
embedding = 8 |
|
key_padding_mask = torch.rand((bsz, seq_len)) |
|
attn_mask = torch.rand((seq_len, seq_len)) |
|
k = torch.rand((seq_len, bsz, embedding)) |
|
v = torch.rand((seq_len, bsz, embedding)) |
|
|
|
k_orig, v_orig, kp_mask_orig, a_mask_orig = old_bias_code( |
|
k, v, key_padding_mask, attn_mask, bsz |
|
) |
|
k_new, v_new, kp_mask_new, a_mask_new = mha._add_bias( |
|
k, v, key_padding_mask, attn_mask, bsz |
|
) |
|
|
|
assert torch.equal(k_orig, k_new) |
|
assert torch.equal(v_orig, v_new) |
|
assert torch.equal(kp_mask_orig, kp_mask_new) |
|
assert torch.equal(a_mask_orig, a_mask_new) |
|
|
|
|
|
class TestMultiheadAttention(unittest.TestCase): |
|
def test_append_prev_key_padding_mask(self): |
|
bsz = 1 |
|
src_len = 4 |
|
|
|
cases = [ |
|
|
|
(None, None, None), |
|
|
|
( |
|
torch.tensor([[1]]).bool(), |
|
None, |
|
torch.tensor([[0, 0, 0, 1]]).bool(), |
|
), |
|
|
|
( |
|
None, |
|
torch.tensor([[0, 1, 0]]).bool(), |
|
torch.tensor([[0, 1, 0, 0]]).bool(), |
|
), |
|
|
|
( |
|
torch.tensor([[1]]).bool(), |
|
torch.tensor([[0, 1, 0]]).bool(), |
|
torch.tensor([[0, 1, 0, 1]]).bool(), |
|
), |
|
|
|
( |
|
torch.tensor([[0, 1, 0, 1]]).bool(), |
|
None, |
|
torch.tensor([[0, 1, 0, 1]]).bool(), |
|
), |
|
|
|
( |
|
None, |
|
torch.tensor([[0, 1, 0, 1]]).bool(), |
|
torch.tensor([[0, 1, 0, 1]]).bool(), |
|
), |
|
] |
|
for c in cases: |
|
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( |
|
c[0], |
|
c[1], |
|
batch_size=bsz, |
|
src_len=src_len, |
|
static_kv=False, |
|
) |
|
|
|
if key_padding_mask is not None: |
|
self.assertTrue( |
|
torch.all(torch.eq(key_padding_mask, c[2])), |
|
f"Unexpected resultant key padding mask: {key_padding_mask}" |
|
f" given current: {c[0]} and previous: {c[1]}", |
|
) |
|
self.assertEqual(key_padding_mask.size(0), bsz) |
|
self.assertEqual(key_padding_mask.size(1), src_len) |
|
else: |
|
self.assertIsNone(c[2]) |
|
|
|
def test_pruning_heads(self): |
|
embed_dim = 768 |
|
num_heads = 12 |
|
num_heads_to_keep = 8 |
|
dummy_input = torch.randn(32, 2, embed_dim) |
|
mha = MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads) |
|
reserve_head_index = mha._get_reserve_head_index( |
|
num_heads_to_keep=num_heads_to_keep |
|
) |
|
mha._adaptive_prune_heads(reserve_head_index=reserve_head_index) |
|
mha._set_skip_embed_dim_check() |
|
mha(query=dummy_input, key=dummy_input, value=dummy_input) |
|
self.assertEqual(mha.head_dim, embed_dim / num_heads) |
|
self.assertEqual(mha.num_heads, num_heads_to_keep) |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|