import copy import torch import torch.nn as nn import numpy as np from torch.nn.init import normal_ from det_map.det.dal.mmdet3d.models.builder import build_fuser import torch.nn.functional as F from mmdet.models.utils.builder import TRANSFORMER from det_map.det.dal.mmdet3d.models.builder import FUSERS from mmcv.cnn import Linear, bias_init_with_prob, xavier_init, constant_init from mmcv.runner.base_module import BaseModule, ModuleList, Sequential from mmcv.cnn.bricks.transformer import build_transformer_layer_sequence, build_positional_encoding from torchvision.transforms.functional import rotate from det_map.det.dal.mmdet3d.models.bevformer_modules.temporal_self_attention import TemporalSelfAttention from det_map.det.dal.mmdet3d.models.bevformer_modules.spatial_cross_attention import MSDeformableAttention3D from det_map.det.dal.mmdet3d.models.bevformer_modules.decoder import CustomMSDeformableAttention from typing import List @FUSERS.register_module() class ConvFuser(nn.Sequential): def __init__(self, in_channels: int, out_channels: int) -> None: self.in_channels = in_channels self.out_channels = out_channels super().__init__( nn.Conv2d(sum(in_channels), out_channels, 3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(True), ) def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: return super().forward(torch.cat(inputs, dim=1)) @TRANSFORMER.register_module() class MapTRPerceptionTransformer(BaseModule): """Implements the Detr3D transformer. Args: as_two_stage (bool): Generate query from encoder features. Default: False. num_feature_levels (int): Number of feature maps from FPN: Default: 4. two_stage_num_proposals (int): Number of proposals when set `as_two_stage` as True. Default: 300. """ def __init__(self, bev_h, bev_w, num_feature_levels=1, num_cams=2, z_cfg=dict( pred_z_flag=False, gt_z_flag=False, ), two_stage_num_proposals=300, fuser=None, encoder=None, decoder=None, embed_dims=256, rotate_prev_bev=True, use_shift=True, use_can_bus=True, can_bus_norm=True, use_cams_embeds=True, rotate_center=[100, 100], modality='vision', feat_down_sample_indice=-1, **kwargs): super(MapTRPerceptionTransformer, self).__init__(**kwargs) if modality == 'fusion': self.fuser = build_fuser(fuser) # self.use_attn_bev = encoder['type'] == 'BEVFormerEncoder' self.use_attn_bev = True self.bev_h = bev_h self.bev_w = bev_w self.bev_embedding = nn.Embedding(self.bev_h * self.bev_w, embed_dims) self.positional_encoding = build_positional_encoding( dict( type='CustomLearnedPositionalEncoding', num_feats=embed_dims // 2, row_num_embed=self.bev_h, col_num_embed=self.bev_w, ) ) self.encoder = build_transformer_layer_sequence(encoder) self.decoder = build_transformer_layer_sequence(decoder) self.embed_dims = embed_dims self.num_feature_levels = num_feature_levels self.num_cams = num_cams self.fp16_enabled = False self.rotate_prev_bev = rotate_prev_bev self.use_shift = use_shift self.use_can_bus = use_can_bus self.can_bus_norm = can_bus_norm self.use_cams_embeds = use_cams_embeds self.two_stage_num_proposals = two_stage_num_proposals self.z_cfg=z_cfg self.init_layers() self.rotate_center = rotate_center self.feat_down_sample_indice = feat_down_sample_indice def init_layers(self): """Initialize layers of the Detr3DTransformer.""" # self.level_embeds = nn.Parameter(torch.Tensor( # self.num_feature_levels, self.embed_dims)) # self.cams_embeds = nn.Parameter( # torch.Tensor(self.num_cams, self.embed_dims)) self.reference_points = nn.Linear(self.embed_dims, 2) if not self.z_cfg['gt_z_flag'] \ else nn.Linear(self.embed_dims, 3) # self.can_bus_mlp = nn.Sequential( # nn.Linear(18, self.embed_dims // 2), # nn.ReLU(inplace=True), # nn.Linear(self.embed_dims // 2, self.embed_dims), # nn.ReLU(inplace=True), # ) # if self.can_bus_norm: # self.can_bus_mlp.add_module('norm', nn.LayerNorm(self.embed_dims)) def init_weights(self): """Initialize the transformer weights.""" for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MSDeformableAttention3D) or isinstance(m, TemporalSelfAttention) \ or isinstance(m, CustomMSDeformableAttention): try: m.init_weight() except AttributeError: m.init_weights() normal_(self.level_embeds) normal_(self.cams_embeds) xavier_init(self.reference_points, distribution='uniform', bias=0.) # xavier_init(self.can_bus_mlp, distribution='uniform', bias=0.) # TODO apply fp16 to this module cause grad_norm NAN # @auto_fp16(apply_to=('mlvl_feats', 'bev_queries', 'prev_bev', 'bev_pos'), out_fp32=True) def attn_bev_encode( self, mlvl_feats, cam_params=None, gt_bboxes_3d=None, pred_img_depth=None, prev_bev=None, bev_mask=None, **kwargs): bs = mlvl_feats[0].size(0) dtype = mlvl_feats[0].dtype feat_flatten = [] spatial_shapes = [] for lvl, feat in enumerate(mlvl_feats): bs, num_cam, c, h, w = feat.shape spatial_shape = (h, w) feat = feat.flatten(3).permute(1, 0, 3, 2) spatial_shapes.append(spatial_shape) feat_flatten.append(feat) feat_flatten = torch.cat(feat_flatten, 2) spatial_shapes = torch.as_tensor( spatial_shapes, dtype=torch.long, device=mlvl_feats[0].device) level_start_index = torch.cat((spatial_shapes.new_zeros( (1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) feat_flatten = feat_flatten.permute(0, 2, 1, 3) # (num_cam, H*W, bs, embed_dims) bev_queries = self.bev_embedding.weight.to(dtype) bev_queries = bev_queries.unsqueeze(1).repeat(1, bs, 1) bev_pos = self.positional_encoding(bs, self.bev_h, self.bev_w, bev_queries.device).to(dtype) bev_pos = bev_pos.flatten(2).permute(2, 0, 1) bev_embed = self.encoder( bev_queries, feat_flatten, feat_flatten, bev_h=self.bev_h, bev_w=self.bev_w, bev_pos=bev_pos, spatial_shapes=spatial_shapes, level_start_index=level_start_index, cam_params=cam_params, gt_bboxes_3d=gt_bboxes_3d, pred_img_depth=pred_img_depth, prev_bev=prev_bev, bev_mask=bev_mask, **kwargs ) return bev_embed def lss_bev_encode( self, mlvl_feats, prev_bev=None, **kwargs): # import ipdb;ipdb.set_trace() # assert len(mlvl_feats) == 1, 'Currently we only use last single level feat in LSS' # import ipdb;ipdb.set_trace() images = mlvl_feats[self.feat_down_sample_indice] img_metas = kwargs['img_metas'] encoder_outputdict = self.encoder(images,img_metas) bev_embed = encoder_outputdict['bev'] depth = encoder_outputdict['depth'] bs, c, _,_ = bev_embed.shape bev_embed = bev_embed.view(bs,c,-1).permute(0,2,1).contiguous() ret_dict = dict( bev=bev_embed, depth=depth ) return ret_dict def get_bev_features( self, mlvl_feats, lidar_feat, bev_queries, bev_h, bev_w, grid_length=[0.512, 0.512], bev_pos=None, prev_bev=None, **kwargs): """ obtain bev features. """ assert self.use_attn_bev if self.use_attn_bev: img_metas = kwargs['img_metas'] rot = img_metas['sensor2lidar_rotation'] B, T, N, _, _ = rot.shape cam_params = (img_metas['sensor2lidar_rotation'][:, -1], img_metas['sensor2lidar_translation'][:, -1], img_metas['intrinsics'][:, -1], img_metas['post_rot'][:, -1], img_metas['post_tran'][:, -1], torch.eye(3, device=rot.device, dtype=rot.dtype)[None].repeat(B, 1, 1) ) bev_embed = self.attn_bev_encode( mlvl_feats, cam_params=cam_params, **kwargs) else: ret_dict = self.lss_bev_encode( mlvl_feats, prev_bev=prev_bev, **kwargs) bev_embed = ret_dict['bev'] depth = ret_dict['depth'] if lidar_feat is not None: bs = mlvl_feats[0].size(0) bev_embed = bev_embed.view(bs, bev_h, bev_w, -1).permute(0,3,1,2).contiguous() lidar_feat = lidar_feat.permute(0,1,3,2).contiguous() # B C H W # lidar_feat = nn.functional.interpolate(lidar_feat, size=(bev_h,bev_w), mode='bicubic', align_corners=False) fused_bev = self.fuser([bev_embed, lidar_feat]) fused_bev = fused_bev.flatten(2).permute(0,2,1).contiguous() bev_embed = fused_bev ret_dict = dict( bev=bev_embed, depth=None ) return ret_dict def format_feats(self, mlvl_feats): bs = mlvl_feats[0].size(0) feat_flatten = [] spatial_shapes = [] for lvl, feat in enumerate(mlvl_feats): # import pdb; pdb.set_trace() bs, num_cam, c, h, w = feat.shape spatial_shape = (h, w) feat = feat.flatten(3).permute(1, 0, 3, 2) if self.use_cams_embeds: feat = feat feat = feat spatial_shapes.append(spatial_shape) feat_flatten.append(feat) feat_flatten = torch.cat(feat_flatten, 2) spatial_shapes = torch.as_tensor( spatial_shapes, dtype=torch.long, device=feat.device) level_start_index = torch.cat((spatial_shapes.new_zeros( (1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) feat_flatten = feat_flatten.permute( 0, 2, 1, 3) # (num_cam, H*W, bs, embed_dims) return feat_flatten, spatial_shapes, level_start_index # TODO apply fp16 to this module cause grad_norm NAN # @auto_fp16(apply_to=('mlvl_feats', 'bev_queries', 'object_query_embed', 'prev_bev', 'bev_pos')) def forward(self, mlvl_feats, lidar_feat, bev_queries, object_query_embed, bev_h, bev_w, grid_length=[0.512, 0.512], bev_pos=None, reg_branches=None, cls_branches=None, prev_bev=None, **kwargs): """Forward function for `Detr3DTransformer`. Args: mlvl_feats (list(Tensor)): Input queries from different level. Each element has shape [bs, num_cams, embed_dims, h, w]. bev_queries (Tensor): (bev_h*bev_w, c) bev_pos (Tensor): (bs, embed_dims, bev_h, bev_w) object_query_embed (Tensor): The query embedding for decoder, with shape [num_query, c]. reg_branches (obj:`nn.ModuleList`): Regression heads for feature maps from each decoder layer. Only would be passed when `with_box_refine` is True. Default to None. Returns: tuple[Tensor]: results of decoder containing the following tensor. - bev_embed: BEV features - inter_states: Outputs from decoder. If return_intermediate_dec is True output has shape \ (num_dec_layers, bs, num_query, embed_dims), else has \ shape (1, bs, num_query, embed_dims). - init_reference_out: The initial value of reference \ points, has shape (bs, num_queries, 4). - inter_references_out: The internal value of reference \ points in decoder, has shape \ (num_dec_layers, bs,num_query, embed_dims) - enc_outputs_class: The classification score of \ proposals generated from \ encoder's feature maps, has shape \ (batch, h*w, num_classes). \ Only would be returned when `as_two_stage` is True, \ otherwise None. - enc_outputs_coord_unact: The regression results \ generated from encoder's feature maps., has shape \ (batch, h*w, 4). Only would \ be returned when `as_two_stage` is True, \ otherwise None. """ ouput_dic = self.get_bev_features( mlvl_feats, lidar_feat, bev_queries, bev_h, bev_w, grid_length=grid_length, bev_pos=bev_pos, prev_bev=prev_bev, **kwargs) # bev_embed shape: bs, bev_h*bev_w, embed_dims bev_embed = ouput_dic['bev'] depth = ouput_dic['depth'] bs = mlvl_feats[0].size(0) query_pos, query = torch.split( object_query_embed, self.embed_dims, dim=1) query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) query = query.unsqueeze(0).expand(bs, -1, -1) reference_points = self.reference_points(query_pos) reference_points = reference_points.sigmoid() init_reference_out = reference_points query = query.permute(1, 0, 2) query_pos = query_pos.permute(1, 0, 2) bev_embed = bev_embed.permute(1, 0, 2) feat_flatten, feat_spatial_shapes, feat_level_start_index \ = self.format_feats(mlvl_feats) inter_states, inter_references = self.decoder( query=query, key=None, value=bev_embed, query_pos=query_pos, reference_points=reference_points, reg_branches=reg_branches, cls_branches=cls_branches, spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device), level_start_index=torch.tensor([0], device=query.device), mlvl_feats=mlvl_feats, feat_flatten=None, feat_spatial_shapes=feat_spatial_shapes, feat_level_start_index=feat_level_start_index, **kwargs) inter_references_out = inter_references return bev_embed, depth, inter_states, init_reference_out, inter_references_out