navsim_ours / navsim /agents /dreamer /dreamer_network_cond.py
lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
raw
history blame
2.3 kB
from torch.utils.checkpoint import checkpoint as ckpt
from functools import partial
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint as ckpt
from navsim.agents.dreamer.backbone import Backbone
from navsim.agents.dreamer.hydra_dreamer_config import HydraDreamerConfig
from navsim.agents.utils.layers import Mlp, NestedTensorBlock as Block
class DreamerNetworkCondition(nn.Module):
def __init__(self, config: HydraDreamerConfig):
super().__init__()
# fixed vit -> init from a planning hydra model, provides latent gt
self.fixed_vit = Backbone(config)
self.fixed_vit.requires_grad_(False)
self.siamese_vit = Backbone(config)
self.proj = nn.Conv2d(
self.siamese_vit.img_feat_c * 3,
self.siamese_vit.img_feat_c, kernel_size=1
)
self.decoder_blocks = nn.ModuleList([
Block(
dim=self.siamese_vit.img_feat_c,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
ffn_bias=True,
proj_bias=True,
drop_path=0.0,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
ffn_layer=Mlp,
init_values=1.0,
) for _ in range(config.decoder_blocks)
])
def forward(self, features):
# todo: OCC COND
result = {}
# B, C, H, W
img_3, img_2, img_1 = features['img_3'], features['img_2'], features['img_1']
B, C_IMG, H_IMG, W_IMG = img_3.shape
img_batched = torch.cat([
img_3[:, None],
img_2[:, None],
img_1[:, None],
], dim=1).view(-1, C_IMG, H_IMG, W_IMG)
BN = img_batched.shape[0]
N = BN // B
siamese_feats = self.siamese_vit(img_batched)
_, C, H, W = siamese_feats.shape
siamese_feats = siamese_feats.view(B, N, C, H, W)
x = self.proj(torch.cat([
siamese_feats[:, 0],
siamese_feats[:, 1],
siamese_feats[:, 2],
], dim=1))
x = x.view(B, C, -1).permute(0, 2, 1)
for i, blk in enumerate(self.decoder_blocks):
x = ckpt(blk, x)
result['pred'] = x
return result