import functools import os.path import random import math import torchvision.transforms from PIL import Image import numpy as np import torch from einops import rearrange from torch.utils.data import Dataset from torchvision import transforms from torchvision.transforms import InterpolationMode from datasets import register import torchvision.transforms from utils import to_pixel_samples, to_coordinates def resize_fn(img, size): return transforms.ToTensor()( transforms.Resize(size, Image.BICUBIC)( transforms.ToPILImage()(img))) @register('cnn_fixed_scale_sr_warp') class CNNFixedScaleSRWarp(Dataset): def __init__(self, dataset, scale_ratio, patch_size=48, augment=False, val_mode=False, test_mode=False, vis_continuous=False): self.dataset = dataset self.augment = augment self.scale_ratio = scale_ratio self.hr_size = int(patch_size * scale_ratio) self.test_mode = test_mode self.val_mode = val_mode self.patch_size = patch_size self.vis_continuous = vis_continuous def __len__(self): return len(self.dataset) def __getitem__(self, idx): img_hr, file_name = self.dataset[idx] class_name = os.path.basename(os.path.dirname(file_name)) file_name = os.path.basename(file_name).split('.')[0] if self.vis_continuous: img_lr = transforms.Resize(self.patch_size, InterpolationMode.BICUBIC)( transforms.CenterCrop(4*self.patch_size)(img_hr)) # img_hr: 3xHxW if self.test_mode: img_hr = transforms.CenterCrop(self.hr_size)(img_hr) else: img_hr = transforms.RandomCrop(self.hr_size)(img_hr) if not self.vis_continuous: img_lr = transforms.Resize(self.patch_size, InterpolationMode.BICUBIC)(img_hr) if self.augment and not self.test_mode: if random.random() < 0.5: img_lr = img_lr.flip(-1) img_hr = img_hr.flip(-1) if random.random() < 0.5: img_lr = img_lr.flip(-2) img_hr = img_hr.flip(-2) return { 'img': img_lr, 'gt': img_hr, 'class_name': class_name, 'filename': file_name }