Spaces:
Runtime error
Runtime error
| import argparse | |
| def str2bool(v): | |
| return v.lower() in ("true", "1") | |
| arg_lists = [] | |
| parser = argparse.ArgumentParser() | |
| def add_argument_group(name): | |
| arg = parser.add_argument_group(name) | |
| arg_lists.append(arg) | |
| return arg | |
| # ----------------------------------------------------------------------------- | |
| # Network | |
| net_arg = add_argument_group("Network") | |
| net_arg.add_argument( | |
| "--model_name", type=str, default="SGM", help="" "model for training" | |
| ) | |
| net_arg.add_argument( | |
| "--config_path", | |
| type=str, | |
| default="configs/sgm.yaml", | |
| help="" "config path for model", | |
| ) | |
| # ----------------------------------------------------------------------------- | |
| # Data | |
| data_arg = add_argument_group("Data") | |
| data_arg.add_argument( | |
| "--rawdata_path", type=str, default="rawdata", help="" "path for rawdata" | |
| ) | |
| data_arg.add_argument( | |
| "--dataset_path", type=str, default="dataset", help="" "path for dataset" | |
| ) | |
| data_arg.add_argument( | |
| "--desc_path", type=str, default="desc", help="" "path for descriptor(kpt) dir" | |
| ) | |
| data_arg.add_argument( | |
| "--num_kpt", type=int, default=1000, help="" "number of kpt for training" | |
| ) | |
| data_arg.add_argument( | |
| "--input_normalize", | |
| type=str, | |
| default="img", | |
| help="" "normalize type for input kpt, img or intrinsic", | |
| ) | |
| data_arg.add_argument( | |
| "--data_aug", | |
| type=str2bool, | |
| default=True, | |
| help="" "apply kpt coordinate homography augmentation", | |
| ) | |
| data_arg.add_argument( | |
| "--desc_suffix", type=str, default="suffix", help="" "desc file suffix" | |
| ) | |
| # ----------------------------------------------------------------------------- | |
| # Loss | |
| loss_arg = add_argument_group("loss") | |
| loss_arg.add_argument("--momentum", type=float, default=0.9, help="" "momentum") | |
| loss_arg.add_argument( | |
| "--seed_loss_weight", | |
| type=float, | |
| default=250, | |
| help="" "confidence loss weight for sgm", | |
| ) | |
| loss_arg.add_argument( | |
| "--mid_loss_weight", type=float, default=1, help="" "midseeding loss weight for sgm" | |
| ) | |
| loss_arg.add_argument( | |
| "--inlier_th", | |
| type=float, | |
| default=5e-3, | |
| help="" "inlier threshold for epipolar distance (for sgm and visualization)", | |
| ) | |
| # ----------------------------------------------------------------------------- | |
| # Training | |
| train_arg = add_argument_group("Train") | |
| train_arg.add_argument("--train_lr", type=float, default=1e-4, help="" "learning rate") | |
| train_arg.add_argument("--train_batch_size", type=int, default=16, help="" "batch size") | |
| train_arg.add_argument( | |
| "--gpu_id", type=str, default="0", help="id(s) for CUDA_VISIBLE_DEVICES" | |
| ) | |
| train_arg.add_argument( | |
| "--train_iter", type=int, default=1000000, help="" "training iterations to perform" | |
| ) | |
| train_arg.add_argument("--log_base", type=str, default="./log/", help="" "log path") | |
| train_arg.add_argument( | |
| "--val_intv", type=int, default=20000, help="" "validation interval" | |
| ) | |
| train_arg.add_argument( | |
| "--save_intv", type=int, default=1000, help="" "summary interval" | |
| ) | |
| train_arg.add_argument("--log_intv", type=int, default=100, help="" "log interval") | |
| train_arg.add_argument( | |
| "--decay_rate", type=float, default=0.999996, help="" "lr decay rate" | |
| ) | |
| train_arg.add_argument( | |
| "--decay_iter", type=float, default=300000, help="" "lr decay iter" | |
| ) | |
| train_arg.add_argument( | |
| "--local_rank", type=int, default=0, help="" "local rank for ddp" | |
| ) | |
| train_arg.add_argument( | |
| "--train_vis_folder", | |
| type=str, | |
| default=".", | |
| help="" "visualization folder during training", | |
| ) | |
| # ----------------------------------------------------------------------------- | |
| # Visualization | |
| vis_arg = add_argument_group("Visualization") | |
| vis_arg.add_argument( | |
| "--tqdm_width", type=int, default=79, help="" "width of the tqdm bar" | |
| ) | |
| def get_config(): | |
| config, unparsed = parser.parse_known_args() | |
| return config, unparsed | |
| def print_usage(): | |
| parser.print_usage() | |
| # | |
| # config.py ends here | |