File size: 2,245 Bytes
9b33fca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""Default runtime configuration for the project."""

from __future__ import annotations

import platform
from datetime import datetime

from ml_collections import ConfigDict

from vis4d.config import class_config
from vis4d.config.typing import ExperimentConfig
from vis4d.engine.callbacks import LoggingCallback


def get_default_cfg(
    exp_name: str, work_dir: str = "vis4d-workspace"
) -> ExperimentConfig:
    """Set default config for the project.

    Args:
        exp_name (str): Experiment name.
        work_dir (str, optional): Working directory. Defaults to
            "vis4d-workspace".

    Returns:
        ExperimentConfig: Config for the project.
    """
    config = ExperimentConfig()

    config.work_dir = work_dir
    config.experiment_name = exp_name

    timestamp = (
        str(datetime.now())
        .split(".", maxsplit=1)[0]
        .replace(" ", "_")
        .replace(":", "-")
    )
    config.timestamp = timestamp
    config.version = timestamp

    if platform.system() == "Windows":
        path_component = "\\"
    else:
        path_component = "/"

    config.output_dir = (
        config.work_dir
        + path_component
        + config.experiment_name
        + path_component
        + config.version
    )

    # Set default value for the following fields
    config.seed = -1
    config.log_every_n_steps = 50
    config.use_tf32 = False
    config.tf32_matmul_precision = "highest"
    config.benchmark = False
    config.compute_flops = False
    config.check_unused_parameters = False

    return config


def get_default_callbacks_cfg(
    epoch_based: bool = True,
    refresh_rate: int = 50,
) -> list[ConfigDict]:
    """Get default callbacks config.

    It will return a list of callbacks config including:
        - LoggingCallback

    Args:
        epoch_based (bool, optional): Whether to use epoch based logging.
        refresh_rate (int, optional): Refresh rate for the logging. Defaults to
            50.

    Returns:
        list[ConfigDict]: List of callbacks config.
    """
    callbacks = []

    # Logger
    callbacks.append(
        class_config(
            LoggingCallback, epoch_based=epoch_based, refresh_rate=refresh_rate
        )
    )

    return callbacks