import argparse from pathlib import Path import torch import data import models from util import util class BaseOptions: """模型超参数设置 """ def __init__(self): """重置该类:表示该类尚未被初始化 训练模型、测试模型、推理模型的参数全继承该类 """ self.initialized = False def initialize(self, parser): """共有参数""" # basic parameters parser.add_argument("--dataroot", type=str, default="./datasets/horse2zebra/", help="path to images (should have subfolders trainA, trainB, valA, valB, etc)") parser.add_argument("--name", type=str, default="horse2zebra", help="name of the experiment.") parser.add_argument("--gpu_ids", type=str, default="0", help="gpu ids: e.g. 0 0,1,2 -1 for CPU") parser.add_argument("--checkpoints_dir", type=str, default="./checkpoints", help="models_saved") # model parameters parser.add_argument("--model", type=str, default="cycle_gan") parser.add_argument("--input_nc", type=int, default=3, help="# input image channels: 3 for RGB and 1 for grayscale") parser.add_argument("--output_nc", type=int, default=3, help="# output image channels: 3 for RGB and 1 for grayscale") parser.add_argument("--ngf", type=int, default=64, help="# of gen filters in the last conv layer") parser.add_argument("--ndf", type=int, default=64, help="# of discrim filters in the first conv layer") parser.add_argument( "--netD", type=str, default="basic", help="[basic | n_layers | pixel]. basic: a 70x70 PatchGAN. n_layers: allows you to specify the layers in the discriminator", ) parser.add_argument("--netG", type=str, default="resnet_9blocks", help="[resnet_9blocks | resnet_6blocks | unet_256 | unet_128]") parser.add_argument("--n_layers_D", type=int, default=3, help="only used if netD==n_layers") parser.add_argument("--norm", type=str, default="instance", help="instance normalization or batch normalization [instance | batch | none]") parser.add_argument("--init_type", type=str, default="normal", help="network initialization [normal | xavier | kaiming | orthogonal]") parser.add_argument("--init_gain", type=float, default=0.02, help="scaling factor for normal, xavier and orthogonal.") parser.add_argument("--no_dropout", type=bool, default=True, help="no dropout for the generator") # dataset parameters parser.add_argument("--dataset_mode", type=str, default="unaligned") parser.add_argument("--direction", type=str, default="AtoB", help="AtoB or BtoA") parser.add_argument("--serial_batches", action="store_true", help="if true, takes images in order to make batches, otherwise takes them randomly") parser.add_argument("--num_threads", default=8, type=int, help="# threads for loading data") parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--load_size", type=int, default=286, help="scale images to this size") parser.add_argument("--crop_size", type=int, default=256, help="then crop to this size") parser.add_argument( "--max_dataset_size", type=int, default=float("inf"), help="Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.", ) parser.add_argument("--preprocess", type=str, default="resize_and_crop", help="[resize_and_crop | crop | scale_width | scale_width_and_crop | none] img preprocess") parser.add_argument("--no_flip", action="store_true", help="if specified, do not flip the images for data augmentation") parser.add_argument("--display_winsize", type=int, default=256, help="display window size for both visdom and HTML") # additional parameters parser.add_argument("--epoch", type=str, default="latest", help="which epoch to load? set to latest to use latest cached model") parser.add_argument( "--load_iter", type=int, default="0", help="which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]" ) parser.add_argument("--verbose", action="store_true", help="if specified, print more debugging information") parser.add_argument("--suffix", default="", type=str, help="customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}") # wandb parameters parser.add_argument("--use_wandb", action="store_true", help="wandb logging") parser.add_argument("--wandb_project_name", type=str, default="CycleGAN", help="wandb project name") self.initialized = True return parser def gather_options(self): """Initialize our parser with basic options(only once). Add additional model-specific and dataset-specific options. These options are defined in the function in model and dataset classes. """ if not self.initialized: # check if it has been initialized parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = self.initialize(parser) # get the basic options opt, _ = parser.parse_known_args() # modify model-related parser options model_name = opt.model model_option_setter = models.get_option_setter(model_name) parser = model_option_setter(parser, self.isTrain) opt, _ = parser.parse_known_args() # parse again with new defaults # modify dataset-related parser options dataset_name = opt.dataset_mode dataset_option_setter = data.get_option_setter(dataset_name) parser = dataset_option_setter(parser, self.isTrain) # save and return the parser self.parser = parser # return parser.parse_args() # 用上面的会报以下错误,命令行参数不能合并,gradio运行时会调用命令行参数,而我们的没有 # uvicorn: error: unrecognized arguments: app:demo.app --reload --port 7860 --log-level warning --reload-dir E:\miniconda3\envs\yanguan\lib\site-packages\gradio --reload-dir D:\projects\CycleGAN return parser.parse_known_args()[0] def print_options(self, opt): """ 1. 同时打印当前选项和默认值(如果不同)。 2. 将选项保存到一个文本文件 /[checkpoints_dir]/opt.txt 中 """ message = "" message += "----------------- Options ---------------\n" for k, v in sorted(vars(opt).items()): comment = "" default = self.parser.get_default(k) if v != default: comment = "\t[default: %s]" % str(default) message += "{:>25}: {:<30}{}\n".format(str(k), str(v), comment) message += "----------------- End -------------------" print(message) # save it to the disk expr_dir = Path(opt.checkpoints_dir, opt.name) util.mkdirs(expr_dir) file_name = Path(expr_dir, "{}_opt.txt".format(opt.phase)) with open(file_name, "wt") as opt_file: opt_file.write(message) opt_file.write("\n") def parse(self): """解析选项,创建检查点目录后缀,并设置GPU设备。""" opt = self.gather_options() opt.isTrain = self.isTrain # train or test # process opt.suffix if opt.suffix: suffix = ("_" + opt.suffix.format(**vars(opt))) if opt.suffix != "" else "" opt.name = opt.name + suffix self.print_options(opt) # set gpu ids _ids = [int(str_id) for str_id in opt.gpu_ids.split(",")] opt.gpu_ids = [_id for _id in _ids if _id >= 0] if len(opt.gpu_ids) > 0: torch.cuda.set_device(opt.gpu_ids[0]) self.opt = opt return self.opt