rcan / data /srdata.py
zhengrongzhang's picture
init model
da9195c
raw
history blame
2.53 kB
import os
import glob
import random
import pickle
from data import common
import numpy as np
import imageio
import torch
import torch.utils.data as data
class SRData(data.Dataset):
def __init__(self, args, name='', benchmark=False):
self.args = args
self.name = name
self.split = 'test'
self.do_eval = True
self.benchmark = benchmark
self.input_large = False
self.scale = args.scale
self.idx_scale = 0
self._set_filesystem(args.dir_data)
list_hr, list_lr = self._scan()
self.images_hr, self.images_lr = list_hr, list_lr
# Below functions as used to prepare images
def _scan(self):
names_hr = sorted(
glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
)
names_lr = [[] for _ in self.scale]
for f in names_hr:
filename, _ = os.path.splitext(os.path.basename(f))
for si, s in enumerate(self.scale):
names_lr[si].append(os.path.join(
self.dir_lr, 'X{}/{}x{}{}'.format(
s, filename, s, self.ext[1]
)
))
return names_hr, names_lr
def _set_filesystem(self, dir_data):
self.apath = os.path.join(dir_data, self.name)
self.dir_hr = os.path.join(self.apath, 'HR')
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
self.ext = ('.png', '.png')
def __getitem__(self, idx):
lr, hr, filename = self._load_file(idx)
pair = self.get_patch(lr, hr)
pair = common.set_channel(*pair, n_channels=3)
pair_t = common.np2Tensor(*pair, rgb_range=255)
return pair_t[0], pair_t[1], filename
def __len__(self):
return len(self.images_hr)
def _get_index(self, idx):
return idx
def _load_file(self, idx):
idx = self._get_index(idx)
f_hr = self.images_hr[idx]
f_lr = self.images_lr[self.idx_scale][idx]
filename, _ = os.path.splitext(os.path.basename(f_hr))
hr = imageio.imread(f_hr)
lr = imageio.imread(f_lr)
return lr, hr, filename
def get_patch(self, lr, hr):
scale = self.scale[self.idx_scale]
ih, iw = lr.shape[:2]
hr = hr[0:ih * scale, 0:iw * scale]
return lr, hr
def set_scale(self, idx_scale):
if not self.input_large:
self.idx_scale = idx_scale
else:
self.idx_scale = random.randint(0, len(self.scale) - 1)