Spaces:
Running
Running
| # Copyright 2019-present NAVER Corp. | |
| # CC BY-NC-SA 3.0 | |
| # Available only for non-commercial use | |
| from PIL import Image | |
| from tools import common | |
| from tools.dataloader import norm_RGB | |
| from nets.patchnet import * | |
| from os import path | |
| from extract import load_network, NonMaxSuppression, extract_multiscale | |
| # Kapture is a pivot file format, based on text and binary files, used to describe SfM (Structure From Motion) | |
| # and more generally sensor-acquired data | |
| # it can be installed with | |
| # pip install kapture | |
| # for more information check out https://github.com/naver/kapture | |
| import kapture | |
| from kapture.io.records import get_image_fullpath | |
| from kapture.io.csv import kapture_from_dir | |
| from kapture.io.csv import ( | |
| get_feature_csv_fullpath, | |
| keypoints_to_file, | |
| descriptors_to_file, | |
| ) | |
| from kapture.io.features import ( | |
| get_keypoints_fullpath, | |
| keypoints_check_dir, | |
| image_keypoints_to_file, | |
| ) | |
| from kapture.io.features import ( | |
| get_descriptors_fullpath, | |
| descriptors_check_dir, | |
| image_descriptors_to_file, | |
| ) | |
| from kapture.io.csv import get_all_tar_handlers | |
| def extract_kapture_keypoints(args): | |
| """ | |
| Extract r2d2 keypoints and descritors to the kapture format directly | |
| """ | |
| print("extract_kapture_keypoints...") | |
| with get_all_tar_handlers( | |
| args.kapture_root, | |
| mode={ | |
| kapture.Keypoints: "a", | |
| kapture.Descriptors: "a", | |
| kapture.GlobalFeatures: "r", | |
| kapture.Matches: "r", | |
| }, | |
| ) as tar_handlers: | |
| kdata = kapture_from_dir( | |
| args.kapture_root, | |
| None, | |
| skip_list=[ | |
| kapture.GlobalFeatures, | |
| kapture.Matches, | |
| kapture.Points3d, | |
| kapture.Observations, | |
| ], | |
| tar_handlers=tar_handlers, | |
| ) | |
| assert kdata.records_camera is not None | |
| image_list = [ | |
| filename for _, _, filename in kapture.flatten(kdata.records_camera) | |
| ] | |
| if args.keypoints_type is None: | |
| args.keypoints_type = path.splitext(path.basename(args.model))[0] | |
| print(f"keypoints_type set to {args.keypoints_type}") | |
| if args.descriptors_type is None: | |
| args.descriptors_type = path.splitext(path.basename(args.model))[0] | |
| print(f"descriptors_type set to {args.descriptors_type}") | |
| if ( | |
| kdata.keypoints is not None | |
| and args.keypoints_type in kdata.keypoints | |
| and kdata.descriptors is not None | |
| and args.descriptors_type in kdata.descriptors | |
| ): | |
| print( | |
| "detected already computed features of same keypoints_type/descriptors_type, resuming extraction..." | |
| ) | |
| image_list = [ | |
| name | |
| for name in image_list | |
| if name not in kdata.keypoints[args.keypoints_type] | |
| or name not in kdata.descriptors[args.descriptors_type] | |
| ] | |
| if len(image_list) == 0: | |
| print("All features were already extracted") | |
| return | |
| else: | |
| print(f"Extracting r2d2 features for {len(image_list)} images") | |
| iscuda = common.torch_set_gpu(args.gpu) | |
| # load the network... | |
| net = load_network(args.model) | |
| if iscuda: | |
| net = net.cuda() | |
| # create the non-maxima detector | |
| detector = NonMaxSuppression( | |
| rel_thr=args.reliability_thr, rep_thr=args.repeatability_thr | |
| ) | |
| if kdata.keypoints is None: | |
| kdata.keypoints = {} | |
| if kdata.descriptors is None: | |
| kdata.descriptors = {} | |
| if args.keypoints_type not in kdata.keypoints: | |
| keypoints_dtype = None | |
| keypoints_dsize = None | |
| else: | |
| keypoints_dtype = kdata.keypoints[args.keypoints_type].dtype | |
| keypoints_dsize = kdata.keypoints[args.keypoints_type].dsize | |
| if args.descriptors_type not in kdata.descriptors: | |
| descriptors_dtype = None | |
| descriptors_dsize = None | |
| else: | |
| descriptors_dtype = kdata.descriptors[args.descriptors_type].dtype | |
| descriptors_dsize = kdata.descriptors[args.descriptors_type].dsize | |
| for image_name in image_list: | |
| img_path = get_image_fullpath(args.kapture_root, image_name) | |
| print(f"\nExtracting features for {img_path}") | |
| img = Image.open(img_path).convert("RGB") | |
| W, H = img.size | |
| img = norm_RGB(img)[None] | |
| if iscuda: | |
| img = img.cuda() | |
| # extract keypoints/descriptors for a single image | |
| xys, desc, scores = extract_multiscale( | |
| net, | |
| img, | |
| detector, | |
| scale_f=args.scale_f, | |
| min_scale=args.min_scale, | |
| max_scale=args.max_scale, | |
| min_size=args.min_size, | |
| max_size=args.max_size, | |
| verbose=True, | |
| ) | |
| xys = xys.cpu().numpy() | |
| desc = desc.cpu().numpy() | |
| scores = scores.cpu().numpy() | |
| idxs = scores.argsort()[-args.top_k or None :] | |
| xys = xys[idxs] | |
| desc = desc[idxs] | |
| if keypoints_dtype is None or descriptors_dtype is None: | |
| keypoints_dtype = xys.dtype | |
| descriptors_dtype = desc.dtype | |
| keypoints_dsize = xys.shape[1] | |
| descriptors_dsize = desc.shape[1] | |
| kdata.keypoints[args.keypoints_type] = kapture.Keypoints( | |
| "r2d2", keypoints_dtype, keypoints_dsize | |
| ) | |
| kdata.descriptors[args.descriptors_type] = kapture.Descriptors( | |
| "r2d2", | |
| descriptors_dtype, | |
| descriptors_dsize, | |
| args.keypoints_type, | |
| "L2", | |
| ) | |
| keypoints_config_absolute_path = get_feature_csv_fullpath( | |
| kapture.Keypoints, args.keypoints_type, args.kapture_root | |
| ) | |
| descriptors_config_absolute_path = get_feature_csv_fullpath( | |
| kapture.Descriptors, args.descriptors_type, args.kapture_root | |
| ) | |
| keypoints_to_file( | |
| keypoints_config_absolute_path, kdata.keypoints[args.keypoints_type] | |
| ) | |
| descriptors_to_file( | |
| descriptors_config_absolute_path, | |
| kdata.descriptors[args.descriptors_type], | |
| ) | |
| else: | |
| assert kdata.keypoints[args.keypoints_type].dtype == xys.dtype | |
| assert kdata.descriptors[args.descriptors_type].dtype == desc.dtype | |
| assert kdata.keypoints[args.keypoints_type].dsize == xys.shape[1] | |
| assert kdata.descriptors[args.descriptors_type].dsize == desc.shape[1] | |
| assert ( | |
| kdata.descriptors[args.descriptors_type].keypoints_type | |
| == args.keypoints_type | |
| ) | |
| assert kdata.descriptors[args.descriptors_type].metric_type == "L2" | |
| keypoints_fullpath = get_keypoints_fullpath( | |
| args.keypoints_type, args.kapture_root, image_name, tar_handlers | |
| ) | |
| print(f"Saving {xys.shape[0]} keypoints to {keypoints_fullpath}") | |
| image_keypoints_to_file(keypoints_fullpath, xys) | |
| kdata.keypoints[args.keypoints_type].add(image_name) | |
| descriptors_fullpath = get_descriptors_fullpath( | |
| args.descriptors_type, args.kapture_root, image_name, tar_handlers | |
| ) | |
| print(f"Saving {desc.shape[0]} descriptors to {descriptors_fullpath}") | |
| image_descriptors_to_file(descriptors_fullpath, desc) | |
| kdata.descriptors[args.descriptors_type].add(image_name) | |
| if not keypoints_check_dir( | |
| kdata.keypoints[args.keypoints_type], | |
| args.keypoints_type, | |
| args.kapture_root, | |
| tar_handlers, | |
| ) or not descriptors_check_dir( | |
| kdata.descriptors[args.descriptors_type], | |
| args.descriptors_type, | |
| args.kapture_root, | |
| tar_handlers, | |
| ): | |
| print( | |
| "local feature extraction ended successfully but not all files were saved" | |
| ) | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser( | |
| "Extract r2d2 local features for all images in a dataset stored in the kapture format" | |
| ) | |
| parser.add_argument("--model", type=str, required=True, help="model path") | |
| parser.add_argument( | |
| "--keypoints-type", | |
| default=None, | |
| help="keypoint type_name, default is filename of model", | |
| ) | |
| parser.add_argument( | |
| "--descriptors-type", | |
| default=None, | |
| help="descriptors type_name, default is filename of model", | |
| ) | |
| parser.add_argument( | |
| "--kapture-root", type=str, required=True, help="path to kapture root directory" | |
| ) | |
| parser.add_argument("--top-k", type=int, default=5000, help="number of keypoints") | |
| parser.add_argument("--scale-f", type=float, default=2**0.25) | |
| parser.add_argument("--min-size", type=int, default=256) | |
| parser.add_argument("--max-size", type=int, default=1024) | |
| parser.add_argument("--min-scale", type=float, default=0) | |
| parser.add_argument("--max-scale", type=float, default=1) | |
| parser.add_argument("--reliability-thr", type=float, default=0.7) | |
| parser.add_argument("--repeatability-thr", type=float, default=0.7) | |
| parser.add_argument( | |
| "--gpu", type=int, nargs="+", default=[0], help="use -1 for CPU" | |
| ) | |
| args = parser.parse_args() | |
| extract_kapture_keypoints(args) | |