Spaces:
Runtime error
Runtime error
| # --------------------------------------------------------------- | |
| # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # This work is licensed under the NVIDIA Source Code License | |
| # for Denoising Diffusion GAN. To view a copy of this license, see the LICENSE file. | |
| # --------------------------------------------------------------- | |
| import torch.utils.data as data | |
| import numpy as np | |
| import lmdb | |
| import os | |
| import io | |
| from PIL import Image | |
| def num_samples(dataset, train): | |
| if dataset == 'celeba': | |
| return 27000 if train else 3000 | |
| else: | |
| raise NotImplementedError('dataset %s is unknown' % dataset) | |
| class LMDBDataset(data.Dataset): | |
| def __init__(self, root, name='', train=True, transform=None, is_encoded=False): | |
| self.train = train | |
| self.name = name | |
| self.transform = transform | |
| if self.train: | |
| lmdb_path = os.path.join(root, 'train.lmdb') | |
| else: | |
| lmdb_path = os.path.join(root, 'validation.lmdb') | |
| self.data_lmdb = lmdb.open(lmdb_path, readonly=True, max_readers=1, | |
| lock=False, readahead=False, meminit=False) | |
| self.is_encoded = is_encoded | |
| def __getitem__(self, index): | |
| target = [0] | |
| with self.data_lmdb.begin(write=False, buffers=True) as txn: | |
| data = txn.get(str(index).encode()) | |
| if self.is_encoded: | |
| img = Image.open(io.BytesIO(data)) | |
| img = img.convert('RGB') | |
| else: | |
| img = np.asarray(data, dtype=np.uint8) | |
| # assume data is RGB | |
| size = int(np.sqrt(len(img) / 3)) | |
| img = np.reshape(img, (size, size, 3)) | |
| img = Image.fromarray(img, mode='RGB') | |
| if self.transform is not None: | |
| img = self.transform(img) | |
| return img, target | |
| def __len__(self): | |
| return num_samples(self.name, self.train) | |