File size: 3,446 Bytes
58da73e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from pathlib import Path

from .base_options import BaseOptions


class TrainOptions(BaseOptions):
    """继承BaseOptions,补充训练参数
    """

    def initialize(self, parser):
        self.isTrain = True
        parser = BaseOptions.initialize(self, parser)
        # 训练过程可视化参数
        parser.add_argument("--display_freq", type=int, default=400, help="frequency of showing training results on screen")
        parser.add_argument("--display_ncols", type=int, default=4, help="if positive, display all images in a single visdom web panel with certain number of images per row.")
        parser.add_argument("--display_id", type=int, default=1, help="window id of the web display")
        parser.add_argument("--display_server", type=str, default="http://localhost", help="visdom server of the web display")
        parser.add_argument("--display_env", type=str, default="main", help="visdom display environment name")
        parser.add_argument("--display_port", type=int, default=8097, help="visdom port of the web display")
        parser.add_argument("--update_html_freq", type=int, default=1000, help="frequency of saving training results to html")
        parser.add_argument("--print_freq", type=int, default=100, help="frequency of showing training results on console")
        parser.add_argument("--no_html", action="store_true", help="do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/")
        # 网络保存和加载参数
        parser.add_argument("--save_latest_freq", type=int, default=5000, help="frequency of saving the latest results")
        parser.add_argument("--save_epoch_freq", type=int, default=1, help="frequency of saving checkpoints at the end of epochs")
        parser.add_argument("--save_by_iter", action="store_true", help="whether saves model by iteration")
        parser.add_argument("--continue_train", action="store_true", help="continue training: load the latest model")
        parser.add_argument("--epoch_count", type=int, default=1, help="the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...")
        parser.add_argument("--phase", type=str, default="train", help="train, val, test, etc")
        # 训练参数
        parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs with the initial learning rate")
        parser.add_argument("--n_epochs_decay", type=int, default=100, help="number of epochs to linearly decay learning rate to zero")
        parser.add_argument("--beta1", type=float, default=0.5, help="momentum term of adam")
        parser.add_argument("--lr", type=float, default=0.0002, help="initial learning rate for adam")
        parser.add_argument("--gan_mode", type=str, default="lsgan", help="the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss used in the original GAN")
        parser.add_argument("--pool_size", type=int, default=50, help="the size of image buffer that stores previously generated images")
        parser.add_argument("--lr_policy", type=str, default="linear", help="learning rate policy. [linear | step | plateau | cosine]")
        parser.add_argument("--lr_decay_iters", type=int, default=50, help="multiply by a gamma every lr_decay_iters iterations")
        #
        parser.set_defaults(dataroot="./datasets/horse2zebra/")
        parser.set_defaults(name="horse2zebra")
        return parser