"""不同的模型使用不同的数据集 比如有监督模型使用的都是成对的训练数据、无监督模型使用的数据集不必使用成对的数据 This package includes all the modules related to data loading and preprocessing To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. You need to implement four functions: -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). -- <__len__>: return the size of dataset. -- <__getitem__>: get a data point from data loader. -- : (optionally) add dataset-specific options and set default options. Now you can use the dataset class by specifying a flag '--dataset_mode dummy'. See our template dataset class 'template_dataset.py' for more details. """ import pickle import importlib import torch.utils.data from .base_dataset import BaseDataset from .one_dataset import * __all__ = [OneDataset] def find_dataset_by_name(dataset_name: str): """按照数据集名称来寻找所对应的dataset类进行动态导入 Import the module "data/[dataset_name]_dataset.py". In the file, the class called DatasetNameDataset() will be instantiated. It has to be a subclass of BaseDataset, and it is case-insensitive. """ dataset_filename = "data." + dataset_name + "_dataset" datasetlib = importlib.import_module(dataset_filename) dataset = None target_dataset_name = dataset_name.replace("_", "") + "dataset" for name, cls in datasetlib.__dict__.items(): if name.lower() == target_dataset_name.lower() and issubclass(cls, BaseDataset): dataset = cls if dataset is None: raise NotImplementedError(f"In {dataset_filename}.py, there should be a subclass of BaseDataset with class " f"name that matches {target_dataset_name} in lowercase.") return dataset def get_option_setter(dataset_name): """Return the static method of the dataset class.""" dataset_class = find_dataset_by_name(dataset_name) return dataset_class.modify_commandline_options def create_dataset(opt): """Create a dataset given the option. This function wraps the class CustomDatasetDataLoader. This is the main interface between this package and 'train.py'/'test.py' Example: >>> from data import create_dataset >>> dataset = create_dataset(opt) """ data_loader = CustomDatasetDataLoader(opt) dataset = data_loader.load_data() return dataset class CustomDatasetDataLoader: """Wrapper class of Dataset class that performs multi-threading data loading""" def __init__(self, opt): """Initialize this class Step 1: create a dataset instance given the name [dataset_mode] Step 2: create a multi-threading data loader. """ self.opt = opt dataset_file = f"datasets/{opt.name}.pkl" if not Path(dataset_file).exists(): # 判断数据集类型(成对/不成对),得到相应的类包 dataset_class = find_dataset_by_name(opt.dataset_mode) # 传入数据集路径到类包中,得到数据集 self.dataset = dataset_class(opt) # 打包下次直接使用 # 打包后文件也很大,暂时就这样 print("pickle dump dataset...") pickle.dump(self.dataset, open(dataset_file, 'wb')) else: print("pickle load dataset...") self.dataset = pickle.load(open(dataset_file, 'rb')) print("dataset [%s] was created" % type(self.dataset).__name__) self.dataloader = torch.utils.data.DataLoader( self.dataset, batch_size=opt.batch_size, shuffle=not opt.serial_batches, num_workers=int(opt.num_threads), ) def load_data(self): print(f"The number of training images = {len(self)}") return self def __iter__(self): """Return a batch of data""" for i, data in enumerate(self.dataloader): if i * self.opt.batch_size >= self.opt.max_dataset_size: break yield data def __len__(self): """Return the number of data in the dataset""" return min(len(self.dataset), self.opt.max_dataset_size)