File size: 2,611 Bytes
124ba77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Copyright (c) Meta Platforms, Inc. and affiliates.

import argparse
from pathlib import Path
from typing import Optional, Tuple

from omegaconf import OmegaConf, DictConfig

from .. import logger
from ..data import KittiDataModule
from .run import evaluate


default_cfg_single = OmegaConf.create({})
# For the sequential evaluation, we need to center the map around the GT location,
# since random offsets would accumulate and leave only the GT location with a valid mask.
# This should not have much impact on the results.
default_cfg_sequential = OmegaConf.create(
    {
        "data": {
            "mask_radius": KittiDataModule.default_cfg["max_init_error"],
            "prior_range_rotation": KittiDataModule.default_cfg[
                "max_init_error_rotation"
            ]
            + 1,
            "max_init_error": 0,
            "max_init_error_rotation": 0,
        },
        "chunking": {
            "max_length": 100,  # about 10s?
        },
    }
)


def run(
    split: str,
    experiment: str,
    cfg: Optional[DictConfig] = None,
    sequential: bool = False,
    thresholds: Tuple[int] = (1, 3, 5),
    **kwargs,
):
    cfg = cfg or {}
    if isinstance(cfg, dict):
        cfg = OmegaConf.create(cfg)
    default = default_cfg_sequential if sequential else default_cfg_single
    cfg = OmegaConf.merge(default, cfg)
    dataset = KittiDataModule(cfg.get("data", {}))

    metrics = evaluate(
        experiment,
        cfg,
        dataset,
        split=split,
        sequential=sequential,
        viz_kwargs=dict(show_dir_error=True, show_masked_prob=False),
        **kwargs,
    )

    keys = ["directional_error", "yaw_max_error"]
    if sequential:
        keys += ["directional_seq_error", "yaw_seq_error"]
    for k in keys:
        rec = metrics[k].recall(thresholds).double().numpy().round(2).tolist()
        logger.info("Recall %s: %s at %s m/°", k, rec, thresholds)
    return metrics


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--experiment", type=str, required=True)
    parser.add_argument(
        "--split", type=str, default="test", choices=["test", "val", "train"]
    )
    parser.add_argument("--sequential", action="store_true")
    parser.add_argument("--output_dir", type=Path)
    parser.add_argument("--num", type=int)
    parser.add_argument("dotlist", nargs="*")
    args = parser.parse_args()
    cfg = OmegaConf.from_cli(args.dotlist)
    run(
        args.split,
        args.experiment,
        cfg,
        args.sequential,
        output_dir=args.output_dir,
        num=args.num,
    )