FunSR / datasets /rs_super_warp.py
KyanChen's picture
add
02c5426
import functools
import random
import math
from PIL import Image
import numpy as np
import torch
from einops import rearrange
from torch.utils.data import Dataset
from torchvision import transforms
from datasets import register
from utils import to_pixel_samples, to_coordinates
def resize_fn(img, size):
return transforms.ToTensor()(
transforms.Resize(size, Image.BICUBIC)(
transforms.ToPILImage()(img)))
@register('rs_sr_warp')
class RSSRWarp(Dataset):
def __init__(self, dataset, size_min=None, size_max=None,
augment=False, gt_resize=None, sample_q=None, val_mode=False):
self.dataset = dataset
self.size_min = size_min
if size_max is None:
size_max = size_min
self.size_max = size_max
self.augment = augment
self.gt_resize = gt_resize
self.sample_q = sample_q
self.val_mode = val_mode
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
img_lr, img_hr = self.dataset[idx]
# p = idx / (len(self.dataset) - 1)
if not self.val_mode:
p = random.random()
w_hr = round(self.size_min + (self.size_max - self.size_min) * p)
img_hr = resize_fn(img_hr, w_hr)
else:
img_hr = resize_fn(img_hr, self.size_max)
if self.augment and not self.val_mode:
if random.random() < 0.5:
img_lr = img_lr.flip(-1)
img_hr = img_hr.flip(-1)
if random.random() < 0.5:
img_lr = img_lr.flip(-2)
img_hr = img_hr.flip(-2)
if self.gt_resize is not None:
img_hr = resize_fn(img_hr, self.gt_resize)
hr_coord = to_coordinates(size=img_hr.shape[-2:], return_map=False)
hr_rgb = rearrange(img_hr, 'C H W -> (H W) C')
if self.sample_q is not None:
sample_lst = np.random.choice(len(hr_coord), self.sample_q, replace=False)
hr_coord = hr_coord[sample_lst]
hr_rgb = hr_rgb[sample_lst]
# cell = torch.ones_like(hr_coord)
# cell[:, 0] *= 2 / img_hr.shape[-2]
# cell[:, 1] *= 2 / img_hr.shape[-1]
return {
'inp': img_lr,
'coord': hr_coord,
'gt': hr_rgb
}