|
|
|
|
|
import numpy as np |
|
import os |
|
import ntpath |
|
import time |
|
import glob |
|
from scipy.misc import imresize |
|
import torchvision.utils as vutils |
|
from operator import itemgetter |
|
from tensorboardX import SummaryWriter |
|
|
|
|
|
class Visualizer(): |
|
def __init__(self, checkpoints_dir, name): |
|
self.win_size = 256 |
|
self.name = name |
|
self.saved = False |
|
self.checkpoints_dir = checkpoints_dir |
|
self.ncols = 4 |
|
|
|
|
|
for filename in glob.glob(self.checkpoints_dir+"/events*"): |
|
os.remove(filename) |
|
self.writer = SummaryWriter(checkpoints_dir) |
|
|
|
def reset(self): |
|
self.saved = False |
|
|
|
|
|
def image_summary(self, mode, epoch, images): |
|
images = vutils.make_grid(images, normalize=True, scale_each=True) |
|
self.writer.add_image('{}/Image'.format(mode), images, epoch) |
|
|
|
|
|
def text_summary(self, mode, epoch, type, text, vocabulary, gt=True, max_length=20): |
|
for i, el in enumerate(text): |
|
if not gt: |
|
idx = el.nonzero().squeeze() + 1 |
|
else: |
|
idx = el |
|
|
|
words_list = itemgetter(*idx)(vocabulary) |
|
|
|
if len(words_list) <= max_length: |
|
self.writer.add_text('{}/{}_{}_{}'.format(mode, type, i, 'gt' if gt else 'prediction'), |
|
', '.join(filter(lambda x: x != '<pad>', words_list)), epoch) |
|
else: |
|
self.writer.add_text('{}/{}_{}_{}'.format(mode, type, i, 'gt' if gt else 'prediction'), |
|
'Number of sampled ingredients is too big: {}'.format(len(words_list)), epoch) |
|
|
|
|
|
def scalar_summary(self, mode, epoch, **args): |
|
for k, v in args.items(): |
|
self.writer.add_scalar('{}/{}'.format(mode, k), v, epoch) |
|
|
|
self.writer.export_scalars_to_json("{}/tensorboard_all_scalars.json".format(self.checkpoints_dir)) |
|
|
|
def histo_summary(self, model, step): |
|
"""Log a histogram of the tensor of values.""" |
|
|
|
for name, param in model.named_parameters(): |
|
self.writer.add_histogram(name, param, step) |
|
|
|
def close(self): |
|
self.writer.close() |
|
|