MMDet / mmdetection /demo /video_gpuaccel_demo.py
Saurabh1105's picture
MMdet Model for Image Segmentation
6c9ac8f
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from typing import Tuple
import cv2
import mmcv
import numpy as np
import torch
import torch.nn as nn
from mmcv.transforms import Compose
from mmengine.utils import track_iter_progress
from mmdet.apis import init_detector
from mmdet.registry import VISUALIZERS
from mmdet.structures import DetDataSample
try:
import ffmpegcv
except ImportError:
raise ImportError(
'Please install ffmpegcv with:\n\n pip install ffmpegcv')
def parse_args():
parser = argparse.ArgumentParser(
description='MMDetection video demo with GPU acceleration')
parser.add_argument('video', help='Video file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='Bbox score threshold')
parser.add_argument('--out', type=str, help='Output video file')
parser.add_argument('--show', action='store_true', help='Show video')
parser.add_argument(
'--nvdecode', action='store_true', help='Use NVIDIA decoder')
parser.add_argument(
'--wait-time',
type=float,
default=1,
help='The interval of show (s), 0 is block')
args = parser.parse_args()
return args
def prefetch_batch_input_shape(model: nn.Module, ori_wh: Tuple[int,
int]) -> dict:
cfg = model.cfg
w, h = ori_wh
cfg.test_dataloader.dataset.pipeline[0].type = 'LoadImageFromNDArray'
test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)
data = {'img': np.zeros((h, w, 3), dtype=np.uint8), 'img_id': 0}
data = test_pipeline(data)
_, data_sample = model.data_preprocessor([data], False)
batch_input_shape = data_sample[0].batch_input_shape
return batch_input_shape
def pack_data(frame_resize: np.ndarray, batch_input_shape: Tuple[int, int],
ori_shape: Tuple[int, int]) -> dict:
assert frame_resize.shape[:2] == batch_input_shape
data_sample = DetDataSample()
data_sample.set_metainfo({
'img_shape':
batch_input_shape,
'ori_shape':
ori_shape,
'scale_factor': (batch_input_shape[0] / ori_shape[0],
batch_input_shape[1] / ori_shape[1])
})
frame_resize = torch.from_numpy(frame_resize).permute((2, 0, 1))
data = {'inputs': frame_resize, 'data_sample': data_sample}
return data
def main():
args = parse_args()
assert args.out or args.show, \
('Please specify at least one operation (save/show the '
'video) with the argument "--out" or "--show"')
model = init_detector(args.config, args.checkpoint, device=args.device)
# init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer)
# the dataset_meta is loaded from the checkpoint and
# then pass to the model in init_detector
visualizer.dataset_meta = model.dataset_meta
if args.nvdecode:
VideoCapture = ffmpegcv.VideoCaptureNV
else:
VideoCapture = ffmpegcv.VideoCapture
video_origin = VideoCapture(args.video)
batch_input_shape = prefetch_batch_input_shape(
model, (video_origin.width, video_origin.height))
ori_shape = (video_origin.height, video_origin.width)
resize_wh = batch_input_shape[::-1]
video_resize = VideoCapture(
args.video,
resize=resize_wh,
resize_keepratio=True,
resize_keepratioalign='topleft')
video_writer = None
if args.out:
video_writer = ffmpegcv.VideoWriter(args.out, fps=video_origin.fps)
with torch.no_grad():
for i, (frame_resize, frame_origin) in enumerate(
zip(track_iter_progress(video_resize), video_origin)):
data = pack_data(frame_resize, batch_input_shape, ori_shape)
result = model.test_step([data])[0]
visualizer.add_datasample(
name='video',
image=frame_origin,
data_sample=result,
draw_gt=False,
show=False,
pred_score_thr=args.score_thr)
frame_mask = visualizer.get_image()
if args.show:
cv2.namedWindow('video', 0)
mmcv.imshow(frame_mask, 'video', args.wait_time)
if args.out:
video_writer.write(frame_mask)
if video_writer:
video_writer.release()
video_origin.release()
video_resize.release()
cv2.destroyAllWindows()
if __name__ == '__main__':
main()