# Copyright (c) Meta Platforms, Inc. and affiliates. import argparse from pathlib import Path from typing import Optional, Tuple from omegaconf import OmegaConf, DictConfig from .. import logger from ..data import KittiDataModule from .run import evaluate default_cfg_single = OmegaConf.create({}) # For the sequential evaluation, we need to center the map around the GT location, # since random offsets would accumulate and leave only the GT location with a valid mask. # This should not have much impact on the results. default_cfg_sequential = OmegaConf.create( { "data": { "mask_radius": KittiDataModule.default_cfg["max_init_error"], "prior_range_rotation": KittiDataModule.default_cfg[ "max_init_error_rotation" ] + 1, "max_init_error": 0, "max_init_error_rotation": 0, }, "chunking": { "max_length": 100, # about 10s? }, } ) def run( split: str, experiment: str, cfg: Optional[DictConfig] = None, sequential: bool = False, thresholds: Tuple[int] = (1, 3, 5), **kwargs, ): cfg = cfg or {} if isinstance(cfg, dict): cfg = OmegaConf.create(cfg) default = default_cfg_sequential if sequential else default_cfg_single cfg = OmegaConf.merge(default, cfg) dataset = KittiDataModule(cfg.get("data", {})) metrics = evaluate( experiment, cfg, dataset, split=split, sequential=sequential, viz_kwargs=dict(show_dir_error=True, show_masked_prob=False), **kwargs, ) keys = ["directional_error", "yaw_max_error"] if sequential: keys += ["directional_seq_error", "yaw_seq_error"] for k in keys: rec = metrics[k].recall(thresholds).double().numpy().round(2).tolist() logger.info("Recall %s: %s at %s m/°", k, rec, thresholds) return metrics if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--experiment", type=str, required=True) parser.add_argument( "--split", type=str, default="test", choices=["test", "val", "train"] ) parser.add_argument("--sequential", action="store_true") parser.add_argument("--output_dir", type=Path) parser.add_argument("--num", type=int) parser.add_argument("dotlist", nargs="*") args = parser.parse_args() cfg = OmegaConf.from_cli(args.dotlist) run( args.split, args.experiment, cfg, args.sequential, output_dir=args.output_dir, num=args.num, )