# Config format schema number, the yaml support to valid case source from different dataset | |
time_info: | |
num_historical_steps: 11 | |
num_future_steps: 80 | |
use_intention: True | |
token_size: 2048 | |
predict_motion: True | |
predict_state: True | |
predict_map: False | |
predict_occ: True | |
state_token: | |
invalid: 0 | |
valid: 1 | |
enter: 2 | |
exit: 3 | |
pl2seed_radius: 75. | |
grid_range: 150. # 2 times of pl2seed_radius | |
grid_interval: 3. | |
angle_interval: 3. | |
seed_size: 1 | |
buffer_size: 128 | |
Dataset: | |
root: | |
train_batch_size: 1 | |
val_batch_size: 1 | |
test_batch_size: 1 | |
shuffle: True | |
num_workers: 1 | |
pin_memory: True | |
persistent_workers: True | |
train_raw_dir: ["data/waymo_processed/training"] | |
val_raw_dir: ["data/waymo_processed/validation"] | |
test_raw_dir: ["data/waymo_processed/validation"] | |
transform: WaymoTargetBuilder | |
train_processed_dir: | |
val_processed_dir: | |
test_processed_dir: | |
dataset: "scalable" | |
<<: | |
Trainer: | |
strategy: ddp_find_unused_parameters_false | |
accelerator: "gpu" | |
devices: 1 | |
max_epochs: 32 | |
overfit_epochs: 6000 | |
save_ckpt_path: | |
num_nodes: 1 | |
mode: | |
ckpt_path: | |
precision: 32 | |
accumulate_grad_batches: 1 | |
Model: | |
predictor: "smart" | |
decoder_type: "agent_decoder" # choose from ['agent_decoder', 'occ_decoder'] | |
dataset: "waymo" | |
input_dim: 2 | |
hidden_dim: 128 | |
output_dim: 2 | |
output_head: False | |
num_heads: 8 | |
<<: | |
head_dim: 16 | |
dropout: 0.1 | |
num_freq_bands: 64 | |
lr: 0.0005 | |
warmup_steps: 0 | |
total_steps: 32 | |
predict_map_token: False | |
num_recurrent_steps_val: -1 | |
val_open_loop: True | |
val_close_loop: False | |
val_insert: False | |
decoder: | |
<<: | |
num_map_layers: 3 | |
num_agent_layers: 6 | |
a2a_radius: 60 | |
pl2pl_radius: 10 | |
pl2a_radius: 30 | |
a2sa_radius: 10 | |
pl2sa_radius: 10 | |
time_span: 60 | |
loss_weight: | |
token_cls_loss: 1 | |
map_token_loss: 1 | |
state_cls_loss: 10 | |
type_cls_loss: 5 | |
pos_cls_loss: 1 | |
head_cls_loss: 1 | |
offset_reg_loss: 5 | |
shape_reg_loss: .2 | |
state_weight: [0.1, 0.1, 0.8] # invalid, valid, exit | |
seed_state_weight: [0.1, 0.9] # invalid, enter | |
seed_type_weight: [0.8, 0.1, 0.1] | |
agent_occ_pos_weight: 100 | |
pt_occ_pos_weight: 5 | |
agent_occ_loss: 10 | |
pt_occ_loss: 10 | |