|
from __future__ import annotations
|
|
|
|
from typing import Any, List, Dict
|
|
|
|
import torch
|
|
import torch.optim as optim
|
|
import copy
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import LRScheduler
|
|
import torch.nn as nn
|
|
from det_map.data.datasets.dataclasses import SensorConfig, Scene
|
|
from det_map.data.datasets.feature_builders import LiDARCameraFeatureBuilder
|
|
from navsim.agents.abstract_agent import AbstractAgent
|
|
from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder
|
|
|
|
from det_map.det.dal.mmdet3d.models.utils.grid_mask import GridMask
|
|
|
|
import torch.nn.functional as F
|
|
|
|
from det_map.det.dal.mmdet3d.ops import Voxelization, DynamicScatter
|
|
from det_map.det.dal.mmdet3d.models import builder
|
|
from mmcv.utils import TORCH_VERSION, digit_version
|
|
|
|
|
|
class MapModel(nn.Module):
|
|
def __init__(
|
|
self,
|
|
use_grid_mask=False,
|
|
pts_voxel_layer=None,
|
|
pts_voxel_encoder=None,
|
|
pts_middle_encoder=None,
|
|
pts_fusion_layer=None,
|
|
img_backbone=None,
|
|
pts_backbone=None,
|
|
img_neck=None,
|
|
pts_neck=None,
|
|
pts_bbox_head=None,
|
|
img_roi_head=None,
|
|
img_rpn_head=None,
|
|
train_cfg=None,
|
|
test_cfg=None,
|
|
pretrained=None,
|
|
video_test_mode=False,
|
|
modality='vision',
|
|
lidar_encoder=None,
|
|
lr=None,
|
|
):
|
|
super().__init__()
|
|
|
|
self.grid_mask = GridMask(
|
|
True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7)
|
|
if pts_voxel_layer:
|
|
self.pts_voxel_layer = Voxelization(**pts_voxel_layer)
|
|
if pts_voxel_encoder:
|
|
self.pts_voxel_encoder = builder.build_voxel_encoder(
|
|
pts_voxel_encoder)
|
|
if pts_middle_encoder:
|
|
self.pts_middle_encoder = builder.build_middle_encoder(
|
|
pts_middle_encoder)
|
|
if pts_backbone:
|
|
self.pts_backbone = builder.build_backbone(pts_backbone)
|
|
if pts_fusion_layer:
|
|
self.pts_fusion_layer = builder.build_fusion_layer(
|
|
pts_fusion_layer)
|
|
if pts_neck is not None:
|
|
self.pts_neck = builder.build_neck(pts_neck)
|
|
if pts_bbox_head:
|
|
pts_train_cfg = None
|
|
pts_bbox_head.update(train_cfg=pts_train_cfg)
|
|
pts_test_cfg = None
|
|
pts_bbox_head.update(test_cfg=pts_test_cfg)
|
|
self.pts_bbox_head = builder.build_head(pts_bbox_head)
|
|
if img_backbone:
|
|
self.img_backbone = builder.build_backbone(img_backbone)
|
|
if img_neck is not None:
|
|
self.img_neck = builder.build_neck(img_neck)
|
|
if img_rpn_head is not None:
|
|
self.img_rpn_head = builder.build_head(img_rpn_head)
|
|
if img_roi_head is not None:
|
|
self.img_roi_head = builder.build_head(img_roi_head)
|
|
self.train_cfg = train_cfg
|
|
self.test_cfg = test_cfg
|
|
|
|
if pretrained is None:
|
|
img_pretrained = None
|
|
pts_pretrained = None
|
|
elif isinstance(pretrained, dict):
|
|
img_pretrained = pretrained.get('img', None)
|
|
pts_pretrained = pretrained.get('pts', None)
|
|
else:
|
|
raise ValueError(
|
|
f'pretrained should be a dict, got {type(pretrained)}')
|
|
|
|
self.use_grid_mask = use_grid_mask
|
|
self.fp16_enabled = False
|
|
|
|
|
|
self.video_test_mode = video_test_mode
|
|
self.prev_frame_info = {
|
|
'prev_bev': None,
|
|
'scene_token': None,
|
|
'prev_pos': 0,
|
|
'prev_angle': 0,
|
|
}
|
|
self.modality = modality
|
|
if self.modality == 'fusion' and lidar_encoder is not None:
|
|
if lidar_encoder["voxelize"].get("max_num_points", -1) > 0:
|
|
voxelize_module = Voxelization(**lidar_encoder["voxelize"])
|
|
else:
|
|
voxelize_module = DynamicScatter(**lidar_encoder["voxelize"])
|
|
self.lidar_modal_extractor = nn.ModuleDict(
|
|
{
|
|
"voxelize": voxelize_module,
|
|
"backbone": builder.build_middle_encoder(lidar_encoder["backbone"]),
|
|
}
|
|
)
|
|
self.voxelize_reduce = lidar_encoder.get("voxelize_reduce", True)
|
|
|
|
self._lr = lr
|
|
|
|
|
|
def extract_img_feat(self, img, img_metas=None, len_queue=None):
|
|
"""Extract features of images."""
|
|
B = img.size(0)
|
|
if img is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if img.dim() == 5 and img.size(0) == 1:
|
|
img.squeeze_()
|
|
elif img.dim() == 5 and img.size(0) > 1:
|
|
B, N, C, H, W = img.size()
|
|
img = img.reshape(B * N, C, H, W)
|
|
if self.use_grid_mask:
|
|
img = self.grid_mask(img)
|
|
|
|
img_feats = self.img_backbone(img)
|
|
if isinstance(img_feats, dict):
|
|
img_feats = list(img_feats.values())
|
|
else:
|
|
return None
|
|
|
|
self.with_img_neck = True
|
|
if self.with_img_neck:
|
|
img_feats = self.img_neck(img_feats)
|
|
|
|
BN, C, H, W = img_feats[0].shape
|
|
return [tmp.view(B, BN // B, C, H , W) for tmp in img_feats]
|
|
|
|
@torch.no_grad()
|
|
def voxelize(self, points):
|
|
feats, coords, sizes = [], [], []
|
|
for k, res in enumerate(points):
|
|
ret = self.lidar_modal_extractor["voxelize"](res)
|
|
if len(ret) == 3:
|
|
|
|
f, c, n = ret
|
|
else:
|
|
assert len(ret) == 2
|
|
f, c = ret
|
|
n = None
|
|
feats.append(f)
|
|
coords.append(F.pad(c, (1, 0), mode="constant", value=k))
|
|
if n is not None:
|
|
sizes.append(n)
|
|
|
|
feats = torch.cat(feats, dim=0)
|
|
coords = torch.cat(coords, dim=0)
|
|
if len(sizes) > 0:
|
|
sizes = torch.cat(sizes, dim=0)
|
|
if self.voxelize_reduce:
|
|
feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view(
|
|
-1, 1
|
|
)
|
|
feats = feats.contiguous()
|
|
|
|
return feats, coords, sizes
|
|
|
|
def extract_lidar_feat(self, points):
|
|
feats, coords, sizes = self.voxelize(points)
|
|
|
|
batch_size = coords[-1, 0] + 1
|
|
lidar_feat = self.lidar_modal_extractor["backbone"](feats, coords, batch_size)
|
|
|
|
return lidar_feat
|
|
|
|
def forward(self, feature_dict=None, points=None, img_metas=None) -> Dict[str, torch.Tensor]:
|
|
lidar_feat = None
|
|
|
|
|
|
|
|
|
|
if self.modality == 'fusion':
|
|
lidar_feat = self.extract_lidar_feat(points_input)
|
|
|
|
|
|
img = feature_dict['image']
|
|
len_queue = img.size(1)
|
|
img = img[:, -1, ...]
|
|
img_feats = self.extract_img_feat(img, img_metas, len_queue=len_queue)
|
|
|
|
outs = self.pts_bbox_head(
|
|
img_feats, lidar_feat, feature_dict, None)
|
|
|
|
return outs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|