import argparse import json import os import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP import yaml import torch import torch.nn as nn from tqdm import tqdm from torch.utils.data import DataLoader from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR import datasets import models import utils def make_data_loader(spec, tag='', local_rank=0): if spec is None: return None dataset = datasets.make(spec['dataset']) dataset = datasets.make(spec['wrapper'], args={'dataset': dataset}) if local_rank == 0: print('{} dataset: size={}'.format(tag, len(dataset))) for k, v in dataset[0].items(): if torch.is_tensor(v): print(' {}: shape={}'.format(k, v.shape)) elif isinstance(v, str): pass elif isinstance(v, dict): for k0, v0 in v.items(): if hasattr(v0, 'shape'): print(' {}: shape={}'.format(k0, v0.shape)) else: raise NotImplementedError sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=(tag == 'train')) loader = torch.utils.data.DataLoader(dataset, batch_size=spec['batch_size'], num_workers=spec['num_workers'], pin_memory=True, sampler=sampler) return loader def make_data_loaders(config, local_rank): train_loader = make_data_loader(config.get('train_dataset'), tag='train', local_rank=local_rank) val_loader = make_data_loader(config.get('val_dataset'), tag='val', local_rank=local_rank) return train_loader, val_loader def prepare_training(config, local_rank): if config.get('resume') is not None: sv_file = torch.load(config['resume']) model = models.make(sv_file['model'], load_sd=True).cuda() model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = utils.make_optimizer( model.parameters(), sv_file['optimizer'], load_sd=True) epoch_start = sv_file['epoch'] + 1 if config.get('multi_step_lr') is None: lr_scheduler = None else: lr_scheduler = MultiStepLR(optimizer, **config['multi_step_lr']) for _ in range(epoch_start - 1): lr_scheduler.step() else: model = models.make(config['model']).cuda(local_rank) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank]) optimizer = utils.make_optimizer( model.parameters(), config['optimizer']) epoch_start = 1 lr_scheduler = config.get('lr_scheduler') lr_scheduler_name = lr_scheduler.pop('name') if 'MultiStepLR' == lr_scheduler_name: lr_scheduler = MultiStepLR(optimizer, **lr_scheduler) elif 'CosineAnnealingLR' == lr_scheduler_name: lr_scheduler = CosineAnnealingLR(optimizer, **lr_scheduler) elif 'CosineAnnealingWarmUpLR' == lr_scheduler_name: lr_scheduler = utils.warm_up_cosine_lr_scheduler(optimizer, **lr_scheduler) if local_rank == 0: print('model: #params={}'.format(utils.compute_num_params(model, text=True))) return model, optimizer, epoch_start, lr_scheduler def reduce_mean(tensor, nprocs): rt = tensor.clone() dist.all_reduce(rt, op=dist.ReduceOp.SUM) rt /= nprocs return rt class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self, name, fmt=':f'): self.name = name self.fmt = fmt self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def return_avg(self): return self.avg def __str__(self): fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' return fmtstr.format(**self.__dict__) def train(train_loader, model, optimizer, local_rank): model = model.train() loss_fn = nn.L1Loss().cuda(local_rank) train_losses = AverageMeter('Loss', ':.4e') data_norm = config['data_norm'] t = data_norm['img'] img_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).cuda(local_rank) img_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).cuda(local_rank) t = data_norm['gt'] gt_sub = torch.FloatTensor(t['sub']).view(1, 1, -1).cuda(local_rank) gt_div = torch.FloatTensor(t['div']).view(1, 1, -1).cuda(local_rank) if local_rank == 0: pbar = tqdm(total=len(train_loader), desc='train', leave=False) for i, batch in enumerate(train_loader): if local_rank == 0: pbar.update(1) keys = list(batch.keys()) batch = batch[keys[torch.randint(0, len(keys), [])]] for k, v in batch.items(): if torch.is_tensor(v): batch[k] = v.cuda(local_rank, non_blocking=True) img = (batch['img'] - img_sub) / img_div gt = (batch['gt'] - gt_sub) / gt_div pred = model(img, gt.shape[-2:]) if isinstance(pred, tuple): loss = 0.2 * loss_fn(pred[0], gt) + loss_fn(pred[1], gt) elif isinstance(pred, list): losses = [loss_fn(x, gt) for x in pred] losses = [x * (idx + 1) for idx, x in enumerate(losses)] loss = sum(losses) / ((1 + len(losses)) * len(losses) / 2) else: loss = loss_fn(pred, gt) torch.distributed.barrier() reduced_loss = reduce_mean(loss, dist.get_world_size()) train_losses.update(reduced_loss.item(), img.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() if local_rank == 0: pbar.close() return train_losses.avg def eval_psnr(loader, class_names, model, local_rank, data_norm=None, eval_type=None, eval_bsize=None, verbose=False, crop_border=4): crop_border = int(crop_border) if crop_border else crop_border if local_rank == 0: print('crop border: ', crop_border) model = model.eval() if data_norm is None: data_norm = { 'img': {'sub': [0], 'div': [1]}, 'gt': {'sub': [0], 'div': [1]} } t = data_norm['img'] img_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).cuda(local_rank) img_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).cuda(local_rank) t = data_norm['gt'] gt_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).cuda(local_rank) gt_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).cuda(local_rank) if eval_type is None: metric_fn = [utils.calculate_psnr_pt, utils.calculate_ssim_pt] elif eval_type == 'psnr+ssim': metric_fn = [utils.calculate_psnr_pt, utils.calculate_ssim_pt] elif eval_type.startswith('div2k'): scale = int(eval_type.split('-')[1]) metric_fn = partial(utils.calc_psnr, dataset='div2k', scale=scale) elif eval_type.startswith('benchmark'): scale = int(eval_type.split('-')[1]) metric_fn = partial(utils.calc_psnr, dataset='benchmark', scale=scale) else: raise NotImplementedError val_res_psnr = AverageMeter('psnr', ':.4f') val_res_ssim = AverageMeter('ssim', ':.4f') if local_rank == 0: pbar = tqdm(total=len(loader), desc='val', leave=False) for batch in loader: if local_rank == 0: pbar.update(1) for k, v in batch.items(): if torch.is_tensor(v): batch[k] = v.cuda(local_rank, non_blocking=True) img = (batch['img'] - img_sub) / img_div with torch.no_grad(): pred = model(img, batch['gt'].shape[-2:]) if isinstance(pred, list): pred = pred[-1] pred = pred * gt_div + gt_sub res_psnr = metric_fn[0]( pred, batch['gt'], crop_border=crop_border ).mean() res_ssim = metric_fn[1]( pred, batch['gt'], crop_border=crop_border ).mean() torch.distributed.barrier() reduced_val_res_psnr = reduce_mean(res_psnr, dist.get_world_size()) reduced_val_res_ssim = reduce_mean(res_ssim, dist.get_world_size()) val_res_psnr.update(reduced_val_res_psnr.item(), img.size(0)) val_res_ssim.update(reduced_val_res_ssim.item(), img.size(0)) if verbose and local_rank == 0: pbar.set_description( 'val psnr: {:.4f} ssim: {:.4f}'.format(val_res_psnr.avg, val_res_ssim.avg)) if local_rank == 0: pbar.close() return val_res_psnr.avg, val_res_ssim.avg def main(config, save_path): # torch.backends.cudnn.benchmark = True dist.init_process_group("nccl") rank = dist.get_rank() local_rank = int(os.environ["LOCAL_RANK"]) world_size = dist.get_world_size() print(f'rank: {rank} local_rank: {local_rank} world_size: {world_size}') # print(f'local_rank: {torch.distributed.local_rank()}') if local_rank == 0: log, writer = utils.set_save_path(save_path) with open(os.path.join(save_path, 'config.yaml'), 'w') as f: yaml.dump(config, f, sort_keys=False) train_loader, val_loader = make_data_loaders(config, local_rank) if config.get('data_norm') is None: config['data_norm'] = { 'img': {'sub': [0], 'div': [1]}, 'gt': {'sub': [0], 'div': [1]} } model, optimizer, epoch_start, lr_scheduler = prepare_training(config, local_rank) epoch_max = config['epoch_max'] epoch_val_interval = config.get('epoch_val_interval') epoch_save_interval = config.get('epoch_save_interval') max_val_v = -1e18 timer = utils.Timer() for epoch in range(epoch_start, epoch_max + 1): t_epoch_start = timer.t() train_loader.sampler.set_epoch(epoch) train_loss = train(train_loader, model, optimizer, local_rank) if lr_scheduler is not None: lr_scheduler.step() if rank == 0: log_info = ['epoch {}/{}'.format(epoch, epoch_max)] log_info.append('train: loss={:.4f}'.format(train_loss)) writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch) writer.add_scalars('loss', {'train': train_loss}, epoch) model_ = model.module model_spec = config['model'] model_spec['sd'] = model_.state_dict() optimizer_spec = config['optimizer'] optimizer_spec['sd'] = optimizer.state_dict() sv_file = { 'model': model_spec, 'optimizer': optimizer_spec, 'epoch': epoch } if rank == 0: torch.save(sv_file, os.path.join(save_path, 'epoch-last.pth')) if (epoch_save_interval is not None) and (epoch % epoch_save_interval == 0): if rank == 0: torch.save(sv_file, os.path.join(save_path, 'epoch-{}.pth'.format(epoch))) if (epoch_val_interval is not None) and (epoch % epoch_val_interval == 0): file_names = json.load(open(config['val_dataset']['dataset']['args']['split_file']))['test'] class_names = list(set([os.path.basename(os.path.dirname(x)) for x in file_names])) val_res_psnr, val_res_ssim = eval_psnr(val_loader, class_names, model_, local_rank, data_norm=config['data_norm'], eval_type=config.get('eval_type'), eval_bsize=config.get('eval_bsize'), crop_border=4) if rank == 0: log_info.append('val: psnr={:.4f}'.format(val_res_psnr)) writer.add_scalars('psnr', {'val': val_res_psnr}, epoch) if val_res_psnr > max_val_v: max_val_v = val_res_psnr if rank == 0: torch.save(sv_file, os.path.join(save_path, 'epoch-best.pth')) t = timer.t() if rank == 0: prog = (epoch - epoch_start + 1) / (epoch_max - epoch_start + 1) t_epoch = utils.time_text(t - t_epoch_start) t_elapsed, t_all = utils.time_text(t), utils.time_text(t / prog) log_info.append('{} {}/{}'.format(t_epoch, t_elapsed, t_all)) log(', '.join(log_info)) writer.flush() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--config', default='configs/train_1x-5x_INR_funsr.yaml') parser.add_argument('--name', default='EXP20221216_11') parser.add_argument('--tag', default=None) args = parser.parse_args() with open(args.config, 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) print('config loaded.') save_name = args.name if save_name is None: save_name = '_' + args.config.split('/')[-1][:-len('.yaml')] if args.tag is not None: save_name += '_' + args.tag save_path = os.path.join('./checkpoints', save_name) main(config, save_path)