from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import matplotlib.pyplot as plt import unittest from tensorboardX import SummaryWriter class FigureTest(unittest.TestCase): def test_figure(self): writer = SummaryWriter() figure, axes = plt.figure(), plt.gca() circle1 = plt.Circle((0.2, 0.5), 0.2, color='r') circle2 = plt.Circle((0.8, 0.5), 0.2, color='g') axes.add_patch(circle1) axes.add_patch(circle2) plt.axis('scaled') plt.tight_layout() writer.add_figure("add_figure/figure", figure, 0, close=False) assert plt.fignum_exists(figure.number) is True writer.add_figure("add_figure/figure", figure, 1) assert plt.fignum_exists(figure.number) is False writer.close() def test_figure_list(self): writer = SummaryWriter() figures = [] for i in range(5): figure = plt.figure() plt.plot([i * 1, i * 2, i * 3], label="Plot " + str(i)) plt.xlabel("X") plt.xlabel("Y") plt.legend() plt.tight_layout() figures.append(figure) writer.add_figure("add_figure/figure_list", figures, 0, close=False) assert all([plt.fignum_exists(figure.number) is True for figure in figures]) writer.add_figure("add_figure/figure_list", figures, 1) assert all([plt.fignum_exists(figure.number) is False for figure in figures]) writer.close()