|
import functools |
|
import random |
|
import math |
|
from PIL import Image |
|
|
|
import numpy as np |
|
import torch |
|
from torch.utils.data import Dataset |
|
from torchvision import transforms |
|
|
|
from datasets import register |
|
from utils import to_pixel_samples |
|
|
|
|
|
@register('liff_test_warp') |
|
class LIIFTestWarp(Dataset): |
|
def __init__(self, dataset, scale_ratio, val_mode=False, sample_q=None): |
|
self.dataset = dataset |
|
self.scale_ratio = scale_ratio |
|
self.val_mode = val_mode |
|
self.sample_q = sample_q |
|
print('hr_scale: ', int(scale_ratio*32)) |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
img_lr, img_hr = self.dataset[idx] |
|
if img_hr.shape[-1] < 256: |
|
img_hr = transforms.Resize([256, 256])(img_hr) |
|
|
|
img_hr = transforms.Resize([self.scale_ratio*32, self.scale_ratio*32])(img_hr) |
|
|
|
hr_coord, hr_rgb = to_pixel_samples(img_hr.contiguous()) |
|
|
|
if self.sample_q is not None: |
|
sample_lst = np.random.choice(len(hr_coord), self.sample_q, replace=False) |
|
hr_coord = hr_coord[sample_lst] |
|
hr_rgb = hr_rgb[sample_lst] |
|
|
|
cell = torch.ones_like(hr_coord) |
|
cell[:, 0] *= 2 / img_hr.shape[-2] |
|
cell[:, 1] *= 2 / img_hr.shape[-1] |
|
|
|
return { |
|
'inp': img_lr, |
|
'coord': hr_coord, |
|
'cell': cell, |
|
'gt': hr_rgb |
|
} |
|
|
|
@register('sr-implicit-paired') |
|
class SRImplicitPaired(Dataset): |
|
|
|
def __init__(self, dataset, inp_size=None, augment=False, sample_q=None): |
|
self.dataset = dataset |
|
self.inp_size = inp_size |
|
self.augment = augment |
|
self.sample_q = sample_q |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
img_lr, img_hr = self.dataset[idx] |
|
if img_hr.shape[-1] < 256: |
|
img_hr = transforms.Resize([256, 256])(img_hr) |
|
|
|
s = img_hr.shape[-2] // img_lr.shape[-2] |
|
if self.inp_size is None: |
|
h_lr, w_lr = img_lr.shape[-2:] |
|
img_hr = img_hr[:, :h_lr * s, :w_lr * s] |
|
crop_lr, crop_hr = img_lr, img_hr |
|
else: |
|
w_lr = self.inp_size |
|
x0 = random.randint(0, img_lr.shape[-2] - w_lr) |
|
y0 = random.randint(0, img_lr.shape[-1] - w_lr) |
|
crop_lr = img_lr[:, x0: x0 + w_lr, y0: y0 + w_lr] |
|
w_hr = w_lr * s |
|
x1 = x0 * s |
|
y1 = y0 * s |
|
crop_hr = img_hr[:, x1: x1 + w_hr, y1: y1 + w_hr] |
|
|
|
if self.augment: |
|
hflip = random.random() < 0.5 |
|
vflip = random.random() < 0.5 |
|
dflip = random.random() < 0.5 |
|
|
|
def augment(x): |
|
if hflip: |
|
x = x.flip(-2) |
|
if vflip: |
|
x = x.flip(-1) |
|
if dflip: |
|
x = x.transpose(-2, -1) |
|
return x |
|
|
|
crop_lr = augment(crop_lr) |
|
crop_hr = augment(crop_hr) |
|
|
|
hr_coord, hr_rgb = to_pixel_samples(crop_hr.contiguous()) |
|
|
|
if self.sample_q is not None: |
|
sample_lst = np.random.choice( |
|
len(hr_coord), self.sample_q, replace=False) |
|
hr_coord = hr_coord[sample_lst] |
|
hr_rgb = hr_rgb[sample_lst] |
|
|
|
cell = torch.ones_like(hr_coord) |
|
cell[:, 0] *= 2 / crop_hr.shape[-2] |
|
cell[:, 1] *= 2 / crop_hr.shape[-1] |
|
|
|
return { |
|
'inp': crop_lr, |
|
'coord': hr_coord, |
|
'cell': cell, |
|
'gt': hr_rgb |
|
} |
|
|
|
|
|
def resize_fn(img, size): |
|
return transforms.ToTensor()( |
|
transforms.Resize(size, Image.BICUBIC)( |
|
transforms.ToPILImage()(img))) |
|
|
|
|
|
@register('sr-implicit-downsampled') |
|
class SRImplicitDownsampled(Dataset): |
|
|
|
def __init__(self, dataset, inp_size=None, scale_min=1, scale_max=None, |
|
augment=False, sample_q=None): |
|
self.dataset = dataset |
|
self.inp_size = inp_size |
|
self.scale_min = scale_min |
|
if scale_max is None: |
|
scale_max = scale_min |
|
self.scale_max = scale_max |
|
self.augment = augment |
|
self.sample_q = sample_q |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
img = self.dataset[idx] |
|
s = random.uniform(self.scale_min, self.scale_max) |
|
|
|
if self.inp_size is None: |
|
h_lr = math.floor(img.shape[-2] / s + 1e-9) |
|
w_lr = math.floor(img.shape[-1] / s + 1e-9) |
|
img = img[:, :round(h_lr * s), :round(w_lr * s)] |
|
img_down = resize_fn(img, (h_lr, w_lr)) |
|
crop_lr, crop_hr = img_down, img |
|
else: |
|
w_lr = self.inp_size |
|
w_hr = round(w_lr * s) |
|
x0 = random.randint(0, img.shape[-2] - w_hr) |
|
y0 = random.randint(0, img.shape[-1] - w_hr) |
|
crop_hr = img[:, x0: x0 + w_hr, y0: y0 + w_hr] |
|
crop_lr = resize_fn(crop_hr, w_lr) |
|
|
|
if self.augment: |
|
hflip = random.random() < 0.5 |
|
vflip = random.random() < 0.5 |
|
dflip = random.random() < 0.5 |
|
|
|
def augment(x): |
|
if hflip: |
|
x = x.flip(-2) |
|
if vflip: |
|
x = x.flip(-1) |
|
if dflip: |
|
x = x.transpose(-2, -1) |
|
return x |
|
|
|
crop_lr = augment(crop_lr) |
|
crop_hr = augment(crop_hr) |
|
|
|
hr_coord, hr_rgb = to_pixel_samples(crop_hr.contiguous()) |
|
|
|
if self.sample_q is not None: |
|
sample_lst = np.random.choice( |
|
len(hr_coord), self.sample_q, replace=False) |
|
hr_coord = hr_coord[sample_lst] |
|
hr_rgb = hr_rgb[sample_lst] |
|
|
|
cell = torch.ones_like(hr_coord) |
|
cell[:, 0] *= 2 / crop_hr.shape[-2] |
|
cell[:, 1] *= 2 / crop_hr.shape[-1] |
|
|
|
return { |
|
'inp': crop_lr, |
|
'coord': hr_coord, |
|
'cell': cell, |
|
'gt': hr_rgb |
|
} |
|
|
|
|
|
@register('sr-implicit-uniform-varied') |
|
class SRImplicitUniformVaried(Dataset): |
|
|
|
def __init__(self, dataset, size_min, size_max=None, |
|
augment=False, gt_resize=None, sample_q=None): |
|
self.dataset = dataset |
|
self.size_min = size_min |
|
if size_max is None: |
|
size_max = size_min |
|
self.size_max = size_max |
|
self.augment = augment |
|
self.gt_resize = gt_resize |
|
self.sample_q = sample_q |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
img_lr, img_hr = self.dataset[idx] |
|
|
|
p = random.random() |
|
w_hr = round(self.size_min + (self.size_max - self.size_min) * p) |
|
img_hr = resize_fn(img_hr, w_hr) |
|
|
|
if self.augment: |
|
if random.random() < 0.5: |
|
img_lr = img_lr.flip(-1) |
|
img_hr = img_hr.flip(-1) |
|
|
|
if self.gt_resize is not None: |
|
img_hr = resize_fn(img_hr, self.gt_resize) |
|
|
|
hr_coord, hr_rgb = to_pixel_samples(img_hr) |
|
|
|
if self.sample_q is not None: |
|
sample_lst = np.random.choice( |
|
len(hr_coord), self.sample_q, replace=False) |
|
hr_coord = hr_coord[sample_lst] |
|
hr_rgb = hr_rgb[sample_lst] |
|
|
|
cell = torch.ones_like(hr_coord) |
|
cell[:, 0] *= 2 / img_hr.shape[-2] |
|
cell[:, 1] *= 2 / img_hr.shape[-1] |
|
|
|
return { |
|
'inp': img_lr, |
|
'coord': hr_coord, |
|
'cell': cell, |
|
'gt': hr_rgb |
|
} |
|
|