jukebox / tensorboardX /tests /test_pytorch_graph.py
bds2714's picture
Upload 331 files
c508d7f
from __future__ import absolute_import, division, print_function, unicode_literals
import unittest
import torch
from tensorboardX import SummaryWriter
class PytorchGraphTest(unittest.TestCase):
def test_pytorch_graph(self):
dummy_input = (torch.zeros(1, 3),)
class myLinear(torch.nn.Module):
def __init__(self):
super(myLinear, self).__init__()
self.linear = torch.nn.Linear(3, 5)
def forward(self, x):
return self.linear(x)
with SummaryWriter(comment='LinearModel') as w:
w.add_graph(myLinear(), dummy_input, True)
def test_wrong_input_size(self):
print('expect error here:')
with self.assertRaises(RuntimeError):
dummy_input = torch.rand(1, 9)
model = torch.nn.Linear(3, 5)
with SummaryWriter(comment='expect_error') as w:
w.add_graph(model, dummy_input) # error