File size: 4,725 Bytes
6c9ac8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from typing import Tuple

import cv2
import mmcv
import numpy as np
import torch
import torch.nn as nn
from mmcv.transforms import Compose
from mmengine.utils import track_iter_progress

from mmdet.apis import init_detector
from mmdet.registry import VISUALIZERS
from mmdet.structures import DetDataSample

try:
    import ffmpegcv
except ImportError:
    raise ImportError(
        'Please install ffmpegcv with:\n\n    pip install ffmpegcv')


def parse_args():
    parser = argparse.ArgumentParser(
        description='MMDetection video demo with GPU acceleration')
    parser.add_argument('video', help='Video file')
    parser.add_argument('config', help='Config file')
    parser.add_argument('checkpoint', help='Checkpoint file')
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference')
    parser.add_argument(
        '--score-thr', type=float, default=0.3, help='Bbox score threshold')
    parser.add_argument('--out', type=str, help='Output video file')
    parser.add_argument('--show', action='store_true', help='Show video')
    parser.add_argument(
        '--nvdecode', action='store_true', help='Use NVIDIA decoder')
    parser.add_argument(
        '--wait-time',
        type=float,
        default=1,
        help='The interval of show (s), 0 is block')
    args = parser.parse_args()
    return args


def prefetch_batch_input_shape(model: nn.Module, ori_wh: Tuple[int,
                                                               int]) -> dict:
    cfg = model.cfg
    w, h = ori_wh
    cfg.test_dataloader.dataset.pipeline[0].type = 'LoadImageFromNDArray'
    test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)
    data = {'img': np.zeros((h, w, 3), dtype=np.uint8), 'img_id': 0}
    data = test_pipeline(data)
    _, data_sample = model.data_preprocessor([data], False)
    batch_input_shape = data_sample[0].batch_input_shape
    return batch_input_shape


def pack_data(frame_resize: np.ndarray, batch_input_shape: Tuple[int, int],
              ori_shape: Tuple[int, int]) -> dict:
    assert frame_resize.shape[:2] == batch_input_shape
    data_sample = DetDataSample()
    data_sample.set_metainfo({
        'img_shape':
        batch_input_shape,
        'ori_shape':
        ori_shape,
        'scale_factor': (batch_input_shape[0] / ori_shape[0],
                         batch_input_shape[1] / ori_shape[1])
    })
    frame_resize = torch.from_numpy(frame_resize).permute((2, 0, 1))
    data = {'inputs': frame_resize, 'data_sample': data_sample}
    return data


def main():
    args = parse_args()
    assert args.out or args.show, \
        ('Please specify at least one operation (save/show the '
         'video) with the argument "--out" or "--show"')

    model = init_detector(args.config, args.checkpoint, device=args.device)

    # init visualizer
    visualizer = VISUALIZERS.build(model.cfg.visualizer)
    # the dataset_meta is loaded from the checkpoint and
    # then pass to the model in init_detector
    visualizer.dataset_meta = model.dataset_meta

    if args.nvdecode:
        VideoCapture = ffmpegcv.VideoCaptureNV
    else:
        VideoCapture = ffmpegcv.VideoCapture
    video_origin = VideoCapture(args.video)

    batch_input_shape = prefetch_batch_input_shape(
        model, (video_origin.width, video_origin.height))
    ori_shape = (video_origin.height, video_origin.width)
    resize_wh = batch_input_shape[::-1]
    video_resize = VideoCapture(
        args.video,
        resize=resize_wh,
        resize_keepratio=True,
        resize_keepratioalign='topleft')

    video_writer = None
    if args.out:
        video_writer = ffmpegcv.VideoWriter(args.out, fps=video_origin.fps)

    with torch.no_grad():
        for i, (frame_resize, frame_origin) in enumerate(
                zip(track_iter_progress(video_resize), video_origin)):
            data = pack_data(frame_resize, batch_input_shape, ori_shape)
            result = model.test_step([data])[0]

            visualizer.add_datasample(
                name='video',
                image=frame_origin,
                data_sample=result,
                draw_gt=False,
                show=False,
                pred_score_thr=args.score_thr)

            frame_mask = visualizer.get_image()

            if args.show:
                cv2.namedWindow('video', 0)
                mmcv.imshow(frame_mask, 'video', args.wait_time)
            if args.out:
                video_writer.write(frame_mask)

    if video_writer:
        video_writer.release()
    video_origin.release()
    video_resize.release()

    cv2.destroyAllWindows()


if __name__ == '__main__':
    main()