longsim-base / backups /configs /experiments /ablate_state_tokens.yaml
gzzyyxy's picture
Upload folder using huggingface_hub
d37e5d1 verified
# Config format schema number, the yaml support to valid case source from different dataset
time_info: &time_info
num_historical_steps: 11
num_future_steps: 80
use_intention: True
token_size: 2048
predict_motion: True
predict_state: True
predict_map: True
predict_occ: True
state_token:
invalid: 0
valid: 1
enter: 2
exit: 3
pl2seed_radius: 75.
disable_state_tokens: True
grid_range: 150. # 2 times of pl2seed_radius
grid_interval: 3.
angle_interval: 3.
seed_size: 1
buffer_size: 128
max_num: 32
Dataset:
root:
train_batch_size: 1
val_batch_size: 1
test_batch_size: 1
shuffle: True
num_workers: 1
pin_memory: True
persistent_workers: True
train_raw_dir: 'data/waymo_processed/training'
val_raw_dir: 'data/waymo_processed/validation'
test_raw_dir: 'data/waymo_processed/validation'
val_tfrecords_splitted: data/waymo_processed/training/validation_tfrecords_splitted
transform: WaymoTargetBuilder
train_processed_dir:
val_processed_dir:
test_processed_dir:
dataset: 'scalable'
<<: *time_info
Trainer:
strategy: ddp_find_unused_parameters_false
accelerator: 'gpu'
devices: 1
max_epochs: 32
save_ckpt_path:
num_nodes: 1
mode:
ckpt_path:
precision: 32
accumulate_grad_batches: 1
overfit_epochs: 6000
Model:
predictor: 'smart'
decoder_type: 'agent_decoder' # choose from ['agent_decoder', 'occ_decoder']
dataset: 'waymo'
input_dim: 2
hidden_dim: 128
output_dim: 2
output_head: False
num_heads: 8
<<: *time_info
head_dim: 16
dropout: 0.1
num_freq_bands: 64
lr: 0.0005
warmup_steps: 0
total_steps: 32
predict_map_token: False
num_recurrent_steps_val: 300
val_open_loop: False
val_close_loop: True
val_insert: False
n_rollout_close_val: 1
decoder:
<<: *time_info
num_map_layers: 3
num_agent_layers: 6
a2a_radius: 60
pl2pl_radius: 10
pl2a_radius: 30
a2sa_radius: 10
pl2sa_radius: 10
time_span: 60
loss_weight:
token_cls_loss: 1
map_token_loss: 1
state_cls_loss: 10
type_cls_loss: 5
pos_cls_loss: 1
head_cls_loss: 1
offset_reg_loss: 5
shape_reg_loss: .2
pos_reg_loss: 10
state_weight: [0.1, 0.1, 0.8] # invalid, valid, exit
seed_state_weight: [0.1, 0.9] # invalid, enter
seed_type_weight: [0.8, 0.1, 0.1]
agent_occ_pos_weight: 100
pt_occ_pos_weight: 5
agent_occ_loss: 10
pt_occ_loss: 10