|
import unittest |
|
|
|
import itertools as it |
|
|
|
from apex import amp |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
|
|
from utils import common_init, HALF, FLOAT, DTYPES |
|
|
|
class TestPromotion(unittest.TestCase): |
|
def setUp(self): |
|
self.handle = amp.init(enabled=True) |
|
common_init(self) |
|
|
|
def tearDown(self): |
|
self.handle._deactivate() |
|
|
|
def run_binary_promote_test(self, fns, input_shape, x_inplace=False): |
|
type_pairs = it.product(DTYPES, DTYPES) |
|
for fn, (xtype, ytype) in it.product(fns, type_pairs): |
|
x = torch.randn(input_shape, dtype=xtype).requires_grad_() |
|
x_leaf = x |
|
if x_inplace: |
|
|
|
x = x.clone() |
|
y = torch.randn(input_shape, dtype=ytype) |
|
out = fn(x, y) |
|
if x_inplace: |
|
|
|
self.assertEqual(out.type(), x.type()) |
|
else: |
|
|
|
if xtype == torch.float or ytype == torch.float: |
|
self.assertEqual(out.type(), FLOAT) |
|
else: |
|
self.assertEqual(out.type(), HALF) |
|
out.float().sum().backward() |
|
self.assertEqual(x_leaf.grad.dtype, xtype) |
|
|
|
def test_atan2_matches_widest(self): |
|
fns = [lambda x, y : torch.atan2(x, y), |
|
lambda x, y : x.atan2(y)] |
|
self.run_binary_promote_test(fns, (self.b,)) |
|
|
|
def test_mul_matches_widest(self): |
|
fns = [lambda x, y : torch.mul(x, y), |
|
lambda x, y: x.mul(y)] |
|
self.run_binary_promote_test(fns, (self.b,)) |
|
|
|
def test_cat_matches_widest(self): |
|
shape = self.b |
|
ys = [torch.randn(shape, dtype=torch.half) for _ in range(5)] |
|
x_float = torch.randn(shape) |
|
out = torch.cat(ys + [x_float]) |
|
self.assertEqual(out.type(), FLOAT) |
|
x_half = torch.randn(shape, dtype=torch.half) |
|
out = torch.cat(ys + [x_half]) |
|
self.assertEqual(out.type(), HALF) |
|
|
|
def test_inplace_exp_is_error_for_half(self): |
|
xs = torch.randn(self.b) |
|
xs.exp_() |
|
self.assertEqual(xs.type(), FLOAT) |
|
xs = torch.randn(self.b, dtype=torch.half) |
|
with self.assertRaises(NotImplementedError): |
|
xs.exp_() |
|
|
|
def test_inplace_add_matches_self(self): |
|
fn = lambda x, y: x.add_(y) |
|
self.run_binary_promote_test([fn], (self.b,), x_inplace=True) |
|
|
|
if __name__ == '__main__': |
|
unittest.main() |
|
|