File size: 8,102 Bytes
58da73e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import argparse
from pathlib import Path

import torch

import data
import models
from util import util


class BaseOptions:
    """模型超参数设置
    """
    def __init__(self):
        """重置该类:表示该类尚未被初始化
        训练模型、测试模型、推理模型的参数全继承该类
        """
        self.initialized = False

    def initialize(self, parser):
        """共有参数"""
        # basic parameters
        parser.add_argument("--dataroot", type=str, default="./datasets/horse2zebra/", help="path to images (should have subfolders trainA, trainB, valA, valB, etc)")
        parser.add_argument("--name", type=str, default="horse2zebra", help="name of the experiment.")
        parser.add_argument("--gpu_ids", type=str, default="0", help="gpu ids: e.g. 0  0,1,2 -1 for CPU")
        parser.add_argument("--checkpoints_dir", type=str, default="./checkpoints", help="models_saved")
        # model parameters
        parser.add_argument("--model", type=str, default="cycle_gan")
        parser.add_argument("--input_nc", type=int, default=3, help="# input image channels: 3 for RGB and 1 for grayscale")
        parser.add_argument("--output_nc", type=int, default=3, help="# output image channels: 3 for RGB and 1 for grayscale")
        parser.add_argument("--ngf", type=int, default=64, help="# of gen filters in the last conv layer")
        parser.add_argument("--ndf", type=int, default=64, help="# of discrim filters in the first conv layer")
        parser.add_argument(
            "--netD",
            type=str,
            default="basic",
            help="[basic | n_layers | pixel]. basic: a 70x70 PatchGAN. n_layers: allows you to specify the layers in the discriminator",
        )
        parser.add_argument("--netG", type=str, default="resnet_9blocks", help="[resnet_9blocks | resnet_6blocks | unet_256 | unet_128]")
        parser.add_argument("--n_layers_D", type=int, default=3, help="only used if netD==n_layers")
        parser.add_argument("--norm", type=str, default="instance", help="instance normalization or batch normalization [instance | batch | none]")
        parser.add_argument("--init_type", type=str, default="normal", help="network initialization [normal | xavier | kaiming | orthogonal]")
        parser.add_argument("--init_gain", type=float, default=0.02, help="scaling factor for normal, xavier and orthogonal.")
        parser.add_argument("--no_dropout", type=bool, default=True, help="no dropout for the generator")
        # dataset parameters
        parser.add_argument("--dataset_mode", type=str, default="unaligned")
        parser.add_argument("--direction", type=str, default="AtoB", help="AtoB or BtoA")
        parser.add_argument("--serial_batches", action="store_true", help="if true, takes images in order to make batches, otherwise takes them randomly")
        parser.add_argument("--num_threads", default=8, type=int, help="# threads for loading data")
        parser.add_argument("--batch_size", type=int, default=1)
        parser.add_argument("--load_size", type=int, default=286, help="scale images to this size")
        parser.add_argument("--crop_size", type=int, default=256, help="then crop to this size")
        parser.add_argument(
            "--max_dataset_size",
            type=int,
            default=float("inf"),
            help="Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.",
        )
        parser.add_argument("--preprocess", type=str, default="resize_and_crop", help="[resize_and_crop | crop | scale_width | scale_width_and_crop | none] img preprocess")
        parser.add_argument("--no_flip", action="store_true", help="if specified, do not flip the images for data augmentation")
        parser.add_argument("--display_winsize", type=int, default=256, help="display window size for both visdom and HTML")
        # additional parameters
        parser.add_argument("--epoch", type=str, default="latest", help="which epoch to load? set to latest to use latest cached model")
        parser.add_argument(
            "--load_iter", type=int, default="0", help="which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]"
        )
        parser.add_argument("--verbose", action="store_true", help="if specified, print more debugging information")
        parser.add_argument("--suffix", default="", type=str, help="customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}")
        # wandb parameters
        parser.add_argument("--use_wandb", action="store_true", help="wandb logging")
        parser.add_argument("--wandb_project_name", type=str, default="CycleGAN", help="wandb project name")
        self.initialized = True
        return parser

    def gather_options(self):
        """Initialize our parser with basic options(only once).
        Add additional model-specific and dataset-specific options.
        These options are defined in the <modify_commandline_options> function
        in model and dataset classes.
        """
        if not self.initialized:  # check if it has been initialized
            parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
            parser = self.initialize(parser)

        # get the basic options
        opt, _ = parser.parse_known_args()

        # modify model-related parser options
        model_name = opt.model
        model_option_setter = models.get_option_setter(model_name)
        parser = model_option_setter(parser, self.isTrain)
        opt, _ = parser.parse_known_args()  # parse again with new defaults

        # modify dataset-related parser options
        dataset_name = opt.dataset_mode
        dataset_option_setter = data.get_option_setter(dataset_name)
        parser = dataset_option_setter(parser, self.isTrain)

        # save and return the parser
        self.parser = parser
        # return parser.parse_args()
        # 用上面的会报以下错误,命令行参数不能合并,gradio运行时会调用命令行参数,而我们的没有
        # uvicorn: error: unrecognized arguments: app:demo.app --reload --port 7860 --log-level warning --reload-dir E:\miniconda3\envs\yanguan\lib\site-packages\gradio --reload-dir D:\projects\CycleGAN
        return parser.parse_known_args()[0]

    def print_options(self, opt):
        """
        1. 同时打印当前选项和默认值(如果不同)。
        2. 将选项保存到一个文本文件 /[checkpoints_dir]/opt.txt 中
        """
        message = ""
        message += "----------------- Options ---------------\n"
        for k, v in sorted(vars(opt).items()):
            comment = ""
            default = self.parser.get_default(k)
            if v != default:
                comment = "\t[default: %s]" % str(default)
            message += "{:>25}: {:<30}{}\n".format(str(k), str(v), comment)
        message += "----------------- End -------------------"
        print(message)

        # save it to the disk
        expr_dir = Path(opt.checkpoints_dir, opt.name)
        util.mkdirs(expr_dir)
        file_name = Path(expr_dir, "{}_opt.txt".format(opt.phase))
        with open(file_name, "wt") as opt_file:
            opt_file.write(message)
            opt_file.write("\n")

    def parse(self):
        """解析选项,创建检查点目录后缀,并设置GPU设备。"""
        opt = self.gather_options()
        opt.isTrain = self.isTrain  # train or test

        # process opt.suffix
        if opt.suffix:
            suffix = ("_" + opt.suffix.format(**vars(opt))) if opt.suffix != "" else ""
            opt.name = opt.name + suffix

        self.print_options(opt)

        # set gpu ids
        _ids = [int(str_id) for str_id in opt.gpu_ids.split(",")]
        opt.gpu_ids = [_id for _id in _ids if _id >= 0]
        if len(opt.gpu_ids) > 0:
            torch.cuda.set_device(opt.gpu_ids[0])

        self.opt = opt
        return self.opt