|
import copy |
|
import functools |
|
import os |
|
import random |
|
import math |
|
from PIL import Image |
|
|
|
import numpy as np |
|
import torch |
|
from einops import rearrange |
|
from torch.utils.data import Dataset |
|
from torchvision import transforms |
|
|
|
from datasets import register |
|
from utils import to_pixel_samples, to_coordinates |
|
|
|
import torchvision.transforms.functional as TF |
|
import random |
|
from typing import Sequence |
|
|
|
|
|
class MyRotateTransform: |
|
def __init__(self, angles: Sequence[int], p=0.5): |
|
self.angles = angles |
|
self.p = p |
|
|
|
def __call__(self, x): |
|
if torch.rand(1) < self.p: |
|
return x |
|
angle = random.choice(self.angles) |
|
return TF.rotate(x, angle) |
|
|
|
|
|
@register('inr_diinn_select_scale_sr_warp') |
|
class INRSelectScaleSRWarp(Dataset): |
|
def __init__(self, |
|
dataset, scales, patch_size=48, |
|
augment=False, |
|
val_mode=False, test_mode=False |
|
): |
|
super(INRSelectScaleSRWarp, self).__init__() |
|
self.dataset = dataset |
|
self.scales = scales |
|
self.patch_size = patch_size |
|
self.augment = augment |
|
self.test_mode = test_mode |
|
self.val_mode = val_mode |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
img_hr_ori, file_name = self.dataset[idx] |
|
class_name = os.path.basename(os.path.dirname(file_name)) |
|
|
|
sample = {} |
|
for scale in self.scales: |
|
hr_size = self.patch_size * scale |
|
hr_size = int(hr_size) |
|
|
|
if self.test_mode or self.val_mode: |
|
hr_size = int(self.patch_size * max(self.scales)) |
|
img_hr = transforms.CenterCrop(hr_size)(img_hr_ori) |
|
else: |
|
img_hr = transforms.RandomCrop(hr_size)(copy.deepcopy(img_hr_ori)) |
|
if self.augment: |
|
img_hr = transforms.RandomHorizontalFlip(p=0.5)(img_hr) |
|
img_hr = transforms.RandomVerticalFlip(p=0.5)(img_hr) |
|
img_hr = MyRotateTransform([90, 180, 270], p=0.5)(img_hr) |
|
|
|
img_lr = transforms.Resize(self.patch_size, TF.InterpolationMode.BICUBIC)(img_hr) |
|
sample[scale] = {'img': img_lr, 'gt': img_hr, 'class_name': class_name} |
|
|
|
return sample |