Spaces:
Runtime error
Runtime error
File size: 9,621 Bytes
ec0fdfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 |
import torch
import torchvision.transforms as transforms
import torch.utils.data as data
from util import task
from .image_folder import make_dataset
import random
import numpy as np
import copy
import skimage.morphology as sm
from PIL import Image, ImageFile, ImageOps
ImageFile.LOAD_TRUNCATED_IMAGES = True
######################################################################################
# Create the dataloader
######################################################################################
class CreateDataset(data.Dataset):
def __init__(self, opt):
self.opt = opt
self.img_paths, self.img_size = make_dataset(opt.img_file)
if opt.mask_file != 'none': # load the random mask files for training and testing
self.mask_paths, self.mask_size = make_dataset(opt.mask_file)
self.transform = get_transform(opt, convert=False, augment=False)
fixed_opt = copy.deepcopy(opt)
fixed_opt.preprocess = 'scale_longside'
fixed_opt.load_size = fixed_opt.fixed_size
fixed_opt.no_flip = True
self.transform_fixed = get_transform(fixed_opt, convert=True, augment=False)
def __len__(self):
"""return the total number of examples in the dataset"""
return self.img_size
def __getitem__(self, item):
"""return a data point and its metadata information"""
# load the image and conditional input
img_org, img, img_path = self._load_img(item)
if self.opt.batch_size > 1: # padding the image to the same size for batch training
img_org = transforms.functional.pad(img_org, (0, 0, self.opt.fine_size-self.img_h, self.opt.fine_size-self.img_w))
img = transforms.functional.pad(img, (0, 0, self.opt.fixed_size - img.size(-1), self.opt.fixed_size - img.size(-2)))
pad_mask = torch.zeros_like(img_org)
pad_mask[:, :self.img_w, :self.img_h] = 1
# load the mask
mask, mask_type = self._load_mask(item, img_org)
if self.opt.reverse_mask:
if self.opt.isTrain:
mask = 1 - mask if random.random() > 0.8 else mask
else:
mask = 1 - mask
return {'img_org': img_org, 'img': img, 'img_path': img_path, 'mask': mask, 'pad_mask': pad_mask}
def name(self):
return ""
def _load_img(self, item):
"""load the original image and preprocess image"""
img_path = self.img_paths[item % self.img_size]
img_pil = Image.open(img_path).convert('RGB')
img_org = self.transform(img_pil)
img = self.transform_fixed(img_org)
img_org = transforms.ToTensor()(img_org)
img_pil.close()
self.img_c, self.img_w, self.img_h = img_org.size()
return img_org, img, img_path
def _mask_dilation(self, mask):
"""mask erosion for different region"""
mask = np.array(mask)
pixel = np.random.randint(3, 25)
mask = sm.erosion(mask, sm.square(pixel)).astype(np.uint8)
return mask
def _load_mask(self, item, img):
"""load the mask for image completion task"""
c, h, w = img.size()
if isinstance(self.opt.mask_type, list):
mask_type_index = random.randint(0, len(self.opt.mask_type) - 1)
mask_type = self.opt.mask_type[mask_type_index]
else:
mask_type = self.opt.mask_type
if mask_type == 0: # center mask
if random.random() > 0.3 and self.opt.isTrain:
return task.random_regular_mask(img), mask_type # random regular mask
return task.center_mask(img), mask_type
elif mask_type == 1: # random regular mask
return task.random_regular_mask(img), mask_type
elif mask_type == 2: # random irregular mask
return task.random_irregular_mask(img), mask_type
elif mask_type == 3:
# external mask from "Image Inpainting for Irregular Holes Using Partial Convolutions (ECCV18)"
if self.opt.isTrain:
mask_index = random.randint(0, self.mask_size-1)
mask_transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.RandomCrop([self.opt.fine_size + 64, self.opt.fine_size + 64]),
transforms.Resize([h, w])
]
)
else:
mask_index = item
mask_transform = transforms.Compose(
[
transforms.Resize([h, w])
]
)
mask_pil = Image.open(self.mask_paths[mask_index]).convert('L')
mask = mask_transform(mask_pil)
mask_pil.close()
if self.opt.isTrain:
mask = self._mask_dilation(mask)
else:
mask = np.array(mask) < 128
mask = torch.tensor(mask).view(1, h, w).float()
return mask, mask_type
else:
raise NotImplementedError('mask type [%s] is not implemented' % str(mask_type))
def dataloader(opt):
datasets = CreateDataset(opt)
dataset = data.DataLoader(datasets, batch_size=opt.batch_size, shuffle=not opt.no_shuffle,
num_workers=int(opt.nThreads), drop_last=True)
return dataset
######################################################################################
# Basic image preprocess function
######################################################################################
def _make_power_2(img, power, method=Image.BICUBIC):
"""resize the image to the size of log2(base) times"""
ow, oh = img.size
base = 2 ** power
nw, nh = int(max(1, round(ow / base)) * base), int(max(1, round(oh / base)) * base)
if nw == ow and nh == oh:
return img
return img.resize((nw, nh), method)
def _random_zoom(img, target_width, method=Image.BICUBIC):
"""random resize the image scale"""
zoom_level = np.random.uniform(0.8, 1.0, size=[2])
ow, oh = img.size
nw, nh = int(round(max(target_width, ow * zoom_level[0]))), int(round(max(target_width, oh * zoom_level[1])))
return img.resize((nw, nh), method)
def _scale_shortside(img, target_width, method=Image.BICUBIC):
"""resize the short side to the target width"""
ow, oh = img.size
shortsize = min(ow, oh)
scale = target_width / shortsize
return img.resize((round(ow * scale), round(oh * scale)), method)
def _scale_longside(img, target_width, method=Image.BICUBIC):
"""resize the long side to the target width"""
ow, oh = img.size
longsize = max(ow, oh)
scale = target_width / longsize
return img.resize((round(ow * scale), round(oh * scale)), method)
def _scale_randomside(img, target_width, method=Image.BICUBIC):
"""resize the side to the target width with random side"""
if random.random() > 0.5:
return _scale_shortside(img, target_width, method)
else:
return _scale_longside(img, target_width, method)
def _crop(img, pos=None, size=None):
"""crop the image based on the given pos and size"""
ow, oh = img.size
if size is None:
return img
nw = min(ow, size)
nh = min(oh, size)
if (ow > nw or oh > nh):
if pos is None:
x1 = np.random.randint(0, int(ow-nw)+1)
y1 = np.random.randint(0, int(oh-nh)+1)
else:
x1, y1 = pos
return img.crop((x1, y1, x1 + nw, y1 + nh))
return img
def _pad(img):
"""expand the image to the square size"""
ow, oh = img.size
size = max(ow, oh)
return ImageOps.pad(img, (size, size), centering=(0, 0))
def _flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img
def get_transform(opt, params=None, method=Image.BICUBIC, convert=True, augment=False):
"""get the transform functions"""
transforms_list = []
if 'resize' in opt.preprocess:
osize = [opt.load_size, opt.load_size]
transforms_list.append(transforms.Resize(osize))
elif 'scale_shortside' in opt.preprocess:
transforms_list.append(transforms.Lambda(lambda img: _scale_shortside(img, opt.load_size, method)))
elif 'scale_longside' in opt.preprocess:
transforms_list.append(transforms.Lambda(lambda img: _scale_longside(img, opt.load_size, method)))
elif "scale_randomside" in opt.preprocess:
transforms_list.append(transforms.Lambda(lambda img: _scale_randomside(img, opt.load_size, method)))
if 'zoom' in opt.preprocess:
transforms_list.append(transforms.Lambda(lambda img: _random_zoom(img, opt.load_size, method)))
if 'crop' in opt.preprocess and opt.isTrain:
transforms_list.append(transforms.Lambda(lambda img: _crop(img, size=opt.fine_size)))
if 'pad' in opt.preprocess:
transforms_list.append(transforms.Lambda(lambda img: _pad(img))) # padding image to square
transforms_list.append(transforms.Lambda(lambda img: _make_power_2(img, opt.data_powers, method)))
if not opt.no_flip and opt.isTrain:
transforms_list.append(transforms.RandomHorizontalFlip())
if augment and opt.isTrain:
transforms_list.append(transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2))
if convert:
transforms_list.append(transforms.ToTensor())
return transforms.Compose(transforms_list) |