import torch import numpy as np from mmcv.runner.base_module import BaseModule, ModuleList, Sequential import torch.nn as nn from mmcv.cnn.bricks.registry import (ATTENTION, TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE) from det_map.det.dal.mmdet3d.ops.bev_pool_v2.bev_pool import bev_pool_v2 from mmcv.runner import force_fp32, auto_fp16 from torch.cuda.amp.autocast_mode import autocast from mmcv.cnn import build_conv_layer from mmdet.models.backbones.resnet import BasicBlock, Bottleneck import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from det_map.det.dal.mmdet3d.models.bevformer_modules.encoder import BEVFormerEncoder def gen_dx_bx(xbound, ybound, zbound): dx = torch.Tensor([row[2] for row in [xbound, ybound, zbound]]) bx = torch.Tensor([row[0] + row[2] / 2.0 for row in [xbound, ybound, zbound]]) nx = torch.Tensor( [int((row[1] - row[0]) / row[2]) for row in [xbound, ybound, zbound]] ) return dx, bx, nx @TRANSFORMER_LAYER_SEQUENCE.register_module() class BaseTransform(BaseModule): def __init__( self, in_channels, out_channels, feat_down_sample, pc_range, voxel_size, dbound, ): super(BaseTransform, self).__init__() self.in_channels = in_channels self.feat_down_sample = feat_down_sample # self.image_size = image_size # self.feature_size = feature_size self.xbound = [pc_range[0],pc_range[3], voxel_size[0]] self.ybound = [pc_range[1],pc_range[4], voxel_size[1]] self.zbound = [pc_range[2],pc_range[5], voxel_size[2]] self.dbound = dbound dx, bx, nx = gen_dx_bx(self.xbound, self.ybound, self.zbound) self.dx = nn.Parameter(dx, requires_grad=False) self.bx = nn.Parameter(bx, requires_grad=False) self.nx = nn.Parameter(nx, requires_grad=False) self.C = out_channels self.frustum = None self.D = int((dbound[1] - dbound[0]) / dbound[2]) # self.frustum = self.create_frustum() # self.D = self.frustum.shape[0] self.fp16_enabled = False @force_fp32() def create_frustum(self,fH,fW,img_metas): # iH, iW = self.image_size # fH, fW = self.feature_size iH = img_metas[0]['img_shape'][0][0] iW = img_metas[0]['img_shape'][0][1] assert iH // self.feat_down_sample == fH # import pdb;pdb.set_trace() ds = ( torch.arange(*self.dbound, dtype=torch.float) .view(-1, 1, 1) .expand(-1, fH, fW) ) D, _, _ = ds.shape xs = ( torch.linspace(0, iW - 1, fW, dtype=torch.float) .view(1, 1, fW) .expand(D, fH, fW) ) ys = ( torch.linspace(0, iH - 1, fH, dtype=torch.float) .view(1, fH, 1) .expand(D, fH, fW) ) frustum = torch.stack((xs, ys, ds), -1) # return nn.Parameter(frustum, requires_grad=False) return frustum @force_fp32() def get_geometry_v1( self, fH, fW, rots, trans, intrins, post_rots, post_trans, lidar2ego_rots, lidar2ego_trans, img_metas, **kwargs, ): B, N, _ = trans.shape device = trans.device if self.frustum == None: self.frustum = self.create_frustum(fH,fW,img_metas) self.frustum = self.frustum.to(device) # self.D = self.frustum.shape[0] # undo post-transformation # B x N x D x H x W x 3 points = self.frustum - post_trans.view(B, N, 1, 1, 1, 3) points = ( torch.inverse(post_rots) .view(B, N, 1, 1, 1, 3, 3) .matmul(points.unsqueeze(-1)) ) # cam_to_ego points = torch.cat( ( points[:, :, :, :, :, :2] * points[:, :, :, :, :, 2:3], points[:, :, :, :, :, 2:3], ), 5, ) combine = rots.matmul(torch.inverse(intrins)) points = combine.view(B, N, 1, 1, 1, 3, 3).matmul(points).squeeze(-1) points += trans.view(B, N, 1, 1, 1, 3) # ego_to_lidar points -= lidar2ego_trans.view(B, 1, 1, 1, 1, 3) points = ( torch.inverse(lidar2ego_rots) .view(B, 1, 1, 1, 1, 3, 3) .matmul(points.unsqueeze(-1)) .squeeze(-1) ) if "extra_rots" in kwargs: extra_rots = kwargs["extra_rots"] points = ( extra_rots.view(B, 1, 1, 1, 1, 3, 3) .repeat(1, N, 1, 1, 1, 1, 1) .matmul(points.unsqueeze(-1)) .squeeze(-1) ) if "extra_trans" in kwargs: extra_trans = kwargs["extra_trans"] points += extra_trans.view(B, 1, 1, 1, 1, 3).repeat(1, N, 1, 1, 1, 1) return points @force_fp32() def get_geometry( self, fH, fW, lidar2img, img_metas, ): B, N, _, _ = lidar2img.shape device = lidar2img.device # import pdb;pdb.set_trace() if self.frustum == None: self.frustum = self.create_frustum(fH,fW,img_metas) self.frustum = self.frustum.to(device) # self.D = self.frustum.shape[0] points = self.frustum.view(1,1,self.D, fH, fW, 3) \ .repeat(B,N,1,1,1,1) lidar2img = lidar2img.view(B,N,1,1,1,4,4) # img2lidar = torch.inverse(lidar2img) points = torch.cat( (points, torch.ones_like(points[..., :1])), -1) points = torch.linalg.solve(lidar2img.to(torch.float32), points.unsqueeze(-1).to(torch.float32)).squeeze(-1) # points = torch.matmul(img2lidar.to(torch.float32), # points.unsqueeze(-1).to(torch.float32)).squeeze(-1) # import pdb;pdb.set_trace() eps = 1e-5 points = points[..., 0:3] / torch.maximum( points[..., 3:4], torch.ones_like(points[..., 3:4]) * eps) return points def get_cam_feats(self, x): raise NotImplementedError def get_mlp_input(self, sensor2ego, intrin, post_rot, post_tran, bda): raise NotImplementedError @force_fp32() def bev_pool(self, geom_feats, x): B, N, D, H, W, C = x.shape Nprime = B * N * D * H * W # flatten x x = x.reshape(Nprime, C) # flatten indices geom_feats = ((geom_feats - (self.bx - self.dx / 2.0)) / self.dx).long() geom_feats = geom_feats.view(Nprime, 3) batch_ix = torch.cat( [ torch.full([Nprime // B, 1], ix, device=x.device, dtype=torch.long) for ix in range(B) ] ) geom_feats = torch.cat((geom_feats, batch_ix), 1) # filter out points that are outside box kept = ( (geom_feats[:, 0] >= 0) & (geom_feats[:, 0] < self.nx[0]) & (geom_feats[:, 1] >= 0) & (geom_feats[:, 1] < self.nx[1]) & (geom_feats[:, 2] >= 0) & (geom_feats[:, 2] < self.nx[2]) ) x = x[kept] geom_feats = geom_feats[kept] x = bev_pool(x, geom_feats, B, self.nx[2], self.nx[0], self.nx[1]) # collapse Z final = torch.cat(x.unbind(dim=2), 1) return final @force_fp32() def forward( self, images, img_metas ): B, N, C, fH, fW = images.shape lidar2img = [] camera2ego = [] camera_intrinsics = [] img_aug_matrix = [] lidar2ego = [] for img_meta in img_metas: lidar2img.append(img_meta['lidar2img']) camera2ego.append(img_meta['camera2ego']) camera_intrinsics.append(img_meta['camera_intrinsics']) img_aug_matrix.append(img_meta['img_aug_matrix']) lidar2ego.append(img_meta['lidar2ego']) lidar2img = np.asarray(lidar2img) lidar2img = images.new_tensor(lidar2img) # (B, N, 4, 4) camera2ego = np.asarray(camera2ego) camera2ego = images.new_tensor(camera2ego) # (B, N, 4, 4) camera_intrinsics = np.asarray(camera_intrinsics) camera_intrinsics = images.new_tensor(camera_intrinsics) # (B, N, 4, 4) img_aug_matrix = np.asarray(img_aug_matrix) img_aug_matrix = images.new_tensor(img_aug_matrix) # (B, N, 4, 4) lidar2ego = np.asarray(lidar2ego) lidar2ego = images.new_tensor(lidar2ego) # (B, N, 4, 4) # import pdb;pdb.set_trace() # lidar2cam = torch.linalg.solve(camera2ego, lidar2ego.view(B,1,4,4).repeat(1,N,1,1)) # lidar2oriimg = torch.matmul(camera_intrinsics,lidar2cam) # mylidar2img = torch.matmul(img_aug_matrix,lidar2oriimg) rots = camera2ego[..., :3, :3] trans = camera2ego[..., :3, 3] intrins = camera_intrinsics[..., :3, :3] post_rots = img_aug_matrix[..., :3, :3] post_trans = img_aug_matrix[..., :3, 3] lidar2ego_rots = lidar2ego[..., :3, :3] lidar2ego_trans = lidar2ego[..., :3, 3] # tmpgeom = self.get_geometry( # fH, # fW, # mylidar2img, # img_metas, # ) geom = self.get_geometry_v1( fH, fW, rots, trans, intrins, post_rots, post_trans, lidar2ego_rots, lidar2ego_trans, img_metas ) mlp_input = self.get_mlp_input(camera2ego, camera_intrinsics, post_rots, post_trans) x, depth = self.get_cam_feats(images, mlp_input) x = self.bev_pool(geom, x) # import pdb;pdb.set_trace() x = x.permute(0,1,3,2).contiguous() return x, depth @TRANSFORMER_LAYER_SEQUENCE.register_module() class BaseTransformV2(BaseModule): def __init__( self, input_size, in_channels, out_channels, feat_down_sample, pc_range, voxel_size, dbound, sid=False, ): super(BaseTransformV2, self).__init__() self.mlp_input = nn.Parameter( torch.randn(22) ) self.in_channels = in_channels self.feat_down_sample = feat_down_sample # self.image_size = image_size # self.feature_size = feature_size xbound = [pc_range[0],pc_range[3], voxel_size[0]] ybound = [pc_range[1],pc_range[4], voxel_size[1]] zbound = [pc_range[2],pc_range[5], voxel_size[2]] grid_config = [xbound, ybound, zbound] self.create_grid_infos(*grid_config) self.dbound = dbound self.sid = sid self.frustum = self.create_frustum(dbound, input_size, feat_down_sample) self.C = out_channels self.D = round((dbound[1] - dbound[0]) / dbound[2]) self.fp16_enabled = False def create_grid_infos(self, x, y, z, **kwargs): """Generate the grid information including the lower bound, interval, and size. Args: x (tuple(float)): Config of grid alone x axis in format of (lower_bound, upper_bound, interval). y (tuple(float)): Config of grid alone y axis in format of (lower_bound, upper_bound, interval). z (tuple(float)): Config of grid alone z axis in format of (lower_bound, upper_bound, interval). **kwargs: Container for other potential parameters """ self.grid_lower_bound = torch.Tensor([cfg[0] for cfg in [x, y, z]]) self.grid_interval = torch.Tensor([cfg[2] for cfg in [x, y, z]]) self.grid_size = torch.Tensor([(cfg[1] - cfg[0]) / cfg[2] for cfg in [x, y, z]]) # @force_fp32() def create_frustum(self, depth_cfg, input_size, downsample): """Generate the frustum template for each image. Args: depth_cfg (tuple(float)): Config of grid alone depth axis in format of (lower_bound, upper_bound, interval). `input_size` (tuple(int)): Size of input images in format of (height, width). downsample (int): Down sample scale factor from the input size to the feature size. """ H_in, W_in = input_size H_feat, W_feat = H_in // downsample, W_in // downsample d = torch.arange(*depth_cfg, dtype=torch.float)\ .view(-1, 1, 1).expand(-1, H_feat, W_feat) self.D = d.shape[0] if self.sid: d_sid = torch.arange(self.D).float() depth_cfg_t = torch.tensor(depth_cfg).float() d_sid = torch.exp(torch.log(depth_cfg_t[0]) + d_sid / (self.D-1) * torch.log((depth_cfg_t[1]-1) / depth_cfg_t[0])) d = d_sid.view(-1, 1, 1).expand(-1, H_feat, W_feat) x = torch.linspace(0, W_in - 1, W_feat, dtype=torch.float)\ .view(1, 1, W_feat).expand(self.D, H_feat, W_feat) y = torch.linspace(0, H_in - 1, H_feat, dtype=torch.float)\ .view(1, H_feat, 1).expand(self.D, H_feat, W_feat) # D x H x W x 3 return torch.stack((x, y, d), -1) @force_fp32() def get_geometry_v1( self, fH, fW, rots, trans, intrins, post_rots, post_trans, lidar2ego_rots, lidar2ego_trans, img_metas, **kwargs, ): B, N, _ = trans.shape device = trans.device # if self.frustum == None: # self.frustum = self.create_frustum(fH,fW,img_metas) # self.frustum = self.frustum.to(device) # # self.D = self.frustum.shape[0] # undo post-transformation # B x N x D x H x W x 3 points = self.frustum.to(device)- post_trans.view(B, N, 1, 1, 1, 3) points = ( torch.inverse(post_rots) .view(B, N, 1, 1, 1, 3, 3) .matmul(points.unsqueeze(-1)) ) # cam_to_ego points = torch.cat( ( points[:, :, :, :, :, :2] * points[:, :, :, :, :, 2:3], points[:, :, :, :, :, 2:3], ), 5, ) combine = rots.matmul(torch.inverse(intrins)) points = combine.view(B, N, 1, 1, 1, 3, 3).matmul(points).squeeze(-1) points += trans.view(B, N, 1, 1, 1, 3) # ego_to_lidar points -= lidar2ego_trans.view(B, 1, 1, 1, 1, 3) points = ( torch.inverse(lidar2ego_rots) .view(B, 1, 1, 1, 1, 3, 3) .matmul(points.unsqueeze(-1)) .squeeze(-1) ) if "extra_rots" in kwargs: extra_rots = kwargs["extra_rots"] points = ( extra_rots.view(B, 1, 1, 1, 1, 3, 3) .repeat(1, N, 1, 1, 1, 1, 1) .matmul(points.unsqueeze(-1)) .squeeze(-1) ) if "extra_trans" in kwargs: extra_trans = kwargs["extra_trans"] points += extra_trans.view(B, 1, 1, 1, 1, 3).repeat(1, N, 1, 1, 1, 1) return points @force_fp32() def get_geometry( self, fH, fW, lidar2img, img_metas, ): B, N, _, _ = lidar2img.shape device = lidar2img.device if self.frustum == None: self.frustum = self.create_frustum(fH,fW,img_metas) self.frustum = self.frustum.to(device) # self.D = self.frustum.shape[0] points = self.frustum.view(1,1,self.D, fH, fW, 3) \ .repeat(B,N,1,1,1,1) lidar2img = lidar2img.view(B,N,1,1,1,4,4) # img2lidar = torch.inverse(lidar2img) points = torch.cat( (points, torch.ones_like(points[..., :1])), -1) points = torch.linalg.solve(lidar2img.to(torch.float32), points.unsqueeze(-1).to(torch.float32)).squeeze(-1) # points = torch.matmul(img2lidar.to(torch.float32), # points.unsqueeze(-1).to(torch.float32)).squeeze(-1) eps = 1e-5 points = points[..., 0:3] / torch.maximum( points[..., 3:4], torch.ones_like(points[..., 3:4]) * eps) return points def get_cam_feats(self, x): raise NotImplementedError def get_mlp_input(self, sensor2ego, intrin, post_rot, post_tran, bda): raise NotImplementedError def voxel_pooling_prepare_v2(self, coor): """Data preparation for voxel pooling. Args: coor (torch.tensor): Coordinate of points in the lidar space in shape (B, N, D, H, W, 3). Returns: tuple[torch.tensor]: Rank of the voxel that a point is belong to in shape (N_Points); Reserved index of points in the depth space in shape (N_Points). Reserved index of points in the feature space in shape (N_Points). """ B, N, D, H, W, _ = coor.shape num_points = B * N * D * H * W # record the index of selected points for acceleration purpose ranks_depth = torch.range( 0, num_points - 1, dtype=torch.int, device=coor.device) ranks_feat = torch.range( 0, num_points // D - 1, dtype=torch.int, device=coor.device) ranks_feat = ranks_feat.reshape(B, N, 1, H, W) ranks_feat = ranks_feat.expand(B, N, D, H, W).flatten() # convert coordinate into the voxel space coor = ((coor - self.grid_lower_bound.to(coor)) / self.grid_interval.to(coor)) coor = coor.long().view(num_points, 3) batch_idx = torch.range(0, B - 1).reshape(B, 1). \ expand(B, num_points // B).reshape(num_points, 1).to(coor) coor = torch.cat((coor, batch_idx), 1) # filter out points that are outside box kept = (coor[:, 0] >= 0) & (coor[:, 0] < self.grid_size[0]) & \ (coor[:, 1] >= 0) & (coor[:, 1] < self.grid_size[1]) & \ (coor[:, 2] >= 0) & (coor[:, 2] < self.grid_size[2]) if len(kept) == 0: return None, None, None, None, None coor, ranks_depth, ranks_feat = \ coor[kept], ranks_depth[kept], ranks_feat[kept] # get tensors from the same voxel next to each other ranks_bev = coor[:, 3] * ( self.grid_size[2] * self.grid_size[1] * self.grid_size[0]) ranks_bev += coor[:, 2] * (self.grid_size[1] * self.grid_size[0]) ranks_bev += coor[:, 1] * self.grid_size[0] + coor[:, 0] order = ranks_bev.argsort() ranks_bev, ranks_depth, ranks_feat = \ ranks_bev[order], ranks_depth[order], ranks_feat[order] kept = torch.ones( ranks_bev.shape[0], device=ranks_bev.device, dtype=torch.bool) kept[1:] = ranks_bev[1:] != ranks_bev[:-1] interval_starts = torch.where(kept)[0].int() if len(interval_starts) == 0: return None, None, None, None, None interval_lengths = torch.zeros_like(interval_starts) interval_lengths[:-1] = interval_starts[1:] - interval_starts[:-1] interval_lengths[-1] = ranks_bev.shape[0] - interval_starts[-1] return ranks_bev.int().contiguous(), ranks_depth.int().contiguous( ), ranks_feat.int().contiguous(), interval_starts.int().contiguous( ), interval_lengths.int().contiguous() @force_fp32() def voxel_pooling_v2(self, coor, depth, feat): ranks_bev, ranks_depth, ranks_feat, \ interval_starts, interval_lengths = \ self.voxel_pooling_prepare_v2(coor) if ranks_feat is None: print('warning ---> no points within the predefined ' 'bev receptive field') dummy = torch.zeros(size=[ feat.shape[0], feat.shape[2], int(self.grid_size[2]), int(self.grid_size[0]), int(self.grid_size[1]) ]).to(feat) dummy = torch.cat(dummy.unbind(dim=2), 1) return dummy feat = feat.permute(0, 1, 3, 4, 2) bev_feat_shape = (depth.shape[0], int(self.grid_size[2]), int(self.grid_size[1]), int(self.grid_size[0]), feat.shape[-1]) # (B, Z, Y, X, C) bev_feat = bev_pool_v2(depth, feat, ranks_depth, ranks_feat, ranks_bev, bev_feat_shape, interval_starts, interval_lengths) # collapse Z # if self.collapse_z: bev_feat = torch.cat(bev_feat.unbind(dim=2), 1) return bev_feat @force_fp32() def bev_pool(self, geom_feats, x): B, N, D, H, W, C = x.shape Nprime = B * N * D * H * W # flatten x x = x.reshape(Nprime, C) # flatten indices geom_feats = ((geom_feats - (self.bx - self.dx / 2.0)) / self.dx).long() geom_feats = geom_feats.view(Nprime, 3) batch_ix = torch.cat( [ torch.full([Nprime // B, 1], ix, device=x.device, dtype=torch.long) for ix in range(B) ] ) geom_feats = torch.cat((geom_feats, batch_ix), 1) # filter out points that are outside box kept = ( (geom_feats[:, 0] >= 0) & (geom_feats[:, 0] < self.nx[0]) & (geom_feats[:, 1] >= 0) & (geom_feats[:, 1] < self.nx[1]) & (geom_feats[:, 2] >= 0) & (geom_feats[:, 2] < self.nx[2]) ) x = x[kept] geom_feats = geom_feats[kept] x = bev_pool(x, geom_feats, B, self.nx[2], self.nx[0], self.nx[1]) # collapse Z final = torch.cat(x.unbind(dim=2), 1) return final @force_fp32() def forward( self, images, img_metas ): B, N, C, fH, fW = images.shape rots = img_metas['sensor2lidar_rotation'][:, -1] trans = img_metas['sensor2lidar_translation'][:, -1] intrins = img_metas['intrinsics'][:, -1] post_rots = img_metas['post_rot'][:, -1] post_trans = img_metas['post_tran'][:, -1] lidar2ego = torch.eye(4, device=post_trans.device, dtype=post_rots.dtype) lidar2ego = lidar2ego[None, None].repeat(B, 1, 1, 1) lidar2ego_rots = lidar2ego[..., :3, :3] lidar2ego_trans = lidar2ego[..., :3, 3] coor = self.get_geometry_v1( fH, fW, rots, trans, intrins, post_rots, post_trans, lidar2ego_rots, lidar2ego_trans, img_metas ) sensor2ego = torch.zeros((B, N, 4, 4), dtype=rots.dtype, device=rots.device) sensor2ego[:, :, :3, :3] = rots sensor2ego[:, :, :3, 3] = trans sensor2ego[:, :, -1, -1] = 1.0 # mlp_input = self.get_mlp_input(sensor2ego, intrins, post_rots, post_trans) tran_feat, depth = self.get_cam_feats(images, self.mlp_input.data[None, None].repeat(B, N, 1)) bev_feat = self.voxel_pooling_v2( coor, depth, tran_feat) return bev_feat, depth class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, drop=0.0): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.drop1 = nn.Dropout(drop) self.fc2 = nn.Linear(hidden_features, out_features) self.drop2 = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) x = self.drop2(x) return x class SELayer(nn.Module): def __init__(self, channels, act_layer=nn.ReLU, gate_layer=nn.Sigmoid): super().__init__() self.conv_reduce = nn.Conv2d(channels, channels, 1, bias=True) self.act1 = act_layer() self.conv_expand = nn.Conv2d(channels, channels, 1, bias=True) self.gate = gate_layer() def forward(self, x, x_se): x_se = self.conv_reduce(x_se) x_se = self.act1(x_se) x_se = self.conv_expand(x_se) return x * self.gate(x_se) class DepthNet(nn.Module): def __init__(self, in_channels, mid_channels, context_channels, depth_channels, use_dcn=True, use_aspp=True, with_cp=False, aspp_mid_channels=-1, only_depth=False): super(DepthNet, self).__init__() self.reduce_conv = nn.Sequential( nn.Conv2d( in_channels, mid_channels, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), ) self.only_depth = only_depth or context_channels == 0 if not self.only_depth: self.context_conv = nn.Conv2d( mid_channels, context_channels, kernel_size=1, stride=1, padding=0) self.context_mlp = Mlp(22, mid_channels, mid_channels) self.context_se = SELayer(mid_channels) # NOTE: add camera-aware self.bn = nn.BatchNorm1d(22) self.depth_mlp = Mlp(22, mid_channels, mid_channels) self.depth_se = SELayer(mid_channels) # NOTE: add camera-aware depth_conv_list = [ BasicBlock(mid_channels, mid_channels), BasicBlock(mid_channels, mid_channels), BasicBlock(mid_channels, mid_channels), ] if use_aspp: if aspp_mid_channels<0: aspp_mid_channels = mid_channels depth_conv_list.append(ASPP(mid_channels, aspp_mid_channels)) if use_dcn: depth_conv_list.append( build_conv_layer( cfg=dict( type='DCN', in_channels=mid_channels, out_channels=mid_channels, kernel_size=3, padding=1, groups=4, im2col_step=128, ))) depth_conv_list.append( nn.Conv2d( mid_channels, depth_channels, kernel_size=1, stride=1, padding=0)) self.depth_conv = nn.Sequential(*depth_conv_list) self.with_cp = with_cp def forward(self, x, mlp_input): mlp_input = self.bn(mlp_input.reshape(-1, mlp_input.shape[-1])) x = self.reduce_conv(x) if not self.only_depth: context_se = self.context_mlp(mlp_input)[..., None, None] context = self.context_se(x, context_se) context = self.context_conv(context) depth_se = self.depth_mlp(mlp_input)[..., None, None] depth = self.depth_se(x, depth_se) if self.with_cp: depth = checkpoint(self.depth_conv, depth) else: depth = self.depth_conv(depth) if not self.only_depth: return torch.cat([depth, context], dim=1) else: return depth @TRANSFORMER_LAYER_SEQUENCE.register_module() class BEVFormerEncoderDepth(BEVFormerEncoder): def __init__(self, *args, in_channels=256, out_channels=256, feat_down_sample=32, loss_depth_weight = 3.0, depthnet_cfg=dict(),grid_config=None,**kwargs): super(BEVFormerEncoderDepth, self).__init__(*args, **kwargs) self.fp16_enabled = False self.loss_depth_weight = loss_depth_weight self.feat_down_sample = feat_down_sample self.grid_config = grid_config self.D = int((grid_config['depth'][1] - grid_config['depth'][0]) / grid_config['depth'][2]) self.depth_net = DepthNet(in_channels, in_channels, 0, self.D, **depthnet_cfg) @auto_fp16() def forward(self, bev_query, key, value, *args, mlvl_feats=None, bev_h=None, bev_w=None, bev_pos=None, spatial_shapes=None, level_start_index=None, valid_ratios=None, prev_bev=None, shift=0., **kwargs): """Forward function for `TransformerDecoder`. Args: bev_query (Tensor): Input BEV query with shape `(num_query, bs, embed_dims)`. key & value (Tensor): Input multi-cameta features with shape (num_cam, num_value, bs, embed_dims) reference_points (Tensor): The reference points of offset. has shape (bs, num_query, 4) when as_two_stage, otherwise has shape ((bs, num_query, 2). valid_ratios (Tensor): The radios of valid points on the feature map, has shape (bs, num_levels, 2) Returns: Tensor: Results with shape [1, num_query, bs, embed_dims] when return_intermediate is `False`, otherwise it has shape [num_layers, num_query, bs, embed_dims]. """ bev_embed = super().forward( bev_query, key, value, bev_h=bev_h, bev_w=bev_w, bev_pos=bev_pos, spatial_shapes=spatial_shapes, level_start_index=level_start_index, prev_bev=prev_bev, shift=shift, **kwargs) # import ipdb; ipdb.set_trace() images = mlvl_feats[0] img_metas = kwargs['img_metas'] B, N, C, fH, fW = images.shape lidar2img = [] camera2ego = [] camera_intrinsics = [] img_aug_matrix = [] lidar2ego = [] for img_meta in img_metas: lidar2img.append(img_meta['lidar2img']) camera2ego.append(img_meta['camera2ego']) camera_intrinsics.append(img_meta['camera_intrinsics']) img_aug_matrix.append(img_meta['img_aug_matrix']) lidar2ego.append(img_meta['lidar2ego']) lidar2img = np.asarray(lidar2img) lidar2img = images.new_tensor(lidar2img) # (B, N, 4, 4) camera2ego = np.asarray(camera2ego) camera2ego = images.new_tensor(camera2ego) # (B, N, 4, 4) camera_intrinsics = np.asarray(camera_intrinsics) camera_intrinsics = images.new_tensor(camera_intrinsics) # (B, N, 4, 4) img_aug_matrix = np.asarray(img_aug_matrix) img_aug_matrix = images.new_tensor(img_aug_matrix) # (B, N, 4, 4) lidar2ego = np.asarray(lidar2ego) lidar2ego = images.new_tensor(lidar2ego) # (B, N, 4, 4) rots = camera2ego[..., :3, :3] trans = camera2ego[..., :3, 3] intrins = camera_intrinsics[..., :3, :3] post_rots = img_aug_matrix[..., :3, :3] post_trans = img_aug_matrix[..., :3, 3] lidar2ego_rots = lidar2ego[..., :3, :3] lidar2ego_trans = lidar2ego[..., :3, 3] mlp_input = self.get_mlp_input(camera2ego, camera_intrinsics, post_rots, post_trans) depth = self.get_cam_feats(images, mlp_input) ret_dict = dict( bev=bev_embed['bev'], depth=depth, ) # import ipdb; ipdb.set_trace() return ret_dict @force_fp32() def get_cam_feats(self, x, mlp_input): B, N, C, fH, fW = x.shape x = x.view(B * N, C, fH, fW) x = self.depth_net(x, mlp_input) depth = x[:, : self.D].softmax(dim=1) depth = depth.view(B, N, self.D, fH, fW) return depth def get_downsampled_gt_depth(self, gt_depths): """ Input: gt_depths: [B, N, H, W] Output: gt_depths: [B*N*h*w, d] """ B, N, H, W = gt_depths.shape gt_depths = gt_depths.view(B * N, H // self.feat_down_sample, self.feat_down_sample, W // self.feat_down_sample, self.feat_down_sample, 1) gt_depths = gt_depths.permute(0, 1, 3, 5, 2, 4).contiguous() gt_depths = gt_depths.view(-1, self.feat_down_sample * self.feat_down_sample) # 把gt_depth做feat_down_sample倍数的采样 gt_depths_tmp = torch.where(gt_depths == 0.0, 1e5 * torch.ones_like(gt_depths), gt_depths) # 因为深度很稀疏,大部分的点都是0,所以把0变成10000,下一步取-1维度上的最小就是深度的值 gt_depths = torch.min(gt_depths_tmp, dim=-1).values gt_depths = gt_depths.view(B * N, H // self.feat_down_sample, W // self.feat_down_sample) gt_depths = ( gt_depths - (self.grid_config['depth'][0] - self.grid_config['depth'][2])) / self.grid_config['depth'][2] gt_depths = torch.where((gt_depths < self.D + 1) & (gt_depths >= 0.0), gt_depths, torch.zeros_like(gt_depths)) gt_depths = F.one_hot( gt_depths.long(), num_classes=self.D + 1).view(-1, self.D + 1)[:, 1:] return gt_depths.float() @force_fp32() def get_depth_loss(self, depth_labels, depth_preds): # import pdb;pdb.set_trace() if depth_preds is None: return 0 depth_labels = self.get_downsampled_gt_depth(depth_labels) depth_preds = depth_preds.permute(0, 1, 3, 4, 2).contiguous().view(-1, self.D) # fg_mask = torch.max(depth_labels, dim=1).values > 0.0 # 只计算有深度的前景的深度loss # import pdb;pdb.set_trace() fg_mask = depth_labels > 0.0 # 只计算有深度的前景的深度loss depth_labels = depth_labels[fg_mask] depth_preds = depth_preds[fg_mask] with autocast(enabled=False): depth_loss = F.binary_cross_entropy( depth_preds, depth_labels, reduction='none', ).sum() / max(1.0, fg_mask.sum()) # if depth_loss <= 0.: # import pdb;pdb.set_trace() return self.loss_depth_weight * depth_loss def get_mlp_input(self, sensor2ego, intrin, post_rot, post_tran): B, N, _, _ = sensor2ego.shape mlp_input = torch.stack([ intrin[:, :, 0, 0], intrin[:, :, 1, 1], intrin[:, :, 0, 2], intrin[:, :, 1, 2], post_rot[:, :, 0, 0], post_rot[:, :, 0, 1], post_tran[:, :, 0], post_rot[:, :, 1, 0], post_rot[:, :, 1, 1], post_tran[:, :, 1], ], dim=-1) sensor2ego = sensor2ego[:,:,:3,:].reshape(B, N, -1) mlp_input = torch.cat([mlp_input, sensor2ego], dim=-1) return mlp_input @TRANSFORMER_LAYER_SEQUENCE.register_module() class LSSTransform(BaseTransform): def __init__( self, in_channels, out_channels, feat_down_sample, pc_range, voxel_size, dbound, downsample=1, loss_depth_weight = 3.0, depthnet_cfg=dict(), grid_config=None, ): super(LSSTransform, self).__init__( in_channels=in_channels, out_channels=out_channels, feat_down_sample=feat_down_sample, pc_range=pc_range, voxel_size=voxel_size, dbound=dbound, ) # import pdb;pdb.set_trace() self.loss_depth_weight = loss_depth_weight self.grid_config = grid_config self.depth_net = DepthNet(in_channels, in_channels, self.C, self.D, **depthnet_cfg) if downsample > 1: assert downsample == 2, downsample self.downsample = nn.Sequential( nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(True), nn.Conv2d( out_channels, out_channels, 3, stride=downsample, padding=1, bias=False, ), nn.BatchNorm2d(out_channels), nn.ReLU(True), nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(True), ) else: self.downsample = nn.Identity() @force_fp32() def get_cam_feats(self, x, mlp_input): B, N, C, fH, fW = x.shape x = x.view(B * N, C, fH, fW) x = self.depth_net(x, mlp_input) depth = x[:, : self.D].softmax(dim=1) x = depth.unsqueeze(1) * x[:, self.D : (self.D + self.C)].unsqueeze(2) x = x.view(B, N, self.C, self.D, fH, fW) x = x.permute(0, 1, 3, 4, 5, 2) depth = depth.view(B, N, self.D, fH, fW) return x, depth def forward(self, images, img_metas): x, depth = super().forward(images, img_metas) x = self.downsample(x) ret_dict = dict( bev=x, depth=depth, ) return ret_dict def get_downsampled_gt_depth(self, gt_depths): """ Input: gt_depths: [B, N, H, W] Output: gt_depths: [B*N*h*w, d] """ B, N, H, W = gt_depths.shape gt_depths = gt_depths.view(B * N, H // self.feat_down_sample, self.feat_down_sample, W // self.feat_down_sample, self.feat_down_sample, 1) gt_depths = gt_depths.permute(0, 1, 3, 5, 2, 4).contiguous() gt_depths = gt_depths.view(-1, self.feat_down_sample * self.feat_down_sample) # 把gt_depth做feat_down_sample倍数的采样 gt_depths_tmp = torch.where(gt_depths == 0.0, 1e5 * torch.ones_like(gt_depths), gt_depths) # 因为深度很稀疏,大部分的点都是0,所以把0变成10000,下一步取-1维度上的最小就是深度的值 gt_depths = torch.min(gt_depths_tmp, dim=-1).values gt_depths = gt_depths.view(B * N, H // self.feat_down_sample, W // self.feat_down_sample) gt_depths = ( gt_depths - (self.grid_config['depth'][0] - self.grid_config['depth'][2])) / self.grid_config['depth'][2] gt_depths = torch.where((gt_depths < self.D + 1) & (gt_depths >= 0.0), gt_depths, torch.zeros_like(gt_depths)) gt_depths = F.one_hot( gt_depths.long(), num_classes=self.D + 1).view(-1, self.D + 1)[:, 1:] return gt_depths.float() @force_fp32() def get_depth_loss(self, depth_labels, depth_preds): # import pdb;pdb.set_trace() if depth_preds is None: return 0 depth_labels = self.get_downsampled_gt_depth(depth_labels) depth_preds = depth_preds.permute(0, 1, 3, 4, 2).contiguous().view(-1, self.D) # fg_mask = torch.max(depth_labels, dim=1).values > 0.0 # 只计算有深度的前景的深度loss # import pdb;pdb.set_trace() fg_mask = depth_labels > 0.0 # 只计算有深度的前景的深度loss depth_labels = depth_labels[fg_mask] depth_preds = depth_preds[fg_mask] with autocast(enabled=False): depth_loss = F.binary_cross_entropy( depth_preds, depth_labels, reduction='none', ).sum() / max(1.0, fg_mask.sum()) # if depth_loss <= 0.: # import pdb;pdb.set_trace() return self.loss_depth_weight * depth_loss def get_mlp_input(self, sensor2ego, intrin, post_rot, post_tran): B, N, _, _ = sensor2ego.shape mlp_input = torch.stack([ intrin[:, :, 0, 0], intrin[:, :, 1, 1], intrin[:, :, 0, 2], intrin[:, :, 1, 2], post_rot[:, :, 0, 0], post_rot[:, :, 0, 1], post_tran[:, :, 0], post_rot[:, :, 1, 0], post_rot[:, :, 1, 1], post_tran[:, :, 1], ], dim=-1) sensor2ego = sensor2ego[:,:,:3,:].reshape(B, N, -1) mlp_input = torch.cat([mlp_input, sensor2ego], dim=-1) return mlp_input @TRANSFORMER_LAYER_SEQUENCE.register_module() class LSSTransformV2(BaseTransformV2): def __init__( self, input_size, in_channels, out_channels, feat_down_sample, pc_range, voxel_size, dbound, downsample=1, loss_depth_weight = 3.0, depthnet_cfg=dict(), grid_config = None, sid=False, ): super(LSSTransformV2, self).__init__( input_size=input_size, in_channels=in_channels, out_channels=out_channels, feat_down_sample=feat_down_sample, pc_range=pc_range, voxel_size=voxel_size, dbound=dbound, sid=sid, ) self.loss_depth_weight = loss_depth_weight self.grid_config = grid_config self.depth_net = DepthNet(self.in_channels, self.in_channels, self.C, self.D, **depthnet_cfg) if downsample > 1: assert downsample == 2, downsample self.downsample = nn.Sequential( nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(True), nn.Conv2d( out_channels, out_channels, 3, stride=downsample, padding=1, bias=False, ), nn.BatchNorm2d(out_channels), nn.ReLU(True), nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(True), ) else: self.downsample = nn.Identity() @force_fp32() def get_cam_feats(self, x, mlp_input): B, N, C, fH, fW = x.shape x = x.view(B * N, C, fH, fW) x = self.depth_net(x, mlp_input) depth = x[:, : self.D].softmax(dim=1) tran_feat = x[:, self.D : (self.D + self.C)] tran_feat = tran_feat.view(B, N, self.C, fH, fW) # x = x.permute(0, 1, 3, 4, 5, 2) depth = depth.view(B, N, self.D, fH, fW) return tran_feat, depth def forward(self, images, img_metas): x, depth = super().forward(images, img_metas) x = self.downsample(x) ret_dict = dict( bev=x, depth=depth, ) return ret_dict def get_downsampled_gt_depth(self, gt_depths): """ Input: gt_depths: [B, N, H, W] Output: gt_depths: [B*N*h*w, d] """ B, N, H, W = gt_depths.shape gt_depths = gt_depths.view(B * N, H // self.feat_down_sample, self.feat_down_sample, W // self.feat_down_sample, self.feat_down_sample, 1) gt_depths = gt_depths.permute(0, 1, 3, 5, 2, 4).contiguous() gt_depths = gt_depths.view(-1, self.feat_down_sample * self.feat_down_sample) # 把gt_depth做feat_down_sample倍数的采样 gt_depths_tmp = torch.where(gt_depths == 0.0, 1e5 * torch.ones_like(gt_depths), gt_depths) # 因为深度很稀疏,大部分的点都是0,所以把0变成10000,下一步取-1维度上的最小就是深度的值 gt_depths = torch.min(gt_depths_tmp, dim=-1).values gt_depths = gt_depths.view(B * N, H // self.feat_down_sample, W // self.feat_down_sample) gt_depths = ( gt_depths - (self.grid_config['depth'][0] - self.grid_config['depth'][2])) / self.grid_config['depth'][2] gt_depths = torch.where((gt_depths < self.D + 1) & (gt_depths >= 0.0), gt_depths, torch.zeros_like(gt_depths)) gt_depths = F.one_hot( gt_depths.long(), num_classes=self.D + 1).view(-1, self.D + 1)[:, 1:] return gt_depths.float() @force_fp32() def get_depth_loss(self, depth_labels, depth_preds): # import pdb;pdb.set_trace() if depth_preds is None: return 0 depth_labels = self.get_downsampled_gt_depth(depth_labels) depth_preds = depth_preds.permute(0, 1, 3, 4, 2).contiguous().view(-1, self.D) # fg_mask = torch.max(depth_labels, dim=1).values > 0.0 # 只计算有深度的前景的深度loss # import pdb;pdb.set_trace() fg_mask = depth_labels > 0.0 # 只计算有深度的前景的深度loss depth_labels = depth_labels[fg_mask] depth_preds = depth_preds[fg_mask] with autocast(enabled=False): depth_loss = F.binary_cross_entropy( depth_preds, depth_labels, reduction='none', ).sum() / max(1.0, fg_mask.sum()) # if depth_loss <= 0.: # import pdb;pdb.set_trace() return self.loss_depth_weight * depth_loss def get_mlp_input(self, sensor2ego, intrin, post_rot, post_tran): B, N, _, _ = sensor2ego.shape mlp_input = torch.stack([ intrin[:, :, 0, 0], intrin[:, :, 1, 1], intrin[:, :, 0, 2], intrin[:, :, 1, 2], post_rot[:, :, 0, 0], post_rot[:, :, 0, 1], post_tran[:, :, 0], post_rot[:, :, 1, 0], post_rot[:, :, 1, 1], post_tran[:, :, 1], ], dim=-1) sensor2ego = sensor2ego[:,:,:3,:].reshape(B, N, -1) mlp_input = torch.cat([mlp_input, sensor2ego], dim=-1) return mlp_input class _ASPPModule(nn.Module): def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm): super(_ASPPModule, self).__init__() self.atrous_conv = nn.Conv2d( inplanes, planes, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation, bias=False) self.bn = BatchNorm(planes) self.relu = nn.ReLU() self._init_weight() def forward(self, x): x = self.atrous_conv(x) x = self.bn(x) return self.relu(x) def _init_weight(self): for m in self.modules(): if isinstance(m, nn.Conv2d): torch.nn.init.kaiming_normal_(m.weight) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() class ASPP(nn.Module): def __init__(self, inplanes, mid_channels=256, BatchNorm=nn.BatchNorm2d): super(ASPP, self).__init__() dilations = [1, 6, 12, 18] self.aspp1 = _ASPPModule( inplanes, mid_channels, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm) self.aspp2 = _ASPPModule( inplanes, mid_channels, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm) self.aspp3 = _ASPPModule( inplanes, mid_channels, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm) self.aspp4 = _ASPPModule( inplanes, mid_channels, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm) self.global_avg_pool = nn.Sequential( nn.AdaptiveAvgPool2d((1, 1)), nn.Conv2d(inplanes, mid_channels, 1, stride=1, bias=False), BatchNorm(mid_channels), nn.ReLU(), ) self.conv1 = nn.Conv2d( int(mid_channels * 5), inplanes, 1, bias=False) self.bn1 = BatchNorm(inplanes) self.relu = nn.ReLU() self.dropout = nn.Dropout(0.5) self._init_weight() def forward(self, x): x1 = self.aspp1(x) x2 = self.aspp2(x) x3 = self.aspp3(x) x4 = self.aspp4(x) x5 = self.global_avg_pool(x) x5 = F.interpolate( x5, size=x4.size()[2:], mode='bilinear', align_corners=True) x = torch.cat((x1, x2, x3, x4, x5), dim=1) x = self.conv1(x) x = self.bn1(x) x = self.relu(x) return self.dropout(x) def _init_weight(self): for m in self.modules(): if isinstance(m, nn.Conv2d): torch.nn.init.kaiming_normal_(m.weight) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_()