|
from __future__ import absolute_import, division, print_function, unicode_literals |
|
import unittest |
|
import torch |
|
from tensorboardX import SummaryWriter |
|
|
|
|
|
class PytorchGraphTest(unittest.TestCase): |
|
def test_pytorch_graph(self): |
|
dummy_input = (torch.zeros(1, 3),) |
|
|
|
class myLinear(torch.nn.Module): |
|
def __init__(self): |
|
super(myLinear, self).__init__() |
|
self.linear = torch.nn.Linear(3, 5) |
|
|
|
def forward(self, x): |
|
return self.linear(x) |
|
|
|
with SummaryWriter(comment='LinearModel') as w: |
|
w.add_graph(myLinear(), dummy_input, True) |
|
|
|
def test_wrong_input_size(self): |
|
print('expect error here:') |
|
with self.assertRaises(RuntimeError): |
|
dummy_input = torch.rand(1, 9) |
|
model = torch.nn.Linear(3, 5) |
|
with SummaryWriter(comment='expect_error') as w: |
|
w.add_graph(model, dummy_input) |
|
|