import unittest | |
import torch | |
from fairseq.modules import RelPositionalEncoding | |
import numpy as np | |
class TestRelPositionalEncoding(unittest.TestCase): | |
def setUp(self) -> None: | |
self.T = 3 | |
self.B = 1 | |
self.C = 2 | |
torch.manual_seed(0) | |
self.sample = torch.randn(self.T, self.B, self.C) # TBC | |
self.rel_pos_enc = RelPositionalEncoding(max_len=4, d_model=self.C) | |
def test_extend_pe(self): | |
inp = self.sample.transpose(0, 1) | |
self.rel_pos_enc.extend_pe(inp) | |
expected_pe = torch.tensor( | |
[ | |
[ | |
[0.1411, -0.9900], | |
[0.9093, -0.4161], | |
[0.8415, 0.5403], | |
[0.0000, 1.0000], | |
[-0.8415, 0.5403], | |
[-0.9093, -0.4161], | |
[-0.1411, -0.9900], | |
] | |
] | |
) | |
self.assertTrue( | |
np.allclose( | |
expected_pe.cpu().detach().numpy(), | |
self.rel_pos_enc.pe.cpu().detach().numpy(), | |
atol=1e-4, | |
) | |
) | |
def test_forward(self): | |
pos_enc = self.rel_pos_enc(self.sample) | |
expected_pos_enc = torch.tensor( | |
[ | |
[[0.9093, -0.4161]], | |
[[0.8415, 0.5403]], | |
[[0.0000, 1.0000]], | |
[[-0.8415, 0.5403]], | |
[[-0.9093, -0.4161]], | |
] | |
) | |
self.assertTrue( | |
np.allclose( | |
pos_enc.cpu().detach().numpy(), | |
expected_pos_enc.cpu().detach().numpy(), | |
atol=1e-4, | |
) | |
) | |
if __name__ == "__main__": | |
unittest.main() | |