Spaces:
Running
on
Zero
Running
on
Zero
# pylint: disable=duplicate-code | |
"""RetinaNet COCO training example.""" | |
from __future__ import annotations | |
from torch.optim.lr_scheduler import LinearLR, MultiStepLR | |
from torch.optim.sgd import SGD | |
from vis4d.config import class_config | |
from vis4d.config.typing import ExperimentConfig, ExperimentParameters | |
from vis4d.data.io.hdf5 import HDF5Backend | |
from vis4d.engine.callbacks import EvaluatorCallback, VisualizerCallback | |
from vis4d.engine.connectors import ( | |
CallbackConnector, | |
DataConnector, | |
LossConnector, | |
) | |
from vis4d.engine.loss_module import LossModule | |
from vis4d.eval.coco import COCODetectEvaluator | |
from vis4d.model.detect.retinanet import RetinaNet | |
from vis4d.op.box.encoder import DeltaXYWHBBoxEncoder | |
from vis4d.op.detect.retinanet import ( | |
RetinaNetHeadLoss, | |
get_default_anchor_generator, | |
) | |
from vis4d.vis.image import BoundingBoxVisualizer | |
from vis4d.zoo.base import ( | |
get_default_callbacks_cfg, | |
get_default_cfg, | |
get_default_pl_trainer_cfg, | |
get_lr_scheduler_cfg, | |
get_optimizer_cfg, | |
) | |
from vis4d.zoo.base.data_connectors import ( | |
CONN_BBOX_2D_VIS, | |
CONN_BOX_LOSS_2D, | |
CONN_IMAGES_TEST, | |
CONN_IMAGES_TRAIN, | |
) | |
from vis4d.zoo.base.datasets.coco import ( | |
CONN_COCO_BBOX_EVAL, | |
get_coco_detection_cfg, | |
) | |
def get_config() -> ExperimentConfig: | |
"""Returns the RetinaNet config dict for the coco detection task. | |
This is an example that shows how to set up a training experiment for the | |
COCO detection task. | |
Note that the high level params are exposed in the config. This allows | |
to easily change them from the command line. | |
E.g.: | |
>>> python -m vis4d.engine.run fit --config vis4d/zoo/retinanet/retinanet_rcnn_coco.py --config.num_epochs 100 --config.params.lr 0.001 | |
Returns: | |
ExperimentConfig: The configuration | |
""" | |
###################################################### | |
## General Config ## | |
###################################################### | |
config = get_default_cfg(exp_name="retinanet_r50_fpn_coco") | |
# High level hyper parameters | |
params = ExperimentParameters() | |
params.samples_per_gpu = 2 | |
params.workers_per_gpu = 2 | |
params.lr = 0.01 | |
params.num_epochs = 12 | |
params.num_classes = 80 | |
config.params = params | |
###################################################### | |
## Datasets with augmentations ## | |
###################################################### | |
data_root = "data/coco" | |
train_split = "train2017" | |
test_split = "val2017" | |
data_backend = class_config(HDF5Backend) | |
config.data = get_coco_detection_cfg( | |
data_root=data_root, | |
train_split=train_split, | |
test_split=test_split, | |
data_backend=data_backend, | |
samples_per_gpu=params.samples_per_gpu, | |
workers_per_gpu=params.workers_per_gpu, | |
) | |
###################################################### | |
## MODEL & LOSS ## | |
###################################################### | |
config.model = class_config( | |
RetinaNet, | |
num_classes=params.num_classes, | |
# weights="mmdet", | |
) | |
box_encoder = class_config( | |
DeltaXYWHBBoxEncoder, | |
target_means=(0.0, 0.0, 0.0, 0.0), | |
target_stds=(1.0, 1.0, 1.0, 1.0), | |
) | |
anchor_generator = class_config(get_default_anchor_generator) | |
retina_loss = class_config( | |
RetinaNetHeadLoss, | |
box_encoder=box_encoder, | |
anchor_generator=anchor_generator, | |
) | |
config.loss = class_config( | |
LossModule, | |
losses={ | |
"loss": retina_loss, | |
"connector": class_config( | |
LossConnector, key_mapping=CONN_BOX_LOSS_2D | |
), | |
}, | |
) | |
###################################################### | |
## OPTIMIZERS ## | |
###################################################### | |
config.optimizers = [ | |
get_optimizer_cfg( | |
optimizer=class_config( | |
SGD, lr=params.lr, momentum=0.9, weight_decay=0.0001 | |
), | |
lr_schedulers=[ | |
get_lr_scheduler_cfg( | |
class_config( | |
LinearLR, start_factor=0.001, total_iters=500 | |
), | |
end=500, | |
epoch_based=False, | |
), | |
get_lr_scheduler_cfg( | |
class_config(MultiStepLR, milestones=[8, 11], gamma=0.1), | |
), | |
], | |
) | |
] | |
###################################################### | |
## DATA CONNECTOR ## | |
###################################################### | |
config.train_data_connector = class_config( | |
DataConnector, | |
key_mapping=CONN_IMAGES_TRAIN, | |
) | |
config.test_data_connector = class_config( | |
DataConnector, | |
key_mapping=CONN_IMAGES_TEST, | |
) | |
###################################################### | |
## CALLBACKS ## | |
###################################################### | |
# Logger | |
callbacks = get_default_callbacks_cfg() | |
# Visualizer | |
callbacks.append( | |
class_config( | |
VisualizerCallback, | |
visualizer=class_config(BoundingBoxVisualizer, vis_freq=100), | |
output_dir=config.output_dir, | |
test_connector=class_config( | |
CallbackConnector, | |
key_mapping=CONN_BBOX_2D_VIS, | |
), | |
) | |
) | |
# Evaluator | |
callbacks.append( | |
class_config( | |
EvaluatorCallback, | |
evaluator=class_config( | |
COCODetectEvaluator, | |
data_root=data_root, | |
split=test_split, | |
), | |
metrics_to_eval=["Det"], | |
test_connector=class_config( | |
CallbackConnector, | |
key_mapping=CONN_COCO_BBOX_EVAL, | |
), | |
) | |
) | |
config.callbacks = callbacks | |
###################################################### | |
## PL CLI ## | |
###################################################### | |
# PL Trainer args | |
pl_trainer = get_default_pl_trainer_cfg(config) | |
pl_trainer.max_epochs = params.num_epochs | |
config.pl_trainer = pl_trainer | |
return config.value_mode() | |