|
from torch.utils.checkpoint import checkpoint as ckpt |
|
from functools import partial |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.checkpoint import checkpoint as ckpt |
|
|
|
from navsim.agents.dreamer.backbone import Backbone |
|
from navsim.agents.dreamer.hydra_dreamer_config import HydraDreamerConfig |
|
from navsim.agents.utils.layers import Mlp, NestedTensorBlock as Block |
|
|
|
|
|
class DreamerNetwork(nn.Module): |
|
def __init__(self, config: HydraDreamerConfig): |
|
super().__init__() |
|
|
|
self.fixed_vit = Backbone(config) |
|
self.fixed_vit.requires_grad_(False) |
|
self.siamese_vit = Backbone(config) |
|
self.proj = nn.Conv2d( |
|
self.siamese_vit.img_feat_c * 3, |
|
self.siamese_vit.img_feat_c, kernel_size=1 |
|
) |
|
self.decoder_blocks = nn.ModuleList([ |
|
Block( |
|
dim=self.siamese_vit.img_feat_c, |
|
num_heads=16, |
|
mlp_ratio=4, |
|
qkv_bias=True, |
|
ffn_bias=True, |
|
proj_bias=True, |
|
drop_path=0.0, |
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), |
|
act_layer=nn.GELU, |
|
ffn_layer=Mlp, |
|
init_values=1.0, |
|
) for _ in range(config.decoder_blocks) |
|
]) |
|
|
|
def forward(self, features): |
|
|
|
|
|
result = {} |
|
|
|
img_3, img_2, img_1 = features['img_3'], features['img_2'], features['img_1'] |
|
B, C_IMG, H_IMG, W_IMG = img_3.shape |
|
img_batched = torch.cat([ |
|
img_3[:, None], |
|
img_2[:, None], |
|
img_1[:, None], |
|
], dim=1).view(-1, C_IMG, H_IMG, W_IMG) |
|
BN = img_batched.shape[0] |
|
N = BN // B |
|
siamese_feats = self.siamese_vit(img_batched) |
|
_, C, H, W = siamese_feats.shape |
|
siamese_feats = siamese_feats.view(B, N, C, H, W) |
|
x = self.proj(torch.cat([ |
|
siamese_feats[:, 0], |
|
siamese_feats[:, 1], |
|
siamese_feats[:, 2], |
|
], dim=1)) |
|
x = x.view(B, C, -1).permute(0, 2, 1) |
|
for i, blk in enumerate(self.decoder_blocks): |
|
x = ckpt(blk, x) |
|
result['pred'] = x |
|
return result |
|
|