|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
from mmdet.datasets.builder import PIPELINES |
|
from mmcv.parallel import DataContainer as DC |
|
from mmdet3d.core.points import BasePoints |
|
from mmdet.datasets.pipelines import to_tensor |
|
from mmdet3d.datasets.pipelines import DefaultFormatBundle |
|
|
|
@PIPELINES.register_module() |
|
class PETRFormatBundle3D(DefaultFormatBundle): |
|
"""Default formatting bundle. |
|
|
|
It simplifies the pipeline of formatting common fields for voxels, |
|
including "proposals", "gt_bboxes", "gt_labels", "gt_masks" and |
|
"gt_semantic_seg". |
|
These fields are formatted as follows. |
|
|
|
- img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True) |
|
- proposals: (1)to tensor, (2)to DataContainer |
|
- gt_bboxes: (1)to tensor, (2)to DataContainer |
|
- gt_bboxes_ignore: (1)to tensor, (2)to DataContainer |
|
- gt_labels: (1)to tensor, (2)to DataContainer |
|
""" |
|
|
|
def __init__(self, class_names, collect_keys, with_gt=True, with_label=True): |
|
super(PETRFormatBundle3D, self).__init__() |
|
self.class_names = class_names |
|
self.with_gt = with_gt |
|
self.with_label = with_label |
|
self.collect_keys = collect_keys |
|
def __call__(self, results): |
|
"""Call function to transform and format common fields in results. |
|
|
|
Args: |
|
results (dict): Result dict contains the data to convert. |
|
|
|
Returns: |
|
dict: The result dict contains the data that is formatted with |
|
default bundle. |
|
""" |
|
|
|
if 'points' in results: |
|
assert isinstance(results['points'], BasePoints) |
|
results['points'] = DC(results['points'].tensor) |
|
|
|
for key in self.collect_keys: |
|
if key in ['timestamp', 'img_timestamp']: |
|
results[key] = DC(to_tensor(np.array(results[key], dtype=np.float64)), stack=True, pad_dims=None) |
|
else: |
|
results[key] = DC(to_tensor(np.array(results[key], dtype=np.float32)), stack=True, pad_dims=None) |
|
|
|
if 'lane_pts' in results.keys(): |
|
results['lane_pts'] = DC(to_tensor(np.array(results['lane_pts'], dtype=np.float32)), cpu_only=False) |
|
|
|
|
|
for key in ['voxels', 'coors', 'voxel_centers', 'num_points']: |
|
if key not in results: |
|
continue |
|
results[key] = DC(to_tensor(results[key]), stack=False) |
|
|
|
for key in ['input_ids', 'vlm_labels']: |
|
if key not in results: |
|
continue |
|
results[key] = DC(results[key], stack=False) |
|
|
|
if self.with_gt: |
|
if 'ann_info' in results.keys(): |
|
if 'lane_pts' in results['ann_info'].keys(): |
|
results['lane_pts'] = DC(to_tensor(np.array(results['ann_info']['lane_pts'], dtype=np.float32)), cpu_only=False) |
|
|
|
if 'gt_bboxes_3d_mask' in results: |
|
gt_bboxes_3d_mask = results['gt_bboxes_3d_mask'] |
|
results['gt_bboxes_3d'] = results['gt_bboxes_3d'][ |
|
gt_bboxes_3d_mask] |
|
if 'gt_names_3d' in results: |
|
results['gt_names_3d'] = results['gt_names_3d'][ |
|
gt_bboxes_3d_mask] |
|
if 'centers2d' in results: |
|
results['centers2d'] = results['centers2d'][ |
|
gt_bboxes_3d_mask] |
|
if 'depths' in results: |
|
results['depths'] = results['depths'][gt_bboxes_3d_mask] |
|
if 'gt_bboxes_mask' in results: |
|
gt_bboxes_mask = results['gt_bboxes_mask'] |
|
if 'gt_bboxes' in results: |
|
results['gt_bboxes'] = results['gt_bboxes'][gt_bboxes_mask] |
|
results['gt_names'] = results['gt_names'][gt_bboxes_mask] |
|
if self.with_label: |
|
if 'gt_names' in results and len(results['gt_names']) == 0: |
|
results['gt_labels'] = np.array([], dtype=np.int64) |
|
results['attr_labels'] = np.array([], dtype=np.int64) |
|
elif 'gt_names' in results and isinstance( |
|
results['gt_names'][0], list): |
|
|
|
results['gt_labels'] = [ |
|
np.array([self.class_names.index(n) for n in res], |
|
dtype=np.int64) for res in results['gt_names'] |
|
] |
|
elif 'gt_names' in results: |
|
results['gt_labels'] = np.array([ |
|
self.class_names.index(n) for n in results['gt_names'] |
|
], |
|
dtype=np.int64) |
|
|
|
|
|
if 'gt_names_3d' in results: |
|
results['gt_labels_3d'] = np.array([ |
|
self.class_names.index(n) |
|
for n in results['gt_names_3d'] |
|
], |
|
dtype=np.int64) |
|
results = super(PETRFormatBundle3D, self).__call__(results) |
|
return results |
|
|
|
def __repr__(self): |
|
"""str: Return a string that describes the module.""" |
|
repr_str = self.__class__.__name__ |
|
repr_str += f'(class_names={self.class_names}, ' |
|
repr_str += f'collect_keys={self.collect_keys}, with_gt={self.with_gt}, with_label={self.with_label})' |
|
return repr_str |