|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
from __future__ import unicode_literals |
|
|
|
import matplotlib.pyplot as plt |
|
import unittest |
|
|
|
from tensorboardX import SummaryWriter |
|
|
|
|
|
class FigureTest(unittest.TestCase): |
|
def test_figure(self): |
|
writer = SummaryWriter() |
|
|
|
figure, axes = plt.figure(), plt.gca() |
|
circle1 = plt.Circle((0.2, 0.5), 0.2, color='r') |
|
circle2 = plt.Circle((0.8, 0.5), 0.2, color='g') |
|
axes.add_patch(circle1) |
|
axes.add_patch(circle2) |
|
plt.axis('scaled') |
|
plt.tight_layout() |
|
|
|
writer.add_figure("add_figure/figure", figure, 0, close=False) |
|
assert plt.fignum_exists(figure.number) is True |
|
|
|
writer.add_figure("add_figure/figure", figure, 1) |
|
assert plt.fignum_exists(figure.number) is False |
|
|
|
writer.close() |
|
|
|
def test_figure_list(self): |
|
writer = SummaryWriter() |
|
|
|
figures = [] |
|
for i in range(5): |
|
figure = plt.figure() |
|
plt.plot([i * 1, i * 2, i * 3], label="Plot " + str(i)) |
|
plt.xlabel("X") |
|
plt.xlabel("Y") |
|
plt.legend() |
|
plt.tight_layout() |
|
figures.append(figure) |
|
|
|
writer.add_figure("add_figure/figure_list", figures, 0, close=False) |
|
assert all([plt.fignum_exists(figure.number) is True for figure in figures]) |
|
|
|
writer.add_figure("add_figure/figure_list", figures, 1) |
|
assert all([plt.fignum_exists(figure.number) is False for figure in figures]) |
|
|
|
writer.close() |
|
|