Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,716 Bytes
9b33fca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
"""Type definitions for configuration files."""
from __future__ import annotations
from typing import Any, TypedDict
from ml_collections import ConfigDict, FieldReference
from typing_extensions import NotRequired
from .config_dict import FieldConfigDict
class ParamGroupCfg(TypedDict):
"""Parameter group config.
Attributes:
custom_keys (list[str]): List of custom keys.
lr_mult (NotRequired[float]): Learning rate multiplier.
decay_mult (NotRequired[float]): Weight Decay multiplier.
"""
custom_keys: list[str]
lr_mult: NotRequired[float]
decay_mult: NotRequired[float]
norm_decay_mult: NotRequired[float]
bias_decay_mult: NotRequired[float]
class DataConfig(ConfigDict): # type: ignore
"""Configuration for a data set.
This data object is used to configure the training and test data of an
experiment. In particular, the train_dataloader and test_dataloader
need to be config dicts that can be instantiated as a dataloader.
Attributes:
train_dataloader (ConfigDict): Configuration for the training
dataloader.
test_dataloader (ConfigDict): Configuration for the test dataloader.
Example:
>>> from vis4d.config.types import DataConfig
>>> from vis4d.zoo.base import class_config
>>> from my_package.data import MyDataLoader
>>> cfg = DataConfig()
>>> cfg.train_dataloader = class_config(MyDataLoader, ...)
"""
train_dataloader: ConfigDict
test_dataloader: ConfigDict
class LrSchedulerConfig(ConfigDict): # type: ignore
"""Configuration for a learning rate scheduler.
Attributes:
scheduler (ConfigDict): Configuration for the learning rate scheduler.
begin (int): Begin epoch.
end (int): End epoch.
epoch_based (bool): Whether the learning rate scheduler is epoch based
or step based.
convert_epochs_to_steps (bool): Whether to convert the begin and end
for a step based scheduler to steps automatically based on length
of train dataloader. Enables users to set the iteration breakpoints
as epochs. Defaults to False.
convert_attributes (list[str] | None): List of attributes in the
scheduler that should be converted to steps. Defaults to None.
"""
scheduler: ConfigDict
begin: int
end: int
epoch_based: bool
convert_epochs_to_steps: bool = False
convert_attributes: list[str] | None = None
class OptimizerConfig(ConfigDict): # type: ignore
"""Configuration for an optimizer.
Attributes:
optimizer (ConfigDict): Configuration for the optimizer.
lr_scheduler (list[LrSchedulerConfig] | None): Configuration for the
learning rate scheduler.
param_groups (list[ParamGroupCfg] | None): Configuration for the
parameter groups.
"""
optimizer: ConfigDict
lr_scheduler: list[LrSchedulerConfig] | None
param_groups: list[ParamGroupCfg] | None
class ExperimentParameters(FieldConfigDict):
"""Parameters for an experiment.
Attributes:
samples_per_gpu (int): Number of samples per GPU.
workers_per_gpu (int): Number of workers per GPU.
"""
samples_per_gpu: int
workers_per_gpu: int
class ExperimentConfig(FieldConfigDict):
"""Configuration for an experiment.
This data object is used to configure an experiment. It contains the
minimal required configuration to run an experiment. In particular, the
data, model, optimizers, and loss need to be config dicts that can be
instantiated as a data set, model, optimizer, and loss function,
respectively.
Attributes:
work_dir (str | FieldReference): The working directory for the
experiment.
experiment_name (str | FieldReference): The name of the experiment.
timestamp (str | FieldReference): The timestamp of the experiment.
version (str | FieldReference): The version of the experiment.
output_dir (str | FieldReference): The output directory for the
experiment.
seed (int | FieldReference): The random seed for the experiment.
log_every_n_steps (int | FieldReference): The number of steps after
which the logs should be written.
use_tf32 (bool | FieldReference): Whether to use tf32.
benchmark (bool | FieldReference): Whether to enable benchmarking.
params (ExperimentParameters): Configuration for the experiment
parameters.
data (DataConfig): Configuration for the dataset.
model (FieldConfigDictOrRef): Configuration for the model.
loss (FieldConfigDictOrRef): Configuration for the loss function.
optimizers (list[OptimizerConfig]): Configuration for the optimizers.
data_connector (FieldConfigDictOrRef): Configuration for the data
connector.
callbacks (list[FieldConfigDictOrRef]): Configuration for the
callbacks which are used in the engine.
"""
# General
work_dir: str | FieldReference
experiment_name: str | FieldReference
timestamp: str | FieldReference
version: str | FieldReference
output_dir: str | FieldReference
seed: int | FieldReference
log_every_n_steps: int | FieldReference
use_tf32: bool | FieldReference
benchmark: bool | FieldReference
tf32_matmul_precision: str | FieldReference
params: ExperimentParameters
# Data
data: DataConfig
# Model
model: ConfigDict
# Loss
loss: ConfigDict
# Optimizer
optimizers: list[OptimizerConfig]
# Data connector
train_data_connector: ConfigDict
test_data_connector: ConfigDict
# Callbacks
callbacks: list[ConfigDict]
class ParameterSweepConfig(FieldConfigDict):
"""Configuration for a parameter sweep.
Confguration object for a parameter sweep. It contains the minimal required
configuration to run a parameter sweep.
Attributes:
method (str): Sweep method that should be used (e.g. grid)
sampling_args (list[tuple[str, Any]]): Arguments that should be passed
to the sweep method. E.g. for grid, this would be a list of tuples
of the form (parameter_name, parameter_values).
suffix (str): Suffix that should be appended to the output directory.
This will be interpreted as a string template and can contain
references to the sampling_args.
E.g. "lr_{lr:.2e}_bs_{batch_size}".
"""
method: str | FieldReference
sampling_args: list[tuple[str, Any]] | FieldReference # type: ignore
suffix: str | FieldReference = ""
|