import logging |
import re |
import tempfile |
from argparse import ArgumentParser |
from collections import OrderedDict |
from functools import partial |
from pathlib import Path |
import numpy as np |
import pandas as pd |
import torch |
from mmengine import Config, DictAction |
from mmengine.analysis import get_model_complexity_info |
from mmengine.analysis.print_helper import _format_size |
from mmengine.fileio import FileClient |
from mmengine.logging import MMLogger |
from mmengine.model import revert_sync_batchnorm |
from mmengine.runner import Runner |
from modelindex.load_model_index import load |
from rich.console import Console |
from rich.table import Table |
from rich.text import Text |
from tqdm import tqdm |
from mmdet.registry import MODELS |
from mmdet.utils import register_all_modules |
console = Console() |
MMDET_ROOT = Path(__file__).absolute().parents[1] |
def parse_args(): |
parser = ArgumentParser(description='Valid all models in model-index.yml') |
parser.add_argument( |
'--shape', |
type=int, |
nargs='+', |
default=[1280, 800], |
help='input image size') |
parser.add_argument( |
'--checkpoint_root', |
help='Checkpoint file root path. If set, load checkpoint before test.') |
parser.add_argument('--img', default='demo/demo.jpg', help='Image file') |
parser.add_argument('--models', nargs='+', help='models name to inference') |
parser.add_argument( |
'--batch-size', |
type=int, |
default=1, |
help='The batch size during the inference.') |
parser.add_argument( |
'--flops', action='store_true', help='Get Flops and Params of models') |
parser.add_argument( |
'--flops-str', |
action='store_true', |
help='Output FLOPs and params counts in a string form.') |
parser.add_argument( |
'--cfg-options', |
nargs='+', |
action=DictAction, |
help='override some settings in the used config, the key-value pair ' |
'in xxx=yyy format will be merged into config file. If the value to ' |
'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' |
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' |
'Note that the quotation marks are necessary and that no white space ' |
'is allowed.') |
parser.add_argument( |
'--size_divisor', |
type=int, |
default=32, |
help='Pad the input image, the minimum size that is divisible ' |
'by size_divisor, -1 means do not pad the image.') |
args = parser.parse_args() |
return args |
def inference(config_file, checkpoint, work_dir, args, exp_name): |
logger = MMLogger.get_instance(name='MMLogger') |
logger.warning('if you want test flops, please make sure torch>=1.12') |
cfg = Config.fromfile(config_file) |
cfg.work_dir = work_dir |
cfg.load_from = checkpoint |
cfg.log_level = 'WARN' |
cfg.experiment_name = exp_name |
if args.cfg_options is not None: |
cfg.merge_from_dict(args.cfg_options) |
result = {'model': config_file.stem} |
if args.flops: |
if len(args.shape) == 1: |
h = w = args.shape[0] |
elif len(args.shape) == 2: |
h, w = args.shape |
else: |
raise ValueError('invalid input shape') |
divisor = args.size_divisor |
if divisor > 0: |
h = int(np.ceil(h / divisor)) * divisor |
w = int(np.ceil(w / divisor)) * divisor |
input_shape = (3, h, w) |
result['resolution'] = input_shape |
try: |
cfg = Config.fromfile(config_file) |
if hasattr(cfg, 'head_norm_cfg'): |
cfg['head_norm_cfg'] = dict(type='SyncBN', requires_grad=True) |
cfg['model']['roi_head']['bbox_head']['norm_cfg'] = dict( |
type='SyncBN', requires_grad=True) |
cfg['model']['roi_head']['mask_head']['norm_cfg'] = dict( |
type='SyncBN', requires_grad=True) |
if args.cfg_options is not None: |
cfg.merge_from_dict(args.cfg_options) |
model = MODELS.build(cfg.model) |
input = torch.rand(1, *input_shape) |
if torch.cuda.is_available(): |
model.cuda() |
input = input.cuda() |
model = revert_sync_batchnorm(model) |
inputs = (input, ) |
model.eval() |
outputs = get_model_complexity_info( |
model, input_shape, inputs, show_table=False, show_arch=False) |
flops = outputs['flops'] |
params = outputs['params'] |
activations = outputs['activations'] |
result['Get Types'] = 'direct' |
except: |
logger = MMLogger.get_instance(name='MMLogger') |
logger.warning( |
'Direct get flops failed, try to get flops with data') |
cfg = Config.fromfile(config_file) |
if hasattr(cfg, 'head_norm_cfg'): |
cfg['head_norm_cfg'] = dict(type='SyncBN', requires_grad=True) |
cfg['model']['roi_head']['bbox_head']['norm_cfg'] = dict( |
type='SyncBN', requires_grad=True) |
cfg['model']['roi_head']['mask_head']['norm_cfg'] = dict( |
type='SyncBN', requires_grad=True) |
data_loader = Runner.build_dataloader(cfg.val_dataloader) |
data_batch = next(iter(data_loader)) |
model = MODELS.build(cfg.model) |
if torch.cuda.is_available(): |
model = model.cuda() |
model = revert_sync_batchnorm(model) |
model.eval() |
_forward = model.forward |
data = model.data_preprocessor(data_batch) |
del data_loader |
model.forward = partial( |
_forward, data_samples=data['data_samples']) |
outputs = get_model_complexity_info( |
model, |
input_shape, |
data['inputs'], |
show_table=False, |
show_arch=False) |
flops = outputs['flops'] |
params = outputs['params'] |
activations = outputs['activations'] |
result['Get Types'] = 'dataloader' |
if args.flops_str: |
flops = _format_size(flops) |
params = _format_size(params) |
activations = _format_size(activations) |
result['flops'] = flops |
result['params'] = params |
return result |
def show_summary(summary_data, args): |
table = Table(title='Validation Benchmark Regression Summary') |
table.add_column('Model') |
table.add_column('Validation') |
table.add_column('Resolution (c, h, w)') |
if args.flops: |
table.add_column('Flops', justify='right', width=11) |
table.add_column('Params', justify='right') |
for model_name, summary in summary_data.items(): |
row = [model_name] |
valid = summary['valid'] |
color = 'green' if valid == 'PASS' else 'red' |
row.append(f'[{color}]{valid}[/{color}]') |
if valid == 'PASS': |
row.append(str(summary['resolution'])) |
if args.flops: |
row.append(str(summary['flops'])) |
row.append(str(summary['params'])) |
table.add_row(*row) |
console.print(table) |
table_data = { |
x.header: [Text.from_markup(y).plain for y in x.cells] |
for x in table.columns |
} |
table_pd = pd.DataFrame(table_data) |
table_pd.to_csv('./mmdetection_flops.csv') |
def main(args): |
register_all_modules() |
model_index_file = MMDET_ROOT / 'model-index.yml' |
model_index = load(str(model_index_file)) |
model_index.build_models_with_collections() |
models = OrderedDict({model.name: model for model in model_index.models}) |
logger = MMLogger( |
'validation', |
logger_name='validation', |
log_file='benchmark_test_image.log', |
log_level=logging.INFO) |
if args.models: |
patterns = [ |
re.compile(pattern.replace('+', '_')) for pattern in args.models |
] |
filter_models = {} |
for k, v in models.items(): |
k = k.replace('+', '_') |
if any([re.match(pattern, k) for pattern in patterns]): |
filter_models[k] = v |
if len(filter_models) == 0: |
print('No model found, please specify models in:') |
print('\n'.join(models.keys())) |
return |
models = filter_models |
summary_data = {} |
tmpdir = tempfile.TemporaryDirectory() |
for model_name, model_info in tqdm(models.items()): |
if model_info.config is None: |
continue |
model_info.config = model_info.config.replace('%2B', '+') |
config = Path(model_info.config) |
try: |
config.exists() |
except: |
logger.error(f'{model_name}: {config} not found.') |
continue |
logger.info(f'Processing: {model_name}') |
http_prefix = 'https://download.openmmlab.com/mmdetection/' |
if args.checkpoint_root is not None: |
root = args.checkpoint_root |
if 's3://' in args.checkpoint_root: |
from petrel_client.common.exception import AccessDeniedError |
file_client = FileClient.infer_client(uri=root) |
checkpoint = file_client.join_path( |
root, model_info.weights[len(http_prefix):]) |
try: |
exists = file_client.exists(checkpoint) |
except AccessDeniedError: |
exists = False |
else: |
checkpoint = Path(root) / model_info.weights[len(http_prefix):] |
exists = checkpoint.exists() |
if exists: |
checkpoint = str(checkpoint) |
else: |
print(f'WARNING: {model_name}: {checkpoint} not found.') |
checkpoint = None |
else: |
checkpoint = None |
try: |
result = inference(MMDET_ROOT / config, checkpoint, tmpdir.name, |
args, model_name) |
result['valid'] = 'PASS' |
except Exception: |
import traceback |
logger.error(f'"{config}" :\n{traceback.format_exc()}') |
result = {'valid': 'FAIL'} |
summary_data[model_name] = result |
tmpdir.cleanup() |
show_summary(summary_data, args) |
if __name__ == '__main__': |
args = parse_args() |
main(args) |