|
import functools |
|
import random |
|
import math |
|
from PIL import Image |
|
import numpy as np |
|
import torch |
|
from einops import rearrange |
|
from torch.utils.data import Dataset |
|
from torchvision import transforms |
|
from datasets import register |
|
from utils import to_pixel_samples, to_coordinates |
|
|
|
|
|
def resize_fn(img, size): |
|
return transforms.ToTensor()( |
|
transforms.Resize(size, Image.BICUBIC)( |
|
transforms.ToPILImage()(img))) |
|
|
|
|
|
@register('rs_sr_warp') |
|
class RSSRWarp(Dataset): |
|
def __init__(self, dataset, size_min=None, size_max=None, |
|
augment=False, gt_resize=None, sample_q=None, val_mode=False): |
|
self.dataset = dataset |
|
self.size_min = size_min |
|
if size_max is None: |
|
size_max = size_min |
|
self.size_max = size_max |
|
self.augment = augment |
|
self.gt_resize = gt_resize |
|
self.sample_q = sample_q |
|
self.val_mode = val_mode |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
img_lr, img_hr = self.dataset[idx] |
|
|
|
if not self.val_mode: |
|
p = random.random() |
|
w_hr = round(self.size_min + (self.size_max - self.size_min) * p) |
|
img_hr = resize_fn(img_hr, w_hr) |
|
else: |
|
img_hr = resize_fn(img_hr, self.size_max) |
|
|
|
|
|
if self.augment and not self.val_mode: |
|
if random.random() < 0.5: |
|
img_lr = img_lr.flip(-1) |
|
img_hr = img_hr.flip(-1) |
|
if random.random() < 0.5: |
|
img_lr = img_lr.flip(-2) |
|
img_hr = img_hr.flip(-2) |
|
|
|
if self.gt_resize is not None: |
|
img_hr = resize_fn(img_hr, self.gt_resize) |
|
|
|
hr_coord = to_coordinates(size=img_hr.shape[-2:], return_map=False) |
|
hr_rgb = rearrange(img_hr, 'C H W -> (H W) C') |
|
|
|
if self.sample_q is not None: |
|
sample_lst = np.random.choice(len(hr_coord), self.sample_q, replace=False) |
|
hr_coord = hr_coord[sample_lst] |
|
hr_rgb = hr_rgb[sample_lst] |
|
|
|
|
|
|
|
|
|
|
|
return { |
|
'inp': img_lr, |
|
'coord': hr_coord, |
|
'gt': hr_rgb |
|
} |
|
|