Spaces:
Sleeping
Sleeping
import os | |
import time | |
import torch as th | |
import numpy as np | |
import torchvision.datasets as dset | |
import torchvision.transforms as transforms | |
import imageio | |
import ttools | |
import rendering | |
BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir) | |
DATA = os.path.join(BASE_DIR, "data") | |
LOG = ttools.get_logger(__name__) | |
class QuickDrawImageDataset(th.utils.data.Dataset): | |
BASE_DATA_URL = \ | |
"https://console.cloud.google.com/storage/browser/_details/quickdraw_dataset/full/numpy_bitmap/cat.npy" | |
""" | |
Args: | |
spatial_limit(int): maximum spatial extent in pixels. | |
""" | |
def __init__(self, imsize, train=True): | |
super(QuickDrawImageDataset, self).__init__() | |
file = os.path.join(DATA, "cat.npy") | |
self.imsize = imsize | |
if not os.path.exists(file): | |
msg = "Dataset file %s does not exist, please download" | |
" it from %s" % (file, QuickDrawImageDataset.BASE_DATA_URL) | |
LOG.error(msg) | |
raise RuntimeError(msg) | |
self.data = np.load(file, allow_pickle=True, encoding="latin1") | |
def __len__(self): | |
return self.data.shape[0] | |
def __getitem__(self, idx): | |
im = np.reshape(self.data[idx], (1, 1, 28, 28)) | |
im = th.from_numpy(im).float() / 255.0 | |
im = th.nn.functional.interpolate(im, size=(self.imsize, self.imsize)) | |
# Bring it to [-1, 1] | |
im = th.clamp(im, 0, 1) | |
im -= 0.5 | |
im /= 0.5 | |
return im.squeeze(0) | |
class QuickDrawDataset(th.utils.data.Dataset): | |
BASE_DATA_URL = \ | |
"https://storage.cloud.google.com/quickdraw_dataset/sketchrnn" | |
""" | |
Args: | |
spatial_limit(int): maximum spatial extent in pixels. | |
""" | |
def __init__(self, dataset, mode="train", | |
max_seq_length=250, | |
spatial_limit=1000): | |
super(QuickDrawDataset, self).__init__() | |
file = os.path.join(DATA, "sketchrnn_"+dataset) | |
remote = os.path.join(QuickDrawDataset.BASE_DATA_URL, dataset) | |
self.max_seq_length = max_seq_length | |
self.spatial_limit = spatial_limit | |
if mode not in ["train", "test", "valid"]: | |
return ValueError("Only allowed data mode are 'train' and 'test'," | |
" 'valid'.") | |
if not os.path.exists(file): | |
msg = "Dataset file %s does not exist, please download" | |
" it from %s" % (file, remote) | |
LOG.error(msg) | |
raise RuntimeError(msg) | |
data = np.load(file, allow_pickle=True, encoding="latin1")[mode] | |
data = self.purify(data) | |
data = self.normalize(data) | |
# Length of longest sequence in the dataset | |
self.nmax = max([len(seq) for seq in data]) | |
self.sketches = data | |
def __repr__(self): | |
return "Dataset with %d sequences of max length %d" % \ | |
(len(self.sketches), self.nmax) | |
def __len__(self): | |
return len(self.sketches) | |
def __getitem__(self, idx): | |
"""Return the idx-th stroke in 5-D format, padded to length (Nmax+2). | |
The first and last element of the sequence are fixed to "start-" and | |
"end-of-sequence" token. | |
dx, dy, + 3 numbers for one-hot encoding of state: | |
1 0 0: pen touching paper till next point | |
0 1 0: pen lifted from paper after current point | |
0 0 1: drawing has ended, next points (including current will not be | |
drawn) | |
""" | |
sample_data = self.sketches[idx] | |
# Allow two extra slots for start/end of sequence tokens | |
sample = np.zeros((self.nmax+2, 5), dtype=np.float32) | |
n = sample_data.shape[0] | |
# normalize dx, dy | |
deltas = sample_data[:, :2] | |
# Absolute coordinates | |
positions = deltas[..., :2].cumsum(0) | |
maxi = np.abs(positions).max() + 1e-8 | |
deltas = deltas / (1.1 * maxi) # leave some margin on edges | |
# fill in dx, dy coordinates | |
sample[1:n+1, :2] = deltas | |
# on paper indicator: 0 means touching paper in the 3d format, flip it | |
sample[1:n+1, 2] = 1 - sample_data[:, 2] | |
# off-paper indicator, complement of previous flag | |
sample[1:n+1, 3] = 1 - sample[1:n+1, 2] | |
# fill with end of sequence tokens for the remainder | |
sample[n+1:, 4] = 1 | |
# Start of sequence token | |
sample[0] = [0, 0, 1, 0, 0] | |
return sample | |
def purify(self, strokes): | |
"""removes to small or too long sequences + removes large gaps""" | |
data = [] | |
for seq in strokes: | |
if seq.shape[0] <= self.max_seq_length: | |
# and seq.shape[0] > 10: | |
# Limit large spatial gaps | |
seq = np.minimum(seq, self.spatial_limit) | |
seq = np.maximum(seq, -self.spatial_limit) | |
seq = np.array(seq, dtype=np.float32) | |
data.append(seq) | |
return data | |
def calculate_normalizing_scale_factor(self, strokes): | |
"""Calculate the normalizing factor explained in appendix of | |
sketch-rnn.""" | |
data = [] | |
for i, stroke_i in enumerate(strokes): | |
for j, pt in enumerate(strokes[i]): | |
data.append(pt[0]) | |
data.append(pt[1]) | |
data = np.array(data) | |
return np.std(data) | |
def normalize(self, strokes): | |
"""Normalize entire dataset (delta_x, delta_y) by the scaling | |
factor.""" | |
data = [] | |
scale_factor = self.calculate_normalizing_scale_factor(strokes) | |
for seq in strokes: | |
seq[:, 0:2] /= scale_factor | |
data.append(seq) | |
return data | |
class FixedLengthQuickDrawDataset(QuickDrawDataset): | |
"""A variant of the QuickDraw dataset where the strokes are represented as | |
a fixed-length sequence of triplets (dx, dy, opacity), where opacity = 0, 1. | |
""" | |
def __init__(self, *args, canvas_size=64, **kwargs): | |
super(FixedLengthQuickDrawDataset, self).__init__(*args, **kwargs) | |
self.canvas_size = canvas_size | |
def __getitem__(self, idx): | |
sample = super(FixedLengthQuickDrawDataset, self).__getitem__(idx) | |
# We construct a stroke opacity variable from the pen down state, dx, dy remain unchanged | |
strokes = sample[:, :3] | |
im = np.zeros((1, 1)) | |
# render image | |
# start = time.time() | |
im = rendering.opacityStroke2diffvg( | |
th.from_numpy(strokes).unsqueeze(0), canvas_size=self.canvas_size, | |
relative=True, debug=False) | |
im = im.squeeze(0).numpy() | |
# elapsed = (time.time() - start)*1000 | |
# print("item %d pipeline gt rendering took %.2fms" % (idx, elapsed)) | |
return strokes, im | |
class MNISTDataset(th.utils.data.Dataset): | |
def __init__(self, imsize, train=True): | |
super(MNISTDataset, self).__init__() | |
self.mnist = dset.MNIST(root=os.path.join(DATA, "mnist"), | |
train=train, | |
download=True, | |
transform=transforms.Compose([ | |
transforms.Resize((imsize, imsize)), | |
transforms.ToTensor(), | |
])) | |
def __len__(self): | |
return len(self.mnist) | |
def __getitem__(self, idx): | |
im, label = self.mnist[idx] | |
# make sure data uses [0, 1] range | |
im -= im.min() | |
im /= im.max() + 1e-8 | |
# Bring it to [-1, 1] | |
im -= 0.5 | |
im /= 0.5 | |
return im | |