history blame
5.95 kB
# ------------------------------------------------------------------------
# Modified from mmdetection3d (https://github.com/open-mmlab/mmdetection3d)
# Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------
# Modified by Shihao Wang
# ------------------------------------------------------------------------
import numpy as np
from mmdet.datasets.builder import PIPELINES
from mmcv.parallel import DataContainer as DC
from mmdet3d.core.points import BasePoints
from mmdet.datasets.pipelines import to_tensor
from mmdet3d.datasets.pipelines import DefaultFormatBundle
class PETRFormatBundle3D(DefaultFormatBundle):
"""Default formatting bundle.
It simplifies the pipeline of formatting common fields for voxels,
including "proposals", "gt_bboxes", "gt_labels", "gt_masks" and
These fields are formatted as follows.
- img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True)
- proposals: (1)to tensor, (2)to DataContainer
- gt_bboxes: (1)to tensor, (2)to DataContainer
- gt_bboxes_ignore: (1)to tensor, (2)to DataContainer
- gt_labels: (1)to tensor, (2)to DataContainer
def __init__(self, class_names, collect_keys, with_gt=True, with_label=True):
super(PETRFormatBundle3D, self).__init__()
self.class_names = class_names
self.with_gt = with_gt
self.with_label = with_label
self.collect_keys = collect_keys
def __call__(self, results):
"""Call function to transform and format common fields in results.
results (dict): Result dict contains the data to convert.
dict: The result dict contains the data that is formatted with
default bundle.
# Format 3D data
if 'points' in results:
assert isinstance(results['points'], BasePoints)
results['points'] = DC(results['points'].tensor)
for key in self.collect_keys:
if key in ['timestamp', 'img_timestamp']:
results[key] = DC(to_tensor(np.array(results[key], dtype=np.float64)), stack=True, pad_dims=None)
results[key] = DC(to_tensor(np.array(results[key], dtype=np.float32)), stack=True, pad_dims=None)
if 'lane_pts' in results.keys():
results['lane_pts'] = DC(to_tensor(np.array(results['lane_pts'], dtype=np.float32)), cpu_only=False)
for key in ['voxels', 'coors', 'voxel_centers', 'num_points']:
if key not in results:
results[key] = DC(to_tensor(results[key]), stack=False)
for key in ['input_ids', 'vlm_labels']:
if key not in results:
results[key] = DC(results[key], stack=False)
if self.with_gt:
if 'ann_info' in results.keys():
if 'lane_pts' in results['ann_info'].keys():
results['lane_pts'] = DC(to_tensor(np.array(results['ann_info']['lane_pts'], dtype=np.float32)), cpu_only=False)
# Clean GT bboxes in the final
if 'gt_bboxes_3d_mask' in results:
gt_bboxes_3d_mask = results['gt_bboxes_3d_mask']
results['gt_bboxes_3d'] = results['gt_bboxes_3d'][
if 'gt_names_3d' in results:
results['gt_names_3d'] = results['gt_names_3d'][
if 'centers2d' in results:
results['centers2d'] = results['centers2d'][
if 'depths' in results:
results['depths'] = results['depths'][gt_bboxes_3d_mask]
if 'gt_bboxes_mask' in results:
gt_bboxes_mask = results['gt_bboxes_mask']
if 'gt_bboxes' in results:
results['gt_bboxes'] = results['gt_bboxes'][gt_bboxes_mask]
results['gt_names'] = results['gt_names'][gt_bboxes_mask]
if self.with_label:
if 'gt_names' in results and len(results['gt_names']) == 0:
results['gt_labels'] = np.array([], dtype=np.int64)
results['attr_labels'] = np.array([], dtype=np.int64)
elif 'gt_names' in results and isinstance(
results['gt_names'][0], list):
# gt_labels might be a list of list in multi-view setting
results['gt_labels'] = [
np.array([self.class_names.index(n) for n in res],
dtype=np.int64) for res in results['gt_names']
elif 'gt_names' in results:
results['gt_labels'] = np.array([
self.class_names.index(n) for n in results['gt_names']
# we still assume one pipeline for one frame LiDAR
# thus, the 3D name is list[string]
if 'gt_names_3d' in results:
results['gt_labels_3d'] = np.array([
for n in results['gt_names_3d']
results = super(PETRFormatBundle3D, self).__call__(results)
return results
def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(class_names={self.class_names}, '
repr_str += f'collect_keys={self.collect_keys}, with_gt={self.with_gt}, with_label={self.with_label})'
return repr_str