Traly's picture
init
193c713
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from .hparams import hparams
from .indexed_datasets import IndexedDataset
from .matlab_resize import imresize
class SRDataSet(Dataset):
def __init__(self, prefix='train'):
self.hparams = hparams
self.data_dir = hparams['binary_data_dir']
self.prefix = prefix
self.len = len(IndexedDataset(f'{self.data_dir}/{self.prefix}'))
self.to_tensor_norm = transforms.Compose([
transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
assert hparams['data_interp'] in ['bilinear', 'bicubic']
self.data_augmentation = hparams['data_augmentation']
self.indexed_ds = None
if self.prefix == 'valid':
self.len = hparams['eval_batch_size'] * hparams['valid_steps']
def _get_item(self, index):
if self.indexed_ds is None:
self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
return self.indexed_ds[index]
def __getitem__(self, index):
item = self._get_item(index)
hparams = self.hparams
img_hr = item['img']
img_hr = Image.fromarray(np.uint8(img_hr))
img_hr = self.pre_process(img_hr) # PIL
img_hr = np.asarray(img_hr) # np.uint8 [H, W, C]
img_lr = imresize(img_hr, 1 / hparams['sr_scale'], method=hparams['data_interp']) # np.uint8 [H, W, C]
img_lr_up = imresize(img_lr / 256, hparams['sr_scale']) # np.float [H, W, C]
img_hr, img_lr, img_lr_up = [self.to_tensor_norm(x).float() for x in [img_hr, img_lr, img_lr_up]]
return {
'img_hr': img_hr, 'img_lr': img_lr, 'img_lr_up': img_lr_up,
'item_name': item['item_name']
}
def pre_process(self, img_hr):
return img_hr
def __len__(self):
return self.len