MMDet / mmdetection /.dev_scripts /benchmark_train.py
Saurabh1105's picture
MMdet Model for Image Segmentation
6c9ac8f
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os
import os.path as osp
from argparse import ArgumentParser
from mmengine.config import Config, DictAction
from mmengine.logging import MMLogger, print_log
from mmengine.registry import RUNNERS
from mmengine.runner import Runner
from mmdet.testing import replace_to_ceph
from mmdet.utils import register_all_modules, replace_cfg_vals
def parse_args():
parser = ArgumentParser()
parser.add_argument('config', help='test config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument('--ceph', action='store_true')
parser.add_argument('--save-ckpt', action='store_true')
parser.add_argument(
'--amp',
action='store_true',
default=False,
help='enable automatic-mixed-precision training')
parser.add_argument(
'--auto-scale-lr',
action='store_true',
help='enable automatically scaling LR.')
parser.add_argument(
'--resume',
action='store_true',
help='resume from the latest checkpoint in the work_dir automatically')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
args = parser.parse_args()
return args
# TODO: Need to refactor train.py so that it can be reused.
def fast_train_model(config_name, args, logger=None):
cfg = Config.fromfile(config_name)
cfg = replace_cfg_vals(cfg)
cfg.launcher = args.launcher
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = osp.join(args.work_dir,
osp.splitext(osp.basename(config_name))[0])
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(config_name))[0])
ckpt_hook = cfg.default_hooks.checkpoint
by_epoch = ckpt_hook.get('by_epoch', True)
fast_stop_hook = dict(type='FastStopTrainingHook')
fast_stop_hook['by_epoch'] = by_epoch
if args.save_ckpt:
if by_epoch:
interval = 1
stop_iter_or_epoch = 2
else:
interval = 4
stop_iter_or_epoch = 10
fast_stop_hook['stop_iter_or_epoch'] = stop_iter_or_epoch
fast_stop_hook['save_ckpt'] = True
ckpt_hook.interval = interval
if 'custom_hooks' in cfg:
cfg.custom_hooks.append(fast_stop_hook)
else:
custom_hooks = [fast_stop_hook]
cfg.custom_hooks = custom_hooks
# TODO: temporary plan
if 'visualizer' in cfg:
if 'name' in cfg.visualizer:
del cfg.visualizer.name
# enable automatic-mixed-precision training
if args.amp is True:
optim_wrapper = cfg.optim_wrapper.type
if optim_wrapper == 'AmpOptimWrapper':
print_log(
'AMP training is already enabled in your config.',
logger='current',
level=logging.WARNING)
else:
assert optim_wrapper == 'OptimWrapper', (
'`--amp` is only supported when the optimizer wrapper type is '
f'`OptimWrapper` but got {optim_wrapper}.')
cfg.optim_wrapper.type = 'AmpOptimWrapper'
cfg.optim_wrapper.loss_scale = 'dynamic'
# enable automatically scaling LR
if args.auto_scale_lr:
if 'auto_scale_lr' in cfg and \
'enable' in cfg.auto_scale_lr and \
'base_batch_size' in cfg.auto_scale_lr:
cfg.auto_scale_lr.enable = True
else:
raise RuntimeError('Can not find "auto_scale_lr" or '
'"auto_scale_lr.enable" or '
'"auto_scale_lr.base_batch_size" in your'
' configuration file.')
if args.ceph:
replace_to_ceph(cfg)
cfg.resume = args.resume
# build the runner from config
if 'runner_type' not in cfg:
# build the default runner
runner = Runner.from_cfg(cfg)
else:
# build customized runner from the registry
# if 'runner_type' is set in the cfg
runner = RUNNERS.build(cfg)
runner.train()
# Sample test whether the train code is correct
def main(args):
# register all modules in mmdet into the registries
register_all_modules(init_default_scope=False)
config = Config.fromfile(args.config)
# test all model
logger = MMLogger.get_instance(
name='MMLogger',
log_file='benchmark_train.log',
log_level=logging.ERROR)
for model_key in config:
model_infos = config[model_key]
if not isinstance(model_infos, list):
model_infos = [model_infos]
for model_info in model_infos:
print('processing: ', model_info['config'], flush=True)
config_name = model_info['config'].strip()
try:
fast_train_model(config_name, args, logger)
except RuntimeError as e:
# quick exit is the normal exit message
if 'quick exit' not in repr(e):
logger.error(f'{config_name} " : {repr(e)}')
except Exception as e:
logger.error(f'{config_name} " : {repr(e)}')
if __name__ == '__main__':
args = parse_args()
main(args)