File size: 2,320 Bytes
02c5426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import functools
import os.path
import random
import math

import torchvision.transforms
from PIL import Image
import numpy as np
import torch
from einops import rearrange
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms import InterpolationMode

from datasets import register
import torchvision.transforms
from utils import to_pixel_samples, to_coordinates



def resize_fn(img, size):
    return transforms.ToTensor()(
        transforms.Resize(size, Image.BICUBIC)(
            transforms.ToPILImage()(img)))


@register('cnn_fixed_scale_sr_warp')
class CNNFixedScaleSRWarp(Dataset):
    def __init__(self, dataset, scale_ratio, patch_size=48,
                 augment=False, val_mode=False, test_mode=False,
                 vis_continuous=False):
        self.dataset = dataset
        self.augment = augment
        self.scale_ratio = scale_ratio
        self.hr_size = int(patch_size * scale_ratio)
        self.test_mode = test_mode
        self.val_mode = val_mode
        self.patch_size = patch_size
        self.vis_continuous = vis_continuous

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img_hr, file_name = self.dataset[idx]
        class_name = os.path.basename(os.path.dirname(file_name))
        file_name = os.path.basename(file_name).split('.')[0]

        if self.vis_continuous:
            img_lr = transforms.Resize(self.patch_size, InterpolationMode.BICUBIC)(
                transforms.CenterCrop(4*self.patch_size)(img_hr))

        # img_hr: 3xHxW
        if self.test_mode:
            img_hr = transforms.CenterCrop(self.hr_size)(img_hr)
        else:
            img_hr = transforms.RandomCrop(self.hr_size)(img_hr)

        if not self.vis_continuous:
            img_lr = transforms.Resize(self.patch_size, InterpolationMode.BICUBIC)(img_hr)

        if self.augment and not self.test_mode:
            if random.random() < 0.5:
                img_lr = img_lr.flip(-1)
                img_hr = img_hr.flip(-1)
            if random.random() < 0.5:
                img_lr = img_lr.flip(-2)
                img_hr = img_hr.flip(-2)

        return {
            'img': img_lr,
            'gt': img_hr,
            'class_name': class_name,
            'filename': file_name
        }