|
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets. |
|
|
|
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in |
|
subclasses. |
|
|
|
""" |
|
import random |
|
import numpy as np |
|
import torch.utils.data as data |
|
import torch |
|
from PIL import Image |
|
import torchvision.transforms as transforms |
|
from abc import ABC, abstractmethod |
|
|
|
|
|
class BaseDataset(data.Dataset, ABC): |
|
"""This class is an abstract base class (ABC) for datasets. |
|
|
|
To create a subclass, you need to implement the following four functions: |
|
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). |
|
-- <__len__>: return the size of dataset. |
|
-- <__getitem__>: get a data point. |
|
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options. |
|
""" |
|
|
|
def __init__(self, opt): |
|
"""Initialize the class; save the options in the class |
|
|
|
Parameters: |
|
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions |
|
""" |
|
self.opt = opt |
|
self.root = opt.dataroot |
|
|
|
@staticmethod |
|
def modify_commandline_options(parser, is_train): |
|
"""用于添加针对这个数据集特定的选项,这个脚本里头只是一个样例。 |
|
|
|
Parameters: |
|
parser -- original option parser |
|
parser: |
|
is_train (bool) -- whether training phase or test phase. |
|
|
|
Returns: |
|
the modified parser. |
|
""" |
|
return parser |
|
|
|
@abstractmethod |
|
def __len__(self): |
|
"""Return the total number of images in the dataset.""" |
|
return 0 |
|
|
|
@abstractmethod |
|
def __getitem__(self, index): |
|
"""Return a data point and its metadata information. |
|
|
|
Parameters: |
|
index - - a random integer for data indexing |
|
|
|
Returns: |
|
a dictionary of data with their names. It usually contains the data itself and its metadata information. |
|
""" |
|
pass |
|
|
|
|
|
def get_params(opt, size): |
|
w, h = size |
|
new_h = h |
|
new_w = w |
|
if opt.preprocess == "resize_and_crop": |
|
new_h = new_w = opt.load_size |
|
elif opt.preprocess == "scale_width_and_crop": |
|
new_w = opt.load_size |
|
new_h = opt.load_size * h // w |
|
|
|
x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) |
|
y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) |
|
|
|
flip = random.random() > 0.5 |
|
|
|
return {"crop_pos": (x, y), "flip": flip} |
|
|
|
|
|
def get_transform( |
|
opt, |
|
params=None, |
|
grayscale=False, |
|
convert=True, |
|
method=transforms.InterpolationMode.BICUBIC, |
|
): |
|
"""数据预处理""" |
|
transform_list = [] |
|
|
|
|
|
if grayscale: |
|
transform_list.append(transforms.Grayscale(1)) |
|
|
|
|
|
|
|
if "resize" in opt.preprocess: |
|
osize = [opt.load_size, opt.load_size] |
|
transform_list.append(transforms.Resize(osize, method)) |
|
elif "scale_width" in opt.preprocess: |
|
transform_list.append( |
|
transforms.Lambda( |
|
lambda img: __scale_width(img, opt.load_size, opt.crop_size, method) |
|
) |
|
) |
|
|
|
|
|
if "crop" in opt.preprocess: |
|
if params is None: |
|
transform_list.append(transforms.RandomCrop(opt.crop_size)) |
|
else: |
|
transform_list.append( |
|
transforms.Lambda( |
|
lambda img: __crop(img, params["crop_pos"], opt.crop_size) |
|
) |
|
) |
|
if opt.preprocess == "none": |
|
transform_list.append( |
|
transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)) |
|
) |
|
|
|
|
|
if not opt.no_flip: |
|
if params is None: |
|
transform_list.append(transforms.RandomHorizontalFlip()) |
|
elif params["flip"]: |
|
transform_list.append( |
|
transforms.Lambda(lambda img: __flip(img, params["flip"])) |
|
) |
|
|
|
|
|
if convert: |
|
transform_list += [transforms.ToTensor()] |
|
transform_list += [GaussionNoise()] if opt.isTrain else [] |
|
if grayscale: |
|
transform_list += [transforms.Normalize((0.5,), (0.5,))] |
|
else: |
|
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] |
|
return transforms.Compose(transform_list) |
|
|
|
|
|
def __transforms2pil_resize(method): |
|
mapper = { |
|
transforms.InterpolationMode.BILINEAR: Image.BILINEAR, |
|
transforms.InterpolationMode.BICUBIC: Image.BICUBIC, |
|
transforms.InterpolationMode.NEAREST: Image.NEAREST, |
|
transforms.InterpolationMode.LANCZOS: Image.LANCZOS, |
|
} |
|
return mapper[method] |
|
|
|
|
|
def __make_power_2(img, base, method=transforms.InterpolationMode.BICUBIC): |
|
"""根据给定的方法(例如:双三次插值),将图片变成指定的大小。 |
|
其中的round函数是一种四舍五入的方法。 |
|
""" |
|
method = __transforms2pil_resize(method) |
|
ow, oh = img.size |
|
h = int(round(oh / base) * base) |
|
w = int(round(ow / base) * base) |
|
if h == oh and w == ow: |
|
return img |
|
|
|
__print_size_warning(ow, oh, w, h) |
|
return img.resize((w, h), method) |
|
|
|
|
|
def __scale_width( |
|
img, target_size, crop_size, method=transforms.InterpolationMode.BICUBIC |
|
): |
|
"""调整大小""" |
|
method = __transforms2pil_resize(method) |
|
ow, oh = img.size |
|
if ow == target_size and oh >= crop_size: |
|
return img |
|
w = target_size |
|
h = int(max(target_size * oh / ow, crop_size)) |
|
return img.resize((w, h), method) |
|
|
|
|
|
def __crop(img, pos, size): |
|
"""图片裁剪""" |
|
ow, oh = img.size |
|
x1, y1 = pos |
|
tw = th = size |
|
if ow > tw or oh > th: |
|
return img.crop((x1, y1, x1 + tw, y1 + th)) |
|
return img |
|
|
|
|
|
def __flip(img, flip): |
|
"""图片左右翻转""" |
|
if flip: |
|
return img.transpose(Image.FLIP_LEFT_RIGHT) |
|
return img |
|
|
|
|
|
def _gaussion_noise(img): |
|
noise = torch.randn(img.shape) |
|
img = img + noise * 0.1 |
|
return img |
|
|
|
|
|
def __print_size_warning(ow, oh, w, h): |
|
"""Print warning information about image size(only print once)""" |
|
if not hasattr(__print_size_warning, "has_printed"): |
|
print( |
|
"The image size needs to be a multiple of 4. " |
|
"The loaded image size was (%d, %d), so it was adjusted to " |
|
"(%d, %d). This adjustment will be done to all images " |
|
"whose sizes are not multiples of 4" % (ow, oh, w, h) |
|
) |
|
__print_size_warning.has_printed = True |
|
|
|
|
|
class GaussionNoise: |
|
"""添加高斯噪声""" |
|
|
|
def __init__(self) -> None: |
|
pass |
|
|
|
def __call__(self, img): |
|
noise = torch.randn(img.shape) |
|
img_mix_noise = img + noise * 0.1 |
|
return img_mix_noise |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.__class__.__name__}()" |
|
|