FunSR / datasets /datasets_loader.py
KyanChen's picture
add
02c5426
import os
import json
from PIL import Image
import pickle
import imageio
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from datasets import register
@register('hr_data_loader')
class HRImgLoader(Dataset):
def __init__(self, root_path, split_file, split_key, first_k=None, cache='none'):
self.cache = cache
with open(split_file, 'r') as f:
filenames = json.load(f)[split_key]
if first_k is not None:
filenames = filenames[:first_k]
self.files = []
for filename in filenames:
file = os.path.join(root_path, filename)
if cache == 'none':
self.files.append(file)
elif cache == 'bin':
bin_root = os.path.join(os.path.dirname(root_path),
'_bin_' + os.path.basename(root_path))
if not os.path.exists(bin_root):
os.mkdir(bin_root)
print('mkdir', bin_root)
bin_file = os.path.join(
bin_root, filename.split('.')[0] + '.pkl')
if not os.path.exists(bin_file):
with open(bin_file, 'wb') as f:
pickle.dump(imageio.imread(file), f)
print('dump', bin_file)
self.files.append(bin_file)
elif cache == 'in_memory':
self.files.append(transforms.ToTensor()(
Image.open(file).convert('RGB')))
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
x = self.files[idx]
file_name = x
if self.cache == 'none':
return transforms.ToTensor()(Image.open(x).convert('RGB')), file_name
elif self.cache == 'bin':
with open(x, 'rb') as f:
x = pickle.load(f)
x = np.ascontiguousarray(x.transpose(2, 0, 1))
x = torch.from_numpy(x).float() / 255
return x, file_name
elif self.cache == 'in_memory':
return x, file_name