import functools import random import math from PIL import Image import numpy as np import torch from torch.utils.data import Dataset from torchvision import transforms from datasets import register from utils import to_pixel_samples @register('liff_test_warp') class LIIFTestWarp(Dataset): def __init__(self, dataset, scale_ratio, val_mode=False, sample_q=None): self.dataset = dataset self.scale_ratio = scale_ratio self.val_mode = val_mode self.sample_q = sample_q print('hr_scale: ', int(scale_ratio*32)) def __len__(self): return len(self.dataset) def __getitem__(self, idx): img_lr, img_hr = self.dataset[idx] if img_hr.shape[-1] < 256: img_hr = transforms.Resize([256, 256])(img_hr) img_hr = transforms.Resize([self.scale_ratio*32, self.scale_ratio*32])(img_hr) hr_coord, hr_rgb = to_pixel_samples(img_hr.contiguous()) if self.sample_q is not None: sample_lst = np.random.choice(len(hr_coord), self.sample_q, replace=False) hr_coord = hr_coord[sample_lst] hr_rgb = hr_rgb[sample_lst] cell = torch.ones_like(hr_coord) cell[:, 0] *= 2 / img_hr.shape[-2] cell[:, 1] *= 2 / img_hr.shape[-1] return { 'inp': img_lr, 'coord': hr_coord, 'cell': cell, 'gt': hr_rgb } @register('sr-implicit-paired') class SRImplicitPaired(Dataset): def __init__(self, dataset, inp_size=None, augment=False, sample_q=None): self.dataset = dataset self.inp_size = inp_size self.augment = augment self.sample_q = sample_q def __len__(self): return len(self.dataset) def __getitem__(self, idx): img_lr, img_hr = self.dataset[idx] if img_hr.shape[-1] < 256: img_hr = transforms.Resize([256, 256])(img_hr) s = img_hr.shape[-2] // img_lr.shape[-2] # assume int scale if self.inp_size is None: h_lr, w_lr = img_lr.shape[-2:] img_hr = img_hr[:, :h_lr * s, :w_lr * s] crop_lr, crop_hr = img_lr, img_hr else: w_lr = self.inp_size x0 = random.randint(0, img_lr.shape[-2] - w_lr) y0 = random.randint(0, img_lr.shape[-1] - w_lr) crop_lr = img_lr[:, x0: x0 + w_lr, y0: y0 + w_lr] w_hr = w_lr * s x1 = x0 * s y1 = y0 * s crop_hr = img_hr[:, x1: x1 + w_hr, y1: y1 + w_hr] if self.augment: hflip = random.random() < 0.5 vflip = random.random() < 0.5 dflip = random.random() < 0.5 def augment(x): if hflip: x = x.flip(-2) if vflip: x = x.flip(-1) if dflip: x = x.transpose(-2, -1) return x crop_lr = augment(crop_lr) crop_hr = augment(crop_hr) hr_coord, hr_rgb = to_pixel_samples(crop_hr.contiguous()) if self.sample_q is not None: sample_lst = np.random.choice( len(hr_coord), self.sample_q, replace=False) hr_coord = hr_coord[sample_lst] hr_rgb = hr_rgb[sample_lst] cell = torch.ones_like(hr_coord) cell[:, 0] *= 2 / crop_hr.shape[-2] cell[:, 1] *= 2 / crop_hr.shape[-1] return { 'inp': crop_lr, 'coord': hr_coord, 'cell': cell, 'gt': hr_rgb } def resize_fn(img, size): return transforms.ToTensor()( transforms.Resize(size, Image.BICUBIC)( transforms.ToPILImage()(img))) @register('sr-implicit-downsampled') class SRImplicitDownsampled(Dataset): def __init__(self, dataset, inp_size=None, scale_min=1, scale_max=None, augment=False, sample_q=None): self.dataset = dataset self.inp_size = inp_size self.scale_min = scale_min if scale_max is None: scale_max = scale_min self.scale_max = scale_max self.augment = augment self.sample_q = sample_q def __len__(self): return len(self.dataset) def __getitem__(self, idx): img = self.dataset[idx] s = random.uniform(self.scale_min, self.scale_max) if self.inp_size is None: h_lr = math.floor(img.shape[-2] / s + 1e-9) w_lr = math.floor(img.shape[-1] / s + 1e-9) img = img[:, :round(h_lr * s), :round(w_lr * s)] # assume round int img_down = resize_fn(img, (h_lr, w_lr)) crop_lr, crop_hr = img_down, img else: w_lr = self.inp_size w_hr = round(w_lr * s) x0 = random.randint(0, img.shape[-2] - w_hr) y0 = random.randint(0, img.shape[-1] - w_hr) crop_hr = img[:, x0: x0 + w_hr, y0: y0 + w_hr] crop_lr = resize_fn(crop_hr, w_lr) if self.augment: hflip = random.random() < 0.5 vflip = random.random() < 0.5 dflip = random.random() < 0.5 def augment(x): if hflip: x = x.flip(-2) if vflip: x = x.flip(-1) if dflip: x = x.transpose(-2, -1) return x crop_lr = augment(crop_lr) crop_hr = augment(crop_hr) hr_coord, hr_rgb = to_pixel_samples(crop_hr.contiguous()) if self.sample_q is not None: sample_lst = np.random.choice( len(hr_coord), self.sample_q, replace=False) hr_coord = hr_coord[sample_lst] hr_rgb = hr_rgb[sample_lst] cell = torch.ones_like(hr_coord) cell[:, 0] *= 2 / crop_hr.shape[-2] cell[:, 1] *= 2 / crop_hr.shape[-1] return { 'inp': crop_lr, 'coord': hr_coord, 'cell': cell, 'gt': hr_rgb } @register('sr-implicit-uniform-varied') class SRImplicitUniformVaried(Dataset): def __init__(self, dataset, size_min, size_max=None, augment=False, gt_resize=None, sample_q=None): self.dataset = dataset self.size_min = size_min if size_max is None: size_max = size_min self.size_max = size_max self.augment = augment self.gt_resize = gt_resize self.sample_q = sample_q def __len__(self): return len(self.dataset) def __getitem__(self, idx): img_lr, img_hr = self.dataset[idx] # p = idx / (len(self.dataset) - 1) p = random.random() w_hr = round(self.size_min + (self.size_max - self.size_min) * p) img_hr = resize_fn(img_hr, w_hr) if self.augment: if random.random() < 0.5: img_lr = img_lr.flip(-1) img_hr = img_hr.flip(-1) if self.gt_resize is not None: img_hr = resize_fn(img_hr, self.gt_resize) hr_coord, hr_rgb = to_pixel_samples(img_hr) if self.sample_q is not None: sample_lst = np.random.choice( len(hr_coord), self.sample_q, replace=False) hr_coord = hr_coord[sample_lst] hr_rgb = hr_rgb[sample_lst] cell = torch.ones_like(hr_coord) cell[:, 0] *= 2 / img_hr.shape[-2] cell[:, 1] *= 2 / img_hr.shape[-1] return { 'inp': img_lr, 'coord': hr_coord, 'cell': cell, 'gt': hr_rgb }