navsim_ours / navsim /agents /transfuser /transfuser_backbone_moe_ult32.py
lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
raw
history blame
12.2 kB
"""
Implements the TransFuser vision backbone.
"""
import math
import torch
from torch import nn
import torch.nn.functional as F
import timm
import copy
from torch.utils.checkpoint import checkpoint
from torchvision.transforms import Resize
from navsim.agents.backbones.eva import EVAViT
from navsim.agents.backbones.vov import VoVNet
from navsim.agents.transfuser.transfuser_backbone import GPT
from timm.models.vision_transformer import VisionTransformer
from navsim.agents.utils.vit import DAViT
class TransfuserBackboneMoeUlt32(nn.Module):
"""
Multi-scale Fusion Transformer for image + LiDAR feature fusion
"""
def __init__(self, config):
super().__init__()
self.config = config
# debug
# vit-l
img_vit_size = (256, 256 * 4)
self.resize = Resize(img_vit_size)
self.image_encoder = nn.ModuleDict({
'davit': DAViT(ckpt=config.vit_ckpt),
'sptrvit': EVAViT(
img_size=img_vit_size[0], # img_size for short side
patch_size=16,
window_size=16,
global_window_size=img_vit_size[0] // 16,
# If use square image (e.g., set global_window_size=0, else global_window_size=img_size // 16)
in_chans=3,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4 * 2 / 3,
window_block_indexes=(
list(range(0, 2)) + list(range(3, 5)) + list(range(6, 8)) + list(range(9, 11)) + list(
range(12, 14)) + list(range(15, 17)) + list(range(18, 20)) + list(range(21, 23))
),
qkv_bias=True,
drop_path_rate=0.3,
with_cp=True,
flash_attn=False,
xformers_attn=True
),
'mapvit': EVAViT(
img_size=img_vit_size[0], # img_size for short side
patch_size=16,
window_size=16,
global_window_size=img_vit_size[0] // 16,
# If use square image (e.g., set global_window_size=0, else global_window_size=img_size // 16)
in_chans=3,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4 * 2 / 3,
window_block_indexes=(
list(range(0, 2)) + list(range(3, 5)) + list(range(6, 8)) + list(range(9, 11)) + list(
range(12, 14)) + list(range(15, 17)) + list(range(18, 20)) + list(range(21, 23))
),
qkv_bias=True,
drop_path_rate=0.3,
with_cp=True,
flash_attn=False,
xformers_attn=True
),
'vov': VoVNet(
spec_name='V-99-eSE',
out_features=['stage4', 'stage5'],
norm_eval=True,
with_cp=True,
init_cfg=dict(
type='Pretrained',
checkpoint=config.vov_ckpt,
prefix='img_backbone.'
)
)
})
self.image_encoder['sptrvit'].init_weights(config.sptr_ckpt)
self.image_encoder['mapvit'].init_weights(config.map_ckpt)
self.image_encoder['vov'].init_weights()
self.moe_proj = nn.Sequential(*[
nn.Conv2d(1024 * 4, 1024, 1)
])
if config.use_ground_plane:
in_channels = 2 * config.lidar_seq_len
else:
in_channels = config.lidar_seq_len
self.avgpool_img = nn.AdaptiveAvgPool2d(
(self.config.img_vert_anchors, self.config.img_horz_anchors)
)
self.lidar_encoder = timm.create_model(
config.lidar_architecture,
pretrained=False,
in_chans=in_channels,
features_only=True,
)
self.global_pool_lidar = nn.AdaptiveAvgPool2d(output_size=1)
self.avgpool_lidar = nn.AdaptiveAvgPool2d(
(self.config.lidar_vert_anchors, self.config.lidar_horz_anchors)
)
lidar_time_frames = [1, 1, 1, 1]
self.global_pool_img = nn.AdaptiveAvgPool2d(output_size=1)
start_index = 0
# Some networks have a stem layer
vit_channels = 1024
if len(self.lidar_encoder.return_layers) > 4:
start_index += 1
self.transformers = nn.ModuleList(
[
GPT(
n_embd=vit_channels,
config=config,
# lidar_video=self.lidar_video,
lidar_time_frames=lidar_time_frames[i],
)
for i in range(4)
]
)
self.lidar_channel_to_img = nn.ModuleList(
[
nn.Conv2d(
self.lidar_encoder.feature_info.info[start_index + i]["num_chs"],
vit_channels,
kernel_size=1,
)
for i in range(4)
]
)
self.img_channel_to_lidar = nn.ModuleList(
[
nn.Conv2d(
vit_channels,
self.lidar_encoder.feature_info.info[start_index + i]["num_chs"],
kernel_size=1,
)
for i in range(4)
]
)
self.num_features = self.lidar_encoder.feature_info.info[start_index + 3]["num_chs"]
# FPN fusion
channel = self.config.bev_features_channels
self.relu = nn.ReLU(inplace=True)
# top down
if self.config.detect_boxes or self.config.use_bev_semantic:
self.upsample = nn.Upsample(
scale_factor=self.config.bev_upsample_factor, mode="bilinear", align_corners=False
)
self.upsample2 = nn.Upsample(
size=(
self.config.lidar_resolution_height // self.config.bev_down_sample_factor,
self.config.lidar_resolution_width // self.config.bev_down_sample_factor,
),
mode="bilinear",
align_corners=False,
)
self.up_conv5 = nn.Conv2d(channel, channel, (3, 3), padding=1)
self.up_conv4 = nn.Conv2d(channel, channel, (3, 3), padding=1)
# lateral
self.c5_conv = nn.Conv2d(
self.lidar_encoder.feature_info.info[start_index + 3]["num_chs"], channel, (1, 1)
)
def top_down(self, x):
p5 = self.relu(self.c5_conv(x))
p4 = self.relu(self.up_conv5(self.upsample(p5)))
p3 = self.relu(self.up_conv4(self.upsample2(p4)))
return p3
def forward(self, image, lidar):
"""
Image + LiDAR feature fusion using transformers
Args:
image_list (list): list of input images
lidar_list (list): list of input LiDAR BEV
"""
image_features, lidar_features = image, lidar
# Generate an iterator for all the layers in the network that one can loop through.
lidar_layers = iter(self.lidar_encoder.items())
# Stem layer.
# In some architectures the stem is not a return layer, so we need to skip it.
if len(self.lidar_encoder.return_layers) > 4:
lidar_features = self.forward_layer_block(
lidar_layers, self.lidar_encoder.return_layers, lidar_features
)
# 16 * 64
vov_features = self.image_encoder['vov'](image_features)[-1]
# resize 512 * 2048 -> 256 * 1024
image_features_resized = self.resize(image_features)
# 16 * 64
sptr_features = self.image_encoder['sptrvit'](image_features_resized)[0]
map_features = self.image_encoder['mapvit'](image_features_resized)[0]
davit_features = self.image_encoder['davit'](image_features_resized)[0]
final_features = torch.cat([
self.avgpool_img(vov_features),
self.avgpool_img(sptr_features),
self.avgpool_img(map_features),
self.avgpool_img(davit_features)
], dim=1)
image_features = self.moe_proj(final_features)
for i in range(4):
lidar_features = self.forward_layer_block(
lidar_layers, self.lidar_encoder.return_layers, lidar_features
)
image_features, lidar_features = self.fuse_features(image_features, lidar_features, i)
if self.config.detect_boxes or self.config.use_bev_semantic:
x4 = lidar_features
# image_feature_grid = None
# if self.config.use_semantic or self.config.use_depth:
# image_feature_grid = image_features
if self.config.transformer_decoder_join:
fused_features = lidar_features
else:
image_features = self.global_pool_img(image_features)
image_features = torch.flatten(image_features, 1)
lidar_features = self.global_pool_lidar(lidar_features)
lidar_features = torch.flatten(lidar_features, 1)
if self.config.add_features:
lidar_features = self.lidar_to_img_features_end(lidar_features)
fused_features = image_features + lidar_features
else:
fused_features = torch.cat((image_features, lidar_features), dim=1)
if self.config.detect_boxes or self.config.use_bev_semantic:
features = self.top_down(x4)
else:
features = None
return features, fused_features, image_features
def forward_layer_block(self, layers, return_layers, features, if_ckpt=False):
"""
Run one forward pass to a block of layers from a TIMM neural network and returns the result.
Advances the whole network by just one block
:param layers: Iterator starting at the current layer block
:param return_layers: TIMM dictionary describing at which intermediate layers features are returned.
:param features: Input features
:return: Processed features
"""
for name, module in layers:
if if_ckpt:
features = checkpoint(module, features)
else:
features = module(features)
if name in return_layers:
break
return features
def fuse_features(self, image_features, lidar_features, layer_idx):
"""
Perform a TransFuser feature fusion block using a Transformer module.
:param image_features: Features from the image branch
:param lidar_features: Features from the LiDAR branch
:param layer_idx: Transformer layer index.
:return: image_features and lidar_features with added features from the other branch.
"""
image_embd_layer = image_features
lidar_embd_layer = self.avgpool_lidar(lidar_features)
lidar_embd_layer = self.lidar_channel_to_img[layer_idx](lidar_embd_layer)
image_features_layer, lidar_features_layer = self.transformers[layer_idx](
image_embd_layer, lidar_embd_layer
)
lidar_features_layer = self.img_channel_to_lidar[layer_idx](lidar_features_layer)
image_features_layer = F.interpolate(
image_features_layer,
size=(image_features.shape[2], image_features.shape[3]),
mode="bilinear",
align_corners=False,
)
lidar_features_layer = F.interpolate(
lidar_features_layer,
size=(lidar_features.shape[2], lidar_features.shape[3]),
mode="bilinear",
align_corners=False,
)
image_features = image_features + image_features_layer
lidar_features = lidar_features + lidar_features_layer
return image_features, lidar_features