Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional | |
import torch | |
from torch import Tensor | |
from mmdet.models.mot import BaseMOTModel | |
from mmdet.registry import MODELS | |
from mmdet.structures import TrackSampleList | |
from mmdet.utils import OptConfigType, OptMultiConfig | |
class MaskTrackRCNN(BaseMOTModel): | |
"""Video Instance Segmentation. | |
This video instance segmentor is the implementation of`MaskTrack R-CNN | |
<https://arxiv.org/abs/1905.04804>`_. | |
Args: | |
detector (dict): Configuration of detector. Defaults to None. | |
track_head (dict): Configuration of track head. Defaults to None. | |
tracker (dict): Configuration of tracker. Defaults to None. | |
data_preprocessor (dict or ConfigDict, optional): The pre-process | |
config of :class:`TrackDataPreprocessor`. it usually includes, | |
``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``. | |
init_cfg (dict or list[dict]): Configuration of initialization. | |
Defaults to None. | |
""" | |
def __init__(self, | |
detector: Optional[dict] = None, | |
track_head: Optional[dict] = None, | |
tracker: Optional[dict] = None, | |
data_preprocessor: OptConfigType = None, | |
init_cfg: OptMultiConfig = None): | |
super().__init__(data_preprocessor, init_cfg) | |
if detector is not None: | |
self.detector = MODELS.build(detector) | |
assert hasattr(self.detector, 'roi_head'), \ | |
'MaskTrack R-CNN only supports two stage detectors.' | |
if track_head is not None: | |
self.track_head = MODELS.build(track_head) | |
if tracker is not None: | |
self.tracker = MODELS.build(tracker) | |
def loss(self, inputs: Tensor, data_samples: TrackSampleList, | |
**kwargs) -> dict: | |
"""Calculate losses from a batch of inputs and data samples. | |
Args: | |
inputs (Dict[str, Tensor]): of shape (N, T, C, H, W) encoding | |
input images. Typically these should be mean centered and std | |
scaled. The N denotes batch size. The T denotes the number of | |
frames. | |
data_samples (list[:obj:`TrackDataSample`]): The batch | |
data samples. It usually includes information such | |
as `gt_instance`. | |
Returns: | |
dict: A dictionary of loss components. | |
""" | |
assert inputs.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).' | |
assert inputs.size(1) == 2, \ | |
'MaskTrackRCNN can only have 1 key frame and 1 reference frame.' | |
# split the data_samples into two aspects: key frames and reference | |
# frames | |
ref_data_samples, key_data_samples = [], [] | |
key_frame_inds, ref_frame_inds = [], [] | |
# set cat_id of gt_labels to 0 in RPN | |
for track_data_sample in data_samples: | |
key_data_sample = track_data_sample.get_key_frames()[0] | |
key_data_samples.append(key_data_sample) | |
ref_data_sample = track_data_sample.get_ref_frames()[0] | |
ref_data_samples.append(ref_data_sample) | |
key_frame_inds.append(track_data_sample.key_frames_inds[0]) | |
ref_frame_inds.append(track_data_sample.ref_frames_inds[0]) | |
key_frame_inds = torch.tensor(key_frame_inds, dtype=torch.int64) | |
ref_frame_inds = torch.tensor(ref_frame_inds, dtype=torch.int64) | |
batch_inds = torch.arange(len(inputs)) | |
key_imgs = inputs[batch_inds, key_frame_inds].contiguous() | |
ref_imgs = inputs[batch_inds, ref_frame_inds].contiguous() | |
x = self.detector.extract_feat(key_imgs) | |
ref_x = self.detector.extract_feat(ref_imgs) | |
losses = dict() | |
# RPN forward and loss | |
if self.detector.with_rpn: | |
proposal_cfg = self.detector.train_cfg.get( | |
'rpn_proposal', self.detector.test_cfg.rpn) | |
rpn_losses, rpn_results_list = self.detector.rpn_head. \ | |
loss_and_predict(x, | |
key_data_samples, | |
proposal_cfg=proposal_cfg, | |
**kwargs) | |
# avoid get same name with roi_head loss | |
keys = rpn_losses.keys() | |
for key in keys: | |
if 'loss' in key and 'rpn' not in key: | |
rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key) | |
losses.update(rpn_losses) | |
else: | |
# TODO: Not support currently, should have a check at Fast R-CNN | |
assert key_data_samples[0].get('proposals', None) is not None | |
# use pre-defined proposals in InstanceData for the second stage | |
# to extract ROI features. | |
rpn_results_list = [ | |
key_data_sample.proposals | |
for key_data_sample in key_data_samples | |
] | |
losses_detect = self.detector.roi_head.loss(x, rpn_results_list, | |
key_data_samples, **kwargs) | |
losses.update(losses_detect) | |
losses_track = self.track_head.loss(x, ref_x, rpn_results_list, | |
data_samples, **kwargs) | |
losses.update(losses_track) | |
return losses | |
def predict(self, | |
inputs: Tensor, | |
data_samples: TrackSampleList, | |
rescale: bool = True, | |
**kwargs) -> TrackSampleList: | |
"""Test without augmentation. | |
Args: | |
inputs (Tensor): of shape (N, T, C, H, W) encoding | |
input images. The N denotes batch size. | |
The T denotes the number of frames in a video. | |
data_samples (list[:obj:`TrackDataSample`]): The batch | |
data samples. It usually includes information such | |
as `video_data_samples`. | |
rescale (bool, Optional): If False, then returned bboxes and masks | |
will fit the scale of img, otherwise, returned bboxes and masks | |
will fit the scale of original image shape. Defaults to True. | |
Returns: | |
TrackSampleList: Tracking results of the inputs. | |
""" | |
assert inputs.dim() == 5, 'The img must be 5D Tensor (N, T, C, H, W).' | |
assert len(data_samples) == 1, \ | |
'MaskTrackRCNN only support 1 batch size per gpu for now.' | |
track_data_sample = data_samples[0] | |
video_len = len(track_data_sample) | |
if track_data_sample[0].frame_id == 0: | |
self.tracker.reset() | |
for frame_id in range(video_len): | |
img_data_sample = track_data_sample[frame_id] | |
single_img = inputs[:, frame_id].contiguous() | |
x = self.detector.extract_feat(single_img) | |
rpn_results_list = self.detector.rpn_head.predict( | |
x, [img_data_sample]) | |
# det_results List[InstanceData] | |
det_results = self.detector.roi_head.predict( | |
x, rpn_results_list, [img_data_sample], rescale=rescale) | |
assert len(det_results) == 1, 'Batch inference is not supported.' | |
assert 'masks' in det_results[0], 'There are no mask results.' | |
img_data_sample.pred_instances = det_results[0] | |
frame_pred_track_instances = self.tracker.track( | |
model=self, feats=x, data_sample=img_data_sample, **kwargs) | |
img_data_sample.pred_track_instances = frame_pred_track_instances | |
return [track_data_sample] | |