M3000j's picture
Upload folder using huggingface_hub
31726e5 verified
import os
import time
import torch as th
import numpy as np
import torchvision.datasets as dset
import torchvision.transforms as transforms
import imageio
import ttools
import rendering
BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir)
DATA = os.path.join(BASE_DIR, "data")
LOG = ttools.get_logger(__name__)
class QuickDrawImageDataset(th.utils.data.Dataset):
BASE_DATA_URL = \
"https://console.cloud.google.com/storage/browser/_details/quickdraw_dataset/full/numpy_bitmap/cat.npy"
"""
Args:
spatial_limit(int): maximum spatial extent in pixels.
"""
def __init__(self, imsize, train=True):
super(QuickDrawImageDataset, self).__init__()
file = os.path.join(DATA, "cat.npy")
self.imsize = imsize
if not os.path.exists(file):
msg = "Dataset file %s does not exist, please download"
" it from %s" % (file, QuickDrawImageDataset.BASE_DATA_URL)
LOG.error(msg)
raise RuntimeError(msg)
self.data = np.load(file, allow_pickle=True, encoding="latin1")
def __len__(self):
return self.data.shape[0]
def __getitem__(self, idx):
im = np.reshape(self.data[idx], (1, 1, 28, 28))
im = th.from_numpy(im).float() / 255.0
im = th.nn.functional.interpolate(im, size=(self.imsize, self.imsize))
# Bring it to [-1, 1]
im = th.clamp(im, 0, 1)
im -= 0.5
im /= 0.5
return im.squeeze(0)
class QuickDrawDataset(th.utils.data.Dataset):
BASE_DATA_URL = \
"https://storage.cloud.google.com/quickdraw_dataset/sketchrnn"
"""
Args:
spatial_limit(int): maximum spatial extent in pixels.
"""
def __init__(self, dataset, mode="train",
max_seq_length=250,
spatial_limit=1000):
super(QuickDrawDataset, self).__init__()
file = os.path.join(DATA, "sketchrnn_"+dataset)
remote = os.path.join(QuickDrawDataset.BASE_DATA_URL, dataset)
self.max_seq_length = max_seq_length
self.spatial_limit = spatial_limit
if mode not in ["train", "test", "valid"]:
return ValueError("Only allowed data mode are 'train' and 'test',"
" 'valid'.")
if not os.path.exists(file):
msg = "Dataset file %s does not exist, please download"
" it from %s" % (file, remote)
LOG.error(msg)
raise RuntimeError(msg)
data = np.load(file, allow_pickle=True, encoding="latin1")[mode]
data = self.purify(data)
data = self.normalize(data)
# Length of longest sequence in the dataset
self.nmax = max([len(seq) for seq in data])
self.sketches = data
def __repr__(self):
return "Dataset with %d sequences of max length %d" % \
(len(self.sketches), self.nmax)
def __len__(self):
return len(self.sketches)
def __getitem__(self, idx):
"""Return the idx-th stroke in 5-D format, padded to length (Nmax+2).
The first and last element of the sequence are fixed to "start-" and
"end-of-sequence" token.
dx, dy, + 3 numbers for one-hot encoding of state:
1 0 0: pen touching paper till next point
0 1 0: pen lifted from paper after current point
0 0 1: drawing has ended, next points (including current will not be
drawn)
"""
sample_data = self.sketches[idx]
# Allow two extra slots for start/end of sequence tokens
sample = np.zeros((self.nmax+2, 5), dtype=np.float32)
n = sample_data.shape[0]
# normalize dx, dy
deltas = sample_data[:, :2]
# Absolute coordinates
positions = deltas[..., :2].cumsum(0)
maxi = np.abs(positions).max() + 1e-8
deltas = deltas / (1.1 * maxi) # leave some margin on edges
# fill in dx, dy coordinates
sample[1:n+1, :2] = deltas
# on paper indicator: 0 means touching paper in the 3d format, flip it
sample[1:n+1, 2] = 1 - sample_data[:, 2]
# off-paper indicator, complement of previous flag
sample[1:n+1, 3] = 1 - sample[1:n+1, 2]
# fill with end of sequence tokens for the remainder
sample[n+1:, 4] = 1
# Start of sequence token
sample[0] = [0, 0, 1, 0, 0]
return sample
def purify(self, strokes):
"""removes to small or too long sequences + removes large gaps"""
data = []
for seq in strokes:
if seq.shape[0] <= self.max_seq_length:
# and seq.shape[0] > 10:
# Limit large spatial gaps
seq = np.minimum(seq, self.spatial_limit)
seq = np.maximum(seq, -self.spatial_limit)
seq = np.array(seq, dtype=np.float32)
data.append(seq)
return data
def calculate_normalizing_scale_factor(self, strokes):
"""Calculate the normalizing factor explained in appendix of
sketch-rnn."""
data = []
for i, stroke_i in enumerate(strokes):
for j, pt in enumerate(strokes[i]):
data.append(pt[0])
data.append(pt[1])
data = np.array(data)
return np.std(data)
def normalize(self, strokes):
"""Normalize entire dataset (delta_x, delta_y) by the scaling
factor."""
data = []
scale_factor = self.calculate_normalizing_scale_factor(strokes)
for seq in strokes:
seq[:, 0:2] /= scale_factor
data.append(seq)
return data
class FixedLengthQuickDrawDataset(QuickDrawDataset):
"""A variant of the QuickDraw dataset where the strokes are represented as
a fixed-length sequence of triplets (dx, dy, opacity), where opacity = 0, 1.
"""
def __init__(self, *args, canvas_size=64, **kwargs):
super(FixedLengthQuickDrawDataset, self).__init__(*args, **kwargs)
self.canvas_size = canvas_size
def __getitem__(self, idx):
sample = super(FixedLengthQuickDrawDataset, self).__getitem__(idx)
# We construct a stroke opacity variable from the pen down state, dx, dy remain unchanged
strokes = sample[:, :3]
im = np.zeros((1, 1))
# render image
# start = time.time()
im = rendering.opacityStroke2diffvg(
th.from_numpy(strokes).unsqueeze(0), canvas_size=self.canvas_size,
relative=True, debug=False)
im = im.squeeze(0).numpy()
# elapsed = (time.time() - start)*1000
# print("item %d pipeline gt rendering took %.2fms" % (idx, elapsed))
return strokes, im
class MNISTDataset(th.utils.data.Dataset):
def __init__(self, imsize, train=True):
super(MNISTDataset, self).__init__()
self.mnist = dset.MNIST(root=os.path.join(DATA, "mnist"),
train=train,
download=True,
transform=transforms.Compose([
transforms.Resize((imsize, imsize)),
transforms.ToTensor(),
]))
def __len__(self):
return len(self.mnist)
def __getitem__(self, idx):
im, label = self.mnist[idx]
# make sure data uses [0, 1] range
im -= im.min()
im /= im.max() + 1e-8
# Bring it to [-1, 1]
im -= 0.5
im /= 0.5
return im