File size: 7,188 Bytes
f2fa83b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import torch
import argparse
import os
from models.anime_gan import GeneratorV1
from models.anime_gan_v2 import GeneratorV2
from models.anime_gan_v3 import GeneratorV3
from models.anime_gan import Discriminator
from datasets import AnimeDataSet
from utils.common import load_checkpoint
from trainer import Trainer
from utils.logger import get_logger


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--real_image_dir', type=str, default='dataset/train_photo')
    parser.add_argument('--anime_image_dir', type=str, default='dataset/Hayao')
    parser.add_argument('--test_image_dir', type=str, default='dataset/test/HR_photo')
    parser.add_argument('--model', type=str, default='v1', help="AnimeGAN version, can be {'v1', 'v2', 'v3'}")
    parser.add_argument('--epochs', type=int, default=70)
    parser.add_argument('--init_epochs', type=int, default=10)
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--exp_dir', type=str, default='runs', help="Experiment directory")
    parser.add_argument('--gan_loss', type=str, default='lsgan', help='lsgan / hinge / bce')
    parser.add_argument('--resume', action='store_true', help="Continue from current dir")
    parser.add_argument('--resume_G_init', type=str, default='False')
    parser.add_argument('--resume_G', type=str, default='False')
    parser.add_argument('--resume_D', type=str, default='False')
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--use_sn', action='store_true')
    parser.add_argument('--cache', action='store_true', help="Turn on disk cache")
    parser.add_argument('--amp', action='store_true', help="Turn on Automatic Mixed Precision")
    parser.add_argument('--save_interval', type=int, default=1)
    parser.add_argument('--debug_samples', type=int, default=0)
    parser.add_argument('--num_workers', type=int, default=2)
    parser.add_argument('--imgsz', type=int, nargs="+", default=[256],
                        help="Image sizes, can provide multiple values, image size will increase after a proportion of epochs")
    parser.add_argument('--resize_method', type=str, default="crop",
                        help="Resize image method if origin photo larger than imgsz")
    # Loss stuff
    parser.add_argument('--lr_g', type=float, default=2e-5)
    parser.add_argument('--lr_d', type=float, default=4e-5)
    parser.add_argument('--init_lr', type=float, default=1e-4)
    parser.add_argument('--wadvg', type=float, default=300.0, help='Adversarial loss weight for G')
    parser.add_argument('--wadvd', type=float, default=300.0, help='Adversarial loss weight for D')
    parser.add_argument(
        '--gray_adv', action='store_true',
        help="If given, train adversarial with gray scale image instead of RGB image to reduce color effect of anime style")
    # Loss weight VGG19
    parser.add_argument('--wcon', type=float, default=1.5, help='Content loss weight') #  1.5 for Hayao, 2.0 for Paprika, 1.2 for Shinkai
    parser.add_argument('--wgra', type=float, default=5.0, help='Gram loss weight') # 2.5 for Hayao, 0.6 for Paprika, 2.0 for Shinkai
    parser.add_argument('--wcol', type=float, default=30.0, help='Color loss weight') # 15. for Hayao, 50. for Paprika, 10. for Shinkai
    parser.add_argument('--wtvar', type=float, default=1.0, help='Total variation loss') # 1. for Hayao, 0.1 for Paprika, 1. for Shinkai
    parser.add_argument('--d_layers', type=int, default=2, help='Discriminator conv layers')
    parser.add_argument('--d_noise', action='store_true')

    # DDP
    parser.add_argument('--ddp', action='store_true')
    parser.add_argument("--local-rank", default=0, type=int)
    parser.add_argument("--world-size", default=2, type=int)

    return parser.parse_args()


def check_params(args):
    # dataset/Hayao + dataset/train_photo -> train_photo_Hayao
    args.dataset = f"{os.path.basename(args.real_image_dir)}_{os.path.basename(args.anime_image_dir)}"
    assert args.gan_loss in {'lsgan', 'hinge', 'bce'}, f'{args.gan_loss} is not supported'


def main(args, logger):
    check_params(args)

    if not torch.cuda.is_available():
        logger.info("CUDA not found, use CPU")
        # Just for debugging purpose, set to minimum config
        # to avoid 🔥 the computer...
        args.device = 'cpu'
        args.debug_samples = 10
        args.batch_size = 2
    else:
        logger.info(f"Use GPU: {torch.cuda.get_device_name(0)}")

    norm_type = "instance"
    if args.model == 'v1':
        G = GeneratorV1(args.dataset)
    elif args.model == 'v2':
        G = GeneratorV2(args.dataset)
        norm_type = "layer"
    elif args.model == 'v3':
        G = GeneratorV3(args.dataset)

    D = Discriminator(
        args.dataset,
        num_layers=args.d_layers,
        use_sn=args.use_sn,
        norm_type=norm_type,
    )

    start_e = 0
    start_e_init = 0

    trainer = Trainer(
        generator=G,
        discriminator=D,
        config=args,
        logger=logger,
    )

    if args.resume_G_init.lower() != 'false':
        start_e_init = load_checkpoint(G, args.resume_G_init) + 1
        if args.local_rank == 0:
            logger.info(f"G content weight loaded from {args.resume_G_init}")
    elif args.resume_G.lower() != 'false' and args.resume_D.lower() != 'false':
        # You should provide both
        try:
            start_e = load_checkpoint(G, args.resume_G)
            if args.local_rank == 0:
                logger.info(f"G weight loaded from {args.resume_G}")
            load_checkpoint(D, args.resume_D)
            if args.local_rank == 0:
                logger.info(f"D weight loaded from {args.resume_D}")
            # If loaded both weight, turn off init G phrase
            args.init_epochs = 0

        except Exception as e:
            print('Could not load checkpoint, train from scratch', e)
    elif args.resume:
        # Try to load from working dir
        logger.info(f"Loading weight from {trainer.checkpoint_path_G}")
        start_e = load_checkpoint(G, trainer.checkpoint_path_G)
        logger.info(f"Loading weight from {trainer.checkpoint_path_D}")
        load_checkpoint(D, trainer.checkpoint_path_D)
        args.init_epochs = 0
    
    dataset = AnimeDataSet(
        args.anime_image_dir,
        args.real_image_dir,
        args.debug_samples,
        args.cache,
        imgsz=args.imgsz,
        resize_method=args.resize_method,
    )
    if args.local_rank == 0:
        logger.info(f"Start from epoch {start_e}, {start_e_init}")
    trainer.train(dataset, start_e, start_e_init)

if __name__ == '__main__':
    args = parse_args()
    real_name = os.path.basename(args.real_image_dir)
    anime_name = os.path.basename(args.anime_image_dir)
    args.exp_dir = f"{args.exp_dir}_{real_name}_{anime_name}"

    os.makedirs(args.exp_dir, exist_ok=True)
    logger = get_logger(os.path.join(args.exp_dir, "train.log"))

    if args.local_rank == 0:
        logger.info("# ==== Train Config ==== #")
        for arg in vars(args):
            logger.info(f"{arg} {getattr(args, arg)}")
        logger.info("==========================")

    main(args, logger)