longsim-base / backups /dev /modules /smart_decoder.py
gzzyyxy's picture
Upload folder using huggingface_hub
d37e5d1 verified
from typing import Dict, Optional
import torch
import torch.nn as nn
from torch_geometric.data import HeteroData
from dev.modules.attr_tokenizer import Attr_Tokenizer
from dev.modules.agent_decoder import SMARTAgentDecoder
from dev.modules.occ_decoder import SMARTOccDecoder
from dev.modules.map_decoder import SMARTMapDecoder
DECODER = {'agent_decoder': SMARTAgentDecoder,
'occ_decoder': SMARTOccDecoder}
class SMARTDecoder(nn.Module):
def __init__(self,
decoder_type: str,
dataset: str,
input_dim: int,
hidden_dim: int,
num_historical_steps: int,
pl2pl_radius: float,
time_span: Optional[int],
pl2a_radius: float,
pl2seed_radius: float,
a2a_radius: float,
a2sa_radius: float,
pl2sa_radius: float,
num_freq_bands: int,
num_map_layers: int,
num_agent_layers: int,
num_heads: int,
head_dim: int,
dropout: float,
map_token: Dict,
token_size=512,
attr_tokenizer: Attr_Tokenizer=None,
predict_motion: bool=False,
predict_state: bool=False,
predict_map: bool=False,
predict_occ: bool=False,
use_grid_token: bool=True,
use_head_token: bool=True,
use_state_token: bool=True,
disable_insertion: bool=False,
state_token: Dict[str, int]=None,
seed_size: int=5,
buffer_size: int=32,
num_recurrent_steps_val: int=-1,
loss_weight: dict=None,
logger=None) -> None:
super(SMARTDecoder, self).__init__()
self.map_encoder = SMARTMapDecoder(
dataset=dataset,
input_dim=input_dim,
hidden_dim=hidden_dim,
num_historical_steps=num_historical_steps,
pl2pl_radius=pl2pl_radius,
num_freq_bands=num_freq_bands,
num_layers=num_map_layers,
num_heads=num_heads,
head_dim=head_dim,
dropout=dropout,
map_token=map_token,
)
assert decoder_type in list(DECODER.keys()), f"Unsupport decoder type: {decoder_type}"
self.agent_encoder = DECODER[decoder_type](
dataset=dataset,
input_dim=input_dim,
hidden_dim=hidden_dim,
num_historical_steps=num_historical_steps,
time_span=time_span,
pl2a_radius=pl2a_radius,
pl2seed_radius=pl2seed_radius,
a2a_radius=a2a_radius,
a2sa_radius=a2sa_radius,
pl2sa_radius=pl2sa_radius,
num_freq_bands=num_freq_bands,
num_layers=num_agent_layers,
num_heads=num_heads,
head_dim=head_dim,
dropout=dropout,
token_size=token_size,
attr_tokenizer=attr_tokenizer,
predict_motion=predict_motion,
predict_state=predict_state,
predict_map=predict_map,
predict_occ=predict_occ,
state_token=state_token,
use_grid_token=use_grid_token,
use_head_token=use_head_token,
use_state_token=use_state_token,
disable_insertion=disable_insertion,
seed_size=seed_size,
buffer_size=buffer_size,
num_recurrent_steps_val=num_recurrent_steps_val,
loss_weight=loss_weight,
logger=logger,
)
self.map_enc = None
self.predict_motion = predict_motion
self.predict_state = predict_state
self.predict_map = predict_map
self.predict_occ = predict_occ
self.data_keys = ["agent_valid_mask", "category", "valid_mask", "av_index", "scenario_id", "shape"]
def get_agent_inputs(self, data: HeteroData) -> Dict[str, torch.Tensor]:
return self.agent_encoder.get_inputs(data)
def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]:
map_enc = self.map_encoder(data)
agent_enc = {}
if self.predict_motion or self.predict_state or self.predict_occ:
agent_enc = self.agent_encoder(data, map_enc)
return {**map_enc, **agent_enc, **{k: data[k] for k in self.data_keys}}
def inference(self, data: HeteroData) -> Dict[str, torch.Tensor]:
map_enc = self.map_encoder(data)
agent_enc = {}
if self.predict_motion or self.predict_state or self.predict_occ:
agent_enc = self.agent_encoder.inference(data, map_enc)
return {**map_enc, **agent_enc, **{k: data[k] for k in self.data_keys}}
def inference_no_map(self, data: HeteroData, map_enc) -> Dict[str, torch.Tensor]:
agent_enc = self.agent_encoder.inference(data, map_enc)
return {**map_enc, **agent_enc}
def insert_agent(self, data: HeteroData) -> Dict[str, torch.Tensor]:
map_enc = self.map_encoder(data)
agent_enc = self.agent_encoder.insert(data, map_enc)
return {**map_enc, **agent_enc, **{k: data[k] for k in self.data_keys}}
def predict_nearest_pos(self, data: HeteroData, rank) -> Dict[str, torch.Tensor]:
map_enc = self.map_encoder(data)
self.agent_encoder.predict_nearest_pos(data, map_enc, rank)