File size: 2,107 Bytes
02c5426 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import os
import json
from PIL import Image
import pickle
import imageio
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from datasets import register
@register('hr_data_loader')
class HRImgLoader(Dataset):
def __init__(self, root_path, split_file, split_key, first_k=None, cache='none'):
self.cache = cache
with open(split_file, 'r') as f:
filenames = json.load(f)[split_key]
if first_k is not None:
filenames = filenames[:first_k]
self.files = []
for filename in filenames:
file = os.path.join(root_path, filename)
if cache == 'none':
self.files.append(file)
elif cache == 'bin':
bin_root = os.path.join(os.path.dirname(root_path),
'_bin_' + os.path.basename(root_path))
if not os.path.exists(bin_root):
os.mkdir(bin_root)
print('mkdir', bin_root)
bin_file = os.path.join(
bin_root, filename.split('.')[0] + '.pkl')
if not os.path.exists(bin_file):
with open(bin_file, 'wb') as f:
pickle.dump(imageio.imread(file), f)
print('dump', bin_file)
self.files.append(bin_file)
elif cache == 'in_memory':
self.files.append(transforms.ToTensor()(
Image.open(file).convert('RGB')))
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
x = self.files[idx]
file_name = x
if self.cache == 'none':
return transforms.ToTensor()(Image.open(x).convert('RGB')), file_name
elif self.cache == 'bin':
with open(x, 'rb') as f:
x = pickle.load(f)
x = np.ascontiguousarray(x.transpose(2, 0, 1))
x = torch.from_numpy(x).float() / 255
return x, file_name
elif self.cache == 'in_memory':
return x, file_name
|