Spaces:
Running
Running
| # Copyright 2019-present NAVER Corp. | |
| # CC BY-NC-SA 3.0 | |
| # Available only for non-commercial use | |
| import os, pdb | |
| import torch | |
| import torch.optim as optim | |
| from tools import common, trainer | |
| from tools.dataloader import * | |
| from nets.patchnet import * | |
| from nets.losses import * | |
| default_net = "Quad_L2Net_ConfCFS()" | |
| toy_db_debug = """SyntheticPairDataset( | |
| ImgFolder('imgs'), | |
| 'RandomScale(256,1024,can_upscale=True)', | |
| 'RandomTilting(0.5), PixelNoise(25)')""" | |
| db_web_images = """SyntheticPairDataset( | |
| web_images, | |
| 'RandomScale(256,1024,can_upscale=True)', | |
| 'RandomTilting(0.5), PixelNoise(25)')""" | |
| db_aachen_images = """SyntheticPairDataset( | |
| aachen_db_images, | |
| 'RandomScale(256,1024,can_upscale=True)', | |
| 'RandomTilting(0.5), PixelNoise(25)')""" | |
| db_aachen_style_transfer = """TransformedPairs( | |
| aachen_style_transfer_pairs, | |
| 'RandomScale(256,1024,can_upscale=True), RandomTilting(0.5), PixelNoise(25)')""" | |
| db_aachen_flow = "aachen_flow_pairs" | |
| data_sources = dict( | |
| D=toy_db_debug, | |
| W=db_web_images, | |
| A=db_aachen_images, | |
| F=db_aachen_flow, | |
| S=db_aachen_style_transfer, | |
| ) | |
| default_dataloader = """PairLoader(CatPairDataset(`data`), | |
| scale = 'RandomScale(256,1024,can_upscale=True)', | |
| distort = 'ColorJitter(0.2,0.2,0.2,0.1)', | |
| crop = 'RandomCrop(192)')""" | |
| default_sampler = """NghSampler2(ngh=7, subq=-8, subd=1, pos_d=3, neg_d=5, border=16, | |
| subd_neg=-8,maxpool_pos=True)""" | |
| default_loss = """MultiLoss( | |
| 1, ReliabilityLoss(`sampler`, base=0.5, nq=20), | |
| 1, CosimLoss(N=`N`), | |
| 1, PeakyLoss(N=`N`))""" | |
| class MyTrainer(trainer.Trainer): | |
| """This class implements the network training. | |
| Below is the function I need to overload to explain how to do the backprop. | |
| """ | |
| def forward_backward(self, inputs): | |
| output = self.net(imgs=[inputs.pop("img1"), inputs.pop("img2")]) | |
| allvars = dict(inputs, **output) | |
| loss, details = self.loss_func(**allvars) | |
| if torch.is_grad_enabled(): | |
| loss.backward() | |
| return loss, details | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser("Train R2D2") | |
| parser.add_argument("--data-loader", type=str, default=default_dataloader) | |
| parser.add_argument( | |
| "--train-data", | |
| type=str, | |
| default=list("WASF"), | |
| nargs="+", | |
| choices=set(data_sources.keys()), | |
| ) | |
| parser.add_argument( | |
| "--net", type=str, default=default_net, help="network architecture" | |
| ) | |
| parser.add_argument( | |
| "--pretrained", type=str, default="", help="pretrained model path" | |
| ) | |
| parser.add_argument( | |
| "--save-path", type=str, required=True, help="model save_path path" | |
| ) | |
| parser.add_argument("--loss", type=str, default=default_loss, help="loss function") | |
| parser.add_argument( | |
| "--sampler", type=str, default=default_sampler, help="AP sampler" | |
| ) | |
| parser.add_argument( | |
| "--N", type=int, default=16, help="patch size for repeatability" | |
| ) | |
| parser.add_argument( | |
| "--epochs", type=int, default=25, help="number of training epochs" | |
| ) | |
| parser.add_argument("--batch-size", "--bs", type=int, default=8, help="batch size") | |
| parser.add_argument("--learning-rate", "--lr", type=str, default=1e-4) | |
| parser.add_argument("--weight-decay", "--wd", type=float, default=5e-4) | |
| parser.add_argument( | |
| "--threads", type=int, default=8, help="number of worker threads" | |
| ) | |
| parser.add_argument("--gpu", type=int, nargs="+", default=[0], help="-1 for CPU") | |
| args = parser.parse_args() | |
| iscuda = common.torch_set_gpu(args.gpu) | |
| common.mkdir_for(args.save_path) | |
| # Create data loader | |
| from datasets import * | |
| db = [data_sources[key] for key in args.train_data] | |
| db = eval(args.data_loader.replace("`data`", ",".join(db)).replace("\n", "")) | |
| print("Training image database =", db) | |
| loader = threaded_loader(db, iscuda, args.threads, args.batch_size, shuffle=True) | |
| # create network | |
| print("\n>> Creating net = " + args.net) | |
| net = eval(args.net) | |
| print(f" ( Model size: {common.model_size(net)/1000:.0f}K parameters )") | |
| # initialization | |
| if args.pretrained: | |
| checkpoint = torch.load(args.pretrained, lambda a, b: a) | |
| net.load_pretrained(checkpoint["state_dict"]) | |
| # create losses | |
| loss = args.loss.replace("`sampler`", args.sampler).replace("`N`", str(args.N)) | |
| print("\n>> Creating loss = " + loss) | |
| loss = eval(loss.replace("\n", "")) | |
| # create optimizer | |
| optimizer = optim.Adam( | |
| [p for p in net.parameters() if p.requires_grad], | |
| lr=args.learning_rate, | |
| weight_decay=args.weight_decay, | |
| ) | |
| train = MyTrainer(net, loader, loss, optimizer) | |
| if iscuda: | |
| train = train.cuda() | |
| # Training loop # | |
| for epoch in range(args.epochs): | |
| print(f"\n>> Starting epoch {epoch}...") | |
| train() | |
| print(f"\n>> Saving model to {args.save_path}") | |
| torch.save({"net": args.net, "state_dict": net.state_dict()}, args.save_path) | |