|
import os |
|
import json |
|
from PIL import Image |
|
|
|
import pickle |
|
import imageio |
|
import numpy as np |
|
import torch |
|
from torch.utils.data import Dataset |
|
from torchvision import transforms |
|
|
|
from datasets import register |
|
|
|
|
|
@register('hr_data_loader') |
|
class HRImgLoader(Dataset): |
|
def __init__(self, root_path, split_file, split_key, first_k=None, cache='none'): |
|
self.cache = cache |
|
with open(split_file, 'r') as f: |
|
filenames = json.load(f)[split_key] |
|
if first_k is not None: |
|
filenames = filenames[:first_k] |
|
|
|
self.files = [] |
|
for filename in filenames: |
|
file = os.path.join(root_path, filename) |
|
|
|
if cache == 'none': |
|
self.files.append(file) |
|
|
|
elif cache == 'bin': |
|
bin_root = os.path.join(os.path.dirname(root_path), |
|
'_bin_' + os.path.basename(root_path)) |
|
if not os.path.exists(bin_root): |
|
os.mkdir(bin_root) |
|
print('mkdir', bin_root) |
|
bin_file = os.path.join( |
|
bin_root, filename.split('.')[0] + '.pkl') |
|
if not os.path.exists(bin_file): |
|
with open(bin_file, 'wb') as f: |
|
pickle.dump(imageio.imread(file), f) |
|
print('dump', bin_file) |
|
self.files.append(bin_file) |
|
|
|
elif cache == 'in_memory': |
|
self.files.append(transforms.ToTensor()( |
|
Image.open(file).convert('RGB'))) |
|
|
|
def __len__(self): |
|
return len(self.files) |
|
|
|
def __getitem__(self, idx): |
|
x = self.files[idx] |
|
file_name = x |
|
|
|
if self.cache == 'none': |
|
return transforms.ToTensor()(Image.open(x).convert('RGB')), file_name |
|
|
|
elif self.cache == 'bin': |
|
with open(x, 'rb') as f: |
|
x = pickle.load(f) |
|
x = np.ascontiguousarray(x.transpose(2, 0, 1)) |
|
x = torch.from_numpy(x).float() / 255 |
|
return x, file_name |
|
|
|
elif self.cache == 'in_memory': |
|
return x, file_name |
|
|
|
|
|
|