Spaces:
Runtime error
Runtime error
| import numpy as np | |
| from mmdet.core import bbox2result | |
| from mmdet.models import TwoStageDetector | |
| from qdtrack.core import track2result | |
| from ..builder import MODELS, build_tracker | |
| from qdtrack.core import imshow_tracks, restore_result | |
| from tracker import BYTETracker | |
| class QDTrack(TwoStageDetector): | |
| def __init__(self, tracker=None, freeze_detector=False, *args, **kwargs): | |
| self.prepare_cfg(kwargs) | |
| super().__init__(*args, **kwargs) | |
| self.tracker_cfg = tracker | |
| self.freeze_detector = freeze_detector | |
| if self.freeze_detector: | |
| self._freeze_detector() | |
| def _freeze_detector(self): | |
| self.detector = [ | |
| self.backbone, self.neck, self.rpn_head, self.roi_head.bbox_head | |
| ] | |
| for model in self.detector: | |
| model.eval() | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| def prepare_cfg(self, kwargs): | |
| if kwargs.get('train_cfg', False): | |
| kwargs['roi_head']['track_train_cfg'] = kwargs['train_cfg'].get( | |
| 'embed', None) | |
| def init_tracker(self): | |
| # self.tracker = build_tracker(self.tracker_cfg) | |
| self.tracker = BYTETracker() | |
| def forward_train(self, | |
| img, | |
| img_metas, | |
| gt_bboxes, | |
| gt_labels, | |
| gt_match_indices, | |
| ref_img, | |
| ref_img_metas, | |
| ref_gt_bboxes, | |
| ref_gt_labels, | |
| ref_gt_match_indices, | |
| gt_bboxes_ignore=None, | |
| gt_masks=None, | |
| ref_gt_bboxes_ignore=None, | |
| ref_gt_masks=None, | |
| **kwargs): | |
| x = self.extract_feat(img) | |
| losses = dict() | |
| # RPN forward and loss | |
| proposal_cfg = self.train_cfg.get('rpn_proposal', self.test_cfg.rpn) | |
| rpn_losses, proposal_list = self.rpn_head.forward_train( | |
| x, | |
| img_metas, | |
| gt_bboxes, | |
| gt_labels=None, | |
| gt_bboxes_ignore=gt_bboxes_ignore, | |
| proposal_cfg=proposal_cfg) | |
| losses.update(rpn_losses) | |
| ref_x = self.extract_feat(ref_img) | |
| ref_proposals = self.rpn_head.simple_test_rpn(ref_x, ref_img_metas) | |
| roi_losses = self.roi_head.forward_train( | |
| x, img_metas, proposal_list, gt_bboxes, gt_labels, | |
| gt_match_indices, ref_x, ref_img_metas, ref_proposals, | |
| ref_gt_bboxes, ref_gt_labels, gt_bboxes_ignore, gt_masks, | |
| ref_gt_bboxes_ignore, **kwargs) | |
| losses.update(roi_losses) | |
| return losses | |
| def simple_test(self, img, img_metas, rescale=False): | |
| # TODO inherit from a base tracker | |
| assert self.roi_head.with_track, 'Track head must be implemented.' | |
| frame_id = img_metas[0].get('frame_id', -1) | |
| if frame_id == 0: | |
| self.init_tracker() | |
| x = self.extract_feat(img) | |
| proposal_list = self.rpn_head.simple_test_rpn(x, img_metas) | |
| det_bboxes, det_labels, track_feats = self.roi_head.simple_test(x, img_metas, proposal_list, rescale) | |
| bboxes, labels, ids = self.tracker.update(det_bboxes, det_labels, frame_id, track_feats) | |
| # if track_feats is not None: | |
| # bboxes, labels, ids = self.tracker.match( | |
| # bboxes=det_bboxes, | |
| # labels=det_labels, | |
| # track_feats=track_feats, | |
| # frame_id=frame_id) | |
| bbox_result = bbox2result(det_bboxes, det_labels, | |
| self.roi_head.bbox_head.num_classes) | |
| if track_feats is not None: | |
| track_result = track2result(bboxes, labels, ids, | |
| self.roi_head.bbox_head.num_classes) | |
| else: | |
| track_result = [ | |
| np.zeros((0, 6), dtype=np.float32) | |
| for i in range(self.roi_head.bbox_head.num_classes) | |
| ] | |
| return dict(bbox_results=bbox_result, track_results=track_result) | |
| def show_result(self, | |
| img, | |
| result, | |
| thickness=1, | |
| font_scale=0.5, | |
| show=False, | |
| out_file=None, | |
| wait_time=0, | |
| backend='cv2', | |
| **kwargs): | |
| """Visualize tracking results. | |
| Args: | |
| img (str | ndarray): Filename of loaded image. | |
| result (dict): Tracking result. | |
| The value of key 'track_results' is ndarray with shape (n, 6) | |
| in [id, tl_x, tl_y, br_x, br_y, score] format. | |
| The value of key 'bbox_results' is ndarray with shape (n, 5) | |
| in [tl_x, tl_y, br_x, br_y, score] format. | |
| thickness (int, optional): Thickness of lines. Defaults to 1. | |
| font_scale (float, optional): Font scales of texts. Defaults | |
| to 0.5. | |
| show (bool, optional): Whether show the visualizations on the | |
| fly. Defaults to False. | |
| out_file (str | None, optional): Output filename. Defaults to None. | |
| backend (str, optional): Backend to draw the bounding boxes, | |
| options are `cv2` and `plt`. Defaults to 'cv2'. | |
| Returns: | |
| ndarray: Visualized image. | |
| """ | |
| assert isinstance(result, dict) | |
| track_result = result.get('track_results', None) | |
| bboxes, labels, ids = restore_result(track_result, return_ids=True) | |
| img = imshow_tracks( | |
| img, | |
| bboxes, | |
| labels, | |
| ids, | |
| classes=self.CLASSES, | |
| thickness=thickness, | |
| font_scale=font_scale, | |
| show=show, | |
| out_file=out_file, | |
| wait_time=wait_time, | |
| backend=backend) | |
| return img | |