Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import cv2 | |
| import time | |
| import random | |
| import datetime | |
| import argparse | |
| import numpy as np | |
| from tqdm import tqdm | |
| from piq import ssim,psnr | |
| from itertools import cycle | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils import data | |
| import torch.distributed as dist | |
| from torch.utils.data.distributed import DistributedSampler | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from utils import dict2string,mkdir,get_lr,torch2cvimg,second2hours | |
| from loaders import docres_loader | |
| from models import restormer_arch | |
| def seed_torch(seed=1029): | |
| random.seed(seed) | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.benchmark = False | |
| torch.backends.cudnn.deterministic = True | |
| #torch.use_deterministic_algorithms(True) | |
| # seed_torch() | |
| def getBasecoord(h,w): | |
| base_coord0 = np.tile(np.arange(h).reshape(h,1),(1,w)).astype(np.float32) | |
| base_coord1 = np.tile(np.arange(w).reshape(1,w),(h,1)).astype(np.float32) | |
| base_coord = np.concatenate((np.expand_dims(base_coord1,-1),np.expand_dims(base_coord0,-1)),-1) | |
| return base_coord | |
| def train(args): | |
| ## DDP init | |
| dist.init_process_group(backend='nccl',init_method='env://',timeout=datetime.timedelta(seconds=36000)) | |
| torch.cuda.set_device(args.local_rank) | |
| device = torch.device('cuda',args.local_rank) | |
| torch.cuda.manual_seed_all(42) | |
| ### Log file: | |
| mkdir(args.logdir) | |
| mkdir(os.path.join(args.logdir,args.experiment_name)) | |
| log_file_path=os.path.join(args.logdir,args.experiment_name,'log.txt') | |
| log_file=open(log_file_path,'a') | |
| log_file.write('\n--------------- '+args.experiment_name+' ---------------\n') | |
| log_file.close() | |
| ### Setup tensorboard for visualization | |
| if args.tboard: | |
| writer = SummaryWriter(os.path.join(args.logdir,args.experiment_name,'runs'),args.experiment_name) | |
| ### Setup Dataloader | |
| datasets_setting = [ | |
| {'task':'deblurring','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/deblurring/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/deblurring/tdd/train.json']}, | |
| {'task':'dewarping','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/dewarping/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/dewarping/doc3d/train_1_19.json']}, | |
| {'task':'binarization','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/binarization/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/binarization/train.json']}, | |
| {'task':'deshadowing','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/deshadowing/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/deshadowing/train.json']}, | |
| {'task':'appearance','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/appearance/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/appearance/trainv2.json']} | |
| ] | |
| ratios = [dataset_setting['ratio'] for dataset_setting in datasets_setting] | |
| datasets = [docres_loader.DocResTrainDataset(dataset=dataset_setting,img_size=args.im_size) for dataset_setting in datasets_setting] | |
| trainloaders = [{'task':datasets_setting[i],'loader':data.DataLoader(dataset=datasets[i], sampler=DistributedSampler(datasets[i]), batch_size=args.batch_size, num_workers=2, pin_memory=True,drop_last=True),'iter_loader':iter(data.DataLoader(dataset=datasets[i], sampler=DistributedSampler(datasets[i]), batch_size=args.batch_size, num_workers=2, pin_memory=True,drop_last=True))} for i in range(len(datasets))] | |
| ### test loader | |
| # for i in tqdm(range(args.total_iter)): | |
| # loader_index = random.choices(list(range(len(trainloaders))),ratios)[0] | |
| # in_im,gt_im = next(trainloaders[loader_index]['iter_loader']) | |
| ### Setup Model | |
| model = restormer_arch.Restormer( | |
| inp_channels=6, | |
| out_channels=3, | |
| dim = 48, | |
| num_blocks = [2,3,3,4], | |
| num_refinement_blocks = 4, | |
| heads = [1,2,4,8], | |
| ffn_expansion_factor = 2.66, | |
| bias = False, | |
| LayerNorm_type = 'WithBias', | |
| dual_pixel_task = True | |
| ) | |
| model=DDP(model.cuda(),device_ids=[args.local_rank],output_device=args.local_rank) | |
| ### Optimizer | |
| optimizer= torch.optim.AdamW(model.parameters(),lr=args.l_rate,weight_decay=5e-4) | |
| ### LR Scheduler | |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.total_iter, eta_min=1e-6, last_epoch=-1) | |
| ### load checkpoint | |
| iter_start=0 | |
| if args.resume is not None: | |
| print("Loading model and optimizer from checkpoint '{}'".format(args.resume)) | |
| x = checkpoint['model_state'] | |
| model.load_state_dict(x,strict=False) | |
| iter_start=checkpoint['iter'] | |
| print("Loaded checkpoint '{}' (iter {})".format(args.resume, iter_start)) | |
| ###-----------------------------------------Training----------------------------------------- | |
| ##initialize | |
| scaler = torch.cuda.amp.GradScaler() | |
| loss_dict = {} | |
| total_step = 0 | |
| l2 = nn.MSELoss() | |
| l1 = nn.L1Loss() | |
| ce = nn.CrossEntropyLoss() | |
| bce = nn.BCEWithLogitsLoss() | |
| m = nn.Sigmoid() | |
| best = 0 | |
| best_ce = 999 | |
| ## total_steps | |
| for iters in range(iter_start,args.total_iter): | |
| start_time = time.time() | |
| loader_index = random.choices(list(range(len(trainloaders))),ratios)[0] | |
| try: | |
| in_im,gt_im = next(trainloaders[loader_index]['iter_loader']) | |
| except StopIteration: | |
| trainloaders[loader_index]['iter_loader']=iter(trainloaders[loader_index]['loader']) | |
| in_im,gt_im = next(trainloaders[loader_index]['iter_loader']) | |
| in_im = in_im.float().cuda() | |
| gt_im = gt_im.float().cuda() | |
| binarization_loss,appearance_loss,dewarping_loss,deblurring_loss,deshadowing_loss = 0,0,0,0,0 | |
| with torch.cuda.amp.autocast(): | |
| pred_im = model(in_im,trainloaders[loader_index]['task']['task']) | |
| if trainloaders[loader_index]['task']['task'] == 'binarization': | |
| gt_im = gt_im.long() | |
| binarization_loss = ce(pred_im[:,:2,:,:], gt_im[:,0,:,:]) | |
| loss = binarization_loss | |
| elif trainloaders[loader_index]['task']['task'] == 'dewarping': | |
| dewarping_loss = l1(pred_im[:,:2,:,:], gt_im[:,:2,:,:]) | |
| loss = dewarping_loss | |
| elif trainloaders[loader_index]['task']['task'] == 'appearance': | |
| appearance_loss = l1(pred_im, gt_im) | |
| loss = appearance_loss | |
| elif trainloaders[loader_index]['task']['task'] == 'deblurring': | |
| deblurring_loss = l1(pred_im, gt_im) | |
| loss = deblurring_loss | |
| elif trainloaders[loader_index]['task']['task'] == 'deshadowing': | |
| deshadowing_loss = l1(pred_im, gt_im) | |
| loss = deshadowing_loss | |
| optimizer.zero_grad() | |
| scaler.scale(loss).backward() | |
| scaler.step(optimizer) | |
| scaler.update() | |
| loss_dict['dew_loss']=dewarping_loss.item() if isinstance(dewarping_loss,torch.Tensor) else 0 | |
| loss_dict['app_loss']=appearance_loss.item() if isinstance(appearance_loss,torch.Tensor) else 0 | |
| loss_dict['des_loss']=deshadowing_loss.item() if isinstance(deshadowing_loss,torch.Tensor) else 0 | |
| loss_dict['deb_loss']=deblurring_loss.item() if isinstance(deblurring_loss,torch.Tensor) else 0 | |
| loss_dict['bin_loss']=binarization_loss.item() if isinstance(binarization_loss,torch.Tensor) else 0 | |
| end_time = time.time() | |
| duration = end_time-start_time | |
| ## log | |
| if (iters+1) % 10 == 0: | |
| print('iters [{}/{}] -- '.format(iters+1,args.total_iter)+dict2string(loss_dict)+' --lr {:6f}'.format(get_lr(optimizer))+' -- time {}'.format(second2hours(duration*(args.total_iter-iters)))) | |
| ## tbord | |
| if args.tboard: | |
| for key,value in loss_dict.items(): | |
| writer.add_scalar('Train '+key+'/Iterations', value, total_step) | |
| ## logfile | |
| with open(log_file_path,'a') as f: | |
| f.write('iters [{}/{}] -- '.format(iters+1,args.total_iter)+dict2string(loss_dict)+' --lr {:6f}'.format(get_lr(optimizer))+' -- time {}'.format(second2hours(duration*(args.total_iter-iters)))+'\n') | |
| if (iters+1) % 5000 == 0: | |
| state = {'iters': iters+1, | |
| 'model_state': model.state_dict(), | |
| 'optimizer_state' : optimizer.state_dict(),} | |
| if not os.path.exists(os.path.join(args.logdir,args.experiment_name)): | |
| os.system('mkdir ' + os.path.join(args.logdir,args.experiment_name)) | |
| if torch.distributed.get_rank()==0: | |
| torch.save(state, os.path.join(args.logdir,args.experiment_name,"{}.pkl".format(iters+1))) | |
| sched.step() | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description='Hyperparams') | |
| parser.add_argument('--im_size', nargs='?', type=int, default=256, | |
| help='Height of the input image') | |
| parser.add_argument('--total_iter', nargs='?', type=int, default=100000, | |
| help='# of the epochs') | |
| parser.add_argument('--batch_size', nargs='?', type=int, default=10, | |
| help='Batch Size') | |
| parser.add_argument('--l_rate', nargs='?', type=float, default=2e-4, | |
| help='Learning Rate') | |
| parser.add_argument('--resume', nargs='?', type=str, default=None, | |
| help='Path to previous saved model to restart from') | |
| parser.add_argument('--logdir', nargs='?', type=str, default='./checkpoints/', | |
| help='Path to store the loss logs') | |
| parser.add_argument('--tboard', dest='tboard', action='store_true', | |
| help='Enable visualization(s) on tensorboard | False by default') | |
| parser.add_argument('--local_rank',type=int,default=0,metavar='N') | |
| parser.add_argument('--experiment_name', nargs='?', type=str,default='experiment_name', | |
| help='the name of this experiment') | |
| parser.set_defaults(tboard=False) | |
| args = parser.parse_args() | |
| train(args) |