|
|
|
|
|
import argparse |
|
from pathlib import Path |
|
from typing import Optional, Tuple |
|
|
|
from omegaconf import OmegaConf, DictConfig |
|
|
|
from .. import logger |
|
from ..data import KittiDataModule |
|
from .run import evaluate |
|
|
|
|
|
default_cfg_single = OmegaConf.create({}) |
|
|
|
|
|
|
|
default_cfg_sequential = OmegaConf.create( |
|
{ |
|
"data": { |
|
"mask_radius": KittiDataModule.default_cfg["max_init_error"], |
|
"prior_range_rotation": KittiDataModule.default_cfg[ |
|
"max_init_error_rotation" |
|
] |
|
+ 1, |
|
"max_init_error": 0, |
|
"max_init_error_rotation": 0, |
|
}, |
|
"chunking": { |
|
"max_length": 100, |
|
}, |
|
} |
|
) |
|
|
|
|
|
def run( |
|
split: str, |
|
experiment: str, |
|
cfg: Optional[DictConfig] = None, |
|
sequential: bool = False, |
|
thresholds: Tuple[int] = (1, 3, 5), |
|
**kwargs, |
|
): |
|
cfg = cfg or {} |
|
if isinstance(cfg, dict): |
|
cfg = OmegaConf.create(cfg) |
|
default = default_cfg_sequential if sequential else default_cfg_single |
|
cfg = OmegaConf.merge(default, cfg) |
|
dataset = KittiDataModule(cfg.get("data", {})) |
|
|
|
metrics = evaluate( |
|
experiment, |
|
cfg, |
|
dataset, |
|
split=split, |
|
sequential=sequential, |
|
viz_kwargs=dict(show_dir_error=True, show_masked_prob=False), |
|
**kwargs, |
|
) |
|
|
|
keys = ["directional_error", "yaw_max_error"] |
|
if sequential: |
|
keys += ["directional_seq_error", "yaw_seq_error"] |
|
for k in keys: |
|
rec = metrics[k].recall(thresholds).double().numpy().round(2).tolist() |
|
logger.info("Recall %s: %s at %s m/°", k, rec, thresholds) |
|
return metrics |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--experiment", type=str, required=True) |
|
parser.add_argument( |
|
"--split", type=str, default="test", choices=["test", "val", "train"] |
|
) |
|
parser.add_argument("--sequential", action="store_true") |
|
parser.add_argument("--output_dir", type=Path) |
|
parser.add_argument("--num", type=int) |
|
parser.add_argument("dotlist", nargs="*") |
|
args = parser.parse_args() |
|
cfg = OmegaConf.from_cli(args.dotlist) |
|
run( |
|
args.split, |
|
args.experiment, |
|
cfg, |
|
args.sequential, |
|
output_dir=args.output_dir, |
|
num=args.num, |
|
) |
|
|