|
import argparse |
|
from pathlib import Path |
|
|
|
import torch |
|
|
|
import data |
|
import models |
|
from util import util |
|
|
|
|
|
class BaseOptions: |
|
"""模型超参数设置 |
|
""" |
|
def __init__(self): |
|
"""重置该类:表示该类尚未被初始化 |
|
训练模型、测试模型、推理模型的参数全继承该类 |
|
""" |
|
self.initialized = False |
|
|
|
def initialize(self, parser): |
|
"""共有参数""" |
|
|
|
parser.add_argument("--dataroot", type=str, default="./datasets/horse2zebra/", help="path to images (should have subfolders trainA, trainB, valA, valB, etc)") |
|
parser.add_argument("--name", type=str, default="horse2zebra", help="name of the experiment.") |
|
parser.add_argument("--gpu_ids", type=str, default="0", help="gpu ids: e.g. 0 0,1,2 -1 for CPU") |
|
parser.add_argument("--checkpoints_dir", type=str, default="./checkpoints", help="models_saved") |
|
|
|
parser.add_argument("--model", type=str, default="cycle_gan") |
|
parser.add_argument("--input_nc", type=int, default=3, help="# input image channels: 3 for RGB and 1 for grayscale") |
|
parser.add_argument("--output_nc", type=int, default=3, help="# output image channels: 3 for RGB and 1 for grayscale") |
|
parser.add_argument("--ngf", type=int, default=64, help="# of gen filters in the last conv layer") |
|
parser.add_argument("--ndf", type=int, default=64, help="# of discrim filters in the first conv layer") |
|
parser.add_argument( |
|
"--netD", |
|
type=str, |
|
default="basic", |
|
help="[basic | n_layers | pixel]. basic: a 70x70 PatchGAN. n_layers: allows you to specify the layers in the discriminator", |
|
) |
|
parser.add_argument("--netG", type=str, default="resnet_9blocks", help="[resnet_9blocks | resnet_6blocks | unet_256 | unet_128]") |
|
parser.add_argument("--n_layers_D", type=int, default=3, help="only used if netD==n_layers") |
|
parser.add_argument("--norm", type=str, default="instance", help="instance normalization or batch normalization [instance | batch | none]") |
|
parser.add_argument("--init_type", type=str, default="normal", help="network initialization [normal | xavier | kaiming | orthogonal]") |
|
parser.add_argument("--init_gain", type=float, default=0.02, help="scaling factor for normal, xavier and orthogonal.") |
|
parser.add_argument("--no_dropout", type=bool, default=True, help="no dropout for the generator") |
|
|
|
parser.add_argument("--dataset_mode", type=str, default="unaligned") |
|
parser.add_argument("--direction", type=str, default="AtoB", help="AtoB or BtoA") |
|
parser.add_argument("--serial_batches", action="store_true", help="if true, takes images in order to make batches, otherwise takes them randomly") |
|
parser.add_argument("--num_threads", default=8, type=int, help="# threads for loading data") |
|
parser.add_argument("--batch_size", type=int, default=1) |
|
parser.add_argument("--load_size", type=int, default=286, help="scale images to this size") |
|
parser.add_argument("--crop_size", type=int, default=256, help="then crop to this size") |
|
parser.add_argument( |
|
"--max_dataset_size", |
|
type=int, |
|
default=float("inf"), |
|
help="Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.", |
|
) |
|
parser.add_argument("--preprocess", type=str, default="resize_and_crop", help="[resize_and_crop | crop | scale_width | scale_width_and_crop | none] img preprocess") |
|
parser.add_argument("--no_flip", action="store_true", help="if specified, do not flip the images for data augmentation") |
|
parser.add_argument("--display_winsize", type=int, default=256, help="display window size for both visdom and HTML") |
|
|
|
parser.add_argument("--epoch", type=str, default="latest", help="which epoch to load? set to latest to use latest cached model") |
|
parser.add_argument( |
|
"--load_iter", type=int, default="0", help="which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]" |
|
) |
|
parser.add_argument("--verbose", action="store_true", help="if specified, print more debugging information") |
|
parser.add_argument("--suffix", default="", type=str, help="customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}") |
|
|
|
parser.add_argument("--use_wandb", action="store_true", help="wandb logging") |
|
parser.add_argument("--wandb_project_name", type=str, default="CycleGAN", help="wandb project name") |
|
self.initialized = True |
|
return parser |
|
|
|
def gather_options(self): |
|
"""Initialize our parser with basic options(only once). |
|
Add additional model-specific and dataset-specific options. |
|
These options are defined in the <modify_commandline_options> function |
|
in model and dataset classes. |
|
""" |
|
if not self.initialized: |
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
|
parser = self.initialize(parser) |
|
|
|
|
|
opt, _ = parser.parse_known_args() |
|
|
|
|
|
model_name = opt.model |
|
model_option_setter = models.get_option_setter(model_name) |
|
parser = model_option_setter(parser, self.isTrain) |
|
opt, _ = parser.parse_known_args() |
|
|
|
|
|
dataset_name = opt.dataset_mode |
|
dataset_option_setter = data.get_option_setter(dataset_name) |
|
parser = dataset_option_setter(parser, self.isTrain) |
|
|
|
|
|
self.parser = parser |
|
|
|
|
|
|
|
return parser.parse_known_args()[0] |
|
|
|
def print_options(self, opt): |
|
""" |
|
1. 同时打印当前选项和默认值(如果不同)。 |
|
2. 将选项保存到一个文本文件 /[checkpoints_dir]/opt.txt 中 |
|
""" |
|
message = "" |
|
message += "----------------- Options ---------------\n" |
|
for k, v in sorted(vars(opt).items()): |
|
comment = "" |
|
default = self.parser.get_default(k) |
|
if v != default: |
|
comment = "\t[default: %s]" % str(default) |
|
message += "{:>25}: {:<30}{}\n".format(str(k), str(v), comment) |
|
message += "----------------- End -------------------" |
|
print(message) |
|
|
|
|
|
expr_dir = Path(opt.checkpoints_dir, opt.name) |
|
util.mkdirs(expr_dir) |
|
file_name = Path(expr_dir, "{}_opt.txt".format(opt.phase)) |
|
with open(file_name, "wt") as opt_file: |
|
opt_file.write(message) |
|
opt_file.write("\n") |
|
|
|
def parse(self): |
|
"""解析选项,创建检查点目录后缀,并设置GPU设备。""" |
|
opt = self.gather_options() |
|
opt.isTrain = self.isTrain |
|
|
|
|
|
if opt.suffix: |
|
suffix = ("_" + opt.suffix.format(**vars(opt))) if opt.suffix != "" else "" |
|
opt.name = opt.name + suffix |
|
|
|
self.print_options(opt) |
|
|
|
|
|
_ids = [int(str_id) for str_id in opt.gpu_ids.split(",")] |
|
opt.gpu_ids = [_id for _id in _ids if _id >= 0] |
|
if len(opt.gpu_ids) > 0: |
|
torch.cuda.set_device(opt.gpu_ids[0]) |
|
|
|
self.opt = opt |
|
return self.opt |
|
|