import os import time import shutil import torch import cv2 import torch.optim as optim import numpy as np from glob import glob from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel.distributed import DistributedDataParallel from torch.utils.data import Dataset, DataLoader from tqdm import tqdm from utils.image_processing import denormalize_input, preprocess_images, resize_image from losses import LossSummary, AnimeGanLoss, to_gray_scale from utils import load_checkpoint, save_checkpoint, read_image from utils.common import set_lr from color_transfer import color_transfer_pytorch def transfer_color_and_rescale(src, target): """Transfer color from src image to target then rescale to [-1, 1]""" out = color_transfer_pytorch(src, target) # [0, 1] out = (out / 0.5) - 1 return out def gaussian_noise(): gaussian_mean = torch.tensor(0.0) gaussian_std = torch.tensor(0.1) return torch.normal(gaussian_mean, gaussian_std) def convert_to_readable(seconds): return time.strftime('%H:%M:%S', time.gmtime(seconds)) def revert_to_np_image(image_tensor): image = image_tensor.cpu().numpy() # CHW image = image.transpose(1, 2, 0) image = denormalize_input(image, dtype=np.int16) return image[..., ::-1] # to RGB def save_generated_images(images: torch.Tensor, save_dir: str): """Save generated images `(*, 3, H, W)` range [-1, 1] into disk""" os.makedirs(save_dir, exist_ok=True) images = images.clone().detach().cpu().numpy() images = images.transpose(0, 2, 3, 1) n_images = len(images) for i in range(n_images): img = images[i] img = denormalize_input(img, dtype=np.int16) img = img[..., ::-1] cv2.imwrite(os.path.join(save_dir, f"G{i}.jpg"), img) class DDPTrainer: def _init_distributed(self): if self.cfg.ddp: self.logger.info("Setting up DDP") self.pg = torch.distributed.init_process_group( backend="nccl", rank=self.cfg.local_rank, world_size=self.cfg.world_size ) self.G = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.G, self.pg) self.D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.D, self.pg) torch.cuda.set_device(self.cfg.local_rank) self.G.cuda(self.cfg.local_rank) self.D.cuda(self.cfg.local_rank) self.logger.info("Setting up DDP Done") def _init_amp(self, enabled=False): # self.scaler = torch.cuda.amp.GradScaler(enabled=enabled, growth_interval=100) self.scaler_g = GradScaler(enabled=enabled) self.scaler_d = GradScaler(enabled=enabled) if self.cfg.ddp: self.G = DistributedDataParallel( self.G, device_ids=[self.cfg.local_rank], output_device=self.cfg.local_rank, find_unused_parameters=False) self.D = DistributedDataParallel( self.D, device_ids=[self.cfg.local_rank], output_device=self.cfg.local_rank, find_unused_parameters=False) self.logger.info("Set DistributedDataParallel") class Trainer(DDPTrainer): """ Base Trainer class """ def __init__( self, generator, discriminator, config, logger, ) -> None: self.G = generator self.D = discriminator self.cfg = config self.max_norm = 10 self.device_type = 'cuda' if self.cfg.device.startswith('cuda') else 'cpu' self.optimizer_g = optim.Adam(self.G.parameters(), lr=self.cfg.lr_g, betas=(0.5, 0.999)) self.optimizer_d = optim.Adam(self.D.parameters(), lr=self.cfg.lr_d, betas=(0.5, 0.999)) self.loss_tracker = LossSummary() if self.cfg.ddp: self.device = torch.device(f"cuda:{self.cfg.local_rank}") logger.info(f"---------{self.cfg.local_rank} {self.device}") else: self.device = torch.device(self.cfg.device) self.loss_fn = AnimeGanLoss(self.cfg, self.device, self.cfg.gray_adv) self.logger = logger self._init_working_dir() self._init_distributed() self._init_amp(enabled=self.cfg.amp) def _init_working_dir(self): """Init working directory for saving checkpoint, ...""" os.makedirs(self.cfg.exp_dir, exist_ok=True) Gname = self.G.name Dname = self.D.name self.checkpoint_path_G_init = os.path.join(self.cfg.exp_dir, f"{Gname}_init.pt") self.checkpoint_path_G = os.path.join(self.cfg.exp_dir, f"{Gname}.pt") self.checkpoint_path_D = os.path.join(self.cfg.exp_dir, f"{Dname}.pt") self.save_image_dir = os.path.join(self.cfg.exp_dir, "generated_images") self.example_image_dir = os.path.join(self.cfg.exp_dir, "train_images") os.makedirs(self.save_image_dir, exist_ok=True) os.makedirs(self.example_image_dir, exist_ok=True) def init_weight_G(self, weight: str): """Init Generator weight""" return load_checkpoint(self.G, weight) def init_weight_D(self, weight: str): """Init Discriminator weight""" return load_checkpoint(self.D, weight) def pretrain_generator(self, train_loader, start_epoch): """ Pretrain Generator to recontruct input image. """ init_losses = [] set_lr(self.optimizer_g, self.cfg.init_lr) for epoch in range(start_epoch, self.cfg.init_epochs): # Train with content loss only pbar = tqdm(train_loader) for data in pbar: img = data["image"].to(self.device) self.optimizer_g.zero_grad() with autocast(enabled=self.cfg.amp): fake_img = self.G(img) loss = self.loss_fn.content_loss_vgg(img, fake_img) self.scaler_g.scale(loss).backward() self.scaler_g.step(self.optimizer_g) self.scaler_g.update() if self.cfg.ddp: torch.distributed.barrier() init_losses.append(loss.cpu().detach().numpy()) avg_content_loss = sum(init_losses) / len(init_losses) pbar.set_description(f'[Init Training G] content loss: {avg_content_loss:2f}') save_checkpoint(self.G, self.checkpoint_path_G_init, self.optimizer_g, epoch) if self.cfg.local_rank == 0: self.generate_and_save(self.cfg.test_image_dir, subname='initg') self.logger.info(f"Epoch {epoch}/{self.cfg.init_epochs}") set_lr(self.optimizer_g, self.cfg.lr_g) def train_epoch(self, epoch, train_loader): pbar = tqdm(train_loader, total=len(train_loader)) for data in pbar: img = data["image"].to(self.device) anime = data["anime"].to(self.device) anime_gray = data["anime_gray"].to(self.device) anime_smt_gray = data["smooth_gray"].to(self.device) # ---------------- TRAIN D ---------------- # self.optimizer_d.zero_grad() with autocast(enabled=self.cfg.amp): fake_img = self.G(img) # Add some Gaussian noise to images before feeding to D if self.cfg.d_noise: fake_img += gaussian_noise() anime += gaussian_noise() anime_gray += gaussian_noise() anime_smt_gray += gaussian_noise() if self.cfg.gray_adv: fake_img = to_gray_scale(fake_img) fake_d = self.D(fake_img) real_anime_d = self.D(anime) real_anime_gray_d = self.D(anime_gray) real_anime_smt_gray_d = self.D(anime_smt_gray) loss_d = self.loss_fn.compute_loss_D( fake_d, real_anime_d, real_anime_gray_d, real_anime_smt_gray_d ) self.scaler_d.scale(loss_d).backward() self.scaler_d.unscale_(self.optimizer_d) torch.nn.utils.clip_grad_norm_(self.D.parameters(), max_norm=self.max_norm) self.scaler_d.step(self.optimizer_d) self.scaler_d.update() if self.cfg.ddp: torch.distributed.barrier() self.loss_tracker.update_loss_D(loss_d) # ---------------- TRAIN G ---------------- # self.optimizer_g.zero_grad() with autocast(enabled=self.cfg.amp): fake_img = self.G(img) if self.cfg.gray_adv: fake_d = self.D(to_gray_scale(fake_img)) else: fake_d = self.D(fake_img) ( adv_loss, con_loss, gra_loss, col_loss, tv_loss ) = self.loss_fn.compute_loss_G( fake_img, img, fake_d, anime_gray, ) loss_g = adv_loss + con_loss + gra_loss + col_loss + tv_loss if torch.isnan(adv_loss).any(): self.logger.info("----------------------------------------------") self.logger.info(fake_d) self.logger.info(adv_loss) self.logger.info("----------------------------------------------") raise ValueError("NAN loss!!") self.scaler_g.scale(loss_g).backward() self.scaler_d.unscale_(self.optimizer_g) grad = torch.nn.utils.clip_grad_norm_(self.G.parameters(), max_norm=self.max_norm) self.scaler_g.step(self.optimizer_g) self.scaler_g.update() if self.cfg.ddp: torch.distributed.barrier() self.loss_tracker.update_loss_G(adv_loss, gra_loss, col_loss, con_loss) pbar.set_description(f"{self.loss_tracker.get_loss_description()} - {grad:.3f}") def get_train_loader(self, dataset): if self.cfg.ddp: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) else: train_sampler = None return DataLoader( dataset, batch_size=self.cfg.batch_size, num_workers=self.cfg.num_workers, pin_memory=True, shuffle=train_sampler is None, sampler=train_sampler, drop_last=True, # collate_fn=collate_fn, ) def maybe_increase_imgsz(self, epoch, train_dataset): """ Increase image size at specific epoch + 50% epochs train at imgsz[0] + the rest 50% will increase every `len(epochs) / 2 / (len(imgsz) - 1)` Args: epoch: Current epoch train_dataset: Dataset Examples: ``` epochs = 100 imgsz = [256, 352, 416, 512] => [(0, 256), (50, 352), (66, 416), (82, 512)] ``` """ epochs = self.cfg.epochs imgsz = self.cfg.imgsz num_size_remains = len(imgsz) - 1 half_epochs = epochs // 2 if len(imgsz) == 1: new_size = imgsz[0] elif epoch < half_epochs: new_size = imgsz[0] else: per_epoch_increment = int(half_epochs / num_size_remains) found = None for i, size in enumerate(imgsz[:]): if epoch < half_epochs + per_epoch_increment * i: found = size break if not found: found = imgsz[-1] new_size = found self.logger.info(f"Check {imgsz}, {new_size}, {train_dataset.imgsz}") if new_size != train_dataset.imgsz: train_dataset.set_imgsz(new_size) self.logger.info(f"Increase image size to {new_size} at epoch {epoch}") def train(self, train_dataset: Dataset, start_epoch=0, start_epoch_g=0): """ Train Generator and Discriminator. """ self.logger.info(self.device) self.G.to(self.device) self.D.to(self.device) self.pretrain_generator(self.get_train_loader(train_dataset), start_epoch_g) if self.cfg.local_rank == 0: self.logger.info(f"Start training for {self.cfg.epochs} epochs") for i, data in enumerate(train_dataset): for k in data.keys(): image = data[k] cv2.imwrite( os.path.join(self.example_image_dir, f"data_{k}_{i}.jpg"), revert_to_np_image(image) ) if i == 2: break end = None num_iter = 0 per_epoch_times = [] for epoch in range(start_epoch, self.cfg.epochs): self.maybe_increase_imgsz(epoch, train_dataset) start = time.time() self.train_epoch(epoch, self.get_train_loader(train_dataset)) if epoch % self.cfg.save_interval == 0 and self.cfg.local_rank == 0: save_checkpoint(self.G, self.checkpoint_path_G,self.optimizer_g, epoch) save_checkpoint(self.D, self.checkpoint_path_D, self.optimizer_d, epoch) self.generate_and_save(self.cfg.test_image_dir) if epoch % 10 == 0: self.copy_results(epoch) num_iter += 1 if self.cfg.local_rank == 0: end = time.time() if end is None: eta = 9999 else: per_epoch_time = (end - start) per_epoch_times.append(per_epoch_time) eta = np.mean(per_epoch_times) * (self.cfg.epochs - epoch) eta = convert_to_readable(eta) self.logger.info(f"epoch {epoch}/{self.cfg.epochs}, ETA: {eta}") def generate_and_save( self, image_dir, max_imgs=15, subname='gen' ): ''' Generate and save images ''' start = time.time() self.G.eval() max_iter = max_imgs fake_imgs = [] real_imgs = [] image_files = glob(os.path.join(image_dir, "*")) for i, image_file in enumerate(image_files): image = read_image(image_file) image = resize_image(image) real_imgs.append(image.copy()) image = preprocess_images(image) image = image.to(self.device) with torch.no_grad(): with autocast(enabled=self.cfg.amp): fake_img = self.G(image) # fake_img = to_gray_scale(fake_img) fake_img = fake_img.detach().cpu().numpy() # Channel first -> channel last fake_img = fake_img.transpose(0, 2, 3, 1) fake_imgs.append(denormalize_input(fake_img, dtype=np.int16)[0]) if i + 1 == max_iter: break # fake_imgs = np.concatenate(fake_imgs, axis=0) for i, (real_img, fake_img) in enumerate(zip(real_imgs, fake_imgs)): img = np.concatenate((real_img, fake_img), axis=1) # Concate aross width save_path = os.path.join(self.save_image_dir, f'{subname}_{i}.jpg') if not cv2.imwrite(save_path, img[..., ::-1]): self.logger.info(f"Save generated image failed, {save_path}, {img.shape}") elapsed = time.time() - start self.logger.info(f"Generated {len(fake_imgs)} images in {elapsed:.3f}s.") def copy_results(self, epoch): """Copy result (Weight + Generated images) to each epoch folder Every N epoch """ copy_dir = os.path.join(self.cfg.exp_dir, f"epoch_{epoch}") os.makedirs(copy_dir, exist_ok=True) shutil.copy2( self.checkpoint_path_G, copy_dir ) dest = os.path.join(copy_dir, os.path.basename(self.save_image_dir)) shutil.copytree( self.save_image_dir, dest, dirs_exist_ok=True )