File size: 2,440 Bytes
bd07fad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import numpy as np
import os
import ntpath
import time
import glob
from scipy.misc import imresize
import torchvision.utils as vutils
from operator import itemgetter
from tensorboardX import SummaryWriter


class Visualizer():
    def __init__(self, checkpoints_dir, name):
        self.win_size = 256
        self.name = name
        self.saved = False
        self.checkpoints_dir = checkpoints_dir
        self.ncols = 4

        # remove existing
        for filename in glob.glob(self.checkpoints_dir+"/events*"):
            os.remove(filename)
        self.writer = SummaryWriter(checkpoints_dir)

    def reset(self):
        self.saved = False

    # images: (b, c, 0, 1) array of images
    def image_summary(self, mode, epoch, images):
        images = vutils.make_grid(images, normalize=True, scale_each=True)
        self.writer.add_image('{}/Image'.format(mode), images, epoch)

    # text: type: ingredients/recipe
    def text_summary(self, mode, epoch, type, text, vocabulary, gt=True, max_length=20):
        for i, el in enumerate(text):  # text_list
            if not gt:  # we are printing a sample
                idx = el.nonzero().squeeze() + 1
            else:
                idx = el  # we are printing the ground truth

            words_list = itemgetter(*idx)(vocabulary)

            if len(words_list) <= max_length:
                self.writer.add_text('{}/{}_{}_{}'.format(mode, type, i, 'gt' if gt else 'prediction'),
                                     ', '.join(filter(lambda x: x != '<pad>', words_list)), epoch)
            else:
                self.writer.add_text('{}/{}_{}_{}'.format(mode, type, i, 'gt' if gt else 'prediction'),
                                     'Number of sampled ingredients is too big: {}'.format(len(words_list)), epoch)

    # losses: dictionary of error labels and values
    def scalar_summary(self, mode, epoch, **args):
        for k, v in args.items():
            self.writer.add_scalar('{}/{}'.format(mode, k), v, epoch)

        self.writer.export_scalars_to_json("{}/tensorboard_all_scalars.json".format(self.checkpoints_dir))

    def histo_summary(self, model, step):
        """Log a histogram of the tensor of values."""

        for name, param in model.named_parameters():
            self.writer.add_histogram(name, param, step)

    def close(self):
        self.writer.close()