|
|
|
import argparse |
|
from typing import Tuple |
|
|
|
import cv2 |
|
import mmcv |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from mmcv.transforms import Compose |
|
from mmengine.utils import track_iter_progress |
|
|
|
from mmdet.apis import init_detector |
|
from mmdet.registry import VISUALIZERS |
|
from mmdet.structures import DetDataSample |
|
|
|
try: |
|
import ffmpegcv |
|
except ImportError: |
|
raise ImportError( |
|
'Please install ffmpegcv with:\n\n pip install ffmpegcv') |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser( |
|
description='MMDetection video demo with GPU acceleration') |
|
parser.add_argument('video', help='Video file') |
|
parser.add_argument('config', help='Config file') |
|
parser.add_argument('checkpoint', help='Checkpoint file') |
|
parser.add_argument( |
|
'--device', default='cuda:0', help='Device used for inference') |
|
parser.add_argument( |
|
'--score-thr', type=float, default=0.3, help='Bbox score threshold') |
|
parser.add_argument('--out', type=str, help='Output video file') |
|
parser.add_argument('--show', action='store_true', help='Show video') |
|
parser.add_argument( |
|
'--nvdecode', action='store_true', help='Use NVIDIA decoder') |
|
parser.add_argument( |
|
'--wait-time', |
|
type=float, |
|
default=1, |
|
help='The interval of show (s), 0 is block') |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def prefetch_batch_input_shape(model: nn.Module, ori_wh: Tuple[int, |
|
int]) -> dict: |
|
cfg = model.cfg |
|
w, h = ori_wh |
|
cfg.test_dataloader.dataset.pipeline[0].type = 'LoadImageFromNDArray' |
|
test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline) |
|
data = {'img': np.zeros((h, w, 3), dtype=np.uint8), 'img_id': 0} |
|
data = test_pipeline(data) |
|
_, data_sample = model.data_preprocessor([data], False) |
|
batch_input_shape = data_sample[0].batch_input_shape |
|
return batch_input_shape |
|
|
|
|
|
def pack_data(frame_resize: np.ndarray, batch_input_shape: Tuple[int, int], |
|
ori_shape: Tuple[int, int]) -> dict: |
|
assert frame_resize.shape[:2] == batch_input_shape |
|
data_sample = DetDataSample() |
|
data_sample.set_metainfo({ |
|
'img_shape': |
|
batch_input_shape, |
|
'ori_shape': |
|
ori_shape, |
|
'scale_factor': (batch_input_shape[0] / ori_shape[0], |
|
batch_input_shape[1] / ori_shape[1]) |
|
}) |
|
frame_resize = torch.from_numpy(frame_resize).permute((2, 0, 1)) |
|
data = {'inputs': frame_resize, 'data_sample': data_sample} |
|
return data |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
assert args.out or args.show, \ |
|
('Please specify at least one operation (save/show the ' |
|
'video) with the argument "--out" or "--show"') |
|
|
|
model = init_detector(args.config, args.checkpoint, device=args.device) |
|
|
|
|
|
visualizer = VISUALIZERS.build(model.cfg.visualizer) |
|
|
|
|
|
visualizer.dataset_meta = model.dataset_meta |
|
|
|
if args.nvdecode: |
|
VideoCapture = ffmpegcv.VideoCaptureNV |
|
else: |
|
VideoCapture = ffmpegcv.VideoCapture |
|
video_origin = VideoCapture(args.video) |
|
|
|
batch_input_shape = prefetch_batch_input_shape( |
|
model, (video_origin.width, video_origin.height)) |
|
ori_shape = (video_origin.height, video_origin.width) |
|
resize_wh = batch_input_shape[::-1] |
|
video_resize = VideoCapture( |
|
args.video, |
|
resize=resize_wh, |
|
resize_keepratio=True, |
|
resize_keepratioalign='topleft') |
|
|
|
video_writer = None |
|
if args.out: |
|
video_writer = ffmpegcv.VideoWriter(args.out, fps=video_origin.fps) |
|
|
|
with torch.no_grad(): |
|
for i, (frame_resize, frame_origin) in enumerate( |
|
zip(track_iter_progress(video_resize), video_origin)): |
|
data = pack_data(frame_resize, batch_input_shape, ori_shape) |
|
result = model.test_step([data])[0] |
|
|
|
visualizer.add_datasample( |
|
name='video', |
|
image=frame_origin, |
|
data_sample=result, |
|
draw_gt=False, |
|
show=False, |
|
pred_score_thr=args.score_thr) |
|
|
|
frame_mask = visualizer.get_image() |
|
|
|
if args.show: |
|
cv2.namedWindow('video', 0) |
|
mmcv.imshow(frame_mask, 'video', args.wait_time) |
|
if args.out: |
|
video_writer.write(frame_mask) |
|
|
|
if video_writer: |
|
video_writer.release() |
|
video_origin.release() |
|
video_resize.release() |
|
|
|
cv2.destroyAllWindows() |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|