File size: 2,343 Bytes
da2e2ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
from torch.utils.checkpoint import checkpoint as ckpt
from functools import partial
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint as ckpt
from navsim.agents.dreamer.backbone import Backbone
from navsim.agents.dreamer.hydra_dreamer_config import HydraDreamerConfig
from navsim.agents.utils.layers import Mlp, NestedTensorBlock as Block
class DreamerNetwork(nn.Module):
def __init__(self, config: HydraDreamerConfig):
super().__init__()
# fixed vit -> init from a planning hydra model, provides latent gt
self.fixed_vit = Backbone(config)
self.fixed_vit.requires_grad_(False)
self.siamese_vit = Backbone(config)
self.proj = nn.Conv2d(
self.siamese_vit.img_feat_c * 3,
self.siamese_vit.img_feat_c, kernel_size=1
)
self.decoder_blocks = nn.ModuleList([
Block(
dim=self.siamese_vit.img_feat_c,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
ffn_bias=True,
proj_bias=True,
drop_path=0.0,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
ffn_layer=Mlp,
init_values=1.0,
) for _ in range(config.decoder_blocks)
])
def forward(self, features):
# todo: 1. condition -- traj discriminator
# todo: 2. long-term
result = {}
# B, C, H, W
img_3, img_2, img_1 = features['img_3'], features['img_2'], features['img_1']
B, C_IMG, H_IMG, W_IMG = img_3.shape
img_batched = torch.cat([
img_3[:, None],
img_2[:, None],
img_1[:, None],
], dim=1).view(-1, C_IMG, H_IMG, W_IMG)
BN = img_batched.shape[0]
N = BN // B
siamese_feats = self.siamese_vit(img_batched)
_, C, H, W = siamese_feats.shape
siamese_feats = siamese_feats.view(B, N, C, H, W)
x = self.proj(torch.cat([
siamese_feats[:, 0],
siamese_feats[:, 1],
siamese_feats[:, 2],
], dim=1))
x = x.view(B, C, -1).permute(0, 2, 1)
for i, blk in enumerate(self.decoder_blocks):
x = ckpt(blk, x)
result['pred'] = x
return result
|