Spaces:
Running
Running
File size: 7,188 Bytes
f2fa83b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import torch
import argparse
import os
from models.anime_gan import GeneratorV1
from models.anime_gan_v2 import GeneratorV2
from models.anime_gan_v3 import GeneratorV3
from models.anime_gan import Discriminator
from datasets import AnimeDataSet
from utils.common import load_checkpoint
from trainer import Trainer
from utils.logger import get_logger
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--real_image_dir', type=str, default='dataset/train_photo')
parser.add_argument('--anime_image_dir', type=str, default='dataset/Hayao')
parser.add_argument('--test_image_dir', type=str, default='dataset/test/HR_photo')
parser.add_argument('--model', type=str, default='v1', help="AnimeGAN version, can be {'v1', 'v2', 'v3'}")
parser.add_argument('--epochs', type=int, default=70)
parser.add_argument('--init_epochs', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--exp_dir', type=str, default='runs', help="Experiment directory")
parser.add_argument('--gan_loss', type=str, default='lsgan', help='lsgan / hinge / bce')
parser.add_argument('--resume', action='store_true', help="Continue from current dir")
parser.add_argument('--resume_G_init', type=str, default='False')
parser.add_argument('--resume_G', type=str, default='False')
parser.add_argument('--resume_D', type=str, default='False')
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--use_sn', action='store_true')
parser.add_argument('--cache', action='store_true', help="Turn on disk cache")
parser.add_argument('--amp', action='store_true', help="Turn on Automatic Mixed Precision")
parser.add_argument('--save_interval', type=int, default=1)
parser.add_argument('--debug_samples', type=int, default=0)
parser.add_argument('--num_workers', type=int, default=2)
parser.add_argument('--imgsz', type=int, nargs="+", default=[256],
help="Image sizes, can provide multiple values, image size will increase after a proportion of epochs")
parser.add_argument('--resize_method', type=str, default="crop",
help="Resize image method if origin photo larger than imgsz")
# Loss stuff
parser.add_argument('--lr_g', type=float, default=2e-5)
parser.add_argument('--lr_d', type=float, default=4e-5)
parser.add_argument('--init_lr', type=float, default=1e-4)
parser.add_argument('--wadvg', type=float, default=300.0, help='Adversarial loss weight for G')
parser.add_argument('--wadvd', type=float, default=300.0, help='Adversarial loss weight for D')
parser.add_argument(
'--gray_adv', action='store_true',
help="If given, train adversarial with gray scale image instead of RGB image to reduce color effect of anime style")
# Loss weight VGG19
parser.add_argument('--wcon', type=float, default=1.5, help='Content loss weight') # 1.5 for Hayao, 2.0 for Paprika, 1.2 for Shinkai
parser.add_argument('--wgra', type=float, default=5.0, help='Gram loss weight') # 2.5 for Hayao, 0.6 for Paprika, 2.0 for Shinkai
parser.add_argument('--wcol', type=float, default=30.0, help='Color loss weight') # 15. for Hayao, 50. for Paprika, 10. for Shinkai
parser.add_argument('--wtvar', type=float, default=1.0, help='Total variation loss') # 1. for Hayao, 0.1 for Paprika, 1. for Shinkai
parser.add_argument('--d_layers', type=int, default=2, help='Discriminator conv layers')
parser.add_argument('--d_noise', action='store_true')
# DDP
parser.add_argument('--ddp', action='store_true')
parser.add_argument("--local-rank", default=0, type=int)
parser.add_argument("--world-size", default=2, type=int)
return parser.parse_args()
def check_params(args):
# dataset/Hayao + dataset/train_photo -> train_photo_Hayao
args.dataset = f"{os.path.basename(args.real_image_dir)}_{os.path.basename(args.anime_image_dir)}"
assert args.gan_loss in {'lsgan', 'hinge', 'bce'}, f'{args.gan_loss} is not supported'
def main(args, logger):
check_params(args)
if not torch.cuda.is_available():
logger.info("CUDA not found, use CPU")
# Just for debugging purpose, set to minimum config
# to avoid 🔥 the computer...
args.device = 'cpu'
args.debug_samples = 10
args.batch_size = 2
else:
logger.info(f"Use GPU: {torch.cuda.get_device_name(0)}")
norm_type = "instance"
if args.model == 'v1':
G = GeneratorV1(args.dataset)
elif args.model == 'v2':
G = GeneratorV2(args.dataset)
norm_type = "layer"
elif args.model == 'v3':
G = GeneratorV3(args.dataset)
D = Discriminator(
args.dataset,
num_layers=args.d_layers,
use_sn=args.use_sn,
norm_type=norm_type,
)
start_e = 0
start_e_init = 0
trainer = Trainer(
generator=G,
discriminator=D,
config=args,
logger=logger,
)
if args.resume_G_init.lower() != 'false':
start_e_init = load_checkpoint(G, args.resume_G_init) + 1
if args.local_rank == 0:
logger.info(f"G content weight loaded from {args.resume_G_init}")
elif args.resume_G.lower() != 'false' and args.resume_D.lower() != 'false':
# You should provide both
try:
start_e = load_checkpoint(G, args.resume_G)
if args.local_rank == 0:
logger.info(f"G weight loaded from {args.resume_G}")
load_checkpoint(D, args.resume_D)
if args.local_rank == 0:
logger.info(f"D weight loaded from {args.resume_D}")
# If loaded both weight, turn off init G phrase
args.init_epochs = 0
except Exception as e:
print('Could not load checkpoint, train from scratch', e)
elif args.resume:
# Try to load from working dir
logger.info(f"Loading weight from {trainer.checkpoint_path_G}")
start_e = load_checkpoint(G, trainer.checkpoint_path_G)
logger.info(f"Loading weight from {trainer.checkpoint_path_D}")
load_checkpoint(D, trainer.checkpoint_path_D)
args.init_epochs = 0
dataset = AnimeDataSet(
args.anime_image_dir,
args.real_image_dir,
args.debug_samples,
args.cache,
imgsz=args.imgsz,
resize_method=args.resize_method,
)
if args.local_rank == 0:
logger.info(f"Start from epoch {start_e}, {start_e_init}")
trainer.train(dataset, start_e, start_e_init)
if __name__ == '__main__':
args = parse_args()
real_name = os.path.basename(args.real_image_dir)
anime_name = os.path.basename(args.anime_image_dir)
args.exp_dir = f"{args.exp_dir}_{real_name}_{anime_name}"
os.makedirs(args.exp_dir, exist_ok=True)
logger = get_logger(os.path.join(args.exp_dir, "train.log"))
if args.local_rank == 0:
logger.info("# ==== Train Config ==== #")
for arg in vars(args):
logger.info(f"{arg} {getattr(args, arg)}")
logger.info("==========================")
main(args, logger)
|