File size: 956 Bytes
c508d7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
from __future__ import absolute_import, division, print_function, unicode_literals
import unittest
import torch
from tensorboardX import SummaryWriter
class PytorchGraphTest(unittest.TestCase):
def test_pytorch_graph(self):
dummy_input = (torch.zeros(1, 3),)
class myLinear(torch.nn.Module):
def __init__(self):
super(myLinear, self).__init__()
self.linear = torch.nn.Linear(3, 5)
def forward(self, x):
return self.linear(x)
with SummaryWriter(comment='LinearModel') as w:
w.add_graph(myLinear(), dummy_input, True)
def test_wrong_input_size(self):
print('expect error here:')
with self.assertRaises(RuntimeError):
dummy_input = torch.rand(1, 9)
model = torch.nn.Linear(3, 5)
with SummaryWriter(comment='expect_error') as w:
w.add_graph(model, dummy_input) # error
|