navsim_ours / det_map /map /map_model.py
lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
raw
history blame
9.44 kB
from __future__ import annotations
from typing import Any, List, Dict
import torch
import torch.optim as optim
import copy
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
import torch.nn as nn
from det_map.data.datasets.dataclasses import SensorConfig, Scene
from det_map.data.datasets.feature_builders import LiDARCameraFeatureBuilder
from navsim.agents.abstract_agent import AbstractAgent
from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder
from det_map.det.dal.mmdet3d.models.utils.grid_mask import GridMask
import torch.nn.functional as F
from det_map.det.dal.mmdet3d.ops import Voxelization, DynamicScatter
from det_map.det.dal.mmdet3d.models import builder
from mmcv.utils import TORCH_VERSION, digit_version
class MapModel(nn.Module):
def __init__(
self,
use_grid_mask=False,
pts_voxel_layer=None,
pts_voxel_encoder=None,
pts_middle_encoder=None,
pts_fusion_layer=None,
img_backbone=None,
pts_backbone=None,
img_neck=None,
pts_neck=None,
pts_bbox_head=None,
img_roi_head=None,
img_rpn_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
video_test_mode=False,
modality='vision',
lidar_encoder=None,
lr=None,
):
super().__init__()
# self.pipelines = pipelines
self.grid_mask = GridMask(
True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7)
if pts_voxel_layer:
self.pts_voxel_layer = Voxelization(**pts_voxel_layer)
if pts_voxel_encoder:
self.pts_voxel_encoder = builder.build_voxel_encoder(
pts_voxel_encoder)
if pts_middle_encoder:
self.pts_middle_encoder = builder.build_middle_encoder(
pts_middle_encoder)
if pts_backbone:
self.pts_backbone = builder.build_backbone(pts_backbone)
if pts_fusion_layer:
self.pts_fusion_layer = builder.build_fusion_layer(
pts_fusion_layer)
if pts_neck is not None:
self.pts_neck = builder.build_neck(pts_neck)
if pts_bbox_head:
pts_train_cfg = None
pts_bbox_head.update(train_cfg=pts_train_cfg)
pts_test_cfg = None
pts_bbox_head.update(test_cfg=pts_test_cfg)
self.pts_bbox_head = builder.build_head(pts_bbox_head)
if img_backbone:
self.img_backbone = builder.build_backbone(img_backbone)
if img_neck is not None:
self.img_neck = builder.build_neck(img_neck)
if img_rpn_head is not None:
self.img_rpn_head = builder.build_head(img_rpn_head)
if img_roi_head is not None:
self.img_roi_head = builder.build_head(img_roi_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
if pretrained is None:
img_pretrained = None
pts_pretrained = None
elif isinstance(pretrained, dict):
img_pretrained = pretrained.get('img', None)
pts_pretrained = pretrained.get('pts', None)
else:
raise ValueError(
f'pretrained should be a dict, got {type(pretrained)}')
self.use_grid_mask = use_grid_mask
self.fp16_enabled = False
# temporal
self.video_test_mode = video_test_mode
self.prev_frame_info = {
'prev_bev': None,
'scene_token': None,
'prev_pos': 0,
'prev_angle': 0,
}
self.modality = modality
if self.modality == 'fusion' and lidar_encoder is not None:
if lidar_encoder["voxelize"].get("max_num_points", -1) > 0:
voxelize_module = Voxelization(**lidar_encoder["voxelize"])
else:
voxelize_module = DynamicScatter(**lidar_encoder["voxelize"])
self.lidar_modal_extractor = nn.ModuleDict(
{
"voxelize": voxelize_module,
"backbone": builder.build_middle_encoder(lidar_encoder["backbone"]),
}
)
self.voxelize_reduce = lidar_encoder.get("voxelize_reduce", True)
self._lr = lr
def extract_img_feat(self, img, img_metas=None, len_queue=None):
"""Extract features of images."""
B = img.size(0)
if img is not None:
# input_shape = img.shape[-2:]
# # update real input shape of each single img
# for img_meta in img_metas:
# img_meta.update(input_shape=input_shape)
if img.dim() == 5 and img.size(0) == 1:
img.squeeze_()
elif img.dim() == 5 and img.size(0) > 1:
B, N, C, H, W = img.size()
img = img.reshape(B * N, C, H, W)
if self.use_grid_mask:
img = self.grid_mask(img)
img_feats = self.img_backbone(img)
if isinstance(img_feats, dict):
img_feats = list(img_feats.values())
else:
return None
self.with_img_neck = True
if self.with_img_neck:
img_feats = self.img_neck(img_feats)
BN, C, H, W = img_feats[0].shape
return [tmp.view(B, BN // B, C, H , W) for tmp in img_feats]
@torch.no_grad()
def voxelize(self, points):
feats, coords, sizes = [], [], []
for k, res in enumerate(points):
ret = self.lidar_modal_extractor["voxelize"](res)
if len(ret) == 3:
# hard voxelize
f, c, n = ret
else:
assert len(ret) == 2
f, c = ret
n = None
feats.append(f)
coords.append(F.pad(c, (1, 0), mode="constant", value=k))
if n is not None:
sizes.append(n)
feats = torch.cat(feats, dim=0)
coords = torch.cat(coords, dim=0)
if len(sizes) > 0:
sizes = torch.cat(sizes, dim=0)
if self.voxelize_reduce:
feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view(
-1, 1
)
feats = feats.contiguous()
return feats, coords, sizes
def extract_lidar_feat(self, points):
feats, coords, sizes = self.voxelize(points)
# voxel_features = self.lidar_modal_extractor["voxel_encoder"](feats, sizes, coords)
batch_size = coords[-1, 0] + 1
lidar_feat = self.lidar_modal_extractor["backbone"](feats, coords, batch_size)
return lidar_feat
def forward(self, feature_dict=None, points=None, img_metas=None) -> Dict[str, torch.Tensor]:
lidar_feat = None
# points = feature_dict['lidars_warped']
# points_input = []
# for tmp in points:
# points_input.append(torch.cat(tmp, 0))
if self.modality == 'fusion':
lidar_feat = self.extract_lidar_feat(points_input)
img = feature_dict['image']
len_queue = img.size(1)
img = img[:, -1, ...]
img_feats = self.extract_img_feat(img, img_metas, len_queue=len_queue)
outs = self.pts_bbox_head(
img_feats, lidar_feat, feature_dict, None)
return outs
# class MyLightningModule(pl.LightningModule):
# def __init__(
# self,
# agent: AbstractAgent,
# ):
# super().__init__()
# self.agent = agent
# def _step(
# self,
# batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]],
# logging_prefix: str,
# ):
# features, targets = batch
# prediction = self.agent.forward(features)
# loss = self.agent.compute_loss(features, targets, prediction)
# self.log(f"{logging_prefix}/loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
# return loss
# def training_step(
# self,
# batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]],
# batch_idx: int
# ):
# return self._step(batch, "train")
# def validation_step(
# self,
# batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]],
# batch_idx: int
# ):
# return self._step(batch, "val")
# def configure_optimizers(self):
# optimizer = self.agent.get_optimizers()
# # 应用梯度裁剪
# if 'grad_clip' in self.optimizer_config:
# grad_clip = self.optimizer_config['grad_clip']
# max_norm = grad_clip.get('max_norm', 1.0)
# norm_type = grad_clip.get('norm_type', 2)
# optimizer = optim.Adam(self.parameters(), lr=1e-3)
# return {
# 'optimizer': optimizer,
# 'clip_grad_norm': max_norm,
# 'clip_grad_value': None, # 可以使用 'clip_grad_value' 来限制梯度的绝对值
# }
# else:
# return optimizerfrom __future__ import annotations