from __future__ import annotations from typing import Any, List, Dict import torch import torch.optim as optim import copy from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler import torch.nn as nn from det_map.data.datasets.dataclasses import SensorConfig, Scene from det_map.data.datasets.feature_builders import LiDARCameraFeatureBuilder from navsim.agents.abstract_agent import AbstractAgent from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder from det_map.det.dal.mmdet3d.models.utils.grid_mask import GridMask import torch.nn.functional as F from det_map.det.dal.mmdet3d.ops import Voxelization, DynamicScatter from det_map.det.dal.mmdet3d.models import builder from mmcv.utils import TORCH_VERSION, digit_version class MapModel(nn.Module): def __init__( self, use_grid_mask=False, pts_voxel_layer=None, pts_voxel_encoder=None, pts_middle_encoder=None, pts_fusion_layer=None, img_backbone=None, pts_backbone=None, img_neck=None, pts_neck=None, pts_bbox_head=None, img_roi_head=None, img_rpn_head=None, train_cfg=None, test_cfg=None, pretrained=None, video_test_mode=False, modality='vision', lidar_encoder=None, lr=None, ): super().__init__() # self.pipelines = pipelines self.grid_mask = GridMask( True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7) if pts_voxel_layer: self.pts_voxel_layer = Voxelization(**pts_voxel_layer) if pts_voxel_encoder: self.pts_voxel_encoder = builder.build_voxel_encoder( pts_voxel_encoder) if pts_middle_encoder: self.pts_middle_encoder = builder.build_middle_encoder( pts_middle_encoder) if pts_backbone: self.pts_backbone = builder.build_backbone(pts_backbone) if pts_fusion_layer: self.pts_fusion_layer = builder.build_fusion_layer( pts_fusion_layer) if pts_neck is not None: self.pts_neck = builder.build_neck(pts_neck) if pts_bbox_head: pts_train_cfg = None pts_bbox_head.update(train_cfg=pts_train_cfg) pts_test_cfg = None pts_bbox_head.update(test_cfg=pts_test_cfg) self.pts_bbox_head = builder.build_head(pts_bbox_head) if img_backbone: self.img_backbone = builder.build_backbone(img_backbone) if img_neck is not None: self.img_neck = builder.build_neck(img_neck) if img_rpn_head is not None: self.img_rpn_head = builder.build_head(img_rpn_head) if img_roi_head is not None: self.img_roi_head = builder.build_head(img_roi_head) self.train_cfg = train_cfg self.test_cfg = test_cfg if pretrained is None: img_pretrained = None pts_pretrained = None elif isinstance(pretrained, dict): img_pretrained = pretrained.get('img', None) pts_pretrained = pretrained.get('pts', None) else: raise ValueError( f'pretrained should be a dict, got {type(pretrained)}') self.use_grid_mask = use_grid_mask self.fp16_enabled = False # temporal self.video_test_mode = video_test_mode self.prev_frame_info = { 'prev_bev': None, 'scene_token': None, 'prev_pos': 0, 'prev_angle': 0, } self.modality = modality if self.modality == 'fusion' and lidar_encoder is not None: if lidar_encoder["voxelize"].get("max_num_points", -1) > 0: voxelize_module = Voxelization(**lidar_encoder["voxelize"]) else: voxelize_module = DynamicScatter(**lidar_encoder["voxelize"]) self.lidar_modal_extractor = nn.ModuleDict( { "voxelize": voxelize_module, "backbone": builder.build_middle_encoder(lidar_encoder["backbone"]), } ) self.voxelize_reduce = lidar_encoder.get("voxelize_reduce", True) self._lr = lr def extract_img_feat(self, img, img_metas=None, len_queue=None): """Extract features of images.""" B = img.size(0) if img is not None: # input_shape = img.shape[-2:] # # update real input shape of each single img # for img_meta in img_metas: # img_meta.update(input_shape=input_shape) if img.dim() == 5 and img.size(0) == 1: img.squeeze_() elif img.dim() == 5 and img.size(0) > 1: B, N, C, H, W = img.size() img = img.reshape(B * N, C, H, W) if self.use_grid_mask: img = self.grid_mask(img) img_feats = self.img_backbone(img) if isinstance(img_feats, dict): img_feats = list(img_feats.values()) else: return None self.with_img_neck = True if self.with_img_neck: img_feats = self.img_neck(img_feats) BN, C, H, W = img_feats[0].shape return [tmp.view(B, BN // B, C, H , W) for tmp in img_feats] @torch.no_grad() def voxelize(self, points): feats, coords, sizes = [], [], [] for k, res in enumerate(points): ret = self.lidar_modal_extractor["voxelize"](res) if len(ret) == 3: # hard voxelize f, c, n = ret else: assert len(ret) == 2 f, c = ret n = None feats.append(f) coords.append(F.pad(c, (1, 0), mode="constant", value=k)) if n is not None: sizes.append(n) feats = torch.cat(feats, dim=0) coords = torch.cat(coords, dim=0) if len(sizes) > 0: sizes = torch.cat(sizes, dim=0) if self.voxelize_reduce: feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view( -1, 1 ) feats = feats.contiguous() return feats, coords, sizes def extract_lidar_feat(self, points): feats, coords, sizes = self.voxelize(points) # voxel_features = self.lidar_modal_extractor["voxel_encoder"](feats, sizes, coords) batch_size = coords[-1, 0] + 1 lidar_feat = self.lidar_modal_extractor["backbone"](feats, coords, batch_size) return lidar_feat def forward(self, feature_dict=None, points=None, img_metas=None) -> Dict[str, torch.Tensor]: lidar_feat = None # points = feature_dict['lidars_warped'] # points_input = [] # for tmp in points: # points_input.append(torch.cat(tmp, 0)) if self.modality == 'fusion': lidar_feat = self.extract_lidar_feat(points_input) img = feature_dict['image'] len_queue = img.size(1) img = img[:, -1, ...] img_feats = self.extract_img_feat(img, img_metas, len_queue=len_queue) outs = self.pts_bbox_head( img_feats, lidar_feat, feature_dict, None) return outs # class MyLightningModule(pl.LightningModule): # def __init__( # self, # agent: AbstractAgent, # ): # super().__init__() # self.agent = agent # def _step( # self, # batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]], # logging_prefix: str, # ): # features, targets = batch # prediction = self.agent.forward(features) # loss = self.agent.compute_loss(features, targets, prediction) # self.log(f"{logging_prefix}/loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) # return loss # def training_step( # self, # batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]], # batch_idx: int # ): # return self._step(batch, "train") # def validation_step( # self, # batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]], # batch_idx: int # ): # return self._step(batch, "val") # def configure_optimizers(self): # optimizer = self.agent.get_optimizers() # # 应用梯度裁剪 # if 'grad_clip' in self.optimizer_config: # grad_clip = self.optimizer_config['grad_clip'] # max_norm = grad_clip.get('max_norm', 1.0) # norm_type = grad_clip.get('norm_type', 2) # optimizer = optim.Adam(self.parameters(), lr=1e-3) # return { # 'optimizer': optimizer, # 'clip_grad_norm': max_norm, # 'clip_grad_value': None, # 可以使用 'clip_grad_value' 来限制梯度的绝对值 # } # else: # return optimizerfrom __future__ import annotations