File size: 8,102 Bytes
58da73e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import argparse
from pathlib import Path
import torch
import data
import models
from util import util
class BaseOptions:
"""模型超参数设置
"""
def __init__(self):
"""重置该类:表示该类尚未被初始化
训练模型、测试模型、推理模型的参数全继承该类
"""
self.initialized = False
def initialize(self, parser):
"""共有参数"""
# basic parameters
parser.add_argument("--dataroot", type=str, default="./datasets/horse2zebra/", help="path to images (should have subfolders trainA, trainB, valA, valB, etc)")
parser.add_argument("--name", type=str, default="horse2zebra", help="name of the experiment.")
parser.add_argument("--gpu_ids", type=str, default="0", help="gpu ids: e.g. 0 0,1,2 -1 for CPU")
parser.add_argument("--checkpoints_dir", type=str, default="./checkpoints", help="models_saved")
# model parameters
parser.add_argument("--model", type=str, default="cycle_gan")
parser.add_argument("--input_nc", type=int, default=3, help="# input image channels: 3 for RGB and 1 for grayscale")
parser.add_argument("--output_nc", type=int, default=3, help="# output image channels: 3 for RGB and 1 for grayscale")
parser.add_argument("--ngf", type=int, default=64, help="# of gen filters in the last conv layer")
parser.add_argument("--ndf", type=int, default=64, help="# of discrim filters in the first conv layer")
parser.add_argument(
"--netD",
type=str,
default="basic",
help="[basic | n_layers | pixel]. basic: a 70x70 PatchGAN. n_layers: allows you to specify the layers in the discriminator",
)
parser.add_argument("--netG", type=str, default="resnet_9blocks", help="[resnet_9blocks | resnet_6blocks | unet_256 | unet_128]")
parser.add_argument("--n_layers_D", type=int, default=3, help="only used if netD==n_layers")
parser.add_argument("--norm", type=str, default="instance", help="instance normalization or batch normalization [instance | batch | none]")
parser.add_argument("--init_type", type=str, default="normal", help="network initialization [normal | xavier | kaiming | orthogonal]")
parser.add_argument("--init_gain", type=float, default=0.02, help="scaling factor for normal, xavier and orthogonal.")
parser.add_argument("--no_dropout", type=bool, default=True, help="no dropout for the generator")
# dataset parameters
parser.add_argument("--dataset_mode", type=str, default="unaligned")
parser.add_argument("--direction", type=str, default="AtoB", help="AtoB or BtoA")
parser.add_argument("--serial_batches", action="store_true", help="if true, takes images in order to make batches, otherwise takes them randomly")
parser.add_argument("--num_threads", default=8, type=int, help="# threads for loading data")
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--load_size", type=int, default=286, help="scale images to this size")
parser.add_argument("--crop_size", type=int, default=256, help="then crop to this size")
parser.add_argument(
"--max_dataset_size",
type=int,
default=float("inf"),
help="Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.",
)
parser.add_argument("--preprocess", type=str, default="resize_and_crop", help="[resize_and_crop | crop | scale_width | scale_width_and_crop | none] img preprocess")
parser.add_argument("--no_flip", action="store_true", help="if specified, do not flip the images for data augmentation")
parser.add_argument("--display_winsize", type=int, default=256, help="display window size for both visdom and HTML")
# additional parameters
parser.add_argument("--epoch", type=str, default="latest", help="which epoch to load? set to latest to use latest cached model")
parser.add_argument(
"--load_iter", type=int, default="0", help="which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]"
)
parser.add_argument("--verbose", action="store_true", help="if specified, print more debugging information")
parser.add_argument("--suffix", default="", type=str, help="customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}")
# wandb parameters
parser.add_argument("--use_wandb", action="store_true", help="wandb logging")
parser.add_argument("--wandb_project_name", type=str, default="CycleGAN", help="wandb project name")
self.initialized = True
return parser
def gather_options(self):
"""Initialize our parser with basic options(only once).
Add additional model-specific and dataset-specific options.
These options are defined in the <modify_commandline_options> function
in model and dataset classes.
"""
if not self.initialized: # check if it has been initialized
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = self.initialize(parser)
# get the basic options
opt, _ = parser.parse_known_args()
# modify model-related parser options
model_name = opt.model
model_option_setter = models.get_option_setter(model_name)
parser = model_option_setter(parser, self.isTrain)
opt, _ = parser.parse_known_args() # parse again with new defaults
# modify dataset-related parser options
dataset_name = opt.dataset_mode
dataset_option_setter = data.get_option_setter(dataset_name)
parser = dataset_option_setter(parser, self.isTrain)
# save and return the parser
self.parser = parser
# return parser.parse_args()
# 用上面的会报以下错误,命令行参数不能合并,gradio运行时会调用命令行参数,而我们的没有
# uvicorn: error: unrecognized arguments: app:demo.app --reload --port 7860 --log-level warning --reload-dir E:\miniconda3\envs\yanguan\lib\site-packages\gradio --reload-dir D:\projects\CycleGAN
return parser.parse_known_args()[0]
def print_options(self, opt):
"""
1. 同时打印当前选项和默认值(如果不同)。
2. 将选项保存到一个文本文件 /[checkpoints_dir]/opt.txt 中
"""
message = ""
message += "----------------- Options ---------------\n"
for k, v in sorted(vars(opt).items()):
comment = ""
default = self.parser.get_default(k)
if v != default:
comment = "\t[default: %s]" % str(default)
message += "{:>25}: {:<30}{}\n".format(str(k), str(v), comment)
message += "----------------- End -------------------"
print(message)
# save it to the disk
expr_dir = Path(opt.checkpoints_dir, opt.name)
util.mkdirs(expr_dir)
file_name = Path(expr_dir, "{}_opt.txt".format(opt.phase))
with open(file_name, "wt") as opt_file:
opt_file.write(message)
opt_file.write("\n")
def parse(self):
"""解析选项,创建检查点目录后缀,并设置GPU设备。"""
opt = self.gather_options()
opt.isTrain = self.isTrain # train or test
# process opt.suffix
if opt.suffix:
suffix = ("_" + opt.suffix.format(**vars(opt))) if opt.suffix != "" else ""
opt.name = opt.name + suffix
self.print_options(opt)
# set gpu ids
_ids = [int(str_id) for str_id in opt.gpu_ids.split(",")]
opt.gpu_ids = [_id for _id in _ids if _id >= 0]
if len(opt.gpu_ids) > 0:
torch.cuda.set_device(opt.gpu_ids[0])
self.opt = opt
return self.opt
|