File size: 2,263 Bytes
c1a7f73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Config format schema number, the yaml support to valid case source from different dataset
time_info: &time_info
  num_historical_steps: 11
  num_future_steps: 80
  use_intention: True
  token_size: 2048
  predict_motion: True
  predict_state: True
  predict_map: False
  predict_occ: True
  state_token:
    invalid: 0
    valid: 1
    enter: 2
    exit: 3
  pl2seed_radius: 75.
  grid_range: 150. # 2 times of pl2seed_radius
  grid_interval: 3.
  angle_interval: 3.
  seed_size: 1
  buffer_size: 128

Dataset:
  root:
  train_batch_size: 1
  val_batch_size: 1
  test_batch_size: 1
  shuffle: True
  num_workers: 1
  pin_memory: True
  persistent_workers: True
  train_raw_dir: ["data/waymo_processed/training"]
  val_raw_dir: ["data/waymo_processed/validation"]
  test_raw_dir: ["data/waymo_processed/validation"]
  transform: WaymoTargetBuilder
  train_processed_dir:
  val_processed_dir:
  test_processed_dir:
  dataset: "scalable"
  <<: *time_info

Trainer:
  strategy: ddp_find_unused_parameters_false
  accelerator: "gpu"
  devices: 1
  max_epochs: 32
  overfit_epochs: 6000
  save_ckpt_path:
  num_nodes: 1
  mode:
  ckpt_path:
  precision: 32
  accumulate_grad_batches: 1

Model:
  predictor: "smart"
  decoder_type: "agent_decoder" # choose from ['agent_decoder', 'occ_decoder']
  dataset: "waymo"
  input_dim: 2
  hidden_dim: 128
  output_dim: 2
  output_head: False
  num_heads: 8
  <<: *time_info
  head_dim: 16
  dropout: 0.1
  num_freq_bands: 64
  lr: 0.0005
  warmup_steps: 0
  total_steps: 32
  predict_map_token: False
  num_recurrent_steps_val: -1
  val_open_loop: True
  val_close_loop: False
  val_insert: False
  decoder:
    <<: *time_info
    num_map_layers: 3
    num_agent_layers: 6
    a2a_radius: 60
    pl2pl_radius: 10
    pl2a_radius: 30
    a2sa_radius: 10
    pl2sa_radius: 10
    time_span: 60
  loss_weight:
    token_cls_loss: 1
    map_token_loss: 1
    state_cls_loss: 10
    type_cls_loss: 5
    pos_cls_loss: 1
    head_cls_loss: 1
    offset_reg_loss: 5
    shape_reg_loss: .2
    state_weight: [0.1, 0.1, 0.8] # invalid, valid, exit
    seed_state_weight: [0.1, 0.9] # invalid, enter
    seed_type_weight: [0.8, 0.1, 0.1]
    agent_occ_pos_weight: 100
    pt_occ_pos_weight: 5
    agent_occ_loss: 10
    pt_occ_loss: 10