import numpy as np from PIL import Image from torch.utils.data import Dataset from torchvision import transforms from .hparams import hparams from .indexed_datasets import IndexedDataset from .matlab_resize import imresize class SRDataSet(Dataset): def __init__(self, prefix='train'): self.hparams = hparams self.data_dir = hparams['binary_data_dir'] self.prefix = prefix self.len = len(IndexedDataset(f'{self.data_dir}/{self.prefix}')) self.to_tensor_norm = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) assert hparams['data_interp'] in ['bilinear', 'bicubic'] self.data_augmentation = hparams['data_augmentation'] self.indexed_ds = None if self.prefix == 'valid': self.len = hparams['eval_batch_size'] * hparams['valid_steps'] def _get_item(self, index): if self.indexed_ds is None: self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}') return self.indexed_ds[index] def __getitem__(self, index): item = self._get_item(index) hparams = self.hparams img_hr = item['img'] img_hr = Image.fromarray(np.uint8(img_hr)) img_hr = self.pre_process(img_hr) # PIL img_hr = np.asarray(img_hr) # np.uint8 [H, W, C] img_lr = imresize(img_hr, 1 / hparams['sr_scale'], method=hparams['data_interp']) # np.uint8 [H, W, C] img_lr_up = imresize(img_lr / 256, hparams['sr_scale']) # np.float [H, W, C] img_hr, img_lr, img_lr_up = [self.to_tensor_norm(x).float() for x in [img_hr, img_lr, img_lr_up]] return { 'img_hr': img_hr, 'img_lr': img_lr, 'img_lr_up': img_lr_up, 'item_name': item['item_name'] } def pre_process(self, img_hr): return img_hr def __len__(self): return self.len