Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,347 Bytes
9b33fca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
# pylint: disable=duplicate-code
"""YOLOX COCO."""
from __future__ import annotations
from lightning.pytorch.callbacks import ModelCheckpoint
from vis4d.config import class_config
from vis4d.config.typing import ExperimentConfig, ExperimentParameters
from vis4d.data.const import CommonKeys as K
from vis4d.data.io.hdf5 import HDF5Backend
from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback
from vis4d.engine.connectors import CallbackConnector, DataConnector
from vis4d.eval.coco import COCODetectEvaluator
from vis4d.vis.image import BoundingBoxVisualizer
from vis4d.zoo.base import (
get_default_callbacks_cfg,
get_default_cfg,
get_default_pl_trainer_cfg,
)
from vis4d.zoo.base.data_connectors import CONN_BBOX_2D_TEST, CONN_BBOX_2D_VIS
from vis4d.zoo.base.models.yolox import (
get_yolox_callbacks_cfg,
get_yolox_cfg,
get_yolox_optimizers_cfg,
)
from vis4d.zoo.yolox.data import CONN_COCO_BBOX_EVAL, get_coco_yolox_cfg
CONN_BBOX_2D_TRAIN = {"images": K.images}
def get_config() -> ExperimentConfig:
"""Returns the YOLOX config dict for the coco detection task.
Returns:
ExperimentConfig: The configuration
"""
######################################################
## General Config ##
######################################################
config = get_default_cfg(exp_name="yolox_s_300e_coco")
config.checkpoint_period = 15
config.check_val_every_n_epoch = 10
# High level hyper parameters
params = ExperimentParameters()
params.samples_per_gpu = 8
params.workers_per_gpu = 4
params.lr = 0.01
params.num_epochs = 300
params.num_classes = 80
config.params = params
######################################################
## Datasets with augmentations ##
######################################################
data_root = "data/coco"
train_split = "train2017"
test_split = "val2017"
data_backend = class_config(HDF5Backend)
config.data = get_coco_yolox_cfg(
data_root=data_root,
train_split=train_split,
test_split=test_split,
data_backend=data_backend,
samples_per_gpu=params.samples_per_gpu,
workers_per_gpu=params.workers_per_gpu,
)
######################################################
## MODEL & LOSS ##
######################################################
config.model, config.loss = get_yolox_cfg(params.num_classes, "small")
######################################################
## OPTIMIZERS ##
######################################################
num_last_epochs, warmup_epochs = 15, 5
config.optimizers = get_yolox_optimizers_cfg(
params.lr, params.num_epochs, warmup_epochs, num_last_epochs
)
######################################################
## DATA CONNECTOR ##
######################################################
config.train_data_connector = class_config(
DataConnector, key_mapping=CONN_BBOX_2D_TRAIN
)
config.test_data_connector = class_config(
DataConnector, key_mapping=CONN_BBOX_2D_TEST
)
######################################################
## CALLBACKS ##
######################################################
# Logger
callbacks = get_default_callbacks_cfg(
refresh_rate=config.log_every_n_steps
)
# YOLOX callbacks
callbacks += get_yolox_callbacks_cfg(
switch_epoch=params.num_epochs - num_last_epochs
)
# Visualizer
callbacks.append(
class_config(
VisualizerCallback,
visualizer=class_config(
BoundingBoxVisualizer, vis_freq=100, image_mode="BGR"
),
output_dir=config.output_dir,
test_connector=class_config(
CallbackConnector, key_mapping=CONN_BBOX_2D_VIS
),
)
)
# Evaluator
callbacks.append(
class_config(
EvaluatorCallback,
evaluator=class_config(
COCODetectEvaluator, data_root=data_root, split=test_split
),
metrics_to_eval=["Det"],
test_connector=class_config(
CallbackConnector, key_mapping=CONN_COCO_BBOX_EVAL
),
)
)
config.callbacks = callbacks
######################################################
## PL CLI ##
######################################################
# PL Trainer args
pl_trainer = get_default_pl_trainer_cfg(config)
pl_trainer.max_epochs = params.num_epochs
pl_trainer.check_val_every_n_epoch = config.check_val_every_n_epoch
pl_trainer.checkpoint_callback = class_config(
ModelCheckpoint,
dirpath=config.get_ref("output_dir") + "/checkpoints",
verbose=True,
save_last=True,
save_on_train_epoch_end=True,
every_n_epochs=config.checkpoint_period,
save_top_k=3,
mode="max",
monitor="step",
)
pl_trainer.wandb = True
config.pl_trainer = pl_trainer
return config.value_mode()
|