TTP / mmdet /models /trackers /sort_tracker.py
KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple
import numpy as np
import torch
from mmengine.structures import InstanceData
try:
import motmetrics
from motmetrics.lap import linear_sum_assignment
except ImportError:
motmetrics = None
from torch import Tensor
from mmdet.registry import MODELS, TASK_UTILS
from mmdet.structures import DetDataSample
from mmdet.structures.bbox import bbox_overlaps, bbox_xyxy_to_cxcyah
from mmdet.utils import OptConfigType
from ..utils import imrenormalize
from .base_tracker import BaseTracker
@MODELS.register_module()
class SORTTracker(BaseTracker):
"""Tracker for SORT/DeepSORT.
Args:
obj_score_thr (float, optional): Threshold to filter the objects.
Defaults to 0.3.
motion (dict): Configuration of motion. Defaults to None.
reid (dict, optional): Configuration for the ReID model.
- num_samples (int, optional): Number of samples to calculate the
feature embeddings of a track. Default to 10.
- image_scale (tuple, optional): Input scale of the ReID model.
Default to (256, 128).
- img_norm_cfg (dict, optional): Configuration to normalize the
input. Default to None.
- match_score_thr (float, optional): Similarity threshold for the
matching process. Default to 2.0.
match_iou_thr (float, optional): Threshold of the IoU matching process.
Defaults to 0.7.
num_tentatives (int, optional): Number of continuous frames to confirm
a track. Defaults to 3.
"""
def __init__(self,
motion: Optional[dict] = None,
obj_score_thr: float = 0.3,
reid: dict = dict(
num_samples=10,
img_scale=(256, 128),
img_norm_cfg=None,
match_score_thr=2.0),
match_iou_thr: float = 0.7,
num_tentatives: int = 3,
**kwargs):
if motmetrics is None:
raise RuntimeError('motmetrics is not installed,\
please install it by: pip install motmetrics')
super().__init__(**kwargs)
if motion is not None:
self.motion = TASK_UTILS.build(motion)
assert self.motion is not None, 'SORT/Deep SORT need KalmanFilter'
self.obj_score_thr = obj_score_thr
self.reid = reid
self.match_iou_thr = match_iou_thr
self.num_tentatives = num_tentatives
@property
def confirmed_ids(self) -> List:
"""Confirmed ids in the tracker."""
ids = [id for id, track in self.tracks.items() if not track.tentative]
return ids
def init_track(self, id: int, obj: Tuple[Tensor]) -> None:
"""Initialize a track."""
super().init_track(id, obj)
self.tracks[id].tentative = True
bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1]) # size = (1, 4)
assert bbox.ndim == 2 and bbox.shape[0] == 1
bbox = bbox.squeeze(0).cpu().numpy()
self.tracks[id].mean, self.tracks[id].covariance = self.kf.initiate(
bbox)
def update_track(self, id: int, obj: Tuple[Tensor]) -> None:
"""Update a track."""
super().update_track(id, obj)
if self.tracks[id].tentative:
if len(self.tracks[id]['bboxes']) >= self.num_tentatives:
self.tracks[id].tentative = False
bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1]) # size = (1, 4)
assert bbox.ndim == 2 and bbox.shape[0] == 1
bbox = bbox.squeeze(0).cpu().numpy()
self.tracks[id].mean, self.tracks[id].covariance = self.kf.update(
self.tracks[id].mean, self.tracks[id].covariance, bbox)
def pop_invalid_tracks(self, frame_id: int) -> None:
"""Pop out invalid tracks."""
invalid_ids = []
for k, v in self.tracks.items():
# case1: disappeared frames >= self.num_frames_retrain
case1 = frame_id - v['frame_ids'][-1] >= self.num_frames_retain
# case2: tentative tracks but not matched in this frame
case2 = v.tentative and v['frame_ids'][-1] != frame_id
if case1 or case2:
invalid_ids.append(k)
for invalid_id in invalid_ids:
self.tracks.pop(invalid_id)
def track(self,
model: torch.nn.Module,
img: Tensor,
data_sample: DetDataSample,
data_preprocessor: OptConfigType = None,
rescale: bool = False,
**kwargs) -> InstanceData:
"""Tracking forward function.
Args:
model (nn.Module): MOT model.
img (Tensor): of shape (T, C, H, W) encoding input image.
Typically these should be mean centered and std scaled.
The T denotes the number of key images and usually is 1 in
SORT method.
data_sample (:obj:`TrackDataSample`): The data sample.
It includes information such as `pred_det_instances`.
data_preprocessor (dict or ConfigDict, optional): The pre-process
config of :class:`TrackDataPreprocessor`. it usually includes,
``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
rescale (bool, optional): If True, the bounding boxes should be
rescaled to fit the original scale of the image. Defaults to
False.
Returns:
:obj:`InstanceData`: Tracking results of the input images.
Each InstanceData usually contains ``bboxes``, ``labels``,
``scores`` and ``instances_id``.
"""
metainfo = data_sample.metainfo
bboxes = data_sample.pred_instances.bboxes
labels = data_sample.pred_instances.labels
scores = data_sample.pred_instances.scores
frame_id = metainfo.get('frame_id', -1)
if frame_id == 0:
self.reset()
if not hasattr(self, 'kf'):
self.kf = self.motion
if self.with_reid:
if self.reid.get('img_norm_cfg', False):
img_norm_cfg = dict(
mean=data_preprocessor['mean'],
std=data_preprocessor['std'],
to_bgr=data_preprocessor['rgb_to_bgr'])
reid_img = imrenormalize(img, img_norm_cfg,
self.reid['img_norm_cfg'])
else:
reid_img = img.clone()
valid_inds = scores > self.obj_score_thr
bboxes = bboxes[valid_inds]
labels = labels[valid_inds]
scores = scores[valid_inds]
if self.empty or bboxes.size(0) == 0:
num_new_tracks = bboxes.size(0)
ids = torch.arange(
self.num_tracks,
self.num_tracks + num_new_tracks,
dtype=torch.long).to(bboxes.device)
self.num_tracks += num_new_tracks
if self.with_reid:
crops = self.crop_imgs(reid_img, metainfo, bboxes.clone(),
rescale)
if crops.size(0) > 0:
embeds = model.reid(crops, mode='tensor')
else:
embeds = crops.new_zeros((0, model.reid.head.out_channels))
else:
ids = torch.full((bboxes.size(0), ), -1,
dtype=torch.long).to(bboxes.device)
# motion
self.tracks, costs = self.motion.track(self.tracks,
bbox_xyxy_to_cxcyah(bboxes))
active_ids = self.confirmed_ids
if self.with_reid:
crops = self.crop_imgs(reid_img, metainfo, bboxes.clone(),
rescale)
embeds = model.reid(crops, mode='tensor')
# reid
if len(active_ids) > 0:
track_embeds = self.get(
'embeds',
active_ids,
self.reid.get('num_samples', None),
behavior='mean')
reid_dists = torch.cdist(track_embeds, embeds)
# support multi-class association
track_labels = torch.tensor([
self.tracks[id]['labels'][-1] for id in active_ids
]).to(bboxes.device)
cate_match = labels[None, :] == track_labels[:, None]
cate_cost = (1 - cate_match.int()) * 1e6
reid_dists = (reid_dists + cate_cost).cpu().numpy()
valid_inds = [list(self.ids).index(_) for _ in active_ids]
reid_dists[~np.isfinite(costs[valid_inds, :])] = np.nan
row, col = linear_sum_assignment(reid_dists)
for r, c in zip(row, col):
dist = reid_dists[r, c]
if not np.isfinite(dist):
continue
if dist <= self.reid['match_score_thr']:
ids[c] = active_ids[r]
active_ids = [
id for id in self.ids if id not in ids
and self.tracks[id].frame_ids[-1] == frame_id - 1
]
if len(active_ids) > 0:
active_dets = torch.nonzero(ids == -1).squeeze(1)
track_bboxes = self.get('bboxes', active_ids)
ious = bbox_overlaps(track_bboxes, bboxes[active_dets])
# support multi-class association
track_labels = torch.tensor([
self.tracks[id]['labels'][-1] for id in active_ids
]).to(bboxes.device)
cate_match = labels[None, active_dets] == track_labels[:, None]
cate_cost = (1 - cate_match.int()) * 1e6
dists = (1 - ious + cate_cost).cpu().numpy()
row, col = linear_sum_assignment(dists)
for r, c in zip(row, col):
dist = dists[r, c]
if dist < 1 - self.match_iou_thr:
ids[active_dets[c]] = active_ids[r]
new_track_inds = ids == -1
ids[new_track_inds] = torch.arange(
self.num_tracks,
self.num_tracks + new_track_inds.sum(),
dtype=torch.long).to(bboxes.device)
self.num_tracks += new_track_inds.sum()
self.update(
ids=ids,
bboxes=bboxes,
scores=scores,
labels=labels,
embeds=embeds if self.with_reid else None,
frame_ids=frame_id)
# update pred_track_instances
pred_track_instances = InstanceData()
pred_track_instances.bboxes = bboxes
pred_track_instances.labels = labels
pred_track_instances.scores = scores
pred_track_instances.instances_id = ids
return pred_track_instances