FunSR / datasets /image_folder.py
KyanChen's picture
add
02c5426
raw
history blame
2.66 kB
import os
import json
from PIL import Image
import pickle
import imageio
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from datasets import register
@register('image-folder')
class ImageFolder(Dataset):
def __init__(self, root_path, split_file=None, split_key=None, first_k=None,
repeat=1, cache='none'):
self.repeat = repeat
self.cache = cache
if split_file is None:
filenames = sorted(os.listdir(root_path))
else:
with open(split_file, 'r') as f:
filenames = json.load(f)[split_key]
if first_k is not None:
filenames = filenames[:first_k]
self.files = []
for filename in filenames:
file = os.path.join(root_path, filename)
if cache == 'none':
self.files.append(file)
elif cache == 'bin':
bin_root = os.path.join(os.path.dirname(root_path),
'_bin_' + os.path.basename(root_path))
if not os.path.exists(bin_root):
os.mkdir(bin_root)
print('mkdir', bin_root)
bin_file = os.path.join(
bin_root, filename.split('.')[0] + '.pkl')
if not os.path.exists(bin_file):
with open(bin_file, 'wb') as f:
pickle.dump(imageio.imread(file), f)
print('dump', bin_file)
self.files.append(bin_file)
elif cache == 'in_memory':
self.files.append(transforms.ToTensor()(
Image.open(file).convert('RGB')))
def __len__(self):
return len(self.files) * self.repeat
def __getitem__(self, idx):
x = self.files[idx % len(self.files)]
if self.cache == 'none':
return transforms.ToTensor()(Image.open(x).convert('RGB'))
elif self.cache == 'bin':
with open(x, 'rb') as f:
x = pickle.load(f)
x = np.ascontiguousarray(x.transpose(2, 0, 1))
x = torch.from_numpy(x).float() / 255
return x
elif self.cache == 'in_memory':
return x
@register('paired-image-folders')
class PairedImageFolders(Dataset):
def __init__(self, root_path_1, root_path_2, **kwargs):
self.dataset_1 = ImageFolder(root_path_1, **kwargs)
self.dataset_2 = ImageFolder(root_path_2, **kwargs)
def __len__(self):
return len(self.dataset_1)
def __getitem__(self, idx):
return self.dataset_1[idx], self.dataset_2[idx]