File size: 956 Bytes
c508d7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from __future__ import absolute_import, division, print_function, unicode_literals
import unittest
import torch
from tensorboardX import SummaryWriter


class PytorchGraphTest(unittest.TestCase):
    def test_pytorch_graph(self):
        dummy_input = (torch.zeros(1, 3),)

        class myLinear(torch.nn.Module):
            def __init__(self):
                super(myLinear, self).__init__()
                self.linear = torch.nn.Linear(3, 5)

            def forward(self, x):
                return self.linear(x)

        with SummaryWriter(comment='LinearModel') as w:
            w.add_graph(myLinear(), dummy_input, True)

    def test_wrong_input_size(self):
        print('expect error here:')
        with self.assertRaises(RuntimeError):
            dummy_input = torch.rand(1, 9)
            model = torch.nn.Linear(3, 5)
            with SummaryWriter(comment='expect_error') as w:
                w.add_graph(model, dummy_input)  # error