CycleGAN / options /base_options.py
Yanguan's picture
0
58da73e
import argparse
from pathlib import Path
import torch
import data
import models
from util import util
class BaseOptions:
"""模型超参数设置
"""
def __init__(self):
"""重置该类:表示该类尚未被初始化
训练模型、测试模型、推理模型的参数全继承该类
"""
self.initialized = False
def initialize(self, parser):
"""共有参数"""
# basic parameters
parser.add_argument("--dataroot", type=str, default="./datasets/horse2zebra/", help="path to images (should have subfolders trainA, trainB, valA, valB, etc)")
parser.add_argument("--name", type=str, default="horse2zebra", help="name of the experiment.")
parser.add_argument("--gpu_ids", type=str, default="0", help="gpu ids: e.g. 0 0,1,2 -1 for CPU")
parser.add_argument("--checkpoints_dir", type=str, default="./checkpoints", help="models_saved")
# model parameters
parser.add_argument("--model", type=str, default="cycle_gan")
parser.add_argument("--input_nc", type=int, default=3, help="# input image channels: 3 for RGB and 1 for grayscale")
parser.add_argument("--output_nc", type=int, default=3, help="# output image channels: 3 for RGB and 1 for grayscale")
parser.add_argument("--ngf", type=int, default=64, help="# of gen filters in the last conv layer")
parser.add_argument("--ndf", type=int, default=64, help="# of discrim filters in the first conv layer")
parser.add_argument(
"--netD",
type=str,
default="basic",
help="[basic | n_layers | pixel]. basic: a 70x70 PatchGAN. n_layers: allows you to specify the layers in the discriminator",
)
parser.add_argument("--netG", type=str, default="resnet_9blocks", help="[resnet_9blocks | resnet_6blocks | unet_256 | unet_128]")
parser.add_argument("--n_layers_D", type=int, default=3, help="only used if netD==n_layers")
parser.add_argument("--norm", type=str, default="instance", help="instance normalization or batch normalization [instance | batch | none]")
parser.add_argument("--init_type", type=str, default="normal", help="network initialization [normal | xavier | kaiming | orthogonal]")
parser.add_argument("--init_gain", type=float, default=0.02, help="scaling factor for normal, xavier and orthogonal.")
parser.add_argument("--no_dropout", type=bool, default=True, help="no dropout for the generator")
# dataset parameters
parser.add_argument("--dataset_mode", type=str, default="unaligned")
parser.add_argument("--direction", type=str, default="AtoB", help="AtoB or BtoA")
parser.add_argument("--serial_batches", action="store_true", help="if true, takes images in order to make batches, otherwise takes them randomly")
parser.add_argument("--num_threads", default=8, type=int, help="# threads for loading data")
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--load_size", type=int, default=286, help="scale images to this size")
parser.add_argument("--crop_size", type=int, default=256, help="then crop to this size")
parser.add_argument(
"--max_dataset_size",
type=int,
default=float("inf"),
help="Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.",
)
parser.add_argument("--preprocess", type=str, default="resize_and_crop", help="[resize_and_crop | crop | scale_width | scale_width_and_crop | none] img preprocess")
parser.add_argument("--no_flip", action="store_true", help="if specified, do not flip the images for data augmentation")
parser.add_argument("--display_winsize", type=int, default=256, help="display window size for both visdom and HTML")
# additional parameters
parser.add_argument("--epoch", type=str, default="latest", help="which epoch to load? set to latest to use latest cached model")
parser.add_argument(
"--load_iter", type=int, default="0", help="which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]"
)
parser.add_argument("--verbose", action="store_true", help="if specified, print more debugging information")
parser.add_argument("--suffix", default="", type=str, help="customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}")
# wandb parameters
parser.add_argument("--use_wandb", action="store_true", help="wandb logging")
parser.add_argument("--wandb_project_name", type=str, default="CycleGAN", help="wandb project name")
self.initialized = True
return parser
def gather_options(self):
"""Initialize our parser with basic options(only once).
Add additional model-specific and dataset-specific options.
These options are defined in the <modify_commandline_options> function
in model and dataset classes.
"""
if not self.initialized: # check if it has been initialized
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = self.initialize(parser)
# get the basic options
opt, _ = parser.parse_known_args()
# modify model-related parser options
model_name = opt.model
model_option_setter = models.get_option_setter(model_name)
parser = model_option_setter(parser, self.isTrain)
opt, _ = parser.parse_known_args() # parse again with new defaults
# modify dataset-related parser options
dataset_name = opt.dataset_mode
dataset_option_setter = data.get_option_setter(dataset_name)
parser = dataset_option_setter(parser, self.isTrain)
# save and return the parser
self.parser = parser
# return parser.parse_args()
# 用上面的会报以下错误,命令行参数不能合并,gradio运行时会调用命令行参数,而我们的没有
# uvicorn: error: unrecognized arguments: app:demo.app --reload --port 7860 --log-level warning --reload-dir E:\miniconda3\envs\yanguan\lib\site-packages\gradio --reload-dir D:\projects\CycleGAN
return parser.parse_known_args()[0]
def print_options(self, opt):
"""
1. 同时打印当前选项和默认值(如果不同)。
2. 将选项保存到一个文本文件 /[checkpoints_dir]/opt.txt 中
"""
message = ""
message += "----------------- Options ---------------\n"
for k, v in sorted(vars(opt).items()):
comment = ""
default = self.parser.get_default(k)
if v != default:
comment = "\t[default: %s]" % str(default)
message += "{:>25}: {:<30}{}\n".format(str(k), str(v), comment)
message += "----------------- End -------------------"
print(message)
# save it to the disk
expr_dir = Path(opt.checkpoints_dir, opt.name)
util.mkdirs(expr_dir)
file_name = Path(expr_dir, "{}_opt.txt".format(opt.phase))
with open(file_name, "wt") as opt_file:
opt_file.write(message)
opt_file.write("\n")
def parse(self):
"""解析选项,创建检查点目录后缀,并设置GPU设备。"""
opt = self.gather_options()
opt.isTrain = self.isTrain # train or test
# process opt.suffix
if opt.suffix:
suffix = ("_" + opt.suffix.format(**vars(opt))) if opt.suffix != "" else ""
opt.name = opt.name + suffix
self.print_options(opt)
# set gpu ids
_ids = [int(str_id) for str_id in opt.gpu_ids.split(",")]
opt.gpu_ids = [_id for _id in _ids if _id >= 0]
if len(opt.gpu_ids) > 0:
torch.cuda.set_device(opt.gpu_ids[0])
self.opt = opt
return self.opt