Spaces:
Runtime error
Runtime error
| import argparse | |
| import pickle | |
| from pathlib import Path | |
| import cv2 | |
| import h5py | |
| import numpy as np | |
| import pycolmap | |
| import torch | |
| from scipy.io import loadmat | |
| from tqdm import tqdm | |
| from . import logger | |
| from .utils.parsers import names_to_pair, parse_retrieval | |
| def interpolate_scan(scan, kp): | |
| h, w, c = scan.shape | |
| kp = kp / np.array([[w - 1, h - 1]]) * 2 - 1 | |
| assert np.all(kp > -1) and np.all(kp < 1) | |
| scan = torch.from_numpy(scan).permute(2, 0, 1)[None] | |
| kp = torch.from_numpy(kp)[None, None] | |
| grid_sample = torch.nn.functional.grid_sample | |
| # To maximize the number of points that have depth: | |
| # do bilinear interpolation first and then nearest for the remaining points | |
| interp_lin = grid_sample(scan, kp, align_corners=True, mode="bilinear")[ | |
| 0, :, 0 | |
| ] | |
| interp_nn = torch.nn.functional.grid_sample( | |
| scan, kp, align_corners=True, mode="nearest" | |
| )[0, :, 0] | |
| interp = torch.where(torch.isnan(interp_lin), interp_nn, interp_lin) | |
| valid = ~torch.any(torch.isnan(interp), 0) | |
| kp3d = interp.T.numpy() | |
| valid = valid.numpy() | |
| return kp3d, valid | |
| def get_scan_pose(dataset_dir, rpath): | |
| split_image_rpath = rpath.split("/") | |
| floor_name = split_image_rpath[-3] | |
| scan_id = split_image_rpath[-2] | |
| image_name = split_image_rpath[-1] | |
| building_name = image_name[:3] | |
| path = Path( | |
| dataset_dir, | |
| "database/alignments", | |
| floor_name, | |
| f"transformations/{building_name}_trans_{scan_id}.txt", | |
| ) | |
| with open(path) as f: | |
| raw_lines = f.readlines() | |
| P_after_GICP = np.array( | |
| [ | |
| np.fromstring(raw_lines[7], sep=" "), | |
| np.fromstring(raw_lines[8], sep=" "), | |
| np.fromstring(raw_lines[9], sep=" "), | |
| np.fromstring(raw_lines[10], sep=" "), | |
| ] | |
| ) | |
| return P_after_GICP | |
| def pose_from_cluster( | |
| dataset_dir, q, retrieved, feature_file, match_file, skip=None | |
| ): | |
| height, width = cv2.imread(str(dataset_dir / q)).shape[:2] | |
| cx = 0.5 * width | |
| cy = 0.5 * height | |
| focal_length = 4032.0 * 28.0 / 36.0 | |
| all_mkpq = [] | |
| all_mkpr = [] | |
| all_mkp3d = [] | |
| all_indices = [] | |
| kpq = feature_file[q]["keypoints"].__array__() | |
| num_matches = 0 | |
| for i, r in enumerate(retrieved): | |
| kpr = feature_file[r]["keypoints"].__array__() | |
| pair = names_to_pair(q, r) | |
| m = match_file[pair]["matches0"].__array__() | |
| v = m > -1 | |
| if skip and (np.count_nonzero(v) < skip): | |
| continue | |
| mkpq, mkpr = kpq[v], kpr[m[v]] | |
| num_matches += len(mkpq) | |
| scan_r = loadmat(Path(dataset_dir, r + ".mat"))["XYZcut"] | |
| mkp3d, valid = interpolate_scan(scan_r, mkpr) | |
| Tr = get_scan_pose(dataset_dir, r) | |
| mkp3d = (Tr[:3, :3] @ mkp3d.T + Tr[:3, -1:]).T | |
| all_mkpq.append(mkpq[valid]) | |
| all_mkpr.append(mkpr[valid]) | |
| all_mkp3d.append(mkp3d[valid]) | |
| all_indices.append(np.full(np.count_nonzero(valid), i)) | |
| all_mkpq = np.concatenate(all_mkpq, 0) | |
| all_mkpr = np.concatenate(all_mkpr, 0) | |
| all_mkp3d = np.concatenate(all_mkp3d, 0) | |
| all_indices = np.concatenate(all_indices, 0) | |
| cfg = { | |
| "model": "SIMPLE_PINHOLE", | |
| "width": width, | |
| "height": height, | |
| "params": [focal_length, cx, cy], | |
| } | |
| ret = pycolmap.absolute_pose_estimation(all_mkpq, all_mkp3d, cfg, 48.00) | |
| ret["cfg"] = cfg | |
| return ret, all_mkpq, all_mkpr, all_mkp3d, all_indices, num_matches | |
| def main(dataset_dir, retrieval, features, matches, results, skip_matches=None): | |
| assert retrieval.exists(), retrieval | |
| assert features.exists(), features | |
| assert matches.exists(), matches | |
| retrieval_dict = parse_retrieval(retrieval) | |
| queries = list(retrieval_dict.keys()) | |
| feature_file = h5py.File(features, "r", libver="latest") | |
| match_file = h5py.File(matches, "r", libver="latest") | |
| poses = {} | |
| logs = { | |
| "features": features, | |
| "matches": matches, | |
| "retrieval": retrieval, | |
| "loc": {}, | |
| } | |
| logger.info("Starting localization...") | |
| for q in tqdm(queries): | |
| db = retrieval_dict[q] | |
| ret, mkpq, mkpr, mkp3d, indices, num_matches = pose_from_cluster( | |
| dataset_dir, q, db, feature_file, match_file, skip_matches | |
| ) | |
| poses[q] = (ret["qvec"], ret["tvec"]) | |
| logs["loc"][q] = { | |
| "db": db, | |
| "PnP_ret": ret, | |
| "keypoints_query": mkpq, | |
| "keypoints_db": mkpr, | |
| "3d_points": mkp3d, | |
| "indices_db": indices, | |
| "num_matches": num_matches, | |
| } | |
| logger.info(f"Writing poses to {results}...") | |
| with open(results, "w") as f: | |
| for q in queries: | |
| qvec, tvec = poses[q] | |
| qvec = " ".join(map(str, qvec)) | |
| tvec = " ".join(map(str, tvec)) | |
| name = q.split("/")[-1] | |
| f.write(f"{name} {qvec} {tvec}\n") | |
| logs_path = f"{results}_logs.pkl" | |
| logger.info(f"Writing logs to {logs_path}...") | |
| with open(logs_path, "wb") as f: | |
| pickle.dump(logs, f) | |
| logger.info("Done!") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--dataset_dir", type=Path, required=True) | |
| parser.add_argument("--retrieval", type=Path, required=True) | |
| parser.add_argument("--features", type=Path, required=True) | |
| parser.add_argument("--matches", type=Path, required=True) | |
| parser.add_argument("--results", type=Path, required=True) | |
| parser.add_argument("--skip_matches", type=int) | |
| args = parser.parse_args() | |
| main(**args.__dict__) | |