""" Implements the TransFuser vision backbone. """ import timm import torch import torch.nn.functional as F from torch import nn from torch.utils.checkpoint import checkpoint from navsim.agents.backbones.internimage import InternImage from navsim.agents.backbones.swin import SwinTransformerBEVFT from navsim.agents.backbones.vov import VoVNet from navsim.agents.transfuser.transfuser_backbone import GPT from navsim.agents.utils.vit import DAViT class TransfuserBackboneConv(nn.Module): """ Multi-scale Fusion Transformer for image + LiDAR feature fusion """ def __init__(self, config): super().__init__() self.config = config self.backbone_type = config.backbone_type if config.backbone_type == 'intern': self.image_encoder = InternImage(init_cfg=dict(type='Pretrained', checkpoint=config.intern_ckpt ), frozen_stages=2) # scale_4_c = 2560 vit_channels = 2560 self.image_encoder.init_weights() elif config.backbone_type == 'vov': self.image_encoder = VoVNet( spec_name='V-99-eSE', out_features=['stage4', 'stage5'], norm_eval=True, with_cp=True, init_cfg=dict( type='Pretrained', checkpoint=config.vov_ckpt, prefix='img_backbone.' ) ) # scale_4_c = 1024 vit_channels = 1024 self.image_encoder.init_weights() elif config.backbone_type == 'swin': self.image_encoder = SwinTransformerBEVFT( with_cp=True, convert_weights=False, depths=[2,2,18,2], drop_path_rate=0.35, embed_dims=192, init_cfg=dict( checkpoint=config.swin_ckpt, type='Pretrained' ), num_heads=[6,12,24,48], out_indices=[3], patch_norm=True, window_size=[16,16,16,16], use_abs_pos_embed=True, return_stereo_feat=False, output_missing_index_as_none=False ) vit_channels = 1536 else: raise ValueError # self.lateral_3 = nn.Sequential(*[ # nn.Conv2d(vit_channels, # vit_channels, # kernel_size=1), # nn.ReLU(inplace=True) # ]) # self.lateral_4 = nn.Sequential(*[ # nn.Conv2d(scale_4_c, # vit_channels, # kernel_size=1), # nn.ReLU(inplace=True) # ]) # self.fpn_out = nn.Sequential(*[ # nn.Conv2d(vit_channels, # vit_channels, # kernel_size=3, padding=1), # nn.ReLU(inplace=True) # ]) if config.use_ground_plane: in_channels = 2 * config.lidar_seq_len else: in_channels = config.lidar_seq_len self.avgpool_img = nn.AdaptiveAvgPool2d( (self.config.img_vert_anchors, self.config.img_horz_anchors) ) self.lidar_encoder = timm.create_model( config.lidar_architecture, pretrained=False, in_chans=in_channels, features_only=True, ) self.global_pool_lidar = nn.AdaptiveAvgPool2d(output_size=1) self.avgpool_lidar = nn.AdaptiveAvgPool2d( (self.config.lidar_vert_anchors, self.config.lidar_horz_anchors) ) lidar_time_frames = [1, 1, 1, 1] self.global_pool_img = nn.AdaptiveAvgPool2d(output_size=1) start_index = 0 # Some networks have a stem layer if len(self.lidar_encoder.return_layers) > 4: start_index += 1 self.transformers = nn.ModuleList( [ GPT( n_embd=vit_channels, config=config, # lidar_video=self.lidar_video, lidar_time_frames=lidar_time_frames[i], ) for i in range(4) ] ) self.lidar_channel_to_img = nn.ModuleList( [ nn.Conv2d( self.lidar_encoder.feature_info.info[start_index + i]["num_chs"], vit_channels, kernel_size=1, ) for i in range(4) ] ) self.img_channel_to_lidar = nn.ModuleList( [ nn.Conv2d( vit_channels, self.lidar_encoder.feature_info.info[start_index + i]["num_chs"], kernel_size=1, ) for i in range(4) ] ) self.num_features = self.lidar_encoder.feature_info.info[start_index + 3]["num_chs"] # FPN fusion channel = self.config.bev_features_channels self.relu = nn.ReLU(inplace=True) # top down if self.config.detect_boxes or self.config.use_bev_semantic: self.upsample = nn.Upsample( scale_factor=self.config.bev_upsample_factor, mode="bilinear", align_corners=False ) self.upsample2 = nn.Upsample( size=( self.config.lidar_resolution_height // self.config.bev_down_sample_factor, self.config.lidar_resolution_width // self.config.bev_down_sample_factor, ), mode="bilinear", align_corners=False, ) self.up_conv5 = nn.Conv2d(channel, channel, (3, 3), padding=1) self.up_conv4 = nn.Conv2d(channel, channel, (3, 3), padding=1) # lateral self.c5_conv = nn.Conv2d( self.lidar_encoder.feature_info.info[start_index + 3]["num_chs"], channel, (1, 1) ) def top_down(self, x): p5 = self.relu(self.c5_conv(x)) p4 = self.relu(self.up_conv5(self.upsample(p5))) p3 = self.relu(self.up_conv4(self.upsample2(p4))) return p3 # def fpn(self, xs): # x_4 = xs[-1] # x_3 = xs[-2] # out = self.fpn_out( # F.interpolate(self.lateral_4(x_4), scale_factor=self.config.bev_upsample_factor, mode='bilinear', align_corners=False) # + self.lateral_3(x_3) # ) # # return out def forward(self, image, lidar): """ Image + LiDAR feature fusion using transformers Args: image_list (list): list of input images lidar_list (list): list of input LiDAR BEV """ image_features, lidar_features = image, lidar # Generate an iterator for all the layers in the network that one can loop through. lidar_layers = iter(self.lidar_encoder.items()) # Stem layer. # In some architectures the stem is not a return layer, so we need to skip it. if len(self.lidar_encoder.return_layers) > 4: lidar_features = self.forward_layer_block( lidar_layers, self.lidar_encoder.return_layers, lidar_features ) # Loop through the 4 blocks of the network. # FPN # image_features = self.fpn(self.image_encoder(image_features)) image_features = self.image_encoder(image_features)[-1] # print(image_features.shape) for i in range(4): lidar_features = self.forward_layer_block( lidar_layers, self.lidar_encoder.return_layers, lidar_features ) image_features, lidar_features = self.fuse_features(image_features, lidar_features, i) if self.config.detect_boxes or self.config.use_bev_semantic: x4 = lidar_features # image_feature_grid = None # if self.config.use_semantic or self.config.use_depth: # image_feature_grid = image_features if self.config.transformer_decoder_join: fused_features = lidar_features else: image_features = self.global_pool_img(image_features) image_features = torch.flatten(image_features, 1) lidar_features = self.global_pool_lidar(lidar_features) lidar_features = torch.flatten(lidar_features, 1) if self.config.add_features: lidar_features = self.lidar_to_img_features_end(lidar_features) fused_features = image_features + lidar_features else: fused_features = torch.cat((image_features, lidar_features), dim=1) if self.config.detect_boxes or self.config.use_bev_semantic: features = self.top_down(x4) else: features = None return features, fused_features, image_features def forward_layer_block(self, layers, return_layers, features, if_ckpt=False): """ Run one forward pass to a block of layers from a TIMM neural network and returns the result. Advances the whole network by just one block :param layers: Iterator starting at the current layer block :param return_layers: TIMM dictionary describing at which intermediate layers features are returned. :param features: Input features :return: Processed features """ for name, module in layers: if if_ckpt: features = checkpoint(module, features) else: features = module(features) if name in return_layers: break return features def fuse_features(self, image_features, lidar_features, layer_idx): """ Perform a TransFuser feature fusion block using a Transformer module. :param image_features: Features from the image branch :param lidar_features: Features from the LiDAR branch :param layer_idx: Transformer layer index. :return: image_features and lidar_features with added features from the other branch. """ image_embd_layer = self.avgpool_img(image_features) lidar_embd_layer = self.avgpool_lidar(lidar_features) lidar_embd_layer = self.lidar_channel_to_img[layer_idx](lidar_embd_layer) image_features_layer, lidar_features_layer = self.transformers[layer_idx]( image_embd_layer, lidar_embd_layer ) lidar_features_layer = self.img_channel_to_lidar[layer_idx](lidar_features_layer) image_features_layer = F.interpolate( image_features_layer, size=(image_features.shape[2], image_features.shape[3]), mode="bilinear", align_corners=False, ) lidar_features_layer = F.interpolate( lidar_features_layer, size=(lidar_features.shape[2], lidar_features.shape[3]), mode="bilinear", align_corners=False, ) image_features = image_features + image_features_layer lidar_features = lidar_features + lidar_features_layer return image_features, lidar_features