File size: 2,519 Bytes
c508d7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import unittest
import torch
from tensorboardX import SummaryWriter
class EmbeddingTest(unittest.TestCase):
def test_embedding(self):
w = SummaryWriter()
all_features = torch.Tensor([[1, 2, 3], [5, 4, 1], [3, 7, 7]])
all_labels = torch.Tensor([33, 44, 55])
all_images = torch.zeros(3, 3, 5, 5)
w.add_embedding(all_features,
metadata=all_labels,
label_img=all_images,
global_step=2)
dataset_label = ['test'] * 2 + ['train'] * 2
all_labels = list(zip(all_labels, dataset_label))
w.add_embedding(all_features,
metadata=all_labels,
label_img=all_images,
metadata_header=['digit', 'dataset'],
global_step=2)
# assert...
def test_embedding_64(self):
w = SummaryWriter()
all_features = torch.Tensor([[1, 2, 3], [5, 4, 1], [3, 7, 7]])
all_labels = torch.Tensor([33, 44, 55])
all_images = torch.zeros((3, 3, 5, 5), dtype=torch.float64)
w.add_embedding(all_features,
metadata=all_labels,
label_img=all_images,
global_step=2)
dataset_label = ['test'] * 2 + ['train'] * 2
all_labels = list(zip(all_labels, dataset_label))
w.add_embedding(all_features,
metadata=all_labels,
label_img=all_images,
metadata_header=['digit', 'dataset'],
global_step=2)
def test_embedding_square(self):
w = SummaryWriter(comment='sq')
all_features = torch.rand(228,256)
all_images = torch.rand(228, 3, 32, 32)
for i in range(all_images.shape[0]):
all_images[i] *= (float(i)+60)/(all_images.shape[0]+60)
w.add_embedding(all_features,
label_img=all_images,
global_step=2)
def test_embedding_fail(self):
with self.assertRaises(AssertionError):
w = SummaryWriter(comment='shouldfail')
all_features = torch.rand(228,256)
all_images = torch.rand(228, 3, 16, 32)
for i in range(all_images.shape[0]):
all_images[i] *= (float(i)+60)/(all_images.shape[0]+60)
w.add_embedding(all_features,
label_img=all_images,
global_step=2)
|