Spaces:
Running
Running
| import os | |
| from torch.multiprocessing import Process, Manager, set_start_method, Pool | |
| import functools | |
| import argparse | |
| import yaml | |
| import numpy as np | |
| import sys | |
| import cv2 | |
| from tqdm import trange | |
| set_start_method("spawn", force=True) | |
| ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
| sys.path.insert(0, ROOT_DIR) | |
| from components import load_component | |
| from utils import evaluation_utils, metrics | |
| parser = argparse.ArgumentParser(description="dump eval data.") | |
| parser.add_argument( | |
| "--config_path", type=str, default="configs/eval/scannet_eval_sgm.yaml" | |
| ) | |
| parser.add_argument("--num_process_match", type=int, default=4) | |
| parser.add_argument("--num_process_eval", type=int, default=4) | |
| parser.add_argument("--vis_folder", type=str, default=None) | |
| args = parser.parse_args() | |
| def feed_match(info, matcher): | |
| x1, x2, desc1, desc2, size1, size2 = ( | |
| info["x1"], | |
| info["x2"], | |
| info["desc1"], | |
| info["desc2"], | |
| info["img1"].shape[:2], | |
| info["img2"].shape[:2], | |
| ) | |
| test_data = { | |
| "x1": x1, | |
| "x2": x2, | |
| "desc1": desc1, | |
| "desc2": desc2, | |
| "size1": np.flip(np.asarray(size1)), | |
| "size2": np.flip(np.asarray(size2)), | |
| } | |
| corr1, corr2 = matcher.run(test_data) | |
| return [corr1, corr2] | |
| def reader_handler(config, read_que): | |
| reader = load_component("reader", config["name"], config) | |
| for index in range(len(reader)): | |
| index += 0 | |
| info = reader.run(index) | |
| read_que.put(info) | |
| read_que.put("over") | |
| def match_handler(config, read_que, match_que): | |
| matcher = load_component("matcher", config["name"], config) | |
| match_func = functools.partial(feed_match, matcher=matcher) | |
| pool = Pool(args.num_process_match) | |
| cache = [] | |
| while True: | |
| item = read_que.get() | |
| # clear cache | |
| if item == "over": | |
| if len(cache) != 0: | |
| results = pool.map(match_func, cache) | |
| for cur_item, cur_result in zip(cache, results): | |
| cur_item["corr1"], cur_item["corr2"] = cur_result[0], cur_result[1] | |
| match_que.put(cur_item) | |
| match_que.put("over") | |
| break | |
| cache.append(item) | |
| # print(len(cache)) | |
| if len(cache) == args.num_process_match: | |
| # matching in parallel | |
| results = pool.map(match_func, cache) | |
| for cur_item, cur_result in zip(cache, results): | |
| cur_item["corr1"], cur_item["corr2"] = cur_result[0], cur_result[1] | |
| match_que.put(cur_item) | |
| cache = [] | |
| pool.close() | |
| pool.join() | |
| def evaluate_handler(config, match_que): | |
| evaluator = load_component("evaluator", config["name"], config) | |
| pool = Pool(args.num_process_eval) | |
| cache = [] | |
| for _ in trange(config["num_pair"]): | |
| item = match_que.get() | |
| if item == "over": | |
| if len(cache) != 0: | |
| results = pool.map(evaluator.run, cache) | |
| for cur_res in results: | |
| evaluator.res_inqueue(cur_res) | |
| break | |
| cache.append(item) | |
| if len(cache) == args.num_process_eval: | |
| results = pool.map(evaluator.run, cache) | |
| for cur_res in results: | |
| evaluator.res_inqueue(cur_res) | |
| cache = [] | |
| if args.vis_folder is not None: | |
| # dump visualization | |
| corr1_norm, corr2_norm = evaluation_utils.normalize_intrinsic( | |
| item["corr1"], item["K1"] | |
| ), evaluation_utils.normalize_intrinsic(item["corr2"], item["K2"]) | |
| inlier_mask = metrics.compute_epi_inlier( | |
| corr1_norm, corr2_norm, item["e"], config["inlier_th"] | |
| ) | |
| display = evaluation_utils.draw_match( | |
| item["img1"], item["img2"], item["corr1"], item["corr2"], inlier_mask | |
| ) | |
| cv2.imwrite( | |
| os.path.join(args.vis_folder, str(item["index"]) + ".png"), display | |
| ) | |
| evaluator.parse() | |
| if __name__ == "__main__": | |
| with open(args.config_path, "r") as f: | |
| config = yaml.load(f) | |
| if args.vis_folder is not None and not os.path.exists(args.vis_folder): | |
| os.mkdir(args.vis_folder) | |
| read_que, match_que, estimate_que = ( | |
| Manager().Queue(maxsize=100), | |
| Manager().Queue(maxsize=100), | |
| Manager().Queue(maxsize=100), | |
| ) | |
| read_process = Process(target=reader_handler, args=(config["reader"], read_que)) | |
| match_process = Process( | |
| target=match_handler, args=(config["matcher"], read_que, match_que) | |
| ) | |
| evaluate_process = Process( | |
| target=evaluate_handler, args=(config["evaluator"], match_que) | |
| ) | |
| read_process.start() | |
| match_process.start() | |
| evaluate_process.start() | |
| read_process.join() | |
| match_process.join() | |
| evaluate_process.join() | |