File size: 1,161 Bytes
9b33fca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""Default runtime configuration for PyTorch Lightning."""

import inspect

from lightning import Trainer

from vis4d.config import FieldConfigDict
from vis4d.config.typing import ExperimentConfig


def get_default_pl_trainer_cfg(config: ExperimentConfig) -> ExperimentConfig:
    """Get PyTorch Lightning Trainer config."""
    pl_trainer = FieldConfigDict()

    # PL Trainer arguments
    for k, v in inspect.signature(Trainer).parameters.items():
        if not k in {"callbacks", "devices", "logger", "strategy"}:
            pl_trainer[k] = v.default

    # PL Trainer settings
    pl_trainer.benchmark = config.benchmark
    pl_trainer.use_distributed_sampler = False
    pl_trainer.num_sanity_val_steps = 0

    # logger
    pl_trainer.enable_progress_bar = False
    pl_trainer.log_every_n_steps = config.log_every_n_steps

    # Default Trainer arguments
    pl_trainer.work_dir = config.work_dir
    pl_trainer.exp_name = config.experiment_name
    pl_trainer.version = config.version
    pl_trainer.find_unused_parameters = False
    pl_trainer.checkpoint_period = 1
    pl_trainer.save_top_k = 1
    pl_trainer.wandb = False

    return pl_trainer