import os import json from PIL import Image import pickle import imageio import numpy as np import torch from torch.utils.data import Dataset from torchvision import transforms from datasets import register @register('hr_data_loader') class HRImgLoader(Dataset): def __init__(self, root_path, split_file, split_key, first_k=None, cache='none'): self.cache = cache with open(split_file, 'r') as f: filenames = json.load(f)[split_key] if first_k is not None: filenames = filenames[:first_k] self.files = [] for filename in filenames: file = os.path.join(root_path, filename) if cache == 'none': self.files.append(file) elif cache == 'bin': bin_root = os.path.join(os.path.dirname(root_path), '_bin_' + os.path.basename(root_path)) if not os.path.exists(bin_root): os.mkdir(bin_root) print('mkdir', bin_root) bin_file = os.path.join( bin_root, filename.split('.')[0] + '.pkl') if not os.path.exists(bin_file): with open(bin_file, 'wb') as f: pickle.dump(imageio.imread(file), f) print('dump', bin_file) self.files.append(bin_file) elif cache == 'in_memory': self.files.append(transforms.ToTensor()( Image.open(file).convert('RGB'))) def __len__(self): return len(self.files) def __getitem__(self, idx): x = self.files[idx] file_name = x if self.cache == 'none': return transforms.ToTensor()(Image.open(x).convert('RGB')), file_name elif self.cache == 'bin': with open(x, 'rb') as f: x = pickle.load(f) x = np.ascontiguousarray(x.transpose(2, 0, 1)) x = torch.from_numpy(x).float() / 255 return x, file_name elif self.cache == 'in_memory': return x, file_name