# Copyright (c) Meta Platforms, Inc. and affiliates. import argparse from pathlib import Path from typing import Optional, Tuple from omegaconf import OmegaConf, DictConfig from .. import logger from ..conf import data as conf_data_dir from ..data import MapillaryDataModule from .run import evaluate split_overrides = { "val": { "scenes": [ "sanfrancisco_soma", "sanfrancisco_hayes", "amsterdam", "berlin", "lemans", "montrouge", "toulouse", "nantes", "vilnius", "avignon", "helsinki", "milan", "paris", ], }, } data_cfg_train = OmegaConf.load(Path(conf_data_dir.__file__).parent / "mapillary.yaml") data_cfg = OmegaConf.merge( data_cfg_train, { "return_gps": True, "add_map_mask": True, "max_init_error": 32, "loading": {"val": {"batch_size": 1, "num_workers": 0}}, }, ) default_cfg_single = OmegaConf.create({"data": data_cfg}) default_cfg_sequential = OmegaConf.create( { **default_cfg_single, "chunking": { "max_length": 10, }, } ) def run( split: str, experiment: str, cfg: Optional[DictConfig] = None, sequential: bool = False, thresholds: Tuple[int] = (1, 3, 5), **kwargs, ): cfg = cfg or {} if isinstance(cfg, dict): cfg = OmegaConf.create(cfg) default = default_cfg_sequential if sequential else default_cfg_single default = OmegaConf.merge(default, split_overrides[split]) cfg = OmegaConf.merge(default, cfg) dataset = MapillaryDataModule(cfg.get("data", {})) metrics = evaluate(experiment, cfg, dataset, split, sequential=sequential, **kwargs) keys = [ "xy_max_error", "xy_gps_error", "yaw_max_error", ] if sequential: keys += [ "xy_seq_error", "xy_gps_seq_error", "yaw_seq_error", "yaw_gps_seq_error", ] for k in keys: if k not in metrics: logger.warning("Key %s not in metrics.", k) continue rec = metrics[k].recall(thresholds).double().numpy().round(2).tolist() logger.info("Recall %s: %s at %s m/°", k, rec, thresholds) return metrics if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--experiment", type=str, required=True) parser.add_argument("--split", type=str, default="val", choices=["val"]) parser.add_argument("--sequential", action="store_true") parser.add_argument("--output_dir", type=Path) parser.add_argument("--num", type=int) parser.add_argument("dotlist", nargs="*") args = parser.parse_args() cfg = OmegaConf.from_cli(args.dotlist) run( args.split, args.experiment, cfg, args.sequential, output_dir=args.output_dir, num=args.num, )