|
"""
|
|
Implements the TransFuser vision backbone.
|
|
"""
|
|
|
|
import math
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
import timm
|
|
import copy
|
|
from torch.utils.checkpoint import checkpoint
|
|
from torchvision.transforms import Resize
|
|
|
|
from navsim.agents.backbones.eva import EVAViT
|
|
from navsim.agents.backbones.vov import VoVNet
|
|
from navsim.agents.transfuser.transfuser_backbone import GPT
|
|
from timm.models.vision_transformer import VisionTransformer
|
|
|
|
from navsim.agents.utils.vit import DAViT
|
|
|
|
|
|
class TransfuserBackboneMoeUlt32(nn.Module):
|
|
"""
|
|
Multi-scale Fusion Transformer for image + LiDAR feature fusion
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
|
|
|
|
img_vit_size = (256, 256 * 4)
|
|
self.resize = Resize(img_vit_size)
|
|
self.image_encoder = nn.ModuleDict({
|
|
'davit': DAViT(ckpt=config.vit_ckpt),
|
|
'sptrvit': EVAViT(
|
|
img_size=img_vit_size[0],
|
|
patch_size=16,
|
|
window_size=16,
|
|
global_window_size=img_vit_size[0] // 16,
|
|
|
|
in_chans=3,
|
|
embed_dim=1024,
|
|
depth=24,
|
|
num_heads=16,
|
|
mlp_ratio=4 * 2 / 3,
|
|
window_block_indexes=(
|
|
list(range(0, 2)) + list(range(3, 5)) + list(range(6, 8)) + list(range(9, 11)) + list(
|
|
range(12, 14)) + list(range(15, 17)) + list(range(18, 20)) + list(range(21, 23))
|
|
),
|
|
qkv_bias=True,
|
|
drop_path_rate=0.3,
|
|
with_cp=True,
|
|
flash_attn=False,
|
|
xformers_attn=True
|
|
),
|
|
'mapvit': EVAViT(
|
|
img_size=img_vit_size[0],
|
|
patch_size=16,
|
|
window_size=16,
|
|
global_window_size=img_vit_size[0] // 16,
|
|
|
|
in_chans=3,
|
|
embed_dim=1024,
|
|
depth=24,
|
|
num_heads=16,
|
|
mlp_ratio=4 * 2 / 3,
|
|
window_block_indexes=(
|
|
list(range(0, 2)) + list(range(3, 5)) + list(range(6, 8)) + list(range(9, 11)) + list(
|
|
range(12, 14)) + list(range(15, 17)) + list(range(18, 20)) + list(range(21, 23))
|
|
),
|
|
qkv_bias=True,
|
|
drop_path_rate=0.3,
|
|
with_cp=True,
|
|
flash_attn=False,
|
|
xformers_attn=True
|
|
),
|
|
'vov': VoVNet(
|
|
spec_name='V-99-eSE',
|
|
out_features=['stage4', 'stage5'],
|
|
norm_eval=True,
|
|
with_cp=True,
|
|
init_cfg=dict(
|
|
type='Pretrained',
|
|
checkpoint=config.vov_ckpt,
|
|
prefix='img_backbone.'
|
|
)
|
|
)
|
|
})
|
|
self.image_encoder['sptrvit'].init_weights(config.sptr_ckpt)
|
|
self.image_encoder['mapvit'].init_weights(config.map_ckpt)
|
|
self.image_encoder['vov'].init_weights()
|
|
self.moe_proj = nn.Sequential(*[
|
|
nn.Conv2d(1024 * 4, 1024, 1)
|
|
])
|
|
|
|
if config.use_ground_plane:
|
|
in_channels = 2 * config.lidar_seq_len
|
|
else:
|
|
in_channels = config.lidar_seq_len
|
|
|
|
self.avgpool_img = nn.AdaptiveAvgPool2d(
|
|
(self.config.img_vert_anchors, self.config.img_horz_anchors)
|
|
)
|
|
|
|
self.lidar_encoder = timm.create_model(
|
|
config.lidar_architecture,
|
|
pretrained=False,
|
|
in_chans=in_channels,
|
|
features_only=True,
|
|
)
|
|
self.global_pool_lidar = nn.AdaptiveAvgPool2d(output_size=1)
|
|
self.avgpool_lidar = nn.AdaptiveAvgPool2d(
|
|
(self.config.lidar_vert_anchors, self.config.lidar_horz_anchors)
|
|
)
|
|
lidar_time_frames = [1, 1, 1, 1]
|
|
|
|
self.global_pool_img = nn.AdaptiveAvgPool2d(output_size=1)
|
|
start_index = 0
|
|
|
|
vit_channels = 1024
|
|
if len(self.lidar_encoder.return_layers) > 4:
|
|
start_index += 1
|
|
|
|
self.transformers = nn.ModuleList(
|
|
[
|
|
GPT(
|
|
n_embd=vit_channels,
|
|
config=config,
|
|
|
|
lidar_time_frames=lidar_time_frames[i],
|
|
)
|
|
for i in range(4)
|
|
]
|
|
)
|
|
self.lidar_channel_to_img = nn.ModuleList(
|
|
[
|
|
nn.Conv2d(
|
|
self.lidar_encoder.feature_info.info[start_index + i]["num_chs"],
|
|
vit_channels,
|
|
kernel_size=1,
|
|
)
|
|
for i in range(4)
|
|
]
|
|
)
|
|
self.img_channel_to_lidar = nn.ModuleList(
|
|
[
|
|
nn.Conv2d(
|
|
vit_channels,
|
|
self.lidar_encoder.feature_info.info[start_index + i]["num_chs"],
|
|
kernel_size=1,
|
|
)
|
|
for i in range(4)
|
|
]
|
|
)
|
|
|
|
self.num_features = self.lidar_encoder.feature_info.info[start_index + 3]["num_chs"]
|
|
|
|
channel = self.config.bev_features_channels
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
if self.config.detect_boxes or self.config.use_bev_semantic:
|
|
self.upsample = nn.Upsample(
|
|
scale_factor=self.config.bev_upsample_factor, mode="bilinear", align_corners=False
|
|
)
|
|
self.upsample2 = nn.Upsample(
|
|
size=(
|
|
self.config.lidar_resolution_height // self.config.bev_down_sample_factor,
|
|
self.config.lidar_resolution_width // self.config.bev_down_sample_factor,
|
|
),
|
|
mode="bilinear",
|
|
align_corners=False,
|
|
)
|
|
|
|
self.up_conv5 = nn.Conv2d(channel, channel, (3, 3), padding=1)
|
|
self.up_conv4 = nn.Conv2d(channel, channel, (3, 3), padding=1)
|
|
|
|
|
|
self.c5_conv = nn.Conv2d(
|
|
self.lidar_encoder.feature_info.info[start_index + 3]["num_chs"], channel, (1, 1)
|
|
)
|
|
|
|
def top_down(self, x):
|
|
|
|
p5 = self.relu(self.c5_conv(x))
|
|
p4 = self.relu(self.up_conv5(self.upsample(p5)))
|
|
p3 = self.relu(self.up_conv4(self.upsample2(p4)))
|
|
|
|
return p3
|
|
|
|
def forward(self, image, lidar):
|
|
"""
|
|
Image + LiDAR feature fusion using transformers
|
|
Args:
|
|
image_list (list): list of input images
|
|
lidar_list (list): list of input LiDAR BEV
|
|
"""
|
|
image_features, lidar_features = image, lidar
|
|
|
|
|
|
lidar_layers = iter(self.lidar_encoder.items())
|
|
|
|
|
|
|
|
if len(self.lidar_encoder.return_layers) > 4:
|
|
lidar_features = self.forward_layer_block(
|
|
lidar_layers, self.lidar_encoder.return_layers, lidar_features
|
|
)
|
|
|
|
|
|
vov_features = self.image_encoder['vov'](image_features)[-1]
|
|
|
|
image_features_resized = self.resize(image_features)
|
|
|
|
|
|
sptr_features = self.image_encoder['sptrvit'](image_features_resized)[0]
|
|
map_features = self.image_encoder['mapvit'](image_features_resized)[0]
|
|
davit_features = self.image_encoder['davit'](image_features_resized)[0]
|
|
|
|
final_features = torch.cat([
|
|
self.avgpool_img(vov_features),
|
|
self.avgpool_img(sptr_features),
|
|
self.avgpool_img(map_features),
|
|
self.avgpool_img(davit_features)
|
|
], dim=1)
|
|
|
|
image_features = self.moe_proj(final_features)
|
|
for i in range(4):
|
|
lidar_features = self.forward_layer_block(
|
|
lidar_layers, self.lidar_encoder.return_layers, lidar_features
|
|
)
|
|
|
|
image_features, lidar_features = self.fuse_features(image_features, lidar_features, i)
|
|
|
|
if self.config.detect_boxes or self.config.use_bev_semantic:
|
|
x4 = lidar_features
|
|
|
|
|
|
|
|
|
|
|
|
if self.config.transformer_decoder_join:
|
|
fused_features = lidar_features
|
|
else:
|
|
image_features = self.global_pool_img(image_features)
|
|
image_features = torch.flatten(image_features, 1)
|
|
lidar_features = self.global_pool_lidar(lidar_features)
|
|
lidar_features = torch.flatten(lidar_features, 1)
|
|
|
|
if self.config.add_features:
|
|
lidar_features = self.lidar_to_img_features_end(lidar_features)
|
|
fused_features = image_features + lidar_features
|
|
else:
|
|
fused_features = torch.cat((image_features, lidar_features), dim=1)
|
|
|
|
if self.config.detect_boxes or self.config.use_bev_semantic:
|
|
features = self.top_down(x4)
|
|
else:
|
|
features = None
|
|
|
|
|
|
return features, fused_features, image_features
|
|
|
|
def forward_layer_block(self, layers, return_layers, features, if_ckpt=False):
|
|
"""
|
|
Run one forward pass to a block of layers from a TIMM neural network and returns the result.
|
|
Advances the whole network by just one block
|
|
:param layers: Iterator starting at the current layer block
|
|
:param return_layers: TIMM dictionary describing at which intermediate layers features are returned.
|
|
:param features: Input features
|
|
:return: Processed features
|
|
"""
|
|
for name, module in layers:
|
|
if if_ckpt:
|
|
features = checkpoint(module, features)
|
|
else:
|
|
features = module(features)
|
|
if name in return_layers:
|
|
break
|
|
return features
|
|
|
|
def fuse_features(self, image_features, lidar_features, layer_idx):
|
|
"""
|
|
Perform a TransFuser feature fusion block using a Transformer module.
|
|
:param image_features: Features from the image branch
|
|
:param lidar_features: Features from the LiDAR branch
|
|
:param layer_idx: Transformer layer index.
|
|
:return: image_features and lidar_features with added features from the other branch.
|
|
"""
|
|
image_embd_layer = image_features
|
|
lidar_embd_layer = self.avgpool_lidar(lidar_features)
|
|
|
|
lidar_embd_layer = self.lidar_channel_to_img[layer_idx](lidar_embd_layer)
|
|
|
|
image_features_layer, lidar_features_layer = self.transformers[layer_idx](
|
|
image_embd_layer, lidar_embd_layer
|
|
)
|
|
lidar_features_layer = self.img_channel_to_lidar[layer_idx](lidar_features_layer)
|
|
|
|
image_features_layer = F.interpolate(
|
|
image_features_layer,
|
|
size=(image_features.shape[2], image_features.shape[3]),
|
|
mode="bilinear",
|
|
align_corners=False,
|
|
)
|
|
lidar_features_layer = F.interpolate(
|
|
lidar_features_layer,
|
|
size=(lidar_features.shape[2], lidar_features.shape[3]),
|
|
mode="bilinear",
|
|
align_corners=False,
|
|
)
|
|
|
|
image_features = image_features + image_features_layer
|
|
lidar_features = lidar_features + lidar_features_layer
|
|
|
|
return image_features, lidar_features
|
|
|