gzzyyxy commited on
Commit
c1a7f73
·
verified ·
1 Parent(s): 0d5ab94

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. backups/configs/experiments/ablate_grid_tokens.yaml +106 -0
  2. backups/configs/ours_long_term.yaml +105 -0
  3. backups/configs/ours_standard.yaml +101 -0
  4. backups/configs/ours_standard_decode_occ.yaml +100 -0
  5. backups/configs/pretrain_scalable_map.yaml +97 -0
  6. backups/configs/smart.yaml +70 -0
  7. backups/data_preprocess.py +916 -0
  8. backups/dev/datasets/preprocess.py +761 -0
  9. backups/dev/datasets/scalable_dataset.py +276 -0
  10. backups/dev/metrics/box_utils.py +113 -0
  11. backups/dev/metrics/compute_metrics.py +1812 -0
  12. backups/dev/metrics/geometry_utils.py +137 -0
  13. backups/dev/metrics/interact_features.py +220 -0
  14. backups/dev/metrics/map_features.py +349 -0
  15. backups/dev/metrics/placement_features.py +48 -0
  16. backups/dev/metrics/protos/long_metrics_pb2.py +648 -0
  17. backups/dev/metrics/protos/map_pb2.py +1070 -0
  18. backups/dev/metrics/protos/scenario_pb2.py +454 -0
  19. backups/dev/metrics/trajectory_features.py +52 -0
  20. backups/dev/metrics/val_close_long_metrics.json +24 -0
  21. backups/dev/model/smart.py +1100 -0
  22. backups/dev/modules/agent_decoder.py +0 -0
  23. backups/dev/modules/attr_tokenizer.py +109 -0
  24. backups/dev/modules/debug.py +1439 -0
  25. backups/dev/modules/layers.py +371 -0
  26. backups/dev/modules/map_decoder.py +130 -0
  27. backups/dev/modules/occ_decoder.py +927 -0
  28. backups/dev/modules/smart_decoder.py +137 -0
  29. backups/dev/utils/cluster_reader.py +45 -0
  30. backups/dev/utils/func.py +260 -0
  31. backups/dev/utils/graph.py +89 -0
  32. backups/dev/utils/metrics.py +692 -0
  33. backups/dev/utils/visualization.py +1145 -0
  34. backups/environment.yml +326 -0
  35. backups/run.py +181 -0
  36. backups/scripts/aggregate_log_metric_features.sh +16 -0
  37. backups/scripts/c128.sh +13 -0
  38. backups/scripts/c64.sh +13 -0
  39. backups/scripts/compute_metrics.sh +13 -0
  40. backups/scripts/data_preprocess.sh +12 -0
  41. backups/scripts/data_preprocess_loop.sh +23 -0
  42. backups/scripts/debug.py +17 -0
  43. backups/scripts/debug_map.py +204 -0
  44. backups/scripts/g2.sh +37 -0
  45. backups/scripts/g4.sh +35 -0
  46. backups/scripts/g8.sh +44 -0
  47. backups/scripts/hf_model.py +111 -0
  48. backups/scripts/pretrain_map.sh +27 -0
  49. backups/scripts/run_eval.sh +20 -0
  50. backups/scripts/run_train.sh +20 -0
backups/configs/experiments/ablate_grid_tokens.yaml ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Config format schema number, the yaml support to valid case source from different dataset
2
+ time_info: &time_info
3
+ num_historical_steps: 11
4
+ num_future_steps: 80
5
+ use_intention: True
6
+ token_size: 2048
7
+ predict_motion: True
8
+ predict_state: True
9
+ predict_map: True
10
+ predict_occ: True
11
+ state_token:
12
+ invalid: 0
13
+ valid: 1
14
+ enter: 2
15
+ exit: 3
16
+ pl2seed_radius: 75.
17
+ disable_grid_token: True
18
+ grid_range: 150. # 2 times of pl2seed_radius
19
+ grid_interval: 3.
20
+ angle_interval: 3.
21
+ seed_size: 1
22
+ buffer_size: 128
23
+ max_num: 32
24
+
25
+ Dataset:
26
+ root:
27
+ train_batch_size: 1
28
+ val_batch_size: 1
29
+ test_batch_size: 1
30
+ shuffle: True
31
+ num_workers: 1
32
+ pin_memory: True
33
+ persistent_workers: True
34
+ train_raw_dir: 'data/waymo_processed/training'
35
+ val_raw_dir: 'data/waymo_processed/validation'
36
+ test_raw_dir: 'data/waymo_processed/validation'
37
+ val_tfrecords_splitted: data/waymo_processed/training/validation_tfrecords_splitted
38
+ transform: WaymoTargetBuilder
39
+ train_processed_dir:
40
+ val_processed_dir:
41
+ test_processed_dir:
42
+ dataset: 'scalable'
43
+ <<: *time_info
44
+
45
+ Trainer:
46
+ strategy: ddp_find_unused_parameters_false
47
+ accelerator: 'gpu'
48
+ devices: 1
49
+ max_epochs: 32
50
+ save_ckpt_path:
51
+ num_nodes: 1
52
+ mode:
53
+ ckpt_path:
54
+ precision: 32
55
+ accumulate_grad_batches: 1
56
+ overfit_epochs: 6000
57
+
58
+ Model:
59
+ predictor: 'smart'
60
+ decoder_type: 'agent_decoder' # choose from ['agent_decoder', 'occ_decoder']
61
+ dataset: 'waymo'
62
+ input_dim: 2
63
+ hidden_dim: 128
64
+ output_dim: 2
65
+ output_head: False
66
+ num_heads: 8
67
+ <<: *time_info
68
+ head_dim: 16
69
+ dropout: 0.1
70
+ num_freq_bands: 64
71
+ lr: 0.0005
72
+ warmup_steps: 0
73
+ total_steps: 32
74
+ predict_map_token: False
75
+ num_recurrent_steps_val: 300
76
+ val_open_loop: False
77
+ val_close_loop: True
78
+ val_insert: False
79
+ n_rollout_close_val: 1
80
+ decoder:
81
+ <<: *time_info
82
+ num_map_layers: 3
83
+ num_agent_layers: 6
84
+ a2a_radius: 60
85
+ pl2pl_radius: 10
86
+ pl2a_radius: 30
87
+ a2sa_radius: 10
88
+ pl2sa_radius: 10
89
+ time_span: 60
90
+ loss_weight:
91
+ token_cls_loss: 1
92
+ map_token_loss: 1
93
+ state_cls_loss: 10
94
+ type_cls_loss: 5
95
+ pos_cls_loss: 1
96
+ head_cls_loss: 1
97
+ offset_reg_loss: 5
98
+ shape_reg_loss: .2
99
+ pos_reg_loss: 10
100
+ state_weight: [0.1, 0.1, 0.8] # invalid, valid, exit
101
+ seed_state_weight: [0.1, 0.9] # invalid, enter
102
+ seed_type_weight: [0.8, 0.1, 0.1]
103
+ agent_occ_pos_weight: 100
104
+ pt_occ_pos_weight: 5
105
+ agent_occ_loss: 10
106
+ pt_occ_loss: 10
backups/configs/ours_long_term.yaml ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Config format schema number, the yaml support to valid case source from different dataset
2
+ time_info: &time_info
3
+ num_historical_steps: 11
4
+ num_future_steps: 80
5
+ use_intention: True
6
+ token_size: 2048
7
+ predict_motion: True
8
+ predict_state: True
9
+ predict_map: True
10
+ predict_occ: True
11
+ state_token:
12
+ invalid: 0
13
+ valid: 1
14
+ enter: 2
15
+ exit: 3
16
+ pl2seed_radius: 75.
17
+ grid_range: 150. # 2 times of pl2seed_radius
18
+ grid_interval: 3.
19
+ angle_interval: 3.
20
+ seed_size: 1
21
+ buffer_size: 128
22
+ max_num: 32
23
+
24
+ Dataset:
25
+ root:
26
+ train_batch_size: 1
27
+ val_batch_size: 1
28
+ test_batch_size: 1
29
+ shuffle: True
30
+ num_workers: 1
31
+ pin_memory: True
32
+ persistent_workers: True
33
+ train_raw_dir: 'data/waymo_processed/training'
34
+ val_raw_dir: 'data/waymo_processed/validation'
35
+ test_raw_dir: 'data/waymo_processed/validation'
36
+ val_tfrecords_splitted: data/waymo_processed/training/validation_tfrecords_splitted
37
+ transform: WaymoTargetBuilder
38
+ train_processed_dir:
39
+ val_processed_dir:
40
+ test_processed_dir:
41
+ dataset: 'scalable'
42
+ <<: *time_info
43
+
44
+ Trainer:
45
+ strategy: ddp_find_unused_parameters_false
46
+ accelerator: 'gpu'
47
+ devices: 1
48
+ max_epochs: 32
49
+ save_ckpt_path:
50
+ num_nodes: 1
51
+ mode:
52
+ ckpt_path:
53
+ precision: 32
54
+ accumulate_grad_batches: 1
55
+ overfit_epochs: 6000
56
+
57
+ Model:
58
+ predictor: 'smart'
59
+ decoder_type: 'agent_decoder' # choose from ['agent_decoder', 'occ_decoder']
60
+ dataset: 'waymo'
61
+ input_dim: 2
62
+ hidden_dim: 128
63
+ output_dim: 2
64
+ output_head: False
65
+ num_heads: 8
66
+ <<: *time_info
67
+ head_dim: 16
68
+ dropout: 0.1
69
+ num_freq_bands: 64
70
+ lr: 0.0005
71
+ warmup_steps: 0
72
+ total_steps: 32
73
+ predict_map_token: False
74
+ num_recurrent_steps_val: 300
75
+ val_open_loop: False
76
+ val_close_loop: True
77
+ val_insert: False
78
+ n_rollout_close_val: 1
79
+ decoder:
80
+ <<: *time_info
81
+ num_map_layers: 3
82
+ num_agent_layers: 6
83
+ a2a_radius: 60
84
+ pl2pl_radius: 10
85
+ pl2a_radius: 30
86
+ a2sa_radius: 10
87
+ pl2sa_radius: 10
88
+ time_span: 60
89
+ loss_weight:
90
+ token_cls_loss: 1
91
+ map_token_loss: 1
92
+ state_cls_loss: 10
93
+ type_cls_loss: 5
94
+ pos_cls_loss: 1
95
+ head_cls_loss: 1
96
+ offset_reg_loss: 5
97
+ shape_reg_loss: .2
98
+ pos_reg_loss: 10
99
+ state_weight: [0.1, 0.1, 0.8] # invalid, valid, exit
100
+ seed_state_weight: [0.1, 0.9] # invalid, enter
101
+ seed_type_weight: [0.8, 0.1, 0.1]
102
+ agent_occ_pos_weight: 100
103
+ pt_occ_pos_weight: 5
104
+ agent_occ_loss: 10
105
+ pt_occ_loss: 10
backups/configs/ours_standard.yaml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Config format schema number, the yaml support to valid case source from different dataset
2
+ time_info: &time_info
3
+ num_historical_steps: 11
4
+ num_future_steps: 80
5
+ use_intention: True
6
+ token_size: 2048
7
+ predict_motion: True
8
+ predict_state: True
9
+ predict_map: False
10
+ predict_occ: True
11
+ state_token:
12
+ invalid: 0
13
+ valid: 1
14
+ enter: 2
15
+ exit: 3
16
+ pl2seed_radius: 75.
17
+ grid_range: 150. # 2 times of pl2seed_radius
18
+ grid_interval: 3.
19
+ angle_interval: 3.
20
+ seed_size: 1
21
+ buffer_size: 128
22
+
23
+ Dataset:
24
+ root:
25
+ train_batch_size: 1
26
+ val_batch_size: 1
27
+ test_batch_size: 1
28
+ shuffle: True
29
+ num_workers: 1
30
+ pin_memory: True
31
+ persistent_workers: True
32
+ train_raw_dir: ["data/waymo_processed/training"]
33
+ val_raw_dir: ["data/waymo_processed/validation"]
34
+ test_raw_dir: ["data/waymo_processed/validation"]
35
+ transform: WaymoTargetBuilder
36
+ train_processed_dir:
37
+ val_processed_dir:
38
+ test_processed_dir:
39
+ dataset: "scalable"
40
+ <<: *time_info
41
+
42
+ Trainer:
43
+ strategy: ddp_find_unused_parameters_false
44
+ accelerator: "gpu"
45
+ devices: 1
46
+ max_epochs: 32
47
+ overfit_epochs: 6000
48
+ save_ckpt_path:
49
+ num_nodes: 1
50
+ mode:
51
+ ckpt_path:
52
+ precision: 32
53
+ accumulate_grad_batches: 1
54
+
55
+ Model:
56
+ predictor: "smart"
57
+ decoder_type: "agent_decoder" # choose from ['agent_decoder', 'occ_decoder']
58
+ dataset: "waymo"
59
+ input_dim: 2
60
+ hidden_dim: 128
61
+ output_dim: 2
62
+ output_head: False
63
+ num_heads: 8
64
+ <<: *time_info
65
+ head_dim: 16
66
+ dropout: 0.1
67
+ num_freq_bands: 64
68
+ lr: 0.0005
69
+ warmup_steps: 0
70
+ total_steps: 32
71
+ predict_map_token: False
72
+ num_recurrent_steps_val: -1
73
+ val_open_loop: True
74
+ val_close_loop: False
75
+ val_insert: False
76
+ decoder:
77
+ <<: *time_info
78
+ num_map_layers: 3
79
+ num_agent_layers: 6
80
+ a2a_radius: 60
81
+ pl2pl_radius: 10
82
+ pl2a_radius: 30
83
+ a2sa_radius: 10
84
+ pl2sa_radius: 10
85
+ time_span: 60
86
+ loss_weight:
87
+ token_cls_loss: 1
88
+ map_token_loss: 1
89
+ state_cls_loss: 10
90
+ type_cls_loss: 5
91
+ pos_cls_loss: 1
92
+ head_cls_loss: 1
93
+ offset_reg_loss: 5
94
+ shape_reg_loss: .2
95
+ state_weight: [0.1, 0.1, 0.8] # invalid, valid, exit
96
+ seed_state_weight: [0.1, 0.9] # invalid, enter
97
+ seed_type_weight: [0.8, 0.1, 0.1]
98
+ agent_occ_pos_weight: 100
99
+ pt_occ_pos_weight: 5
100
+ agent_occ_loss: 10
101
+ pt_occ_loss: 10
backups/configs/ours_standard_decode_occ.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Config format schema number, the yaml support to valid case source from different dataset
2
+ time_info: &time_info
3
+ num_historical_steps: 11
4
+ num_future_steps: 80
5
+ use_intention: True
6
+ token_size: 2048
7
+ predict_motion: False
8
+ predict_state: False
9
+ predict_map: False
10
+ predict_occ: True
11
+ state_token:
12
+ invalid: 0
13
+ valid: 1
14
+ enter: 2
15
+ exit: 3
16
+ pl2seed_radius: 75.
17
+ grid_range: 150. # 2 times of pl2seed_radius
18
+ grid_interval: 3.
19
+ angle_interval: 3.
20
+ seed_size: 1
21
+ buffer_size: 128
22
+
23
+ Dataset:
24
+ root:
25
+ train_batch_size: 1
26
+ val_batch_size: 1
27
+ test_batch_size: 1
28
+ shuffle: True
29
+ num_workers: 1
30
+ pin_memory: True
31
+ persistent_workers: True
32
+ train_raw_dir: ["data/waymo_processed/training"]
33
+ val_raw_dir: ["data/waymo_processed/validation"]
34
+ test_raw_dir: ["data/waymo_processed/validation"]
35
+ transform: WaymoTargetBuilder
36
+ train_processed_dir:
37
+ val_processed_dir:
38
+ test_processed_dir:
39
+ dataset: "scalable"
40
+ <<: *time_info
41
+
42
+ Trainer:
43
+ strategy: ddp_find_unused_parameters_false
44
+ accelerator: "gpu"
45
+ devices: 1
46
+ max_epochs: 32
47
+ overfit_epochs: 6000
48
+ save_ckpt_path:
49
+ num_nodes: 1
50
+ mode:
51
+ ckpt_path:
52
+ precision: 32
53
+ accumulate_grad_batches: 1
54
+
55
+ Model:
56
+ predictor: "smart"
57
+ decoder_type: "occ_decoder" # choose from ['agent_decoder', 'occ_decoder']
58
+ dataset: "waymo"
59
+ input_dim: 2
60
+ hidden_dim: 128
61
+ output_dim: 2
62
+ output_head: False
63
+ num_heads: 8
64
+ <<: *time_info
65
+ head_dim: 16
66
+ dropout: 0.1
67
+ num_freq_bands: 64
68
+ lr: 0.0005
69
+ warmup_steps: 0
70
+ total_steps: 32
71
+ predict_map_token: False
72
+ num_recurrent_steps_val: -1
73
+ val_open_loop: True
74
+ val_closed_loop: False
75
+ decoder:
76
+ <<: *time_info
77
+ num_map_layers: 3
78
+ num_agent_layers: 6
79
+ a2a_radius: 60
80
+ pl2pl_radius: 10
81
+ pl2a_radius: 30
82
+ a2sa_radius: 10
83
+ pl2sa_radius: 10
84
+ time_span: 60
85
+ loss_weight:
86
+ token_cls_loss: 1
87
+ map_token_loss: 1
88
+ state_cls_loss: 10
89
+ type_cls_loss: 5
90
+ pos_cls_loss: 1
91
+ head_cls_loss: 1
92
+ offset_reg_loss: 5
93
+ shape_reg_loss: .2
94
+ state_weight: [0.1, 0.1, 0.8] # invalid, valid, exit
95
+ seed_state_weight: [0.1, 0.9] # invalid, enter
96
+ seed_type_weight: [0.8, 0.1, 0.1]
97
+ agent_occ_pos_weight: 100
98
+ pt_occ_pos_weight: 5
99
+ agent_occ_loss: 10
100
+ pt_occ_loss: 10
backups/configs/pretrain_scalable_map.yaml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Config format schema number, the yaml support to valid case source from different dataset
2
+ time_info: &time_info
3
+ num_historical_steps: 11
4
+ num_future_steps: 80
5
+ use_intention: True
6
+ token_size: 2048
7
+ predict_motion: False
8
+ predict_state: False
9
+ predict_map: True
10
+ predict_occ: False
11
+ state_token:
12
+ invalid: 0
13
+ valid: 1
14
+ enter: 2
15
+ exit: 3
16
+ pl2seed_radius: 75.
17
+ grid_range: 150. # 2 times of pl2seed_radius
18
+ grid_interval: 3.
19
+ angle_interval: 3.
20
+ seed_size: 1
21
+ buffer_size: 32
22
+
23
+ Dataset:
24
+ root:
25
+ train_batch_size: 1
26
+ val_batch_size: 1
27
+ test_batch_size: 1
28
+ shuffle: True
29
+ num_workers: 1
30
+ pin_memory: True
31
+ persistent_workers: True
32
+ train_raw_dir: ["data/waymo_processed/training"]
33
+ val_raw_dir: ["data/waymo_processed/validation"]
34
+ test_raw_dir: ["data/waymo_processed/validation"]
35
+ transform: WaymoTargetBuilder
36
+ train_processed_dir:
37
+ val_processed_dir:
38
+ test_processed_dir:
39
+ dataset: "scalable"
40
+ <<: *time_info
41
+
42
+ Trainer:
43
+ strategy: ddp_find_unused_parameters_false
44
+ accelerator: "gpu"
45
+ devices: 1
46
+ max_epochs: 32
47
+ overfit_epochs: 6000
48
+ save_ckpt_path:
49
+ num_nodes: 1
50
+ mode:
51
+ ckpt_path:
52
+ precision: 32
53
+ accumulate_grad_batches: 1
54
+
55
+ Model:
56
+ mode: "train"
57
+ predictor: "smart"
58
+ decoder_type: "agent_decoder"
59
+ dataset: "waymo"
60
+ input_dim: 2
61
+ hidden_dim: 128
62
+ output_dim: 2
63
+ output_head: False
64
+ num_heads: 8
65
+ <<: *time_info
66
+ head_dim: 16
67
+ dropout: 0.1
68
+ num_freq_bands: 64
69
+ lr: 0.0005
70
+ warmup_steps: 0
71
+ total_steps: 32
72
+ predict_map_token: False
73
+ decoder:
74
+ <<: *time_info
75
+ num_map_layers: 3
76
+ num_agent_layers: 6
77
+ a2a_radius: 60
78
+ pl2pl_radius: 10
79
+ pl2a_radius: 30
80
+ a2sa_radius: 10
81
+ pl2sa_radius: 10
82
+ time_span: 60
83
+ loss_weight:
84
+ token_cls_loss: 1
85
+ map_token_loss: 1
86
+ state_cls_loss: 10
87
+ type_cls_loss: 5
88
+ pos_cls_loss: 1
89
+ head_cls_loss: 1
90
+ shape_reg_loss: .2
91
+ state_weight: [0.1, 0.1, 0.8] # invalid, valid, exit
92
+ seed_state_weight: [0.1, 0.9] # invalid, enter
93
+ seed_type_weight: [0.8, 0.1, 0.1]
94
+ agent_occ_pos_weight: 100
95
+ pt_occ_pos_weight: 5
96
+ agent_occ_loss: 10
97
+ pt_occ_loss: 10
backups/configs/smart.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Config format schema number, the yaml support to valid case source from different dataset
2
+ time_info: &time_info
3
+ num_historical_steps: 11
4
+ num_future_steps: 80
5
+ use_intention: True
6
+ token_size: 2048
7
+ disable_invalid: True
8
+ use_special_motion_token: False
9
+ use_state_token: False
10
+ only_state: False
11
+
12
+ Dataset:
13
+ root:
14
+ train_batch_size: 1
15
+ val_batch_size: 1
16
+ test_batch_size: 1
17
+ shuffle: True
18
+ num_workers: 1
19
+ pin_memory: True
20
+ persistent_workers: True
21
+ train_raw_dir: ["data/waymo_processed/training"]
22
+ val_raw_dir: ["data/waymo_processed/validation"]
23
+ test_raw_dir: ["data/waymo_processed/validation"]
24
+ transform: WaymoTargetBuilder
25
+ train_processed_dir:
26
+ val_processed_dir:
27
+ test_processed_dir:
28
+ dataset: "scalable"
29
+ <<: *time_info
30
+
31
+ Trainer:
32
+ strategy: ddp_find_unused_parameters_false
33
+ accelerator: "gpu"
34
+ devices: 1
35
+ max_epochs: 32
36
+ overfit_epochs: 5000
37
+ save_ckpt_path:
38
+ num_nodes: 1
39
+ mode:
40
+ ckpt_path:
41
+ precision: 32
42
+ accumulate_grad_batches: 1
43
+
44
+ Model:
45
+ mode: "train"
46
+ predictor: "smart"
47
+ dataset: "waymo"
48
+ input_dim: 2
49
+ hidden_dim: 128
50
+ output_dim: 2
51
+ output_head: False
52
+ num_heads: 8
53
+ <<: *time_info
54
+ head_dim: 16
55
+ dropout: 0.1
56
+ num_freq_bands: 64
57
+ lr: 0.0005
58
+ warmup_steps: 0
59
+ total_steps: 32
60
+ decoder:
61
+ <<: *time_info
62
+ num_map_layers: 3
63
+ num_agent_layers: 6
64
+ a2a_radius: 60
65
+ pl2pl_radius: 10
66
+ pl2a_radius: 30
67
+ time_span: 30
68
+ loss_weight:
69
+ token_cls_loss: 1
70
+ state_cls_loss: 5
backups/data_preprocess.py ADDED
@@ -0,0 +1,916 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Not a contribution
2
+ # Changes made by NVIDIA CORPORATION & AFFILIATES enabling <CAT-K> or otherwise documented as
3
+ # NVIDIA-proprietary are not a contribution and subject to the following terms and conditions:
4
+ # SPDX-FileCopyrightText: Copyright (c) <year> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
5
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
6
+ #
7
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
8
+ # property and proprietary rights in and to this material, related
9
+ # documentation and any modifications thereto. Any use, reproduction,
10
+ # disclosure or distribution of this material and related documentation
11
+ # without an express license agreement from NVIDIA CORPORATION or
12
+ # its affiliates is strictly prohibited.
13
+
14
+ import signal
15
+ import multiprocessing
16
+ import os
17
+ import numpy as np
18
+ import pandas as pd
19
+ import tensorflow as tf
20
+ import torch
21
+ import pickle
22
+ import easydict
23
+ from functools import partial
24
+ from scipy.interpolate import interp1d
25
+ from argparse import ArgumentParser
26
+ from tqdm import tqdm
27
+ from typing import Any, Dict, List, Optional
28
+ from waymo_open_dataset.protos import scenario_pb2
29
+
30
+
31
+ MIN_VALID_STEPS = 15
32
+
33
+
34
+ _polygon_types = ['VEHICLE', 'BIKE', 'BUS', 'PEDESTRIAN']
35
+ _polygon_light_type = ['LANE_STATE_STOP', 'LANE_STATE_GO', 'LANE_STATE_CAUTION', 'LANE_STATE_UNKNOWN']
36
+ _point_types = ['DASH_SOLID_YELLOW', 'DASH_SOLID_WHITE', 'DASHED_WHITE', 'DASHED_YELLOW',
37
+ 'DOUBLE_SOLID_YELLOW', 'DOUBLE_SOLID_WHITE', 'DOUBLE_DASH_YELLOW', 'DOUBLE_DASH_WHITE',
38
+ 'SOLID_YELLOW', 'SOLID_WHITE', 'SOLID_DASH_WHITE', 'SOLID_DASH_YELLOW', 'EDGE',
39
+ 'NONE', 'UNKNOWN', 'CROSSWALK', 'CENTERLINE']
40
+ _polygon_to_polygon_types = ['NONE', 'PRED', 'SUCC', 'LEFT', 'RIGHT']
41
+
42
+
43
+ Lane_type_hash = {
44
+ 4: "BIKE",
45
+ 3: "VEHICLE",
46
+ 2: "VEHICLE",
47
+ 1: "BUS"
48
+ }
49
+
50
+ boundary_type_hash = {
51
+ 5: "UNKNOWN",
52
+ 6: "DASHED_WHITE",
53
+ 7: "SOLID_WHITE",
54
+ 8: "DOUBLE_DASH_WHITE",
55
+ 9: "DASHED_YELLOW",
56
+ 10: "DOUBLE_DASH_YELLOW",
57
+ 11: "SOLID_YELLOW",
58
+ 12: "DOUBLE_SOLID_YELLOW",
59
+ 13: "DASH_SOLID_YELLOW",
60
+ 14: "UNKNOWN",
61
+ 15: "EDGE",
62
+ 16: "EDGE"
63
+ }
64
+
65
+
66
+ def safe_list_index(ls: List[Any], elem: Any) -> Optional[int]:
67
+ try:
68
+ return ls.index(elem)
69
+ except ValueError:
70
+ return None
71
+
72
+
73
+ # def get_agent_features(df: pd.DataFrame, av_id, num_historical_steps=11, dim=3, num_steps=91) -> Dict[str, Any]:
74
+ # if args.disable_invalid: # filter out agents that are unseen during the historical time steps
75
+ # historical_df = df[df['timestep'] == num_historical_steps-1] # extract the timestep==10 (current)
76
+ # agent_ids = list(historical_df['track_id'].unique()) # these agents are seen at timestep==10 (current)
77
+ # df = df[df['track_id'].isin(agent_ids)] # remove other agents
78
+ # else:
79
+ # agent_ids = list(df['track_id'].unique())
80
+
81
+ # num_agents = len(agent_ids)
82
+ # # initialization
83
+ # valid_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool)
84
+ # current_valid_mask = torch.zeros(num_agents, dtype=torch.bool)
85
+ # predict_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool)
86
+ # agent_id: List[Optional[str]] = [None] * num_agents
87
+ # agent_type = torch.zeros(num_agents, dtype=torch.uint8)
88
+ # agent_category = torch.zeros(num_agents, dtype=torch.uint8)
89
+ # position = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
90
+ # heading = torch.zeros(num_agents, num_steps, dtype=torch.float)
91
+ # velocity = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
92
+ # shape = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
93
+
94
+ # for track_id, track_df in df.groupby('track_id'):
95
+ # agent_idx = agent_ids.index(track_id)
96
+ # all_agent_steps = track_df['timestep'].values
97
+ # valid_agent_steps = all_agent_steps[track_df['validity'].astype(np.bool_)].astype(np.int32)
98
+ # valid_mask[agent_idx, valid_agent_steps] = True
99
+ # current_valid_mask[agent_idx] = valid_mask[agent_idx, num_historical_steps - 1] # current timestep 10
100
+ # if args.disable_invalid:
101
+ # predict_mask[agent_idx, valid_agent_steps] = True
102
+ # else:
103
+ # predict_mask[agent_idx] = True
104
+ # predict_mask[agent_idx, :num_historical_steps] = False
105
+ # if not current_valid_mask[agent_idx]:
106
+ # predict_mask[agent_idx, num_historical_steps:] = False
107
+
108
+ # # TODO: why using vector_repr?
109
+ # if vector_repr: # a time step t is valid only when both t and t-1 are valid
110
+ # valid_mask[agent_idx, 1 : num_historical_steps] = (
111
+ # valid_mask[agent_idx, : num_historical_steps - 1] &
112
+ # valid_mask[agent_idx, 1 : num_historical_steps])
113
+ # valid_mask[agent_idx, 0] = False
114
+
115
+ # agent_id[agent_idx] = track_id
116
+ # agent_type[agent_idx] = _agent_types.index(track_df['object_type'].values[0])
117
+ # agent_category[agent_idx] = track_df['object_category'].values[0]
118
+ # position[agent_idx, valid_agent_steps, :3] = torch.from_numpy(np.stack([track_df['position_x'].values[valid_agent_steps],
119
+ # track_df['position_y'].values[valid_agent_steps],
120
+ # track_df['position_z'].values[valid_agent_steps]],
121
+ # axis=-1)).float()
122
+ # heading[agent_idx, valid_agent_steps] = torch.from_numpy(track_df['heading'].values[valid_agent_steps]).float()
123
+ # velocity[agent_idx, valid_agent_steps, :2] = torch.from_numpy(np.stack([track_df['velocity_x'].values[valid_agent_steps],
124
+ # track_df['velocity_y'].values[valid_agent_steps]],
125
+ # axis=-1)).float()
126
+ # shape[agent_idx, valid_agent_steps, :3] = torch.from_numpy(np.stack([track_df['length'].values[valid_agent_steps],
127
+ # track_df['width'].values[valid_agent_steps],
128
+ # track_df["height"].values[valid_agent_steps]],
129
+ # axis=-1)).float()
130
+ # av_idx = agent_id.index(av_id)
131
+ # if split == 'test':
132
+ # predict_mask[current_valid_mask
133
+ # | (agent_category == 2)
134
+ # | (agent_category == 3), num_historical_steps:] = True
135
+
136
+ # return {
137
+ # 'num_nodes': num_agents,
138
+ # 'av_index': av_idx,
139
+ # 'valid_mask': valid_mask,
140
+ # 'predict_mask': predict_mask,
141
+ # 'id': agent_id,
142
+ # 'type': agent_type,
143
+ # 'category': agent_category,
144
+ # 'position': position,
145
+ # 'heading': heading,
146
+ # 'velocity': velocity,
147
+ # 'shape': shape
148
+ # }
149
+
150
+
151
+ def get_agent_features(track_infos: Dict[str, np.ndarray], av_id: int, num_historical_steps: int, num_steps: int) -> Dict[str, Any]:
152
+
153
+ agent_idx_to_add = []
154
+ for i in range(len(track_infos['object_id'])):
155
+ is_visible = track_infos['valid'][i, num_historical_steps - 1]
156
+ valid_steps = np.where(track_infos['valid'][i])[0]
157
+ valid_start, valid_end = valid_steps[0], valid_steps[-1]
158
+ is_valid = (valid_end - valid_start + 1) >= MIN_VALID_STEPS
159
+
160
+ if (is_visible or not args.disable_invalid) and is_valid:
161
+ agent_idx_to_add.append(i)
162
+
163
+ num_agents = len(agent_idx_to_add)
164
+ out_dict = {
165
+ 'num_nodes': num_agents,
166
+ 'valid_mask': torch.zeros(num_agents, num_steps, dtype=torch.bool),
167
+ 'role': torch.zeros(num_agents, 3, dtype=torch.bool),
168
+ 'id': torch.zeros(num_agents, dtype=torch.int64) - 1,
169
+ 'type': torch.zeros(num_agents, dtype=torch.uint8),
170
+ 'category': torch.zeros(num_agents, dtype=torch.uint8),
171
+ 'position': torch.zeros(num_agents, num_steps, 3, dtype=torch.float),
172
+ 'heading': torch.zeros(num_agents, num_steps, dtype=torch.float),
173
+ 'velocity': torch.zeros(num_agents, num_steps, 2, dtype=torch.float),
174
+ 'shape': torch.zeros(num_agents, num_steps, 3, dtype=torch.float),
175
+ }
176
+
177
+ for i, idx in enumerate(agent_idx_to_add):
178
+
179
+ out_dict['role'][i] = torch.from_numpy(track_infos['role'][idx])
180
+ out_dict['id'][i] = track_infos['object_id'][idx]
181
+ out_dict['type'][i] = track_infos['object_type'][idx]
182
+ out_dict['category'][i] = idx in track_infos['tracks_to_predict']
183
+
184
+ valid = track_infos["valid"][idx] # [n_step]
185
+ states = track_infos["states"][idx]
186
+
187
+ object_shape = states[:, 3:6] # [n_step, 3], length, width, height
188
+ object_shape = object_shape[valid].mean(axis=0) # [3]
189
+ out_dict["shape"][i] = torch.from_numpy(object_shape)
190
+
191
+ valid_steps = np.where(valid)[0]
192
+ position = states[:, :3] # [n_step, dim], x, y, z
193
+ velocity = states[:, 7:9] # [n_step, 2], vx, vy
194
+ heading = states[:, 6] # [n_step], heading
195
+
196
+ # valid.sum() should > 1:
197
+ t_start, t_end = valid_steps[0], valid_steps[-1]
198
+ f_pos = interp1d(valid_steps, position[valid], axis=0)
199
+ f_vel = interp1d(valid_steps, velocity[valid], axis=0)
200
+ f_yaw = interp1d(valid_steps, np.unwrap(heading[valid], axis=0), axis=0)
201
+ t_in = np.arange(t_start, t_end + 1)
202
+ out_dict["valid_mask"][i, t_start : t_end + 1] = True
203
+ out_dict["position"][i, t_start : t_end + 1] = torch.from_numpy(f_pos(t_in))
204
+ out_dict["velocity"][i, t_start : t_end + 1] = torch.from_numpy(f_vel(t_in))
205
+ out_dict["heading"][i, t_start : t_end + 1] = torch.from_numpy(f_yaw(t_in))
206
+
207
+ out_dict['av_idx'] = out_dict['id'].tolist().index(av_id)
208
+
209
+ return out_dict
210
+
211
+
212
+ def get_map_features(map_infos, tf_current_light, dim=3):
213
+ lane_segments = map_infos['lane']
214
+ all_polylines = map_infos["all_polylines"]
215
+ crosswalks = map_infos['crosswalk']
216
+ road_edges = map_infos['road_edge']
217
+ road_lines = map_infos['road_line']
218
+ lane_segment_ids = [info["id"] for info in lane_segments]
219
+ cross_walk_ids = [info["id"] for info in crosswalks]
220
+ road_edge_ids = [info["id"] for info in road_edges]
221
+ road_line_ids = [info["id"] for info in road_lines]
222
+ polygon_ids = lane_segment_ids + road_edge_ids + road_line_ids + cross_walk_ids
223
+ num_polygons = len(lane_segment_ids) + len(road_edge_ids) + len(road_line_ids) + len(cross_walk_ids)
224
+
225
+ # initialization
226
+ polygon_type = torch.zeros(num_polygons, dtype=torch.uint8)
227
+ polygon_light_type = torch.ones(num_polygons, dtype=torch.uint8) * 3
228
+
229
+ # list of (num_of_segments,), each element has shape of (num_of_points_of_current_segment - 1, dim)
230
+ point_position: List[Optional[torch.Tensor]] = [None] * num_polygons
231
+ point_orientation: List[Optional[torch.Tensor]] = [None] * num_polygons
232
+ point_magnitude: List[Optional[torch.Tensor]] = [None] * num_polygons
233
+ point_height: List[Optional[torch.Tensor]] = [None] * num_polygons
234
+ point_type: List[Optional[torch.Tensor]] = [None] * num_polygons
235
+
236
+ for lane_segment in lane_segments:
237
+ lane_segment = easydict.EasyDict(lane_segment)
238
+ lane_segment_idx = polygon_ids.index(lane_segment.id)
239
+ polyline_index = lane_segment.polyline_index # (start index of point in current scenario, end index of point in current scenario)
240
+ centerline = all_polylines[polyline_index[0] : polyline_index[1], :] # (num_of_points_of_current_segment, 5)
241
+ centerline = torch.from_numpy(centerline).float()
242
+ polygon_type[lane_segment_idx] = _polygon_types.index(Lane_type_hash[lane_segment.type])
243
+
244
+ res = tf_current_light[tf_current_light["lane_id"] == str(lane_segment.id)]
245
+ if len(res) != 0:
246
+ polygon_light_type[lane_segment_idx] = _polygon_light_type.index(res["state"].item())
247
+
248
+ point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0) # (num_of_points_of_current_segment - 1, 3)
249
+ center_vectors = centerline[1:] - centerline[:-1] # (num_of_points_of_current_segment - 1, 5)
250
+ point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0) # (num_of_points_of_current_segment - 1,)
251
+ point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1) # (num_of_points_of_current_segment - 1,)
252
+ point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0) # (num_of_points_of_current_segment - 1,)
253
+ center_type = _point_types.index('CENTERLINE')
254
+ point_type[lane_segment_idx] = torch.cat(
255
+ [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)
256
+
257
+ for lane_segment in road_edges:
258
+ lane_segment = easydict.EasyDict(lane_segment)
259
+ lane_segment_idx = polygon_ids.index(lane_segment.id)
260
+ polyline_index = lane_segment.polyline_index
261
+ centerline = all_polylines[polyline_index[0] : polyline_index[1], :]
262
+ centerline = torch.from_numpy(centerline).float()
263
+ polygon_type[lane_segment_idx] = _polygon_types.index("VEHICLE")
264
+
265
+ point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)
266
+ center_vectors = centerline[1:] - centerline[:-1]
267
+ point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)
268
+ point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)
269
+ point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)
270
+ center_type = _point_types.index('EDGE')
271
+ point_type[lane_segment_idx] = torch.cat(
272
+ [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)
273
+
274
+ for lane_segment in road_lines:
275
+ lane_segment = easydict.EasyDict(lane_segment)
276
+ lane_segment_idx = polygon_ids.index(lane_segment.id)
277
+ polyline_index = lane_segment.polyline_index
278
+ centerline = all_polylines[polyline_index[0] : polyline_index[1], :]
279
+ centerline = torch.from_numpy(centerline).float()
280
+
281
+ polygon_type[lane_segment_idx] = _polygon_types.index("VEHICLE")
282
+
283
+ point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)
284
+ center_vectors = centerline[1:] - centerline[:-1]
285
+ point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)
286
+ point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)
287
+ point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)
288
+ center_type = _point_types.index(boundary_type_hash[lane_segment.type])
289
+ point_type[lane_segment_idx] = torch.cat(
290
+ [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)
291
+
292
+ for crosswalk in crosswalks:
293
+ crosswalk = easydict.EasyDict(crosswalk)
294
+ lane_segment_idx = polygon_ids.index(crosswalk.id)
295
+ polyline_index = crosswalk.polyline_index
296
+ centerline = all_polylines[polyline_index[0] : polyline_index[1], :]
297
+ centerline = torch.from_numpy(centerline).float()
298
+
299
+ polygon_type[lane_segment_idx] = _polygon_types.index("PEDESTRIAN")
300
+
301
+ point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)
302
+ center_vectors = centerline[1:] - centerline[:-1]
303
+ point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)
304
+ point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)
305
+ point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)
306
+ center_type = _point_types.index("CROSSWALK")
307
+ point_type[lane_segment_idx] = torch.cat(
308
+ [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)
309
+
310
+ # (num_of_segments,), each element represents the number of points of the segment
311
+ num_points = torch.tensor([point.size(0) for point in point_position], dtype=torch.long)
312
+ # (2, total_num_of_points_of_all_segments), store the point index of segment and its corresponding segment index
313
+ # e.g. a scenario has 203 segments, and totally 14039 points:
314
+ # tensor([[ 0, 1, 2, ..., 14927, 14928, 14929],
315
+ # [ 0, 0, 0, ..., 202, 202, 202]]) => polygon_ids.index(lane_segment.id)
316
+ point_to_polygon_edge_index = torch.stack(
317
+ [torch.arange(num_points.sum(), dtype=torch.long),
318
+ torch.arange(num_polygons, dtype=torch.long).repeat_interleave(num_points)], dim=0)
319
+ # list of (num_of_lane_segments,)
320
+ polygon_to_polygon_edge_index = []
321
+ # list of (num_of_lane_segments,)
322
+ polygon_to_polygon_type = []
323
+ for lane_segment in lane_segments:
324
+ lane_segment = easydict.EasyDict(lane_segment)
325
+ lane_segment_idx = polygon_ids.index(lane_segment.id)
326
+ pred_inds = []
327
+ for pred in lane_segment.entry_lanes:
328
+ pred_idx = safe_list_index(polygon_ids, pred)
329
+ if pred_idx is not None:
330
+ pred_inds.append(pred_idx)
331
+ if len(pred_inds) != 0:
332
+ polygon_to_polygon_edge_index.append(
333
+ torch.stack([torch.tensor(pred_inds, dtype=torch.long),
334
+ torch.full((len(pred_inds),), lane_segment_idx, dtype=torch.long)], dim=0))
335
+ polygon_to_polygon_type.append(
336
+ torch.full((len(pred_inds),), _polygon_to_polygon_types.index('PRED'), dtype=torch.uint8))
337
+ succ_inds = []
338
+ for succ in lane_segment.exit_lanes:
339
+ succ_idx = safe_list_index(polygon_ids, succ)
340
+ if succ_idx is not None:
341
+ succ_inds.append(succ_idx)
342
+ if len(succ_inds) != 0:
343
+ polygon_to_polygon_edge_index.append(
344
+ torch.stack([torch.tensor(succ_inds, dtype=torch.long),
345
+ torch.full((len(succ_inds),), lane_segment_idx, dtype=torch.long)], dim=0))
346
+ polygon_to_polygon_type.append(
347
+ torch.full((len(succ_inds),), _polygon_to_polygon_types.index('SUCC'), dtype=torch.uint8))
348
+ if len(lane_segment.left_neighbors) != 0:
349
+ left_neighbor_ids = lane_segment.left_neighbors
350
+ for left_neighbor_id in left_neighbor_ids:
351
+ left_idx = safe_list_index(polygon_ids, left_neighbor_id)
352
+ if left_idx is not None:
353
+ polygon_to_polygon_edge_index.append(
354
+ torch.tensor([[left_idx], [lane_segment_idx]], dtype=torch.long))
355
+ polygon_to_polygon_type.append(
356
+ torch.tensor([_polygon_to_polygon_types.index('LEFT')], dtype=torch.uint8))
357
+ if len(lane_segment.right_neighbors) != 0:
358
+ right_neighbor_ids = lane_segment.right_neighbors
359
+ for right_neighbor_id in right_neighbor_ids:
360
+ right_idx = safe_list_index(polygon_ids, right_neighbor_id)
361
+ if right_idx is not None:
362
+ polygon_to_polygon_edge_index.append(
363
+ torch.tensor([[right_idx], [lane_segment_idx]], dtype=torch.long))
364
+ polygon_to_polygon_type.append(
365
+ torch.tensor([_polygon_to_polygon_types.index('RIGHT')], dtype=torch.uint8))
366
+ if len(polygon_to_polygon_edge_index) != 0:
367
+ polygon_to_polygon_edge_index = torch.cat(polygon_to_polygon_edge_index, dim=1)
368
+ polygon_to_polygon_type = torch.cat(polygon_to_polygon_type, dim=0)
369
+ else:
370
+ polygon_to_polygon_edge_index = torch.tensor([[], []], dtype=torch.long)
371
+ polygon_to_polygon_type = torch.tensor([], dtype=torch.uint8)
372
+
373
+ map_data = {
374
+ 'map_polygon': {},
375
+ 'map_point': {},
376
+ ('map_point', 'to', 'map_polygon'): {},
377
+ ('map_polygon', 'to', 'map_polygon'): {},
378
+ }
379
+ map_data['map_polygon']['num_nodes'] = num_polygons # int, number of map segments in the scenario
380
+ map_data['map_polygon']['type'] = polygon_type # (num_polygons,) type of each polygon
381
+ map_data['map_polygon']['light_type'] = polygon_light_type # (num_polygons,) light type of each polygon, 3 means unknown
382
+ if len(num_points) == 0:
383
+ map_data['map_point']['num_nodes'] = 0
384
+ map_data['map_point']['position'] = torch.tensor([], dtype=torch.float)
385
+ map_data['map_point']['orientation'] = torch.tensor([], dtype=torch.float)
386
+ map_data['map_point']['magnitude'] = torch.tensor([], dtype=torch.float)
387
+ if dim == 3:
388
+ map_data['map_point']['height'] = torch.tensor([], dtype=torch.float)
389
+ map_data['map_point']['type'] = torch.tensor([], dtype=torch.uint8)
390
+ map_data['map_point']['side'] = torch.tensor([], dtype=torch.uint8)
391
+ else:
392
+ map_data['map_point']['num_nodes'] = num_points.sum().item() # int, number of total points of all segments in the scenario
393
+ map_data['map_point']['position'] = torch.cat(point_position, dim=0) # (num_of_total_points_of_all_segments, 3)
394
+ map_data['map_point']['orientation'] = torch.cat(point_orientation, dim=0) # (num_of_total_points_of_all_segments,)
395
+ map_data['map_point']['magnitude'] = torch.cat(point_magnitude, dim=0) # (num_of_total_points_of_all_segments,)
396
+ if dim == 3:
397
+ map_data['map_point']['height'] = torch.cat(point_height, dim=0) # (num_of_total_points_of_all_segments,)
398
+ map_data['map_point']['type'] = torch.cat(point_type, dim=0) # (num_of_total_points_of_all_segments,) type of point => `_point_types`
399
+ map_data['map_point', 'to', 'map_polygon']['edge_index'] = point_to_polygon_edge_index # (2, num_of_total_points_of_all_segments)
400
+ map_data['map_polygon', 'to', 'map_polygon']['edge_index'] = polygon_to_polygon_edge_index
401
+ map_data['map_polygon', 'to', 'map_polygon']['type'] = polygon_to_polygon_type
402
+
403
+ if int(os.getenv('DEBUG_MAP', 1)):
404
+ import matplotlib.pyplot as plt
405
+ plt.axis('equal')
406
+ plt.scatter(map_data['map_point']['position'][:, 0],
407
+ map_data['map_point']['position'][:, 1], s=0.2, c='black', edgecolors='none')
408
+ plt.savefig("debug.png", dpi=600)
409
+
410
+ return map_data
411
+
412
+
413
+ # def process_agent(track_info, tracks_to_predict, scenario_id, start_timestamp, end_timestamp):
414
+
415
+ # agents_array = track_info["states"].transpose(1, 0, 2) # (num_timesteps, num_agents, 10) e.g. (91, 15, 10)
416
+ # object_id = np.array(track_info["object_id"]) # (num_agents,) global id of each agent
417
+ # object_type = track_info["object_type"] # (num_agents,) type of each agent, e.g. 'TYPE_VEHICLE'
418
+ # id_hash = {object_id[o_idx]: object_type[o_idx] for o_idx in range(len(object_id))}
419
+
420
+ # def type_hash(x):
421
+ # tp = id_hash[x]
422
+ # type_re_hash = {
423
+ # "TYPE_VEHICLE": "vehicle",
424
+ # "TYPE_PEDESTRIAN": "pedestrian",
425
+ # "TYPE_CYCLIST": "cyclist",
426
+ # "TYPE_OTHER": "background",
427
+ # "TYPE_UNSET": "background"
428
+ # }
429
+ # return type_re_hash[tp]
430
+
431
+ # columns = ['observed', 'track_id', 'object_type', 'object_category', 'timestep',
432
+ # 'position_x', 'position_y', 'position_z', 'length', 'width', 'height', 'heading', 'velocity_x', 'velocity_y',
433
+ # 'scenario_id', 'start_timestamp', 'end_timestamp', 'num_timestamps',
434
+ # 'focal_track_id', 'city', 'validity']
435
+
436
+ # # (num_timesteps, num_agents, 10) e.g. (91, 15, 10)
437
+ # new_columns = np.ones((agents_array.shape[0], agents_array.shape[1], 11))
438
+ # new_columns[:11, :, 0] = True # observed, 10 timesteps
439
+ # new_columns[11:, :, 0] = False # not observed (current + future)
440
+ # for index in range(new_columns.shape[0]):
441
+ # new_columns[index, :, 4] = int(index) # timestep (0 ~ 90)
442
+ # new_columns[..., 1] = object_id
443
+ # new_columns[..., 2] = object_id
444
+ # new_columns[:, tracks_to_predict['track_index'], 3] = 3
445
+ # new_columns[..., 5] = 11
446
+ # new_columns[..., 6] = int(start_timestamp) # 0
447
+ # new_columns[..., 7] = int(end_timestamp) # 91
448
+ # new_columns[..., 8] = int(91) # 91
449
+ # new_columns[..., 9] = object_id
450
+ # new_columns[..., 10] = 10086
451
+ # new_columns = new_columns
452
+ # new_agents_array = np.concatenate([new_columns, agents_array], axis=-1) # (num_timesteps, num_agents, 21) e.g. (91, 15, 21)
453
+ # # filter out the invalid timestep of agents, reshape to (num_valid_of_timesteps_of_all_agents, 21) e.g. (91, 15, 21) -> (1137, 21)
454
+ # if args.disable_invalid:
455
+ # new_agents_array = new_agents_array[new_agents_array[..., -1] == 1.0].reshape(-1, new_agents_array.shape[-1])
456
+ # else:
457
+ # agent_valid_mask = new_agents_array[..., -1] # (num_timesteps, num_agents)
458
+ # agent_mask = np.sum(agent_valid_mask, axis=0) > MIN_VALID_STEPS # NOTE: 10 is a empirical parameter
459
+ # new_agents_array = new_agents_array[:, agent_mask]
460
+ # new_agents_array = new_agents_array.reshape(-1, new_agents_array.shape[-1]) # (91, 15, 21) -> (1365, 21)
461
+ # new_agents_array = new_agents_array[..., [0, 1, 2, 3, 4, 11, 12, 13, 14, 15, 16, 17, 18, 19, 5, 6, 7, 8, 9, 10, 20]]
462
+ # new_agents_array = pd.DataFrame(data=new_agents_array, columns=columns)
463
+ # new_agents_array["object_type"] = new_agents_array["object_type"].apply(func=type_hash)
464
+ # new_agents_array["start_timestamp"] = new_agents_array["start_timestamp"].astype(int)
465
+ # new_agents_array["end_timestamp"] = new_agents_array["end_timestamp"].astype(int)
466
+ # new_agents_array["num_timestamps"] = new_agents_array["num_timestamps"].astype(int)
467
+ # new_agents_array["scenario_id"] = scenario_id
468
+
469
+ # return new_agents_array
470
+
471
+
472
+ def process_dynamic_map(dynamic_map_infos):
473
+ lane_ids = dynamic_map_infos["lane_id"]
474
+ tf_lights = []
475
+ for t in range(len(lane_ids)):
476
+ lane_id = lane_ids[t]
477
+ time = np.ones_like(lane_id) * t
478
+ state = dynamic_map_infos["state"][t]
479
+ tf_light = np.concatenate([lane_id, time, state], axis=0)
480
+ tf_lights.append(tf_light)
481
+ tf_lights = np.concatenate(tf_lights, axis=1).transpose(1, 0)
482
+ tf_lights = pd.DataFrame(data=tf_lights, columns=["lane_id", "time_step", "state"])
483
+ tf_lights["time_step"] = tf_lights["time_step"].astype("str")
484
+ tf_lights["lane_id"] = tf_lights["lane_id"].astype("str")
485
+ tf_lights["state"] = tf_lights["state"].astype("str")
486
+ tf_lights.loc[tf_lights["state"].str.contains("STOP"), ["state"]] = (
487
+ "LANE_STATE_STOP"
488
+ )
489
+ tf_lights.loc[tf_lights["state"].str.contains("GO"), ["state"]] = "LANE_STATE_GO"
490
+ tf_lights.loc[tf_lights["state"].str.contains("CAUTION"), ["state"]] = (
491
+ "LANE_STATE_CAUTION"
492
+ )
493
+ tf_lights.loc[tf_lights["state"].str.contains("UNKNOWN"), ["state"]] = (
494
+ "LANE_STATE_UNKNOWN"
495
+ )
496
+
497
+ return tf_lights
498
+
499
+
500
+ polyline_type = {
501
+ # for lane
502
+ 'TYPE_UNDEFINED': -1,
503
+ 'TYPE_FREEWAY': 1,
504
+ 'TYPE_SURFACE_STREET': 2,
505
+ 'TYPE_BIKE_LANE': 3,
506
+
507
+ # for roadline
508
+ 'TYPE_UNKNOWN': -1,
509
+ 'TYPE_BROKEN_SINGLE_WHITE': 6,
510
+ 'TYPE_SOLID_SINGLE_WHITE': 7,
511
+ 'TYPE_SOLID_DOUBLE_WHITE': 8,
512
+ 'TYPE_BROKEN_SINGLE_YELLOW': 9,
513
+ 'TYPE_BROKEN_DOUBLE_YELLOW': 10,
514
+ 'TYPE_SOLID_SINGLE_YELLOW': 11,
515
+ 'TYPE_SOLID_DOUBLE_YELLOW': 12,
516
+ 'TYPE_PASSING_DOUBLE_YELLOW': 13,
517
+
518
+ # for roadedge
519
+ 'TYPE_ROAD_EDGE_BOUNDARY': 15,
520
+ 'TYPE_ROAD_EDGE_MEDIAN': 16,
521
+
522
+ # for stopsign
523
+ 'TYPE_STOP_SIGN': 17,
524
+
525
+ # for crosswalk
526
+ 'TYPE_CROSSWALK': 18,
527
+
528
+ # for speed bump
529
+ 'TYPE_SPEED_BUMP': 19
530
+ }
531
+
532
+ object_type = {
533
+ 0: 'TYPE_UNSET',
534
+ 1: 'TYPE_VEHICLE',
535
+ 2: 'TYPE_PEDESTRIAN',
536
+ 3: 'TYPE_CYCLIST',
537
+ 4: 'TYPE_OTHER'
538
+ }
539
+
540
+
541
+ def decode_tracks_from_proto(scenario):
542
+ sdc_track_index = scenario.sdc_track_index
543
+ track_index_predict = [i.track_index for i in scenario.tracks_to_predict]
544
+ object_id_interest = [i for i in scenario.objects_of_interest]
545
+
546
+ track_infos = {
547
+ 'object_id': [], # {0: unset, 1: vehicle, 2: pedestrian, 3: cyclist, 4: others}
548
+ 'object_type': [],
549
+ 'states': [],
550
+ 'valid': [],
551
+ 'role': [],
552
+ }
553
+
554
+ # tracks mean N number of objects, e.g. len(tracks) = 55
555
+ # each track has 91 states, e.g. len(tracks[0].states) == 91
556
+ # each state has 10 attributes: center_x, center_y, center_z, length, ..., velocity_y, valid
557
+ for i, cur_data in enumerate(scenario.tracks):
558
+
559
+ step_state = []
560
+ step_valid = []
561
+
562
+ for s in cur_data.states: # n_steps
563
+ step_state.append(
564
+ [
565
+ s.center_x,
566
+ s.center_y,
567
+ s.center_z,
568
+ s.length,
569
+ s.width,
570
+ s.height,
571
+ s.heading,
572
+ s.velocity_x,
573
+ s.velocity_y,
574
+ ]
575
+ )
576
+ step_valid.append(s.valid)
577
+ # This angle is normalized to [-pi, pi). The velocity vector in m/s
578
+
579
+ track_infos['object_id'].append(cur_data.id) # id of object in this track
580
+ track_infos['object_type'].append(cur_data.object_type - 1)
581
+ track_infos['states'].append(np.array(step_state, dtype=np.float32))
582
+ track_infos['valid'].append(np.array(step_valid))
583
+
584
+ track_infos['role'].append([False, False, False])
585
+ if i in track_index_predict:
586
+ track_infos['role'][-1][2] = True # predict=2
587
+ if cur_data.id in object_id_interest:
588
+ track_infos['role'][-1][1] = True # interest=1
589
+ if i == sdc_track_index:
590
+ track_infos['role'][-1][0] = True # ego_vehicle=0
591
+
592
+ track_infos['states'] = np.array(track_infos['states'], dtype=np.float32) # (n_agent, n_step, 9)
593
+ track_infos['valid'] = np.array(track_infos['valid'], dtype=np.bool_)
594
+ track_infos['role'] = np.array(track_infos['role'], dtype=np.bool_)
595
+ track_infos['object_id'] = np.array(track_infos['object_id'], dtype=np.int64)
596
+ track_infos['object_type'] = np.array(track_infos['object_type'], dtype=np.uint8)
597
+ track_infos['tracks_to_predict'] = np.array(track_index_predict, dtype=np.int64)
598
+
599
+ return track_infos
600
+
601
+
602
+ from collections import defaultdict
603
+
604
+ def decode_map_features_from_proto(map_features):
605
+ map_infos = {
606
+ 'lane': [],
607
+ 'road_line': [],
608
+ 'road_edge': [],
609
+ 'stop_sign': [],
610
+ 'crosswalk': [],
611
+ 'speed_bump': [],
612
+ 'lane_dict': {},
613
+ 'lane2other_dict': {}
614
+ }
615
+ polylines = []
616
+
617
+ point_cnt = 0
618
+ lane2other_dict = defaultdict(list)
619
+
620
+ for cur_data in map_features:
621
+ cur_info = {'id': cur_data.id}
622
+
623
+ if cur_data.lane.ByteSize() > 0:
624
+ cur_info['speed_limit_mph'] = cur_data.lane.speed_limit_mph
625
+ cur_info['type'] = cur_data.lane.type + 1 # 0: undefined, 1: freeway, 2: surface_street, 3: bike_lane
626
+ cur_info['left_neighbors'] = [lane.feature_id for lane in cur_data.lane.left_neighbors]
627
+
628
+ cur_info['right_neighbors'] = [lane.feature_id for lane in cur_data.lane.right_neighbors]
629
+
630
+ cur_info['interpolating'] = cur_data.lane.interpolating
631
+ cur_info['entry_lanes'] = list(cur_data.lane.entry_lanes)
632
+ cur_info['exit_lanes'] = list(cur_data.lane.exit_lanes)
633
+
634
+ cur_info['left_boundary_type'] = [x.boundary_type + 5 for x in cur_data.lane.left_boundaries]
635
+ cur_info['right_boundary_type'] = [x.boundary_type + 5 for x in cur_data.lane.right_boundaries]
636
+
637
+ cur_info['left_boundary'] = [x.boundary_feature_id for x in cur_data.lane.left_boundaries]
638
+ cur_info['right_boundary'] = [x.boundary_feature_id for x in cur_data.lane.right_boundaries]
639
+ cur_info['left_boundary_start_index'] = [lane.lane_start_index for lane in cur_data.lane.left_boundaries]
640
+ cur_info['left_boundary_end_index'] = [lane.lane_end_index for lane in cur_data.lane.left_boundaries]
641
+ cur_info['right_boundary_start_index'] = [lane.lane_start_index for lane in cur_data.lane.right_boundaries]
642
+ cur_info['right_boundary_end_index'] = [lane.lane_end_index for lane in cur_data.lane.right_boundaries]
643
+
644
+ lane2other_dict[cur_data.id].extend(cur_info['left_boundary'])
645
+ lane2other_dict[cur_data.id].extend(cur_info['right_boundary'])
646
+
647
+ global_type = cur_info['type']
648
+ cur_polyline = np.stack(
649
+ [np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in cur_data.lane.polyline],
650
+ axis=0)
651
+ cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
652
+ if cur_polyline.shape[0] <= 1:
653
+ continue
654
+ map_infos['lane'].append(cur_info)
655
+ map_infos['lane_dict'][cur_data.id] = cur_info
656
+
657
+ elif cur_data.road_line.ByteSize() > 0:
658
+ cur_info['type'] = cur_data.road_line.type + 5
659
+
660
+ global_type = cur_info['type']
661
+ cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in
662
+ cur_data.road_line.polyline], axis=0)
663
+ cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
664
+ if cur_polyline.shape[0] <= 1:
665
+ continue
666
+ map_infos['road_line'].append(cur_info) # (num_points, 5)
667
+
668
+ elif cur_data.road_edge.ByteSize() > 0:
669
+ cur_info['type'] = cur_data.road_edge.type + 14
670
+
671
+ global_type = cur_info['type']
672
+ cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in
673
+ cur_data.road_edge.polyline], axis=0)
674
+ cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
675
+ if cur_polyline.shape[0] <= 1:
676
+ continue
677
+ map_infos['road_edge'].append(cur_info)
678
+
679
+ elif cur_data.stop_sign.ByteSize() > 0:
680
+ cur_info['lane_ids'] = list(cur_data.stop_sign.lane)
681
+ for i in cur_info['lane_ids']:
682
+ lane2other_dict[i].append(cur_data.id)
683
+ point = cur_data.stop_sign.position
684
+ cur_info['position'] = np.array([point.x, point.y, point.z])
685
+
686
+ global_type = polyline_type['TYPE_STOP_SIGN']
687
+ cur_polyline = np.array([point.x, point.y, point.z, global_type, cur_data.id]).reshape(1, 5)
688
+ if cur_polyline.shape[0] <= 1:
689
+ continue
690
+ map_infos['stop_sign'].append(cur_info)
691
+ elif cur_data.crosswalk.ByteSize() > 0:
692
+ global_type = polyline_type['TYPE_CROSSWALK']
693
+ cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in
694
+ cur_data.crosswalk.polygon], axis=0)
695
+ cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
696
+ if cur_polyline.shape[0] <= 1:
697
+ continue
698
+ map_infos['crosswalk'].append(cur_info)
699
+
700
+ elif cur_data.speed_bump.ByteSize() > 0:
701
+ global_type = polyline_type['TYPE_SPEED_BUMP']
702
+ cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in
703
+ cur_data.speed_bump.polygon], axis=0)
704
+ cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
705
+ if cur_polyline.shape[0] <= 1:
706
+ continue
707
+ map_infos['speed_bump'].append(cur_info)
708
+
709
+ else:
710
+ continue
711
+ polylines.append(cur_polyline)
712
+ cur_info['polyline_index'] = (point_cnt, point_cnt + len(cur_polyline)) # (start index of point in current scenario, end index of point in current scenario)
713
+ point_cnt += len(cur_polyline)
714
+
715
+ polylines = np.concatenate(polylines, axis=0).astype(np.float32)
716
+ map_infos['all_polylines'] = polylines # (num_of_total_points_in_current_scenario, 5)
717
+ map_infos['lane2other_dict'] = lane2other_dict
718
+ return map_infos
719
+
720
+
721
+ def decode_dynamic_map_states_from_proto(dynamic_map_states):
722
+
723
+ signal_state = {
724
+ 0: 'LANE_STATE_UNKNOWN',
725
+ # States for traffic signals with arrows.
726
+ 1: 'LANE_STATE_ARROW_STOP',
727
+ 2: 'LANE_STATE_ARROW_CAUTION',
728
+ 3: 'LANE_STATE_ARROW_GO',
729
+ # Standard round traffic signals.
730
+ 4: 'LANE_STATE_STOP',
731
+ 5: 'LANE_STATE_CAUTION',
732
+ 6: 'LANE_STATE_GO',
733
+ # Flashing light signals.
734
+ 7: 'LANE_STATE_FLASHING_STOP',
735
+ 8: 'LANE_STATE_FLASHING_CAUTION'
736
+ }
737
+
738
+ dynamic_map_infos = {
739
+ 'lane_id': [],
740
+ 'state': [],
741
+ 'stop_point': []
742
+ }
743
+ for cur_data in dynamic_map_states: # len(dynamic_map_states) = num_timestamp
744
+ lane_id, state, stop_point = [], [], []
745
+ for cur_signal in cur_data.lane_states: # (num_observed_signals)
746
+ lane_id.append(cur_signal.lane)
747
+ state.append(signal_state[cur_signal.state])
748
+ stop_point.append([cur_signal.stop_point.x, cur_signal.stop_point.y, cur_signal.stop_point.z])
749
+
750
+ dynamic_map_infos['lane_id'].append(np.array([lane_id]))
751
+ dynamic_map_infos['state'].append(np.array([state]))
752
+ dynamic_map_infos['stop_point'].append(np.array([stop_point]))
753
+
754
+ return dynamic_map_infos
755
+
756
+
757
+ # def process_single_data(scenario):
758
+ # info = {}
759
+ # info['scenario_id'] = scenario.scenario_id
760
+ # info['timestamps_seconds'] = list(scenario.timestamps_seconds) # list of int of shape (91)
761
+ # info['current_time_index'] = scenario.current_time_index # int, 10
762
+ # info['sdc_track_index'] = scenario.sdc_track_index # int
763
+ # info['objects_of_interest'] = list(scenario.objects_of_interest) # list, could be empty list
764
+
765
+ # info['tracks_to_predict'] = {
766
+ # 'track_index': [cur_pred.track_index for cur_pred in scenario.tracks_to_predict],
767
+ # 'difficulty': [cur_pred.difficulty for cur_pred in scenario.tracks_to_predict]
768
+ # } # for training: suggestion of objects to train on, for val/test: need to be predicted
769
+
770
+ # # decode tracks data
771
+ # track_infos = decode_tracks_from_proto(scenario.tracks)
772
+ # info['tracks_to_predict']['object_type'] = [track_infos['object_type'][cur_idx] for cur_idx in
773
+ # info['tracks_to_predict']['track_index']]
774
+ # # decode map related data
775
+ # map_infos = decode_map_features_from_proto(scenario.map_features)
776
+ # dynamic_map_infos = decode_dynamic_map_states_from_proto(scenario.dynamic_map_states)
777
+
778
+ # save_infos = {
779
+ # 'track_infos': track_infos,
780
+ # 'map_infos': map_infos,
781
+ # 'dynamic_map_infos': dynamic_map_infos,
782
+ # }
783
+ # save_infos.update(info)
784
+ # return save_infos
785
+
786
+
787
+ def wm2argo(file, input_dir, output_dir, existing_files=[], output_dir_tfrecords_splitted=None):
788
+ file_path = os.path.join(input_dir, file)
789
+ dataset = tf.data.TFRecordDataset(file_path, compression_type='', num_parallel_reads=3)
790
+
791
+ for cnt, tf_data in tqdm(enumerate(dataset), leave=False, desc=f'Process {file}...'):
792
+
793
+ scenario = scenario_pb2.Scenario()
794
+ scenario.ParseFromString(bytearray(tf_data.numpy()))
795
+ scenario_id = scenario.scenario_id
796
+ tqdm.write(f"idx: {cnt}, scenario_id: {scenario_id} of {file}")
797
+
798
+ if f'{scenario_id}.pkl' not in existing_files:
799
+
800
+ map_infos = decode_map_features_from_proto(scenario.map_features)
801
+ track_infos = decode_tracks_from_proto(scenario)
802
+ dynamic_map_infos = decode_dynamic_map_states_from_proto(scenario.dynamic_map_states)
803
+ sdc_track_index = scenario.sdc_track_index # int
804
+ av_id = track_infos['object_id'][sdc_track_index]
805
+ # if len(track_infos['tracks_to_predict']) < 1:
806
+ # return
807
+
808
+ current_time_index = scenario.current_time_index
809
+ tf_lights = process_dynamic_map(dynamic_map_infos)
810
+ tf_current_light = tf_lights.loc[tf_lights["time_step"] == current_time_index] # 10 (history) + 1 (current) + 80 (future)
811
+ map_data = get_map_features(map_infos, tf_current_light)
812
+
813
+ # new_agents_array = process_agent(track_infos, tracks_to_predict, scenario_id, 0, 91) # mtr2argo
814
+ data = dict()
815
+ data.update(map_data)
816
+ data['scenario_id'] = scenario_id
817
+ data['agent'] = get_agent_features(track_infos, av_id, num_historical_steps=current_time_index + 1, num_steps=91)
818
+
819
+ with open(os.path.join(output_dir, f'{scenario_id}.pkl'), "wb+") as f:
820
+ pickle.dump(data, f)
821
+
822
+ if output_dir_tfrecords_splitted is not None:
823
+ tf_file = os.path.join(output_dir_tfrecords_splitted, f'{scenario_id}.tfrecords')
824
+ if not os.path.exists(tf_file):
825
+ with tf.io.TFRecordWriter(tf_file) as file_writer:
826
+ file_writer.write(tf_data.numpy())
827
+
828
+
829
+ def batch_process9s_transformer(input_dir, output_dir, split, num_workers=2):
830
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
831
+
832
+ output_dir_tfrecords_splitted = None
833
+ if split == "validation":
834
+ output_dir_tfrecords_splitted = os.path.join(output_dir, 'validation_tfrecords_splitted')
835
+ os.makedirs(output_dir_tfrecords_splitted, exist_ok=True)
836
+
837
+ input_dir = os.path.join(input_dir, split)
838
+ output_dir = os.path.join(output_dir, split)
839
+ os.makedirs(output_dir, exist_ok=True)
840
+
841
+ packages = sorted(os.listdir(input_dir))
842
+ existing_files = sorted(os.listdir(output_dir))
843
+ func = partial(
844
+ wm2argo,
845
+ output_dir=output_dir,
846
+ input_dir=input_dir,
847
+ existing_files=existing_files,
848
+ output_dir_tfrecords_splitted=output_dir_tfrecords_splitted
849
+ )
850
+ try:
851
+ with multiprocessing.Pool(num_workers, maxtasksperchild=10) as p:
852
+ r = list(tqdm(p.imap_unordered(func, packages), total=len(packages)))
853
+ except KeyboardInterrupt:
854
+ p.terminate()
855
+ p.join()
856
+
857
+
858
+ def generate_meta_infos(data_dir):
859
+ import json
860
+
861
+ meta_infos = dict()
862
+
863
+ for split in tqdm(['training', 'validation', 'test'], leave=False):
864
+ if not os.path.exists(os.path.join(data_dir, split)):
865
+ continue
866
+
867
+ split_infos = dict()
868
+ files = os.listdir(os.path.join(data_dir, split))
869
+ for file in tqdm(files, leave=False):
870
+ try:
871
+ data = pickle.load(open(os.path.join(data_dir, split, file), 'rb'))
872
+ except Exception as e:
873
+ tqdm.write(f'Failed to load scenario {file} due to {e}')
874
+ continue
875
+ scenario_infos = dict(num_agents=data['agent']['num_nodes'])
876
+ scenario_id = data['scenario_id']
877
+ split_infos[scenario_id] = scenario_infos
878
+
879
+ meta_infos[split] = split_infos
880
+
881
+ with open(os.path.join(data_dir, 'meta_infos.json'), 'w', encoding='utf-8') as f:
882
+ json.dump(meta_infos, f, indent=4)
883
+
884
+
885
+ if __name__ == "__main__":
886
+ parser = ArgumentParser()
887
+ parser.add_argument('--input_dir', type=str, default='data/waymo/')
888
+ parser.add_argument('--output_dir', type=str, default='data/waymo_processed/')
889
+ parser.add_argument('--split', type=str, default='validation')
890
+ parser.add_argument('--no_batch', action='store_true')
891
+ parser.add_argument('--disable_invalid', action="store_true")
892
+ parser.add_argument('--generate_meta_infos', action="store_true")
893
+ args = parser.parse_args()
894
+
895
+ if args.generate_meta_infos:
896
+ generate_meta_infos(args.output_dir)
897
+
898
+ elif args.no_batch:
899
+
900
+ output_dir_tfrecords_splitted = None
901
+ if args.split == "validation":
902
+ output_dir_tfrecords_splitted = os.path.join(args.output_dir, 'validation_tfrecords_splitted')
903
+ os.makedirs(output_dir_tfrecords_splitted, exist_ok=True)
904
+
905
+ input_dir = os.path.join(args.input_dir, args.split)
906
+ output_dir = os.path.join(args.output_dir, args.split)
907
+ os.makedirs(output_dir, exist_ok=True)
908
+
909
+ files = sorted(os.listdir(input_dir))
910
+ os.makedirs(args.output_dir, exist_ok=True)
911
+ for file in tqdm(files, leave=False, desc=f'Process {args.split}...'):
912
+ wm2argo(file, input_dir, output_dir, output_dir_tfrecords_splitted)
913
+
914
+ else:
915
+
916
+ batch_process9s_transformer(args.input_dir, args.output_dir, args.split, num_workers=96)
backups/dev/datasets/preprocess.py ADDED
@@ -0,0 +1,761 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import math
4
+ import pickle
5
+ import numpy as np
6
+ from torch import nn
7
+ from typing import Dict, Sequence
8
+ from scipy.interpolate import interp1d
9
+ from scipy.spatial.distance import euclidean
10
+ from dev.utils.func import wrap_angle
11
+
12
+
13
+ SHIFT = 5
14
+ AGENT_SHAPE = {
15
+ 'vehicle': [4.3, 1.8, 1.],
16
+ 'pedstrain': [0.5, 0.5, 1.],
17
+ 'cyclist': [1.9, 0.5, 1.],
18
+ }
19
+ AGENT_TYPE = ['veh', 'ped', 'cyc', 'seed']
20
+ AGENT_STATE = ['invalid', 'valid', 'enter', 'exit']
21
+
22
+
23
+ @torch.no_grad()
24
+ def cal_polygon_contour(pos, head, width_length) -> torch.Tensor: # [n_agent, n_step, n_target, 4, 2]
25
+ x, y = pos[..., 0], pos[..., 1] # [n_agent, n_step, n_target]
26
+ width, length = width_length[..., 0], width_length[..., 1] # [n_agent, 1, 1]
27
+
28
+ half_cos = 0.5 * head.cos() # [n_agent, n_step, n_target]
29
+ half_sin = 0.5 * head.sin() # [n_agent, n_step, n_target]
30
+ length_cos = length * half_cos # [n_agent, n_step, n_target]
31
+ length_sin = length * half_sin # [n_agent, n_step, n_target]
32
+ width_cos = width * half_cos # [n_agent, n_step, n_target]
33
+ width_sin = width * half_sin # [n_agent, n_step, n_target]
34
+
35
+ left_front_x = x + length_cos - width_sin
36
+ left_front_y = y + length_sin + width_cos
37
+ left_front = torch.stack((left_front_x, left_front_y), dim=-1)
38
+
39
+ right_front_x = x + length_cos + width_sin
40
+ right_front_y = y + length_sin - width_cos
41
+ right_front = torch.stack((right_front_x, right_front_y), dim=-1)
42
+
43
+ right_back_x = x - length_cos + width_sin
44
+ right_back_y = y - length_sin - width_cos
45
+ right_back = torch.stack((right_back_x, right_back_y), dim=-1)
46
+
47
+ left_back_x = x - length_cos - width_sin
48
+ left_back_y = y - length_sin + width_cos
49
+ left_back = torch.stack((left_back_x, left_back_y), dim=-1)
50
+
51
+ polygon_contour = torch.stack(
52
+ (left_front, right_front, right_back, left_back), dim=-2
53
+ )
54
+
55
+ return polygon_contour
56
+
57
+
58
+ def interplating_polyline(polylines, heading, distance=0.5, split_distace=5):
59
+ # Calculate the cumulative distance along the path, up-sample the polyline to 0.5 meter
60
+ dist_along_path_list = [[0]]
61
+ polylines_list = [[polylines[0]]]
62
+ for i in range(1, polylines.shape[0]):
63
+ euclidean_dist = euclidean(polylines[i, :2], polylines[i - 1, :2])
64
+ heading_diff = min(abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1])),
65
+ abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1]) + math.pi))
66
+ if heading_diff > math.pi / 4 and euclidean_dist > 3:
67
+ dist_along_path_list.append([0])
68
+ polylines_list.append([polylines[i]])
69
+ elif heading_diff > math.pi / 8 and euclidean_dist > 3:
70
+ dist_along_path_list.append([0])
71
+ polylines_list.append([polylines[i]])
72
+ elif heading_diff > 0.1 and euclidean_dist > 3:
73
+ dist_along_path_list.append([0])
74
+ polylines_list.append([polylines[i]])
75
+ elif euclidean_dist > 10:
76
+ dist_along_path_list.append([0])
77
+ polylines_list.append([polylines[i]])
78
+ else:
79
+ dist_along_path_list[-1].append(dist_along_path_list[-1][-1] + euclidean_dist)
80
+ polylines_list[-1].append(polylines[i])
81
+ # plt.plot(polylines[:, 0], polylines[:, 1])
82
+ # plt.savefig('tmp.jpg')
83
+ new_x_list = []
84
+ new_y_list = []
85
+ multi_polylines_list = []
86
+ for idx in range(len(dist_along_path_list)):
87
+ if len(dist_along_path_list[idx]) < 2:
88
+ continue
89
+ dist_along_path = np.array(dist_along_path_list[idx])
90
+ polylines_cur = np.array(polylines_list[idx])
91
+ # Create interpolation functions for x and y coordinates
92
+ fx = interp1d(dist_along_path, polylines_cur[:, 0])
93
+ fy = interp1d(dist_along_path, polylines_cur[:, 1])
94
+ # fyaw = interp1d(dist_along_path, heading)
95
+
96
+ # Create an array of distances at which to interpolate
97
+ new_dist_along_path = np.arange(0, dist_along_path[-1], distance)
98
+ new_dist_along_path = np.concatenate([new_dist_along_path, dist_along_path[[-1]]])
99
+ # Use the interpolation functions to generate new x and y coordinates
100
+ new_x = fx(new_dist_along_path)
101
+ new_y = fy(new_dist_along_path)
102
+ # new_yaw = fyaw(new_dist_along_path)
103
+ new_x_list.append(new_x)
104
+ new_y_list.append(new_y)
105
+
106
+ # Combine the new x and y coordinates into a single array
107
+ new_polylines = np.vstack((new_x, new_y)).T
108
+ polyline_size = int(split_distace / distance)
109
+ if new_polylines.shape[0] >= (polyline_size + 1):
110
+ padding_size = (new_polylines.shape[0] - (polyline_size + 1)) % polyline_size
111
+ final_index = (new_polylines.shape[0] - (polyline_size + 1)) // polyline_size + 1
112
+ else:
113
+ padding_size = new_polylines.shape[0]
114
+ final_index = 0
115
+ multi_polylines = None
116
+ new_polylines = torch.from_numpy(new_polylines)
117
+ new_heading = torch.atan2(new_polylines[1:, 1] - new_polylines[:-1, 1],
118
+ new_polylines[1:, 0] - new_polylines[:-1, 0])
119
+ new_heading = torch.cat([new_heading, new_heading[-1:]], -1)[..., None]
120
+ new_polylines = torch.cat([new_polylines, new_heading], -1)
121
+ if new_polylines.shape[0] >= (polyline_size + 1):
122
+ multi_polylines = new_polylines.unfold(dimension=0, size=polyline_size + 1, step=polyline_size)
123
+ multi_polylines = multi_polylines.transpose(1, 2)
124
+ multi_polylines = multi_polylines[:, ::5, :]
125
+ if padding_size >= 3:
126
+ last_polyline = new_polylines[final_index * polyline_size:]
127
+ last_polyline = last_polyline[torch.linspace(0, last_polyline.shape[0] - 1, steps=3).long()]
128
+ if multi_polylines is not None:
129
+ multi_polylines = torch.cat([multi_polylines, last_polyline.unsqueeze(0)], dim=0)
130
+ else:
131
+ multi_polylines = last_polyline.unsqueeze(0)
132
+ if multi_polylines is None:
133
+ continue
134
+ multi_polylines_list.append(multi_polylines)
135
+ if len(multi_polylines_list) > 0:
136
+ multi_polylines_list = torch.cat(multi_polylines_list, dim=0)
137
+ else:
138
+ multi_polylines_list = None
139
+ return multi_polylines_list
140
+
141
+
142
+ # def interplating_polyline(polylines, heading, distance=0.5, split_distance=5, device='cpu'):
143
+ # dist_along_path_list = [[0]]
144
+ # polylines_list = [[polylines[0]]]
145
+ # for i in range(1, polylines.shape[0]):
146
+ # euclidean_dist = torch.norm(polylines[i, :2] - polylines[i - 1, :2])
147
+ # heading_diff = min(abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1])),
148
+ # abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1]) + torch.pi))
149
+ # if heading_diff > torch.pi / 4 and euclidean_dist > 3:
150
+ # dist_along_path_list.append([0])
151
+ # polylines_list.append([polylines[i]])
152
+ # elif heading_diff > torch.pi / 8 and euclidean_dist > 3:
153
+ # dist_along_path_list.append([0])
154
+ # polylines_list.append([polylines[i]])
155
+ # elif heading_diff > 0.1 and euclidean_dist > 3:
156
+ # dist_along_path_list.append([0])
157
+ # polylines_list.append([polylines[i]])
158
+ # elif euclidean_dist > 10:
159
+ # dist_along_path_list.append([0])
160
+ # polylines_list.append([polylines[i]])
161
+ # else:
162
+ # dist_along_path_list[-1].append(dist_along_path_list[-1][-1] + euclidean_dist)
163
+ # polylines_list[-1].append(polylines[i])
164
+
165
+ # new_x_list = []
166
+ # new_y_list = []
167
+ # multi_polylines_list = []
168
+
169
+ # for idx in range(len(dist_along_path_list)):
170
+ # if len(dist_along_path_list[idx]) < 2:
171
+ # continue
172
+
173
+ # dist_along_path = torch.tensor(dist_along_path_list[idx], device=device)
174
+ # polylines_cur = torch.stack(polylines_list[idx])
175
+
176
+ # new_dist_along_path = torch.arange(0, dist_along_path[-1], distance)
177
+ # new_dist_along_path = torch.cat([new_dist_along_path, dist_along_path[[-1]]])
178
+
179
+ # new_x = torch.interp(new_dist_along_path, dist_along_path, polylines_cur[:, 0])
180
+ # new_y = torch.interp(new_dist_along_path, dist_along_path, polylines_cur[:, 1])
181
+
182
+ # new_x_list.append(new_x)
183
+ # new_y_list.append(new_y)
184
+
185
+ # new_polylines = torch.stack((new_x, new_y), dim=-1)
186
+
187
+ # polyline_size = int(split_distance / distance)
188
+ # if new_polylines.shape[0] >= (polyline_size + 1):
189
+ # padding_size = (new_polylines.shape[0] - (polyline_size + 1)) % polyline_size
190
+ # final_index = (new_polylines.shape[0] - (polyline_size + 1)) // polyline_size + 1
191
+ # else:
192
+ # padding_size = new_polylines.shape[0]
193
+ # final_index = 0
194
+
195
+ # multi_polylines = None
196
+ # new_heading = torch.atan2(new_polylines[1:, 1] - new_polylines[:-1, 1],
197
+ # new_polylines[1:, 0] - new_polylines[:-1, 0])
198
+ # new_heading = torch.cat([new_heading, new_heading[-1:]], -1)[..., None]
199
+ # new_polylines = torch.cat([new_polylines, new_heading], -1)
200
+
201
+ # if new_polylines.shape[0] >= (polyline_size + 1):
202
+ # multi_polylines = new_polylines.unfold(dimension=0, size=polyline_size + 1, step=polyline_size)
203
+ # multi_polylines = multi_polylines.transpose(1, 2)
204
+ # multi_polylines = multi_polylines[:, ::5, :]
205
+
206
+ # if padding_size >= 3:
207
+ # last_polyline = new_polylines[final_index * polyline_size:]
208
+ # last_polyline = last_polyline[torch.linspace(0, last_polyline.shape[0] - 1, steps=3).long()]
209
+ # if multi_polylines is not None:
210
+ # multi_polylines = torch.cat([multi_polylines, last_polyline.unsqueeze(0)], dim=0)
211
+ # else:
212
+ # multi_polylines = last_polyline.unsqueeze(0)
213
+
214
+ # if multi_polylines is None:
215
+ # continue
216
+ # multi_polylines_list.append(multi_polylines)
217
+
218
+ # if len(multi_polylines_list) > 0:
219
+ # multi_polylines_list = torch.cat(multi_polylines_list, dim=0)
220
+ # else:
221
+ # multi_polylines_list = None
222
+
223
+ # return multi_polylines_list
224
+
225
+
226
+ def average_distance_vectorized(point_set1, centroids):
227
+ dists = np.sqrt(np.sum((point_set1[:, None, :, :] - centroids[None, :, :, :]) ** 2, axis=-1))
228
+ return np.mean(dists, axis=2)
229
+
230
+
231
+ def assign_clusters(sub_X, centroids):
232
+ distances = average_distance_vectorized(sub_X, centroids)
233
+ return np.argmin(distances, axis=1)
234
+
235
+
236
+ class TokenProcessor(nn.Module):
237
+
238
+ def __init__(self, token_size,
239
+ training: bool=False,
240
+ predict_motion: bool=False,
241
+ predict_state: bool=False,
242
+ predict_map: bool=False,
243
+ state_token: Dict[str, int]=None, **kwargs):
244
+ super().__init__()
245
+
246
+ module_dir = os.path.dirname(os.path.dirname(__file__))
247
+ self.agent_token_path = os.path.join(module_dir, f'tokens/agent_vocab_555_s2.pkl')
248
+ self.map_token_traj_path = os.path.join(module_dir, 'tokens/map_traj_token5.pkl')
249
+ assert os.path.exists(self.agent_token_path), f"File {self.agent_token_path} not found."
250
+ assert os.path.exists(self.map_token_traj_path), f"File {self.map_token_traj_path} not found."
251
+
252
+ self.training = training
253
+ self.token_size = token_size
254
+ self.disable_invalid = not predict_state
255
+ self.predict_motion = predict_motion
256
+ self.predict_state = predict_state
257
+ self.predict_map = predict_map
258
+
259
+ # define new special tokens
260
+ self.bos_token_index = token_size
261
+ self.eos_token_index = token_size + 1
262
+ self.invalid_token_index = token_size + 2
263
+ self.special_token_index = []
264
+ self._init_token()
265
+
266
+ # define agent states
267
+ self.invalid_state = int(state_token['invalid'])
268
+ self.valid_state = int(state_token['valid'])
269
+ self.enter_state = int(state_token['enter'])
270
+ self.exit_state = int(state_token['exit'])
271
+
272
+ self.pl2seed_radius = kwargs.get('pl2seed_radius', None)
273
+
274
+ self.noise = False
275
+ self.disturb = False
276
+ self.shift = 5
277
+ self.training = False
278
+ self.current_step = 10
279
+
280
+ # debugging
281
+ self.debug_data = None
282
+
283
+ def forward(self, data):
284
+ """
285
+ Each pkl data represents a extracted scenario from raw tfrecord data
286
+ """
287
+ data['agent']['av_index'] = data['agent']['av_idx']
288
+ data = self._tokenize_agent(data)
289
+ # data = self._tokenize_map(data)
290
+ del data['city']
291
+ if 'polygon_is_intersection' in data['map_polygon']:
292
+ del data['map_polygon']['polygon_is_intersection']
293
+ if 'route_type' in data['map_polygon']:
294
+ del data['map_polygon']['route_type']
295
+
296
+ av_index = int(data['agent']['av_idx'])
297
+ data['ego_pos'] = data['agent']['token_pos'][[av_index]]
298
+ data['ego_heading'] = data['agent']['token_heading'][[av_index]]
299
+
300
+ return data
301
+
302
+ def _init_token(self):
303
+
304
+ agent_token_data = pickle.load(open(self.agent_token_path, 'rb'))
305
+ for agent_type, token in agent_token_data['token_all'].items():
306
+ token = torch.tensor(token, dtype=torch.float32)
307
+ self.register_buffer(f'agent_token_all_{agent_type}', token, persistent=False) # [n_token, 6, 4, 2]
308
+
309
+ map_token_traj = pickle.load(open(self.map_token_traj_path, 'rb'))['traj_src']
310
+ map_token_traj = torch.tensor(map_token_traj, dtype=torch.float32)
311
+ self.register_buffer('map_token_traj_src', map_token_traj, persistent=False) # [n_token, 11 * 2]
312
+
313
+ # self.trajectory_token = agent_token_data['token'] # (token_size, 4, 2)
314
+ # self.trajectory_token_all = agent_token_data['token_all'] # (token_size, shift + 1, 4, 2)
315
+ # self.map_token = {'traj_src': map_token_traj['traj_src']}
316
+
317
+ @staticmethod
318
+ def clean_heading(valid: torch.Tensor, heading: torch.Tensor) -> torch.Tensor:
319
+ valid_pairs = valid[:, :-1] & valid[:, 1:]
320
+ for i in range(heading.shape[1] - 1):
321
+ heading_diff = torch.abs(wrap_angle(heading[:, i] - heading[:, i + 1]))
322
+ change_needed = (heading_diff > 1.5) & valid_pairs[:, i]
323
+ heading[:, i + 1][change_needed] = heading[:, i][change_needed]
324
+ return heading
325
+
326
+ def _extrapolate_agent_to_prev_token_step(self, valid, pos, heading, vel) -> Sequence[torch.Tensor]:
327
+ # [n_agent], max will give the first True step
328
+ first_valid_step = torch.max(valid, dim=1).indices
329
+
330
+ for i, t in enumerate(first_valid_step): # extrapolate to previous 5th step.
331
+ n_step_to_extrapolate = t % self.shift
332
+ if (t == self.current_step) and (not valid[i, self.current_step - self.shift]):
333
+ # such that at least one token is valid in the history.
334
+ n_step_to_extrapolate = self.shift
335
+
336
+ if n_step_to_extrapolate > 0:
337
+ vel[i, t - n_step_to_extrapolate : t] = vel[i, t]
338
+ valid[i, t - n_step_to_extrapolate : t] = True
339
+ heading[i, t - n_step_to_extrapolate : t] = heading[i, t]
340
+
341
+ for j in range(n_step_to_extrapolate):
342
+ pos[i, t - j - 1] = pos[i, t - j] - vel[i, t] * 0.1
343
+
344
+ return valid, pos, heading, vel
345
+
346
+ def _get_agent_shape(self, agent_type_masks: dict) -> torch.Tensor:
347
+ agent_shape = 0.
348
+ for type, type_mask in agent_type_masks.items():
349
+ if type == 'veh': width = 2.; length = 4.8
350
+ if type == 'ped': width = 1.; length = 2.
351
+ if type == 'cyc': width = 1.; length = 1.
352
+ agent_shape += torch.stack([width * type_mask, length * type_mask], dim=-1)
353
+
354
+ return agent_shape
355
+
356
+ def _get_token_traj_all(self, agent_type_masks: dict) -> torch.Tensor:
357
+ token_traj_all = 0.
358
+ for type, type_mask in agent_type_masks.items():
359
+ token_traj_all += type_mask[:, None, None, None, None] * (
360
+ getattr(self, f'agent_token_all_{type}').unsqueeze(0)
361
+ )
362
+ return token_traj_all
363
+
364
+ def _tokenize_agent(self, data):
365
+
366
+ # get raw data
367
+ valid_mask = data['agent']['valid_mask'] # [n_agent, n_step]
368
+ agent_heading = data['agent']['heading'] # [n_agent, n_step]
369
+ agent_pos = data['agent']['position'][..., :2].contiguous() # [n_agent, n_step, 2]
370
+ agent_vel = data['agent']['velocity'] # [n_agent, n_step, 2]
371
+ agent_type = data['agent']['type']
372
+ agent_category = data['agent']['category']
373
+
374
+ n_agent, n_all_step = valid_mask.shape
375
+
376
+ agent_type_masks = {
377
+ "veh": agent_type == 0,
378
+ "ped": agent_type == 1,
379
+ "cyc": agent_type == 2,
380
+ }
381
+ agent_heading = self.clean_heading(valid_mask, agent_heading)
382
+ agent_shape = self._get_agent_shape(agent_type_masks)
383
+ token_traj_all = self._get_token_traj_all(agent_type_masks)
384
+ valid_mask, agent_pos, agent_heading, agent_vel = self._extrapolate_agent_to_prev_token_step(
385
+ valid_mask, agent_pos, agent_heading, agent_vel
386
+ )
387
+ token_traj = token_traj_all[:, :, -1, ...]
388
+ data['agent']['token_traj_all'] = token_traj_all # [n_agent, n_token, 6, 4, 2]
389
+ data['agent']['token_traj'] = token_traj # [n_agent, n_token, 4, 2]
390
+
391
+ valid_mask_shift = valid_mask.unfold(1, self.shift + 1, self.shift)
392
+ token_valid_mask = valid_mask_shift[:, :, 0] * valid_mask_shift[:, :, -1]
393
+
394
+ # vehicle_mask = agent_type == 0
395
+ # cyclist_mask = agent_type == 2
396
+ # ped_mask = agent_type == 1
397
+
398
+ # veh_pos = agent_pos[vehicle_mask, :, :]
399
+ # veh_valid_mask = valid_mask[vehicle_mask, :]
400
+ # cyc_pos = agent_pos[cyclist_mask, :, :]
401
+ # cyc_valid_mask = valid_mask[cyclist_mask, :]
402
+ # ped_pos = agent_pos[ped_mask, :, :]
403
+ # ped_valid_mask = valid_mask[ped_mask, :]
404
+
405
+ # index: [n_agent, n_step] contour: [n_agent, n_step, 4, 2]
406
+ token_index, token_contour, token_all = self._match_agent_token(
407
+ valid_mask, agent_pos, agent_heading, agent_shape, token_traj, None # token_traj_all
408
+ )
409
+
410
+ traj_pos = traj_heading = None
411
+ if len(token_all) > 0:
412
+ traj_pos = token_all.mean(dim=3) # [n_agent, n_step, 6, 2]
413
+ diff_xy = token_all[..., 0, :] - token_all[..., 3, :]
414
+ traj_heading = torch.arctan2(diff_xy[..., 1], diff_xy[..., 0])
415
+ token_pos = token_contour.mean(dim=2) # [n_agent, n_step, 2]
416
+ diff_xy = token_contour[:, :, 0, :] - token_contour[:, :, 3, :]
417
+ token_heading = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])
418
+
419
+ # token_index: (num_agent, num_timestep // shift) e.g. (49, 18)
420
+ # token_contour: (num_agent, num_timestep // shift, contour_dim, feat_dim, 2) e.g. (49, 18, 4, 2)
421
+ # veh_token_index, veh_token_contour = self._match_agent_token(veh_valid_mask, veh_pos, agent_heading[vehicle_mask],
422
+ # 'veh', agent_shape[vehicle_mask])
423
+ # ped_token_index, ped_token_contour = self._match_agent_token(ped_valid_mask, ped_pos, agent_heading[ped_mask],
424
+ # 'ped', agent_shape[ped_mask])
425
+ # cyc_token_index, cyc_token_contour = self._match_agent_token(cyc_valid_mask, cyc_pos, agent_heading[cyclist_mask],
426
+ # 'cyc', agent_shape[cyclist_mask])
427
+
428
+ # token_index = torch.zeros((agent_pos.shape[0], veh_token_index.shape[1])).to(torch.int64)
429
+ # token_index[vehicle_mask] = veh_token_index
430
+ # token_index[ped_mask] = ped_token_index
431
+ # token_index[cyclist_mask] = cyc_token_index
432
+
433
+ # ! compute agent states
434
+ bos_index = torch.argmax(token_valid_mask.long(), dim=1)
435
+ eos_index = token_valid_mask.shape[1] - 1 - torch.argmax(torch.flip(token_valid_mask.long(), dims=[1]), dim=1)
436
+ state_index = torch.ones_like(token_index) # init with all valid
437
+ step_index = torch.arange(state_index.shape[1])[None].repeat(state_index.shape[0], 1).to(token_index.device)
438
+ state_index[step_index == bos_index[:, None]] = self.enter_state
439
+ state_index[step_index == eos_index[:, None]] = self.exit_state
440
+ state_index[(step_index < bos_index[:, None]) | (step_index > eos_index[:, None])] = self.invalid_state
441
+ # ! IMPORTANT: if the last step is exit token, should convert it back to valid token
442
+ state_index[state_index[:, -1] == self.exit_state, -1] = self.valid_state
443
+
444
+ # update token attributions according to state tokens
445
+ token_valid_mask[state_index == self.enter_state] = False
446
+ token_pos[state_index == self.invalid_state] = 0.
447
+ token_heading[state_index == self.invalid_state] = 0.
448
+ for i in range(self.shift, agent_pos.shape[1], self.shift):
449
+ is_bos = state_index[:, i // self.shift - 1] == self.enter_state
450
+ token_pos[is_bos, i // self.shift - 1] = agent_pos[is_bos, i].clone()
451
+ # token_heading[is_bos, i // self.shift - 1] = agent_heading[is_bos, i].clone()
452
+ token_index[state_index == self.invalid_state] = -1
453
+ token_index[state_index == self.enter_state] = -2
454
+
455
+ # acc_token_valid_step = torch.concat([torch.zeros_like(token_valid_mask[:, :1]),
456
+ # torch.cumsum(token_valid_mask.int(), dim=1),
457
+ # torch.zeros_like(token_valid_mask[:, -1:])], dim=1)
458
+ # state_index = torch.ones_like(token_index) # init with all valid
459
+ # max_valid_index = torch.argmax(acc_token_valid_step, dim=1)
460
+ # for step in range(1, acc_token_valid_step.shape[1] - 1):
461
+
462
+ # # replace part of motion tokens with special tokens
463
+ # is_bos = (acc_token_valid_step[:, step] == 0) & (acc_token_valid_step[:, step + 1] == 1)
464
+ # is_eos = (step == max_valid_index) & (step < acc_token_valid_step.shape[1] - 2) & ~is_bos
465
+ # is_invalid = ~token_valid_mask[:, step - 1] & ~is_bos & ~is_eos
466
+
467
+ # state_index[is_bos, step - 1] = self.enter_state
468
+ # state_index[is_eos, step - 1] = self.exit_state
469
+ # state_index[is_invalid, step - 1] = self.invalid_state
470
+
471
+ # token_valid_mask[state_index[:, 0] == self.valid_state, 0] = False
472
+ # state_index[state_index[:, 0] == self.valid_state, 0] = self.enter_state
473
+
474
+ # token_contour = torch.zeros((agent_pos.shape[0], veh_token_contour.shape[1],
475
+ # veh_token_contour.shape[2], veh_token_contour.shape[3]))
476
+ # token_contour[vehicle_mask] = veh_token_contour
477
+ # token_contour[ped_mask] = ped_token_contour
478
+ # token_contour[cyclist_mask] = cyc_token_contour
479
+
480
+ raw_token_valid_mask = token_valid_mask.clone()
481
+ if not self.disable_invalid:
482
+ token_valid_mask = torch.ones_like(token_valid_mask).bool()
483
+
484
+ # apply mask
485
+ # apply_mask = raw_token_valid_mask.sum(dim=-1) > 2
486
+ # if self.training and os.getenv('AUG_MASK', False):
487
+ # aug_mask = torch.randint(0, 2, (raw_token_valid_mask.shape[0],)).to(raw_token_valid_mask).bool()
488
+ # apply_mask &= aug_mask
489
+
490
+ # remove invalid agents which are outside the range of pl2inva_radius
491
+ # remove_ina_mask = torch.zeros_like(data['agent']['train_mask'])
492
+ # if self.pl2seed_radius is not None:
493
+ # num_history_token = 1 if self.training else 2 # NOTE: hard code!!!
494
+ # av_index = int(data['agent']['av_index'])
495
+ # is_invalid = torch.any(state_index[:, :num_history_token] == self.invalid_state, dim=-1)
496
+ # ina_bos_mask = (state_index == self.enter_state) & is_invalid[:, None]
497
+ # invalid_bos_step = torch.nonzero(ina_bos_mask, as_tuple=False)
498
+ # av_bos_pos = token_pos[av_index, invalid_bos_step[:, 1]] # (num_invalid_bos, 2)
499
+ # ina_bos_pos = token_pos[invalid_bos_step[:, 0], invalid_bos_step[:, 1]] # (num_invalid_bos, 2)
500
+ # distance = torch.sqrt(torch.sum((ina_bos_pos - av_bos_pos) ** 2, dim=-1))
501
+ # remove_ina_mask = (distance > self.pl2seed_radius) | (distance < 0.)
502
+ # # apply_mask[invalid_bos_step[remove_ina_mask, 0]] = False
503
+
504
+ # data['agent']['remove_ina_mask'] = remove_ina_mask
505
+
506
+ # apply_mask[int(data['agent']['av_index'])] = True
507
+ # data['agent']['num_nodes'] = apply_mask.sum()
508
+
509
+ # av_id = data['agent']['id'][data['agent']['av_index']]
510
+ # data['agent']['id'] = [data['agent']['id'][i] for i in range(len(apply_mask)) if apply_mask[i]]
511
+ # data['agent']['av_index'] = data['agent']['id'].index(av_id)
512
+ # data['agent']['id'] = torch.tensor(data['agent']['id'], dtype=torch.long)
513
+
514
+ # agent_keys = ['valid_mask', 'predict_mask', 'type', 'category', 'position', 'heading', 'velocity', 'shape']
515
+ # for key in agent_keys:
516
+ # if key in data['agent']:
517
+ # data['agent'][key] = data['agent'][key][apply_mask]
518
+
519
+ # reset agent shapes
520
+ for i in range(n_agent):
521
+ bos_shape_index = torch.nonzero(torch.all(data['agent']['shape'][i] != 0., dim=-1))[0]
522
+ data['agent']['shape'][i, :] = data['agent']['shape'][i, bos_shape_index]
523
+ if torch.any(torch.all(data['agent']['shape'][i] == 0., dim=-1)):
524
+ raise ValueError(f"Found invalid shape values.")
525
+
526
+ # compute mean height values for each scenario
527
+ raw_height = data['agent']['position'][:, self.current_step, 2]
528
+ valid_height = raw_token_valid_mask[:, 1].bool()
529
+ veh_mean_z = raw_height[agent_type_masks['veh'] & valid_height].mean()
530
+ ped_mean_z = raw_height[agent_type_masks['ped'] & valid_height].mean().nan_to_num_(veh_mean_z) # FIXME: hard code
531
+ cyc_mean_z = raw_height[agent_type_masks['cyc'] & valid_height].mean().nan_to_num_(veh_mean_z)
532
+
533
+ # output
534
+ data['agent']['token_idx'] = token_index
535
+ data['agent']['state_idx'] = state_index
536
+ data['agent']['token_contour'] = token_contour
537
+ data['agent']['traj_pos'] = traj_pos
538
+ data['agent']['traj_heading'] = traj_heading
539
+ data['agent']['token_pos'] = token_pos
540
+ data['agent']['token_heading'] = token_heading
541
+ data['agent']['agent_valid_mask'] = token_valid_mask # (a, t)
542
+ data['agent']['raw_agent_valid_mask'] = raw_token_valid_mask
543
+ data['agent']['raw_height'] = dict(veh=veh_mean_z,
544
+ ped=ped_mean_z,
545
+ cyc=cyc_mean_z)
546
+ for type in ['veh', 'ped', 'cyc']:
547
+ data['agent'][f'trajectory_token_{type}'] = getattr(
548
+ self, f'agent_token_all_{type}') # [n_token, 6, 4, 2]
549
+
550
+ return data
551
+
552
+ def _match_agent_token(self, valid_mask, pos, heading, shape, token_traj, token_traj_all=None):
553
+ """
554
+ Parameters:
555
+ valid_mask (torch.Tensor): Validity mask for agents over time. Shape: (n_agent, n_step)
556
+ pos (torch.Tensor): Positions of agents at each time step. Shape: (n_agent, n_step, 3)
557
+ heading (torch.Tensor): Headings of agents at each time step. Shape: (n_agent, n_step)
558
+ shape (torch.Tensor): Shape information of agents. Shape: (n_agent, 3)
559
+ token_traj (torch.Tensor): Token trajectories for agents. Shape: (n_agent, n_token, 4, 2)
560
+ token_traj_all (torch.Tensor): Token trajectories for all agents. Shape: (n_agnet, n_token_all, n_contour, 4, 2)
561
+
562
+ Returns:
563
+ tuple: Contains token indices and contours for agents.
564
+ """
565
+
566
+ n_agent, n_step = valid_mask.shape
567
+
568
+ # agent_token_src = self.trajectory_token[category]
569
+ # if self.shift <= 2:
570
+ # if category == 'veh':
571
+ # width = 1.0
572
+ # length = 2.4
573
+ # elif category == 'cyc':
574
+ # width = 0.5
575
+ # length = 1.5
576
+ # else:
577
+ # width = 0.5
578
+ # length = 0.5
579
+ # else:
580
+ # if category == 'veh':
581
+ # width = 2.0
582
+ # length = 4.8
583
+ # elif category == 'cyc':
584
+ # width = 1.0
585
+ # length = 2.0
586
+ # else:
587
+ # width = 1.0
588
+ # length = 1.0
589
+
590
+ _, n_token, token_contour_dim, feat_dim = token_traj.shape
591
+ # agent_token_src = agent_token_src.reshape(1, token_num * token_contour_dim, feat_dim).repeat(agent_num, 0)
592
+
593
+ token_index_list = []
594
+ token_contour_list = []
595
+ token_all = []
596
+
597
+ prev_heading = heading[:, 0]
598
+ prev_pos = pos[:, 0]
599
+ prev_token_idx = None
600
+ for i in range(self.shift, n_step, self.shift): # [5, 10, 15, ..., 90]
601
+ _valid_mask = valid_mask[:, i - self.shift] & valid_mask[:, i]
602
+ _invalid_mask = ~_valid_mask
603
+
604
+ # transformation
605
+ theta = prev_heading
606
+ cos, sin = theta.cos(), theta.sin()
607
+ rot_mat = theta.new_zeros(n_agent, 2, 2)
608
+ rot_mat[:, 0, 0] = cos
609
+ rot_mat[:, 0, 1] = sin
610
+ rot_mat[:, 1, 0] = -sin
611
+ rot_mat[:, 1, 1] = cos
612
+ agent_token_world = torch.bmm(token_traj.flatten(1, 2), rot_mat).reshape(*token_traj.shape)
613
+ agent_token_world += prev_pos[:, None, None, :]
614
+
615
+ cur_contour = cal_polygon_contour(pos[:, i], heading[:, i], shape) # [n_agent, 4, 2]
616
+ agent_token_index = torch.argmin(
617
+ torch.norm(agent_token_world - cur_contour[:, None, ...], dim=-1).sum(-1), dim=-1
618
+ )
619
+ agent_token_contour = agent_token_world[torch.arange(n_agent), agent_token_index] # [n_agent, 4, 2]
620
+ # agent_token_index = torch.from_numpy(np.argmin(
621
+ # np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), axis=2),
622
+ # axis=-1))
623
+
624
+ # except for the first timestep TODO
625
+ if prev_token_idx is not None and self.noise:
626
+ same_idx = prev_token_idx == agent_token_index
627
+ same_idx[:] = True
628
+ topk_indices = np.argsort(
629
+ np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)),
630
+ axis=2), axis=-1)[:, :5]
631
+ sample_topk = np.random.choice(range(0, topk_indices.shape[1]), topk_indices.shape[0])
632
+ agent_token_index[same_idx] = \
633
+ torch.from_numpy(topk_indices[np.arange(topk_indices.shape[0]), sample_topk])[same_idx]
634
+
635
+ # update prev_heading
636
+ prev_heading = heading[:, i].clone()
637
+ diff_xy = agent_token_contour[:, 0] - agent_token_contour[:, 3]
638
+ prev_heading[_valid_mask] = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])[_valid_mask]
639
+
640
+ # update prev_pos
641
+ prev_pos = pos[:, i].clone()
642
+ prev_pos[_valid_mask] = agent_token_contour.mean(dim=1)[_valid_mask]
643
+
644
+ prev_token_idx = agent_token_index
645
+ token_index_list.append(agent_token_index)
646
+ token_contour_list.append(agent_token_contour)
647
+
648
+ # calculate tokenized trajectory
649
+ if token_traj_all is not None:
650
+ agent_token_all_world = torch.bmm(token_traj_all.flatten(1, 3), rot_mat).reshape(*token_traj_all.shape)
651
+ agent_token_all_world += prev_pos[:, None, None, None, :]
652
+ agent_token_all = agent_token_all_world[torch.arange(n_agent), agent_token_index] # [n_agent, 6, 4, 2]
653
+ token_all.append(agent_token_all)
654
+
655
+ token_index = torch.stack(token_index_list, dim=1) # [n_agent, n_step]
656
+ token_contour = torch.stack(token_contour_list, dim=1) # [n_agent, n_step, 4, 2]
657
+ if len(token_all) > 0:
658
+ token_all = torch.stack(token_all, dim=1) # [n_agent, n_step, 6, 4, 2]
659
+
660
+ # sanity check
661
+ assert tuple(token_index.shape) == (n_agent, n_step // self.shift), \
662
+ f'Invalid token_index shape, got {token_index.shape}'
663
+ assert tuple(token_contour.shape )== (n_agent, n_step // self.shift, token_contour_dim, feat_dim), \
664
+ f'Invalid token_contour shape, got {token_contour.shape}'
665
+
666
+ # extra matching
667
+ # if not self.training:
668
+ # theta = heading[extra_mask, self.current_step - 1]
669
+ # prev_pos = pos[extra_mask, self.current_step - 1]
670
+ # cur_pos = pos[extra_mask, self.current_step]
671
+ # cur_heading = heading[extra_mask, self.current_step]
672
+ # cos, sin = theta.cos(), theta.sin()
673
+ # rot_mat = theta.new_zeros(extra_mask.sum(), 2, 2)
674
+ # rot_mat[:, 0, 0] = cos
675
+ # rot_mat[:, 0, 1] = sin
676
+ # rot_mat[:, 1, 0] = -sin
677
+ # rot_mat[:, 1, 1] = cos
678
+ # agent_token_world = torch.bmm(torch.from_numpy(token_last).to(torch.float), rot_mat).reshape(
679
+ # extra_mask.sum(), token_num, token_contour_dim, feat_dim)
680
+ # agent_token_world += prev_pos[:, None, None, :]
681
+
682
+ # cur_contour = cal_polygon_contour(cur_pos[:, 0], cur_pos[:, 1], cur_heading, width, length)
683
+ # agent_token_index = torch.from_numpy(np.argmin(
684
+ # np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), axis=2),
685
+ # axis=-1))
686
+ # token_contour_select = agent_token_world[torch.arange(extra_mask.sum()), agent_token_index]
687
+
688
+ # token_index[extra_mask, 1] = agent_token_index
689
+ # token_contour[extra_mask, 1] = token_contour_select
690
+
691
+ return token_index, token_contour, token_all
692
+
693
+ @staticmethod
694
+ def _tokenize_map(data):
695
+
696
+ data['map_polygon']['type'] = data['map_polygon']['type'].to(torch.uint8)
697
+ data['map_point']['type'] = data['map_point']['type'].to(torch.uint8)
698
+ pt2pl = data[('map_point', 'to', 'map_polygon')]['edge_index']
699
+ pt_type = data['map_point']['type'].to(torch.uint8)
700
+ pt_side = torch.zeros_like(pt_type)
701
+ pt_pos = data['map_point']['position'][:, :2]
702
+ data['map_point']['orientation'] = wrap_angle(data['map_point']['orientation'])
703
+ pt_heading = data['map_point']['orientation']
704
+ split_polyline_type = []
705
+ split_polyline_pos = []
706
+ split_polyline_theta = []
707
+ split_polyline_side = []
708
+ pl_idx_list = []
709
+ split_polygon_type = []
710
+ data['map_point']['type'].unique()
711
+
712
+ for i in sorted(np.unique(pt2pl[1])): # number of polygons in the scenario
713
+ index = pt2pl[0, pt2pl[1] == i] # index of points which belongs to i-th polygon
714
+ polygon_type = data['map_polygon']["type"][i]
715
+ cur_side = pt_side[index]
716
+ cur_type = pt_type[index]
717
+ cur_pos = pt_pos[index]
718
+ cur_heading = pt_heading[index]
719
+
720
+ for side_val in np.unique(cur_side):
721
+ for type_val in np.unique(cur_type):
722
+ if type_val == 13:
723
+ continue
724
+ indices = np.where((cur_side == side_val) & (cur_type == type_val))[0]
725
+ if len(indices) <= 2:
726
+ continue
727
+ split_polyline = interplating_polyline(cur_pos[indices].numpy(), cur_heading[indices].numpy())
728
+ if split_polyline is None:
729
+ continue
730
+ new_cur_type = cur_type[indices][0]
731
+ new_cur_side = cur_side[indices][0]
732
+ map_polygon_type = polygon_type.repeat(split_polyline.shape[0])
733
+ new_cur_type = new_cur_type.repeat(split_polyline.shape[0])
734
+ new_cur_side = new_cur_side.repeat(split_polyline.shape[0])
735
+ cur_pl_idx = torch.Tensor([i])
736
+ new_cur_pl_idx = cur_pl_idx.repeat(split_polyline.shape[0])
737
+ split_polyline_pos.append(split_polyline[..., :2])
738
+ split_polyline_theta.append(split_polyline[..., 2])
739
+ split_polyline_type.append(new_cur_type)
740
+ split_polyline_side.append(new_cur_side)
741
+ pl_idx_list.append(new_cur_pl_idx)
742
+ split_polygon_type.append(map_polygon_type)
743
+
744
+ split_polyline_pos = torch.cat(split_polyline_pos, dim=0)
745
+ split_polyline_theta = torch.cat(split_polyline_theta, dim=0)
746
+ split_polyline_type = torch.cat(split_polyline_type, dim=0)
747
+ split_polyline_side = torch.cat(split_polyline_side, dim=0)
748
+ split_polygon_type = torch.cat(split_polygon_type, dim=0)
749
+ pl_idx_list = torch.cat(pl_idx_list, dim=0)
750
+
751
+ data['map_save'] = {}
752
+ data['pt_token'] = {}
753
+ data['map_save']['traj_pos'] = split_polyline_pos
754
+ data['map_save']['traj_theta'] = split_polyline_theta[:, 0] # torch.arctan2(vec[:, 1], vec[:, 0])
755
+ data['map_save']['pl_idx_list'] = pl_idx_list
756
+ data['pt_token']['type'] = split_polyline_type
757
+ data['pt_token']['side'] = split_polyline_side
758
+ data['pt_token']['pl_type'] = split_polygon_type
759
+ data['pt_token']['num_nodes'] = split_polyline_pos.shape[0]
760
+
761
+ return data
backups/dev/datasets/scalable_dataset.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import torch
4
+ import json
5
+ import pytorch_lightning as pl
6
+ import pandas as pd
7
+ from tqdm import tqdm
8
+ from torch_geometric.data import HeteroData, Dataset
9
+ from torch_geometric.transforms import BaseTransform
10
+ from torch_geometric.loader import DataLoader
11
+ from typing import Callable, Dict, List, Optional
12
+
13
+ from dev.datasets.preprocess import TokenProcessor
14
+
15
+
16
+ class MultiDataset(Dataset):
17
+ def __init__(self,
18
+ split: str,
19
+ raw_dir: List[str] = None,
20
+ transform: Optional[Callable] = None,
21
+ tfrecord_dir: Optional[str] = None,
22
+ token_size=512,
23
+ predict_motion: bool=False,
24
+ predict_state: bool=False,
25
+ predict_map: bool=False,
26
+ # state_token: Dict[str, int]=None,
27
+ # pl2seed_radius: float=None,
28
+ buffer_size: int=128,
29
+ logger=None) -> None:
30
+
31
+ self.disable_invalid = not predict_state
32
+ self.predict_motion = predict_motion
33
+ self.predict_state = predict_state
34
+ self.predict_map = predict_map
35
+ self.logger = logger
36
+ if split not in ('train', 'val', 'test'):
37
+ raise ValueError(f'{split} is not a valid split')
38
+ self.training = split == 'train'
39
+ self.buffer_size = buffer_size
40
+ self._tfrecord_dir = tfrecord_dir
41
+ self.logger.debug('Starting loading dataset')
42
+
43
+ raw_dir = os.path.expanduser(os.path.normpath(raw_dir))
44
+ self._raw_files = sorted(os.listdir(raw_dir))
45
+
46
+ # for debugging
47
+ if int(os.getenv('OVERFIT', 0)):
48
+ # if self.training:
49
+ # # scenario_id = ['74ad7b76d5906d39', '13596229fd8cdb7e', '1d73db1fc42be3bf', '1351ea8b8333ddcb']
50
+ # self._raw_files = ['74ad7b76d5906d39.pkl'] + self._raw_files[:9]
51
+ # else:
52
+ # self._raw_files = self._raw_files[:10]
53
+ self._raw_files = self._raw_files[:1]
54
+ # self._raw_files = ['1002fdc9826fc6d1.pkl']
55
+
56
+ # load meta infos and do filter
57
+ json_path = '/u/xiuyu/work/dev4/data/waymo_processed/meta_infos.json'
58
+ label = 'training' if split == 'train' else ('validation' if split == 'val' else split)
59
+ self.meta_infos = json.load(open(json_path, 'r', encoding='utf-8'))[label]
60
+ self.logger.debug(f"Loaded meta infos from {json_path}")
61
+ self.available_scenarios = list(self.meta_infos.keys())
62
+ # self._raw_files = list(tqdm(filter(lambda fn: (
63
+ # scenario_id := fn.removesuffix('.pkl') in self.available_scenarios and
64
+ # 8 <= self.meta_infos[scenario_id]['num_agents'] < self.buffer_size
65
+ # ), self._raw_files), leave=False))
66
+ df = pd.DataFrame.from_dict(self.meta_infos, orient='index')
67
+ available_scenarios_set = set(self.available_scenarios)
68
+ df_filtered = df[(df.index.isin(available_scenarios_set)) & (df['num_agents'] >= 8) & (df['num_agents'] < self.buffer_size)]
69
+ valid_scenarios = set(df_filtered.index)
70
+ self._raw_files = list(tqdm(filter(lambda fn: fn.removesuffix('.pkl') in valid_scenarios, self._raw_files), leave=False))
71
+ if len(self._raw_files) <= 0:
72
+ raise RuntimeError(f'Invalid number of data {len(self._raw_files)}!')
73
+ self._raw_paths = list(map(lambda fn: os.path.join(raw_dir, fn), self._raw_files))
74
+
75
+ self.logger.debug(f"The number of {split} dataset is {len(self._raw_paths)}")
76
+ self.logger.debug(f"The buffer size is {self.buffer_size}")
77
+ # self.token_processor = TokenProcessor(token_size,
78
+ # training=self.training,
79
+ # predict_motion=self.predict_motion,
80
+ # predict_state=self.predict_state,
81
+ # predict_map=self.predict_map,
82
+ # state_token=state_token,
83
+ # pl2seed_radius=pl2seed_radius) # 2048
84
+ self.logger.debug(f"The used token size is {token_size}.")
85
+ super().__init__(transform=transform, pre_transform=None, pre_filter=None)
86
+
87
+ def len(self) -> int:
88
+ return len(self._raw_paths)
89
+
90
+ def get(self, idx: int):
91
+ """
92
+ Load pkl file (each represents a 91s scenario for waymo dataset)
93
+ """
94
+ with open(self._raw_paths[idx], 'rb') as handle:
95
+ data = pickle.load(handle)
96
+
97
+ if self._tfrecord_dir is not None:
98
+ data['tfrecord_path'] = os.path.join(self._tfrecord_dir, f"{data['scenario_id']}.tfrecords")
99
+
100
+ # data = self.token_processor.preprocess(data)
101
+ return data
102
+
103
+
104
+ class WaymoTargetBuilder(BaseTransform):
105
+
106
+ def __init__(self,
107
+ num_historical_steps: int,
108
+ num_future_steps: int,
109
+ max_num: int,
110
+ training: bool=False) -> None:
111
+
112
+ self.max_num = max_num
113
+ self.num_historical_steps = num_historical_steps
114
+ self.num_future_steps = num_future_steps
115
+ self.step_current = num_historical_steps - 1
116
+ self.training = training
117
+
118
+ def _score_trained_agents(self, data):
119
+ pos = data['agent']['position']
120
+ av_index = torch.where(data['agent']['role'][:, 0])[0].item()
121
+ distance = torch.norm(pos - pos[av_index], dim=-1)
122
+
123
+ # we do not believe the perception out of range of 150 meters
124
+ data['agent']['valid_mask'] &= distance < 150
125
+
126
+ # we do not predict vehicle too far away from ego car
127
+ role_train_mask = data['agent']['role'].any(-1)
128
+ extra_train_mask = (distance[:, self.step_current] < 100) & (
129
+ data['agent']['valid_mask'][:, self.step_current + 1 :].sum(-1) >= 5
130
+ )
131
+
132
+ train_mask = extra_train_mask | role_train_mask
133
+ if train_mask.sum() > self.max_num: # too many vehicle
134
+ _indices = torch.where(extra_train_mask & ~role_train_mask)[0]
135
+ selected_indices = _indices[
136
+ torch.randperm(_indices.size(0))[: self.max_num - role_train_mask.sum()]
137
+ ]
138
+ data['agent']['train_mask'] = role_train_mask
139
+ data['agent']['train_mask'][selected_indices] = True
140
+ else:
141
+ data['agent']['train_mask'] = train_mask # [n_agent]
142
+
143
+ return data
144
+
145
+ def __call__(self, data) -> HeteroData:
146
+
147
+ if self.training:
148
+ self._score_trained_agents(data)
149
+
150
+ data = TokenProcessor._tokenize_map(data)
151
+ # data keys: dict_keys(['scenario_id', 'agent', 'map_polygon', 'map_point', ('map_point', 'to', 'map_polygon'), ('map_polygon', 'to', 'map_polygon'), 'map_save', 'pt_token'])
152
+ return HeteroData(data)
153
+
154
+
155
+ class MultiDataModule(pl.LightningDataModule):
156
+ transforms = {
157
+ 'WaymoTargetBuilder': WaymoTargetBuilder,
158
+ }
159
+
160
+ dataset = {
161
+ 'scalable': MultiDataset,
162
+ }
163
+
164
+ def __init__(self,
165
+ root: str,
166
+ train_batch_size: int,
167
+ val_batch_size: int,
168
+ test_batch_size: int,
169
+ shuffle: bool = False,
170
+ num_workers: int = 0,
171
+ pin_memory: bool = True,
172
+ persistent_workers: bool = True,
173
+ train_raw_dir: Optional[str] = None,
174
+ val_raw_dir: Optional[str] = None,
175
+ test_raw_dir: Optional[str] = None,
176
+ train_processed_dir: Optional[str] = None,
177
+ val_processed_dir: Optional[str] = None,
178
+ test_processed_dir: Optional[str] = None,
179
+ val_tfrecords_splitted: Optional[str] = None,
180
+ transform: Optional[str] = None,
181
+ dataset: Optional[str] = None,
182
+ num_historical_steps: int = 50,
183
+ num_future_steps: int = 60,
184
+ processor='ntp',
185
+ token_size=512,
186
+ predict_motion: bool=False,
187
+ predict_state: bool=False,
188
+ predict_map: bool=False,
189
+ state_token: Dict[str, int]=None,
190
+ pl2seed_radius: float=None,
191
+ max_num: int=32,
192
+ buffer_size: int=256,
193
+ logger=None,
194
+ **kwargs) -> None:
195
+
196
+ super(MultiDataModule, self).__init__()
197
+ self.root = root
198
+ self.dataset_class = dataset
199
+ self.train_batch_size = train_batch_size
200
+ self.val_batch_size = val_batch_size
201
+ self.test_batch_size = test_batch_size
202
+ self.shuffle = shuffle
203
+ self.num_workers = num_workers
204
+ self.pin_memory = pin_memory
205
+ self.persistent_workers = persistent_workers and num_workers > 0
206
+ self.train_raw_dir = train_raw_dir
207
+ self.val_raw_dir = val_raw_dir
208
+ self.test_raw_dir = test_raw_dir
209
+ self.train_processed_dir = train_processed_dir
210
+ self.val_processed_dir = val_processed_dir
211
+ self.test_processed_dir = test_processed_dir
212
+ self.val_tfrecords_splitted = val_tfrecords_splitted
213
+ self.processor = processor
214
+ self.token_size = token_size
215
+ self.predict_motion = predict_motion
216
+ self.predict_state = predict_state
217
+ self.predict_map = predict_map
218
+ self.state_token = state_token
219
+ self.pl2seed_radius = pl2seed_radius
220
+ self.buffer_size = buffer_size
221
+ self.logger = logger
222
+
223
+ self.train_transform = MultiDataModule.transforms[transform](num_historical_steps,
224
+ num_future_steps,
225
+ max_num=max_num,
226
+ training=True)
227
+ self.val_transform = MultiDataModule.transforms[transform](num_historical_steps,
228
+ num_future_steps,
229
+ max_num=max_num,
230
+ training=False)
231
+
232
+ def setup(self, stage: Optional[str] = None) -> None:
233
+ general_params = dict(token_size=self.token_size,
234
+ predict_motion=self.predict_motion,
235
+ predict_state=self.predict_state,
236
+ predict_map=self.predict_map,
237
+ buffer_size=self.buffer_size,
238
+ logger=self.logger)
239
+
240
+ if stage == 'fit' or stage is None:
241
+ self.train_dataset = MultiDataModule.dataset[self.dataset_class](split='train',
242
+ raw_dir=self.train_raw_dir,
243
+ transform=self.train_transform,
244
+ **general_params)
245
+ self.val_dataset = MultiDataModule.dataset[self.dataset_class](split='val',
246
+ raw_dir=self.val_raw_dir,
247
+ transform=self.val_transform,
248
+ tfrecord_dir=self.val_tfrecords_splitted,
249
+ **general_params)
250
+ if stage == 'validate':
251
+ self.val_dataset = MultiDataModule.dataset[self.dataset_class](split='val',
252
+ raw_dir=self.val_raw_dir,
253
+ transform=self.val_transform,
254
+ tfrecord_dir=self.val_tfrecords_splitted,
255
+ **general_params)
256
+ if stage == 'test':
257
+ self.test_dataset = MultiDataModule.dataset[self.dataset_class](split='test',
258
+ raw_dir=self.test_raw_dir,
259
+ transform=self.val_transform,
260
+ tfrecord_dir=self.val_tfrecords_splitted,
261
+ **general_params)
262
+
263
+ def train_dataloader(self):
264
+ return DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=self.shuffle,
265
+ num_workers=self.num_workers, pin_memory=self.pin_memory,
266
+ persistent_workers=self.persistent_workers)
267
+
268
+ def val_dataloader(self):
269
+ return DataLoader(self.val_dataset, batch_size=self.val_batch_size, shuffle=False,
270
+ num_workers=self.num_workers, pin_memory=self.pin_memory,
271
+ persistent_workers=self.persistent_workers)
272
+
273
+ def test_dataloader(self):
274
+ return DataLoader(self.test_dataset, batch_size=self.test_batch_size, shuffle=False,
275
+ num_workers=self.num_workers, pin_memory=self.pin_memory,
276
+ persistent_workers=self.persistent_workers)
backups/dev/metrics/box_utils.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+
4
+
5
+ def get_yaw_rotation_2d(yaw):
6
+ """
7
+ Gets a 2D rotation matrix given a yaw angle.
8
+
9
+ Args:
10
+ yaw: torch.Tensor, rotation angle in radians. Can be any shape except empty.
11
+
12
+ Returns:
13
+ rotation: torch.Tensor with shape [..., 2, 2], where `...` matches input shape.
14
+ """
15
+ cos_yaw = torch.cos(yaw)
16
+ sin_yaw = torch.sin(yaw)
17
+
18
+ rotation = torch.stack([
19
+ torch.stack([cos_yaw, -sin_yaw], dim=-1),
20
+ torch.stack([sin_yaw, cos_yaw], dim=-1),
21
+ ], dim=-2) # Shape: [..., 2, 2]
22
+
23
+ return rotation
24
+
25
+
26
+ def get_yaw_rotation(yaw):
27
+ """
28
+ Computes a 3D rotation matrix given a yaw angle (rotation around the Z-axis).
29
+
30
+ Args:
31
+ yaw: torch.Tensor of any shape, representing yaw angles in radians.
32
+
33
+ Returns:
34
+ rotation: torch.Tensor of shape [input_shape, 3, 3], representing the rotation matrices.
35
+ """
36
+ cos_yaw = torch.cos(yaw)
37
+ sin_yaw = torch.sin(yaw)
38
+ ones = torch.ones_like(yaw)
39
+ zeros = torch.zeros_like(yaw)
40
+
41
+ return torch.stack([
42
+ torch.stack([cos_yaw, -sin_yaw, zeros], dim=-1),
43
+ torch.stack([sin_yaw, cos_yaw, zeros], dim=-1),
44
+ torch.stack([zeros, zeros, ones], dim=-1),
45
+ ], dim=-2)
46
+
47
+
48
+ def get_transform(rotation, translation):
49
+ """
50
+ Combines an NxN rotation matrix and an Nx1 translation vector into an (N+1)x(N+1) transformation matrix.
51
+
52
+ Args:
53
+ rotation: torch.Tensor of shape [..., N, N], representing rotation matrices.
54
+ translation: torch.Tensor of shape [..., N], representing translation vectors.
55
+ This must have the same dtype as rotation.
56
+
57
+ Returns:
58
+ transform: torch.Tensor of shape [..., (N+1), (N+1)], representing the transformation matrices.
59
+ This has the same dtype as rotation.
60
+ """
61
+ # [..., N, 1]
62
+ translation_n_1 = translation.unsqueeze(-1)
63
+
64
+ # [..., N, N+1] - Combine rotation and translation
65
+ transform = torch.cat([rotation, translation_n_1], dim=-1)
66
+
67
+ # [..., N] - Create the last row, which is [0, 0, ..., 0, 1]
68
+ last_row = torch.zeros_like(translation)
69
+ last_row = torch.cat([last_row, torch.ones_like(last_row[..., :1])], dim=-1)
70
+
71
+ # [..., N+1, N+1] - Append the last row to form the final transformation matrix
72
+ transform = torch.cat([transform, last_row.unsqueeze(-2)], dim=-2)
73
+
74
+ return transform
75
+
76
+
77
+ def get_upright_3d_box_corners(boxes: Tensor):
78
+ """
79
+ Given a set of upright 3D bounding boxes, return its 8 corner points.
80
+
81
+ Args:
82
+ boxes: torch.Tensor [N, 7]. The inner dims are [center{x,y,z}, length, width,
83
+ height, heading].
84
+
85
+ Returns:
86
+ corners: torch.Tensor [N, 8, 3].
87
+ """
88
+ center_x, center_y, center_z, length, width, height, heading = boxes.unbind(dim=-1)
89
+
90
+ # Compute rotation matrix [N, 3, 3]
91
+ rotation = get_yaw_rotation(heading)
92
+
93
+ # Translation [N, 3]
94
+ translation = torch.stack([center_x, center_y, center_z], dim=-1)
95
+
96
+ l2, w2, h2 = length * 0.5, width * 0.5, height * 0.5
97
+
98
+ # Define the 8 corners in local coordinates [N, 8, 3]
99
+ corners_local = torch.stack([
100
+ torch.stack([ l2, w2, -h2], dim=-1),
101
+ torch.stack([-l2, w2, -h2], dim=-1),
102
+ torch.stack([-l2, -w2, -h2], dim=-1),
103
+ torch.stack([ l2, -w2, -h2], dim=-1),
104
+ torch.stack([ l2, w2, h2], dim=-1),
105
+ torch.stack([-l2, w2, h2], dim=-1),
106
+ torch.stack([-l2, -w2, h2], dim=-1),
107
+ torch.stack([ l2, -w2, h2], dim=-1),
108
+ ], dim=1) # Shape: [N, 8, 3]
109
+
110
+ # Rotate and translate the corners
111
+ corners = torch.einsum('n i j, n k j -> n k i', rotation, corners_local) + translation.unsqueeze(1)
112
+
113
+ return corners
backups/dev/metrics/compute_metrics.py ADDED
@@ -0,0 +1,1812 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ! Metrics Calculation
2
+ import concurrent.futures
3
+ import os
4
+ import torch
5
+ import tensorflow as tf
6
+ import collections
7
+ import dataclasses
8
+ import fnmatch
9
+ import json
10
+ import pandas as pd
11
+ import pickle
12
+ import copy
13
+ import concurrent
14
+ import multiprocessing
15
+ from torch_geometric.utils import degree
16
+ from functools import partial
17
+ from dataclasses import dataclass, field
18
+ from tqdm import tqdm
19
+ from argparse import ArgumentParser
20
+ from torch import Tensor
21
+ from google.protobuf import text_format
22
+ from torchmetrics import Metric
23
+ from typing import Optional, Sequence, List, Dict
24
+
25
+ from waymo_open_dataset.utils.sim_agents import submission_specs
26
+
27
+ from dev.utils.visualization import safe_run
28
+ from dev.utils.func import CONSOLE
29
+ from dev.datasets.scalable_dataset import WaymoTargetBuilder
30
+ from dev.datasets.preprocess import TokenProcessor, SHIFT, AGENT_STATE
31
+ from dev.metrics import trajectory_features, interact_features, map_features, placement_features
32
+ from dev.metrics.protos import scenario_pb2, long_metrics_pb2
33
+
34
+
35
+ _METRIC_FIELD_NAMES_BY_BUCKET = {
36
+ 'kinematic': [
37
+ 'linear_speed', 'linear_acceleration',
38
+ 'angular_speed', 'angular_acceleration',
39
+ ],
40
+ 'interactive': [
41
+ 'distance_to_nearest_object', 'collision_indication',
42
+ 'time_to_collision',
43
+ ],
44
+ 'map_based': [
45
+ # 'distance_to_road_edge', 'offroad_indication'
46
+ ],
47
+ 'placement_based': [
48
+ 'num_placement', 'num_removement',
49
+ 'distance_placement', 'distance_removement',
50
+ ]
51
+ }
52
+ _METRIC_FIELD_NAMES = (
53
+ _METRIC_FIELD_NAMES_BY_BUCKET['kinematic'] +
54
+ _METRIC_FIELD_NAMES_BY_BUCKET['interactive'] +
55
+ _METRIC_FIELD_NAMES_BY_BUCKET['map_based'] +
56
+ _METRIC_FIELD_NAMES_BY_BUCKET['placement_based']
57
+ )
58
+
59
+
60
+ """ Help Functions """
61
+
62
+ def _arg_gather(tensor: Tensor, reference_tensor: Tensor) -> Tensor:
63
+ """Finds corresponding indices in `tensor` for each element in `reference_tensor`.
64
+
65
+ Args:
66
+ tensor: A 1D tensor without repetitions.
67
+ reference_tensor: A 1D tensor containing items from `tensor`.
68
+
69
+ Returns:
70
+ A tensor of indices such that `tensor[indices] == reference_tensor`.
71
+ """
72
+ assert tensor.ndim == 1, "tensor must be 1D"
73
+ assert reference_tensor.ndim == 1, "reference_tensor must be 1D"
74
+
75
+ # Create the comparison matrix
76
+ bit_mask = tensor[None, :] == reference_tensor[:, None] # Shape: [len(reference_tensor), len(tensor)]
77
+
78
+ # Count the matches along `tensor` dimension
79
+ bit_mask_sum = bit_mask.int().sum(dim=1)
80
+
81
+ if (bit_mask_sum < 1).any():
82
+ raise ValueError(
83
+ 'Some items in `reference_tensor` are missing from `tensor`: '
84
+ f'\n{reference_tensor} \nvs. \n{tensor}.'
85
+ )
86
+
87
+ if (bit_mask_sum > 1).any():
88
+ raise ValueError('Some items in `tensor` are repeated.')
89
+
90
+ # Compute indices
91
+ indices = torch.matmul(bit_mask.int(), torch.arange(tensor.shape[0], dtype=torch.int32))
92
+ return indices
93
+
94
+
95
+ def is_valid_sim_agent(track: scenario_pb2.Track) -> bool: # type: ignore
96
+ """Checks if the object needs to be resimulated as a sim agent.
97
+
98
+ For the Sim Agents challenge, every object that is valid at the
99
+ `current_time_index` step (here hardcoded to 10) needs to be resimulated.
100
+
101
+ Args:
102
+ track: A track proto for a single object.
103
+
104
+ Returns:
105
+ A boolean flag, True if the object needs to be resimulated, False otherwise.
106
+ """
107
+ return track.states[submission_specs.CURRENT_TIME_INDEX].valid
108
+
109
+
110
+ def get_sim_agent_ids(
111
+ scenario: scenario_pb2.Scenario) -> Sequence[int]: # type: ignore
112
+ """Returns the list of object IDs that needs to be resimulated.
113
+
114
+ Internally calls `is_valid_sim_agent` to verify the simulation criteria,
115
+ i.e. is the object valid at `current_time_index`.
116
+
117
+ Args:
118
+ scenario: The Scenario proto containing the data.
119
+
120
+ Returns:
121
+ A list of int IDs, containing all the objects that need to be simulated.
122
+ """
123
+ object_ids = []
124
+ for track in scenario.tracks:
125
+ if is_valid_sim_agent(track):
126
+ object_ids.append(track.id)
127
+ return object_ids
128
+
129
+
130
+ def get_evaluation_agent_ids(
131
+ scenario: scenario_pb2.Scenario) -> Sequence[int]: # type: ignore
132
+ # Start with the AV object.
133
+ object_ids = {scenario.tracks[scenario.sdc_track_index].id}
134
+ # Add the `tracks_to_predict` objects.
135
+ for required_prediction in scenario.tracks_to_predict:
136
+ object_ids.add(scenario.tracks[required_prediction.track_index].id)
137
+ return sorted(object_ids)
138
+
139
+
140
+ """ Base Data Classes s"""
141
+
142
+ @dataclass(frozen=True)
143
+ class ObjectTrajectories:
144
+
145
+ x: Tensor
146
+ y: Tensor
147
+ z: Tensor
148
+ heading: Tensor
149
+ length: Tensor
150
+ width: Tensor
151
+ height: Tensor
152
+ valid: Tensor
153
+ object_id: Tensor
154
+ object_type: Tensor
155
+
156
+ state: Optional[Tensor] = None
157
+ token_pos: Optional[Tensor] = None
158
+ token_heading: Optional[Tensor] = None
159
+ token_valid: Optional[Tensor] = None
160
+ processed_object_id: Optional[Tensor] = None
161
+ av_id: Optional[int] = None
162
+ processed_av_id: Optional[int] = None
163
+
164
+ def slice_time(self, start_index: int = 0, end_index: Optional[int] = None):
165
+ return ObjectTrajectories(
166
+ x=self.x[..., start_index:end_index],
167
+ y=self.y[..., start_index:end_index],
168
+ z=self.z[..., start_index:end_index],
169
+ heading=self.heading[..., start_index:end_index],
170
+ length=self.length[..., start_index:end_index],
171
+ width=self.width[..., start_index:end_index],
172
+ height=self.height[..., start_index:end_index],
173
+ valid=self.valid[..., start_index:end_index],
174
+ object_id=self.object_id,
175
+ object_type=self.object_type,
176
+
177
+ # these properties can only come from processed file
178
+ state=self.state,
179
+ token_pos=self.token_pos,
180
+ token_heading=self.token_heading,
181
+ token_valid=self.token_valid,
182
+ processed_object_id=self.processed_object_id,
183
+ av_id=self.av_id,
184
+ processed_av_id=self.processed_av_id,
185
+ )
186
+
187
+ def gather_objects(self, object_indices: Tensor):
188
+ assert object_indices.ndim == 1, "object_indices must be 1D"
189
+ return ObjectTrajectories(
190
+ x=torch.index_select(self.x, dim=-2, index=object_indices),
191
+ y=torch.index_select(self.y, dim=-2, index=object_indices),
192
+ z=torch.index_select(self.z, dim=-2, index=object_indices),
193
+ heading=torch.index_select(self.heading, dim=-2, index=object_indices),
194
+ length=torch.index_select(self.length, dim=-2, index=object_indices),
195
+ width=torch.index_select(self.width, dim=-2, index=object_indices),
196
+ height=torch.index_select(self.height, dim=-2, index=object_indices),
197
+ valid=torch.index_select(self.valid, dim=-2, index=object_indices),
198
+ object_id=torch.index_select(self.object_id, dim=-1, index=object_indices),
199
+ object_type=torch.index_select(self.object_type, dim=-1, index=object_indices),
200
+
201
+ # these properties can only come from processed file
202
+ state=self.state,
203
+ token_pos=self.token_pos,
204
+ token_heading=self.token_heading,
205
+ token_valid=self.token_valid,
206
+ processed_object_id=self.processed_object_id,
207
+ av_id=self.av_id,
208
+ processed_av_id=self.processed_av_id,
209
+ )
210
+
211
+ def gather_objects_by_id(self, object_ids: tf.Tensor):
212
+ indices = _arg_gather(self.object_id, object_ids)
213
+ return self.gather_objects(indices)
214
+
215
+ @classmethod
216
+ def _get_init_dict_from_processed(cls, scenario: dict):
217
+ """Load from processed pkl data"""
218
+ position = scenario['agent']['position']
219
+ heading = scenario['agent']['heading']
220
+ shape = scenario['agent']['shape']
221
+ object_ids = scenario['agent']['id']
222
+ object_types = scenario['agent']['type']
223
+ valid = scenario['agent']['valid_mask']
224
+
225
+ init_dict = dict(x=position[..., 0],
226
+ y=position[..., 1],
227
+ z=position[..., 2],
228
+ heading=heading,
229
+ length=shape[..., 0],
230
+ width=shape[..., 1],
231
+ height=shape[..., 2],
232
+ valid=valid,
233
+ object_ids=object_ids,
234
+ object_types=object_types)
235
+
236
+ return init_dict
237
+
238
+ @classmethod
239
+ def _get_init_dict_from_raw(cls,
240
+ scenario: scenario_pb2.Scenario): # type: ignore
241
+
242
+ """Load from tfrecords data"""
243
+ states, dimensions, objects = [], [], []
244
+ for track in scenario.tracks: # n_object
245
+ # Iterate over a single object's states.
246
+ track_states, track_dimensions = [], []
247
+ for state in track.states: # n_timestep
248
+ track_states.append((state.center_x, state.center_y, state.center_z,
249
+ state.heading, state.valid))
250
+ track_dimensions.append((state.length, state.width, state.height))
251
+ # Adds to the global states.
252
+ states.append(list(zip(*track_states)))
253
+ dimensions.append(list(zip(*track_dimensions)))
254
+ objects.append((track.id, track.object_type))
255
+
256
+ # Unpack and convert to tf tensors.
257
+ x, y, z, heading, valid = [torch.tensor(s) for s in zip(*states)]
258
+ length, width, height = [torch.tensor(s) for s in zip(*dimensions)]
259
+ object_ids, object_types = [torch.tensor(s) for s in zip(*objects)]
260
+
261
+ av_id = object_ids[scenario.sdc_track_index]
262
+
263
+ init_dict = dict(x=x, y=y, z=z,
264
+ heading=heading,
265
+ length=length,
266
+ width=width,
267
+ height=height,
268
+ valid=valid,
269
+ object_id=object_ids,
270
+ object_type=object_types,
271
+ av_id=int(av_id))
272
+
273
+ return init_dict
274
+
275
+ @classmethod
276
+ def from_scenario(cls,
277
+ scenario: scenario_pb2.Scenario, # type: ignore
278
+ processed_scenario: Optional[dict]=None,
279
+ from_where: str='raw'):
280
+
281
+ if from_where == 'raw':
282
+ init_dict = cls._get_init_dict_from_raw(scenario)
283
+ elif from_where == 'processed':
284
+ assert processed_scenario is not None, f'`processed_scenario` should be given!'
285
+ init_dict = cls._get_init_dict_from_processed(processed_scenario)
286
+ else:
287
+ raise RuntimeError(f'Invalid from {from_where}')
288
+
289
+ if processed_scenario is not None:
290
+ init_dict.update(state=processed_scenario['agent']['state_idx'],
291
+ token_pos=processed_scenario['agent']['token_pos'],
292
+ token_heading=processed_scenario['agent']['token_heading'],
293
+ token_valid=processed_scenario['agent']['raw_agent_valid_mask'],
294
+ processed_object_id=processed_scenario['agent']['id'],
295
+ processed_av_id=int(processed_scenario['agent']['id'][
296
+ processed_scenario['agent']['av_idx']
297
+ ]),
298
+ )
299
+
300
+ return cls(**init_dict)
301
+
302
+
303
+ @dataclass
304
+ class ScenarioRollouts:
305
+ scenario_id: Optional[str] = None
306
+ joint_scenes: List[ObjectTrajectories] = field(default_factory=list)
307
+
308
+
309
+ """ Conversion Methods """
310
+
311
+ def scenario_to_trajectories(
312
+ scenario: scenario_pb2.Scenario, # type: ignore
313
+ processed_scenario: Optional[dict]=None,
314
+ from_where: Optional[str]='raw',
315
+ remove_history: Optional[bool]=False
316
+ ) -> ObjectTrajectories:
317
+ """Converts a WOMD Scenario proto into the `ObjectTrajectories`.
318
+
319
+ Returns:
320
+ A `ObjectTrajectories` with trajectories copied from data.
321
+ """
322
+ trajectories = ObjectTrajectories.from_scenario(scenario,
323
+ processed_scenario,
324
+ from_where,
325
+ )
326
+ # Slice by the required sim agents.
327
+ sim_agent_ids = get_sim_agent_ids(scenario)
328
+ # CONSOLE.log(f'sim_agent_ids of log scenario: {sim_agent_ids} total: {len(sim_agent_ids)}')
329
+ trajectories = trajectories.gather_objects_by_id(torch.tensor(sim_agent_ids))
330
+
331
+ if remove_history:
332
+ # Slice in time to only include steps after `current_time_index`.
333
+ trajectories = trajectories.slice_time(submission_specs.CURRENT_TIME_INDEX + 1) # 10 + 1
334
+ if trajectories.valid.shape[-1] != submission_specs.N_SIMULATION_STEPS: # 80 simulated steps
335
+ raise ValueError(
336
+ 'The Scenario used does not include the right number of time steps. '
337
+ f'Expected: {submission_specs.N_SIMULATION_STEPS}, '
338
+ f'Actual: {trajectories.valid.shape[-1]}.')
339
+
340
+ return trajectories
341
+
342
+
343
+ def _unbatch(src: Tensor, batch: Tensor, dim: int = 0) -> List[Tensor]:
344
+ sizes = degree(batch, dtype=torch.long).tolist()
345
+ return src.split(sizes, dim)
346
+
347
+
348
+ def get_scenario_id_int_tensor(scenario_id: List[str], device: torch.device=torch.device('cpu')) -> torch.Tensor:
349
+ scenario_id_int_tensor = []
350
+ for str_id in scenario_id:
351
+ int_id = [-1] * 16 # max_len of scenario_id string is 16
352
+ for i, c in enumerate(str_id):
353
+ int_id[i] = ord(c)
354
+ scenario_id_int_tensor.append(
355
+ torch.tensor(int_id, dtype=torch.int32, device=device)
356
+ )
357
+ return torch.stack(scenario_id_int_tensor, dim=0) # [n_scenario, 16]
358
+
359
+
360
+ def output_to_rollouts(scenario: dict) -> List[ScenarioRollouts]: # n_scenario
361
+ # scenario_id: Tensor, # [n_scenario, n_str_length]
362
+ # agent_id: Tensor, # [n_agent, n_rollout]
363
+ # agent_batch: Tensor, # [n_agent]
364
+ # pred_traj: Tensor, # [n_agent, n_rollout, n_step, 2]
365
+ # pred_z: Tensor, # [n_agent, n_rollout, n_step]
366
+ # pred_head: Tensor, # [n_agent, n_rollout, n_step]
367
+ # pred_shape: Tensor, # [n_agent, n_rollout, 3]
368
+ # pred_type: Tensor, # [n_agent, n_rollout]
369
+ # pred_state: Tensor, # [n_agent, n_rollout, n_step]
370
+ scenario_id = scenario['scenario_id']
371
+ av_id = (
372
+ scenario['av_id'] if 'av_id' in scenario else -1
373
+ )
374
+ agent_id = scenario['agent_id']
375
+ agent_batch = scenario['agent_batch']
376
+ pred_traj = scenario['pred_traj']
377
+ pred_z = scenario['pred_z']
378
+ pred_head = scenario['pred_head']
379
+ pred_shape = scenario['pred_shape']
380
+ pred_type = scenario['pred_type']
381
+ pred_state = (
382
+ scenario['pred_state'] if 'pred_state' in scenario else
383
+ torch.zeros_like(pred_z).long()
384
+ )
385
+ pred_valid = scenario['pred_valid']
386
+ token_pos = scenario['token_pos']
387
+ token_head = scenario['token_head']
388
+
389
+ # CONSOLE.log("Generate scenario rollouts ...")
390
+ # CONSOLE.log(f'scenario_id: {scenario_id}')
391
+ # CONSOLE.log(f'agent_id: {agent_id.flatten()} total: {agent_id.shape}')
392
+ # CONSOLE.log(f'av_id: {av_id}')
393
+ # CONSOLE.log(f'agent_batch: {agent_batch} total: {agent_batch.shape}')
394
+ # CONSOLE.log(f'pred_traj: {pred_traj.shape}')
395
+ # CONSOLE.log(f'pred_z: {pred_z.shape}')
396
+ # CONSOLE.log(f'pred_head: {pred_head.shape}')
397
+ # CONSOLE.log(f'pred_shape: {pred_shape.shape}')
398
+ # CONSOLE.log(f'pred_type: {pred_type.shape}')
399
+ # CONSOLE.log(f'pred_state: {pred_state.shape}')
400
+ # CONSOLE.log(f'token_pos: {token_pos.shape}')
401
+ # CONSOLE.log(f'token_head: {token_head.shape}')
402
+
403
+ scenario_id = scenario_id.cpu().numpy()
404
+ n_agent, n_rollout, n_step, _ = pred_traj.shape
405
+ agent_id = _unbatch(agent_id, agent_batch)
406
+ pred_traj = _unbatch(pred_traj, agent_batch)
407
+ pred_z = _unbatch(pred_z, agent_batch)
408
+ pred_head = _unbatch(pred_head, agent_batch)
409
+ pred_shape = _unbatch(pred_shape, agent_batch)
410
+ pred_type = _unbatch(pred_type, agent_batch)
411
+ pred_state = _unbatch(pred_state, agent_batch)
412
+ pred_valid = _unbatch(pred_valid, agent_batch)
413
+ token_pos = _unbatch(token_pos, agent_batch)
414
+ token_head = _unbatch(token_head, agent_batch)
415
+
416
+ agent_id = [x.cpu() for x in agent_id]
417
+ pred_traj = [x.cpu() for x in pred_traj]
418
+ pred_z = [x.cpu() for x in pred_z]
419
+ pred_head = [x.cpu() for x in pred_head]
420
+ pred_shape = [x[:, :, None].repeat(1, 1, n_step, 1).cpu() for x in pred_shape]
421
+ pred_type = [x[:, :, None].repeat(1, 1, n_step, 1).cpu() for x in pred_type]
422
+ pred_state = [x.cpu() for x in pred_state]
423
+ pred_valid = [x.cpu() for x in pred_valid]
424
+ token_pos = [x.cpu() for x in token_pos]
425
+ token_head = [x.cpu() for x in token_head]
426
+
427
+ n_scenario = scenario_id.shape[0]
428
+ scenario_rollouts = []
429
+ for i_scenario in range(n_scenario):
430
+ joint_scenes = []
431
+ for i_rollout in range(n_rollout): # 1
432
+ joint_scenes.append(
433
+ ObjectTrajectories(
434
+ x=pred_traj[i_scenario][:, i_rollout, :, 0],
435
+ y=pred_traj[i_scenario][:, i_rollout, :, 1],
436
+ z=pred_z[i_scenario][:, i_rollout],
437
+ heading=pred_head[i_scenario][:, i_rollout],
438
+ length=pred_shape[i_scenario][:, i_rollout, :, 0],
439
+ width=pred_shape[i_scenario][:, i_rollout, :, 1],
440
+ height=pred_shape[i_scenario][:, i_rollout, :, 2],
441
+ valid=pred_valid[i_scenario][:, i_rollout],
442
+ state=pred_state[i_scenario][:, i_rollout],
443
+ object_id=agent_id[i_scenario][:, i_rollout],
444
+ processed_object_id=agent_id[i_scenario][:, i_rollout],
445
+ object_type=pred_type[i_scenario][:, i_rollout],
446
+ token_pos=token_pos[i_scenario][:, i_rollout, :, :2],
447
+ token_heading=token_head[i_scenario][:, i_rollout],
448
+ av_id=av_id,
449
+ processed_av_id=av_id,
450
+ )
451
+ )
452
+
453
+ _str_scenario_id = "".join([chr(x) for x in scenario_id[i_scenario] if x > 0])
454
+ scenario_rollouts.append(
455
+ ScenarioRollouts(
456
+ joint_scenes=joint_scenes, scenario_id=_str_scenario_id
457
+ )
458
+ )
459
+
460
+ # CONSOLE.log(f'n_scenario: {len(scenario_rollouts)}')
461
+ # CONSOLE.log(f'n_rollout: {len(scenario_rollouts[0].joint_scenes)}')
462
+ # CONSOLE.log(f'x shape: {scenario_rollouts[0].joint_scenes[0].x.shape}')
463
+
464
+ return scenario_rollouts
465
+
466
+
467
+ """ Compute Metric Features """
468
+
469
+ def _compute_metametric(
470
+ config: long_metrics_pb2.SimAgentMetricsConfig, # type: ignore
471
+ metrics: long_metrics_pb2.SimAgentMetrics, # type: ignore
472
+ ):
473
+ """Computes the meta-metric aggregation."""
474
+ metametric = 0.0
475
+ for field_name in _METRIC_FIELD_NAMES:
476
+ likelihood_field_name = field_name + '_likelihood'
477
+ weight = getattr(config, field_name).metametric_weight
478
+ metric_score = getattr(metrics, likelihood_field_name)
479
+ metametric += weight * metric_score
480
+ return metametric
481
+
482
+
483
+ @dataclasses.dataclass(frozen=True)
484
+ class MetricFeatures:
485
+
486
+ object_id: Tensor
487
+ valid: Tensor
488
+ linear_speed: Tensor
489
+ linear_acceleration: Tensor
490
+ angular_speed: Tensor
491
+ angular_acceleration: Tensor
492
+ distance_to_nearest_object: Tensor
493
+ collision_per_step: Tensor
494
+ time_to_collision: Tensor
495
+ distance_to_road_edge: Tensor
496
+ offroad_per_step: Tensor
497
+ num_placement: Tensor
498
+ num_removement: Tensor
499
+ distance_placement: Tensor
500
+ distance_removement: Tensor
501
+
502
+ @classmethod
503
+ def from_file(cls, file_path: str):
504
+
505
+ if not os.path.exists(file_path):
506
+ raise FileNotFoundError(f'Not found file {file_path}')
507
+
508
+ with open(file_path, 'rb') as f:
509
+ feat_dict = pickle.load(f)
510
+
511
+ fields = [field.name for field in dataclasses.fields(cls)]
512
+ init_dict = dict()
513
+
514
+ for field in fields:
515
+ if field in feat_dict:
516
+ init_dict[field] = feat_dict[field]
517
+ else:
518
+ init_dict[field] = None
519
+
520
+ return cls(**init_dict)
521
+
522
+ def unfold(self, size: int, step: int):
523
+ return MetricFeatures(
524
+ object_id=self.object_id,
525
+ valid=self.valid.unfold(1, size, step),
526
+ linear_speed=self.linear_speed.unfold(1, size, step),
527
+ linear_acceleration=self.linear_acceleration.unfold(1, size, step),
528
+ angular_speed=self.angular_speed.unfold(1, size, step),
529
+ angular_acceleration=self.angular_acceleration.unfold(1, size, step),
530
+ distance_to_nearest_object=self.distance_to_nearest_object.unfold(1, size, step),
531
+ collision_per_step=self.collision_per_step.unfold(1, size, step),
532
+ time_to_collision=self.time_to_collision.unfold(1, size, step),
533
+ distance_to_road_edge=self.distance_to_road_edge.unfold(1, size, step),
534
+ offroad_per_step=self.offroad_per_step.unfold(1, size, step),
535
+ num_placement=self.num_placement.unfold(1, size // SHIFT, step // SHIFT),
536
+ num_removement=self.num_removement.unfold(1, size // SHIFT, step // SHIFT),
537
+ distance_placement=self.distance_placement.unfold(1, size // SHIFT, step // SHIFT),
538
+ distance_removement=self.distance_removement.unfold(1, size // SHIFT, step // SHIFT),
539
+ )
540
+
541
+
542
+ def compute_metric_features(
543
+ simulate_trajectories: ObjectTrajectories,
544
+ evaluate_agent_ids: Optional[Tensor]=None,
545
+ scenario_log: Optional[scenario_pb2.Scenario]=None, # type: ignore
546
+ ) -> MetricFeatures:
547
+
548
+ if evaluate_agent_ids is not None:
549
+ evaluate_trajectories = simulate_trajectories.gather_objects_by_id(
550
+ evaluate_agent_ids
551
+ )
552
+ else:
553
+ evaluate_trajectories = simulate_trajectories
554
+
555
+ # valid mask
556
+ validity_mask = evaluate_trajectories.valid
557
+ validity_mask = validity_mask[:, submission_specs.CURRENT_TIME_INDEX + 1:]
558
+
559
+ # ! Kinematics-related features, i.e. speed and acceleration, this needs
560
+ # history steps to be prepended to make the first evaluate step valid.
561
+ # Resulted `lienar_speed` and others: (n_object_to_evaluate, n_future_step)
562
+ linear_speed, linear_accel, angular_speed, angular_accel = (
563
+ trajectory_features.compute_kinematic_features(
564
+ evaluate_trajectories.x,
565
+ evaluate_trajectories.y,
566
+ evaluate_trajectories.z,
567
+ evaluate_trajectories.heading,
568
+ seconds_per_step=submission_specs.STEP_DURATION_SECONDS))
569
+ # Removes the data corresponding to the history time interval.
570
+ linear_speed, linear_accel, angular_speed, angular_accel = (
571
+ map(lambda t: t[:, submission_specs.CURRENT_TIME_INDEX + 1:],
572
+ [linear_speed, linear_accel, angular_speed, angular_accel])
573
+ )
574
+
575
+ # ! Distances to nearest objects.
576
+ # evaluate_object_mask = torch.any(
577
+ # evaluate_agent_ids[:, None] == simulated_trajectories.object_id, axis=0
578
+ # )
579
+ evaluate_object_mask = torch.ones(len(simulate_trajectories.object_id)).bool()
580
+ distances_to_objects = interact_features.compute_distance_to_nearest_object(
581
+ center_x=simulate_trajectories.x,
582
+ center_y=simulate_trajectories.y,
583
+ center_z=simulate_trajectories.z,
584
+ length=simulate_trajectories.length,
585
+ width=simulate_trajectories.width,
586
+ height=simulate_trajectories.height,
587
+ heading=simulate_trajectories.heading,
588
+ valid=simulate_trajectories.valid,
589
+ evaluated_object_mask=evaluate_object_mask,
590
+ )
591
+ distances_to_objects = (
592
+ distances_to_objects[:, submission_specs.CURRENT_TIME_INDEX + 1:])
593
+ is_colliding_per_step = torch.lt(
594
+ distances_to_objects, interact_features.COLLISION_DISTANCE_THRESHOLD)
595
+
596
+ # ! Time to collision
597
+ times_to_collision = (
598
+ interact_features.compute_time_to_collision_with_object_in_front(
599
+ center_x=simulate_trajectories.x,
600
+ center_y=simulate_trajectories.y,
601
+ length=simulate_trajectories.length,
602
+ width=simulate_trajectories.width,
603
+ heading=simulate_trajectories.heading,
604
+ valid=simulate_trajectories.valid,
605
+ evaluated_object_mask=evaluate_object_mask,
606
+ seconds_per_step=submission_specs.STEP_DURATION_SECONDS,
607
+ )
608
+ )
609
+ times_to_collision = times_to_collision[:, submission_specs.CURRENT_TIME_INDEX + 1:]
610
+
611
+ # ! Distance to road edge
612
+ distances_to_road_edge = torch.empty_like(distances_to_objects)
613
+ is_offroad_per_step = torch.empty_like(is_colliding_per_step)
614
+ if scenario_log is not None:
615
+ road_edges = []
616
+ for map_feature in scenario_log.map_features:
617
+ if map_feature.HasField('road_edge'):
618
+ road_edges.append(map_feature.road_edge.polyline)
619
+ distances_to_road_edge = map_features.compute_distance_to_road_edge(
620
+ center_x=simulate_trajectories.x,
621
+ center_y=simulate_trajectories.y,
622
+ center_z=simulate_trajectories.z,
623
+ length=simulate_trajectories.length,
624
+ width=simulate_trajectories.width,
625
+ height=simulate_trajectories.height,
626
+ heading=simulate_trajectories.heading,
627
+ valid=simulate_trajectories.valid,
628
+ evaluated_object_mask=evaluate_object_mask,
629
+ road_edge_polylines=road_edges,
630
+ )
631
+ distances_to_road_edge = distances_to_road_edge[:, submission_specs.CURRENT_TIME_INDEX + 1:]
632
+ is_offroad_per_step = torch.gt(
633
+ distances_to_road_edge, map_features.OFFROAD_DISTANCE_THRESHOLD
634
+ )
635
+
636
+ # ! Placement
637
+ if simulate_trajectories.av_id == simulate_trajectories.processed_av_id == -1:
638
+ n_agent, n_step_10hz = linear_speed.shape
639
+ num_placement = torch.zeros((n_step_10hz // SHIFT,))
640
+ num_removement = torch.zeros((n_step_10hz // SHIFT,))
641
+ distance_placement = torch.zeros((n_agent, n_step_10hz // SHIFT))
642
+ distance_removement = torch.zeros((n_agent, n_step_10hz // SHIFT))
643
+
644
+ else:
645
+ assert simulate_trajectories.av_id == simulate_trajectories.processed_av_id, \
646
+ f"Got duplicated av_id: {simulate_trajectories.av_id} and {simulate_trajectories.processed_av_id}"
647
+ num_placement, num_removement = (
648
+ placement_features.compute_num_placement(
649
+ state=simulate_trajectories.state,
650
+ valid=simulate_trajectories.token_valid,
651
+ av_id=simulate_trajectories.processed_av_id,
652
+ object_id=simulate_trajectories.processed_object_id,
653
+ agent_state=AGENT_STATE,
654
+ )
655
+ )
656
+ num_placement = num_placement[submission_specs.CURRENT_TIME_INDEX // SHIFT:]
657
+ num_removement = num_removement[submission_specs.CURRENT_TIME_INDEX // SHIFT:]
658
+ distance_placement, distance_removement = (
659
+ placement_features.compute_distance_placement(
660
+ position=simulate_trajectories.token_pos,
661
+ state=simulate_trajectories.state,
662
+ valid=simulate_trajectories.valid,
663
+ av_id=simulate_trajectories.processed_av_id,
664
+ object_id=simulate_trajectories.processed_object_id,
665
+ agent_state=AGENT_STATE,
666
+ )
667
+ )
668
+ distance_placement = distance_placement[:, submission_specs.CURRENT_TIME_INDEX // SHIFT:]
669
+ distance_removement = distance_removement[:, submission_specs.CURRENT_TIME_INDEX // SHIFT:]
670
+ # distance_placement = distance_placement[distance_placement > 0]
671
+ # distance_removement = distance_removement[distance_removement > 0]
672
+
673
+ # print out some results for debugging
674
+ # CONSOLE.log(f'trajectory x: {simulate_trajectories.x.shape}, \n{simulate_trajectories.x}')
675
+ # CONSOLE.log(f'linear speed: {linear_speed.shape}, \n{linear_speed}')
676
+ # CONSOLE.log(f'distances: {distances_to_objects.shape}, \n{distances_to_objects}')
677
+ # CONSOLE.log(f'time to collision: {times_to_collision.shape}, {times_to_collision}')
678
+
679
+ return MetricFeatures(
680
+ object_id=simulate_trajectories.object_id,
681
+ valid=validity_mask,
682
+ # kinematic
683
+ linear_speed=linear_speed,
684
+ linear_acceleration=linear_accel,
685
+ angular_speed=angular_speed,
686
+ angular_acceleration=angular_accel,
687
+ # interact
688
+ distance_to_nearest_object=distances_to_objects,
689
+ collision_per_step=is_colliding_per_step,
690
+ time_to_collision=times_to_collision,
691
+ # map
692
+ distance_to_road_edge=distances_to_road_edge,
693
+ offroad_per_step=is_offroad_per_step,
694
+ # placement
695
+ num_placement=num_placement[None, ...],
696
+ num_removement=num_removement[None, ...],
697
+ distance_placement=distance_placement,
698
+ distance_removement=distance_removement,
699
+ )
700
+
701
+
702
+ @dataclass(frozen=True)
703
+ class LogDistributions:
704
+
705
+ linear_speed: Tensor
706
+ linear_acceleration: Tensor
707
+ angular_speed: Tensor
708
+ angular_acceleration: Tensor
709
+ distance_to_nearest_object: Tensor
710
+ collision_indication: Tensor
711
+ time_to_collision: Tensor
712
+ distance_to_road_edge: Tensor
713
+ num_placement: Tensor
714
+ num_removement: Tensor
715
+ distance_placement: Tensor
716
+ distance_removement: Tensor
717
+ offroad_indication: Optional[Tensor] = None
718
+
719
+
720
+ """ Compute Metrics """
721
+
722
+ def _assert_and_return_batch_size(
723
+ log_samples: Tensor,
724
+ sim_samples: Tensor
725
+ ) -> int:
726
+ """Asserts consistency in the tensor shapes and returns batch size.
727
+
728
+ Args:
729
+ log_samples: A tensor of shape (batch_size, log_sample_size).
730
+ sim_samples: A tensor of shape (batch_size, sim_sample_size).
731
+
732
+ Returns:
733
+ The `batch_size`.
734
+ """
735
+ assert log_samples.shape[0] == sim_samples.shape[0], "Log and Sim batch sizes must be equal."
736
+ return log_samples.shape[0]
737
+
738
+
739
+ def _reduce_average_with_validity(
740
+ tensor: Tensor, validity: Tensor) -> Tensor:
741
+ """Returns the tensor's average, only selecting valid items.
742
+
743
+ Args:
744
+ tensor: A float tensor of any shape.
745
+ validity: A boolean tensor of the same shape as `tensor`.
746
+
747
+ Returns:
748
+ A float tensor of shape (1,), containing the average of the valid elements
749
+ of `tensor`.
750
+ """
751
+ if tensor.shape != validity.shape:
752
+ raise ValueError('Shapes of `tensor` and `validity` must be the same.'
753
+ f'(Actual: {tensor.shape}, {validity.shape}).')
754
+ cond_sum = torch.sum(torch.where(validity, tensor, torch.zeros_like(tensor)))
755
+ valid_sum = torch.sum(validity)
756
+ if valid_sum == 0:
757
+ return torch.tensor(0.)
758
+ return cond_sum / valid_sum
759
+
760
+
761
+ def histogram_estimate(
762
+ config: long_metrics_pb2.SimAgentMetricsConfig.HistogramEstimate, # type: ignore
763
+ log_samples: Tensor,
764
+ sim_samples: Tensor,
765
+ ) -> Tensor:
766
+ """Computes log-likelihoods of samples based on histograms.
767
+
768
+ Args:
769
+ config: A configuration dictionary, similar to the one in TensorFlow.
770
+ log_samples: A tensor of shape (batch_size, log_sample_size),
771
+ containing `log_sample_size` samples from `batch_size` independent
772
+ populations.
773
+ sim_samples: A tensor of shape (batch_size, sim_sample_size),
774
+ containing `sim_sample_size` samples from `batch_size` independent
775
+ populations.
776
+
777
+ Returns:
778
+ A tensor of shape (batch_size, log_sample_size), where each element (i, k)
779
+ is the log likelihood of the log sample (i, k) under the sim distribution
780
+ (i).
781
+ """
782
+ batch_size = _assert_and_return_batch_size(log_samples, sim_samples)
783
+
784
+ # We generate `num_bins`+1 edges for the histogram buckets.
785
+ edges = torch.linspace(
786
+ config.min_val, config.max_val, config.num_bins + 1
787
+ ).float()
788
+
789
+ # Clip the samples to avoid errors with histograms.
790
+ log_samples = torch.clamp(log_samples, config.min_val, config.max_val)
791
+ sim_samples = torch.clamp(sim_samples, config.min_val, config.max_val)
792
+
793
+ # Create the categorical distribution for simulation. `tfp.histogram` returns
794
+ # a tensor of shape (num_bins, batch_size), so we need to transpose to conform
795
+ # to `tfp.distribution.Categorical`, which requires `probs` to be
796
+ # (batch_size, num_bins).
797
+ sim_counts = torch.vmap(lambda x: torch.histogram(x, bins=edges).hist)(sim_samples)
798
+ sim_counts += config.additive_smoothing_pseudocount
799
+ distributions = torch.distributions.Categorical(probs=sim_counts)
800
+
801
+ # Generate the counts for the log distribution. We reshape the log samples to
802
+ # (batch_size * log_sample_size, 1), so every log sample is independently
803
+ # scored.
804
+ log_values_flat = log_samples.reshape(-1, 1)
805
+ # Shape of log_counts: (batch_size * log_sample_size, num_bins).
806
+ log_counts = torch.vmap(lambda x: torch.histogram(x, bins=edges).hist)(log_values_flat)
807
+ # Identify which bin each sample belongs to and get the log probability of
808
+ # that bin under the sim distribution.
809
+ max_log_bin = log_counts.argmax(dim=-1)
810
+ batched_max_log_bin = max_log_bin.reshape(batch_size, -1)
811
+
812
+ # Since we have defined the categorical distribution to have `batch_size`
813
+ # independent populations, tfp expects this `batch_size` to be in the last
814
+ # dimension of the tensor, so transpose the log bins to
815
+ # (log_sample_size, batch_size).
816
+ log_likelihood = distributions.log_prob(batched_max_log_bin.transpose(0, 1))
817
+
818
+ # Return log likelihood in the shape (batch_size, log_sample_size)
819
+ return log_likelihood.transpose(0, 1)
820
+
821
+
822
+ def log_likelihood_estimate_timeseries(
823
+ field: str,
824
+ feature_config: long_metrics_pb2.SimAgentMetricsConfig.FeatureConfig, # type: ignore
825
+ sim_values: Tensor,
826
+ log_distributions: torch.distributions.Categorical,
827
+ estimate_method: str='histogram',
828
+ ) -> Tensor:
829
+ """Computes the log-likelihood estimates for a time-series simulated feature.
830
+
831
+ Args:
832
+ feature_config: A time-series compatible `FeatureConfig`.
833
+ log_distributions: A float Tensor with shape (batch_sizie, n_bins).
834
+ sim_values: A float Tensor with shape (n_objects / n_scenarios, n_segments, n_steps).
835
+
836
+ Returns:
837
+ A tensor of shape (n_objects, n_steps) containing the simulation probability
838
+ estimates of the simulation features under the logged distribution of the same
839
+ feature.
840
+ """
841
+ assert sim_values.ndim == 3, f'Expect sim_values.ndim==3, got {sim_values.ndim}, shape {sim_values.shape} for {field}'
842
+
843
+ sim_values_flat = sim_values.reshape(-1, 1) # [n_objects * n_segments * n_steps]
844
+
845
+ # if not feature_config.independent_timesteps:
846
+ # # If time steps needs to be considered independent, reshape:
847
+ # # - `sim_values` as (n_objects, n_rollouts * n_steps)
848
+ # # - `log_values` as (n_objects, n_steps)
849
+ # # If values in time are instead to be compared per-step, reshape:
850
+ # # - `sim_values` as (n_objects * n_steps, n_rollouts)
851
+ # # - `log_values` as (n_objects * n_steps, 1)
852
+ # sim_values = sim_values.reshape(-1, 1) # n_rollouts=1
853
+
854
+ # if feature_config.independent_timesteps:
855
+ # sim_values = sim_values.permute(1, 0, 2).reshape(n_objects, n_rollouts * n_steps)
856
+ # else:
857
+ # sim_values = sim_values.permute(1, 2, 0).reshape(n_objects * n_steps, n_rollouts)
858
+ # log_values = log_values.reshape(n_objects * n_steps, 1)
859
+
860
+ # ! calculate distributions for simulate features
861
+ if estimate_method == 'histogram':
862
+ config = feature_config.histogram
863
+ elif estimate_method == 'bernoulli':
864
+ config = (
865
+ long_metrics_pb2.SimAgentMetricsConfig.HistogramEstimate(
866
+ min_val=-0.5, max_val=0.5, num_bins=2,
867
+ additive_smoothing_pseudocount=feature_config.bernoulli.additive_smoothing_pseudocount
868
+ )
869
+ )
870
+ sim_values_flat = sim_values_flat.float() # cast torch.bool to torch.float32
871
+
872
+ # We generate `num_bins`+1 edges for the histogram buckets.
873
+ edges = torch.linspace(
874
+ config.min_val, config.max_val, config.num_bins + 1
875
+ ).float()
876
+
877
+ sim_counts = torch.vmap(lambda x: torch.histogram(x, bins=edges).hist)(sim_values_flat) # [batch_size, num_bins]
878
+ # Identify which bin each sample belongs to and get the log probability of
879
+ # that bin under the sim distribution.
880
+ max_sim_bin = sim_counts.argmax(dim=-1)
881
+ batched_max_sim_bin = max_sim_bin.reshape(1, -1) # `batch_size` = 1, follows the log distributions
882
+
883
+ sim_likelihood = log_distributions.log_prob(batched_max_sim_bin.transpose(0, 1)).flatten()
884
+ return sim_likelihood.reshape(*sim_values.shape) # [n_objects, n_segments, n_steps]
885
+
886
+
887
+ def compute_scenario_metrics_for_bundle(
888
+ config: long_metrics_pb2.SimAgentMetricsConfig, # type: ignore
889
+ log_distributions: LogDistributions,
890
+ scenario_log: Optional[scenario_pb2.Scenario], # type: ignore
891
+ scenario_rollouts: ScenarioRollouts,
892
+ ) -> long_metrics_pb2.SimAgentMetrics: # type: ignore
893
+
894
+ features_fields = [field.name for field in dataclasses.fields(MetricFeatures)]
895
+ features_fields.remove('object_id')
896
+
897
+ # ! compute simluation features
898
+ # CONSOLE.log('[on yellow] Compute sim features [/]')
899
+ sim_features = collections.defaultdict(list)
900
+ for simulate_trajectories in tqdm(scenario_rollouts.joint_scenes, leave=False, desc='rollouts ...'): # n_rollout=1
901
+ rollout_features = compute_metric_features(
902
+ simulate_trajectories,
903
+ evaluate_agent_ids=None,
904
+ scenario_log=scenario_log
905
+ )
906
+
907
+ for field in features_fields:
908
+ sim_features[field].append(getattr(rollout_features, field))
909
+
910
+ for field in features_fields:
911
+ if sim_features[field][0] is not None:
912
+ sim_features[field] = torch.concat(sim_features[field], dim=0) # n_rollout for dim=0
913
+
914
+ sim_features = MetricFeatures(
915
+ **sim_features, object_id=None,
916
+ )
917
+ # after unfold: linear_speed shape [n_agent, n_window, window_size],
918
+ # num_placement shape [n_scenario=1, n_window, window_size]
919
+ flattened_sim_features = copy.deepcopy(sim_features)
920
+ sim_features = sim_features.unfold(size=submission_specs.N_SIMULATION_STEPS, step=SHIFT)
921
+ # CONSOLE.log(f'sim linear_speed feature: {sim_features.linear_speed.shape}')
922
+ # CONSOLE.log(f'sim num_placement feature: {sim_features.num_placement.shape}')
923
+
924
+ ## ! compute metrics
925
+
926
+ # ! kinematics-related metrics
927
+ linear_speed_log_likelihood = log_likelihood_estimate_timeseries(
928
+ field='linear_speed',
929
+ feature_config=config.linear_speed,
930
+ sim_values=sim_features.linear_speed,
931
+ log_distributions=log_distributions.linear_speed,
932
+ )
933
+ angular_speed_log_likelihood = log_likelihood_estimate_timeseries(
934
+ field='angular_speed',
935
+ feature_config=config.angular_speed,
936
+ sim_values=sim_features.angular_speed,
937
+ log_distributions=log_distributions.angular_speed,
938
+ )
939
+ speed_validity, acceleration_validity = (
940
+ trajectory_features.compute_kinematic_validity(flattened_sim_features.valid)
941
+ )
942
+ speed_validity = speed_validity.unfold(1, size=submission_specs.N_SIMULATION_STEPS, step=SHIFT)
943
+ acceleration_validity = acceleration_validity.unfold(1, size=submission_specs.N_SIMULATION_STEPS, step=SHIFT)
944
+ linear_speed_likelihood = torch.exp(_reduce_average_with_validity(
945
+ linear_speed_log_likelihood, speed_validity))
946
+ angular_speed_likelihood = torch.exp(_reduce_average_with_validity(
947
+ angular_speed_log_likelihood, speed_validity))
948
+ # CONSOLE.log(f'linear_speed_likelihood: {linear_speed_likelihood}')
949
+ # CONSOLE.log(f'angular_speed_likelihood: {angular_speed_likelihood}')
950
+
951
+ linear_accel_log_likelihood = log_likelihood_estimate_timeseries(
952
+ field='linear_acceleration',
953
+ feature_config=config.linear_acceleration,
954
+ sim_values=sim_features.linear_acceleration,
955
+ log_distributions=log_distributions.linear_acceleration,
956
+ )
957
+ angular_accel_log_likelihood = log_likelihood_estimate_timeseries(
958
+ field='angular_acceleration',
959
+ feature_config=config.angular_acceleration,
960
+ sim_values=sim_features.angular_acceleration,
961
+ log_distributions=log_distributions.angular_acceleration,
962
+ )
963
+ linear_accel_likelihood = torch.exp(_reduce_average_with_validity(
964
+ linear_accel_log_likelihood, acceleration_validity))
965
+ angular_accel_likelihood = torch.exp(_reduce_average_with_validity(
966
+ angular_accel_log_likelihood, acceleration_validity))
967
+ # CONSOLE.log(f'linear_accel_likelihood: {linear_accel_likelihood}')
968
+ # CONSOLE.log(f'angular_accel_likelihood: {angular_accel_likelihood}')
969
+
970
+ # ! collision and distance to other objects.
971
+
972
+ sim_collision_indication = torch.any(
973
+ torch.where(sim_features.valid, sim_features.collision_per_step, False),
974
+ dim=2)[..., None] # add a dummy time dimension
975
+ collision_score = log_likelihood_estimate_timeseries(
976
+ field='collision_indication',
977
+ feature_config=config.collision_indication,
978
+ sim_values=sim_collision_indication,
979
+ log_distributions=log_distributions.collision_indication,
980
+ estimate_method='bernoulli',
981
+ )
982
+ collision_likelihood = torch.exp(torch.mean(collision_score))
983
+
984
+ distance_to_objects_log_likelihodd = log_likelihood_estimate_timeseries(
985
+ field='distance_to_nearest_object',
986
+ feature_config=config.distance_to_nearest_object,
987
+ sim_values=sim_features.distance_to_nearest_object,
988
+ log_distributions=log_distributions.distance_to_nearest_object,
989
+ )
990
+ distance_to_objects_likelihodd = torch.exp(_reduce_average_with_validity(
991
+ distance_to_objects_log_likelihodd, sim_features.valid))
992
+ # CONSOLE.log(f'distance_to_objects_likelihodd: {distance_to_objects_likelihodd}')
993
+
994
+ ttc_log_likelihood = log_likelihood_estimate_timeseries(
995
+ field='time_to_collision',
996
+ feature_config=config.time_to_collision,
997
+ sim_values=sim_features.time_to_collision,
998
+ log_distributions=log_distributions.time_to_collision,
999
+ )
1000
+ ttc_likelihood = torch.exp(_reduce_average_with_validity(
1001
+ ttc_log_likelihood, sim_features.valid))
1002
+ # CONSOLE.log(f'ttc_likelihood: {ttc_likelihood}')
1003
+
1004
+ # ! offroad and distance to road edge.
1005
+
1006
+ # distance_to_road_edge_log_likelihood = log_likelihood_estimate_timeseries(
1007
+ # field='distance_to_road_edge',
1008
+ # sim_values=sim_features.distance_to_road_edge,
1009
+ # log_distributions=log_distributions.distance_to_road_edge,
1010
+ # )
1011
+ # distance_to_road_edge_likelihood = torch.exp(_reduce_average_with_validity(
1012
+ # distance_to_road_edge_log_likelihood, sim_features.valid))
1013
+ # CONSOLE.log(f'distance_to_road_edge_likelihood: {distance_to_road_edge_likelihood}')
1014
+
1015
+ # ! placement
1016
+
1017
+ num_placement_log_likelihood = log_likelihood_estimate_timeseries(
1018
+ field='num_placement',
1019
+ feature_config=config.num_placement,
1020
+ sim_values=sim_features.num_placement.float(),
1021
+ log_distributions=log_distributions.num_placement,
1022
+ )
1023
+ num_placement_likelihood = torch.exp(torch.mean(num_placement_log_likelihood))
1024
+ num_removement_log_likelihood = log_likelihood_estimate_timeseries(
1025
+ field='num_removement',
1026
+ feature_config=config.num_removement,
1027
+ sim_values=sim_features.num_removement.float(),
1028
+ log_distributions=log_distributions.num_removement,
1029
+ )
1030
+ num_removement_likelihood = torch.exp(torch.mean(num_removement_log_likelihood))
1031
+ # CONSOLE.log(f'num_placement_likelihood: {num_placement_likelihood}')
1032
+ # CONSOLE.log(f'num_removement_likelihood: {num_removement_likelihood}')
1033
+
1034
+ # tensor([[0.0013, 0.0078, 0.0194, 0.0373, 0.0628, 0.0938, 0.1232, 0.1470, 0.1701,
1035
+ # 0.3371]])
1036
+ # tensor([[0.0201, 0.0570, 0.0689, 0.0839, 0.1029, 0.1172, 0.1282, 0.1286, 0.1237,
1037
+ # 0.1695]])
1038
+ distance_placement_log_likelihood = log_likelihood_estimate_timeseries(
1039
+ field='distance_placement',
1040
+ feature_config=config.distance_placement,
1041
+ sim_values=sim_features.distance_placement,
1042
+ log_distributions=log_distributions.distance_placement,
1043
+ )
1044
+ distance_placement_validity = (
1045
+ (sim_features.distance_placement > config.distance_placement.histogram.min_val) &
1046
+ (sim_features.distance_placement < config.distance_placement.histogram.max_val)
1047
+ )
1048
+ distance_placement_likelihood = torch.exp(_reduce_average_with_validity(
1049
+ distance_placement_log_likelihood, distance_placement_validity))
1050
+ distance_removement_log_likelihood = log_likelihood_estimate_timeseries(
1051
+ field='distance_removement',
1052
+ feature_config=config.distance_removement,
1053
+ sim_values=sim_features.distance_removement,
1054
+ log_distributions=log_distributions.distance_removement,
1055
+ )
1056
+ distance_removement_validity = (
1057
+ (sim_features.distance_removement > config.distance_removement.histogram.min_val) &
1058
+ (sim_features.distance_removement < config.distance_removement.histogram.max_val)
1059
+ )
1060
+ distance_removement_likelihood = torch.exp(_reduce_average_with_validity(
1061
+ distance_removement_log_likelihood, distance_removement_validity))
1062
+
1063
+ # ==== Simulated collision and offroad rates ====
1064
+ simulated_collision_rate = torch.sum(
1065
+ sim_collision_indication.long()
1066
+ ) / torch.sum(torch.ones_like(sim_collision_indication).long())
1067
+ # simulated_offroad_rate = tf.reduce_sum(
1068
+ # # `sim_offroad_indication` shape: (n_samples, n_objects).
1069
+ # tf.cast(sim_offroad_indication, tf.int32)
1070
+ # ) / tf.reduce_sum(tf.ones_like(sim_offroad_indication, dtype=tf.int32))
1071
+
1072
+ # ==== Meta metric ====
1073
+ likelihood_metrics = {
1074
+ 'linear_speed_likelihood': float(linear_speed_likelihood.numpy()),
1075
+ 'linear_acceleration_likelihood': float(linear_accel_likelihood.numpy()),
1076
+ 'angular_speed_likelihood': float(angular_speed_likelihood.numpy()),
1077
+ 'angular_acceleration_likelihood': float(angular_accel_likelihood.numpy()),
1078
+ 'distance_to_nearest_object_likelihood': float(distance_to_objects_likelihodd.numpy()),
1079
+ 'collision_indication_likelihood': float(collision_likelihood.numpy()),
1080
+ 'time_to_collision_likelihood': float(ttc_likelihood.numpy()),
1081
+ # 'distance_to_road_edge_likelihoodfloat(': distance_road_edge_likelihood.nump)y(),
1082
+ # 'offroad_indication_likelihoodfloat(': offroad_likelihood.nump)y(),
1083
+ 'num_placement_likelihood': float(num_placement_likelihood.numpy()),
1084
+ 'num_removement_likelihood': float(num_removement_likelihood.numpy()),
1085
+ 'distance_placement_likelihood': float(distance_placement_likelihood.numpy()),
1086
+ 'distance_removement_likelihood': float(distance_removement_likelihood.numpy()),
1087
+ }
1088
+
1089
+ metametric = _compute_metametric(
1090
+ config, long_metrics_pb2.SimAgentMetrics(**likelihood_metrics)
1091
+ )
1092
+ # CONSOLE.log(f'metametric: {metametric}')
1093
+
1094
+ return long_metrics_pb2.SimAgentMetrics(
1095
+ scenario_id=scenario_rollouts.scenario_id,
1096
+ metametric=metametric,
1097
+ simulated_collision_rate=float(simulated_collision_rate.numpy()),
1098
+ # simulated_offroad_rate=simulated_offroad_rate.numpy(),
1099
+ **likelihood_metrics,
1100
+ )
1101
+
1102
+
1103
+ """ Log Features """
1104
+
1105
+ def _get_log_distributions(
1106
+ field: str,
1107
+ feature_config: long_metrics_pb2.SimAgentMetricsConfig.FeatureConfig, # type: ignore
1108
+ log_values: Tensor,
1109
+ estimate_method: str = 'histogram',
1110
+ ) -> Tensor:
1111
+ """Computes the log-likelihood estimates for a time-series simulated feature.
1112
+
1113
+ Args:
1114
+ feature_config: A time-series compatible `FeatureConfig`.
1115
+ log_values: A float Tensor with shape (n_objects, n_steps).
1116
+ sim_values: A float Tensor with shape (n_rollouts, n_objects, n_steps).
1117
+
1118
+ Returns:
1119
+ A tensor of shape (n_objects, n_steps) containing the log probability
1120
+ estimates of the log features under the simulated distribution of the same
1121
+ feature.
1122
+ """
1123
+ assert log_values.ndim == 2, f'Expect log_values.ndim==2, got {log_values.ndim}, shape {log_values.shape} for {field}'
1124
+
1125
+ # [n_objects, n_steps] -> [n_objects * n_steps]
1126
+ log_samples = log_values.reshape(-1)
1127
+
1128
+ # ! estimate
1129
+ if estimate_method == 'histogram':
1130
+ config = feature_config.histogram
1131
+ elif estimate_method == 'bernoulli':
1132
+ config = (
1133
+ long_metrics_pb2.SimAgentMetricsConfig.HistogramEstimate(
1134
+ min_val=-0.5, max_val=0.5, num_bins=2,
1135
+ additive_smoothing_pseudocount=feature_config.bernoulli.additive_smoothing_pseudocount
1136
+ )
1137
+ )
1138
+ log_samples = log_samples.float() # cast torch.bool to torch.float32
1139
+
1140
+ # We generate `num_bins`+1 edges for the histogram buckets.
1141
+ edges = torch.linspace(
1142
+ config.min_val, config.max_val, config.num_bins + 1
1143
+ ).float()
1144
+
1145
+ if field in ('distance_placement', 'distance_removement'):
1146
+ log_samples = log_samples[(log_samples > config.min_val) & (log_samples < config.max_val)]
1147
+
1148
+ # Clip the samples to avoid errors with histograms. Nonetheless, the min/max
1149
+ # values should be configured to never hit this condition in practice.
1150
+ log_samples = torch.clamp(log_samples, config.min_val, config.max_val)
1151
+
1152
+ # Create the categorical distribution for simulation. `tfp.histogram` returns
1153
+ # a tensor of shape (num_bins, batch_size), so we need to transpose to conform
1154
+ # to `tfp.distribution.Categorical`, which requires `probs` to be
1155
+ # (batch_size, num_bins).
1156
+ log_counts = torch.histogram(log_samples, bins=edges).hist.unsqueeze(dim=0) # [1, n_samples]
1157
+ log_counts += config.additive_smoothing_pseudocount
1158
+ distributions = torch.distributions.Categorical(probs=log_counts)
1159
+
1160
+ return distributions
1161
+
1162
+
1163
+ class LongMetric(Metric):
1164
+
1165
+ log_features: MetricFeatures
1166
+
1167
+ def __init__(
1168
+ self,
1169
+ prefix: str='',
1170
+ log_features_dir: str='data/waymo_processed/log_features/',
1171
+ config_path: str='dev/metrics/metric_config.textproto',
1172
+ ) -> None:
1173
+ super().__init__()
1174
+ self.prefix = prefix
1175
+ self.metrics_config = self.load_metrics_config(config_path)
1176
+
1177
+ self.use_log = False
1178
+
1179
+ self.field_names = [
1180
+ "metametric",
1181
+ "average_displacement_error",
1182
+ "min_average_displacement_error",
1183
+ "linear_speed_likelihood",
1184
+ "linear_acceleration_likelihood",
1185
+ "angular_speed_likelihood",
1186
+ "angular_acceleration_likelihood",
1187
+ 'distance_to_nearest_object_likelihood',
1188
+ 'collision_indication_likelihood',
1189
+ 'time_to_collision_likelihood',
1190
+ # 'distance_to_road_edge_likelihood',
1191
+ # 'offroad_indication_likelihood',
1192
+ 'simulated_collision_rate',
1193
+ # 'simulated_offroad_rate',
1194
+ 'num_placement_likelihood',
1195
+ 'num_removement_likelihood',
1196
+ 'distance_placement_likelihood',
1197
+ 'distance_removement_likelihood',
1198
+ ]
1199
+ for k in self.field_names:
1200
+ self.add_state(k, default=torch.tensor(0.), dist_reduce_fx='sum')
1201
+ self.add_state('scenario_counter', default=torch.tensor(0.), dist_reduce_fx='sum')
1202
+ self.add_state('placement_valid_scenario_counter', default=torch.tensor(0.), dist_reduce_fx='sum')
1203
+ self.add_state('removement_valid_scenario_counter', default=torch.tensor(0.), dist_reduce_fx='sum')
1204
+
1205
+ # get log features
1206
+ log_features_path = os.path.join(log_features_dir, 'total_features.pkl')
1207
+ if not os.path.exists(log_features_path):
1208
+ CONSOLE.log(f'[on yellow] Log features does not exist, loading now ... [/]')
1209
+ log_features = aggregate_log_metric_features(log_features_dir)
1210
+ else:
1211
+ log_features = MetricFeatures.from_file(log_features_path)
1212
+ CONSOLE.log(f'Loaded log features from {log_features_path}')
1213
+ self.log_features = log_features
1214
+
1215
+ self._compute_distributions()
1216
+ CONSOLE.log(f"Calculated log distributions:\n{self.log_distributions}")
1217
+
1218
+ def _compute_distributions(self):
1219
+ self.log_distributions = LogDistributions(
1220
+ linear_speed = _get_log_distributions('linear_speed',
1221
+ self.metrics_config.linear_speed, self.log_features.linear_speed,
1222
+ ),
1223
+ linear_acceleration = _get_log_distributions('linear_acceleration',
1224
+ self.metrics_config.linear_acceleration, self.log_features.linear_acceleration,
1225
+ ),
1226
+ angular_speed = _get_log_distributions('angular_speed',
1227
+ self.metrics_config.angular_speed, self.log_features.angular_speed,
1228
+ ),
1229
+ angular_acceleration = _get_log_distributions('angular_acceleration',
1230
+ self.metrics_config.angular_acceleration, self.log_features.angular_acceleration,
1231
+ ),
1232
+ distance_to_nearest_object = _get_log_distributions('distance_to_nearest_object',
1233
+ self.metrics_config.distance_to_nearest_object, self.log_features.distance_to_nearest_object,
1234
+ ),
1235
+ collision_indication = _get_log_distributions('collision_indication',
1236
+ self.metrics_config.collision_indication,
1237
+ log_collision_indication := torch.any(
1238
+ torch.where(self.log_features.valid, self.log_features.collision_per_step, False), dim=1
1239
+ )[..., None], # add a dummy time dimension
1240
+ estimate_method = 'bernoulli',
1241
+ ),
1242
+ time_to_collision = _get_log_distributions('time_to_collision',
1243
+ self.metrics_config.time_to_collision, self.log_features.time_to_collision,
1244
+ ),
1245
+ distance_to_road_edge = _get_log_distributions('distance_to_road_edge',
1246
+ self.metrics_config.distance_to_road_edge, self.log_features.distance_to_road_edge,
1247
+ ),
1248
+ # dist_offroad_indication = _get_log_distributions(
1249
+ # 'offroad_indication',
1250
+ # self.metrics_config.offroad_indication,
1251
+ # log_offroad_indication := torch.any(
1252
+ # torch.where(self.log_features.valid, self.log_features.offroad_per_step, False), dim=1
1253
+ # ),
1254
+ # ),
1255
+ num_placement = _get_log_distributions('num_placement',
1256
+ self.metrics_config.num_placement, self.log_features.num_placement.float(),
1257
+ ),
1258
+ num_removement = _get_log_distributions('num_removement',
1259
+ self.metrics_config.num_removement, self.log_features.num_removement.float(),
1260
+ ),
1261
+ distance_placement = _get_log_distributions('distance_placement',
1262
+ self.metrics_config.distance_placement, (
1263
+ self.log_features.distance_placement[self.log_features.distance_placement > 0])[None, ...],
1264
+ ),
1265
+ distance_removement = _get_log_distributions('distance_removement',
1266
+ self.metrics_config.distance_removement, (
1267
+ self.log_features.distance_removement[self.log_features.distance_removement > 0])[None, ...],
1268
+ ),
1269
+ )
1270
+
1271
+ def _compute_scenario_metrics(
1272
+ self,
1273
+ scenario_file: Optional[str],
1274
+ scenario_rollout: ScenarioRollouts,
1275
+ ) -> long_metrics_pb2.SimAgentMetrics: # type: ignore
1276
+
1277
+ scenario_log = None
1278
+ if self.use_log and scenario_file is not None:
1279
+ if not os.path.exists(scenario_file):
1280
+ raise FileNotFoundError(f"Not found file {scenario_file}")
1281
+ scenario_log = scenario_pb2.Scenario()
1282
+ for data in tf.data.TFRecordDataset([scenario_file], compression_type=''):
1283
+ scenario_log.ParseFromString(bytes(data.numpy()))
1284
+ break
1285
+
1286
+ return compute_scenario_metrics_for_bundle(
1287
+ self.metrics_config, self.log_distributions, scenario_log, scenario_rollout
1288
+ )
1289
+
1290
+ def compute_metrics(self, outputs: dict) -> List[long_metrics_pb2.SimAgentMetrics]: # type: ignore
1291
+ """
1292
+ `outputs` is a dict directly generated by predict models:
1293
+ >>> outputs = dict(
1294
+ >>> scenario_id=get_scenario_id_int_tensor(data['scenario_id'], device),
1295
+ >>> agent_id=agent_id,
1296
+ >>> agent_batch=agent_batch,
1297
+ >>> pred_traj=pred_traj,
1298
+ >>> pred_z=pred_z,
1299
+ >>> pred_head=pred_head,
1300
+ >>> pred_shape=pred_shape,
1301
+ >>> pred_type=pred_type,
1302
+ >>> pred_state=pred_state,
1303
+ >>> )
1304
+ """
1305
+
1306
+ scenario_rollouts = output_to_rollouts(outputs)
1307
+ log_paths: List[str] = outputs['tfrecord_path']
1308
+
1309
+ pool_scenario_metrics = []
1310
+ for _scenario_file, _scenario_rollout in tqdm(
1311
+ zip(log_paths, scenario_rollouts), leave=False, desc='scenarios ...'): # n_scenarios
1312
+ pool_scenario_metrics.append(
1313
+ self._compute_scenario_metrics(
1314
+ _scenario_file, _scenario_rollout,
1315
+ )
1316
+ )
1317
+
1318
+ return pool_scenario_metrics
1319
+
1320
+ def update(
1321
+ self,
1322
+ outputs: Optional[dict]=None,
1323
+ metrics: Optional[List[long_metrics_pb2.SimAgentMetrics]]=None # type: ignore
1324
+ ) -> None:
1325
+
1326
+ if metrics is None:
1327
+ assert outputs is not None, f'`outputs` should not be None!'
1328
+ metrics = self.compute_metrics(outputs)
1329
+
1330
+ for scenario_metrics in metrics:
1331
+ self.scenario_counter += 1
1332
+
1333
+ self.metametric += scenario_metrics.metametric
1334
+ self.average_displacement_error += (
1335
+ scenario_metrics.average_displacement_error
1336
+ )
1337
+ self.min_average_displacement_error += (
1338
+ scenario_metrics.min_average_displacement_error
1339
+ )
1340
+ self.linear_speed_likelihood += scenario_metrics.linear_speed_likelihood
1341
+ self.linear_acceleration_likelihood += (
1342
+ scenario_metrics.linear_acceleration_likelihood
1343
+ )
1344
+ self.angular_speed_likelihood += scenario_metrics.angular_speed_likelihood
1345
+ self.angular_acceleration_likelihood += (
1346
+ scenario_metrics.angular_acceleration_likelihood
1347
+ )
1348
+ self.distance_to_nearest_object_likelihood += (
1349
+ scenario_metrics.distance_to_nearest_object_likelihood
1350
+ )
1351
+ self.collision_indication_likelihood += (
1352
+ scenario_metrics.collision_indication_likelihood
1353
+ )
1354
+ self.time_to_collision_likelihood += (
1355
+ scenario_metrics.time_to_collision_likelihood
1356
+ )
1357
+ # self.distance_to_road_edge_likelihood += (
1358
+ # scenario_metrics.distance_to_road_edge_likelihood
1359
+ # )
1360
+ # self.offroad_indication_likelihood += (
1361
+ # scenario_metrics.offroad_indication_likelihood
1362
+ # )
1363
+ self.simulated_collision_rate += scenario_metrics.simulated_collision_rate
1364
+ # self.simulated_offroad_rate += scenario_metrics.simulated_offroad_rate
1365
+
1366
+ self.num_placement_likelihood += (
1367
+ scenario_metrics.num_placement_likelihood
1368
+ )
1369
+ self.num_removement_likelihood += (
1370
+ scenario_metrics.num_removement_likelihood
1371
+ )
1372
+ self.distance_placement_likelihood += (
1373
+ scenario_metrics.distance_placement_likelihood
1374
+ )
1375
+ self.distance_removement_likelihood += (
1376
+ scenario_metrics.distance_removement_likelihood
1377
+ )
1378
+
1379
+ if scenario_metrics.distance_placement_likelihood > 0:
1380
+ self.placement_valid_scenario_counter += 1
1381
+
1382
+ if scenario_metrics.distance_removement_likelihood > 0:
1383
+ self.removement_valid_scenario_counter += 1
1384
+
1385
+ def compute(self) -> Dict[str, Tensor]:
1386
+ metrics_dict = {}
1387
+ for k in self.field_names:
1388
+ if k not in ('distance_placement', 'distance_removement'):
1389
+ metrics_dict[k] = getattr(self, k) / self.scenario_counter
1390
+ if k == 'distance_placement':
1391
+ metrics_dict[k] = getattr(self, k) / self.placement_valid_scenario_counter
1392
+ if k == 'distance_removement':
1393
+ metrics_dict[k] = getattr(self, k) / self.removement_valid_scenario_counter
1394
+
1395
+ mean_metrics = long_metrics_pb2.SimAgentMetrics(
1396
+ scenario_id='', **metrics_dict,
1397
+ )
1398
+ final_metrics = self.aggregate_metrics_to_buckets(
1399
+ self.metrics_config, mean_metrics
1400
+ )
1401
+ CONSOLE.log(f'final_metrics:\n{final_metrics}')
1402
+
1403
+ out_dict = {
1404
+ f"{self.prefix}/wosac/realism_meta_metric": final_metrics.realism_meta_metric,
1405
+ f"{self.prefix}/wosac/kinematic_metrics": final_metrics.kinematic_metrics,
1406
+ f"{self.prefix}/wosac/interactive_metrics": final_metrics.interactive_metrics,
1407
+ f"{self.prefix}/wosac/map_based_metrics": final_metrics.map_based_metrics,
1408
+ f"{self.prefix}/wosac/placement_based_metrics": final_metrics.placement_based_metrics,
1409
+ f"{self.prefix}/wosac/min_ade": final_metrics.min_ade,
1410
+ f"{self.prefix}/wosac/scenario_counter": int(self.scenario_counter),
1411
+ }
1412
+ for k in self.field_names:
1413
+ out_dict[f"{self.prefix}/wosac_likelihood/{k}"] = float(metrics_dict[k])
1414
+
1415
+ return out_dict
1416
+
1417
+ @staticmethod
1418
+ def aggregate_metrics_to_buckets(
1419
+ config: long_metrics_pb2.SimAgentMetricsConfig, # type: ignore
1420
+ metrics: long_metrics_pb2.SimAgentMetrics # type: ignore
1421
+ ) -> long_metrics_pb2.SimAgentsBucketedMetrics: # type: ignore
1422
+ """Aggregates metrics into buckets for better readability."""
1423
+ bucketed_metrics = {}
1424
+ for bucket_name, fields_in_bucket in _METRIC_FIELD_NAMES_BY_BUCKET.items():
1425
+ weighted_metric, weights_sum = 0.0, 0.0
1426
+ for field_name in fields_in_bucket:
1427
+ likelihood_field_name = field_name + '_likelihood'
1428
+ weight = getattr(config, field_name).metametric_weight
1429
+ metric_score = getattr(metrics, likelihood_field_name)
1430
+ weighted_metric += weight * metric_score
1431
+ weights_sum += weight
1432
+ if weights_sum == 0:
1433
+ weights_sum = 1 # FIXME: hack!!!
1434
+ # raise ValueError('The bucket\'s weight sum is zero. Check your metrics'
1435
+ # ' config.')
1436
+ bucketed_metrics[bucket_name] = weighted_metric / weights_sum
1437
+
1438
+ return long_metrics_pb2.SimAgentsBucketedMetrics(
1439
+ realism_meta_metric=metrics.metametric,
1440
+ kinematic_metrics=bucketed_metrics['kinematic'],
1441
+ interactive_metrics=bucketed_metrics['interactive'],
1442
+ map_based_metrics=bucketed_metrics['map_based'],
1443
+ placement_based_metrics=bucketed_metrics['placement_based'],
1444
+ min_ade=metrics.min_average_displacement_error,
1445
+ simulated_collision_rate=metrics.simulated_collision_rate,
1446
+ simulated_offroad_rate=metrics.simulated_offroad_rate,
1447
+ )
1448
+
1449
+ @staticmethod
1450
+ def load_metrics_config(config_path: str = 'dev/metrics/metric_config.textproto',
1451
+ ) -> long_metrics_pb2.SimAgentMetricsConfig: # type: ignore
1452
+ config = long_metrics_pb2.SimAgentMetricsConfig()
1453
+ with open(config_path, 'r') as f:
1454
+ text_format.Parse(f.read(), config)
1455
+ return config
1456
+
1457
+ def dumps(self, dir):
1458
+ from datetime import datetime
1459
+
1460
+ timestamp = datetime.now().strftime("%m_%d_%H%M%S")
1461
+
1462
+ results = self.compute()
1463
+ path = os.path.join(dir, f'{self.prefix}_{timestamp}.json')
1464
+ with open(path, 'w', encoding='utf-8') as f:
1465
+ json.dump(results, f, indent=4)
1466
+
1467
+ CONSOLE.log(f'Saved results to [bold][yellow]{path}')
1468
+
1469
+
1470
+ """ Preprocess Methods """
1471
+
1472
+ def _dump_log_metric_features(
1473
+ pkl_dir: str,
1474
+ tfrecords_dir: str,
1475
+ save_dir: str,
1476
+ transform: WaymoTargetBuilder,
1477
+ token_processor: TokenProcessor,
1478
+ scenario_id: str,
1479
+ ):
1480
+
1481
+ try:
1482
+
1483
+ tqdm.write(f'Processing scenario {scenario_id}')
1484
+ save_path = os.path.join(save_dir, f'{scenario_id}.pkl')
1485
+ if os.path.exists(save_path):
1486
+ return
1487
+
1488
+ # load gt data
1489
+ pkl_file = os.path.join(pkl_dir, f'{scenario_id}.pkl')
1490
+ if not os.path.exists(pkl_file):
1491
+ raise FileNotFoundError(f"Not found file {pkl_file}")
1492
+ tfrecord_file = os.path.join(tfrecords_dir, f'{scenario_id}.tfrecords')
1493
+ if not os.path.exists(tfrecord_file):
1494
+ raise FileNotFoundError(f"Not found file {tfrecord_file}")
1495
+
1496
+ scenario_log = scenario_pb2.Scenario()
1497
+ for data in tf.data.TFRecordDataset([tfrecord_file], compression_type=''):
1498
+ scenario_log.ParseFromString(bytes(data.numpy()))
1499
+ break
1500
+
1501
+ with open(pkl_file, 'rb') as f:
1502
+ log_data = pickle.load(f)
1503
+
1504
+ # preprocess data
1505
+ log_data = transform._score_trained_agents(log_data) # get `train_mask`
1506
+ log_data = token_processor._tokenize_agent(log_data)
1507
+
1508
+ # convert to `JointScene` and compute features
1509
+ log_trajectories = scenario_to_trajectories(scenario_log, processed_scenario=log_data)
1510
+ # log_trajectories = ObjectTrajectories.init_from_processed_scenario(data)
1511
+
1512
+ # NOTE: we do not consider the `evaluation_agent_ids` here
1513
+ # evaluate_agent_ids = torch.tensor(
1514
+ # get_evaluation_agent_ids(scenario_log)
1515
+ # )
1516
+ evaluate_agent_ids = None
1517
+ log_features = compute_metric_features(
1518
+ log_trajectories, evaluate_agent_ids=evaluate_agent_ids, #scenario_log=scenario_log,
1519
+ )
1520
+
1521
+ # save to pkl file
1522
+ with open(save_path, 'wb') as f:
1523
+ pickle.dump(log_features, f)
1524
+
1525
+ except Exception as e:
1526
+ CONSOLE.log(f'[on red] Failed to process scenario {scenario_id} due to {e}.[/]')
1527
+ return
1528
+
1529
+
1530
+ def dump_log_metric_features(log_dir, save_dir):
1531
+
1532
+ buffer_size = 128
1533
+
1534
+ # file loaders
1535
+ pkl_dir = os.path.join(log_dir, 'validation')
1536
+ if not os.path.exists(pkl_dir):
1537
+ raise RuntimeError(f'Not found folder {pkl_dir}')
1538
+ tfrecords_dir = os.path.join(log_dir, 'validation_tfrecords_splitted')
1539
+ if not os.path.exists(tfrecords_dir):
1540
+ raise RuntimeError(f'Not found folder {tfrecords_dir}')
1541
+
1542
+ files = list(fnmatch.filter(os.listdir(pkl_dir), '*.pkl'))
1543
+ json_path = os.path.join(log_dir, 'meta_infos.json')
1544
+ meta_infos = json.load(open(json_path, 'r', encoding='utf-8'))['validation']
1545
+ CONSOLE.log(f"Loaded meta infos from {json_path}")
1546
+ available_scenarios = list(meta_infos.keys())
1547
+ df = pd.DataFrame.from_dict(meta_infos, orient='index')
1548
+ available_scenarios_set = set(available_scenarios)
1549
+ df_filtered = df[(df.index.isin(available_scenarios_set)) & (df['num_agents'] >= 8) & (df['num_agents'] < buffer_size)]
1550
+ valid_scenarios = set(df_filtered.index)
1551
+ files = list(tqdm(filter(lambda fn: fn.removesuffix('.pkl') in valid_scenarios, files), leave=False))
1552
+
1553
+ scenario_ids = list(map(lambda fn: fn.removesuffix('.pkl'), files))
1554
+ CONSOLE.log(f'Loaded {len(scenario_ids)} scenarios from validation split.')
1555
+
1556
+ # initialize
1557
+ transform = WaymoTargetBuilder(num_historical_steps=11,
1558
+ num_future_steps=80,
1559
+ max_num=32)
1560
+
1561
+ token_processor = TokenProcessor(token_size=2048,
1562
+ state_token={'invalid': 0, 'valid': 1, 'enter': 2, 'exit': 3},
1563
+ pl2seed_radius=75)
1564
+
1565
+ partial_dump_gt_metric_features = partial(
1566
+ _dump_log_metric_features, pkl_dir, tfrecords_dir, save_dir, transform, token_processor)
1567
+
1568
+ for scenario_id in tqdm(scenario_ids, leave=False, desc='scenarios ...'):
1569
+
1570
+ partial_dump_gt_metric_features(scenario_id)
1571
+
1572
+
1573
+ def batch_dump_log_metric_features(log_dir, save_dir, num_workers=64):
1574
+
1575
+ buffer_size = 128
1576
+
1577
+ # file loaders
1578
+ pkl_dir = os.path.join(log_dir, 'validation')
1579
+ if not os.path.exists(pkl_dir):
1580
+ raise RuntimeError(f'Not found folder {pkl_dir}')
1581
+ tfrecords_dir = os.path.join(log_dir, 'validation_tfrecords_splitted')
1582
+ if not os.path.exists(tfrecords_dir):
1583
+ raise RuntimeError(f'Not found folder {tfrecords_dir}')
1584
+
1585
+ files = list(fnmatch.filter(os.listdir(pkl_dir), '*.pkl'))
1586
+ json_path = os.path.join(log_dir, 'meta_infos.json')
1587
+ meta_infos = json.load(open(json_path, 'r', encoding='utf-8'))['validation']
1588
+ CONSOLE.log(f"Loaded meta infos from {json_path}")
1589
+ available_scenarios = list(meta_infos.keys())
1590
+ df = pd.DataFrame.from_dict(meta_infos, orient='index')
1591
+ available_scenarios_set = set(available_scenarios)
1592
+ df_filtered = df[(df.index.isin(available_scenarios_set)) & (df['num_agents'] >= 8) & (df['num_agents'] < buffer_size)]
1593
+ valid_scenarios = set(df_filtered.index)
1594
+ files = list(tqdm(filter(lambda fn: fn.removesuffix('.pkl') in valid_scenarios, files), leave=False))
1595
+
1596
+ scenario_ids = list(map(lambda fn: fn.removesuffix('.pkl'), files))
1597
+ CONSOLE.log(f'Loaded {len(scenario_ids)} scenarios from validation split.')
1598
+
1599
+ # initialize
1600
+ transform = WaymoTargetBuilder(num_historical_steps=11,
1601
+ num_future_steps=80,
1602
+ max_num=32)
1603
+
1604
+ token_processor = TokenProcessor(token_size=2048,
1605
+ state_token={'invalid': 0, 'valid': 1, 'enter': 2, 'exit': 3},
1606
+ pl2seed_radius=75)
1607
+
1608
+ partial_dump_gt_metric_features = partial(
1609
+ _dump_log_metric_features, pkl_dir, tfrecords_dir, save_dir, transform, token_processor)
1610
+
1611
+ with multiprocessing.Pool(num_workers) as p:
1612
+ list(tqdm(p.imap_unordered(partial_dump_gt_metric_features, scenario_ids), total=len(scenario_ids)))
1613
+
1614
+
1615
+ def aggregate_log_metric_features(load_dir):
1616
+
1617
+ files = list(fnmatch.filter(os.listdir(load_dir), '*.pkl'))
1618
+ if 'total_features.pkl' in files:
1619
+ files.remove('total_features.pkl')
1620
+ CONSOLE.log(f'Loaded {len(files)} scenarios from dumpped log metric features')
1621
+
1622
+ features_fields = [field.name for field in dataclasses.fields(MetricFeatures)]
1623
+ features_fields.remove('object_id')
1624
+
1625
+ # load and append
1626
+ total_features = collections.defaultdict(list)
1627
+ for file in tqdm(files, leave=False, desc='scenario ...'):
1628
+
1629
+ with open(os.path.join(load_dir, file), 'rb') as f:
1630
+ log_features = pickle.load(f)
1631
+
1632
+ for field in features_fields:
1633
+ total_features[field].append(getattr(log_features, field))
1634
+
1635
+ # aggregate
1636
+ features_info = dict()
1637
+ for field in (pbar := tqdm(features_fields, leave=False)):
1638
+ pbar.set_postfix(f=field)
1639
+ if total_features[field][0] is not None:
1640
+ total_features[field] = torch.concat(total_features[field], dim=0) # n_agent or n_scenario
1641
+ features_info[field] = total_features[field].shape
1642
+ CONSOLE.log(f'Aggregated log features:\n{features_info}')
1643
+
1644
+ # save
1645
+ save_path = os.path.join(load_dir, 'total_features.pkl')
1646
+ with open(save_path, 'wb') as f:
1647
+ pickle.dump(total_features, f)
1648
+ CONSOLE.log(f'Saved total features to [green]{save_path}.[/]')
1649
+
1650
+ return MetricFeatures(**total_features, object_id=None)
1651
+
1652
+
1653
+ def _compute_metrics(
1654
+ metric: LongMetric,
1655
+ load_dir: str,
1656
+ verbose: bool,
1657
+ rollouts_file: str,
1658
+ ) -> List[long_metrics_pb2.SimAgentMetrics]: # type: ignore
1659
+
1660
+ if verbose:
1661
+ print(f'Processing {rollouts_file}')
1662
+
1663
+ with open(os.path.join(load_dir, rollouts_file), 'rb') as f:
1664
+ rollouts = pickle.load(f)
1665
+ # CONSOLE.log(f'Loaded rollouts from {rollouts_file}')
1666
+
1667
+ return metric.compute_metrics(rollouts)
1668
+
1669
+
1670
+ def compute_metrics(load_dir, rollouts_files):
1671
+
1672
+ log_every_n_steps = 100
1673
+
1674
+ metric = LongMetric('val_close_long')
1675
+ CONSOLE.log(f'metrics config:\n{metric.metrics_config}')
1676
+
1677
+ i = 0
1678
+ for rollouts_file in tqdm(rollouts_files, leave=False, desc='Rollouts files ...'):
1679
+
1680
+ # ! compute metrics and update
1681
+ metric.update(
1682
+ metrics=_compute_metrics(metric, load_dir, verbose=False, rollouts_file=rollouts_file)
1683
+ )
1684
+
1685
+ if i % log_every_n_steps == 0:
1686
+ CONSOLE.log(f'Step={i}:\n{metric.compute()}')
1687
+
1688
+ i += 1
1689
+
1690
+ CONSOLE.log(f'[bold][yellow] Compute metrics completed!')
1691
+ CONSOLE.log(f'[bold][yellow] Final metrics: [/]\n {metric.compute()}')
1692
+
1693
+
1694
+ def batch_compute_metrics(load_dir, rollouts_files, num_workers, save_dir=None):
1695
+ from queue import Queue
1696
+ from threading import Thread
1697
+
1698
+ if save_dir is None:
1699
+ save_dir = load_dir
1700
+
1701
+ results_buffer = Queue()
1702
+
1703
+ log_every_n_steps = 20
1704
+
1705
+ metric = LongMetric('val_close_long')
1706
+ CONSOLE.log(f'metrics config:\n{metric.metrics_config}')
1707
+
1708
+ def _collect_result():
1709
+ while True:
1710
+ r = results_buffer.get()
1711
+ if r is None:
1712
+ break
1713
+ metric.update(metrics=r)
1714
+ results_buffer.task_done()
1715
+
1716
+ collector = Thread(target=_collect_result, daemon=True)
1717
+ collector.start()
1718
+
1719
+ partial_compute_metrics = partial(_compute_metrics, metric, load_dir, True)
1720
+
1721
+ # ! compute metrics in batch
1722
+ with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
1723
+ # results = list(executor.map(partial_compute_metrics, rollouts_files))
1724
+ futures = [executor.submit(partial_compute_metrics, rollouts_file) for rollouts_file in rollouts_files]
1725
+ # results = [f.result() for f in concurrent.futures.as_completed(futures)]
1726
+
1727
+ for i, future in tqdm(enumerate(concurrent.futures.as_completed(futures)), total=len(futures), leave=False):
1728
+ results_buffer.put(future.result())
1729
+
1730
+ if i % log_every_n_steps == 0:
1731
+ CONSOLE.log(f'Step={i}:\n{metric.compute()}')
1732
+ metric.dumps(save_dir)
1733
+
1734
+ results_buffer.put(None)
1735
+ collector.join()
1736
+
1737
+ CONSOLE.log(f'[bold][yellow] Compute metrics completed!')
1738
+ CONSOLE.log(f'[bold][yellow] Final metrics: [/]\n {metric.compute()}')
1739
+
1740
+ # save results to disk
1741
+ metric.dumps(save_dir)
1742
+
1743
+
1744
+ if __name__ == "__main__":
1745
+ parser = ArgumentParser()
1746
+ parser.add_argument('--dump_log', action='store_true')
1747
+ parser.add_argument('--dump_sim', action='store_true')
1748
+ parser.add_argument('--aggregate_log', action='store_true')
1749
+ parser.add_argument('--num_workers', type=int, default=32)
1750
+ parser.add_argument('--compute_metric', action='store_true')
1751
+ parser.add_argument('--log_dir', type=str, default='data/waymo_processed/')
1752
+ parser.add_argument('--sim_dir', type=str, default=None, required=False)
1753
+ parser.add_argument('--save_dir', type=str, default='results', required=False)
1754
+ parser.add_argument('--no_batch', action='store_true')
1755
+ parser.add_argument('--debug', action='store_true')
1756
+ parser.add_argument('--debug_batch', action='store_true')
1757
+ args = parser.parse_args()
1758
+
1759
+ if args.dump_log:
1760
+
1761
+ save_dir = os.path.join(args.log_dir, 'log_features')
1762
+ os.makedirs(save_dir, exist_ok=True)
1763
+
1764
+ if args.no_batch or args.debug:
1765
+ dump_log_metric_features(args.log_dir, save_dir)
1766
+ else:
1767
+ batch_dump_log_metric_features(args.log_dir, save_dir)
1768
+
1769
+ elif args.aggregate_log:
1770
+
1771
+ load_dir = os.path.join(args.log_dir, 'log_features')
1772
+ aggregate_log_metric_features(load_dir)
1773
+
1774
+ elif args.compute_metric:
1775
+
1776
+ assert args.sim_dir is not None and os.path.exists(args.sim_dir), \
1777
+ f'Folder {args.sim_dir} does not exist!'
1778
+ rollouts_files = list(sorted(fnmatch.filter(os.listdir(args.sim_dir), 'idx_*_rollouts.pkl')))
1779
+ CONSOLE.log(f'Found {len(rollouts_files)} rollouts files.')
1780
+
1781
+ os.makedirs(args.save_dir, exist_ok=True)
1782
+ if args.no_batch:
1783
+ compute_metrics(args.sim_dir, rollouts_files)
1784
+
1785
+ else:
1786
+ multiprocessing.set_start_method('spawn', force=True)
1787
+ batch_compute_metrics(args.sim_dir, rollouts_files, args.num_workers, save_dir=args.save_dir)
1788
+
1789
+ elif args.debug:
1790
+
1791
+ debug_path = 'output/scalable_smart_long/validation_catk/idx_0_0_rollouts.pkl'
1792
+
1793
+ # ! for debugging
1794
+ with open(debug_path, 'rb') as f:
1795
+ rollouts = pickle.load(f)
1796
+ metric = LongMetric('debug')
1797
+ CONSOLE.log(f'metrics config: {metric.metrics_config}')
1798
+
1799
+ metric.update(outputs=rollouts)
1800
+ CONSOLE.log(f'metrics:\n{metric.compute()}')
1801
+
1802
+
1803
+ elif args.debug_batch:
1804
+
1805
+ rollouts_files = ['idx_0_rollouts.pkl'] * 1000
1806
+ CONSOLE.log(f'Found {len(rollouts_files)} rollouts files.')
1807
+
1808
+ sim_dir = 'dev/metrics/'
1809
+
1810
+ os.makedirs(args.save_dir, exist_ok=True)
1811
+ multiprocessing.set_start_method('spawn', force=True)
1812
+ batch_compute_metrics(args.sim_dir, rollouts_files, args.num_workers, save_dir=args.save_dir)
backups/dev/metrics/geometry_utils.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torch import Tensor
4
+ from typing import Tuple
5
+
6
+
7
+ NUM_VERTICES_IN_BOX = 4
8
+
9
+
10
+ def minkowski_sum_of_box_and_box_points(box1_points: Tensor,
11
+ box2_points: Tensor) -> Tensor:
12
+ """Batched Minkowski sum of two boxes (counter-clockwise corners in xy)."""
13
+ point_order_1 = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.long)
14
+ point_order_2 = torch.tensor([0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long)
15
+
16
+ box1_start_idx, downmost_box1_edge_direction = _get_downmost_edge_in_box(
17
+ box1_points)
18
+ box2_start_idx, downmost_box2_edge_direction = _get_downmost_edge_in_box(
19
+ box2_points)
20
+
21
+ condition = (cross_product_2d(downmost_box1_edge_direction, downmost_box2_edge_direction) >= 0.)
22
+ condition = condition.repeat(1, 8)
23
+
24
+ box1_point_order = torch.where(condition, point_order_2, point_order_1)
25
+ box1_point_order = (box1_point_order + box1_start_idx) % NUM_VERTICES_IN_BOX
26
+ ordered_box1_points = torch.gather(
27
+ box1_points, 1, box1_point_order.unsqueeze(-1).expand(-1, -1, 2))
28
+
29
+ box2_point_order = torch.where(condition, point_order_1, point_order_2)
30
+ box2_point_order = (box2_point_order + box2_start_idx) % NUM_VERTICES_IN_BOX
31
+ ordered_box2_points = torch.gather(
32
+ box2_points, 1, box2_point_order.unsqueeze(-1).expand(-1, -1, 2))
33
+
34
+ minkowski_sum = ordered_box1_points + ordered_box2_points
35
+
36
+ return minkowski_sum
37
+
38
+
39
+ def signed_distance_from_point_to_convex_polygon(query_points: Tensor, polygon_points: Tensor) -> Tensor:
40
+ """Finds the signed distances from query points to convex polygons."""
41
+ tangent_unit_vectors, normal_unit_vectors, edge_lengths = _get_edge_info(
42
+ polygon_points)
43
+
44
+ query_points = query_points.unsqueeze(1)
45
+ vertices_to_query_vectors = query_points - polygon_points
46
+ vertices_distances = torch.norm(vertices_to_query_vectors, dim=-1)
47
+
48
+ edge_signed_perp_distances = torch.sum(-normal_unit_vectors * vertices_to_query_vectors, dim=-1)
49
+
50
+ is_inside = torch.all(edge_signed_perp_distances <= 0, dim=-1)
51
+
52
+ projection_along_tangent = torch.sum(tangent_unit_vectors * vertices_to_query_vectors, dim=-1)
53
+ projection_along_tangent_proportion = projection_along_tangent / edge_lengths
54
+
55
+ is_projection_on_edge = (projection_along_tangent_proportion >= 0.) & (
56
+ projection_along_tangent_proportion <= 1.)
57
+
58
+ edge_perp_distances = edge_signed_perp_distances.abs()
59
+ edge_distances = torch.where(is_projection_on_edge, edge_perp_distances, torch.tensor(np.inf))
60
+
61
+ edge_and_vertex_distance = torch.cat([edge_distances, vertices_distances], dim=-1)
62
+ min_distance = torch.min(edge_and_vertex_distance, dim=-1)[0]
63
+
64
+ signed_distances = torch.where(is_inside, -min_distance, min_distance)
65
+
66
+ return signed_distances
67
+
68
+
69
+ def _get_downmost_edge_in_box(box: Tensor) -> Tuple[Tensor, Tensor]:
70
+ """Finds the downmost (lowest y-coordinate) edge in the box."""
71
+ downmost_vertex_idx = torch.argmin(box[..., 1], dim=-1, keepdim=True)
72
+
73
+ edge_start_vertex = torch.gather(box, 1, downmost_vertex_idx.unsqueeze(-1).expand(-1, -1, 2))
74
+ edge_end_idx = (downmost_vertex_idx + 1) % NUM_VERTICES_IN_BOX
75
+ edge_end_vertex = torch.gather(box, 1, edge_end_idx.unsqueeze(-1).expand(-1, -1, 2))
76
+
77
+ downmost_edge = edge_end_vertex - edge_start_vertex
78
+ downmost_edge_length = torch.norm(downmost_edge, dim=-1, keepdim=True)
79
+ downmost_edge_direction = downmost_edge / downmost_edge_length
80
+
81
+ return downmost_vertex_idx, downmost_edge_direction
82
+
83
+
84
+ def cross_product_2d(a: Tensor, b: Tensor) -> Tensor:
85
+ return a[..., 0] * b[..., 1] - a[..., 1] * b[..., 0]
86
+
87
+
88
+ def dot_product_2d(a: Tensor, b: Tensor) -> Tensor:
89
+ return a[..., 0] * b[..., 0] + a[..., 1] * b[..., 1]
90
+
91
+
92
+ def _get_edge_info(polygon_points: Tensor):
93
+ """
94
+ Computes properties about the edges of a polygon.
95
+
96
+ Args:
97
+ polygon_points: Tensor containing the vertices of each polygon, with
98
+ shape (num_polygons, num_points_per_polygon, 2). Each polygon is assumed
99
+ to have an equal number of vertices.
100
+
101
+ Returns:
102
+ tangent_unit_vectors: A unit vector in (x,y) with the same direction as
103
+ the tangent to the edge. Shape: (num_polygons, num_points_per_polygon, 2).
104
+ normal_unit_vectors: A unit vector in (x,y) with the same direction as
105
+ the normal to the edge.
106
+ Shape: (num_polygons, num_points_per_polygon, 2).
107
+ edge_lengths: Lengths of the edges.
108
+ Shape (num_polygons, num_points_per_polygon).
109
+ """
110
+ # Shift the polygon points by 1 position to get the edges.
111
+ first_point_in_polygon = polygon_points[:, 0:1, :] # Shape: (num_polygons, 1, 2)
112
+ shifted_polygon_points = torch.cat([polygon_points[:, 1:, :], first_point_in_polygon], dim=1)
113
+ # Shape: (num_polygons, num_points_per_polygon, 2)
114
+
115
+ edge_vectors = shifted_polygon_points - polygon_points # Shape: (num_polygons, num_points_per_polygon, 2)
116
+ edge_lengths = torch.norm(edge_vectors, dim=-1) # Shape: (num_polygons, num_points_per_polygon)
117
+
118
+ # Avoid division by zero by adding a small epsilon
119
+ eps = torch.finfo(edge_lengths.dtype).eps
120
+ tangent_unit_vectors = edge_vectors / (edge_lengths[..., None] + eps) # Shape: (num_polygons, num_points_per_polygon, 2)
121
+
122
+ normal_unit_vectors = torch.stack(
123
+ [-tangent_unit_vectors[..., 1], tangent_unit_vectors[..., 0]], dim=-1
124
+ ) # Shape: (num_polygons, num_points_per_polygon, 2)
125
+
126
+ return tangent_unit_vectors, normal_unit_vectors, edge_lengths
127
+
128
+
129
+ def rotate_2d_points(xys: Tensor, rotation_yaws: Tensor) -> Tensor:
130
+ """Rotates `xys` counterclockwise using the `rotation_yaws`."""
131
+ cos_yaws = torch.cos(rotation_yaws)
132
+ sin_yaws = torch.sin(rotation_yaws)
133
+
134
+ rotated_x = cos_yaws * xys[..., 0] - sin_yaws * xys[..., 1]
135
+ rotated_y = sin_yaws * xys[..., 0] + cos_yaws * xys[..., 1]
136
+
137
+ return torch.stack([rotated_x, rotated_y], axis=-1)
backups/dev/metrics/interact_features.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import Tensor
4
+
5
+ from dev.metrics import box_utils
6
+ from dev.metrics import geometry_utils
7
+ from dev.metrics import trajectory_features
8
+
9
+
10
+ EXTREMELY_LARGE_DISTANCE = 1e10
11
+ COLLISION_DISTANCE_THRESHOLD = 0.0
12
+ CORNER_ROUNDING_FACTOR = 0.7
13
+ MAX_HEADING_DIFF = math.radians(75.0)
14
+ MAX_HEADING_DIFF_FOR_SMALL_OVERLAP = math.radians(10.0)
15
+ SMALL_OVERLAP_THRESHOLD = 0.5
16
+ MAXIMUM_TIME_TO_COLLISION = 5.0
17
+
18
+
19
+ def compute_distance_to_nearest_object(
20
+ center_x: Tensor,
21
+ center_y: Tensor,
22
+ center_z: Tensor,
23
+ length: Tensor,
24
+ width: Tensor,
25
+ height: Tensor,
26
+ heading: Tensor,
27
+ valid: Tensor,
28
+ evaluated_object_mask: Tensor,
29
+ corner_rounding_factor: float = CORNER_ROUNDING_FACTOR,
30
+ ) -> Tensor:
31
+ """Computes the distance to nearest object for each of the evaluated objects."""
32
+ boxes = torch.stack([center_x, center_y, center_z, length, width, height, heading], dim=-1)
33
+ num_objects, num_steps, num_features = boxes.shape
34
+
35
+ shrinking_distance = (torch.minimum(boxes[:, :, 3], boxes[:, :, 4]) * corner_rounding_factor / 2.)
36
+
37
+ boxes = torch.cat([
38
+ boxes[:, :, :3],
39
+ boxes[:, :, 3:4] - 2.0 * shrinking_distance[..., None],
40
+ boxes[:, :, 4:5] - 2.0 * shrinking_distance[..., None],
41
+ boxes[:, :, 5:]
42
+ ], dim=2)
43
+
44
+ boxes = boxes.reshape(num_objects * num_steps, num_features)
45
+
46
+ box_corners = box_utils.get_upright_3d_box_corners(boxes)[:, :4, :2]
47
+ box_corners = box_corners.reshape(num_objects, num_steps, 4, 2)
48
+
49
+ eval_corners = box_corners[evaluated_object_mask]
50
+ num_eval_objects = eval_corners.shape[0]
51
+ other_corners = box_corners[~evaluated_object_mask]
52
+ all_corners = torch.cat([eval_corners, other_corners], dim=0)
53
+
54
+ eval_corners = eval_corners.unsqueeze(1).expand(num_eval_objects, num_objects, num_steps, 4, 2)
55
+ all_corners = all_corners.unsqueeze(0).expand(num_eval_objects, num_objects, num_steps, 4, 2)
56
+
57
+ eval_corners = eval_corners.reshape(num_eval_objects * num_objects * num_steps, 4, 2)
58
+ all_corners = all_corners.reshape(num_eval_objects * num_objects * num_steps, 4, 2)
59
+
60
+ neg_all_corners = -1.0 * all_corners
61
+ minkowski_sum = geometry_utils.minkowski_sum_of_box_and_box_points(
62
+ box1_points=eval_corners, box2_points=neg_all_corners,
63
+ )
64
+
65
+ assert minkowski_sum.shape[1:] == (8, 2), f"Shape mismatch: {minkowski_sum.shape}, expected (*, 8, 2)"
66
+ signed_distances_flat = (
67
+ geometry_utils.signed_distance_from_point_to_convex_polygon(
68
+ query_points=torch.zeros_like(minkowski_sum[:, 0, :]),
69
+ polygon_points=minkowski_sum,
70
+ )
71
+ )
72
+
73
+ signed_distances = signed_distances_flat.reshape(num_eval_objects, num_objects, num_steps)
74
+
75
+ eval_shrinking_distance = shrinking_distance[evaluated_object_mask]
76
+ other_shrinking_distance = shrinking_distance[~evaluated_object_mask]
77
+ all_shrinking_distance = torch.cat([eval_shrinking_distance, other_shrinking_distance], dim=0)
78
+
79
+ signed_distances -= eval_shrinking_distance.unsqueeze(1)
80
+ signed_distances -= all_shrinking_distance.unsqueeze(0)
81
+
82
+ self_mask = torch.eye(num_eval_objects, num_objects, dtype=torch.float32)[:, :, None]
83
+ signed_distances = signed_distances + self_mask * EXTREMELY_LARGE_DISTANCE
84
+
85
+ eval_validity = valid[evaluated_object_mask]
86
+ other_validity = valid[~evaluated_object_mask]
87
+ all_validity = torch.cat([eval_validity, other_validity], dim=0)
88
+
89
+ valid_mask = eval_validity.unsqueeze(1) & all_validity.unsqueeze(0)
90
+
91
+ signed_distances = torch.where(valid_mask, signed_distances, EXTREMELY_LARGE_DISTANCE)
92
+
93
+ return torch.min(signed_distances, dim=1).values
94
+
95
+
96
+ def compute_time_to_collision_with_object_in_front(
97
+ *,
98
+ center_x: Tensor,
99
+ center_y: Tensor,
100
+ length: Tensor,
101
+ width: Tensor,
102
+ heading: Tensor,
103
+ valid: Tensor,
104
+ evaluated_object_mask: Tensor,
105
+ seconds_per_step: float,
106
+ ) -> Tensor:
107
+ """Computes the time-to-collision of the evaluated objects."""
108
+ # `speed` shape: (num_objects, num_steps)
109
+ speed = trajectory_features.compute_kinematic_features(
110
+ x=center_x,
111
+ y=center_y,
112
+ z=torch.zeros_like(center_x),
113
+ heading=heading,
114
+ seconds_per_step=seconds_per_step,
115
+ )[0]
116
+
117
+ boxes = torch.stack([center_x, center_y, length, width, heading, speed], dim=-1)
118
+ boxes = boxes.permute(1, 0, 2) # (num_steps, num_objects, 6)
119
+ valid = valid.permute(1, 0)
120
+
121
+ eval_boxes = boxes[:, evaluated_object_mask]
122
+ ego_xy, ego_sizes, ego_yaw, ego_speed = torch.split(eval_boxes, [2, 2, 1, 1], dim=-1)
123
+ other_xy, other_sizes, other_yaw, _ = torch.split(boxes, [2, 2, 1, 1], dim=-1)
124
+
125
+ yaw_diff = torch.abs(other_yaw[:, None] - ego_yaw[:, :, None])
126
+ yaw_diff_cos = torch.cos(yaw_diff)
127
+ yaw_diff_sin = torch.sin(yaw_diff)
128
+
129
+ other_long_offset = geometry_utils.dot_product_2d(
130
+ other_sizes[:, None] / 2.0, torch.abs(torch.cat([yaw_diff_cos, yaw_diff_sin], dim=-1))
131
+ )
132
+ other_lat_offset = geometry_utils.dot_product_2d(
133
+ other_sizes[:, None] / 2.0, torch.abs(torch.cat([yaw_diff_sin, yaw_diff_cos], dim=-1))
134
+ )
135
+
136
+ other_relative_xy = geometry_utils.rotate_2d_points(
137
+ (other_xy[:, None] - ego_xy[:, :, None]), -ego_yaw
138
+ )
139
+
140
+ long_distance = (
141
+ other_relative_xy[..., 0] - ego_sizes[:, :, None, 0] / 2.0 - other_long_offset
142
+ )
143
+ lat_overlap = (
144
+ torch.abs(other_relative_xy[..., 1]) - ego_sizes[:, :, None, 1] / 2.0 - other_lat_offset
145
+ )
146
+
147
+ following_mask = _get_object_following_mask(
148
+ long_distance, lat_overlap, yaw_diff[..., 0]
149
+ )
150
+ valid_mask = valid[:, None] & following_mask
151
+
152
+ masked_long_distance = (
153
+ long_distance + (1.0 - valid_mask.float()) * EXTREMELY_LARGE_DISTANCE
154
+ )
155
+
156
+ box_ahead_index = masked_long_distance.argmin(dim=-1)
157
+ distance_to_box_ahead = torch.gather(
158
+ masked_long_distance, -1, box_ahead_index.unsqueeze(-1)
159
+ ).squeeze(-1)
160
+
161
+ speed_broadcast = speed.T[:, None, :].expand_as(masked_long_distance)
162
+ box_ahead_speed = torch.gather(speed_broadcast, -1, box_ahead_index.unsqueeze(-1)).squeeze(-1)
163
+
164
+ rel_speed = ego_speed[..., 0] - box_ahead_speed
165
+ time_to_collision = torch.where(
166
+ rel_speed > 0.0,
167
+ torch.minimum(distance_to_box_ahead / rel_speed,
168
+ torch.tensor(MAXIMUM_TIME_TO_COLLISION)), # the float will be broadcasted automatically
169
+ MAXIMUM_TIME_TO_COLLISION,
170
+ )
171
+
172
+ return time_to_collision.T
173
+
174
+
175
+ def _get_object_following_mask(
176
+ longitudinal_distance: Tensor,
177
+ lateral_overlap: Tensor,
178
+ yaw_diff: Tensor,
179
+ ) -> Tensor:
180
+ """Checks whether objects satisfy criteria for following another object.
181
+
182
+ An object on which the criteria are applied is called "ego object" in this
183
+ function to disambiguate it from the other objects acting as obstacles.
184
+
185
+ An "ego" object is considered to be following another object if they satisfy
186
+ conditions on the longitudinal distance, lateral overlap, and yaw alignment
187
+ between them.
188
+
189
+ Args:
190
+ longitudinal_distance: A float Tensor with shape (batch_dim, num_egos,
191
+ num_others) containing longitudinal distances from the back side of each
192
+ ego box to every other box.
193
+ lateral_overlap: A float Tensor with shape (batch_dim, num_egos, num_others)
194
+ containing lateral overlaps of other boxes over the trails of ego boxes.
195
+ yaw_diff: A float Tensor with shape (batch_dim, num_egos, num_others)
196
+ containing absolute yaw differences between egos and other boxes.
197
+
198
+ Returns:
199
+ A boolean Tensor with shape (batch_dim, num_egos, num_others) indicating for
200
+ each ego box if it is following the other boxes.
201
+ """
202
+ # Check object is ahead of the ego box's front.
203
+ valid_mask = longitudinal_distance > 0.0
204
+
205
+ # Check alignment.
206
+ valid_mask = torch.logical_and(valid_mask, yaw_diff <= MAX_HEADING_DIFF)
207
+
208
+ # Check object is directly ahead of the ego box.
209
+ valid_mask = torch.logical_and(valid_mask, lateral_overlap < 0.0)
210
+
211
+ # Check strict alignment if the overlap is small.
212
+ # `lateral_overlap` is a signed penetration distance: it is negative when the
213
+ # boxes have an actual lateral overlap.
214
+ return torch.logical_and(
215
+ valid_mask,
216
+ torch.logical_or(
217
+ lateral_overlap < -SMALL_OVERLAP_THRESHOLD,
218
+ yaw_diff <= MAX_HEADING_DIFF_FOR_SMALL_OVERLAP,
219
+ ),
220
+ )
backups/dev/metrics/map_features.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ from typing import Optional, Sequence
4
+
5
+ from dev.metrics import box_utils
6
+ from dev.metrics import geometry_utils
7
+ from dev.metrics.protos import map_pb2
8
+
9
+ # Constant distance to apply when distances are invalid. This will avoid the
10
+ # propagation of nans and should be reduced out when taking the minimum anyway.
11
+ EXTREMELY_LARGE_DISTANCE = 1e10
12
+ # Off-road threshold, i.e. smallest distance away from the road edge that is
13
+ # considered to be a off-road.
14
+ OFFROAD_DISTANCE_THRESHOLD = 0.0
15
+
16
+ # How close the start and end point of a map feature need to be for the feature
17
+ # to be considered cyclic, in m^2.
18
+ _CYCLIC_MAP_FEATURE_TOLERANCE_M2 = 1.0
19
+ # Scaling factor for vertical distances used when finding the closest segment to
20
+ # a query point. This prevents wrong associations in cases with under- and
21
+ # over-passes.
22
+ _Z_STRETCH_FACTOR = 3.0
23
+
24
+ _Polyline = Sequence[map_pb2.MapPoint]
25
+
26
+
27
+ def compute_distance_to_road_edge(
28
+ *,
29
+ center_x: Tensor,
30
+ center_y: Tensor,
31
+ center_z: Tensor,
32
+ length: Tensor,
33
+ width: Tensor,
34
+ height: Tensor,
35
+ heading: Tensor,
36
+ valid: Tensor,
37
+ evaluated_object_mask: Tensor,
38
+ road_edge_polylines: Sequence[_Polyline],
39
+ ) -> Tensor:
40
+ """Computes the distance to the road edge for each of the evaluated objects."""
41
+ if not road_edge_polylines:
42
+ raise ValueError('Missing road edges.')
43
+
44
+ # Concatenate tensors to have the same convention as `box_utils`.
45
+ boxes = torch.stack([center_x, center_y, center_z, length, width, height, heading], dim=-1)
46
+ num_objects, num_steps, num_features = boxes.shape
47
+ boxes = boxes.reshape(num_objects * num_steps, num_features)
48
+
49
+ # Compute box corners using `box_utils`, and take the xyz coords of the bottom corners.
50
+ box_corners = box_utils.get_upright_3d_box_corners(boxes)[:, :4]
51
+ box_corners = box_corners.reshape(num_objects, num_steps, 4, 3)
52
+
53
+ # Gather objects in the evaluation set
54
+ eval_corners = box_corners[evaluated_object_mask]
55
+ num_eval_objects = eval_corners.shape[0]
56
+
57
+ # Flatten query points.
58
+ flat_eval_corners = eval_corners.reshape(-1, 3)
59
+
60
+ # Tensorize road edges.
61
+ polylines_tensor = _tensorize_polylines(road_edge_polylines)
62
+ is_polyline_cyclic = _check_polyline_cycles(road_edge_polylines)
63
+
64
+ # Compute distances for all query points.
65
+ corner_distance_to_road_edge = _compute_signed_distance_to_polylines(
66
+ xyzs=flat_eval_corners, polylines=polylines_tensor,
67
+ is_polyline_cyclic=is_polyline_cyclic, z_stretch=_Z_STRETCH_FACTOR
68
+ )
69
+
70
+ # Reshape back to (num_evaluated_objects, num_steps, 4)
71
+ corner_distance_to_road_edge = corner_distance_to_road_edge.reshape(num_eval_objects, num_steps, 4)
72
+
73
+ # Reduce to most off-road corner.
74
+ signed_distances = torch.max(corner_distance_to_road_edge, dim=-1)[0]
75
+
76
+ # Mask out invalid boxes.
77
+ eval_validity = valid[evaluated_object_mask]
78
+
79
+ return torch.where(eval_validity, signed_distances, -EXTREMELY_LARGE_DISTANCE)
80
+
81
+
82
+ def _tensorize_polylines(polylines):
83
+ """Stacks a sequence of polylines into a tensor.
84
+
85
+ Args:
86
+ polylines: A sequence of Polyline objects.
87
+
88
+ Returns:
89
+ A float tensor with shape (num_polylines, max_length, 4) containing xyz
90
+ coordinates and a validity flag for all points in the polylines. Polylines
91
+ are padded with zeros up to the length of the longest one.
92
+ """
93
+ polyline_tensors = []
94
+ max_length = 0
95
+
96
+ for polyline in polylines:
97
+ # Skip degenerate polylines.
98
+ if len(polyline) < 2:
99
+ continue
100
+ max_length = max(max_length, len(polyline))
101
+ polyline_tensors.append(
102
+ torch.tensor(
103
+ [[map_point.x, map_point.y, map_point.z, 1.0] for map_point in polyline],
104
+ dtype=torch.float32
105
+ )
106
+ )
107
+
108
+ # Pad and stack polylines
109
+ padded_polylines = [
110
+ torch.cat([p, torch.zeros((max_length - p.shape[0], 4), dtype=torch.float32)], dim=0)
111
+ for p in polyline_tensors
112
+ ]
113
+
114
+ return torch.stack(padded_polylines, dim=0)
115
+
116
+
117
+ def _check_polyline_cycles(polylines):
118
+ """Checks if given polylines are cyclic and returns the result as a tensor.
119
+
120
+ Args:
121
+ polylines: A sequence of Polyline objects.
122
+ tolerance: A float representing the cyclic tolerance.
123
+
124
+ Returns:
125
+ A bool tensor with shape (num_polylines) indicating whether each polyline is cyclic.
126
+ """
127
+ cycles = []
128
+ for polyline in polylines:
129
+ # Skip degenerate polylines.
130
+ if len(polyline) < 2:
131
+ continue
132
+ first_point = torch.tensor([polyline[0].x, polyline[0].y, polyline[0].z], dtype=torch.float32)
133
+ last_point = torch.tensor([polyline[-1].x, polyline[-1].y, polyline[-1].z], dtype=torch.float32)
134
+ cycles.append(torch.sum((first_point - last_point) ** 2) < _CYCLIC_MAP_FEATURE_TOLERANCE_M2)
135
+
136
+ return torch.stack(cycles, dim=0)
137
+
138
+
139
+ def _compute_signed_distance_to_polylines(
140
+ xyzs: Tensor,
141
+ polylines: Tensor,
142
+ is_polyline_cyclic: Optional[Tensor] = None,
143
+ z_stretch: float = 1.0,
144
+ ) -> Tensor:
145
+ """Computes the signed distance to the 2D boundary defined by polylines.
146
+
147
+ Negative distances correspond to being inside the boundary (e.g. on the
148
+ road), positive distances to being outside (e.g. off-road).
149
+
150
+ The polylines should be oriented such that port side is inside the boundary
151
+ and starboard is outside, a.k.a counterclockwise winding order.
152
+
153
+ The altitudes i.e. the z-coordinates of query points and polyline segments
154
+ are used to pair each query point with the most relevant segment, that is
155
+ closest and at the right altitude. The distances returned are 2D distances in
156
+ the xy plane.
157
+
158
+ Args:
159
+ xyzs: A float Tensor of shape (num_points, 3) containing xyz coordinates of
160
+ query points.
161
+ polylines: Tensor with shape (num_polylines, num_segments+1, 4) containing
162
+ sequences of xyz coordinates and validity, representing start and end
163
+ points of consecutive segments.
164
+ is_polyline_cyclic: A boolean Tensor with shape (num_polylines) indicating
165
+ whether each polyline is cyclic. If None, all polylines are considered
166
+ non-cyclic.
167
+ z_stretch: Factor by which to scale distances over the z axis. This can be
168
+ done to ensure edge points from the wrong level (e.g. overpasses) are not
169
+ selected. Defaults to 1.0 (no stretching).
170
+
171
+ Returns:
172
+ A tensor of shape (num_points), containing the signed 2D distance from
173
+ queried points to the nearest polyline.
174
+ """
175
+ num_points = xyzs.shape[0]
176
+ assert xyzs.shape == (num_points, 3), f"Expected shape ({num_points}, 3), but got {xyzs.shape}"
177
+ num_polylines = polylines.shape[0]
178
+ num_segments = polylines.shape[1] - 1
179
+ assert polylines.shape == (num_polylines, num_segments + 1, 4), \
180
+ f"Expected shape ({num_polylines}, {num_segments + 1}, 4), but got {polylines.shape}"
181
+
182
+ # shape: (num_polylines, num_segments+1)
183
+ is_point_valid = polylines[:, :, 3].bool()
184
+ # shape: (num_polylines, num_segments)
185
+ is_segment_valid = is_point_valid[:, :-1] & is_point_valid[:, 1:]
186
+
187
+ if is_polyline_cyclic is None:
188
+ is_polyline_cyclic = torch.zeros(num_polylines, dtype=torch.bool)
189
+ else:
190
+ assert is_polyline_cyclic.shape == (num_polylines,), \
191
+ f"Expected shape ({num_polylines},), but got {is_polyline_cyclic.shape}"
192
+
193
+ # Get distance to each segment.
194
+ # shape: (num_points, num_polylines, num_segments, 3)
195
+ xyz_starts = polylines[None, :, :-1, :3]
196
+ xyz_ends = polylines[None, :, 1:, :3]
197
+ start_to_point = xyzs[:, None, None, :3] - xyz_starts
198
+ start_to_end = xyz_ends - xyz_starts
199
+
200
+ # Relative coordinate of point projection on segment.
201
+ # shape: (num_points, num_polylines, num_segments)
202
+ numerator = geometry_utils.dot_product_2d(
203
+ start_to_point[..., :2], start_to_end[..., :2]
204
+ )
205
+ denominator = geometry_utils.dot_product_2d(
206
+ start_to_end[..., :2], start_to_end[..., :2]
207
+ )
208
+ rel_t = torch.where(denominator != 0, numerator / denominator, torch.zeros_like(numerator))
209
+
210
+ # Negative if point is on port side of segment, positive if point on
211
+ # starboard side of segment.
212
+ # shape: (num_points, num_polylines, num_segments)
213
+ n = torch.sign(
214
+ geometry_utils.cross_product_2d(
215
+ start_to_point[..., :2], start_to_end[..., :2]
216
+ )
217
+ )
218
+
219
+ # Compute the absolute 3D distance to segment.
220
+ # The vertical component is scaled by `z-stretch` to increase the separation
221
+ # between different road altitudes.
222
+ # shape: (num_points, num_polylines, num_segments, 3)
223
+ segment_to_point = start_to_point - (
224
+ start_to_end * torch.clamp(rel_t, 0.0, 1.0)[..., None]
225
+ )
226
+ stretch_vector = torch.tensor([1.0, 1.0, z_stretch], dtype=torch.float32)
227
+ distance_to_segment_3d = torch.norm(
228
+ segment_to_point * stretch_vector[None, None, None],
229
+ dim=-1,
230
+ )
231
+
232
+ # Absolute planar distance to segment.
233
+ # shape: (num_points, num_polylines, num_segments)
234
+ distance_to_segment_2d = torch.norm(segment_to_point[..., :2], dim=-1)
235
+
236
+ # Padded start-to-end segments.
237
+ # shape: (num_points, num_polylines, num_segments+2, 2)
238
+ start_to_end_padded = torch.cat(
239
+ [
240
+ start_to_end[:, :, -1:, :2],
241
+ start_to_end[..., :2],
242
+ start_to_end[:, :, :1, :2],
243
+ ],
244
+ dim=-2,
245
+ )
246
+
247
+ # shape: (num_points, num_polylines, num_segments+1)
248
+ is_locally_convex = torch.gt(
249
+ geometry_utils.cross_product_2d(
250
+ start_to_end_padded[:, :, :-1], start_to_end_padded[:, :, 1:]
251
+ ),
252
+ 0.,
253
+ )
254
+
255
+ # Get shifted versions of `n` and `is_segment_valid`. If the polyline is
256
+ # cyclic, the tensors are rolled, else they are padded with their edge value.
257
+ # shape: (num_points, num_polylines, num_segments)
258
+ n_prior = torch.cat(
259
+ [
260
+ torch.where(
261
+ is_polyline_cyclic[None, :, None],
262
+ n[:, :, -1:],
263
+ n[:, :, :1],
264
+ ),
265
+ n[:, :, :-1],
266
+ ],
267
+ dim=-1,
268
+ )
269
+ n_next = torch.cat(
270
+ [
271
+ n[:, :, 1:],
272
+ torch.where(
273
+ is_polyline_cyclic[None, :, None],
274
+ n[:, :, :1],
275
+ n[:, :, -1:],
276
+ ),
277
+ ],
278
+ dim=-1,
279
+ )
280
+ # shape: (num_polylines, num_segments)
281
+ is_prior_segment_valid = torch.cat(
282
+ [
283
+ torch.where(
284
+ is_polyline_cyclic[:, None],
285
+ is_segment_valid[:, -1:],
286
+ is_segment_valid[:, :1],
287
+ ),
288
+ is_segment_valid[:, :-1],
289
+ ],
290
+ dim=-1,
291
+ )
292
+ is_next_segment_valid = torch.cat(
293
+ [
294
+ is_segment_valid[:, 1:],
295
+ torch.where(
296
+ is_polyline_cyclic[:, None],
297
+ is_segment_valid[:, :1],
298
+ is_segment_valid[:, -1:],
299
+ ),
300
+ ],
301
+ dim=-1,
302
+ )
303
+
304
+ # shape: (num_points, num_polylines, num_segments)
305
+ sign_if_before = torch.where(
306
+ is_locally_convex[:, :, :-1],
307
+ torch.maximum(n, n_prior),
308
+ torch.minimum(n, n_prior),
309
+ )
310
+ sign_if_after = torch.where(
311
+ is_locally_convex[:, :, 1:], torch.maximum(n, n_next), torch.minimum(n, n_next)
312
+ )
313
+
314
+ # shape: (num_points, num_polylines, num_segments)
315
+ sign_to_segment = torch.where(
316
+ (rel_t < 0.0) & is_prior_segment_valid,
317
+ sign_if_before,
318
+ torch.where((rel_t > 1.0) & is_next_segment_valid, sign_if_after, n),
319
+ )
320
+
321
+ # Flatten polylines together.
322
+ # shape: (num_points, all_segments)
323
+ distance_to_segment_3d = distance_to_segment_3d.view(num_points, num_polylines * num_segments)
324
+ distance_to_segment_2d = distance_to_segment_2d.view(num_points, num_polylines * num_segments)
325
+ sign_to_segment = sign_to_segment.view(num_points, num_polylines * num_segments)
326
+
327
+ # Mask out invalid segments.
328
+ # shape: (all_segments)
329
+ is_segment_valid = is_segment_valid.view(num_polylines * num_segments)
330
+ # shape: (num_points, all_segments)
331
+ distance_to_segment_3d = torch.where(
332
+ is_segment_valid[None],
333
+ distance_to_segment_3d,
334
+ EXTREMELY_LARGE_DISTANCE,
335
+ )
336
+ distance_to_segment_2d = torch.where(
337
+ is_segment_valid[None],
338
+ distance_to_segment_2d,
339
+ EXTREMELY_LARGE_DISTANCE,
340
+ )
341
+
342
+ # Get closest segment according to absolute 3D distance and return the
343
+ # corresponding signed 2D distance.
344
+ # shape: (num_points)
345
+ closest_segment_index = torch.argmin(distance_to_segment_3d, dim=-1)
346
+ distance_sign = torch.gather(sign_to_segment, 1, closest_segment_index.unsqueeze(-1)).squeeze(-1)
347
+ distance_2d = torch.gather(distance_to_segment_2d, 1, closest_segment_index.unsqueeze(-1)).squeeze(-1)
348
+
349
+ return distance_sign * distance_2d
backups/dev/metrics/placement_features.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ from typing import Optional, Sequence, List
4
+
5
+
6
+ def compute_num_placement(
7
+ valid: Tensor, # [n_agent, n_step]
8
+ state: Tensor, # [n_agent, n_step]
9
+ av_id: int,
10
+ object_id: Tensor,
11
+ agent_state: List[str],
12
+ ) -> Tensor:
13
+
14
+ enter_state = agent_state.index('enter')
15
+ exit_state = agent_state.index('exit')
16
+
17
+ av_index = object_id.tolist().index(av_id)
18
+ state[av_index] = -1 # we do not incorporate the sdc
19
+
20
+ is_bos = state == enter_state
21
+ is_eos = state == exit_state
22
+
23
+ num_bos = torch.sum(is_bos, dim=0)
24
+ num_eos = torch.sum(is_eos, dim=0)
25
+
26
+ return num_bos, num_eos
27
+
28
+
29
+ def compute_distance_placement(
30
+ position: Tensor,
31
+ state: Tensor,
32
+ valid: Tensor,
33
+ av_id: int,
34
+ object_id: Tensor,
35
+ agent_state: List[str],
36
+ ) -> Tensor:
37
+
38
+ enter_state = agent_state.index('enter')
39
+ exit_state = agent_state.index('exit')
40
+
41
+ av_index = object_id.tolist().index(av_id)
42
+ state[av_index] = -1 # we do not incorporate the sdc
43
+ distance = torch.norm(position - position[av_index : av_index + 1], p=2, dim=-1)
44
+
45
+ bos_distance = distance * (state == enter_state)
46
+ eos_distance = distance * (state == exit_state)
47
+
48
+ return bos_distance, eos_distance
backups/dev/metrics/protos/long_metrics_pb2.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: dev/metrics/protos/long_metrics.proto
4
+
5
+ from google.protobuf import descriptor as _descriptor
6
+ from google.protobuf import message as _message
7
+ from google.protobuf import reflection as _reflection
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ # @@protoc_insertion_point(imports)
10
+
11
+ _sym_db = _symbol_database.Default()
12
+
13
+
14
+
15
+
16
+ DESCRIPTOR = _descriptor.FileDescriptor(
17
+ name='dev/metrics/protos/long_metrics.proto',
18
+ package='long_metric',
19
+ syntax='proto2',
20
+ serialized_options=None,
21
+ create_key=_descriptor._internal_create_key,
22
+ serialized_pb=b'\n%dev/metrics/protos/long_metrics.proto\x12\x0blong_metric\"\xb4\x0c\n\x15SimAgentMetricsConfig\x12\x46\n\x0clinear_speed\x18\x01 \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12M\n\x13linear_acceleration\x18\x02 \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12G\n\rangular_speed\x18\x03 \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12N\n\x14\x61ngular_acceleration\x18\x04 \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12T\n\x1a\x64istance_to_nearest_object\x18\x05 \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12N\n\x14\x63ollision_indication\x18\x06 \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12K\n\x11time_to_collision\x18\x07 \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12O\n\x15\x64istance_to_road_edge\x18\x08 \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12L\n\x12offroad_indication\x18\t \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12G\n\rnum_placement\x18\n \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12H\n\x0enum_removement\x18\x0b \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12L\n\x12\x64istance_placement\x18\x0c \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12M\n\x13\x64istance_removement\x18\r \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x1a\xc0\x02\n\rFeatureConfig\x12I\n\thistogram\x18\x01 \x01(\x0b\x32\x34.long_metric.SimAgentMetricsConfig.HistogramEstimateH\x00\x12R\n\x0ekernel_density\x18\x02 \x01(\x0b\x32\x38.long_metric.SimAgentMetricsConfig.KernelDensityEstimateH\x00\x12I\n\tbernoulli\x18\x03 \x01(\x0b\x32\x34.long_metric.SimAgentMetricsConfig.BernoulliEstimateH\x00\x12\x1d\n\x15independent_timesteps\x18\x04 \x01(\x08\x12\x19\n\x11metametric_weight\x18\x05 \x01(\x02\x42\x0b\n\testimator\x1av\n\x11HistogramEstimate\x12\x0f\n\x07min_val\x18\x01 \x01(\x02\x12\x0f\n\x07max_val\x18\x02 \x01(\x02\x12\x10\n\x08num_bins\x18\x03 \x01(\x05\x12-\n\x1e\x61\x64\x64itive_smoothing_pseudocount\x18\x04 \x01(\x02:\x05\x30.001\x1a*\n\x15KernelDensityEstimate\x12\x11\n\tbandwidth\x18\x01 \x01(\x02\x1a\x42\n\x11\x42\x65rnoulliEstimate\x12-\n\x1e\x61\x64\x64itive_smoothing_pseudocount\x18\x04 \x01(\x02:\x05\x30.001\"\xbf\x05\n\x0fSimAgentMetrics\x12\x13\n\x0bscenario_id\x18\x01 \x01(\t\x12\x12\n\nmetametric\x18\x02 \x01(\x02\x12\"\n\x1a\x61verage_displacement_error\x18\x03 \x01(\x02\x12&\n\x1emin_average_displacement_error\x18\x13 \x01(\x02\x12\x1f\n\x17linear_speed_likelihood\x18\x04 \x01(\x02\x12&\n\x1elinear_acceleration_likelihood\x18\x05 \x01(\x02\x12 \n\x18\x61ngular_speed_likelihood\x18\x06 \x01(\x02\x12\'\n\x1f\x61ngular_acceleration_likelihood\x18\x07 \x01(\x02\x12-\n%distance_to_nearest_object_likelihood\x18\x08 \x01(\x02\x12\'\n\x1f\x63ollision_indication_likelihood\x18\t \x01(\x02\x12$\n\x1ctime_to_collision_likelihood\x18\n \x01(\x02\x12(\n distance_to_road_edge_likelihood\x18\x0b \x01(\x02\x12%\n\x1doffroad_indication_likelihood\x18\x0c \x01(\x02\x12 \n\x18num_placement_likelihood\x18\r \x01(\x02\x12!\n\x19num_removement_likelihood\x18\x0e \x01(\x02\x12%\n\x1d\x64istance_placement_likelihood\x18\x0f \x01(\x02\x12&\n\x1e\x64istance_removement_likelihood\x18\x10 \x01(\x02\x12 \n\x18simulated_collision_rate\x18\x11 \x01(\x02\x12\x1e\n\x16simulated_offroad_rate\x18\x12 \x01(\x02\"\xfe\x01\n\x18SimAgentsBucketedMetrics\x12\x1b\n\x13realism_meta_metric\x18\x01 \x01(\x02\x12\x19\n\x11kinematic_metrics\x18\x02 \x01(\x02\x12\x1b\n\x13interactive_metrics\x18\x05 \x01(\x02\x12\x19\n\x11map_based_metrics\x18\x06 \x01(\x02\x12\x1f\n\x17placement_based_metrics\x18\x07 \x01(\x02\x12\x0f\n\x07min_ade\x18\x08 \x01(\x02\x12 \n\x18simulated_collision_rate\x18\t \x01(\x02\x12\x1e\n\x16simulated_offroad_rate\x18\n \x01(\x02'
23
+ )
24
+
25
+
26
+
27
+
28
+ _SIMAGENTMETRICSCONFIG_FEATURECONFIG = _descriptor.Descriptor(
29
+ name='FeatureConfig',
30
+ full_name='long_metric.SimAgentMetricsConfig.FeatureConfig',
31
+ filename=None,
32
+ file=DESCRIPTOR,
33
+ containing_type=None,
34
+ create_key=_descriptor._internal_create_key,
35
+ fields=[
36
+ _descriptor.FieldDescriptor(
37
+ name='histogram', full_name='long_metric.SimAgentMetricsConfig.FeatureConfig.histogram', index=0,
38
+ number=1, type=11, cpp_type=10, label=1,
39
+ has_default_value=False, default_value=None,
40
+ message_type=None, enum_type=None, containing_type=None,
41
+ is_extension=False, extension_scope=None,
42
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
43
+ _descriptor.FieldDescriptor(
44
+ name='kernel_density', full_name='long_metric.SimAgentMetricsConfig.FeatureConfig.kernel_density', index=1,
45
+ number=2, type=11, cpp_type=10, label=1,
46
+ has_default_value=False, default_value=None,
47
+ message_type=None, enum_type=None, containing_type=None,
48
+ is_extension=False, extension_scope=None,
49
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
50
+ _descriptor.FieldDescriptor(
51
+ name='bernoulli', full_name='long_metric.SimAgentMetricsConfig.FeatureConfig.bernoulli', index=2,
52
+ number=3, type=11, cpp_type=10, label=1,
53
+ has_default_value=False, default_value=None,
54
+ message_type=None, enum_type=None, containing_type=None,
55
+ is_extension=False, extension_scope=None,
56
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
57
+ _descriptor.FieldDescriptor(
58
+ name='independent_timesteps', full_name='long_metric.SimAgentMetricsConfig.FeatureConfig.independent_timesteps', index=3,
59
+ number=4, type=8, cpp_type=7, label=1,
60
+ has_default_value=False, default_value=False,
61
+ message_type=None, enum_type=None, containing_type=None,
62
+ is_extension=False, extension_scope=None,
63
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
64
+ _descriptor.FieldDescriptor(
65
+ name='metametric_weight', full_name='long_metric.SimAgentMetricsConfig.FeatureConfig.metametric_weight', index=4,
66
+ number=5, type=2, cpp_type=6, label=1,
67
+ has_default_value=False, default_value=float(0),
68
+ message_type=None, enum_type=None, containing_type=None,
69
+ is_extension=False, extension_scope=None,
70
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
71
+ ],
72
+ extensions=[
73
+ ],
74
+ nested_types=[],
75
+ enum_types=[
76
+ ],
77
+ serialized_options=None,
78
+ is_extendable=False,
79
+ syntax='proto2',
80
+ extension_ranges=[],
81
+ oneofs=[
82
+ _descriptor.OneofDescriptor(
83
+ name='estimator', full_name='long_metric.SimAgentMetricsConfig.FeatureConfig.estimator',
84
+ index=0, containing_type=None,
85
+ create_key=_descriptor._internal_create_key,
86
+ fields=[]),
87
+ ],
88
+ serialized_start=1091,
89
+ serialized_end=1411,
90
+ )
91
+
92
+ _SIMAGENTMETRICSCONFIG_HISTOGRAMESTIMATE = _descriptor.Descriptor(
93
+ name='HistogramEstimate',
94
+ full_name='long_metric.SimAgentMetricsConfig.HistogramEstimate',
95
+ filename=None,
96
+ file=DESCRIPTOR,
97
+ containing_type=None,
98
+ create_key=_descriptor._internal_create_key,
99
+ fields=[
100
+ _descriptor.FieldDescriptor(
101
+ name='min_val', full_name='long_metric.SimAgentMetricsConfig.HistogramEstimate.min_val', index=0,
102
+ number=1, type=2, cpp_type=6, label=1,
103
+ has_default_value=False, default_value=float(0),
104
+ message_type=None, enum_type=None, containing_type=None,
105
+ is_extension=False, extension_scope=None,
106
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
107
+ _descriptor.FieldDescriptor(
108
+ name='max_val', full_name='long_metric.SimAgentMetricsConfig.HistogramEstimate.max_val', index=1,
109
+ number=2, type=2, cpp_type=6, label=1,
110
+ has_default_value=False, default_value=float(0),
111
+ message_type=None, enum_type=None, containing_type=None,
112
+ is_extension=False, extension_scope=None,
113
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
114
+ _descriptor.FieldDescriptor(
115
+ name='num_bins', full_name='long_metric.SimAgentMetricsConfig.HistogramEstimate.num_bins', index=2,
116
+ number=3, type=5, cpp_type=1, label=1,
117
+ has_default_value=False, default_value=0,
118
+ message_type=None, enum_type=None, containing_type=None,
119
+ is_extension=False, extension_scope=None,
120
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
121
+ _descriptor.FieldDescriptor(
122
+ name='additive_smoothing_pseudocount', full_name='long_metric.SimAgentMetricsConfig.HistogramEstimate.additive_smoothing_pseudocount', index=3,
123
+ number=4, type=2, cpp_type=6, label=1,
124
+ has_default_value=True, default_value=float(0.001),
125
+ message_type=None, enum_type=None, containing_type=None,
126
+ is_extension=False, extension_scope=None,
127
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
128
+ ],
129
+ extensions=[
130
+ ],
131
+ nested_types=[],
132
+ enum_types=[
133
+ ],
134
+ serialized_options=None,
135
+ is_extendable=False,
136
+ syntax='proto2',
137
+ extension_ranges=[],
138
+ oneofs=[
139
+ ],
140
+ serialized_start=1413,
141
+ serialized_end=1531,
142
+ )
143
+
144
+ _SIMAGENTMETRICSCONFIG_KERNELDENSITYESTIMATE = _descriptor.Descriptor(
145
+ name='KernelDensityEstimate',
146
+ full_name='long_metric.SimAgentMetricsConfig.KernelDensityEstimate',
147
+ filename=None,
148
+ file=DESCRIPTOR,
149
+ containing_type=None,
150
+ create_key=_descriptor._internal_create_key,
151
+ fields=[
152
+ _descriptor.FieldDescriptor(
153
+ name='bandwidth', full_name='long_metric.SimAgentMetricsConfig.KernelDensityEstimate.bandwidth', index=0,
154
+ number=1, type=2, cpp_type=6, label=1,
155
+ has_default_value=False, default_value=float(0),
156
+ message_type=None, enum_type=None, containing_type=None,
157
+ is_extension=False, extension_scope=None,
158
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
159
+ ],
160
+ extensions=[
161
+ ],
162
+ nested_types=[],
163
+ enum_types=[
164
+ ],
165
+ serialized_options=None,
166
+ is_extendable=False,
167
+ syntax='proto2',
168
+ extension_ranges=[],
169
+ oneofs=[
170
+ ],
171
+ serialized_start=1533,
172
+ serialized_end=1575,
173
+ )
174
+
175
+ _SIMAGENTMETRICSCONFIG_BERNOULLIESTIMATE = _descriptor.Descriptor(
176
+ name='BernoulliEstimate',
177
+ full_name='long_metric.SimAgentMetricsConfig.BernoulliEstimate',
178
+ filename=None,
179
+ file=DESCRIPTOR,
180
+ containing_type=None,
181
+ create_key=_descriptor._internal_create_key,
182
+ fields=[
183
+ _descriptor.FieldDescriptor(
184
+ name='additive_smoothing_pseudocount', full_name='long_metric.SimAgentMetricsConfig.BernoulliEstimate.additive_smoothing_pseudocount', index=0,
185
+ number=4, type=2, cpp_type=6, label=1,
186
+ has_default_value=True, default_value=float(0.001),
187
+ message_type=None, enum_type=None, containing_type=None,
188
+ is_extension=False, extension_scope=None,
189
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
190
+ ],
191
+ extensions=[
192
+ ],
193
+ nested_types=[],
194
+ enum_types=[
195
+ ],
196
+ serialized_options=None,
197
+ is_extendable=False,
198
+ syntax='proto2',
199
+ extension_ranges=[],
200
+ oneofs=[
201
+ ],
202
+ serialized_start=1577,
203
+ serialized_end=1643,
204
+ )
205
+
206
+ _SIMAGENTMETRICSCONFIG = _descriptor.Descriptor(
207
+ name='SimAgentMetricsConfig',
208
+ full_name='long_metric.SimAgentMetricsConfig',
209
+ filename=None,
210
+ file=DESCRIPTOR,
211
+ containing_type=None,
212
+ create_key=_descriptor._internal_create_key,
213
+ fields=[
214
+ _descriptor.FieldDescriptor(
215
+ name='linear_speed', full_name='long_metric.SimAgentMetricsConfig.linear_speed', index=0,
216
+ number=1, type=11, cpp_type=10, label=1,
217
+ has_default_value=False, default_value=None,
218
+ message_type=None, enum_type=None, containing_type=None,
219
+ is_extension=False, extension_scope=None,
220
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
221
+ _descriptor.FieldDescriptor(
222
+ name='linear_acceleration', full_name='long_metric.SimAgentMetricsConfig.linear_acceleration', index=1,
223
+ number=2, type=11, cpp_type=10, label=1,
224
+ has_default_value=False, default_value=None,
225
+ message_type=None, enum_type=None, containing_type=None,
226
+ is_extension=False, extension_scope=None,
227
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
228
+ _descriptor.FieldDescriptor(
229
+ name='angular_speed', full_name='long_metric.SimAgentMetricsConfig.angular_speed', index=2,
230
+ number=3, type=11, cpp_type=10, label=1,
231
+ has_default_value=False, default_value=None,
232
+ message_type=None, enum_type=None, containing_type=None,
233
+ is_extension=False, extension_scope=None,
234
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
235
+ _descriptor.FieldDescriptor(
236
+ name='angular_acceleration', full_name='long_metric.SimAgentMetricsConfig.angular_acceleration', index=3,
237
+ number=4, type=11, cpp_type=10, label=1,
238
+ has_default_value=False, default_value=None,
239
+ message_type=None, enum_type=None, containing_type=None,
240
+ is_extension=False, extension_scope=None,
241
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
242
+ _descriptor.FieldDescriptor(
243
+ name='distance_to_nearest_object', full_name='long_metric.SimAgentMetricsConfig.distance_to_nearest_object', index=4,
244
+ number=5, type=11, cpp_type=10, label=1,
245
+ has_default_value=False, default_value=None,
246
+ message_type=None, enum_type=None, containing_type=None,
247
+ is_extension=False, extension_scope=None,
248
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
249
+ _descriptor.FieldDescriptor(
250
+ name='collision_indication', full_name='long_metric.SimAgentMetricsConfig.collision_indication', index=5,
251
+ number=6, type=11, cpp_type=10, label=1,
252
+ has_default_value=False, default_value=None,
253
+ message_type=None, enum_type=None, containing_type=None,
254
+ is_extension=False, extension_scope=None,
255
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
256
+ _descriptor.FieldDescriptor(
257
+ name='time_to_collision', full_name='long_metric.SimAgentMetricsConfig.time_to_collision', index=6,
258
+ number=7, type=11, cpp_type=10, label=1,
259
+ has_default_value=False, default_value=None,
260
+ message_type=None, enum_type=None, containing_type=None,
261
+ is_extension=False, extension_scope=None,
262
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
263
+ _descriptor.FieldDescriptor(
264
+ name='distance_to_road_edge', full_name='long_metric.SimAgentMetricsConfig.distance_to_road_edge', index=7,
265
+ number=8, type=11, cpp_type=10, label=1,
266
+ has_default_value=False, default_value=None,
267
+ message_type=None, enum_type=None, containing_type=None,
268
+ is_extension=False, extension_scope=None,
269
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
270
+ _descriptor.FieldDescriptor(
271
+ name='offroad_indication', full_name='long_metric.SimAgentMetricsConfig.offroad_indication', index=8,
272
+ number=9, type=11, cpp_type=10, label=1,
273
+ has_default_value=False, default_value=None,
274
+ message_type=None, enum_type=None, containing_type=None,
275
+ is_extension=False, extension_scope=None,
276
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
277
+ _descriptor.FieldDescriptor(
278
+ name='num_placement', full_name='long_metric.SimAgentMetricsConfig.num_placement', index=9,
279
+ number=10, type=11, cpp_type=10, label=1,
280
+ has_default_value=False, default_value=None,
281
+ message_type=None, enum_type=None, containing_type=None,
282
+ is_extension=False, extension_scope=None,
283
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
284
+ _descriptor.FieldDescriptor(
285
+ name='num_removement', full_name='long_metric.SimAgentMetricsConfig.num_removement', index=10,
286
+ number=11, type=11, cpp_type=10, label=1,
287
+ has_default_value=False, default_value=None,
288
+ message_type=None, enum_type=None, containing_type=None,
289
+ is_extension=False, extension_scope=None,
290
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
291
+ _descriptor.FieldDescriptor(
292
+ name='distance_placement', full_name='long_metric.SimAgentMetricsConfig.distance_placement', index=11,
293
+ number=12, type=11, cpp_type=10, label=1,
294
+ has_default_value=False, default_value=None,
295
+ message_type=None, enum_type=None, containing_type=None,
296
+ is_extension=False, extension_scope=None,
297
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
298
+ _descriptor.FieldDescriptor(
299
+ name='distance_removement', full_name='long_metric.SimAgentMetricsConfig.distance_removement', index=12,
300
+ number=13, type=11, cpp_type=10, label=1,
301
+ has_default_value=False, default_value=None,
302
+ message_type=None, enum_type=None, containing_type=None,
303
+ is_extension=False, extension_scope=None,
304
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
305
+ ],
306
+ extensions=[
307
+ ],
308
+ nested_types=[_SIMAGENTMETRICSCONFIG_FEATURECONFIG, _SIMAGENTMETRICSCONFIG_HISTOGRAMESTIMATE, _SIMAGENTMETRICSCONFIG_KERNELDENSITYESTIMATE, _SIMAGENTMETRICSCONFIG_BERNOULLIESTIMATE, ],
309
+ enum_types=[
310
+ ],
311
+ serialized_options=None,
312
+ is_extendable=False,
313
+ syntax='proto2',
314
+ extension_ranges=[],
315
+ oneofs=[
316
+ ],
317
+ serialized_start=55,
318
+ serialized_end=1643,
319
+ )
320
+
321
+
322
+ _SIMAGENTMETRICS = _descriptor.Descriptor(
323
+ name='SimAgentMetrics',
324
+ full_name='long_metric.SimAgentMetrics',
325
+ filename=None,
326
+ file=DESCRIPTOR,
327
+ containing_type=None,
328
+ create_key=_descriptor._internal_create_key,
329
+ fields=[
330
+ _descriptor.FieldDescriptor(
331
+ name='scenario_id', full_name='long_metric.SimAgentMetrics.scenario_id', index=0,
332
+ number=1, type=9, cpp_type=9, label=1,
333
+ has_default_value=False, default_value=b"".decode('utf-8'),
334
+ message_type=None, enum_type=None, containing_type=None,
335
+ is_extension=False, extension_scope=None,
336
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
337
+ _descriptor.FieldDescriptor(
338
+ name='metametric', full_name='long_metric.SimAgentMetrics.metametric', index=1,
339
+ number=2, type=2, cpp_type=6, label=1,
340
+ has_default_value=False, default_value=float(0),
341
+ message_type=None, enum_type=None, containing_type=None,
342
+ is_extension=False, extension_scope=None,
343
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
344
+ _descriptor.FieldDescriptor(
345
+ name='average_displacement_error', full_name='long_metric.SimAgentMetrics.average_displacement_error', index=2,
346
+ number=3, type=2, cpp_type=6, label=1,
347
+ has_default_value=False, default_value=float(0),
348
+ message_type=None, enum_type=None, containing_type=None,
349
+ is_extension=False, extension_scope=None,
350
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
351
+ _descriptor.FieldDescriptor(
352
+ name='min_average_displacement_error', full_name='long_metric.SimAgentMetrics.min_average_displacement_error', index=3,
353
+ number=19, type=2, cpp_type=6, label=1,
354
+ has_default_value=False, default_value=float(0),
355
+ message_type=None, enum_type=None, containing_type=None,
356
+ is_extension=False, extension_scope=None,
357
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
358
+ _descriptor.FieldDescriptor(
359
+ name='linear_speed_likelihood', full_name='long_metric.SimAgentMetrics.linear_speed_likelihood', index=4,
360
+ number=4, type=2, cpp_type=6, label=1,
361
+ has_default_value=False, default_value=float(0),
362
+ message_type=None, enum_type=None, containing_type=None,
363
+ is_extension=False, extension_scope=None,
364
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
365
+ _descriptor.FieldDescriptor(
366
+ name='linear_acceleration_likelihood', full_name='long_metric.SimAgentMetrics.linear_acceleration_likelihood', index=5,
367
+ number=5, type=2, cpp_type=6, label=1,
368
+ has_default_value=False, default_value=float(0),
369
+ message_type=None, enum_type=None, containing_type=None,
370
+ is_extension=False, extension_scope=None,
371
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
372
+ _descriptor.FieldDescriptor(
373
+ name='angular_speed_likelihood', full_name='long_metric.SimAgentMetrics.angular_speed_likelihood', index=6,
374
+ number=6, type=2, cpp_type=6, label=1,
375
+ has_default_value=False, default_value=float(0),
376
+ message_type=None, enum_type=None, containing_type=None,
377
+ is_extension=False, extension_scope=None,
378
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
379
+ _descriptor.FieldDescriptor(
380
+ name='angular_acceleration_likelihood', full_name='long_metric.SimAgentMetrics.angular_acceleration_likelihood', index=7,
381
+ number=7, type=2, cpp_type=6, label=1,
382
+ has_default_value=False, default_value=float(0),
383
+ message_type=None, enum_type=None, containing_type=None,
384
+ is_extension=False, extension_scope=None,
385
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
386
+ _descriptor.FieldDescriptor(
387
+ name='distance_to_nearest_object_likelihood', full_name='long_metric.SimAgentMetrics.distance_to_nearest_object_likelihood', index=8,
388
+ number=8, type=2, cpp_type=6, label=1,
389
+ has_default_value=False, default_value=float(0),
390
+ message_type=None, enum_type=None, containing_type=None,
391
+ is_extension=False, extension_scope=None,
392
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
393
+ _descriptor.FieldDescriptor(
394
+ name='collision_indication_likelihood', full_name='long_metric.SimAgentMetrics.collision_indication_likelihood', index=9,
395
+ number=9, type=2, cpp_type=6, label=1,
396
+ has_default_value=False, default_value=float(0),
397
+ message_type=None, enum_type=None, containing_type=None,
398
+ is_extension=False, extension_scope=None,
399
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
400
+ _descriptor.FieldDescriptor(
401
+ name='time_to_collision_likelihood', full_name='long_metric.SimAgentMetrics.time_to_collision_likelihood', index=10,
402
+ number=10, type=2, cpp_type=6, label=1,
403
+ has_default_value=False, default_value=float(0),
404
+ message_type=None, enum_type=None, containing_type=None,
405
+ is_extension=False, extension_scope=None,
406
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
407
+ _descriptor.FieldDescriptor(
408
+ name='distance_to_road_edge_likelihood', full_name='long_metric.SimAgentMetrics.distance_to_road_edge_likelihood', index=11,
409
+ number=11, type=2, cpp_type=6, label=1,
410
+ has_default_value=False, default_value=float(0),
411
+ message_type=None, enum_type=None, containing_type=None,
412
+ is_extension=False, extension_scope=None,
413
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
414
+ _descriptor.FieldDescriptor(
415
+ name='offroad_indication_likelihood', full_name='long_metric.SimAgentMetrics.offroad_indication_likelihood', index=12,
416
+ number=12, type=2, cpp_type=6, label=1,
417
+ has_default_value=False, default_value=float(0),
418
+ message_type=None, enum_type=None, containing_type=None,
419
+ is_extension=False, extension_scope=None,
420
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
421
+ _descriptor.FieldDescriptor(
422
+ name='num_placement_likelihood', full_name='long_metric.SimAgentMetrics.num_placement_likelihood', index=13,
423
+ number=13, type=2, cpp_type=6, label=1,
424
+ has_default_value=False, default_value=float(0),
425
+ message_type=None, enum_type=None, containing_type=None,
426
+ is_extension=False, extension_scope=None,
427
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
428
+ _descriptor.FieldDescriptor(
429
+ name='num_removement_likelihood', full_name='long_metric.SimAgentMetrics.num_removement_likelihood', index=14,
430
+ number=14, type=2, cpp_type=6, label=1,
431
+ has_default_value=False, default_value=float(0),
432
+ message_type=None, enum_type=None, containing_type=None,
433
+ is_extension=False, extension_scope=None,
434
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
435
+ _descriptor.FieldDescriptor(
436
+ name='distance_placement_likelihood', full_name='long_metric.SimAgentMetrics.distance_placement_likelihood', index=15,
437
+ number=15, type=2, cpp_type=6, label=1,
438
+ has_default_value=False, default_value=float(0),
439
+ message_type=None, enum_type=None, containing_type=None,
440
+ is_extension=False, extension_scope=None,
441
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
442
+ _descriptor.FieldDescriptor(
443
+ name='distance_removement_likelihood', full_name='long_metric.SimAgentMetrics.distance_removement_likelihood', index=16,
444
+ number=16, type=2, cpp_type=6, label=1,
445
+ has_default_value=False, default_value=float(0),
446
+ message_type=None, enum_type=None, containing_type=None,
447
+ is_extension=False, extension_scope=None,
448
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
449
+ _descriptor.FieldDescriptor(
450
+ name='simulated_collision_rate', full_name='long_metric.SimAgentMetrics.simulated_collision_rate', index=17,
451
+ number=17, type=2, cpp_type=6, label=1,
452
+ has_default_value=False, default_value=float(0),
453
+ message_type=None, enum_type=None, containing_type=None,
454
+ is_extension=False, extension_scope=None,
455
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
456
+ _descriptor.FieldDescriptor(
457
+ name='simulated_offroad_rate', full_name='long_metric.SimAgentMetrics.simulated_offroad_rate', index=18,
458
+ number=18, type=2, cpp_type=6, label=1,
459
+ has_default_value=False, default_value=float(0),
460
+ message_type=None, enum_type=None, containing_type=None,
461
+ is_extension=False, extension_scope=None,
462
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
463
+ ],
464
+ extensions=[
465
+ ],
466
+ nested_types=[],
467
+ enum_types=[
468
+ ],
469
+ serialized_options=None,
470
+ is_extendable=False,
471
+ syntax='proto2',
472
+ extension_ranges=[],
473
+ oneofs=[
474
+ ],
475
+ serialized_start=1646,
476
+ serialized_end=2349,
477
+ )
478
+
479
+
480
+ _SIMAGENTSBUCKETEDMETRICS = _descriptor.Descriptor(
481
+ name='SimAgentsBucketedMetrics',
482
+ full_name='long_metric.SimAgentsBucketedMetrics',
483
+ filename=None,
484
+ file=DESCRIPTOR,
485
+ containing_type=None,
486
+ create_key=_descriptor._internal_create_key,
487
+ fields=[
488
+ _descriptor.FieldDescriptor(
489
+ name='realism_meta_metric', full_name='long_metric.SimAgentsBucketedMetrics.realism_meta_metric', index=0,
490
+ number=1, type=2, cpp_type=6, label=1,
491
+ has_default_value=False, default_value=float(0),
492
+ message_type=None, enum_type=None, containing_type=None,
493
+ is_extension=False, extension_scope=None,
494
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
495
+ _descriptor.FieldDescriptor(
496
+ name='kinematic_metrics', full_name='long_metric.SimAgentsBucketedMetrics.kinematic_metrics', index=1,
497
+ number=2, type=2, cpp_type=6, label=1,
498
+ has_default_value=False, default_value=float(0),
499
+ message_type=None, enum_type=None, containing_type=None,
500
+ is_extension=False, extension_scope=None,
501
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
502
+ _descriptor.FieldDescriptor(
503
+ name='interactive_metrics', full_name='long_metric.SimAgentsBucketedMetrics.interactive_metrics', index=2,
504
+ number=5, type=2, cpp_type=6, label=1,
505
+ has_default_value=False, default_value=float(0),
506
+ message_type=None, enum_type=None, containing_type=None,
507
+ is_extension=False, extension_scope=None,
508
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
509
+ _descriptor.FieldDescriptor(
510
+ name='map_based_metrics', full_name='long_metric.SimAgentsBucketedMetrics.map_based_metrics', index=3,
511
+ number=6, type=2, cpp_type=6, label=1,
512
+ has_default_value=False, default_value=float(0),
513
+ message_type=None, enum_type=None, containing_type=None,
514
+ is_extension=False, extension_scope=None,
515
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
516
+ _descriptor.FieldDescriptor(
517
+ name='placement_based_metrics', full_name='long_metric.SimAgentsBucketedMetrics.placement_based_metrics', index=4,
518
+ number=7, type=2, cpp_type=6, label=1,
519
+ has_default_value=False, default_value=float(0),
520
+ message_type=None, enum_type=None, containing_type=None,
521
+ is_extension=False, extension_scope=None,
522
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
523
+ _descriptor.FieldDescriptor(
524
+ name='min_ade', full_name='long_metric.SimAgentsBucketedMetrics.min_ade', index=5,
525
+ number=8, type=2, cpp_type=6, label=1,
526
+ has_default_value=False, default_value=float(0),
527
+ message_type=None, enum_type=None, containing_type=None,
528
+ is_extension=False, extension_scope=None,
529
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
530
+ _descriptor.FieldDescriptor(
531
+ name='simulated_collision_rate', full_name='long_metric.SimAgentsBucketedMetrics.simulated_collision_rate', index=6,
532
+ number=9, type=2, cpp_type=6, label=1,
533
+ has_default_value=False, default_value=float(0),
534
+ message_type=None, enum_type=None, containing_type=None,
535
+ is_extension=False, extension_scope=None,
536
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
537
+ _descriptor.FieldDescriptor(
538
+ name='simulated_offroad_rate', full_name='long_metric.SimAgentsBucketedMetrics.simulated_offroad_rate', index=7,
539
+ number=10, type=2, cpp_type=6, label=1,
540
+ has_default_value=False, default_value=float(0),
541
+ message_type=None, enum_type=None, containing_type=None,
542
+ is_extension=False, extension_scope=None,
543
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
544
+ ],
545
+ extensions=[
546
+ ],
547
+ nested_types=[],
548
+ enum_types=[
549
+ ],
550
+ serialized_options=None,
551
+ is_extendable=False,
552
+ syntax='proto2',
553
+ extension_ranges=[],
554
+ oneofs=[
555
+ ],
556
+ serialized_start=2352,
557
+ serialized_end=2606,
558
+ )
559
+
560
+ _SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['histogram'].message_type = _SIMAGENTMETRICSCONFIG_HISTOGRAMESTIMATE
561
+ _SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['kernel_density'].message_type = _SIMAGENTMETRICSCONFIG_KERNELDENSITYESTIMATE
562
+ _SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['bernoulli'].message_type = _SIMAGENTMETRICSCONFIG_BERNOULLIESTIMATE
563
+ _SIMAGENTMETRICSCONFIG_FEATURECONFIG.containing_type = _SIMAGENTMETRICSCONFIG
564
+ _SIMAGENTMETRICSCONFIG_FEATURECONFIG.oneofs_by_name['estimator'].fields.append(
565
+ _SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['histogram'])
566
+ _SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['histogram'].containing_oneof = _SIMAGENTMETRICSCONFIG_FEATURECONFIG.oneofs_by_name['estimator']
567
+ _SIMAGENTMETRICSCONFIG_FEATURECONFIG.oneofs_by_name['estimator'].fields.append(
568
+ _SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['kernel_density'])
569
+ _SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['kernel_density'].containing_oneof = _SIMAGENTMETRICSCONFIG_FEATURECONFIG.oneofs_by_name['estimator']
570
+ _SIMAGENTMETRICSCONFIG_FEATURECONFIG.oneofs_by_name['estimator'].fields.append(
571
+ _SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['bernoulli'])
572
+ _SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['bernoulli'].containing_oneof = _SIMAGENTMETRICSCONFIG_FEATURECONFIG.oneofs_by_name['estimator']
573
+ _SIMAGENTMETRICSCONFIG_HISTOGRAMESTIMATE.containing_type = _SIMAGENTMETRICSCONFIG
574
+ _SIMAGENTMETRICSCONFIG_KERNELDENSITYESTIMATE.containing_type = _SIMAGENTMETRICSCONFIG
575
+ _SIMAGENTMETRICSCONFIG_BERNOULLIESTIMATE.containing_type = _SIMAGENTMETRICSCONFIG
576
+ _SIMAGENTMETRICSCONFIG.fields_by_name['linear_speed'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
577
+ _SIMAGENTMETRICSCONFIG.fields_by_name['linear_acceleration'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
578
+ _SIMAGENTMETRICSCONFIG.fields_by_name['angular_speed'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
579
+ _SIMAGENTMETRICSCONFIG.fields_by_name['angular_acceleration'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
580
+ _SIMAGENTMETRICSCONFIG.fields_by_name['distance_to_nearest_object'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
581
+ _SIMAGENTMETRICSCONFIG.fields_by_name['collision_indication'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
582
+ _SIMAGENTMETRICSCONFIG.fields_by_name['time_to_collision'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
583
+ _SIMAGENTMETRICSCONFIG.fields_by_name['distance_to_road_edge'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
584
+ _SIMAGENTMETRICSCONFIG.fields_by_name['offroad_indication'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
585
+ _SIMAGENTMETRICSCONFIG.fields_by_name['num_placement'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
586
+ _SIMAGENTMETRICSCONFIG.fields_by_name['num_removement'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
587
+ _SIMAGENTMETRICSCONFIG.fields_by_name['distance_placement'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
588
+ _SIMAGENTMETRICSCONFIG.fields_by_name['distance_removement'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
589
+ DESCRIPTOR.message_types_by_name['SimAgentMetricsConfig'] = _SIMAGENTMETRICSCONFIG
590
+ DESCRIPTOR.message_types_by_name['SimAgentMetrics'] = _SIMAGENTMETRICS
591
+ DESCRIPTOR.message_types_by_name['SimAgentsBucketedMetrics'] = _SIMAGENTSBUCKETEDMETRICS
592
+ _sym_db.RegisterFileDescriptor(DESCRIPTOR)
593
+
594
+ SimAgentMetricsConfig = _reflection.GeneratedProtocolMessageType('SimAgentMetricsConfig', (_message.Message,), {
595
+
596
+ 'FeatureConfig' : _reflection.GeneratedProtocolMessageType('FeatureConfig', (_message.Message,), {
597
+ 'DESCRIPTOR' : _SIMAGENTMETRICSCONFIG_FEATURECONFIG,
598
+ '__module__' : 'dev.metrics.protos.long_metrics_pb2'
599
+ # @@protoc_insertion_point(class_scope:long_metric.SimAgentMetricsConfig.FeatureConfig)
600
+ })
601
+ ,
602
+
603
+ 'HistogramEstimate' : _reflection.GeneratedProtocolMessageType('HistogramEstimate', (_message.Message,), {
604
+ 'DESCRIPTOR' : _SIMAGENTMETRICSCONFIG_HISTOGRAMESTIMATE,
605
+ '__module__' : 'dev.metrics.protos.long_metrics_pb2'
606
+ # @@protoc_insertion_point(class_scope:long_metric.SimAgentMetricsConfig.HistogramEstimate)
607
+ })
608
+ ,
609
+
610
+ 'KernelDensityEstimate' : _reflection.GeneratedProtocolMessageType('KernelDensityEstimate', (_message.Message,), {
611
+ 'DESCRIPTOR' : _SIMAGENTMETRICSCONFIG_KERNELDENSITYESTIMATE,
612
+ '__module__' : 'dev.metrics.protos.long_metrics_pb2'
613
+ # @@protoc_insertion_point(class_scope:long_metric.SimAgentMetricsConfig.KernelDensityEstimate)
614
+ })
615
+ ,
616
+
617
+ 'BernoulliEstimate' : _reflection.GeneratedProtocolMessageType('BernoulliEstimate', (_message.Message,), {
618
+ 'DESCRIPTOR' : _SIMAGENTMETRICSCONFIG_BERNOULLIESTIMATE,
619
+ '__module__' : 'dev.metrics.protos.long_metrics_pb2'
620
+ # @@protoc_insertion_point(class_scope:long_metric.SimAgentMetricsConfig.BernoulliEstimate)
621
+ })
622
+ ,
623
+ 'DESCRIPTOR' : _SIMAGENTMETRICSCONFIG,
624
+ '__module__' : 'dev.metrics.protos.long_metrics_pb2'
625
+ # @@protoc_insertion_point(class_scope:long_metric.SimAgentMetricsConfig)
626
+ })
627
+ _sym_db.RegisterMessage(SimAgentMetricsConfig)
628
+ _sym_db.RegisterMessage(SimAgentMetricsConfig.FeatureConfig)
629
+ _sym_db.RegisterMessage(SimAgentMetricsConfig.HistogramEstimate)
630
+ _sym_db.RegisterMessage(SimAgentMetricsConfig.KernelDensityEstimate)
631
+ _sym_db.RegisterMessage(SimAgentMetricsConfig.BernoulliEstimate)
632
+
633
+ SimAgentMetrics = _reflection.GeneratedProtocolMessageType('SimAgentMetrics', (_message.Message,), {
634
+ 'DESCRIPTOR' : _SIMAGENTMETRICS,
635
+ '__module__' : 'dev.metrics.protos.long_metrics_pb2'
636
+ # @@protoc_insertion_point(class_scope:long_metric.SimAgentMetrics)
637
+ })
638
+ _sym_db.RegisterMessage(SimAgentMetrics)
639
+
640
+ SimAgentsBucketedMetrics = _reflection.GeneratedProtocolMessageType('SimAgentsBucketedMetrics', (_message.Message,), {
641
+ 'DESCRIPTOR' : _SIMAGENTSBUCKETEDMETRICS,
642
+ '__module__' : 'dev.metrics.protos.long_metrics_pb2'
643
+ # @@protoc_insertion_point(class_scope:long_metric.SimAgentsBucketedMetrics)
644
+ })
645
+ _sym_db.RegisterMessage(SimAgentsBucketedMetrics)
646
+
647
+
648
+ # @@protoc_insertion_point(module_scope)
backups/dev/metrics/protos/map_pb2.py ADDED
@@ -0,0 +1,1070 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: dev/metrics/protos/map.proto
4
+
5
+ from google.protobuf import descriptor as _descriptor
6
+ from google.protobuf import message as _message
7
+ from google.protobuf import reflection as _reflection
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ # @@protoc_insertion_point(imports)
10
+
11
+ _sym_db = _symbol_database.Default()
12
+
13
+
14
+
15
+
16
+ DESCRIPTOR = _descriptor.FileDescriptor(
17
+ name='dev/metrics/protos/map.proto',
18
+ package='long_metric',
19
+ syntax='proto2',
20
+ serialized_options=None,
21
+ create_key=_descriptor._internal_create_key,
22
+ serialized_pb=b'\n\x1c\x64\x65v/metrics/protos/map.proto\x12\x0blong_metric\"g\n\x03Map\x12-\n\x0cmap_features\x18\x01 \x03(\x0b\x32\x17.long_metric.MapFeature\x12\x31\n\x0e\x64ynamic_states\x18\x02 \x03(\x0b\x32\x19.long_metric.DynamicState\"c\n\x0c\x44ynamicState\x12\x19\n\x11timestamp_seconds\x18\x01 \x01(\x01\x12\x38\n\x0blane_states\x18\x02 \x03(\x0b\x32#.long_metric.TrafficSignalLaneState\"\xfe\x02\n\x16TrafficSignalLaneState\x12\x0c\n\x04lane\x18\x01 \x01(\x03\x12\x38\n\x05state\x18\x02 \x01(\x0e\x32).long_metric.TrafficSignalLaneState.State\x12)\n\nstop_point\x18\x03 \x01(\x0b\x32\x15.long_metric.MapPoint\"\xf0\x01\n\x05State\x12\x16\n\x12LANE_STATE_UNKNOWN\x10\x00\x12\x19\n\x15LANE_STATE_ARROW_STOP\x10\x01\x12\x1c\n\x18LANE_STATE_ARROW_CAUTION\x10\x02\x12\x17\n\x13LANE_STATE_ARROW_GO\x10\x03\x12\x13\n\x0fLANE_STATE_STOP\x10\x04\x12\x16\n\x12LANE_STATE_CAUTION\x10\x05\x12\x11\n\rLANE_STATE_GO\x10\x06\x12\x1c\n\x18LANE_STATE_FLASHING_STOP\x10\x07\x12\x1f\n\x1bLANE_STATE_FLASHING_CAUTION\x10\x08\"\xdb\x02\n\nMapFeature\x12\n\n\x02id\x18\x01 \x01(\x03\x12\'\n\x04lane\x18\x03 \x01(\x0b\x32\x17.long_metric.LaneCenterH\x00\x12*\n\troad_line\x18\x04 \x01(\x0b\x32\x15.long_metric.RoadLineH\x00\x12*\n\troad_edge\x18\x05 \x01(\x0b\x32\x15.long_metric.RoadEdgeH\x00\x12*\n\tstop_sign\x18\x07 \x01(\x0b\x32\x15.long_metric.StopSignH\x00\x12+\n\tcrosswalk\x18\x08 \x01(\x0b\x32\x16.long_metric.CrosswalkH\x00\x12,\n\nspeed_bump\x18\t \x01(\x0b\x32\x16.long_metric.SpeedBumpH\x00\x12)\n\x08\x64riveway\x18\n \x01(\x0b\x32\x15.long_metric.DrivewayH\x00\x42\x0e\n\x0c\x66\x65\x61ture_data\"+\n\x08MapPoint\x12\t\n\x01x\x18\x01 \x01(\x01\x12\t\n\x01y\x18\x02 \x01(\x01\x12\t\n\x01z\x18\x03 \x01(\x01\"\x9b\x01\n\x0f\x42oundarySegment\x12\x18\n\x10lane_start_index\x18\x01 \x01(\x05\x12\x16\n\x0elane_end_index\x18\x02 \x01(\x05\x12\x1b\n\x13\x62oundary_feature_id\x18\x03 \x01(\x03\x12\x39\n\rboundary_type\x18\x04 \x01(\x0e\x32\".long_metric.RoadLine.RoadLineType\"\xc0\x01\n\x0cLaneNeighbor\x12\x12\n\nfeature_id\x18\x01 \x01(\x03\x12\x18\n\x10self_start_index\x18\x02 \x01(\x05\x12\x16\n\x0eself_end_index\x18\x03 \x01(\x05\x12\x1c\n\x14neighbor_start_index\x18\x04 \x01(\x05\x12\x1a\n\x12neighbor_end_index\x18\x05 \x01(\x05\x12\x30\n\nboundaries\x18\x06 \x03(\x0b\x32\x1c.long_metric.BoundarySegment\"\xfb\x03\n\nLaneCenter\x12\x17\n\x0fspeed_limit_mph\x18\x01 \x01(\x01\x12.\n\x04type\x18\x02 \x01(\x0e\x32 .long_metric.LaneCenter.LaneType\x12\x15\n\rinterpolating\x18\x03 \x01(\x08\x12\'\n\x08polyline\x18\x08 \x03(\x0b\x32\x15.long_metric.MapPoint\x12\x17\n\x0b\x65ntry_lanes\x18\t \x03(\x03\x42\x02\x10\x01\x12\x16\n\nexit_lanes\x18\n \x03(\x03\x42\x02\x10\x01\x12\x35\n\x0fleft_boundaries\x18\r \x03(\x0b\x32\x1c.long_metric.BoundarySegment\x12\x36\n\x10right_boundaries\x18\x0e \x03(\x0b\x32\x1c.long_metric.BoundarySegment\x12\x31\n\x0eleft_neighbors\x18\x0b \x03(\x0b\x32\x19.long_metric.LaneNeighbor\x12\x32\n\x0fright_neighbors\x18\x0c \x03(\x0b\x32\x19.long_metric.LaneNeighbor\"]\n\x08LaneType\x12\x12\n\x0eTYPE_UNDEFINED\x10\x00\x12\x10\n\x0cTYPE_FREEWAY\x10\x01\x12\x17\n\x13TYPE_SURFACE_STREET\x10\x02\x12\x12\n\x0eTYPE_BIKE_LANE\x10\x03\"\xbf\x01\n\x08RoadEdge\x12\x30\n\x04type\x18\x01 \x01(\x0e\x32\".long_metric.RoadEdge.RoadEdgeType\x12\'\n\x08polyline\x18\x02 \x03(\x0b\x32\x15.long_metric.MapPoint\"X\n\x0cRoadEdgeType\x12\x10\n\x0cTYPE_UNKNOWN\x10\x00\x12\x1b\n\x17TYPE_ROAD_EDGE_BOUNDARY\x10\x01\x12\x19\n\x15TYPE_ROAD_EDGE_MEDIAN\x10\x02\"\xfa\x02\n\x08RoadLine\x12\x30\n\x04type\x18\x01 \x01(\x0e\x32\".long_metric.RoadLine.RoadLineType\x12\'\n\x08polyline\x18\x02 \x03(\x0b\x32\x15.long_metric.MapPoint\"\x92\x02\n\x0cRoadLineType\x12\x10\n\x0cTYPE_UNKNOWN\x10\x00\x12\x1c\n\x18TYPE_BROKEN_SINGLE_WHITE\x10\x01\x12\x1b\n\x17TYPE_SOLID_SINGLE_WHITE\x10\x02\x12\x1b\n\x17TYPE_SOLID_DOUBLE_WHITE\x10\x03\x12\x1d\n\x19TYPE_BROKEN_SINGLE_YELLOW\x10\x04\x12\x1d\n\x19TYPE_BROKEN_DOUBLE_YELLOW\x10\x05\x12\x1c\n\x18TYPE_SOLID_SINGLE_YELLOW\x10\x06\x12\x1c\n\x18TYPE_SOLID_DOUBLE_YELLOW\x10\x07\x12\x1e\n\x1aTYPE_PASSING_DOUBLE_YELLOW\x10\x08\"A\n\x08StopSign\x12\x0c\n\x04lane\x18\x01 \x03(\x03\x12\'\n\x08position\x18\x02 \x01(\x0b\x32\x15.long_metric.MapPoint\"3\n\tCrosswalk\x12&\n\x07polygon\x18\x01 \x03(\x0b\x32\x15.long_metric.MapPoint\"3\n\tSpeedBump\x12&\n\x07polygon\x18\x01 \x03(\x0b\x32\x15.long_metric.MapPoint\"2\n\x08\x44riveway\x12&\n\x07polygon\x18\x01 \x03(\x0b\x32\x15.long_metric.MapPoint'
23
+ )
24
+
25
+
26
+
27
+ _TRAFFICSIGNALLANESTATE_STATE = _descriptor.EnumDescriptor(
28
+ name='State',
29
+ full_name='long_metric.TrafficSignalLaneState.State',
30
+ filename=None,
31
+ file=DESCRIPTOR,
32
+ create_key=_descriptor._internal_create_key,
33
+ values=[
34
+ _descriptor.EnumValueDescriptor(
35
+ name='LANE_STATE_UNKNOWN', index=0, number=0,
36
+ serialized_options=None,
37
+ type=None,
38
+ create_key=_descriptor._internal_create_key),
39
+ _descriptor.EnumValueDescriptor(
40
+ name='LANE_STATE_ARROW_STOP', index=1, number=1,
41
+ serialized_options=None,
42
+ type=None,
43
+ create_key=_descriptor._internal_create_key),
44
+ _descriptor.EnumValueDescriptor(
45
+ name='LANE_STATE_ARROW_CAUTION', index=2, number=2,
46
+ serialized_options=None,
47
+ type=None,
48
+ create_key=_descriptor._internal_create_key),
49
+ _descriptor.EnumValueDescriptor(
50
+ name='LANE_STATE_ARROW_GO', index=3, number=3,
51
+ serialized_options=None,
52
+ type=None,
53
+ create_key=_descriptor._internal_create_key),
54
+ _descriptor.EnumValueDescriptor(
55
+ name='LANE_STATE_STOP', index=4, number=4,
56
+ serialized_options=None,
57
+ type=None,
58
+ create_key=_descriptor._internal_create_key),
59
+ _descriptor.EnumValueDescriptor(
60
+ name='LANE_STATE_CAUTION', index=5, number=5,
61
+ serialized_options=None,
62
+ type=None,
63
+ create_key=_descriptor._internal_create_key),
64
+ _descriptor.EnumValueDescriptor(
65
+ name='LANE_STATE_GO', index=6, number=6,
66
+ serialized_options=None,
67
+ type=None,
68
+ create_key=_descriptor._internal_create_key),
69
+ _descriptor.EnumValueDescriptor(
70
+ name='LANE_STATE_FLASHING_STOP', index=7, number=7,
71
+ serialized_options=None,
72
+ type=None,
73
+ create_key=_descriptor._internal_create_key),
74
+ _descriptor.EnumValueDescriptor(
75
+ name='LANE_STATE_FLASHING_CAUTION', index=8, number=8,
76
+ serialized_options=None,
77
+ type=None,
78
+ create_key=_descriptor._internal_create_key),
79
+ ],
80
+ containing_type=None,
81
+ serialized_options=None,
82
+ serialized_start=394,
83
+ serialized_end=634,
84
+ )
85
+ _sym_db.RegisterEnumDescriptor(_TRAFFICSIGNALLANESTATE_STATE)
86
+
87
+ _LANECENTER_LANETYPE = _descriptor.EnumDescriptor(
88
+ name='LaneType',
89
+ full_name='long_metric.LaneCenter.LaneType',
90
+ filename=None,
91
+ file=DESCRIPTOR,
92
+ create_key=_descriptor._internal_create_key,
93
+ values=[
94
+ _descriptor.EnumValueDescriptor(
95
+ name='TYPE_UNDEFINED', index=0, number=0,
96
+ serialized_options=None,
97
+ type=None,
98
+ create_key=_descriptor._internal_create_key),
99
+ _descriptor.EnumValueDescriptor(
100
+ name='TYPE_FREEWAY', index=1, number=1,
101
+ serialized_options=None,
102
+ type=None,
103
+ create_key=_descriptor._internal_create_key),
104
+ _descriptor.EnumValueDescriptor(
105
+ name='TYPE_SURFACE_STREET', index=2, number=2,
106
+ serialized_options=None,
107
+ type=None,
108
+ create_key=_descriptor._internal_create_key),
109
+ _descriptor.EnumValueDescriptor(
110
+ name='TYPE_BIKE_LANE', index=3, number=3,
111
+ serialized_options=None,
112
+ type=None,
113
+ create_key=_descriptor._internal_create_key),
114
+ ],
115
+ containing_type=None,
116
+ serialized_options=None,
117
+ serialized_start=1799,
118
+ serialized_end=1892,
119
+ )
120
+ _sym_db.RegisterEnumDescriptor(_LANECENTER_LANETYPE)
121
+
122
+ _ROADEDGE_ROADEDGETYPE = _descriptor.EnumDescriptor(
123
+ name='RoadEdgeType',
124
+ full_name='long_metric.RoadEdge.RoadEdgeType',
125
+ filename=None,
126
+ file=DESCRIPTOR,
127
+ create_key=_descriptor._internal_create_key,
128
+ values=[
129
+ _descriptor.EnumValueDescriptor(
130
+ name='TYPE_UNKNOWN', index=0, number=0,
131
+ serialized_options=None,
132
+ type=None,
133
+ create_key=_descriptor._internal_create_key),
134
+ _descriptor.EnumValueDescriptor(
135
+ name='TYPE_ROAD_EDGE_BOUNDARY', index=1, number=1,
136
+ serialized_options=None,
137
+ type=None,
138
+ create_key=_descriptor._internal_create_key),
139
+ _descriptor.EnumValueDescriptor(
140
+ name='TYPE_ROAD_EDGE_MEDIAN', index=2, number=2,
141
+ serialized_options=None,
142
+ type=None,
143
+ create_key=_descriptor._internal_create_key),
144
+ ],
145
+ containing_type=None,
146
+ serialized_options=None,
147
+ serialized_start=1998,
148
+ serialized_end=2086,
149
+ )
150
+ _sym_db.RegisterEnumDescriptor(_ROADEDGE_ROADEDGETYPE)
151
+
152
+ _ROADLINE_ROADLINETYPE = _descriptor.EnumDescriptor(
153
+ name='RoadLineType',
154
+ full_name='long_metric.RoadLine.RoadLineType',
155
+ filename=None,
156
+ file=DESCRIPTOR,
157
+ create_key=_descriptor._internal_create_key,
158
+ values=[
159
+ _descriptor.EnumValueDescriptor(
160
+ name='TYPE_UNKNOWN', index=0, number=0,
161
+ serialized_options=None,
162
+ type=None,
163
+ create_key=_descriptor._internal_create_key),
164
+ _descriptor.EnumValueDescriptor(
165
+ name='TYPE_BROKEN_SINGLE_WHITE', index=1, number=1,
166
+ serialized_options=None,
167
+ type=None,
168
+ create_key=_descriptor._internal_create_key),
169
+ _descriptor.EnumValueDescriptor(
170
+ name='TYPE_SOLID_SINGLE_WHITE', index=2, number=2,
171
+ serialized_options=None,
172
+ type=None,
173
+ create_key=_descriptor._internal_create_key),
174
+ _descriptor.EnumValueDescriptor(
175
+ name='TYPE_SOLID_DOUBLE_WHITE', index=3, number=3,
176
+ serialized_options=None,
177
+ type=None,
178
+ create_key=_descriptor._internal_create_key),
179
+ _descriptor.EnumValueDescriptor(
180
+ name='TYPE_BROKEN_SINGLE_YELLOW', index=4, number=4,
181
+ serialized_options=None,
182
+ type=None,
183
+ create_key=_descriptor._internal_create_key),
184
+ _descriptor.EnumValueDescriptor(
185
+ name='TYPE_BROKEN_DOUBLE_YELLOW', index=5, number=5,
186
+ serialized_options=None,
187
+ type=None,
188
+ create_key=_descriptor._internal_create_key),
189
+ _descriptor.EnumValueDescriptor(
190
+ name='TYPE_SOLID_SINGLE_YELLOW', index=6, number=6,
191
+ serialized_options=None,
192
+ type=None,
193
+ create_key=_descriptor._internal_create_key),
194
+ _descriptor.EnumValueDescriptor(
195
+ name='TYPE_SOLID_DOUBLE_YELLOW', index=7, number=7,
196
+ serialized_options=None,
197
+ type=None,
198
+ create_key=_descriptor._internal_create_key),
199
+ _descriptor.EnumValueDescriptor(
200
+ name='TYPE_PASSING_DOUBLE_YELLOW', index=8, number=8,
201
+ serialized_options=None,
202
+ type=None,
203
+ create_key=_descriptor._internal_create_key),
204
+ ],
205
+ containing_type=None,
206
+ serialized_options=None,
207
+ serialized_start=2193,
208
+ serialized_end=2467,
209
+ )
210
+ _sym_db.RegisterEnumDescriptor(_ROADLINE_ROADLINETYPE)
211
+
212
+
213
+ _MAP = _descriptor.Descriptor(
214
+ name='Map',
215
+ full_name='long_metric.Map',
216
+ filename=None,
217
+ file=DESCRIPTOR,
218
+ containing_type=None,
219
+ create_key=_descriptor._internal_create_key,
220
+ fields=[
221
+ _descriptor.FieldDescriptor(
222
+ name='map_features', full_name='long_metric.Map.map_features', index=0,
223
+ number=1, type=11, cpp_type=10, label=3,
224
+ has_default_value=False, default_value=[],
225
+ message_type=None, enum_type=None, containing_type=None,
226
+ is_extension=False, extension_scope=None,
227
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
228
+ _descriptor.FieldDescriptor(
229
+ name='dynamic_states', full_name='long_metric.Map.dynamic_states', index=1,
230
+ number=2, type=11, cpp_type=10, label=3,
231
+ has_default_value=False, default_value=[],
232
+ message_type=None, enum_type=None, containing_type=None,
233
+ is_extension=False, extension_scope=None,
234
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
235
+ ],
236
+ extensions=[
237
+ ],
238
+ nested_types=[],
239
+ enum_types=[
240
+ ],
241
+ serialized_options=None,
242
+ is_extendable=False,
243
+ syntax='proto2',
244
+ extension_ranges=[],
245
+ oneofs=[
246
+ ],
247
+ serialized_start=45,
248
+ serialized_end=148,
249
+ )
250
+
251
+
252
+ _DYNAMICSTATE = _descriptor.Descriptor(
253
+ name='DynamicState',
254
+ full_name='long_metric.DynamicState',
255
+ filename=None,
256
+ file=DESCRIPTOR,
257
+ containing_type=None,
258
+ create_key=_descriptor._internal_create_key,
259
+ fields=[
260
+ _descriptor.FieldDescriptor(
261
+ name='timestamp_seconds', full_name='long_metric.DynamicState.timestamp_seconds', index=0,
262
+ number=1, type=1, cpp_type=5, label=1,
263
+ has_default_value=False, default_value=float(0),
264
+ message_type=None, enum_type=None, containing_type=None,
265
+ is_extension=False, extension_scope=None,
266
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
267
+ _descriptor.FieldDescriptor(
268
+ name='lane_states', full_name='long_metric.DynamicState.lane_states', index=1,
269
+ number=2, type=11, cpp_type=10, label=3,
270
+ has_default_value=False, default_value=[],
271
+ message_type=None, enum_type=None, containing_type=None,
272
+ is_extension=False, extension_scope=None,
273
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
274
+ ],
275
+ extensions=[
276
+ ],
277
+ nested_types=[],
278
+ enum_types=[
279
+ ],
280
+ serialized_options=None,
281
+ is_extendable=False,
282
+ syntax='proto2',
283
+ extension_ranges=[],
284
+ oneofs=[
285
+ ],
286
+ serialized_start=150,
287
+ serialized_end=249,
288
+ )
289
+
290
+
291
+ _TRAFFICSIGNALLANESTATE = _descriptor.Descriptor(
292
+ name='TrafficSignalLaneState',
293
+ full_name='long_metric.TrafficSignalLaneState',
294
+ filename=None,
295
+ file=DESCRIPTOR,
296
+ containing_type=None,
297
+ create_key=_descriptor._internal_create_key,
298
+ fields=[
299
+ _descriptor.FieldDescriptor(
300
+ name='lane', full_name='long_metric.TrafficSignalLaneState.lane', index=0,
301
+ number=1, type=3, cpp_type=2, label=1,
302
+ has_default_value=False, default_value=0,
303
+ message_type=None, enum_type=None, containing_type=None,
304
+ is_extension=False, extension_scope=None,
305
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
306
+ _descriptor.FieldDescriptor(
307
+ name='state', full_name='long_metric.TrafficSignalLaneState.state', index=1,
308
+ number=2, type=14, cpp_type=8, label=1,
309
+ has_default_value=False, default_value=0,
310
+ message_type=None, enum_type=None, containing_type=None,
311
+ is_extension=False, extension_scope=None,
312
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
313
+ _descriptor.FieldDescriptor(
314
+ name='stop_point', full_name='long_metric.TrafficSignalLaneState.stop_point', index=2,
315
+ number=3, type=11, cpp_type=10, label=1,
316
+ has_default_value=False, default_value=None,
317
+ message_type=None, enum_type=None, containing_type=None,
318
+ is_extension=False, extension_scope=None,
319
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
320
+ ],
321
+ extensions=[
322
+ ],
323
+ nested_types=[],
324
+ enum_types=[
325
+ _TRAFFICSIGNALLANESTATE_STATE,
326
+ ],
327
+ serialized_options=None,
328
+ is_extendable=False,
329
+ syntax='proto2',
330
+ extension_ranges=[],
331
+ oneofs=[
332
+ ],
333
+ serialized_start=252,
334
+ serialized_end=634,
335
+ )
336
+
337
+
338
+ _MAPFEATURE = _descriptor.Descriptor(
339
+ name='MapFeature',
340
+ full_name='long_metric.MapFeature',
341
+ filename=None,
342
+ file=DESCRIPTOR,
343
+ containing_type=None,
344
+ create_key=_descriptor._internal_create_key,
345
+ fields=[
346
+ _descriptor.FieldDescriptor(
347
+ name='id', full_name='long_metric.MapFeature.id', index=0,
348
+ number=1, type=3, cpp_type=2, label=1,
349
+ has_default_value=False, default_value=0,
350
+ message_type=None, enum_type=None, containing_type=None,
351
+ is_extension=False, extension_scope=None,
352
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
353
+ _descriptor.FieldDescriptor(
354
+ name='lane', full_name='long_metric.MapFeature.lane', index=1,
355
+ number=3, type=11, cpp_type=10, label=1,
356
+ has_default_value=False, default_value=None,
357
+ message_type=None, enum_type=None, containing_type=None,
358
+ is_extension=False, extension_scope=None,
359
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
360
+ _descriptor.FieldDescriptor(
361
+ name='road_line', full_name='long_metric.MapFeature.road_line', index=2,
362
+ number=4, type=11, cpp_type=10, label=1,
363
+ has_default_value=False, default_value=None,
364
+ message_type=None, enum_type=None, containing_type=None,
365
+ is_extension=False, extension_scope=None,
366
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
367
+ _descriptor.FieldDescriptor(
368
+ name='road_edge', full_name='long_metric.MapFeature.road_edge', index=3,
369
+ number=5, type=11, cpp_type=10, label=1,
370
+ has_default_value=False, default_value=None,
371
+ message_type=None, enum_type=None, containing_type=None,
372
+ is_extension=False, extension_scope=None,
373
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
374
+ _descriptor.FieldDescriptor(
375
+ name='stop_sign', full_name='long_metric.MapFeature.stop_sign', index=4,
376
+ number=7, type=11, cpp_type=10, label=1,
377
+ has_default_value=False, default_value=None,
378
+ message_type=None, enum_type=None, containing_type=None,
379
+ is_extension=False, extension_scope=None,
380
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
381
+ _descriptor.FieldDescriptor(
382
+ name='crosswalk', full_name='long_metric.MapFeature.crosswalk', index=5,
383
+ number=8, type=11, cpp_type=10, label=1,
384
+ has_default_value=False, default_value=None,
385
+ message_type=None, enum_type=None, containing_type=None,
386
+ is_extension=False, extension_scope=None,
387
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
388
+ _descriptor.FieldDescriptor(
389
+ name='speed_bump', full_name='long_metric.MapFeature.speed_bump', index=6,
390
+ number=9, type=11, cpp_type=10, label=1,
391
+ has_default_value=False, default_value=None,
392
+ message_type=None, enum_type=None, containing_type=None,
393
+ is_extension=False, extension_scope=None,
394
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
395
+ _descriptor.FieldDescriptor(
396
+ name='driveway', full_name='long_metric.MapFeature.driveway', index=7,
397
+ number=10, type=11, cpp_type=10, label=1,
398
+ has_default_value=False, default_value=None,
399
+ message_type=None, enum_type=None, containing_type=None,
400
+ is_extension=False, extension_scope=None,
401
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
402
+ ],
403
+ extensions=[
404
+ ],
405
+ nested_types=[],
406
+ enum_types=[
407
+ ],
408
+ serialized_options=None,
409
+ is_extendable=False,
410
+ syntax='proto2',
411
+ extension_ranges=[],
412
+ oneofs=[
413
+ _descriptor.OneofDescriptor(
414
+ name='feature_data', full_name='long_metric.MapFeature.feature_data',
415
+ index=0, containing_type=None,
416
+ create_key=_descriptor._internal_create_key,
417
+ fields=[]),
418
+ ],
419
+ serialized_start=637,
420
+ serialized_end=984,
421
+ )
422
+
423
+
424
+ _MAPPOINT = _descriptor.Descriptor(
425
+ name='MapPoint',
426
+ full_name='long_metric.MapPoint',
427
+ filename=None,
428
+ file=DESCRIPTOR,
429
+ containing_type=None,
430
+ create_key=_descriptor._internal_create_key,
431
+ fields=[
432
+ _descriptor.FieldDescriptor(
433
+ name='x', full_name='long_metric.MapPoint.x', index=0,
434
+ number=1, type=1, cpp_type=5, label=1,
435
+ has_default_value=False, default_value=float(0),
436
+ message_type=None, enum_type=None, containing_type=None,
437
+ is_extension=False, extension_scope=None,
438
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
439
+ _descriptor.FieldDescriptor(
440
+ name='y', full_name='long_metric.MapPoint.y', index=1,
441
+ number=2, type=1, cpp_type=5, label=1,
442
+ has_default_value=False, default_value=float(0),
443
+ message_type=None, enum_type=None, containing_type=None,
444
+ is_extension=False, extension_scope=None,
445
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
446
+ _descriptor.FieldDescriptor(
447
+ name='z', full_name='long_metric.MapPoint.z', index=2,
448
+ number=3, type=1, cpp_type=5, label=1,
449
+ has_default_value=False, default_value=float(0),
450
+ message_type=None, enum_type=None, containing_type=None,
451
+ is_extension=False, extension_scope=None,
452
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
453
+ ],
454
+ extensions=[
455
+ ],
456
+ nested_types=[],
457
+ enum_types=[
458
+ ],
459
+ serialized_options=None,
460
+ is_extendable=False,
461
+ syntax='proto2',
462
+ extension_ranges=[],
463
+ oneofs=[
464
+ ],
465
+ serialized_start=986,
466
+ serialized_end=1029,
467
+ )
468
+
469
+
470
+ _BOUNDARYSEGMENT = _descriptor.Descriptor(
471
+ name='BoundarySegment',
472
+ full_name='long_metric.BoundarySegment',
473
+ filename=None,
474
+ file=DESCRIPTOR,
475
+ containing_type=None,
476
+ create_key=_descriptor._internal_create_key,
477
+ fields=[
478
+ _descriptor.FieldDescriptor(
479
+ name='lane_start_index', full_name='long_metric.BoundarySegment.lane_start_index', index=0,
480
+ number=1, type=5, cpp_type=1, label=1,
481
+ has_default_value=False, default_value=0,
482
+ message_type=None, enum_type=None, containing_type=None,
483
+ is_extension=False, extension_scope=None,
484
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
485
+ _descriptor.FieldDescriptor(
486
+ name='lane_end_index', full_name='long_metric.BoundarySegment.lane_end_index', index=1,
487
+ number=2, type=5, cpp_type=1, label=1,
488
+ has_default_value=False, default_value=0,
489
+ message_type=None, enum_type=None, containing_type=None,
490
+ is_extension=False, extension_scope=None,
491
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
492
+ _descriptor.FieldDescriptor(
493
+ name='boundary_feature_id', full_name='long_metric.BoundarySegment.boundary_feature_id', index=2,
494
+ number=3, type=3, cpp_type=2, label=1,
495
+ has_default_value=False, default_value=0,
496
+ message_type=None, enum_type=None, containing_type=None,
497
+ is_extension=False, extension_scope=None,
498
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
499
+ _descriptor.FieldDescriptor(
500
+ name='boundary_type', full_name='long_metric.BoundarySegment.boundary_type', index=3,
501
+ number=4, type=14, cpp_type=8, label=1,
502
+ has_default_value=False, default_value=0,
503
+ message_type=None, enum_type=None, containing_type=None,
504
+ is_extension=False, extension_scope=None,
505
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
506
+ ],
507
+ extensions=[
508
+ ],
509
+ nested_types=[],
510
+ enum_types=[
511
+ ],
512
+ serialized_options=None,
513
+ is_extendable=False,
514
+ syntax='proto2',
515
+ extension_ranges=[],
516
+ oneofs=[
517
+ ],
518
+ serialized_start=1032,
519
+ serialized_end=1187,
520
+ )
521
+
522
+
523
+ _LANENEIGHBOR = _descriptor.Descriptor(
524
+ name='LaneNeighbor',
525
+ full_name='long_metric.LaneNeighbor',
526
+ filename=None,
527
+ file=DESCRIPTOR,
528
+ containing_type=None,
529
+ create_key=_descriptor._internal_create_key,
530
+ fields=[
531
+ _descriptor.FieldDescriptor(
532
+ name='feature_id', full_name='long_metric.LaneNeighbor.feature_id', index=0,
533
+ number=1, type=3, cpp_type=2, label=1,
534
+ has_default_value=False, default_value=0,
535
+ message_type=None, enum_type=None, containing_type=None,
536
+ is_extension=False, extension_scope=None,
537
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
538
+ _descriptor.FieldDescriptor(
539
+ name='self_start_index', full_name='long_metric.LaneNeighbor.self_start_index', index=1,
540
+ number=2, type=5, cpp_type=1, label=1,
541
+ has_default_value=False, default_value=0,
542
+ message_type=None, enum_type=None, containing_type=None,
543
+ is_extension=False, extension_scope=None,
544
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
545
+ _descriptor.FieldDescriptor(
546
+ name='self_end_index', full_name='long_metric.LaneNeighbor.self_end_index', index=2,
547
+ number=3, type=5, cpp_type=1, label=1,
548
+ has_default_value=False, default_value=0,
549
+ message_type=None, enum_type=None, containing_type=None,
550
+ is_extension=False, extension_scope=None,
551
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
552
+ _descriptor.FieldDescriptor(
553
+ name='neighbor_start_index', full_name='long_metric.LaneNeighbor.neighbor_start_index', index=3,
554
+ number=4, type=5, cpp_type=1, label=1,
555
+ has_default_value=False, default_value=0,
556
+ message_type=None, enum_type=None, containing_type=None,
557
+ is_extension=False, extension_scope=None,
558
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
559
+ _descriptor.FieldDescriptor(
560
+ name='neighbor_end_index', full_name='long_metric.LaneNeighbor.neighbor_end_index', index=4,
561
+ number=5, type=5, cpp_type=1, label=1,
562
+ has_default_value=False, default_value=0,
563
+ message_type=None, enum_type=None, containing_type=None,
564
+ is_extension=False, extension_scope=None,
565
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
566
+ _descriptor.FieldDescriptor(
567
+ name='boundaries', full_name='long_metric.LaneNeighbor.boundaries', index=5,
568
+ number=6, type=11, cpp_type=10, label=3,
569
+ has_default_value=False, default_value=[],
570
+ message_type=None, enum_type=None, containing_type=None,
571
+ is_extension=False, extension_scope=None,
572
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
573
+ ],
574
+ extensions=[
575
+ ],
576
+ nested_types=[],
577
+ enum_types=[
578
+ ],
579
+ serialized_options=None,
580
+ is_extendable=False,
581
+ syntax='proto2',
582
+ extension_ranges=[],
583
+ oneofs=[
584
+ ],
585
+ serialized_start=1190,
586
+ serialized_end=1382,
587
+ )
588
+
589
+
590
+ _LANECENTER = _descriptor.Descriptor(
591
+ name='LaneCenter',
592
+ full_name='long_metric.LaneCenter',
593
+ filename=None,
594
+ file=DESCRIPTOR,
595
+ containing_type=None,
596
+ create_key=_descriptor._internal_create_key,
597
+ fields=[
598
+ _descriptor.FieldDescriptor(
599
+ name='speed_limit_mph', full_name='long_metric.LaneCenter.speed_limit_mph', index=0,
600
+ number=1, type=1, cpp_type=5, label=1,
601
+ has_default_value=False, default_value=float(0),
602
+ message_type=None, enum_type=None, containing_type=None,
603
+ is_extension=False, extension_scope=None,
604
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
605
+ _descriptor.FieldDescriptor(
606
+ name='type', full_name='long_metric.LaneCenter.type', index=1,
607
+ number=2, type=14, cpp_type=8, label=1,
608
+ has_default_value=False, default_value=0,
609
+ message_type=None, enum_type=None, containing_type=None,
610
+ is_extension=False, extension_scope=None,
611
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
612
+ _descriptor.FieldDescriptor(
613
+ name='interpolating', full_name='long_metric.LaneCenter.interpolating', index=2,
614
+ number=3, type=8, cpp_type=7, label=1,
615
+ has_default_value=False, default_value=False,
616
+ message_type=None, enum_type=None, containing_type=None,
617
+ is_extension=False, extension_scope=None,
618
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
619
+ _descriptor.FieldDescriptor(
620
+ name='polyline', full_name='long_metric.LaneCenter.polyline', index=3,
621
+ number=8, type=11, cpp_type=10, label=3,
622
+ has_default_value=False, default_value=[],
623
+ message_type=None, enum_type=None, containing_type=None,
624
+ is_extension=False, extension_scope=None,
625
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
626
+ _descriptor.FieldDescriptor(
627
+ name='entry_lanes', full_name='long_metric.LaneCenter.entry_lanes', index=4,
628
+ number=9, type=3, cpp_type=2, label=3,
629
+ has_default_value=False, default_value=[],
630
+ message_type=None, enum_type=None, containing_type=None,
631
+ is_extension=False, extension_scope=None,
632
+ serialized_options=b'\020\001', file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
633
+ _descriptor.FieldDescriptor(
634
+ name='exit_lanes', full_name='long_metric.LaneCenter.exit_lanes', index=5,
635
+ number=10, type=3, cpp_type=2, label=3,
636
+ has_default_value=False, default_value=[],
637
+ message_type=None, enum_type=None, containing_type=None,
638
+ is_extension=False, extension_scope=None,
639
+ serialized_options=b'\020\001', file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
640
+ _descriptor.FieldDescriptor(
641
+ name='left_boundaries', full_name='long_metric.LaneCenter.left_boundaries', index=6,
642
+ number=13, type=11, cpp_type=10, label=3,
643
+ has_default_value=False, default_value=[],
644
+ message_type=None, enum_type=None, containing_type=None,
645
+ is_extension=False, extension_scope=None,
646
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
647
+ _descriptor.FieldDescriptor(
648
+ name='right_boundaries', full_name='long_metric.LaneCenter.right_boundaries', index=7,
649
+ number=14, type=11, cpp_type=10, label=3,
650
+ has_default_value=False, default_value=[],
651
+ message_type=None, enum_type=None, containing_type=None,
652
+ is_extension=False, extension_scope=None,
653
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
654
+ _descriptor.FieldDescriptor(
655
+ name='left_neighbors', full_name='long_metric.LaneCenter.left_neighbors', index=8,
656
+ number=11, type=11, cpp_type=10, label=3,
657
+ has_default_value=False, default_value=[],
658
+ message_type=None, enum_type=None, containing_type=None,
659
+ is_extension=False, extension_scope=None,
660
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
661
+ _descriptor.FieldDescriptor(
662
+ name='right_neighbors', full_name='long_metric.LaneCenter.right_neighbors', index=9,
663
+ number=12, type=11, cpp_type=10, label=3,
664
+ has_default_value=False, default_value=[],
665
+ message_type=None, enum_type=None, containing_type=None,
666
+ is_extension=False, extension_scope=None,
667
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
668
+ ],
669
+ extensions=[
670
+ ],
671
+ nested_types=[],
672
+ enum_types=[
673
+ _LANECENTER_LANETYPE,
674
+ ],
675
+ serialized_options=None,
676
+ is_extendable=False,
677
+ syntax='proto2',
678
+ extension_ranges=[],
679
+ oneofs=[
680
+ ],
681
+ serialized_start=1385,
682
+ serialized_end=1892,
683
+ )
684
+
685
+
686
+ _ROADEDGE = _descriptor.Descriptor(
687
+ name='RoadEdge',
688
+ full_name='long_metric.RoadEdge',
689
+ filename=None,
690
+ file=DESCRIPTOR,
691
+ containing_type=None,
692
+ create_key=_descriptor._internal_create_key,
693
+ fields=[
694
+ _descriptor.FieldDescriptor(
695
+ name='type', full_name='long_metric.RoadEdge.type', index=0,
696
+ number=1, type=14, cpp_type=8, label=1,
697
+ has_default_value=False, default_value=0,
698
+ message_type=None, enum_type=None, containing_type=None,
699
+ is_extension=False, extension_scope=None,
700
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
701
+ _descriptor.FieldDescriptor(
702
+ name='polyline', full_name='long_metric.RoadEdge.polyline', index=1,
703
+ number=2, type=11, cpp_type=10, label=3,
704
+ has_default_value=False, default_value=[],
705
+ message_type=None, enum_type=None, containing_type=None,
706
+ is_extension=False, extension_scope=None,
707
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
708
+ ],
709
+ extensions=[
710
+ ],
711
+ nested_types=[],
712
+ enum_types=[
713
+ _ROADEDGE_ROADEDGETYPE,
714
+ ],
715
+ serialized_options=None,
716
+ is_extendable=False,
717
+ syntax='proto2',
718
+ extension_ranges=[],
719
+ oneofs=[
720
+ ],
721
+ serialized_start=1895,
722
+ serialized_end=2086,
723
+ )
724
+
725
+
726
+ _ROADLINE = _descriptor.Descriptor(
727
+ name='RoadLine',
728
+ full_name='long_metric.RoadLine',
729
+ filename=None,
730
+ file=DESCRIPTOR,
731
+ containing_type=None,
732
+ create_key=_descriptor._internal_create_key,
733
+ fields=[
734
+ _descriptor.FieldDescriptor(
735
+ name='type', full_name='long_metric.RoadLine.type', index=0,
736
+ number=1, type=14, cpp_type=8, label=1,
737
+ has_default_value=False, default_value=0,
738
+ message_type=None, enum_type=None, containing_type=None,
739
+ is_extension=False, extension_scope=None,
740
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
741
+ _descriptor.FieldDescriptor(
742
+ name='polyline', full_name='long_metric.RoadLine.polyline', index=1,
743
+ number=2, type=11, cpp_type=10, label=3,
744
+ has_default_value=False, default_value=[],
745
+ message_type=None, enum_type=None, containing_type=None,
746
+ is_extension=False, extension_scope=None,
747
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
748
+ ],
749
+ extensions=[
750
+ ],
751
+ nested_types=[],
752
+ enum_types=[
753
+ _ROADLINE_ROADLINETYPE,
754
+ ],
755
+ serialized_options=None,
756
+ is_extendable=False,
757
+ syntax='proto2',
758
+ extension_ranges=[],
759
+ oneofs=[
760
+ ],
761
+ serialized_start=2089,
762
+ serialized_end=2467,
763
+ )
764
+
765
+
766
+ _STOPSIGN = _descriptor.Descriptor(
767
+ name='StopSign',
768
+ full_name='long_metric.StopSign',
769
+ filename=None,
770
+ file=DESCRIPTOR,
771
+ containing_type=None,
772
+ create_key=_descriptor._internal_create_key,
773
+ fields=[
774
+ _descriptor.FieldDescriptor(
775
+ name='lane', full_name='long_metric.StopSign.lane', index=0,
776
+ number=1, type=3, cpp_type=2, label=3,
777
+ has_default_value=False, default_value=[],
778
+ message_type=None, enum_type=None, containing_type=None,
779
+ is_extension=False, extension_scope=None,
780
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
781
+ _descriptor.FieldDescriptor(
782
+ name='position', full_name='long_metric.StopSign.position', index=1,
783
+ number=2, type=11, cpp_type=10, label=1,
784
+ has_default_value=False, default_value=None,
785
+ message_type=None, enum_type=None, containing_type=None,
786
+ is_extension=False, extension_scope=None,
787
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
788
+ ],
789
+ extensions=[
790
+ ],
791
+ nested_types=[],
792
+ enum_types=[
793
+ ],
794
+ serialized_options=None,
795
+ is_extendable=False,
796
+ syntax='proto2',
797
+ extension_ranges=[],
798
+ oneofs=[
799
+ ],
800
+ serialized_start=2469,
801
+ serialized_end=2534,
802
+ )
803
+
804
+
805
+ _CROSSWALK = _descriptor.Descriptor(
806
+ name='Crosswalk',
807
+ full_name='long_metric.Crosswalk',
808
+ filename=None,
809
+ file=DESCRIPTOR,
810
+ containing_type=None,
811
+ create_key=_descriptor._internal_create_key,
812
+ fields=[
813
+ _descriptor.FieldDescriptor(
814
+ name='polygon', full_name='long_metric.Crosswalk.polygon', index=0,
815
+ number=1, type=11, cpp_type=10, label=3,
816
+ has_default_value=False, default_value=[],
817
+ message_type=None, enum_type=None, containing_type=None,
818
+ is_extension=False, extension_scope=None,
819
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
820
+ ],
821
+ extensions=[
822
+ ],
823
+ nested_types=[],
824
+ enum_types=[
825
+ ],
826
+ serialized_options=None,
827
+ is_extendable=False,
828
+ syntax='proto2',
829
+ extension_ranges=[],
830
+ oneofs=[
831
+ ],
832
+ serialized_start=2536,
833
+ serialized_end=2587,
834
+ )
835
+
836
+
837
+ _SPEEDBUMP = _descriptor.Descriptor(
838
+ name='SpeedBump',
839
+ full_name='long_metric.SpeedBump',
840
+ filename=None,
841
+ file=DESCRIPTOR,
842
+ containing_type=None,
843
+ create_key=_descriptor._internal_create_key,
844
+ fields=[
845
+ _descriptor.FieldDescriptor(
846
+ name='polygon', full_name='long_metric.SpeedBump.polygon', index=0,
847
+ number=1, type=11, cpp_type=10, label=3,
848
+ has_default_value=False, default_value=[],
849
+ message_type=None, enum_type=None, containing_type=None,
850
+ is_extension=False, extension_scope=None,
851
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
852
+ ],
853
+ extensions=[
854
+ ],
855
+ nested_types=[],
856
+ enum_types=[
857
+ ],
858
+ serialized_options=None,
859
+ is_extendable=False,
860
+ syntax='proto2',
861
+ extension_ranges=[],
862
+ oneofs=[
863
+ ],
864
+ serialized_start=2589,
865
+ serialized_end=2640,
866
+ )
867
+
868
+
869
+ _DRIVEWAY = _descriptor.Descriptor(
870
+ name='Driveway',
871
+ full_name='long_metric.Driveway',
872
+ filename=None,
873
+ file=DESCRIPTOR,
874
+ containing_type=None,
875
+ create_key=_descriptor._internal_create_key,
876
+ fields=[
877
+ _descriptor.FieldDescriptor(
878
+ name='polygon', full_name='long_metric.Driveway.polygon', index=0,
879
+ number=1, type=11, cpp_type=10, label=3,
880
+ has_default_value=False, default_value=[],
881
+ message_type=None, enum_type=None, containing_type=None,
882
+ is_extension=False, extension_scope=None,
883
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
884
+ ],
885
+ extensions=[
886
+ ],
887
+ nested_types=[],
888
+ enum_types=[
889
+ ],
890
+ serialized_options=None,
891
+ is_extendable=False,
892
+ syntax='proto2',
893
+ extension_ranges=[],
894
+ oneofs=[
895
+ ],
896
+ serialized_start=2642,
897
+ serialized_end=2692,
898
+ )
899
+
900
+ _MAP.fields_by_name['map_features'].message_type = _MAPFEATURE
901
+ _MAP.fields_by_name['dynamic_states'].message_type = _DYNAMICSTATE
902
+ _DYNAMICSTATE.fields_by_name['lane_states'].message_type = _TRAFFICSIGNALLANESTATE
903
+ _TRAFFICSIGNALLANESTATE.fields_by_name['state'].enum_type = _TRAFFICSIGNALLANESTATE_STATE
904
+ _TRAFFICSIGNALLANESTATE.fields_by_name['stop_point'].message_type = _MAPPOINT
905
+ _TRAFFICSIGNALLANESTATE_STATE.containing_type = _TRAFFICSIGNALLANESTATE
906
+ _MAPFEATURE.fields_by_name['lane'].message_type = _LANECENTER
907
+ _MAPFEATURE.fields_by_name['road_line'].message_type = _ROADLINE
908
+ _MAPFEATURE.fields_by_name['road_edge'].message_type = _ROADEDGE
909
+ _MAPFEATURE.fields_by_name['stop_sign'].message_type = _STOPSIGN
910
+ _MAPFEATURE.fields_by_name['crosswalk'].message_type = _CROSSWALK
911
+ _MAPFEATURE.fields_by_name['speed_bump'].message_type = _SPEEDBUMP
912
+ _MAPFEATURE.fields_by_name['driveway'].message_type = _DRIVEWAY
913
+ _MAPFEATURE.oneofs_by_name['feature_data'].fields.append(
914
+ _MAPFEATURE.fields_by_name['lane'])
915
+ _MAPFEATURE.fields_by_name['lane'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data']
916
+ _MAPFEATURE.oneofs_by_name['feature_data'].fields.append(
917
+ _MAPFEATURE.fields_by_name['road_line'])
918
+ _MAPFEATURE.fields_by_name['road_line'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data']
919
+ _MAPFEATURE.oneofs_by_name['feature_data'].fields.append(
920
+ _MAPFEATURE.fields_by_name['road_edge'])
921
+ _MAPFEATURE.fields_by_name['road_edge'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data']
922
+ _MAPFEATURE.oneofs_by_name['feature_data'].fields.append(
923
+ _MAPFEATURE.fields_by_name['stop_sign'])
924
+ _MAPFEATURE.fields_by_name['stop_sign'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data']
925
+ _MAPFEATURE.oneofs_by_name['feature_data'].fields.append(
926
+ _MAPFEATURE.fields_by_name['crosswalk'])
927
+ _MAPFEATURE.fields_by_name['crosswalk'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data']
928
+ _MAPFEATURE.oneofs_by_name['feature_data'].fields.append(
929
+ _MAPFEATURE.fields_by_name['speed_bump'])
930
+ _MAPFEATURE.fields_by_name['speed_bump'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data']
931
+ _MAPFEATURE.oneofs_by_name['feature_data'].fields.append(
932
+ _MAPFEATURE.fields_by_name['driveway'])
933
+ _MAPFEATURE.fields_by_name['driveway'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data']
934
+ _BOUNDARYSEGMENT.fields_by_name['boundary_type'].enum_type = _ROADLINE_ROADLINETYPE
935
+ _LANENEIGHBOR.fields_by_name['boundaries'].message_type = _BOUNDARYSEGMENT
936
+ _LANECENTER.fields_by_name['type'].enum_type = _LANECENTER_LANETYPE
937
+ _LANECENTER.fields_by_name['polyline'].message_type = _MAPPOINT
938
+ _LANECENTER.fields_by_name['left_boundaries'].message_type = _BOUNDARYSEGMENT
939
+ _LANECENTER.fields_by_name['right_boundaries'].message_type = _BOUNDARYSEGMENT
940
+ _LANECENTER.fields_by_name['left_neighbors'].message_type = _LANENEIGHBOR
941
+ _LANECENTER.fields_by_name['right_neighbors'].message_type = _LANENEIGHBOR
942
+ _LANECENTER_LANETYPE.containing_type = _LANECENTER
943
+ _ROADEDGE.fields_by_name['type'].enum_type = _ROADEDGE_ROADEDGETYPE
944
+ _ROADEDGE.fields_by_name['polyline'].message_type = _MAPPOINT
945
+ _ROADEDGE_ROADEDGETYPE.containing_type = _ROADEDGE
946
+ _ROADLINE.fields_by_name['type'].enum_type = _ROADLINE_ROADLINETYPE
947
+ _ROADLINE.fields_by_name['polyline'].message_type = _MAPPOINT
948
+ _ROADLINE_ROADLINETYPE.containing_type = _ROADLINE
949
+ _STOPSIGN.fields_by_name['position'].message_type = _MAPPOINT
950
+ _CROSSWALK.fields_by_name['polygon'].message_type = _MAPPOINT
951
+ _SPEEDBUMP.fields_by_name['polygon'].message_type = _MAPPOINT
952
+ _DRIVEWAY.fields_by_name['polygon'].message_type = _MAPPOINT
953
+ DESCRIPTOR.message_types_by_name['Map'] = _MAP
954
+ DESCRIPTOR.message_types_by_name['DynamicState'] = _DYNAMICSTATE
955
+ DESCRIPTOR.message_types_by_name['TrafficSignalLaneState'] = _TRAFFICSIGNALLANESTATE
956
+ DESCRIPTOR.message_types_by_name['MapFeature'] = _MAPFEATURE
957
+ DESCRIPTOR.message_types_by_name['MapPoint'] = _MAPPOINT
958
+ DESCRIPTOR.message_types_by_name['BoundarySegment'] = _BOUNDARYSEGMENT
959
+ DESCRIPTOR.message_types_by_name['LaneNeighbor'] = _LANENEIGHBOR
960
+ DESCRIPTOR.message_types_by_name['LaneCenter'] = _LANECENTER
961
+ DESCRIPTOR.message_types_by_name['RoadEdge'] = _ROADEDGE
962
+ DESCRIPTOR.message_types_by_name['RoadLine'] = _ROADLINE
963
+ DESCRIPTOR.message_types_by_name['StopSign'] = _STOPSIGN
964
+ DESCRIPTOR.message_types_by_name['Crosswalk'] = _CROSSWALK
965
+ DESCRIPTOR.message_types_by_name['SpeedBump'] = _SPEEDBUMP
966
+ DESCRIPTOR.message_types_by_name['Driveway'] = _DRIVEWAY
967
+ _sym_db.RegisterFileDescriptor(DESCRIPTOR)
968
+
969
+ Map = _reflection.GeneratedProtocolMessageType('Map', (_message.Message,), {
970
+ 'DESCRIPTOR' : _MAP,
971
+ '__module__' : 'dev.metrics.protos.map_pb2'
972
+ # @@protoc_insertion_point(class_scope:long_metric.Map)
973
+ })
974
+ _sym_db.RegisterMessage(Map)
975
+
976
+ DynamicState = _reflection.GeneratedProtocolMessageType('DynamicState', (_message.Message,), {
977
+ 'DESCRIPTOR' : _DYNAMICSTATE,
978
+ '__module__' : 'dev.metrics.protos.map_pb2'
979
+ # @@protoc_insertion_point(class_scope:long_metric.DynamicState)
980
+ })
981
+ _sym_db.RegisterMessage(DynamicState)
982
+
983
+ TrafficSignalLaneState = _reflection.GeneratedProtocolMessageType('TrafficSignalLaneState', (_message.Message,), {
984
+ 'DESCRIPTOR' : _TRAFFICSIGNALLANESTATE,
985
+ '__module__' : 'dev.metrics.protos.map_pb2'
986
+ # @@protoc_insertion_point(class_scope:long_metric.TrafficSignalLaneState)
987
+ })
988
+ _sym_db.RegisterMessage(TrafficSignalLaneState)
989
+
990
+ MapFeature = _reflection.GeneratedProtocolMessageType('MapFeature', (_message.Message,), {
991
+ 'DESCRIPTOR' : _MAPFEATURE,
992
+ '__module__' : 'dev.metrics.protos.map_pb2'
993
+ # @@protoc_insertion_point(class_scope:long_metric.MapFeature)
994
+ })
995
+ _sym_db.RegisterMessage(MapFeature)
996
+
997
+ MapPoint = _reflection.GeneratedProtocolMessageType('MapPoint', (_message.Message,), {
998
+ 'DESCRIPTOR' : _MAPPOINT,
999
+ '__module__' : 'dev.metrics.protos.map_pb2'
1000
+ # @@protoc_insertion_point(class_scope:long_metric.MapPoint)
1001
+ })
1002
+ _sym_db.RegisterMessage(MapPoint)
1003
+
1004
+ BoundarySegment = _reflection.GeneratedProtocolMessageType('BoundarySegment', (_message.Message,), {
1005
+ 'DESCRIPTOR' : _BOUNDARYSEGMENT,
1006
+ '__module__' : 'dev.metrics.protos.map_pb2'
1007
+ # @@protoc_insertion_point(class_scope:long_metric.BoundarySegment)
1008
+ })
1009
+ _sym_db.RegisterMessage(BoundarySegment)
1010
+
1011
+ LaneNeighbor = _reflection.GeneratedProtocolMessageType('LaneNeighbor', (_message.Message,), {
1012
+ 'DESCRIPTOR' : _LANENEIGHBOR,
1013
+ '__module__' : 'dev.metrics.protos.map_pb2'
1014
+ # @@protoc_insertion_point(class_scope:long_metric.LaneNeighbor)
1015
+ })
1016
+ _sym_db.RegisterMessage(LaneNeighbor)
1017
+
1018
+ LaneCenter = _reflection.GeneratedProtocolMessageType('LaneCenter', (_message.Message,), {
1019
+ 'DESCRIPTOR' : _LANECENTER,
1020
+ '__module__' : 'dev.metrics.protos.map_pb2'
1021
+ # @@protoc_insertion_point(class_scope:long_metric.LaneCenter)
1022
+ })
1023
+ _sym_db.RegisterMessage(LaneCenter)
1024
+
1025
+ RoadEdge = _reflection.GeneratedProtocolMessageType('RoadEdge', (_message.Message,), {
1026
+ 'DESCRIPTOR' : _ROADEDGE,
1027
+ '__module__' : 'dev.metrics.protos.map_pb2'
1028
+ # @@protoc_insertion_point(class_scope:long_metric.RoadEdge)
1029
+ })
1030
+ _sym_db.RegisterMessage(RoadEdge)
1031
+
1032
+ RoadLine = _reflection.GeneratedProtocolMessageType('RoadLine', (_message.Message,), {
1033
+ 'DESCRIPTOR' : _ROADLINE,
1034
+ '__module__' : 'dev.metrics.protos.map_pb2'
1035
+ # @@protoc_insertion_point(class_scope:long_metric.RoadLine)
1036
+ })
1037
+ _sym_db.RegisterMessage(RoadLine)
1038
+
1039
+ StopSign = _reflection.GeneratedProtocolMessageType('StopSign', (_message.Message,), {
1040
+ 'DESCRIPTOR' : _STOPSIGN,
1041
+ '__module__' : 'dev.metrics.protos.map_pb2'
1042
+ # @@protoc_insertion_point(class_scope:long_metric.StopSign)
1043
+ })
1044
+ _sym_db.RegisterMessage(StopSign)
1045
+
1046
+ Crosswalk = _reflection.GeneratedProtocolMessageType('Crosswalk', (_message.Message,), {
1047
+ 'DESCRIPTOR' : _CROSSWALK,
1048
+ '__module__' : 'dev.metrics.protos.map_pb2'
1049
+ # @@protoc_insertion_point(class_scope:long_metric.Crosswalk)
1050
+ })
1051
+ _sym_db.RegisterMessage(Crosswalk)
1052
+
1053
+ SpeedBump = _reflection.GeneratedProtocolMessageType('SpeedBump', (_message.Message,), {
1054
+ 'DESCRIPTOR' : _SPEEDBUMP,
1055
+ '__module__' : 'dev.metrics.protos.map_pb2'
1056
+ # @@protoc_insertion_point(class_scope:long_metric.SpeedBump)
1057
+ })
1058
+ _sym_db.RegisterMessage(SpeedBump)
1059
+
1060
+ Driveway = _reflection.GeneratedProtocolMessageType('Driveway', (_message.Message,), {
1061
+ 'DESCRIPTOR' : _DRIVEWAY,
1062
+ '__module__' : 'dev.metrics.protos.map_pb2'
1063
+ # @@protoc_insertion_point(class_scope:long_metric.Driveway)
1064
+ })
1065
+ _sym_db.RegisterMessage(Driveway)
1066
+
1067
+
1068
+ _LANECENTER.fields_by_name['entry_lanes']._options = None
1069
+ _LANECENTER.fields_by_name['exit_lanes']._options = None
1070
+ # @@protoc_insertion_point(module_scope)
backups/dev/metrics/protos/scenario_pb2.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: dev/metrics/protos/scenario.proto
4
+
5
+ from google.protobuf import descriptor as _descriptor
6
+ from google.protobuf import message as _message
7
+ from google.protobuf import reflection as _reflection
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ # @@protoc_insertion_point(imports)
10
+
11
+ _sym_db = _symbol_database.Default()
12
+
13
+
14
+ from dev.metrics.protos import map_pb2 as dev_dot_metrics_dot_protos_dot_map__pb2
15
+
16
+
17
+ DESCRIPTOR = _descriptor.FileDescriptor(
18
+ name='dev/metrics/protos/scenario.proto',
19
+ package='long_metric',
20
+ syntax='proto2',
21
+ serialized_options=None,
22
+ create_key=_descriptor._internal_create_key,
23
+ serialized_pb=b'\n!dev/metrics/protos/scenario.proto\x12\x0blong_metric\x1a\x1c\x64\x65v/metrics/protos/map.proto\"\xba\x01\n\x0bObjectState\x12\x10\n\x08\x63\x65nter_x\x18\x02 \x01(\x01\x12\x10\n\x08\x63\x65nter_y\x18\x03 \x01(\x01\x12\x10\n\x08\x63\x65nter_z\x18\x04 \x01(\x01\x12\x0e\n\x06length\x18\x05 \x01(\x02\x12\r\n\x05width\x18\x06 \x01(\x02\x12\x0e\n\x06height\x18\x07 \x01(\x02\x12\x0f\n\x07heading\x18\x08 \x01(\x02\x12\x12\n\nvelocity_x\x18\t \x01(\x02\x12\x12\n\nvelocity_y\x18\n \x01(\x02\x12\r\n\x05valid\x18\x0b \x01(\x08\"\xd8\x01\n\x05Track\x12\n\n\x02id\x18\x01 \x01(\x05\x12\x32\n\x0bobject_type\x18\x02 \x01(\x0e\x32\x1d.long_metric.Track.ObjectType\x12(\n\x06states\x18\x03 \x03(\x0b\x32\x18.long_metric.ObjectState\"e\n\nObjectType\x12\x0e\n\nTYPE_UNSET\x10\x00\x12\x10\n\x0cTYPE_VEHICLE\x10\x01\x12\x13\n\x0fTYPE_PEDESTRIAN\x10\x02\x12\x10\n\x0cTYPE_CYCLIST\x10\x03\x12\x0e\n\nTYPE_OTHER\x10\x04\"K\n\x0f\x44ynamicMapState\x12\x38\n\x0blane_states\x18\x01 \x03(\x0b\x32#.long_metric.TrafficSignalLaneState\"\xa5\x01\n\x12RequiredPrediction\x12\x13\n\x0btrack_index\x18\x01 \x01(\x05\x12\x43\n\ndifficulty\x18\x02 \x01(\x0e\x32/.long_metric.RequiredPrediction.DifficultyLevel\"5\n\x0f\x44ifficultyLevel\x12\x08\n\x04NONE\x10\x00\x12\x0b\n\x07LEVEL_1\x10\x01\x12\x0b\n\x07LEVEL_2\x10\x02\"\xdc\x02\n\x08Scenario\x12\x13\n\x0bscenario_id\x18\x05 \x01(\t\x12\x1a\n\x12timestamps_seconds\x18\x01 \x03(\x01\x12\x1a\n\x12\x63urrent_time_index\x18\n \x01(\x05\x12\"\n\x06tracks\x18\x02 \x03(\x0b\x32\x12.long_metric.Track\x12\x38\n\x12\x64ynamic_map_states\x18\x07 \x03(\x0b\x32\x1c.long_metric.DynamicMapState\x12-\n\x0cmap_features\x18\x08 \x03(\x0b\x32\x17.long_metric.MapFeature\x12\x17\n\x0fsdc_track_index\x18\x06 \x01(\x05\x12\x1b\n\x13objects_of_interest\x18\x04 \x03(\x05\x12:\n\x11tracks_to_predict\x18\x0b \x03(\x0b\x32\x1f.long_metric.RequiredPredictionJ\x04\x08\t\x10\n'
24
+ ,
25
+ dependencies=[dev_dot_metrics_dot_protos_dot_map__pb2.DESCRIPTOR,])
26
+
27
+
28
+
29
+ _TRACK_OBJECTTYPE = _descriptor.EnumDescriptor(
30
+ name='ObjectType',
31
+ full_name='long_metric.Track.ObjectType',
32
+ filename=None,
33
+ file=DESCRIPTOR,
34
+ create_key=_descriptor._internal_create_key,
35
+ values=[
36
+ _descriptor.EnumValueDescriptor(
37
+ name='TYPE_UNSET', index=0, number=0,
38
+ serialized_options=None,
39
+ type=None,
40
+ create_key=_descriptor._internal_create_key),
41
+ _descriptor.EnumValueDescriptor(
42
+ name='TYPE_VEHICLE', index=1, number=1,
43
+ serialized_options=None,
44
+ type=None,
45
+ create_key=_descriptor._internal_create_key),
46
+ _descriptor.EnumValueDescriptor(
47
+ name='TYPE_PEDESTRIAN', index=2, number=2,
48
+ serialized_options=None,
49
+ type=None,
50
+ create_key=_descriptor._internal_create_key),
51
+ _descriptor.EnumValueDescriptor(
52
+ name='TYPE_CYCLIST', index=3, number=3,
53
+ serialized_options=None,
54
+ type=None,
55
+ create_key=_descriptor._internal_create_key),
56
+ _descriptor.EnumValueDescriptor(
57
+ name='TYPE_OTHER', index=4, number=4,
58
+ serialized_options=None,
59
+ type=None,
60
+ create_key=_descriptor._internal_create_key),
61
+ ],
62
+ containing_type=None,
63
+ serialized_options=None,
64
+ serialized_start=385,
65
+ serialized_end=486,
66
+ )
67
+ _sym_db.RegisterEnumDescriptor(_TRACK_OBJECTTYPE)
68
+
69
+ _REQUIREDPREDICTION_DIFFICULTYLEVEL = _descriptor.EnumDescriptor(
70
+ name='DifficultyLevel',
71
+ full_name='long_metric.RequiredPrediction.DifficultyLevel',
72
+ filename=None,
73
+ file=DESCRIPTOR,
74
+ create_key=_descriptor._internal_create_key,
75
+ values=[
76
+ _descriptor.EnumValueDescriptor(
77
+ name='NONE', index=0, number=0,
78
+ serialized_options=None,
79
+ type=None,
80
+ create_key=_descriptor._internal_create_key),
81
+ _descriptor.EnumValueDescriptor(
82
+ name='LEVEL_1', index=1, number=1,
83
+ serialized_options=None,
84
+ type=None,
85
+ create_key=_descriptor._internal_create_key),
86
+ _descriptor.EnumValueDescriptor(
87
+ name='LEVEL_2', index=2, number=2,
88
+ serialized_options=None,
89
+ type=None,
90
+ create_key=_descriptor._internal_create_key),
91
+ ],
92
+ containing_type=None,
93
+ serialized_options=None,
94
+ serialized_start=678,
95
+ serialized_end=731,
96
+ )
97
+ _sym_db.RegisterEnumDescriptor(_REQUIREDPREDICTION_DIFFICULTYLEVEL)
98
+
99
+
100
+ _OBJECTSTATE = _descriptor.Descriptor(
101
+ name='ObjectState',
102
+ full_name='long_metric.ObjectState',
103
+ filename=None,
104
+ file=DESCRIPTOR,
105
+ containing_type=None,
106
+ create_key=_descriptor._internal_create_key,
107
+ fields=[
108
+ _descriptor.FieldDescriptor(
109
+ name='center_x', full_name='long_metric.ObjectState.center_x', index=0,
110
+ number=2, type=1, cpp_type=5, label=1,
111
+ has_default_value=False, default_value=float(0),
112
+ message_type=None, enum_type=None, containing_type=None,
113
+ is_extension=False, extension_scope=None,
114
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
115
+ _descriptor.FieldDescriptor(
116
+ name='center_y', full_name='long_metric.ObjectState.center_y', index=1,
117
+ number=3, type=1, cpp_type=5, label=1,
118
+ has_default_value=False, default_value=float(0),
119
+ message_type=None, enum_type=None, containing_type=None,
120
+ is_extension=False, extension_scope=None,
121
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
122
+ _descriptor.FieldDescriptor(
123
+ name='center_z', full_name='long_metric.ObjectState.center_z', index=2,
124
+ number=4, type=1, cpp_type=5, label=1,
125
+ has_default_value=False, default_value=float(0),
126
+ message_type=None, enum_type=None, containing_type=None,
127
+ is_extension=False, extension_scope=None,
128
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
129
+ _descriptor.FieldDescriptor(
130
+ name='length', full_name='long_metric.ObjectState.length', index=3,
131
+ number=5, type=2, cpp_type=6, label=1,
132
+ has_default_value=False, default_value=float(0),
133
+ message_type=None, enum_type=None, containing_type=None,
134
+ is_extension=False, extension_scope=None,
135
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
136
+ _descriptor.FieldDescriptor(
137
+ name='width', full_name='long_metric.ObjectState.width', index=4,
138
+ number=6, type=2, cpp_type=6, label=1,
139
+ has_default_value=False, default_value=float(0),
140
+ message_type=None, enum_type=None, containing_type=None,
141
+ is_extension=False, extension_scope=None,
142
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
143
+ _descriptor.FieldDescriptor(
144
+ name='height', full_name='long_metric.ObjectState.height', index=5,
145
+ number=7, type=2, cpp_type=6, label=1,
146
+ has_default_value=False, default_value=float(0),
147
+ message_type=None, enum_type=None, containing_type=None,
148
+ is_extension=False, extension_scope=None,
149
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
150
+ _descriptor.FieldDescriptor(
151
+ name='heading', full_name='long_metric.ObjectState.heading', index=6,
152
+ number=8, type=2, cpp_type=6, label=1,
153
+ has_default_value=False, default_value=float(0),
154
+ message_type=None, enum_type=None, containing_type=None,
155
+ is_extension=False, extension_scope=None,
156
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
157
+ _descriptor.FieldDescriptor(
158
+ name='velocity_x', full_name='long_metric.ObjectState.velocity_x', index=7,
159
+ number=9, type=2, cpp_type=6, label=1,
160
+ has_default_value=False, default_value=float(0),
161
+ message_type=None, enum_type=None, containing_type=None,
162
+ is_extension=False, extension_scope=None,
163
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
164
+ _descriptor.FieldDescriptor(
165
+ name='velocity_y', full_name='long_metric.ObjectState.velocity_y', index=8,
166
+ number=10, type=2, cpp_type=6, label=1,
167
+ has_default_value=False, default_value=float(0),
168
+ message_type=None, enum_type=None, containing_type=None,
169
+ is_extension=False, extension_scope=None,
170
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
171
+ _descriptor.FieldDescriptor(
172
+ name='valid', full_name='long_metric.ObjectState.valid', index=9,
173
+ number=11, type=8, cpp_type=7, label=1,
174
+ has_default_value=False, default_value=False,
175
+ message_type=None, enum_type=None, containing_type=None,
176
+ is_extension=False, extension_scope=None,
177
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
178
+ ],
179
+ extensions=[
180
+ ],
181
+ nested_types=[],
182
+ enum_types=[
183
+ ],
184
+ serialized_options=None,
185
+ is_extendable=False,
186
+ syntax='proto2',
187
+ extension_ranges=[],
188
+ oneofs=[
189
+ ],
190
+ serialized_start=81,
191
+ serialized_end=267,
192
+ )
193
+
194
+
195
+ _TRACK = _descriptor.Descriptor(
196
+ name='Track',
197
+ full_name='long_metric.Track',
198
+ filename=None,
199
+ file=DESCRIPTOR,
200
+ containing_type=None,
201
+ create_key=_descriptor._internal_create_key,
202
+ fields=[
203
+ _descriptor.FieldDescriptor(
204
+ name='id', full_name='long_metric.Track.id', index=0,
205
+ number=1, type=5, cpp_type=1, label=1,
206
+ has_default_value=False, default_value=0,
207
+ message_type=None, enum_type=None, containing_type=None,
208
+ is_extension=False, extension_scope=None,
209
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
210
+ _descriptor.FieldDescriptor(
211
+ name='object_type', full_name='long_metric.Track.object_type', index=1,
212
+ number=2, type=14, cpp_type=8, label=1,
213
+ has_default_value=False, default_value=0,
214
+ message_type=None, enum_type=None, containing_type=None,
215
+ is_extension=False, extension_scope=None,
216
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
217
+ _descriptor.FieldDescriptor(
218
+ name='states', full_name='long_metric.Track.states', index=2,
219
+ number=3, type=11, cpp_type=10, label=3,
220
+ has_default_value=False, default_value=[],
221
+ message_type=None, enum_type=None, containing_type=None,
222
+ is_extension=False, extension_scope=None,
223
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
224
+ ],
225
+ extensions=[
226
+ ],
227
+ nested_types=[],
228
+ enum_types=[
229
+ _TRACK_OBJECTTYPE,
230
+ ],
231
+ serialized_options=None,
232
+ is_extendable=False,
233
+ syntax='proto2',
234
+ extension_ranges=[],
235
+ oneofs=[
236
+ ],
237
+ serialized_start=270,
238
+ serialized_end=486,
239
+ )
240
+
241
+
242
+ _DYNAMICMAPSTATE = _descriptor.Descriptor(
243
+ name='DynamicMapState',
244
+ full_name='long_metric.DynamicMapState',
245
+ filename=None,
246
+ file=DESCRIPTOR,
247
+ containing_type=None,
248
+ create_key=_descriptor._internal_create_key,
249
+ fields=[
250
+ _descriptor.FieldDescriptor(
251
+ name='lane_states', full_name='long_metric.DynamicMapState.lane_states', index=0,
252
+ number=1, type=11, cpp_type=10, label=3,
253
+ has_default_value=False, default_value=[],
254
+ message_type=None, enum_type=None, containing_type=None,
255
+ is_extension=False, extension_scope=None,
256
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
257
+ ],
258
+ extensions=[
259
+ ],
260
+ nested_types=[],
261
+ enum_types=[
262
+ ],
263
+ serialized_options=None,
264
+ is_extendable=False,
265
+ syntax='proto2',
266
+ extension_ranges=[],
267
+ oneofs=[
268
+ ],
269
+ serialized_start=488,
270
+ serialized_end=563,
271
+ )
272
+
273
+
274
+ _REQUIREDPREDICTION = _descriptor.Descriptor(
275
+ name='RequiredPrediction',
276
+ full_name='long_metric.RequiredPrediction',
277
+ filename=None,
278
+ file=DESCRIPTOR,
279
+ containing_type=None,
280
+ create_key=_descriptor._internal_create_key,
281
+ fields=[
282
+ _descriptor.FieldDescriptor(
283
+ name='track_index', full_name='long_metric.RequiredPrediction.track_index', index=0,
284
+ number=1, type=5, cpp_type=1, label=1,
285
+ has_default_value=False, default_value=0,
286
+ message_type=None, enum_type=None, containing_type=None,
287
+ is_extension=False, extension_scope=None,
288
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
289
+ _descriptor.FieldDescriptor(
290
+ name='difficulty', full_name='long_metric.RequiredPrediction.difficulty', index=1,
291
+ number=2, type=14, cpp_type=8, label=1,
292
+ has_default_value=False, default_value=0,
293
+ message_type=None, enum_type=None, containing_type=None,
294
+ is_extension=False, extension_scope=None,
295
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
296
+ ],
297
+ extensions=[
298
+ ],
299
+ nested_types=[],
300
+ enum_types=[
301
+ _REQUIREDPREDICTION_DIFFICULTYLEVEL,
302
+ ],
303
+ serialized_options=None,
304
+ is_extendable=False,
305
+ syntax='proto2',
306
+ extension_ranges=[],
307
+ oneofs=[
308
+ ],
309
+ serialized_start=566,
310
+ serialized_end=731,
311
+ )
312
+
313
+
314
+ _SCENARIO = _descriptor.Descriptor(
315
+ name='Scenario',
316
+ full_name='long_metric.Scenario',
317
+ filename=None,
318
+ file=DESCRIPTOR,
319
+ containing_type=None,
320
+ create_key=_descriptor._internal_create_key,
321
+ fields=[
322
+ _descriptor.FieldDescriptor(
323
+ name='scenario_id', full_name='long_metric.Scenario.scenario_id', index=0,
324
+ number=5, type=9, cpp_type=9, label=1,
325
+ has_default_value=False, default_value=b"".decode('utf-8'),
326
+ message_type=None, enum_type=None, containing_type=None,
327
+ is_extension=False, extension_scope=None,
328
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
329
+ _descriptor.FieldDescriptor(
330
+ name='timestamps_seconds', full_name='long_metric.Scenario.timestamps_seconds', index=1,
331
+ number=1, type=1, cpp_type=5, label=3,
332
+ has_default_value=False, default_value=[],
333
+ message_type=None, enum_type=None, containing_type=None,
334
+ is_extension=False, extension_scope=None,
335
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
336
+ _descriptor.FieldDescriptor(
337
+ name='current_time_index', full_name='long_metric.Scenario.current_time_index', index=2,
338
+ number=10, type=5, cpp_type=1, label=1,
339
+ has_default_value=False, default_value=0,
340
+ message_type=None, enum_type=None, containing_type=None,
341
+ is_extension=False, extension_scope=None,
342
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
343
+ _descriptor.FieldDescriptor(
344
+ name='tracks', full_name='long_metric.Scenario.tracks', index=3,
345
+ number=2, type=11, cpp_type=10, label=3,
346
+ has_default_value=False, default_value=[],
347
+ message_type=None, enum_type=None, containing_type=None,
348
+ is_extension=False, extension_scope=None,
349
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
350
+ _descriptor.FieldDescriptor(
351
+ name='dynamic_map_states', full_name='long_metric.Scenario.dynamic_map_states', index=4,
352
+ number=7, type=11, cpp_type=10, label=3,
353
+ has_default_value=False, default_value=[],
354
+ message_type=None, enum_type=None, containing_type=None,
355
+ is_extension=False, extension_scope=None,
356
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
357
+ _descriptor.FieldDescriptor(
358
+ name='map_features', full_name='long_metric.Scenario.map_features', index=5,
359
+ number=8, type=11, cpp_type=10, label=3,
360
+ has_default_value=False, default_value=[],
361
+ message_type=None, enum_type=None, containing_type=None,
362
+ is_extension=False, extension_scope=None,
363
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
364
+ _descriptor.FieldDescriptor(
365
+ name='sdc_track_index', full_name='long_metric.Scenario.sdc_track_index', index=6,
366
+ number=6, type=5, cpp_type=1, label=1,
367
+ has_default_value=False, default_value=0,
368
+ message_type=None, enum_type=None, containing_type=None,
369
+ is_extension=False, extension_scope=None,
370
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
371
+ _descriptor.FieldDescriptor(
372
+ name='objects_of_interest', full_name='long_metric.Scenario.objects_of_interest', index=7,
373
+ number=4, type=5, cpp_type=1, label=3,
374
+ has_default_value=False, default_value=[],
375
+ message_type=None, enum_type=None, containing_type=None,
376
+ is_extension=False, extension_scope=None,
377
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
378
+ _descriptor.FieldDescriptor(
379
+ name='tracks_to_predict', full_name='long_metric.Scenario.tracks_to_predict', index=8,
380
+ number=11, type=11, cpp_type=10, label=3,
381
+ has_default_value=False, default_value=[],
382
+ message_type=None, enum_type=None, containing_type=None,
383
+ is_extension=False, extension_scope=None,
384
+ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
385
+ ],
386
+ extensions=[
387
+ ],
388
+ nested_types=[],
389
+ enum_types=[
390
+ ],
391
+ serialized_options=None,
392
+ is_extendable=False,
393
+ syntax='proto2',
394
+ extension_ranges=[],
395
+ oneofs=[
396
+ ],
397
+ serialized_start=734,
398
+ serialized_end=1082,
399
+ )
400
+
401
+ _TRACK.fields_by_name['object_type'].enum_type = _TRACK_OBJECTTYPE
402
+ _TRACK.fields_by_name['states'].message_type = _OBJECTSTATE
403
+ _TRACK_OBJECTTYPE.containing_type = _TRACK
404
+ _DYNAMICMAPSTATE.fields_by_name['lane_states'].message_type = dev_dot_metrics_dot_protos_dot_map__pb2._TRAFFICSIGNALLANESTATE
405
+ _REQUIREDPREDICTION.fields_by_name['difficulty'].enum_type = _REQUIREDPREDICTION_DIFFICULTYLEVEL
406
+ _REQUIREDPREDICTION_DIFFICULTYLEVEL.containing_type = _REQUIREDPREDICTION
407
+ _SCENARIO.fields_by_name['tracks'].message_type = _TRACK
408
+ _SCENARIO.fields_by_name['dynamic_map_states'].message_type = _DYNAMICMAPSTATE
409
+ _SCENARIO.fields_by_name['map_features'].message_type = dev_dot_metrics_dot_protos_dot_map__pb2._MAPFEATURE
410
+ _SCENARIO.fields_by_name['tracks_to_predict'].message_type = _REQUIREDPREDICTION
411
+ DESCRIPTOR.message_types_by_name['ObjectState'] = _OBJECTSTATE
412
+ DESCRIPTOR.message_types_by_name['Track'] = _TRACK
413
+ DESCRIPTOR.message_types_by_name['DynamicMapState'] = _DYNAMICMAPSTATE
414
+ DESCRIPTOR.message_types_by_name['RequiredPrediction'] = _REQUIREDPREDICTION
415
+ DESCRIPTOR.message_types_by_name['Scenario'] = _SCENARIO
416
+ _sym_db.RegisterFileDescriptor(DESCRIPTOR)
417
+
418
+ ObjectState = _reflection.GeneratedProtocolMessageType('ObjectState', (_message.Message,), {
419
+ 'DESCRIPTOR' : _OBJECTSTATE,
420
+ '__module__' : 'dev.metrics.protos.scenario_pb2'
421
+ # @@protoc_insertion_point(class_scope:long_metric.ObjectState)
422
+ })
423
+ _sym_db.RegisterMessage(ObjectState)
424
+
425
+ Track = _reflection.GeneratedProtocolMessageType('Track', (_message.Message,), {
426
+ 'DESCRIPTOR' : _TRACK,
427
+ '__module__' : 'dev.metrics.protos.scenario_pb2'
428
+ # @@protoc_insertion_point(class_scope:long_metric.Track)
429
+ })
430
+ _sym_db.RegisterMessage(Track)
431
+
432
+ DynamicMapState = _reflection.GeneratedProtocolMessageType('DynamicMapState', (_message.Message,), {
433
+ 'DESCRIPTOR' : _DYNAMICMAPSTATE,
434
+ '__module__' : 'dev.metrics.protos.scenario_pb2'
435
+ # @@protoc_insertion_point(class_scope:long_metric.DynamicMapState)
436
+ })
437
+ _sym_db.RegisterMessage(DynamicMapState)
438
+
439
+ RequiredPrediction = _reflection.GeneratedProtocolMessageType('RequiredPrediction', (_message.Message,), {
440
+ 'DESCRIPTOR' : _REQUIREDPREDICTION,
441
+ '__module__' : 'dev.metrics.protos.scenario_pb2'
442
+ # @@protoc_insertion_point(class_scope:long_metric.RequiredPrediction)
443
+ })
444
+ _sym_db.RegisterMessage(RequiredPrediction)
445
+
446
+ Scenario = _reflection.GeneratedProtocolMessageType('Scenario', (_message.Message,), {
447
+ 'DESCRIPTOR' : _SCENARIO,
448
+ '__module__' : 'dev.metrics.protos.scenario_pb2'
449
+ # @@protoc_insertion_point(class_scope:long_metric.Scenario)
450
+ })
451
+ _sym_db.RegisterMessage(Scenario)
452
+
453
+
454
+ # @@protoc_insertion_point(module_scope)
backups/dev/metrics/trajectory_features.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torch import Tensor
4
+ from typing import Tuple
5
+
6
+
7
+ def _wrap_angle(angle: Tensor) -> Tensor:
8
+ return (angle + np.pi) % (2 * np.pi) - np.pi
9
+
10
+
11
+ def central_diff(t: Tensor, pad_value: float) -> Tensor:
12
+ pad_shape = (*t.shape[:-1], 1)
13
+ pad_tensor = torch.full(pad_shape, pad_value, dtype=t.dtype, device=t.device)
14
+ diff_t = (t[..., 2:] - t[..., :-2]) / 2
15
+ return torch.cat([pad_tensor, diff_t, pad_tensor], dim=-1)
16
+
17
+
18
+ def central_logical_and(t: Tensor, pad_value: bool) -> Tensor:
19
+ pad_shape = (*t.shape[:-1], 1)
20
+ pad_tensor = torch.full(pad_shape, pad_value, dtype=torch.bool, device=t.device)
21
+ diff_t = torch.logical_and(t[..., 2:], t[..., :-2])
22
+ return torch.cat([pad_tensor, diff_t, pad_tensor], dim=-1)
23
+
24
+
25
+ def compute_displacement_error(x, y, z, ref_x, ref_y, ref_z) -> Tensor:
26
+ return torch.norm(
27
+ torch.stack([x, y, z], dim=-1) - torch.stack([ref_x, ref_y, ref_z], dim=-1),
28
+ p=2, dim=-1
29
+ )
30
+
31
+
32
+ def compute_kinematic_features(
33
+ x: Tensor,
34
+ y: Tensor,
35
+ z: Tensor,
36
+ heading: Tensor,
37
+ seconds_per_step: float
38
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
39
+ dpos = central_diff(torch.stack([x, y, z], dim=0), pad_value=np.nan)
40
+ linear_speed = torch.norm(dpos, p=2, dim=0) / seconds_per_step
41
+ linear_accel = central_diff(linear_speed, pad_value=np.nan) / seconds_per_step
42
+ dh_step = _wrap_angle(central_diff(heading, pad_value=np.nan) * 2) / 2
43
+ dh = dh_step / seconds_per_step
44
+ d2h_step = _wrap_angle(central_diff(dh_step, pad_value=np.nan) * 2) / 2
45
+ d2h = d2h_step / (seconds_per_step ** 2)
46
+ return linear_speed, linear_accel, dh, d2h
47
+
48
+
49
+ def compute_kinematic_validity(valid: Tensor) -> Tuple[Tensor, Tensor]:
50
+ speed_validity = central_logical_and(valid, pad_value=False)
51
+ acceleration_validity = central_logical_and(speed_validity, pad_value=False)
52
+ return speed_validity, acceleration_validity
backups/dev/metrics/val_close_long_metrics.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "val_close_long/wosac/realism_meta_metric": 0.6187323331832886,
3
+ "val_close_long/wosac/kinematic_metrics": 0.6323384046554565,
4
+ "val_close_long/wosac/interactive_metrics": 0.5528579354286194,
5
+ "val_close_long/wosac/map_based_metrics": 0.0,
6
+ "val_close_long/wosac/placement_based_metrics": 0.6086956858634949,
7
+ "val_close_long/wosac/min_ade": 0.0,
8
+ "val_close_long/wosac/scenario_counter": 61,
9
+ "val_close_long/wosac_likelihood/metametric": 0.6187323331832886,
10
+ "val_close_long/wosac_likelihood/average_displacement_error": 0.0,
11
+ "val_close_long/wosac_likelihood/min_average_displacement_error": 0.0,
12
+ "val_close_long/wosac_likelihood/linear_speed_likelihood": 0.11858943104743958,
13
+ "val_close_long/wosac_likelihood/linear_acceleration_likelihood": 0.6093839406967163,
14
+ "val_close_long/wosac_likelihood/angular_speed_likelihood": 0.8988037705421448,
15
+ "val_close_long/wosac_likelihood/angular_acceleration_likelihood": 0.9025763869285583,
16
+ "val_close_long/wosac_likelihood/distance_to_nearest_object_likelihood": 0.10390616208314896,
17
+ "val_close_long/wosac_likelihood/collision_indication_likelihood": 0.6108496785163879,
18
+ "val_close_long/wosac_likelihood/time_to_collision_likelihood": 0.8568302989006042,
19
+ "val_close_long/wosac_likelihood/simulated_collision_rate": 0.030917881056666374,
20
+ "val_close_long/wosac_likelihood/num_placement_likelihood": 0.7245867848396301,
21
+ "val_close_long/wosac_likelihood/num_removement_likelihood": 0.6228984594345093,
22
+ "val_close_long/wosac_likelihood/distance_placement_likelihood": 1.0,
23
+ "val_close_long/wosac_likelihood/distance_removement_likelihood": 0.08729743212461472
24
+ }
backups/dev/model/smart.py ADDED
@@ -0,0 +1,1100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import contextlib
3
+ import pytorch_lightning as pl
4
+ import math
5
+ import numpy as np
6
+ import pickle
7
+ import random
8
+ import torch
9
+ import torch.nn as nn
10
+ from tqdm import tqdm
11
+ from torch_geometric.data import Batch
12
+ from torch_geometric.data import HeteroData
13
+ from torch.optim.lr_scheduler import LambdaLR
14
+ from collections import defaultdict
15
+
16
+ from dev.utils.func import angle_between_2d_vectors
17
+ from dev.modules.layers import OccLoss
18
+ from dev.modules.attr_tokenizer import Attr_Tokenizer
19
+ from dev.modules.smart_decoder import SMARTDecoder
20
+ from dev.datasets.preprocess import TokenProcessor
21
+ from dev.metrics.compute_metrics import *
22
+ from dev.utils.metrics import *
23
+ from dev.utils.visualization import *
24
+
25
+
26
+ class SMART(pl.LightningModule):
27
+
28
+ def __init__(self, model_config, save_path: os.PathLike="", logger=None, **kwargs) -> None:
29
+ super(SMART, self).__init__()
30
+ self.save_hyperparameters()
31
+ self.model_config = model_config
32
+ self.warmup_steps = model_config.warmup_steps
33
+ self.lr = model_config.lr
34
+ self.total_steps = model_config.total_steps
35
+ self.dataset = model_config.dataset
36
+ self.input_dim = model_config.input_dim
37
+ self.hidden_dim = model_config.hidden_dim
38
+ self.output_dim = model_config.output_dim
39
+ self.output_head = model_config.output_head
40
+ self.num_historical_steps = model_config.num_historical_steps
41
+ self.num_future_steps = model_config.decoder.num_future_steps
42
+ self.num_freq_bands = model_config.num_freq_bands
43
+ self.save_path = save_path
44
+ self.vis_map = False
45
+ self.noise = True
46
+ self.local_logger = logger
47
+ self.max_epochs = kwargs.get('max_epochs', 0)
48
+ module_dir = os.path.dirname(os.path.dirname(__file__))
49
+
50
+ self.map_token_traj_path = os.path.join(module_dir, 'tokens/map_traj_token5.pkl')
51
+ self.init_map_token()
52
+
53
+ self.predict_motion = model_config.predict_motion
54
+ self.predict_state = model_config.predict_state
55
+ self.predict_map = model_config.predict_map
56
+ self.predict_occ = model_config.predict_occ
57
+ self.pl2seed_radius = model_config.decoder.pl2seed_radius
58
+ self.token_size = model_config.decoder.token_size
59
+
60
+ # if `disable_grid_token` is True, then we process all locations as
61
+ # the continuous values. Besides, no occupancy grid input.
62
+ # Also, no need to predict the xy offset.
63
+ self.disable_grid_token = getattr(model_config, 'disable_grid_token') \
64
+ if hasattr(model_config, 'disable_grid_token') else False
65
+ self.use_grid_token = not self.disable_grid_token
66
+ if self.disable_grid_token:
67
+ self.predict_occ = False
68
+
69
+ self.token_processer = TokenProcessor(self.token_size,
70
+ training=self.training,
71
+ predict_motion=self.predict_motion,
72
+ predict_state=self.predict_state,
73
+ predict_map=self.predict_map,
74
+ state_token=model_config.state_token,
75
+ pl2seed_radius=self.pl2seed_radius)
76
+
77
+ self.attr_tokenizer = Attr_Tokenizer(grid_range=self.model_config.grid_range,
78
+ grid_interval=self.model_config.grid_interval,
79
+ radius=model_config.decoder.pl2seed_radius,
80
+ angle_interval=self.model_config.angle_interval)
81
+
82
+ # state tokens
83
+ self.invalid_state = int(self.model_config.state_token['invalid'])
84
+ self.valid_state = int(self.model_config.state_token['valid'])
85
+ self.enter_state = int(self.model_config.state_token['enter'])
86
+ self.exit_state = int(self.model_config.state_token['exit'])
87
+
88
+ self.seed_size = int(model_config.decoder.seed_size)
89
+
90
+ self.encoder = SMARTDecoder(
91
+ decoder_type=model_config.decoder_type,
92
+ dataset=model_config.dataset,
93
+ input_dim=model_config.input_dim,
94
+ hidden_dim=model_config.hidden_dim,
95
+ num_historical_steps=model_config.num_historical_steps,
96
+ num_freq_bands=model_config.num_freq_bands,
97
+ num_heads=model_config.num_heads,
98
+ head_dim=model_config.head_dim,
99
+ dropout=model_config.dropout,
100
+ num_map_layers=model_config.decoder.num_map_layers,
101
+ num_agent_layers=model_config.decoder.num_agent_layers,
102
+ pl2pl_radius=model_config.decoder.pl2pl_radius,
103
+ pl2a_radius=model_config.decoder.pl2a_radius,
104
+ pl2seed_radius=model_config.decoder.pl2seed_radius,
105
+ a2a_radius=model_config.decoder.a2a_radius,
106
+ a2sa_radius=model_config.decoder.a2sa_radius,
107
+ pl2sa_radius=model_config.decoder.pl2sa_radius,
108
+ time_span=model_config.decoder.time_span,
109
+ map_token={'traj_src': self.map_token['traj_src']},
110
+ token_size=self.token_size,
111
+ attr_tokenizer=self.attr_tokenizer,
112
+ predict_motion=self.predict_motion,
113
+ predict_state=self.predict_state,
114
+ predict_map=self.predict_map,
115
+ predict_occ=self.predict_occ,
116
+ state_token=model_config.state_token,
117
+ use_grid_token=self.use_grid_token,
118
+ seed_size=self.seed_size,
119
+ buffer_size=model_config.decoder.buffer_size,
120
+ num_recurrent_steps_val=model_config.num_recurrent_steps_val,
121
+ loss_weight=model_config.loss_weight,
122
+ logger=logger,
123
+ )
124
+ self.minADE = minADE(max_guesses=1)
125
+ self.minFDE = minFDE(max_guesses=1)
126
+ self.TokenCls = TokenCls(max_guesses=1)
127
+ self.StateCls = TokenCls(max_guesses=1)
128
+ self.StateAccuracy = StateAccuracy(state_token=self.model_config.state_token)
129
+ self.GridOverlapRate = GridOverlapRate(num_step=18,
130
+ state_token=self.model_config.state_token,
131
+ seed_size=self.encoder.agent_encoder.num_seed_feature)
132
+ # self.NumInsertAccuracy = NumInsertAccuracy()
133
+ self.loss_weight = model_config.loss_weight
134
+
135
+ self.token_cls_loss = nn.CrossEntropyLoss(label_smoothing=0.1)
136
+ if self.predict_map:
137
+ self.map_token_loss = nn.CrossEntropyLoss(label_smoothing=0.1)
138
+ if self.predict_state:
139
+ self.state_cls_loss = nn.CrossEntropyLoss(
140
+ torch.tensor(self.loss_weight['state_weight']))
141
+ self.state_cls_loss_seed = nn.CrossEntropyLoss(
142
+ torch.tensor(self.loss_weight['seed_state_weight']))
143
+ self.type_cls_loss_seed = nn.CrossEntropyLoss(
144
+ torch.tensor(self.loss_weight['seed_type_weight']))
145
+ self.pos_cls_loss_seed = nn.CrossEntropyLoss(label_smoothing=0.1)
146
+ self.head_cls_loss_seed = nn.CrossEntropyLoss()
147
+ self.offset_reg_loss_seed = nn.MSELoss()
148
+ self.shape_reg_loss_seed = nn.MSELoss()
149
+ self.pos_reg_loss_seed = nn.MSELoss()
150
+ if self.predict_occ:
151
+ self.occ_cls_loss = nn.CrossEntropyLoss()
152
+ self.agent_occ_loss_seed = nn.BCEWithLogitsLoss(
153
+ pos_weight=torch.tensor([self.loss_weight['agent_occ_pos_weight']]))
154
+ self.pt_occ_loss_seed = nn.BCEWithLogitsLoss(
155
+ pos_weight=torch.tensor([self.loss_weight['pt_occ_pos_weight']]))
156
+ # self.agent_occ_loss_seed = OccLoss()
157
+ # self.pt_occ_loss_seed = OccLoss()
158
+ # self.agent_occ_loss_seed = nn.BCEWithLogitsLoss()
159
+ # self.pt_occ_loss_seed = nn.BCEWithLogitsLoss()
160
+ self.rollout_num = 1
161
+
162
+ self.val_open_loop = model_config.val_open_loop
163
+ self.val_close_loop = model_config.val_close_loop
164
+ self.val_insert = model_config.val_insert or bool(os.getenv('VAL_INSERT'))
165
+ self.n_rollout_close_val = model_config.n_rollout_close_val
166
+ self.t = kwargs.get('t', 2)
167
+
168
+ # for validation / test
169
+ self._mode = 'training'
170
+ self._long_metrics = None
171
+ self._online_metric = False
172
+ self._save_validate_reuslts = False
173
+ self._plot_rollouts = False
174
+
175
+ def set(self, mode: str = 'train'):
176
+ self._mode = mode
177
+
178
+ if mode == 'validation':
179
+ self._online_metric = True
180
+ self._save_validate_reuslts = True
181
+ self._long_metrics = LongMetric('val_close_long')
182
+
183
+ elif mode == 'test':
184
+ self._save_validate_reuslts = True
185
+
186
+ elif mode == 'plot_rollouts':
187
+ self._plot_rollouts = True
188
+
189
+ def init_map_token(self):
190
+ self.argmin_sample_len = 3
191
+ map_token_traj = pickle.load(open(self.map_token_traj_path, 'rb'))
192
+ self.map_token = {'traj_src': map_token_traj['traj_src'], }
193
+ traj_end_theta = np.arctan2(self.map_token['traj_src'][:, -1, 1]-self.map_token['traj_src'][:, -2, 1],
194
+ self.map_token['traj_src'][:, -1, 0]-self.map_token['traj_src'][:, -2, 0])
195
+ indices = torch.linspace(0, self.map_token['traj_src'].shape[1]-1, steps=self.argmin_sample_len).long()
196
+ self.map_token['sample_pt'] = torch.from_numpy(self.map_token['traj_src'][:, indices]).to(torch.float)
197
+ self.map_token['traj_end_theta'] = torch.from_numpy(traj_end_theta).to(torch.float)
198
+ self.map_token['traj_src'] = torch.from_numpy(self.map_token['traj_src']).to(torch.float)
199
+
200
+ def get_agent_inputs(self, data: HeteroData):
201
+ res = self.encoder.get_agent_inputs(data)
202
+ return res
203
+
204
+ def forward(self, data: HeteroData):
205
+ res = self.encoder(data)
206
+ return res
207
+
208
+ def maybe_autocast(self, dtype=torch.float16):
209
+ enable_autocast = self.device != torch.device("cpu")
210
+
211
+ if enable_autocast:
212
+ return torch.cuda.amp.autocast(dtype=dtype)
213
+ else:
214
+ return contextlib.nullcontext()
215
+
216
+ def check_inputs(self, data: HeteroData):
217
+ inputs = self.get_agent_inputs(data)
218
+ next_token_idx_gt = inputs['next_token_idx_gt']
219
+ next_state_idx_gt = inputs['next_state_idx_gt'].clone()
220
+ next_token_eval_mask = inputs['next_token_eval_mask'].clone()
221
+ raw_agent_valid_mask = inputs['raw_agent_valid_mask'].clone()
222
+
223
+ self.StateAccuracy.update(state_idx=next_state_idx_gt,
224
+ valid_mask=raw_agent_valid_mask)
225
+
226
+ state_token = inputs['state_token']
227
+ grid_index = inputs['grid_index']
228
+ self.GridOverlapRate.update(state_token=state_token,
229
+ grid_index=grid_index)
230
+
231
+ print(self.StateAccuracy)
232
+ print(self.GridOverlapRate)
233
+ # self.log('valid_accuracy', self.StateAccuracy.compute()['valid'], prog_bar=True, on_step=True, on_epoch=True, batch_size=1)
234
+ # self.log('invalid_accuracy', self.StateAccuracy.compute()['invalid'], prog_bar=True, on_step=True, on_epoch=True, batch_size=1)
235
+
236
+ def training_step(self,
237
+ data,
238
+ batch_idx):
239
+
240
+ data = self.token_processer(data)
241
+
242
+ data = self.match_token_map(data)
243
+ data = self.sample_pt_pred(data)
244
+
245
+ # find map tokens for entering agents
246
+ data = self._fetch_enterings(data)
247
+
248
+ data['batch_size_a'] = data['agent']['ptr'][1:] - data['agent']['ptr'][:-1]
249
+ data['batch_size_pl'] = data['pt_token']['ptr'][1:] - data['pt_token']['ptr'][:-1]
250
+ if isinstance(data, Batch):
251
+ data['agent']['av_index'] += data['agent']['ptr'][:-1]
252
+
253
+ if int(os.getenv("CHECK_INPUTS", 0)):
254
+ return self.check_inputs(data)
255
+
256
+ pred = self(data)
257
+
258
+ loss = 0
259
+
260
+ log_params = dict(prog_bar=True, on_step=True, on_epoch=True, batch_size=1, sync_dist=True)
261
+
262
+ if pred.get('occ_decoder', False):
263
+
264
+ agent_occ = pred['agent_occ']
265
+ agent_occ_gt = pred['agent_occ_gt']
266
+ agent_occ_eval_mask = pred['agent_occ_eval_mask']
267
+ pt_occ = pred['pt_occ']
268
+ pt_occ_gt = pred['pt_occ_gt']
269
+ pt_occ_eval_mask = pred['pt_occ_eval_mask']
270
+
271
+ agent_occ_cls_loss = self.occ_cls_loss(agent_occ[agent_occ_eval_mask],
272
+ agent_occ_gt[agent_occ_eval_mask])
273
+ pt_occ_cls_loss = self.occ_cls_loss(pt_occ[pt_occ_eval_mask],
274
+ pt_occ_gt[pt_occ_eval_mask])
275
+ self.log('agent_occ_cls_loss', agent_occ_cls_loss, **log_params)
276
+ self.log('pt_occ_cls_loss', pt_occ_cls_loss, **log_params)
277
+ loss = loss + agent_occ_cls_loss + pt_occ_cls_loss
278
+
279
+ # plot
280
+ # plot_scenario_ids = ['74ad7b76d5906d39', '1351ea8b8333ddcb', '1352066cc3c0508d', '135436833ce5b9e7', '13570a32432d449', '13577c32a81336fb']
281
+ if random.random() < 4e-5 or os.getenv('DEBUG'):
282
+ num_step = pred['num_step']
283
+ num_agent = pred['num_agent']
284
+ num_pt = pred['num_pt']
285
+ with torch.no_grad():
286
+ agent_occ = agent_occ.reshape(num_step, num_agent, -1).transpose(0, 1)
287
+ agent_occ_gt = agent_occ_gt.reshape(num_step, num_agent).transpose(0, 1)
288
+ agent_occ_gt[agent_occ_gt == -1] = self.encoder.agent_encoder.grid_size // 2
289
+ agent_occ_gt = torch.nn.functional.one_hot(agent_occ_gt, num_classes=self.encoder.agent_encoder.grid_size)
290
+ agent_occ = self.attr_tokenizer.pad_square(agent_occ.softmax(-1).detach().cpu().numpy())[0]
291
+ agent_occ_gt = self.attr_tokenizer.pad_square(agent_occ_gt.detach().cpu().numpy())[0]
292
+ plot_occ_grid(pred['scenario_id'][0],
293
+ agent_occ,
294
+ gt_occ=agent_occ_gt,
295
+ mode='agent',
296
+ save_path=self.save_path,
297
+ prefix=f'training_{self.global_step:06d}_')
298
+ pt_occ = pt_occ.reshape(num_step, num_pt, -1).transpose(0, 1)
299
+ pt_occ_gt = pt_occ_gt.reshape(num_step, num_pt).transpose(0, 1)
300
+ pt_occ_gt[pt_occ_gt == -1] = self.encoder.agent_encoder.grid_size // 2
301
+ pt_occ_gt = torch.nn.functional.one_hot(pt_occ_gt, num_classes=self.encoder.agent_encoder.grid_size)
302
+ pt_occ = self.attr_tokenizer.pad_square(pt_occ.sigmoid().detach().cpu().numpy())[0]
303
+ pt_occ_gt = self.attr_tokenizer.pad_square(pt_occ_gt.detach().cpu().numpy())[0]
304
+ plot_occ_grid(pred['scenario_id'][0],
305
+ pt_occ,
306
+ gt_occ=pt_occ_gt,
307
+ mode='pt',
308
+ save_path=self.save_path,
309
+ prefix=f'training_{self.global_step:06d}_')
310
+
311
+ return loss
312
+
313
+ train_mask = data['agent']['train_mask']
314
+ # remove_ina_mask = data['agent']['remove_ina_mask']
315
+
316
+ # motion token loss
317
+ if self.predict_motion:
318
+
319
+ next_token_idx = pred['next_token_idx']
320
+ next_token_prob = pred['next_token_prob'] # (a, t, token_size)
321
+ next_token_idx_gt = pred['next_token_idx_gt'] # (a, t)
322
+ next_token_eval_mask = pred['next_token_eval_mask'] # (a, t)
323
+ next_token_eval_mask &= train_mask[:, None]
324
+
325
+ token_cls_loss = self.token_cls_loss(next_token_prob[next_token_eval_mask],
326
+ next_token_idx_gt[next_token_eval_mask]) * self.loss_weight['token_cls_loss']
327
+ self.log('token_cls_loss', token_cls_loss, **log_params)
328
+
329
+ loss = loss + token_cls_loss
330
+
331
+ # record motion predict precision of certain timesteps of centain type of agents
332
+ with torch.no_grad():
333
+ agent_state_idx_gt = data['agent']['state_idx']
334
+ index = torch.nonzero(agent_state_idx_gt == self.enter_state)
335
+ for i in range(10):
336
+ index[:, 1] += 1
337
+ index = index[index[:, 1] < agent_state_idx_gt.shape[1] - 1]
338
+ prob = next_token_prob[index[:, 0], index[:, 1]]
339
+ gt = next_token_idx_gt[index[:, 0], index[:, 1]]
340
+ mask = next_token_eval_mask[index[:, 0], index[:, 1]]
341
+ step_token_cls_loss = self.token_cls_loss(prob[mask], gt[mask])
342
+ self.log(f's{i}', step_token_cls_loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True)
343
+
344
+ # state token loss
345
+ if self.predict_state:
346
+
347
+ next_state_idx = pred['next_state_idx']
348
+ next_state_prob = pred['next_state_prob']
349
+ next_state_idx_gt = pred['next_state_idx_gt']
350
+ next_state_eval_mask = pred['next_state_eval_mask'] # (num_agent, num_timestep)
351
+
352
+ state_cls_loss = self.state_cls_loss(next_state_prob[next_state_eval_mask],
353
+ next_state_idx_gt[next_state_eval_mask]) * self.loss_weight['state_cls_loss']
354
+ if torch.isnan(state_cls_loss):
355
+ print("Found NaN in state_cls_loss!!!")
356
+ print(next_state_prob.shape)
357
+ print(next_state_idx_gt.shape)
358
+ print(next_state_eval_mask.shape)
359
+ print(next_state_idx_gt[next_state_eval_mask].shape)
360
+ self.log('state_cls_loss', state_cls_loss, **log_params)
361
+
362
+ loss = loss + state_cls_loss
363
+
364
+ next_state_idx_seed = pred['next_state_idx_seed']
365
+ next_state_prob_seed = pred['next_state_prob_seed']
366
+ next_state_idx_gt_seed = pred['next_state_idx_gt_seed']
367
+ next_type_prob_seed = pred['next_type_prob_seed']
368
+ next_type_idx_gt_seed = pred['next_type_idx_gt_seed']
369
+ next_shape_seed = pred['next_shape_seed']
370
+ next_shape_gt_seed = pred['next_shape_gt_seed']
371
+ next_state_eval_mask_seed = pred['next_state_eval_mask_seed']
372
+ next_attr_eval_mask_seed = pred['next_attr_eval_mask_seed']
373
+
374
+ # when num_seed_gt=0 loss term will be NaN
375
+ state_cls_loss_seed = self.state_cls_loss_seed(next_state_prob_seed[next_state_eval_mask_seed],
376
+ next_state_idx_gt_seed[next_state_eval_mask_seed]) * self.loss_weight['state_cls_loss']
377
+ state_cls_loss_seed = torch.nan_to_num(state_cls_loss_seed)
378
+ self.log('seed_state_cls_loss', state_cls_loss_seed, **log_params)
379
+
380
+ type_cls_loss_seed = self.type_cls_loss_seed(next_type_prob_seed[next_attr_eval_mask_seed],
381
+ next_type_idx_gt_seed[next_attr_eval_mask_seed]) * self.loss_weight['type_cls_loss']
382
+ shape_reg_loss_seed = self.shape_reg_loss_seed(next_shape_seed[next_attr_eval_mask_seed],
383
+ next_shape_gt_seed[next_attr_eval_mask_seed]) * self.loss_weight['shape_reg_loss']
384
+ type_cls_loss_seed = torch.nan_to_num(type_cls_loss_seed)
385
+ shape_reg_loss_seed = torch.nan_to_num(shape_reg_loss_seed)
386
+ self.log('seed_type_cls_loss', type_cls_loss_seed, **log_params)
387
+ self.log('seed_shape_reg_loss', shape_reg_loss_seed, **log_params)
388
+
389
+ loss = loss + state_cls_loss_seed + type_cls_loss_seed + shape_reg_loss_seed
390
+
391
+ next_head_rel_prob_seed = pred['next_head_rel_prob_seed']
392
+ next_head_rel_index_gt_seed = pred['next_head_rel_index_gt_seed']
393
+ next_offset_xy_seed = pred['next_offset_xy_seed']
394
+ next_offset_xy_gt_seed = pred['next_offset_xy_gt_seed']
395
+ next_head_eval_mask_seed = pred['next_head_eval_mask_seed']
396
+
397
+ if self.use_grid_token:
398
+ next_pos_rel_prob_seed = pred['next_pos_rel_prob_seed']
399
+ next_pos_rel_index_gt_seed = pred['next_pos_rel_index_gt_seed']
400
+
401
+ pos_cls_loss_seed = self.pos_cls_loss_seed(next_pos_rel_prob_seed[next_attr_eval_mask_seed],
402
+ next_pos_rel_index_gt_seed[next_attr_eval_mask_seed]) * self.loss_weight['pos_cls_loss']
403
+ offset_reg_loss_seed = self.offset_reg_loss_seed(next_offset_xy_seed[next_head_eval_mask_seed],
404
+ next_offset_xy_gt_seed[next_head_eval_mask_seed]) * self.loss_weight['offset_reg_loss']
405
+ pos_cls_loss_seed = torch.nan_to_num(pos_cls_loss_seed)
406
+ self.log('seed_pos_cls_loss', pos_cls_loss_seed, **log_params)
407
+ self.log('seed_offset_reg_loss', offset_reg_loss_seed, **log_params)
408
+
409
+ loss = loss + pos_cls_loss_seed + offset_reg_loss_seed
410
+
411
+ else:
412
+ next_pos_rel_xy_seed = pred['next_pos_rel_xy_seed']
413
+ next_pos_rel_xy_gt_seed = pred['next_pos_rel_xy_gt_seed']
414
+ pos_reg_loss_seed = self.pos_reg_loss_seed(next_pos_rel_xy_seed[next_attr_eval_mask_seed],
415
+ next_pos_rel_xy_gt_seed[next_attr_eval_mask_seed]) * self.loss_weight['pos_reg_loss']
416
+ pos_reg_loss_seed = torch.nan_to_num(pos_reg_loss_seed)
417
+ self.log('seed_pos_reg_loss', pos_reg_loss_seed, **log_params)
418
+ loss = loss + pos_reg_loss_seed
419
+
420
+ head_cls_loss_seed = self.head_cls_loss_seed(next_head_rel_prob_seed[next_head_eval_mask_seed],
421
+ next_head_rel_index_gt_seed[next_head_eval_mask_seed]) * self.loss_weight['head_cls_loss']
422
+ self.log('seed_head_cls_loss', head_cls_loss_seed, **log_params)
423
+
424
+ loss = loss + head_cls_loss_seed
425
+
426
+ # plot_scenario_ids = ['74ad7b76d5906d39', '1351ea8b8333ddcb', '1352066cc3c0508d', '135436833ce5b9e7', '13570a32432d449', '13577c32a81336fb']
427
+ if random.random() < 4e-5 or int(os.getenv('DEBUG', 0)):
428
+ with torch.no_grad():
429
+ # plot probability of inserting new agent (agent-timestep)
430
+ raw_next_state_prob_seed = pred['raw_next_state_prob_seed']
431
+ plot_prob_seed(pred['scenario_id'][0],
432
+ torch.softmax(raw_next_state_prob_seed, dim=-1
433
+ )[..., -1].detach().cpu().numpy(),
434
+ self.save_path,
435
+ prefix=f'training_{self.global_step:06d}_',
436
+ indices=pred['target_indices'].cpu().numpy())
437
+
438
+ # plot heatmap of inserting new agent
439
+ if self.use_grid_token:
440
+ next_pos_rel_prob_seed = pred['next_pos_rel_prob_seed']
441
+ if next_pos_rel_prob_seed.shape[0] > 0:
442
+ next_pos_rel_prob_seed = torch.softmax(next_pos_rel_prob_seed, dim=-1).detach().cpu().numpy()
443
+ indices = next_pos_rel_index_gt_seed.detach().cpu().numpy()
444
+ mask = next_attr_eval_mask_seed.detach().cpu().numpy().astype(np.bool_)
445
+ indices[~mask] = -1
446
+ prob, indices = self.attr_tokenizer.pad_square(next_pos_rel_prob_seed, indices)
447
+ plot_insert_grid(pred['scenario_id'][0],
448
+ prob,
449
+ indices=indices,
450
+ save_path=self.save_path,
451
+ prefix=f'training_{self.global_step:06d}_')
452
+
453
+ if self.predict_occ:
454
+
455
+ neighbor_agent_grid_idx = pred['neighbor_agent_grid_idx']
456
+ neighbor_agent_grid_index_gt = pred['neighbor_agent_grid_index_gt']
457
+ neighbor_agent_grid_index_eval_mask = pred['neighbor_agent_grid_index_eval_mask']
458
+ neighbor_pt_grid_idx = pred['neighbor_pt_grid_idx']
459
+ neighbor_pt_grid_index_gt = pred['neighbor_pt_grid_index_gt']
460
+ neighbor_pt_grid_index_eval_mask = pred['neighbor_pt_grid_index_eval_mask']
461
+
462
+ neighbor_agent_grid_cls_loss = self.occ_cls_loss(neighbor_agent_grid_idx[neighbor_agent_grid_index_eval_mask],
463
+ neighbor_agent_grid_index_gt[neighbor_agent_grid_index_eval_mask])
464
+ neighbor_pt_grid_cls_loss = self.occ_cls_loss(neighbor_pt_grid_idx[neighbor_pt_grid_index_eval_mask],
465
+ neighbor_pt_grid_index_gt[neighbor_pt_grid_index_eval_mask])
466
+ # self.log('neighbor_agent_grid_cls_loss', neighbor_agent_grid_cls_loss, **log_params)
467
+ # self.log('neighbor_pt_grid_cls_loss', neighbor_pt_grid_cls_loss, **log_params)
468
+ # loss = loss + neighbor_agent_grid_cls_loss + neighbor_pt_grid_cls_loss
469
+
470
+ grid_agent_occ_seed = pred['grid_agent_occ_seed']
471
+ grid_pt_occ_seed = pred['grid_pt_occ_seed']
472
+ grid_agent_occ_gt_seed = pred['grid_agent_occ_gt_seed'].float()
473
+ grid_pt_occ_gt_seed = pred['grid_pt_occ_gt_seed'].float()
474
+ grid_agent_occ_eval_mask_seed = pred['grid_agent_occ_eval_mask_seed']
475
+ grid_pt_occ_eval_mask_seed = pred['grid_pt_occ_eval_mask_seed']
476
+
477
+ # plot_scenario_ids = ['74ad7b76d5906d39', '1351ea8b8333ddcb', '1352066cc3c0508d', '135436833ce5b9e7', '13570a32432d449', '13577c32a81336fb']
478
+ if random.random() < 4e-5 or os.getenv('DEBUG'):
479
+ with torch.no_grad():
480
+ agent_occ = self.attr_tokenizer.pad_square(grid_agent_occ_seed.sigmoid().detach().cpu().numpy())[0]
481
+ agent_occ_gt = self.attr_tokenizer.pad_square(grid_agent_occ_gt_seed.detach().cpu().numpy())[0]
482
+ plot_occ_grid(pred['scenario_id'][0],
483
+ agent_occ,
484
+ gt_occ=agent_occ_gt,
485
+ mode='agent',
486
+ save_path=self.save_path,
487
+ prefix=f'training_{self.global_step:06d}_')
488
+ pt_occ = self.attr_tokenizer.pad_square(grid_pt_occ_seed.sigmoid().detach().cpu().numpy())[0]
489
+ pt_occ_gt = self.attr_tokenizer.pad_square(grid_pt_occ_gt_seed.detach().cpu().numpy())[0]
490
+ plot_occ_grid(pred['scenario_id'][0],
491
+ pt_occ,
492
+ gt_occ=pt_occ_gt,
493
+ mode='pt',
494
+ save_path=self.save_path,
495
+ prefix=f'training_{self.global_step:06d}_')
496
+
497
+ grid_agent_occ_gt_seed[grid_agent_occ_gt_seed == -1] = 0
498
+ if grid_agent_occ_gt_seed.min() < 0 or grid_agent_occ_gt_seed.max() > 1 or \
499
+ grid_pt_occ_gt_seed.min() < 0 or grid_pt_occ_gt_seed.max() > 1:
500
+ raise RuntimeError("Occurred invalid values in occ gt")
501
+
502
+ agent_occ_loss = self.agent_occ_loss_seed(grid_agent_occ_seed[grid_agent_occ_eval_mask_seed],
503
+ grid_agent_occ_gt_seed[grid_agent_occ_eval_mask_seed]) * self.loss_weight['agent_occ_loss']
504
+ pt_occ_loss = self.pt_occ_loss_seed(grid_pt_occ_seed[grid_pt_occ_eval_mask_seed],
505
+ grid_pt_occ_gt_seed[grid_pt_occ_eval_mask_seed]) * self.loss_weight['pt_occ_loss']
506
+
507
+ self.log('agent_occ_loss', agent_occ_loss, **log_params)
508
+ self.log('pt_occ_loss', pt_occ_loss, **log_params)
509
+ loss = loss + agent_occ_loss + pt_occ_loss
510
+
511
+ if os.getenv('LOG_TRAIN', False) and (self.predict_motion or self.predict_state):
512
+ for a in range(next_token_idx.shape[0]):
513
+ print(f"agent: {a}")
514
+ if self.predict_motion:
515
+ print(f"pred motion: {next_token_idx[a, :, 0].tolist()}, \ngt motion: {next_token_idx_gt[a, :].tolist()}")
516
+ print(f"train mask: {next_token_eval_mask[a].long().tolist()}")
517
+ if self.predict_state:
518
+ print(f"pred state: {next_state_idx[a, :, 0].tolist()}, \ngt state: {next_state_idx_gt[a, :].tolist()}")
519
+ print(f"train mask: {next_state_eval_mask[a].long().tolist()}")
520
+ num_sa = next_state_idx_seed[..., 0].sum(dim=-1).bool().sum()
521
+ for sa in range(num_sa):
522
+ print(f"seed agent: {sa}")
523
+ print(f"seed pred state: {next_state_idx_seed[sa, :, 0].tolist()}, \ngt seed state: {next_state_idx_gt_seed[sa, :].tolist()}")
524
+ # if sa < next_pos_rel_seed.shape[0]:
525
+ # print(f"pred pos: {next_pos_rel_seed[sa, :, 0].tolist()}, \ngt pos: {next_pos_rel_gt_seed[sa, :, 0].tolist()}")
526
+ # print(f"pred head: {next_head_rel_seed[sa].tolist()}, \ngt head: {next_head_rel_gt_seed[sa].tolist()}")
527
+ # print(f"seed train mask: {next_state_eval_mask_seed[sa].long().tolist()}")
528
+
529
+ # map token loss
530
+ if self.predict_map:
531
+
532
+ map_next_token_prob = pred['map_next_token_prob']
533
+ map_next_token_idx_gt = pred['map_next_token_idx_gt']
534
+ map_next_token_eval_mask = pred['map_next_token_eval_mask']
535
+
536
+ map_token_loss = self.map_token_loss(map_next_token_prob[map_next_token_eval_mask], map_next_token_idx_gt[map_next_token_eval_mask]) * self.loss_weight['map_token_loss']
537
+ self.log('map_token_loss', map_token_loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=1)
538
+ loss = loss + map_token_loss
539
+
540
+ allocated = torch.cuda.memory_allocated(device='cuda:0') / (1024 ** 3)
541
+ reserved = torch.cuda.memory_reserved(device='cuda:0') / (1024 ** 3)
542
+ self.log('allocated', allocated, **log_params)
543
+ self.log('reserved', reserved, **log_params)
544
+
545
+ return loss
546
+
547
+ def validation_step(self,
548
+ data,
549
+ batch_idx):
550
+
551
+ self.debug = int(os.getenv('DEBUG', 0))
552
+
553
+ # ! validation in training process
554
+ if (
555
+ self._mode == 'training' and (
556
+ self.current_epoch not in [5, 10, 20, 25, self.max_epochs] or random.random() > 5e-4) and
557
+ not self.debug
558
+ ):
559
+ self.val_open_loop = False
560
+ self.val_close_loop = False
561
+ return
562
+
563
+ if int(os.getenv('NO_VAL', 0)) or int(os.getenv("CHECK_INPUTS", 0)):
564
+ return
565
+
566
+ # ! check if save exists
567
+ if not self._plot_rollouts:
568
+ rollouts_path = os.path.join(self.save_path, f'idx_{self.trainer.global_rank}_{batch_idx}_rollouts.pkl')
569
+ if os.path.exists(rollouts_path):
570
+ tqdm.write(f'Skipped batch {batch_idx}')
571
+ return
572
+ else:
573
+ rollouts_path = os.path.join(self.save_path, f'{data["scenario_id"][0]}.gif')
574
+ if os.path.exists(rollouts_path):
575
+ tqdm.write(f'Skipped scenario {data["scenario_id"][0]}')
576
+ return
577
+
578
+ # ! data preparation
579
+ data = self.token_processer(data)
580
+
581
+ data = self.match_token_map(data)
582
+ data = self.sample_pt_pred(data)
583
+
584
+ # find map tokens for entering agents
585
+ data = self._fetch_enterings(data)
586
+
587
+ data['batch_size_a'] = data['agent']['ptr'][1:] - data['agent']['ptr'][:-1]
588
+ data['batch_size_pl'] = data['pt_token']['ptr'][1:] - data['pt_token']['ptr'][:-1]
589
+ if isinstance(data, Batch):
590
+ data['agent']['av_index'] += data['agent']['ptr'][:-1]
591
+
592
+ if int(os.getenv('NEAREST_POS', 0)):
593
+ pred = self.encoder.predict_nearest_pos(data, rank=self.local_rank)
594
+ return
595
+
596
+ # if self.insert_agent:
597
+ # pred = self.encoder.insert_agent(data)
598
+ # return
599
+
600
+ # ! open-loop validation
601
+ if self.val_open_loop or int(os.getenv('OPEN_LOOP', 0)):
602
+
603
+ pred = self(data)
604
+
605
+ # pred['next_state_prob_seed'] = torch.softmax(pred['next_state_prob_seed'], dim=-1)[..., -1]
606
+ # plot_prob_seed(pred, self.save_path, suffix=f'_training')
607
+
608
+ loss = 0
609
+
610
+ if self.predict_motion:
611
+
612
+ # motion token
613
+ next_token_idx = pred['next_token_idx']
614
+ next_token_idx_gt = pred['next_token_idx_gt'] # (num_agent, num_step, 10)
615
+ next_token_prob = pred['next_token_prob']
616
+ next_token_eval_mask = pred['next_token_eval_mask']
617
+
618
+ token_cls_loss = self.token_cls_loss(next_token_prob[next_token_eval_mask], next_token_idx_gt[next_token_eval_mask])
619
+ loss = loss + token_cls_loss
620
+
621
+ if self.predict_state:
622
+
623
+ # state token
624
+ next_state_idx = pred['next_state_idx']
625
+ next_state_idx_gt = pred['next_state_idx_gt']
626
+ next_state_prob = pred['next_state_prob']
627
+ next_state_eval_mask = pred['next_state_eval_mask']
628
+
629
+ state_cls_loss = self.state_cls_loss(next_state_prob[next_state_eval_mask], next_state_idx_gt[next_state_eval_mask])
630
+ loss = loss + state_cls_loss
631
+
632
+ # seed state token
633
+ next_state_idx_seed = pred['next_state_idx_seed']
634
+ next_state_idx_gt_seed = pred['next_state_idx_gt_seed']
635
+
636
+ if self.predict_occ:
637
+
638
+ grid_agent_occ_seed = pred['grid_agent_occ_seed']
639
+ grid_pt_occ_seed = pred['grid_pt_occ_seed']
640
+ grid_agent_occ_gt_seed = pred['grid_agent_occ_gt_seed'].float()
641
+ grid_pt_occ_gt_seed = pred['grid_pt_occ_gt_seed'].float()
642
+
643
+ agent_occ = self.attr_tokenizer.pad_square(grid_agent_occ_seed.sigmoid().detach().cpu().numpy())[0]
644
+ agent_occ_gt = self.attr_tokenizer.pad_square(grid_agent_occ_gt_seed.detach().cpu().numpy())[0]
645
+ plot_occ_grid(pred['scenario_id'][0],
646
+ agent_occ,
647
+ gt_occ=agent_occ_gt,
648
+ mode='agent',
649
+ save_path=self.save_path,
650
+ prefix=f'eval_')
651
+ pt_occ = self.attr_tokenizer.pad_square(grid_pt_occ_seed.sigmoid().detach().cpu().numpy())[0]
652
+ pt_occ_gt = self.attr_tokenizer.pad_square(grid_pt_occ_gt_seed.detach().cpu().numpy())[0]
653
+ plot_occ_grid(pred['scenario_id'][0],
654
+ pt_occ,
655
+ gt_occ=pt_occ_gt,
656
+ mode='pt',
657
+ save_path=self.save_path,
658
+ prefix=f'eval_')
659
+
660
+ self.log('val_loss', loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=1, sync_dist=True)
661
+
662
+ if self.val_insert:
663
+
664
+ pred = self(data)
665
+
666
+ next_state_idx_seed = pred['next_state_idx_seed']
667
+ next_state_idx_gt_seed = pred['next_state_idx_gt_seed']
668
+
669
+ self.NumInsertAccuracy.update(next_state_idx_seed=next_state_idx_seed,
670
+ next_state_idx_gt_seed=next_state_idx_gt_seed)
671
+
672
+ return
673
+
674
+ # ! close-loop validation
675
+ if self.val_close_loop and (self.predict_motion or self.predict_state):
676
+
677
+ rollouts = []
678
+ for _ in tqdm(range(self.n_rollout_close_val), leave=False, desc='Rollout ...'):
679
+ rollout = self.encoder.inference(data.clone())
680
+ rollouts.append(rollout)
681
+
682
+ av_index = int(rollout['ego_index'])
683
+ scenario_id = rollout['scenario_id'][0]
684
+
685
+ # motion tokens
686
+ if self.predict_motion:
687
+
688
+ if self._plot_rollouts: # only plot gifs for last 2 epochs for efficiency
689
+ plot_val(data, rollout, av_index, self.save_path, pl2seed_radius=self.pl2seed_radius, attr_tokenizer=self.attr_tokenizer)
690
+
691
+ # next_token_idx = pred['next_token_idx'][..., None]
692
+ # next_token_idx_gt = pred['next_token_idx_gt'][:, 2:] # hard code 2=11//5
693
+ # next_token_eval_mask = pred['next_token_eval_mask'][:, 2:]
694
+
695
+ # gt_traj = pred['gt_traj']
696
+ # pred_traj = pred['pred_traj']
697
+ # pred_head = pred['pred_head']
698
+
699
+ # self.TokenCls.update(pred=next_token_idx[next_token_eval_mask], target=next_token_idx_gt[next_token_eval_mask],
700
+ # valid_mask=next_token_eval_mask[next_token_eval_mask])
701
+ # self.log('val_token_cls_acc', self.TokenCls, prog_bar=True, on_step=True, on_epoch=True, batch_size=1, sync_dist=True)
702
+
703
+ # remove the agents which are unseen at current step
704
+ # eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps - 1]
705
+
706
+ # self.minADE.update(pred=pred_traj[eval_mask], target=gt_traj[eval_mask], valid_mask=valid_mask[eval_mask])
707
+ # self.minFDE.update(pred=pred_traj[eval_mask], target=gt_traj[eval_mask], valid_mask=valid_mask[eval_mask])
708
+ # print('ade: ', self.minADE.compute(), 'fde: ', self.minFDE.compute())
709
+
710
+ # self.log('val_minADE', self.minADE, prog_bar=True, on_step=True, on_epoch=True, batch_size=1)
711
+ # self.log('val_minFDE', self.minFDE, prog_bar=True, on_step=True, on_epoch=True, batch_size=1)
712
+
713
+ # state tokens
714
+ if self.predict_state:
715
+
716
+ if self.use_grid_token:
717
+ next_pos_rel_prob_seed = rollout['next_pos_rel_prob_seed'].cpu().numpy() # (s, t, grid_size)
718
+ prob, _ = self.attr_tokenizer.pad_square(next_pos_rel_prob_seed)
719
+
720
+ if self._plot_rollouts:
721
+ if self.use_grid_token:
722
+ plot_insert_grid(scenario_id,
723
+ prob,
724
+ save_path=self.save_path,
725
+ prefix=f'inference_')
726
+ plot_prob_seed(scenario_id,
727
+ rollout['next_state_prob_seed'].cpu().numpy(),
728
+ self.save_path,
729
+ prefix=f'inference_')
730
+
731
+ next_state_idx = rollout['next_state_idx'][..., None]
732
+ # next_state_idx_gt = rollout['next_state_idx_gt'][:, 2:]
733
+ # next_state_eval_mask = rollout['next_state_eval_mask'][:, 2:]
734
+
735
+ # self.StateCls.update(pred=next_state_idx[next_token_eval_mask], target=next_state_idx_gt[next_token_eval_mask],
736
+ # valid_mask=next_token_eval_mask[next_token_eval_mask])
737
+ # self.log('val_state_cls_acc', self.TokenCls, prog_bar=True, on_step=True, on_epoch=True, batch_size=1, sync_dist=True)
738
+
739
+ self.StateAccuracy.update(state_idx=next_state_idx[..., 0])
740
+ self.log('valid_accuracy', self.StateAccuracy.compute()['valid'], prog_bar=True, on_step=True, on_epoch=True, batch_size=1)
741
+ self.log('invalid_accuracy', self.StateAccuracy.compute()['invalid'], prog_bar=True, on_step=True, on_epoch=True, batch_size=1)
742
+ self.local_logger.info(rollout['log_message'])
743
+ # print(rollout['log_message'])
744
+ # print(self.StateAccuracy)
745
+
746
+ if self.predict_occ:
747
+
748
+ grid_agent_occ_seed = rollout['grid_agent_occ_seed']
749
+ grid_pt_occ_seed = rollout['grid_pt_occ_seed']
750
+ grid_agent_occ_gt_seed = rollout['grid_agent_occ_gt_seed']
751
+
752
+ agent_occ = self.attr_tokenizer.pad_square(grid_agent_occ_seed.sigmoid().cpu().numpy())[0]
753
+ agent_occ_gt = self.attr_tokenizer.pad_square(grid_agent_occ_gt_seed.sigmoid().cpu().numpy())[0]
754
+ if self._plot_rollouts:
755
+ plot_occ_grid(scenario_id,
756
+ agent_occ,
757
+ gt_occ=agent_occ_gt,
758
+ mode='agent',
759
+ save_path=self.save_path,
760
+ prefix=f'inference_')
761
+
762
+ if self._online_metric or self._save_validate_reuslts:
763
+
764
+ # ! format results
765
+ pred_valid, token_pos, token_head = [], [], []
766
+ pred_traj, pred_head, pred_z = [], [], []
767
+ pred_shape, pred_type, pred_state = [], [], []
768
+ agent_id = []
769
+ for rollout in rollouts:
770
+ pred_valid.append(rollout['pred_valid'])
771
+ token_pos.append(rollout['pos_a'])
772
+ token_head.append(rollout['head_a'])
773
+ pred_traj.append(rollout['pred_traj'])
774
+ pred_head.append(rollout['pred_head'])
775
+ pred_z.append(rollout['pred_z'])
776
+ pred_shape.append(rollout['eval_shape'])
777
+ pred_type.append(rollout['pred_type'])
778
+ pred_state.append(rollout['next_state_idx'])
779
+ agent_id.append(rollout['agent_id'])
780
+
781
+ pred_valid = torch.stack(pred_valid, dim=1)
782
+ token_pos = torch.stack(token_pos, dim=1)
783
+ token_head = torch.stack(token_head, dim=1)
784
+ pred_traj = torch.stack(pred_traj, dim=1) # (n_agent, n_rollout, n_step, 2)
785
+ pred_head = torch.stack(pred_head, dim=1)
786
+ pred_z = torch.stack(pred_z, dim=1)
787
+ pred_shape = torch.stack(pred_shape, dim=1) # [n_agent, n_rollout, 3]
788
+ pred_type = torch.stack(pred_type, dim=1) # [n_agent, n_rollout]
789
+ pred_state = torch.stack(pred_state, dim=1) # [n_agent, n_rollout, n_step // shift]
790
+ agent_id = torch.stack(agent_id, dim=1) # [n_agent, n_rollout]
791
+
792
+ agent_batch = torch.zeros((pred_traj.shape[0]), dtype=torch.long)
793
+ rollouts = dict(
794
+ _scenario_id=data['scenario_id'],
795
+ scenario_id=get_scenario_id_int_tensor(data['scenario_id']),
796
+ av_id=int(rollouts[0]['agent_id'][rollouts[0]['ego_index']]), # NOTE: hard code!!!
797
+ agent_id=agent_id.cpu(),
798
+ agent_batch=agent_batch.cpu(),
799
+ pred_traj=pred_traj.cpu(),
800
+ pred_z=pred_z.cpu(),
801
+ pred_head=pred_head.cpu(),
802
+ pred_shape=pred_shape.cpu(),
803
+ pred_type=pred_type.cpu(),
804
+ pred_state=pred_state.cpu(),
805
+ pred_valid=pred_valid.cpu(),
806
+ token_pos=token_pos.cpu(),
807
+ token_head=token_head.cpu(),
808
+ tfrecord_path=data['tfrecord_path'],
809
+ )
810
+
811
+ if self._save_validate_reuslts:
812
+ with open(rollouts_path, 'wb') as f:
813
+ pickle.dump(rollouts, f)
814
+
815
+ if self._online_metric:
816
+ self._long_metrics.update(rollouts)
817
+
818
+ def on_validation_start(self):
819
+ self.scenario_rollouts = []
820
+ self.batch_metric = defaultdict(list)
821
+
822
+ def on_validation_epoch_end(self):
823
+ if self.val_close_loop:
824
+
825
+ if self._long_metrics is not None:
826
+ epoch_long_metrics = self._long_metrics.compute()
827
+ if self.global_rank == 0:
828
+ epoch_long_metrics['epoch'] = self.current_epoch
829
+ self.logger.log_metrics(epoch_long_metrics)
830
+
831
+ self._long_metrics.reset()
832
+
833
+ self.minADE.reset()
834
+ self.minFDE.reset()
835
+ self.StateAccuracy.reset()
836
+
837
+ def configure_optimizers(self):
838
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
839
+
840
+ def lr_lambda(current_step):
841
+ if current_step + 1 < self.warmup_steps:
842
+ return float(current_step + 1) / float(max(1, self.warmup_steps))
843
+ return max(
844
+ 0.0, 0.5 * (1.0 + math.cos(math.pi * (current_step - self.warmup_steps) / float(max(1, self.total_steps - self.warmup_steps))))
845
+ )
846
+
847
+ lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
848
+ return [optimizer], [lr_scheduler]
849
+
850
+ def load_state_from_file(self, filename, to_cpu=False):
851
+ logger = self.local_logger
852
+
853
+ if not os.path.isfile(filename):
854
+ raise FileNotFoundError
855
+
856
+ logger.info('==> Loading parameters from checkpoint %s to %s' % (filename, 'CPU' if to_cpu else 'GPU'))
857
+ loc_type = torch.device('cpu') if to_cpu else None
858
+ checkpoint = torch.load(filename, map_location=loc_type)
859
+
860
+ version = checkpoint.get("version", None)
861
+ if version is not None:
862
+ logger.info('==> Checkpoint trained from version: %s' % version)
863
+
864
+
865
+ model_state_disk = checkpoint['state_dict']
866
+ logger.info(f'The number of disk ckpt keys: {len(model_state_disk)}')
867
+
868
+ model_state = self.state_dict()
869
+ model_state_disk_filter = {}
870
+ for key, val in model_state_disk.items():
871
+ if key in model_state and model_state_disk[key].shape == model_state[key].shape:
872
+ model_state_disk_filter[key] = val
873
+ else:
874
+ if key not in model_state:
875
+ print(f'Ignore key in disk (not found in model): {key}, shape={val.shape}')
876
+ else:
877
+ print(f'Ignore key in disk (shape does not match): {key}, load_shape={val.shape}, model_shape={model_state[key].shape}')
878
+
879
+ model_state_disk = model_state_disk_filter
880
+ missing_keys, unexpected_keys = self.load_state_dict(model_state_disk, strict=False)
881
+
882
+ logger.info(f'Missing keys: {missing_keys}')
883
+ logger.info(f'The number of missing keys: {len(missing_keys)}')
884
+ logger.info(f'The number of unexpected keys: {len(unexpected_keys)}')
885
+ logger.info('==> Done (total keys %d)' % (len(model_state)))
886
+
887
+ epoch = checkpoint.get('epoch', -1)
888
+ it = checkpoint.get('it', 0.0)
889
+
890
+ return it, epoch
891
+
892
+ def match_token_map(self, data):
893
+ traj_pos = data['map_save']['traj_pos'].to(torch.float)
894
+ traj_theta = data['map_save']['traj_theta'].to(torch.float)
895
+ pl_idx_list = data['map_save']['pl_idx_list']
896
+ token_sample_pt = self.map_token['sample_pt'].to(traj_pos.device)
897
+ token_src = self.map_token['traj_src'].to(traj_pos.device)
898
+ max_traj_len = self.map_token['traj_src'].shape[1]
899
+ pl_num = traj_pos.shape[0]
900
+
901
+ pt_token_pos = traj_pos[:, 0, :].clone()
902
+ pt_token_orientation = traj_theta.clone()
903
+ cos, sin = traj_theta.cos(), traj_theta.sin()
904
+ rot_mat = traj_theta.new_zeros(pl_num, 2, 2)
905
+ rot_mat[..., 0, 0] = cos
906
+ rot_mat[..., 0, 1] = -sin
907
+ rot_mat[..., 1, 0] = sin
908
+ rot_mat[..., 1, 1] = cos
909
+ traj_pos_local = torch.bmm((traj_pos - traj_pos[:, 0:1]), rot_mat.view(-1, 2, 2))
910
+ distance = torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1)) ** 2, dim=(-2, -1))
911
+ pt_token_id = torch.argmin(distance, dim=1)
912
+
913
+ if self.noise:
914
+ topk_indices = torch.argsort(torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1))**2, dim=(-2, -1)), dim=1)[:, :8]
915
+ sample_topk = torch.randint(0, topk_indices.shape[-1], size=(topk_indices.shape[0], 1), device=topk_indices.device)
916
+ pt_token_id = torch.gather(topk_indices, 1, sample_topk).squeeze(-1)
917
+
918
+ # cos, sin = traj_theta.cos(), traj_theta.sin()
919
+ # rot_mat = traj_theta.new_zeros(pl_num, 2, 2)
920
+ # rot_mat[..., 0, 0] = cos
921
+ # rot_mat[..., 0, 1] = sin
922
+ # rot_mat[..., 1, 0] = -sin
923
+ # rot_mat[..., 1, 1] = cos
924
+ # token_src_world = torch.bmm(token_src[None, ...].repeat(pl_num, 1, 1, 1).reshape(pl_num, -1, 2),
925
+ # rot_mat.view(-1, 2, 2)).reshape(pl_num, token_src.shape[0], max_traj_len, 2) + traj_pos[:, None, [0], :]
926
+ # token_src_world_select = token_src_world.view(-1, 1024, 11, 2)[torch.arange(pt_token_id.view(-1).shape[0]), pt_token_id.view(-1)].view(pl_num, max_traj_len, 2)
927
+
928
+ pl_idx_full = pl_idx_list.clone()
929
+ token2pl = torch.stack([torch.arange(len(pl_idx_list), device=traj_pos.device), pl_idx_full.long()])
930
+ count_nums = []
931
+ for pl in pl_idx_full.unique():
932
+ pt = token2pl[0, token2pl[1, :] == pl]
933
+ left_side = (data['pt_token']['side'][pt] == 0).sum()
934
+ right_side = (data['pt_token']['side'][pt] == 1).sum()
935
+ center_side = (data['pt_token']['side'][pt] == 2).sum()
936
+ count_nums.append(torch.Tensor([left_side, right_side, center_side]))
937
+ count_nums = torch.stack(count_nums, dim=0)
938
+ num_polyline = int(count_nums.max().item())
939
+ traj_mask = torch.zeros((int(len(pl_idx_full.unique())), 3, num_polyline), dtype=bool)
940
+ idx_matrix = torch.arange(traj_mask.size(2)).unsqueeze(0).unsqueeze(0)
941
+ idx_matrix = idx_matrix.expand(traj_mask.size(0), traj_mask.size(1), -1)
942
+ counts_num_expanded = count_nums.unsqueeze(-1)
943
+ mask_update = idx_matrix < counts_num_expanded
944
+ traj_mask[mask_update] = True
945
+
946
+ data['pt_token']['traj_mask'] = traj_mask
947
+ data['pt_token']['position'] = torch.cat([pt_token_pos, torch.zeros((data['pt_token']['num_nodes'], 1),
948
+ device=traj_pos.device, dtype=torch.float)], dim=-1)
949
+ data['pt_token']['orientation'] = pt_token_orientation
950
+ data['pt_token']['height'] = data['pt_token']['position'][:, -1]
951
+ data[('pt_token', 'to', 'map_polygon')] = {}
952
+ data[('pt_token', 'to', 'map_polygon')]['edge_index'] = token2pl # (2, num_points)
953
+ data['pt_token']['token_idx'] = pt_token_id
954
+
955
+ # data['pt_token']['batch'] = torch.zeros(data['pt_token']['num_nodes'], device=traj_pos.device).long()
956
+ # data['pt_token']['ptr'] = torch.tensor([0, data['pt_token']['num_nodes']], device=traj_pos.device).long()
957
+
958
+ return data
959
+
960
+ def sample_pt_pred(self, data):
961
+ traj_mask = data['pt_token']['traj_mask']
962
+ raw_pt_index = torch.arange(1, traj_mask.shape[2]).repeat(traj_mask.shape[0], traj_mask.shape[1], 1)
963
+ masked_pt_index = raw_pt_index.view(-1)[torch.randperm(raw_pt_index.numel())[:traj_mask.shape[0] * traj_mask.shape[1] * ((traj_mask.shape[2] - 1) // 3)].reshape(traj_mask.shape[0], traj_mask.shape[1], (traj_mask.shape[2] - 1) // 3)]
964
+ masked_pt_index = torch.sort(masked_pt_index, -1)[0]
965
+ pt_valid_mask = traj_mask.clone()
966
+ pt_valid_mask.scatter_(2, masked_pt_index, False)
967
+ pt_pred_mask = traj_mask.clone()
968
+ pt_pred_mask.scatter_(2, masked_pt_index, False)
969
+ tmp_mask = pt_pred_mask.clone()
970
+ tmp_mask[:, :, :] = True
971
+ tmp_mask.scatter_(2, masked_pt_index-1, False)
972
+ pt_pred_mask.masked_fill_(tmp_mask, False)
973
+ pt_pred_mask = pt_pred_mask * torch.roll(traj_mask, shifts=-1, dims=2)
974
+ pt_target_mask = torch.roll(pt_pred_mask, shifts=1, dims=2)
975
+
976
+ data['pt_token']['pt_valid_mask'] = pt_valid_mask[traj_mask]
977
+ data['pt_token']['pt_pred_mask'] = pt_pred_mask[traj_mask]
978
+ data['pt_token']['pt_target_mask'] = pt_target_mask[traj_mask]
979
+
980
+ return data
981
+
982
+ def _fetch_enterings(self, data: HeteroData, plot: bool=False):
983
+ data['agent']['grid_token_idx'] = torch.zeros_like(data['agent']['state_idx']).long()
984
+ data['agent']['grid_offset_xy'] = torch.zeros_like(data['agent']['token_pos'])
985
+ data['agent']['heading_token_idx'] = torch.zeros_like(data['agent']['state_idx']).long()
986
+ data['agent']['sort_indices'] = torch.zeros_like(data['agent']['state_idx']).long()
987
+ data['agent']['inrange_mask'] = torch.zeros_like(data['agent']['state_idx']).bool()
988
+ data['agent']['bos_mask'] = torch.zeros_like(data['agent']['state_idx']).bool()
989
+
990
+ data['agent']['pos_xy'] = torch.zeros_like(data['agent']['token_pos'])
991
+ if self.predict_occ:
992
+ num_step = data['agent']['state_idx'].shape[1]
993
+ data['agent']['pt_grid_token_idx'] = torch.zeros_like(data['pt_token']['token_idx'])[None].repeat(num_step, 1).long()
994
+
995
+ for b in range(data.num_graphs):
996
+ av_index = int(data['agent']['av_index'][b])
997
+ agent_batch_mask = data['agent']['batch'] == b
998
+ pt_batch_mask = data['pt_token']['batch'] == b
999
+ pt_token_idx = data['pt_token']['token_idx'][pt_batch_mask]
1000
+ pt_pos = data['pt_token']['position'][pt_batch_mask]
1001
+ agent_token_pos = data['agent']['token_pos'][agent_batch_mask]
1002
+ agent_token_heading = data['agent']['token_heading'][agent_batch_mask]
1003
+ state_idx = data['agent']['state_idx'][agent_batch_mask]
1004
+ ego_pos = agent_token_pos[av_index] # NOTE: `av_index` will be added by `ptr` later
1005
+ ego_heading = agent_token_heading[av_index]
1006
+
1007
+ grid_token_idx = torch.full(state_idx.shape, -1, device=state_idx.device)
1008
+ offset_xy = torch.zeros_like(agent_token_pos)
1009
+ sort_indices = torch.zeros_like(grid_token_idx)
1010
+ pt_grid_token_idx = torch.full((state_idx.shape[1], *pt_token_idx.shape), -1, device=pt_token_idx.device)
1011
+
1012
+ pos_xy = torch.zeros((*state_idx.shape, 2), device=state_idx.device)
1013
+
1014
+ is_bos = []
1015
+ is_inrange = []
1016
+ for t in range(agent_token_pos.shape[1]): # num_step
1017
+
1018
+ # tokenize position
1019
+ is_bos_t = state_idx[:, t] == self.enter_state
1020
+ is_invalid_t = state_idx[:, t] == self.invalid_state
1021
+ is_inrange_t = ((agent_token_pos[:, t] - ego_pos[[t]]) ** 2).sum(-1).sqrt() <= self.pl2seed_radius
1022
+ grid_index_t, offset_xy_t = self.attr_tokenizer.encode_pos(x=agent_token_pos[~is_invalid_t & is_inrange_t, t],
1023
+ y=ego_pos[[t]],
1024
+ theta_y=ego_heading[[t]])
1025
+ grid_token_idx[~is_invalid_t & is_inrange_t, t] = grid_index_t
1026
+ offset_xy[~is_invalid_t & is_inrange_t, t] = offset_xy_t
1027
+
1028
+ pos_xy[~is_invalid_t & is_inrange_t, t] = agent_token_pos[~is_invalid_t & is_inrange_t, t] - ego_pos[[t]]
1029
+
1030
+ # distance = ((agent_token_pos[:, t] - ego_pos[[t]]) ** 2).sum(-1).sqrt()
1031
+ head_vector = torch.stack([ego_heading[[t]].cos(), ego_heading[[t]].sin()], dim=-1)
1032
+ distance = angle_between_2d_vectors(ctr_vector=head_vector,
1033
+ nbr_vector=agent_token_pos[:, t] - ego_pos[[t]])
1034
+ # distance = torch.rand(agent_token_pos.shape[0], device=agent_token_pos.device)
1035
+ distance[~(is_bos_t & is_inrange_t)] = torch.inf
1036
+ sort_dist, sort_indice = distance.sort()
1037
+ sort_indice[torch.isinf(sort_dist)] = av_index
1038
+ sort_indices[:, t] = sort_indice
1039
+
1040
+ is_bos.append(is_bos_t)
1041
+ is_inrange.append(is_inrange_t)
1042
+
1043
+ # tokenize pt token
1044
+ if self.predict_occ:
1045
+ is_inrange_t = ((pt_pos[:, :2] - ego_pos[None, t]) ** 2).sum(-1).sqrt() <= self.pl2seed_radius
1046
+ grid_index_t, _ = self.attr_tokenizer.encode_pos(x=pt_pos[is_inrange_t, :2],
1047
+ y=ego_pos[[t]],
1048
+ theta_y=ego_heading[[t]])
1049
+
1050
+ pt_grid_token_idx[t, is_inrange_t] = grid_index_t
1051
+
1052
+ # tokenize heading
1053
+ rel_heading = agent_token_heading - ego_heading[None, ...]
1054
+ heading_token_idx = self.attr_tokenizer.encode_heading(rel_heading)
1055
+
1056
+ data['agent']['grid_token_idx'][agent_batch_mask] = grid_token_idx
1057
+ data['agent']['grid_offset_xy'][agent_batch_mask] = offset_xy
1058
+ data['agent']['pos_xy'][agent_batch_mask] = pos_xy
1059
+ data['agent']['heading_token_idx'][agent_batch_mask] = heading_token_idx
1060
+ data['agent']['sort_indices'][agent_batch_mask] = sort_indices
1061
+ data['agent']['inrange_mask'][agent_batch_mask] = torch.stack(is_inrange, dim=1)
1062
+ data['agent']['bos_mask'][agent_batch_mask] = torch.stack(is_bos, dim=1)
1063
+ if self.predict_occ:
1064
+ data['agent']['pt_grid_token_idx'][:, pt_batch_mask] = pt_grid_token_idx
1065
+
1066
+ plot = False
1067
+ if plot:
1068
+ scenario_id = data['scenario_id'][b]
1069
+ dummy_prob = np.zeros((ego_pos.shape[0], self.attr_tokenizer.grid.shape[0])) + .5
1070
+ indices = grid_token_idx[:, 1:][state_idx[:, 1:] == self.enter_state].reshape(-1).cpu().numpy()
1071
+ dummy_prob, indices = self.attr_tokenizer.pad_square(dummy_prob, indices)
1072
+ # plot_insert_grid(scenario_id, dummy_prob,
1073
+ # self.attr_tokenizer.grid.cpu().numpy(),
1074
+ # ego_pos.cpu().numpy(),
1075
+ # None,
1076
+ # save_path=os.path.join(self.save_path, 'vis'),
1077
+ # indices=indices[np.newaxis, ...],
1078
+ # inference=True,
1079
+ # all_t_in_one=True)
1080
+
1081
+ enter_index = [grid_token_idx[:, i][state_idx[:, i] == self.enter_state].tolist()
1082
+ for i in range(agent_token_pos.shape[1])]
1083
+ agent_labels = [[f'A{i}'] * agent_token_pos.shape[1] for i in range(agent_token_pos.shape[0])]
1084
+ plot_scenario(scenario_id,
1085
+ data['map_point']['position'].cpu().numpy(),
1086
+ agent_token_pos.cpu().numpy(),
1087
+ agent_token_heading.cpu().numpy(),
1088
+ state_idx.cpu().numpy(),
1089
+ types=list(map(lambda i: self.encoder.agent_encoder.agent_type[i],
1090
+ data['agent']['type'].tolist())),
1091
+ av_index=av_index,
1092
+ pl2seed_radius=self.pl2seed_radius,
1093
+ attr_tokenizer=self.attr_tokenizer,
1094
+ enter_index=enter_index,
1095
+ save_gif=False,
1096
+ save_path=os.path.join(self.save_path, 'vis'),
1097
+ agent_labels=agent_labels,
1098
+ tokenized=True)
1099
+
1100
+ return data
backups/dev/modules/agent_decoder.py ADDED
The diff for this file is too large to render. See raw diff
 
backups/dev/modules/attr_tokenizer.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from dev.utils.func import wrap_angle, angle_between_2d_vectors
5
+
6
+
7
+ class Attr_Tokenizer(nn.Module):
8
+
9
+ def __init__(self, grid_range, grid_interval, radius, angle_interval):
10
+ super().__init__()
11
+ self.grid_range = grid_range
12
+ self.grid_interval = grid_interval
13
+ self.radius = radius
14
+ self.angle_interval = angle_interval
15
+ self.heading = torch.pi / 2
16
+ self._prepare_grid()
17
+
18
+ self.grid_size = self.grid.shape[0]
19
+ self.angle_size = int(360. / self.angle_interval)
20
+
21
+ assert torch.all(self.grid[self.grid_size // 2] == 0.)
22
+
23
+ def _prepare_grid(self):
24
+ num_grid = int(self.grid_range / self.grid_interval) + 1 # Do not use '//'
25
+
26
+ x = torch.linspace(0, num_grid - 1, steps=num_grid)
27
+ y = torch.linspace(0, num_grid - 1, steps=num_grid)
28
+ grid_x, grid_y = torch.meshgrid(x, y, indexing='xy')
29
+ grid = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1) # (n^2, 2)
30
+ grid = grid.reshape(num_grid, num_grid, 2).flip(dims=[0]).reshape(-1, 2)
31
+ grid = (grid - x.shape[0] // 2) * self.grid_interval
32
+
33
+ distance = (grid ** 2).sum(-1).sqrt()
34
+ square_mask = ((distance <= self.radius) & (distance >= 0.)) | (distance == 0.)
35
+ self.register_buffer('grid', grid[square_mask])
36
+ self.register_buffer('dist', torch.norm(self.grid, p=2, dim=-1))
37
+ head_vector = torch.stack([torch.tensor(self.heading).cos(), torch.tensor(self.heading).sin()])
38
+ self.register_buffer('dir', angle_between_2d_vectors(ctr_vector=head_vector.unsqueeze(0),
39
+ nbr_vector=self.grid)) # (-pi, pi]
40
+
41
+ self.num_grid = num_grid
42
+ self.square_mask = square_mask.numpy()
43
+
44
+ def _apply_rot(self, x, theta):
45
+ # x: (b, l, 2) e.g. (num_step, num_agent, 2)
46
+ # theta: (b,) e.g. (num_step,)
47
+ cos, sin = theta.cos(), theta.sin()
48
+ rot_mat = torch.zeros((theta.shape[0], 2, 2), device=theta.device)
49
+ rot_mat[:, 0, 0] = cos
50
+ rot_mat[:, 0, 1] = sin
51
+ rot_mat[:, 1, 0] = -sin
52
+ rot_mat[:, 1, 1] = cos
53
+ x = torch.bmm(x, rot_mat)
54
+ return x
55
+
56
+ def pad_square(self, prob, indices=None):
57
+ # square_mask: bool array of shape (n^2,)
58
+ # prob: float array of shape (num_step, m)
59
+ pad_prob = np.zeros((*prob.shape[:-1], self.square_mask.shape[0]))
60
+ pad_prob[..., self.square_mask] = prob
61
+
62
+ square_indices = np.arange(self.square_mask.shape[0])
63
+ circle_indices = np.concatenate([square_indices[self.square_mask], [-1]])
64
+ if indices is not None:
65
+ indices = circle_indices[indices]
66
+
67
+ return pad_prob, indices
68
+
69
+ def get_grid(self, x, theta=None):
70
+ x = x.reshape(-1, 2)
71
+ grid = self.grid[None, ...].to(x.device)
72
+ if theta is not None:
73
+ grid = self._apply_rot(grid, (theta - self.heading).expand(x.shape[0]))
74
+ return x[:, None] + grid
75
+
76
+ def encode_pos(self, x, y, theta_y=None):
77
+ assert x.dim() == y.dim() and x.shape[-1] == 2 and y.shape[-1] == 2, \
78
+ f"Invalid input shape x: {x.shape}, y: {y.shape}."
79
+ centered_x = x - y
80
+ if theta_y is not None:
81
+ centered_x = self._apply_rot(centered_x[:, None], -(theta_y - self.heading).expand(x.shape[0]))[:, 0]
82
+ distance = ((centered_x[:, None] - self.grid.to(x.device)[None, ...]) ** 2).sum(-1).sqrt()
83
+ index = torch.argmin(distance, dim=-1)
84
+
85
+ grid_xy = self.grid[index]
86
+ offset_xy = centered_x - grid_xy
87
+
88
+ return index.long(), offset_xy
89
+
90
+ def decode_pos(self, index, y=None, theta_y=None):
91
+ assert torch.all((index >= 0) & (index < self.grid_size))
92
+ centered_x = self.grid.to(index.device)[index.long()]
93
+ if y is not None:
94
+ if theta_y is not None:
95
+ centered_x = self._apply_rot(centered_x[:, None], (theta_y - self.heading).expand(centered_x.shape[0]))[:, 0]
96
+ x = centered_x + y
97
+ return x.float()
98
+ return centered_x.float()
99
+
100
+ def encode_heading(self, heading):
101
+ heading = (wrap_angle(heading) + torch.pi) / (2 * torch.pi) * 360
102
+ index = heading // self.angle_interval
103
+ return index.long()
104
+
105
+ def decode_heading(self, index):
106
+ assert torch.all(index >= 0) and torch.all(index < (360 / self.angle_interval))
107
+ angles = index * self.angle_interval - 180
108
+ angles = angles / 360 * (2 * torch.pi)
109
+ return angles.float()
backups/dev/modules/debug.py ADDED
@@ -0,0 +1,1439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Mapping, Optional
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch_cluster import radius, radius_graph
6
+ from torch_geometric.data import HeteroData, Batch
7
+ from torch_geometric.utils import dense_to_sparse, subgraph
8
+
9
+ from dev.modules.layers import *
10
+ from dev.modules.map_decoder import discretize_neighboring
11
+ from dev.utils.geometry import angle_between_2d_vectors, wrap_angle
12
+ from dev.utils.weight_init import weight_init
13
+
14
+
15
+ def cal_polygon_contour(x, y, theta, width, length):
16
+ left_front_x = x + 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)
17
+ left_front_y = y + 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)
18
+ left_front = (left_front_x, left_front_y)
19
+
20
+ right_front_x = x + 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)
21
+ right_front_y = y + 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)
22
+ right_front = (right_front_x, right_front_y)
23
+
24
+ right_back_x = x - 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)
25
+ right_back_y = y - 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)
26
+ right_back = (right_back_x, right_back_y)
27
+
28
+ left_back_x = x - 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)
29
+ left_back_y = y - 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)
30
+ left_back = (left_back_x, left_back_y)
31
+ polygon_contour = [left_front, right_front, right_back, left_back]
32
+
33
+ return polygon_contour
34
+
35
+
36
+ class SMARTAgentDecoder(nn.Module):
37
+
38
+ def __init__(self,
39
+ dataset: str,
40
+ input_dim: int,
41
+ hidden_dim: int,
42
+ num_historical_steps: int,
43
+ num_interaction_steps: int,
44
+ time_span: Optional[int],
45
+ pl2a_radius: float,
46
+ pl2seed_radius: float,
47
+ a2a_radius: float,
48
+ num_freq_bands: int,
49
+ num_layers: int,
50
+ num_heads: int,
51
+ head_dim: int,
52
+ dropout: float,
53
+ token_data: Dict,
54
+ token_size: int,
55
+ special_token_index: list=[],
56
+ predict_motion: bool=False,
57
+ predict_state: bool=False,
58
+ predict_map: bool=False,
59
+ state_token: Dict[str, int]=None,
60
+ seed_size: int=5) -> None:
61
+
62
+ super(SMARTAgentDecoder, self).__init__()
63
+ self.dataset = dataset
64
+ self.input_dim = input_dim
65
+ self.hidden_dim = hidden_dim
66
+ self.num_historical_steps = num_historical_steps
67
+ self.num_interaction_steps = num_interaction_steps
68
+ self.time_span = time_span if time_span is not None else num_historical_steps
69
+ self.pl2a_radius = pl2a_radius
70
+ self.pl2seed_radius = pl2seed_radius
71
+ self.a2a_radius = a2a_radius
72
+ self.num_freq_bands = num_freq_bands
73
+ self.num_layers = num_layers
74
+ self.num_heads = num_heads
75
+ self.head_dim = head_dim
76
+ self.dropout = dropout
77
+ self.special_token_index = special_token_index
78
+ self.predict_motion = predict_motion
79
+ self.predict_state = predict_state
80
+ self.predict_map = predict_map
81
+
82
+ # state tokens
83
+ self.state_type = list(state_token.keys())
84
+ self.state_token = state_token
85
+ self.invalid_state = int(state_token['invalid'])
86
+ self.valid_state = int(state_token['valid'])
87
+ self.enter_state = int(state_token['enter'])
88
+ self.exit_state = int(state_token['exit'])
89
+
90
+ self.seed_state_type = ['invalid', 'enter']
91
+ self.valid_state_type = ['invalid', 'valid', 'exit']
92
+
93
+ input_dim_x_a = 2
94
+ input_dim_r_t = 4
95
+ input_dim_r_pt2a = 3
96
+ input_dim_r_a2a = 3
97
+ input_dim_token = 8 # tokens: (token_size, 4, 2)
98
+
99
+ self.seed_size = seed_size
100
+
101
+ self.all_agent_type = ['veh', 'ped', 'cyc', 'background', 'invalid', 'seed']
102
+ self.seed_agent_type = ['veh', 'ped', 'cyc', 'seed']
103
+ self.type_a_emb = nn.Embedding(len(self.all_agent_type), hidden_dim)
104
+ self.shape_emb = MLPLayer(3, hidden_dim, hidden_dim)
105
+ if self.predict_state:
106
+ self.state_a_emb = nn.Embedding(len(self.state_type), hidden_dim)
107
+ self.invalid_shape_value = .1
108
+ self.motion_gap = 1.
109
+ self.heading_gap = 1.
110
+
111
+ self.x_a_emb = FourierEmbedding(input_dim=input_dim_x_a, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands)
112
+ self.r_t_emb = FourierEmbedding(input_dim=input_dim_r_t, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands)
113
+ self.r_pt2a_emb = FourierEmbedding(input_dim=input_dim_r_pt2a, hidden_dim=hidden_dim,
114
+ num_freq_bands=num_freq_bands)
115
+ self.r_a2a_emb = FourierEmbedding(input_dim=input_dim_r_a2a, hidden_dim=hidden_dim,
116
+ num_freq_bands=num_freq_bands)
117
+ self.token_emb_veh = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
118
+ self.token_emb_ped = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
119
+ self.token_emb_cyc = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
120
+ self.no_token_emb = nn.Embedding(1, hidden_dim)
121
+ self.bos_token_emb = nn.Embedding(1, hidden_dim)
122
+ # FIXME: do we need this???
123
+ self.token_emb_offset = MLPEmbedding(input_dim=2, hidden_dim=hidden_dim)
124
+
125
+ num_inputs = 2
126
+ if self.predict_state:
127
+ num_inputs = 3
128
+ self.fusion_emb = MLPEmbedding(input_dim=self.hidden_dim * num_inputs, hidden_dim=self.hidden_dim)
129
+
130
+ self.t_attn_layers = nn.ModuleList(
131
+ [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
132
+ bipartite=False, has_pos_emb=True) for _ in range(num_layers)]
133
+ )
134
+ self.pt2a_attn_layers = nn.ModuleList(
135
+ [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
136
+ bipartite=True, has_pos_emb=True) for _ in range(num_layers)]
137
+ )
138
+ self.a2a_attn_layers = nn.ModuleList(
139
+ [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
140
+ bipartite=False, has_pos_emb=True) for _ in range(num_layers)]
141
+ )
142
+ self.token_size = token_size # 2048
143
+ # agent motion prediction head
144
+ self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
145
+ output_dim=self.token_size)
146
+ # agent state prediction head
147
+ if self.predict_state:
148
+
149
+ self.seed_feature = nn.Embedding(self.seed_size, self.hidden_dim)
150
+
151
+ self.state_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
152
+ output_dim=len(self.valid_state_type))
153
+
154
+ self.seed_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
155
+ output_dim=hidden_dim)
156
+
157
+ self.seed_state_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
158
+ output_dim=len(self.seed_state_type))
159
+ self.seed_type_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
160
+ output_dim=len(self.seed_agent_type))
161
+ # entering token prediction
162
+ # FIXME: this is just under test!!!
163
+ # self.bos_pl_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
164
+ # output_dim=200)
165
+ # self.bos_offset_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
166
+ # output_dim=2601)
167
+ self.trajectory_token = token_data['token'] # dict('veh', 'ped', 'cyc') (2048, 4, 2)
168
+ self.trajectory_token_traj = token_data['traj'] # (2048, 6, 3)
169
+ self.trajectory_token_all = token_data['token_all'] # (2048, 6, 4, 2)
170
+ self.apply(weight_init)
171
+
172
+ self.shift = 5
173
+ self.beam_size = 5
174
+ self.hist_mask = True
175
+ self.temporal_attn_to_invalid = True
176
+ self.temporal_attn_seed = False
177
+
178
+ # FIXME: This is just under test!!!
179
+ # self.mapping_network = MappingNetwork(z_dim=hidden_dim, w_dim=hidden_dim, num_layers=num_layers)
180
+
181
+ def transform_rel(self, token_traj, prev_pos, prev_heading=None):
182
+ if prev_heading is None:
183
+ diff_xy = prev_pos[:, :, -1, :] - prev_pos[:, :, -2, :]
184
+ prev_heading = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])
185
+
186
+ num_agent, num_step, traj_num, traj_dim = token_traj.shape
187
+ cos, sin = prev_heading.cos(), prev_heading.sin()
188
+ rot_mat = torch.zeros((num_agent, num_step, 2, 2), device=prev_heading.device)
189
+ rot_mat[:, :, 0, 0] = cos
190
+ rot_mat[:, :, 0, 1] = -sin
191
+ rot_mat[:, :, 1, 0] = sin
192
+ rot_mat[:, :, 1, 1] = cos
193
+ agent_diff_rel = torch.bmm(token_traj.view(-1, traj_num, 2), rot_mat.view(-1, 2, 2)).view(num_agent, num_step, traj_num, traj_dim)
194
+ agent_pred_rel = agent_diff_rel + prev_pos[:, :, -1:, :]
195
+ return agent_pred_rel
196
+
197
+ def agent_token_embedding(self, data, agent_token_index, agent_state, pos_a, head_a, inference=False,
198
+ filter_mask=None, av_index=None):
199
+
200
+ if filter_mask is None:
201
+ filter_mask = torch.ones_like(agent_state[:, 2], dtype=torch.bool)
202
+
203
+ num_agent, num_step, traj_dim = pos_a.shape # traj_dim=2
204
+ agent_type = data['agent']['type'][filter_mask]
205
+ veh_mask = (agent_type == 0)
206
+ ped_mask = (agent_type == 1)
207
+ cyc_mask = (agent_type == 2)
208
+
209
+ # set the position of invalid agents to the position of ego agent
210
+ # note here we only set invalid steps BEFORE the bos token!
211
+ # is_invalid = agent_state == self.invalid_state
212
+ # is_bos = agent_state == self.enter_state
213
+ # bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
214
+ # bos_mask = torch.arange(num_step).expand(num_agent, -1).to(agent_state.device) < bos_index[:, None]
215
+ # is_invalid[~bos_mask] = False
216
+
217
+ # ego_pos_a = pos_a[av_index].clone()
218
+ # ego_head_vector_a = head_vector_a[av_index].clone()
219
+ # pos_a[is_invalid] = ego_pos_a[None, :].repeat(pos_a.shape[0], 1, 1)[is_invalid]
220
+ # head_vector_a[is_invalid] = ego_head_vector_a[None, :].repeat(head_vector_a.shape[0], 1, 1)[is_invalid]
221
+
222
+ motion_vector_a, head_vector_a = self.build_vector_a(pos_a, head_a, agent_state)
223
+
224
+ trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(pos_a.device).to(torch.float)
225
+ trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(pos_a.device).to(torch.float)
226
+ trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(pos_a.device).to(torch.float)
227
+ self.agent_token_emb_veh = self.token_emb_veh(trajectory_token_veh.view(trajectory_token_veh.shape[0], -1)) # (token_size, 8)
228
+ self.agent_token_emb_ped = self.token_emb_ped(trajectory_token_ped.view(trajectory_token_ped.shape[0], -1))
229
+ self.agent_token_emb_cyc = self.token_emb_cyc(trajectory_token_cyc.view(trajectory_token_cyc.shape[0], -1))
230
+
231
+ # add bos token embedding
232
+ self.agent_token_emb_veh = torch.cat([self.agent_token_emb_veh, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())])
233
+ self.agent_token_emb_ped = torch.cat([self.agent_token_emb_ped, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())])
234
+ self.agent_token_emb_cyc = torch.cat([self.agent_token_emb_cyc, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())])
235
+
236
+ # add invalid token embedding
237
+ self.agent_token_emb_veh = torch.cat([self.agent_token_emb_veh, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())])
238
+ self.agent_token_emb_ped = torch.cat([self.agent_token_emb_ped, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())])
239
+ self.agent_token_emb_cyc = torch.cat([self.agent_token_emb_cyc, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())])
240
+
241
+ if inference:
242
+ agent_token_traj_all = torch.zeros((num_agent, self.token_size, self.shift + 1, 4, 2), device=pos_a.device)
243
+ trajectory_token_all_veh = torch.from_numpy(self.trajectory_token_all['veh']).clone().to(pos_a.device).to(torch.float)
244
+ trajectory_token_all_ped = torch.from_numpy(self.trajectory_token_all['ped']).clone().to(pos_a.device).to(torch.float)
245
+ trajectory_token_all_cyc = torch.from_numpy(self.trajectory_token_all['cyc']).clone().to(pos_a.device).to(torch.float)
246
+ agent_token_traj_all[veh_mask] = torch.cat(
247
+ [trajectory_token_all_veh[:, :self.shift], trajectory_token_veh[:, None, ...]], dim=1)
248
+ agent_token_traj_all[ped_mask] = torch.cat(
249
+ [trajectory_token_all_ped[:, :self.shift], trajectory_token_ped[:, None, ...]], dim=1)
250
+ agent_token_traj_all[cyc_mask] = torch.cat(
251
+ [trajectory_token_all_cyc[:, :self.shift], trajectory_token_cyc[:, None, ...]], dim=1)
252
+
253
+ # additional token embeddings are already added -> -1: invalid, -2: bos
254
+ agent_token_emb = torch.zeros((num_agent, num_step, self.hidden_dim), device=pos_a.device)
255
+ agent_token_emb[veh_mask] = self.agent_token_emb_veh[agent_token_index[veh_mask]]
256
+ agent_token_emb[ped_mask] = self.agent_token_emb_ped[agent_token_index[ped_mask]]
257
+ agent_token_emb[cyc_mask] = self.agent_token_emb_cyc[agent_token_index[cyc_mask]]
258
+
259
+ # 'vehicle', 'pedestrian', 'cyclist', 'background'
260
+ is_invalid = (agent_state == self.invalid_state) & (agent_state != self.enter_state)
261
+ agent_types = data['agent']['type'][filter_mask].long().repeat_interleave(repeats=num_step, dim=0)
262
+ agent_types[is_invalid.reshape(-1)] = self.all_agent_type.index('invalid')
263
+ agent_shapes = data['agent']['shape'][filter_mask, self.num_historical_steps - 1, :].repeat_interleave(repeats=num_step, dim=0)
264
+ agent_shapes[is_invalid.reshape(-1)] = self.invalid_shape_value
265
+
266
+ categorical_embs = [self.type_a_emb(agent_types), self.shape_emb(agent_shapes)]
267
+ feature_a = torch.stack(
268
+ [torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1),
269
+ angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2]),
270
+ ], dim=-1) # (num_agent, num_shifted_step, 2)
271
+
272
+ x_a = self.x_a_emb(continuous_inputs=feature_a.view(-1, feature_a.size(-1)),
273
+ categorical_embs=categorical_embs)
274
+ x_a = x_a.view(-1, num_step, self.hidden_dim) # (num_agent, num_step, hidden_dim)
275
+
276
+ s_a = self.state_a_emb(agent_state.reshape(-1).long()).reshape(num_agent, num_step, self.hidden_dim)
277
+ feat_a = torch.cat((agent_token_emb, x_a, s_a), dim=-1) # (num_agent, num_step, hidden_dim * 3)
278
+ feat_a = self.fusion_emb(feat_a) # (num_agent, num_step, hidden_dim)
279
+
280
+ # seed agent feature
281
+ motion_vector_seed = motion_vector_a[av_index : av_index + 1]
282
+ head_vector_seed = head_vector_a[av_index : av_index + 1]
283
+ feat_seed = self.build_invalid_agent_feature(num_step, pos_a.device, type_index=self.all_agent_type.index('seed'),
284
+ motion_vector=motion_vector_seed, head_vector=head_vector_seed)
285
+
286
+ # replace the features of steps before bos of valid agents with the corresponding invalid agent features
287
+ # is_bos = agent_state == self.enter_state
288
+ # is_eos = agent_state == self.exit_state
289
+ # bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
290
+ # eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1))
291
+ # is_before_bos = torch.arange(num_step).expand(num_agent, -1).to(agent_state.device) < bos_index[:, None]
292
+ # is_after_eos = torch.arange(num_step).expand(num_agent, -1).to(agent_state.device) > eos_index[:, None] + 1
293
+ # feat_ina = self.build_invalid_agent_feature(num_step, pos_a.device)
294
+ # feat_a[is_before_bos | is_after_eos] = feat_ina.repeat(num_agent, 1, 1)[is_before_bos | is_after_eos]
295
+
296
+ # print("train")
297
+ # is_bos = agent_state == self.enter_state
298
+ # is_eos = agent_state == self.exit_state
299
+ # bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
300
+ # eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1))
301
+ # mask = torch.arange(num_step).expand(num_agent, -1).to(agent_state.device)
302
+ # mask = (mask >= bos_index[:, None]) & (mask <= eos_index[:, None] + 1)
303
+ # is_invalid[mask] = False
304
+ # print(feat_a.sum(dim=-1)[is_invalid])
305
+
306
+ feat_a = torch.cat([feat_a, feat_seed], dim=0) # (num_agent + 1, num_step, hidden_dim)
307
+
308
+ # feat_a_sum = feat_a.sum(dim=-1)
309
+ # for a in range(num_agent):
310
+ # print(f"agent {a}:")
311
+ # print(f"state: {agent_state[a, :]}")
312
+ # print(f"feat_a_sum: {feat_a_sum[a, :]}")
313
+ # exit(1)
314
+
315
+ if inference:
316
+ return feat_a, head_vector_a, agent_token_traj_all, agent_token_emb, categorical_embs
317
+ else:
318
+ return feat_a, head_vector_a
319
+
320
+ def build_vector_a(self, pos_a, head_a, state_a):
321
+ num_agent = pos_a.shape[0]
322
+
323
+ motion_vector_a = torch.cat([pos_a.new_zeros(num_agent, 1, self.input_dim),
324
+ pos_a[:, 1:] - pos_a[:, :-1]], dim=1)
325
+
326
+ # update the relative motion/head vectors
327
+ is_bos = state_a == self.enter_state
328
+ motion_vector_a[is_bos] = self.motion_gap
329
+
330
+ is_last_eos = state_a.roll(shifts=1, dims=1) == self.exit_state
331
+ is_last_eos[:, 0] = False
332
+ motion_vector_a[is_last_eos] = -self.motion_gap
333
+
334
+ head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1)
335
+
336
+ return motion_vector_a, head_vector_a
337
+
338
+ def build_invalid_agent_feature(self, num_step, device, motion_vector=None, head_vector=None, type_index=None, shape_value=None):
339
+ invalid_agent_token_emb = self.no_token_emb(torch.zeros(1, device=device).long())[:, None].repeat(1, num_step, 1)
340
+
341
+ if motion_vector is None or head_vector is None:
342
+ motion_vector = torch.zeros((1, num_step, 2), device=device)
343
+ head_vector = torch.stack([torch.cos(torch.zeros(1, device=device)), torch.sin(torch.zeros(1, device=device))], dim=-1)[:, None, :].repeat(1, num_step, 1)
344
+
345
+ feature_ina = torch.stack(
346
+ [torch.norm(motion_vector[:, :, :2], p=2, dim=-1),
347
+ angle_between_2d_vectors(ctr_vector=head_vector, nbr_vector=motion_vector[:, :, :2]),
348
+ ], dim=-1)
349
+
350
+ if type_index is None:
351
+ type_index = self.all_agent_type.index('invalid')
352
+ if shape_value is None:
353
+ shape_value = torch.full((1, 3), self.invalid_shape_value, device=device)
354
+
355
+ categorical_embs_ina = [self.type_a_emb(torch.tensor([type_index], device=device)),
356
+ self.shape_emb(shape_value)]
357
+ x_ina = self.x_a_emb(continuous_inputs=feature_ina.view(-1, feature_ina.size(-1)),
358
+ categorical_embs=categorical_embs_ina)
359
+ x_ina = x_ina.view(-1, num_step, self.hidden_dim) # (1, num_step, hidden_dim)
360
+
361
+ s_ina = self.state_a_emb(torch.tensor([self.invalid_state], device=device))[:, None].repeat(1, num_step, 1) # NOTE: do not use `expand`
362
+
363
+ feat_ina = torch.cat((invalid_agent_token_emb, x_ina, s_ina), dim=-1)
364
+ feat_ina = self.fusion_emb(feat_ina) # (1, num_step, hidden_dim)
365
+
366
+ return feat_ina
367
+
368
+ def build_temporal_edge(self, pos_a, head_a, head_vector_a, state_a, mask, inference_mask=None, av_index=None):
369
+
370
+ num_agent = pos_a.shape[0]
371
+ hist_mask = mask.clone()
372
+
373
+ if not self.temporal_attn_to_invalid:
374
+ hist_mask[state_a == self.invalid_state] = False
375
+
376
+ # set the position of invalid agents to the position of ego agent
377
+ ego_pos_a = pos_a[av_index].clone() # (num_step, 2)
378
+ ego_head_a = head_a[av_index].clone()
379
+ ego_head_vector_a = head_vector_a[av_index].clone()
380
+ ego_state_a = state_a[av_index].clone()
381
+ # is_invalid = state_a == self.invalid_state
382
+ # pos_a[is_invalid] = ego_pos_a[None, :].repeat(pos_a.shape[0], 1, 1)[is_invalid]
383
+ # head_a[is_invalid] = ego_head_a[None, :].repeat(head_a.shape[0], 1)[is_invalid]
384
+
385
+ # add seed agent
386
+ pos_a = torch.cat([pos_a, ego_pos_a[None]], dim=0)
387
+ head_a = torch.cat([head_a, ego_head_a[None]], dim=0)
388
+ state_a = torch.cat([state_a, ego_state_a[None]], dim=0)
389
+ head_vector_a = torch.cat([head_vector_a, ego_head_vector_a[None]], dim=0)
390
+ hist_mask = torch.cat([hist_mask, torch.ones_like(hist_mask[0:1])], dim=0).bool()
391
+ if not self.temporal_attn_seed:
392
+ hist_mask[-1:] = False
393
+ if inference_mask is not None:
394
+ inference_mask[-1:] = False
395
+
396
+ pos_t = pos_a.reshape(-1, self.input_dim) # (num_agent * num_step, ...)
397
+ head_t = head_a.reshape(-1)
398
+ head_vector_t = head_vector_a.reshape(-1, 2)
399
+
400
+ # for those invalid agents won't predict any motion token, we don't attend to them
401
+ is_bos = state_a == self.enter_state
402
+ is_bos[-1] = False
403
+ bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
404
+ motion_predict_start_index = torch.clamp(bos_index - self.time_span / self.shift + 1, min=0)
405
+ motion_predict_mask = torch.arange(hist_mask.shape[1]).expand(hist_mask.shape[0], -1).to(hist_mask.device)
406
+ motion_predict_mask = motion_predict_mask >= motion_predict_start_index[:, None]
407
+ hist_mask[~motion_predict_mask] = False
408
+
409
+ if self.hist_mask and self.training:
410
+ hist_mask[
411
+ torch.arange(mask.shape[0]).unsqueeze(1), torch.randint(0, mask.shape[1], (num_agent, 10))] = False
412
+ mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1)
413
+ elif inference_mask is not None:
414
+ mask_t = hist_mask.unsqueeze(2) & inference_mask.unsqueeze(1)
415
+ else:
416
+ mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1)
417
+
418
+ # mask_t: (num_agent, 18, 18), edge_index_t: (2, num_edge)
419
+ edge_index_t = dense_to_sparse(mask_t)[0]
420
+ edge_index_t = edge_index_t[:, (edge_index_t[1] - edge_index_t[0] > 0) &
421
+ (edge_index_t[1] - edge_index_t[0] <= self.time_span / self.shift)]
422
+ rel_pos_t = pos_t[edge_index_t[0]] - pos_t[edge_index_t[1]]
423
+ rel_head_t = wrap_angle(head_t[edge_index_t[0]] - head_t[edge_index_t[1]])
424
+
425
+ # FIXME relative motion/head for bos/eos token
426
+ # is_next_bos = state_a.roll(shifts=-1, dims=1) == self.enter_state
427
+ # is_next_bos[:, -1] = False # the last step
428
+ # is_next_bos_t = is_next_bos.reshape(-1)
429
+ # rel_pos_t[is_next_bos_t[edge_index_t[0]]] = -self.bos_motion
430
+ # rel_pos_t[is_next_bos_t[edge_index_t[1]]] = self.bos_motion
431
+ # rel_head_t[is_next_bos_t[edge_index_t[0]]] = -torch.pi
432
+ # rel_head_t[is_next_bos_t[edge_index_t[1]]] = torch.pi
433
+
434
+ # is_last_eos = state_a.roll(shifts=1, dims=1) == self.exit_state
435
+ # is_last_eos[:, 0] = False # the first step
436
+ # is_last_eos_t = is_last_eos.reshape(-1)
437
+ # rel_pos_t[is_last_eos_t[edge_index_t[0]]] = -self.bos_motion
438
+ # rel_pos_t[is_last_eos_t[edge_index_t[1]]] = self.bos_motion
439
+ # rel_head_t[is_last_eos_t[edge_index_t[0]]] = -torch.pi
440
+ # rel_head_t[is_last_eos_t[edge_index_t[1]]] = torch.pi
441
+
442
+ # handle the bos token of ego agent
443
+ # is_invalid = state_a == self.invalid_state
444
+ # is_invalid_t = is_invalid.reshape(-1)
445
+ # is_ego_bos = (ego_state_a == self.enter_state)[None, :].expand(num_agent + 1, -1)
446
+ # is_ego_bos_t = is_ego_bos.reshape(-1)
447
+ # rel_pos_t[is_invalid_t[edge_index_t[0]] & is_ego_bos_t[edge_index_t[0]]] = 0.
448
+ # rel_pos_t[is_invalid_t[edge_index_t[1]] & is_ego_bos_t[edge_index_t[1]]] = 0.
449
+ # rel_head_t[is_invalid_t[edge_index_t[0]] & is_ego_bos_t[edge_index_t[0]]] = 0.
450
+ # rel_head_t[is_invalid_t[edge_index_t[1]] & is_ego_bos_t[edge_index_t[1]]] = 0.
451
+
452
+ # handle the invalid steps
453
+ is_invalid = state_a == self.invalid_state
454
+ is_invalid_t = is_invalid.reshape(-1)
455
+ rel_pos_t[is_invalid_t[edge_index_t[0]]] = -self.motion_gap
456
+ rel_pos_t[is_invalid_t[edge_index_t[1]]] = self.motion_gap
457
+ rel_head_t[is_invalid_t[edge_index_t[0]]] = -self.heading_gap
458
+ rel_head_t[is_invalid_t[edge_index_t[1]]] = self.heading_gap
459
+
460
+ r_t = torch.stack(
461
+ [torch.norm(rel_pos_t[:, :2], p=2, dim=-1),
462
+ angle_between_2d_vectors(ctr_vector=head_vector_t[edge_index_t[1]], nbr_vector=rel_pos_t[:, :2]),
463
+ rel_head_t,
464
+ edge_index_t[0] - edge_index_t[1]], dim=-1)
465
+ r_t = self.r_t_emb(continuous_inputs=r_t, categorical_embs=None)
466
+
467
+ return edge_index_t, r_t
468
+
469
+ def build_interaction_edge(self, pos_a, head_a, head_vector_a, state_a, batch_s, mask_a, inference_mask=None, av_index=None):
470
+ num_agent, num_step, _ = pos_a.shape
471
+
472
+ pos_a = torch.cat([pos_a, pos_a[av_index][None]], dim=0)
473
+ head_a = torch.cat([head_a, head_a[av_index][None]], dim=0)
474
+ state_a = torch.cat([state_a, state_a[av_index][None]], dim=0)
475
+ head_vector_a = torch.cat([head_vector_a, head_vector_a[av_index][None]], dim=0)
476
+
477
+ pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
478
+ head_s = head_a.transpose(0, 1).reshape(-1)
479
+ head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
480
+ if inference_mask is not None:
481
+ mask_a = mask_a & inference_mask
482
+ mask_s = mask_a.transpose(0, 1).reshape(-1)
483
+
484
+ # seed agent
485
+ mask_seed = state_a[av_index] != self.invalid_state
486
+ pos_seed = pos_a[av_index]
487
+ edge_index_seed2a = radius(x=pos_seed[:, :2], y=pos_s[:, :2], r=self.pl2seed_radius,
488
+ batch_x=torch.arange(num_step).to(pos_s.device), batch_y=batch_s, max_num_neighbors=300)
489
+ edge_index_seed2a = edge_index_seed2a[:, mask_s[edge_index_seed2a[0]] & mask_seed[edge_index_seed2a[1]]]
490
+
491
+ # convert to global index (must be unilateral connection)
492
+ edge_index_seed2a[1, :] = (edge_index_seed2a[1, :] + 1) * (num_agent + 1) - 1
493
+
494
+ # build agent2agent bilateral connection
495
+ edge_index_a2a = radius_graph(x=pos_s[:, :2], r=self.a2a_radius, batch=batch_s, loop=False,
496
+ max_num_neighbors=300)
497
+ edge_index_a2a = subgraph(subset=mask_s, edge_index=edge_index_a2a)[0]
498
+
499
+ # add the edges which connect seed agents
500
+ edge_index_a2a = torch.cat([edge_index_a2a, edge_index_seed2a], dim=-1)
501
+
502
+ # set the position of invalid agents to the position of ego agent
503
+ # ego_pos_a = pos_a[av_index].clone()
504
+ # ego_head_a = head_a[av_index].clone()
505
+ # is_invalid = state_a == self.invalid_state
506
+ # pos_a[is_invalid] = ego_pos_a[None, :].repeat(pos_a.shape[0], 1, 1)[is_invalid]
507
+ # head_a[is_invalid] = ego_head_a[None, :].repeat(head_a.shape[0], 1)[is_invalid]
508
+
509
+ rel_pos_a2a = pos_s[edge_index_a2a[0]] - pos_s[edge_index_a2a[1]]
510
+ rel_head_a2a = wrap_angle(head_s[edge_index_a2a[0]] - head_s[edge_index_a2a[1]])
511
+
512
+ # relative motion/head for bos/eos token
513
+ # is_bos = state_a == self.enter_state
514
+ # is_bos_s = is_bos.transpose(0, 1).reshape(-1)
515
+ # rel_pos_a2a[is_bos_s[edge_index_a2a[0]]] = -self.bos_motion
516
+ # rel_pos_a2a[is_bos_s[edge_index_a2a[1]]] = self.bos_motion
517
+ # rel_head_a2a[is_bos_s[edge_index_a2a[0]]] = -torch.pi
518
+ # rel_head_a2a[is_bos_s[edge_index_a2a[1]]] = torch.pi
519
+
520
+ # is_last_eos = state_a.roll(shifts=-1, dims=1) == self.exit_state
521
+ # is_last_eos[:, 0] = False # first step
522
+ # is_last_eos_s = is_last_eos.transpose(0, 1).reshape(-1)
523
+ # rel_pos_a2a[is_last_eos_s[edge_index_a2a[0]]] = -self.bos_motion
524
+ # rel_pos_a2a[is_last_eos_s[edge_index_a2a[1]]] = self.bos_motion
525
+ # rel_head_a2a[is_last_eos_s[edge_index_a2a[0]]] = -torch.pi
526
+ # rel_head_a2a[is_last_eos_s[edge_index_a2a[1]]] = torch.pi
527
+
528
+ # handle the bos token of ego agent
529
+ # is_invalid = state_a == self.invalid_state
530
+ # is_invalid_s = is_invalid.transpose(0, 1).reshape(-1)
531
+ # is_ego_bos = (state_a[av_index] == self.enter_state)[None, :].expand(num_agent + 1, -1)
532
+ # is_ego_bos_s = is_ego_bos.transpose(0, 1).reshape(-1)
533
+ # rel_pos_a2a[is_invalid_s[edge_index_a2a[0]] & is_ego_bos_s[edge_index_a2a[0]]] = 0.
534
+ # rel_pos_a2a[is_invalid_s[edge_index_a2a[1]] & is_ego_bos_s[edge_index_a2a[1]]] = 0.
535
+ # rel_head_a2a[is_invalid_s[edge_index_a2a[0]] & is_ego_bos_s[edge_index_a2a[0]]] = 0.
536
+ # rel_head_a2a[is_invalid_s[edge_index_a2a[1]] & is_ego_bos_s[edge_index_a2a[1]]] = 0.
537
+
538
+ # handle the invalid steps
539
+ is_invalid = state_a == self.invalid_state
540
+ is_invalid_s = is_invalid.transpose(0, 1).reshape(-1)
541
+ rel_pos_a2a[is_invalid_s[edge_index_a2a[0]]] = -self.motion_gap
542
+ rel_pos_a2a[is_invalid_s[edge_index_a2a[1]]] = self.motion_gap
543
+ rel_head_a2a[is_invalid_s[edge_index_a2a[0]]] = -self.heading_gap
544
+ rel_head_a2a[is_invalid_s[edge_index_a2a[1]]] = self.heading_gap
545
+
546
+ r_a2a = torch.stack(
547
+ [torch.norm(rel_pos_a2a[:, :2], p=2, dim=-1),
548
+ angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2a[1]], nbr_vector=rel_pos_a2a[:, :2]),
549
+ rel_head_a2a], dim=-1)
550
+ r_a2a = self.r_a2a_emb(continuous_inputs=r_a2a, categorical_embs=None)
551
+
552
+ return edge_index_a2a, r_a2a
553
+
554
+ def build_map2agent_edge(self, data, num_step, pos_a, head_a, head_vector_a, state_a, batch_s, batch_pl,
555
+ mask, inference_mask=None, av_index=None):
556
+
557
+ num_agent, num_step, _ = pos_a.shape
558
+
559
+ mask_pl2a = mask.clone()
560
+ if inference_mask is not None:
561
+ mask_pl2a = mask_pl2a & inference_mask
562
+ mask_pl2a = mask_pl2a.transpose(0, 1).reshape(-1)
563
+
564
+ pos_a = torch.cat([pos_a, pos_a[av_index][None]], dim=0)
565
+ state_a = torch.cat([state_a, state_a[av_index][None]], dim=0)
566
+ head_a = torch.cat([head_a, head_a[av_index][None]], dim=0)
567
+ head_vector_a = torch.cat([head_vector_a, head_vector_a[av_index][None]], dim=0)
568
+
569
+ pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
570
+ head_s = head_a.transpose(0, 1).reshape(-1)
571
+ head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
572
+
573
+ ori_pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous()
574
+ ori_orient_pl = data['pt_token']['orientation'].contiguous()
575
+ pos_pl = ori_pos_pl.repeat(num_step, 1) # not `repeat_interleave`
576
+ orient_pl = ori_orient_pl.repeat(num_step)
577
+
578
+ # seed agent
579
+ mask_seed = state_a[av_index] != self.invalid_state
580
+ pos_seed = pos_a[av_index]
581
+ edge_index_pl2seed = radius(x=pos_seed[:, :2], y=pos_pl[:, :2], r=self.pl2seed_radius,
582
+ batch_x=torch.arange(num_step).to(pos_s.device), batch_y=batch_pl, max_num_neighbors=600)
583
+ edge_index_pl2seed = edge_index_pl2seed[:, mask_seed[edge_index_pl2seed[1]]]
584
+
585
+ # convert to global index
586
+ edge_index_pl2seed[1, :] = (edge_index_pl2seed[1, :] + 1) * (num_agent + 1) - 1
587
+
588
+ # build map2agent directed graph
589
+ # edge_index_pl2a[0]: pl token; edge_index_pl2a[1]: agent token
590
+ edge_index_pl2a = radius(x=pos_s[:, :2], y=pos_pl[:, :2], r=self.pl2a_radius,
591
+ batch_x=batch_s, batch_y=batch_pl, max_num_neighbors=300)
592
+ # We force invalid agents to interact with **all** (visible in current window) map tokens
593
+ # invalid_node_index_a = torch.where(bos_state_s.bool())[0]
594
+ # sampled_node_index_m = torch.arange(ori_pos_pl.shape[0]).to(pos_pl.device)
595
+ # if kwargs.get('sample_pt_indices', None) is not None:
596
+ # sampled_node_index_m = sampled_node_index_m[kwargs['sample_pt_indices'].long()]
597
+ # grid_a, grid_b = torch.meshgrid(sampled_node_index_m, invalid_node_index_a, indexing='ij')
598
+ # invalid_edge_index_pl2a = torch.stack([grid_a.reshape(-1), grid_b.reshape(-1)], dim=0)
599
+ # edge_index_pl2a = torch.concat([edge_index_pl2a, invalid_edge_index_pl2a], dim=-1)
600
+ # remove the edges which connect with motion-invalid agents
601
+ edge_index_pl2a = edge_index_pl2a[:, mask_pl2a[edge_index_pl2a[1]]]
602
+
603
+ # add the edges which connect seed agents with map tokens
604
+ edge_index_pl2a = torch.cat([edge_index_pl2a, edge_index_pl2seed], dim=-1)
605
+
606
+ # set the position of invalid agents to the position of ego agent
607
+ # ego_pos_a = pos_a[av_index].clone()
608
+ # ego_head_a = head_a[av_index].clone()
609
+ # is_invalid = state_a == self.invalid_state
610
+ # pos_a[is_invalid] = ego_pos_a[None, :].repeat(pos_a.shape[0], 1, 1)[is_invalid]
611
+ # head_a[is_invalid] = ego_head_a[None, :].repeat(head_a.shape[0], 1)[is_invalid]
612
+
613
+ rel_pos_pl2a = pos_pl[edge_index_pl2a[0]] - pos_s[edge_index_pl2a[1]]
614
+ rel_orient_pl2a = wrap_angle(orient_pl[edge_index_pl2a[0]] - head_s[edge_index_pl2a[1]])
615
+
616
+ # handle the invalid steps
617
+ is_invalid = state_a == self.invalid_state
618
+ is_invalid_s = is_invalid.transpose(0, 1).reshape(-1)
619
+ rel_pos_pl2a[is_invalid_s[edge_index_pl2a[1]]] = self.motion_gap
620
+ rel_orient_pl2a[is_invalid_s[edge_index_pl2a[1]]] = self.heading_gap
621
+
622
+ r_pl2a = torch.stack(
623
+ [torch.norm(rel_pos_pl2a[:, :2], p=2, dim=-1),
624
+ angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_pl2a[1]], nbr_vector=rel_pos_pl2a[:, :2]),
625
+ rel_orient_pl2a], dim=-1)
626
+ r_pl2a = self.r_pt2a_emb(continuous_inputs=r_pl2a, categorical_embs=None)
627
+
628
+ return edge_index_pl2a, r_pl2a
629
+
630
+ def get_inputs(self, data: HeteroData) -> Dict[str, torch.Tensor]:
631
+
632
+ pos_a = data['agent']['token_pos']
633
+ head_a = data['agent']['token_heading']
634
+ agent_category = data['agent']['category']
635
+ agent_token_index = data['agent']['token_idx']
636
+ agent_state_index = data['agent']['state_idx']
637
+ mask = data['agent']['raw_agent_valid_mask'].clone()
638
+ # mask[agent_category != 3] = False
639
+
640
+ if not self.predict_state:
641
+ agent_state_index = None
642
+
643
+ next_token_index_gt = agent_token_index.roll(shifts=-1, dims=1)
644
+ next_state_index_gt = agent_state_index.roll(shifts=-1, dims=1)
645
+
646
+ if self.predict_state:
647
+ next_token_eval_mask = mask.clone()
648
+ next_token_eval_mask = next_token_eval_mask * next_token_eval_mask.roll(shifts=1, dims=1)
649
+ bos_token_index = torch.nonzero(agent_state_index == 2)
650
+ eos_token_index = torch.nonzero(agent_state_index == 3)
651
+ next_token_eval_mask[bos_token_index[:, 0], bos_token_index[:, 1]] = 1
652
+ for eos_token_index_ in eos_token_index:
653
+ if not next_token_eval_mask[eos_token_index_[0], eos_token_index_[1]]:
654
+ next_token_eval_mask[eos_token_index_[0], eos_token_index_[1]:] = 0
655
+ next_token_eval_mask = next_token_eval_mask.roll(shifts=-1, dims=1)
656
+ # TODO: next_state_eval_mask !!!
657
+
658
+ if next_token_index_gt[next_token_eval_mask].min() < 0:
659
+ raise RuntimeError()
660
+
661
+ next_token_eval_mask[:, -1] = False
662
+
663
+ return {'token_pos': pos_a,
664
+ 'token_heading': head_a,
665
+ 'agent_category': agent_category,
666
+ 'next_token_idx_gt': next_token_index_gt,
667
+ 'next_state_idx_gt': next_state_index_gt,
668
+ 'next_token_eval_mask': next_token_eval_mask,
669
+ 'raw_agent_valid_mask': data['agent']['raw_agent_valid_mask'],
670
+ }
671
+
672
+ def forward(self,
673
+ data: HeteroData,
674
+ map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
675
+
676
+ pos_a = data['agent']['token_pos'].clone() # (num_agent, num_shifted_step, 2)
677
+ head_a = data['agent']['token_heading'].clone() # (num_agent, num_shifted_step)
678
+ num_agent, num_step, traj_dim = pos_a.shape # e.g. (50, 18, 2)
679
+ agent_category = data['agent']['category'].clone() # (num_agent,)
680
+ agent_token_index = data['agent']['token_idx'].clone() # (num_agent, num_step)
681
+ agent_state_index = data['agent']['state_idx'].clone() # (num_agent, num_step)
682
+ agent_type_index = data['agent']['type'].clone() # (num_agent, num_step)
683
+ agent_enter_pl_token_idx = None
684
+ agent_enter_offset_token_idx = None
685
+
686
+ device = pos_a.device
687
+
688
+ seed_step_mask = agent_state_index[:, 1:] == self.enter_state
689
+ if torch.any(seed_step_mask.sum(dim=0) > self.seed_size):
690
+ print(agent_state_index)
691
+ print(agent_state_index.shape)
692
+ print(seed_step_mask.long())
693
+ print(seed_step_mask.sum(dim=0))
694
+ raise RuntimeError(f"Seed size {self.seed_size} is too small.")
695
+
696
+ # fix pos and head of invalid agents
697
+ av_index = int(data['agent']['av_index'])
698
+ # ego_pos_a = pos_a[av_index].clone() # (num_shifted_step, 2)
699
+ # ego_head_vector_a = head_vector_a[av_index] # (num_shifted_step, 2)
700
+ # is_invalid = agent_state_index == self.invalid_state
701
+ # pos_a[is_invalid] = ego_pos_a[None, :].expand(pos_a.shape[0], -1, -1)[is_invalid]
702
+ # head_vector_a[is_invalid] = ego_head_vector_a[None, :].expand(head_vector_a.shape[0], -1, -1)[is_invalid]
703
+
704
+ if not self.predict_state:
705
+ agent_state_index = None
706
+
707
+ feat_a, head_vector_a = self.agent_token_embedding(data, agent_token_index, agent_state_index, pos_a, head_a, av_index=av_index)
708
+
709
+ # build masks
710
+ mask = data['agent']['raw_agent_valid_mask'].clone()
711
+ temporal_mask = mask.clone()
712
+ interact_mask = mask.clone()
713
+ if self.predict_state:
714
+
715
+ agent_enter_offset_token_idx = data['agent']['neighbor_token_idx']
716
+ agent_enter_pl_token_idx = data['agent']['map_bos_token_idx']
717
+ agent_enter_pl_token_id = data['agent']['map_bos_token_id']
718
+
719
+ is_bos = agent_state_index == self.enter_state
720
+ is_eos = agent_state_index == self.exit_state
721
+ bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
722
+ eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1)) # not `-1`
723
+
724
+ temporal_mask = torch.ones_like(mask)
725
+ motion_mask = torch.arange(mask.shape[1]).expand(mask.shape[0], -1).to(device)
726
+ motion_mask = (motion_mask > bos_index[:, None]) & (motion_mask <= eos_index[:, None])
727
+ temporal_mask[motion_mask] = mask[motion_mask]
728
+
729
+ interact_mask[agent_state_index == self.enter_state] = True
730
+ interact_mask = torch.cat([interact_mask, torch.ones_like(interact_mask[:1])]).bool() # placeholder
731
+
732
+ edge_index_t, r_t = self.build_temporal_edge(pos_a, head_a, head_vector_a, agent_state_index, temporal_mask,
733
+ av_index=av_index)
734
+
735
+ # +1: placeholder for seed agent
736
+ # if isinstance(data, Batch):
737
+ # print(data['agent']['batch'], data.num_graphs)
738
+ # batch_s = torch.cat([data['agent']['batch'] + data.num_graphs * t for t in range(num_step)], dim=0)
739
+ # batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t for t in range(num_step)], dim=0)
740
+ # else:
741
+ batch_s = torch.arange(num_step, device=device).repeat_interleave(data['agent']['num_nodes'] + 1)
742
+ batch_pl = torch.arange(num_step, device=device).repeat_interleave(data['pt_token']['num_nodes'])
743
+
744
+ edge_index_a2a, r_a2a = self.build_interaction_edge(pos_a, head_a, head_vector_a, agent_state_index, batch_s,
745
+ interact_mask, av_index=av_index)
746
+
747
+ agent_category = torch.cat([agent_category, torch.full(agent_category[-1:].shape, 3, device=device)])
748
+ interact_mask[agent_category != 3] = False
749
+ edge_index_pl2a, r_pl2a = self.build_map2agent_edge(data, num_step, pos_a, head_a, head_vector_a,
750
+ agent_state_index, batch_s, batch_pl, interact_mask, av_index=av_index)
751
+
752
+ # mapping network
753
+ # z = torch.randn(num_agent, self.hidden_dim).to(feat_a.device)
754
+ # w = self.mapping_network(z)
755
+
756
+ for i in range(self.num_layers):
757
+
758
+ # feat_a = feat_a + w[:, None]
759
+
760
+ feat_a = feat_a.reshape(-1, self.hidden_dim) # (num_agent, num_step, hidden_dim) -> (seq_len, hidden_dim)
761
+ feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t)
762
+
763
+ feat_a = feat_a.reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim)
764
+ feat_a = self.pt2a_attn_layers[i]((
765
+ map_enc['x_pt'].repeat_interleave(repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(
766
+ -1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a)
767
+
768
+ feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a)
769
+ feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1)
770
+
771
+ # next motion token
772
+ next_token_prob = self.token_predict_head(feat_a[:-1]) # (num_agent, num_step, token_size)
773
+ next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)
774
+ _, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1) # (num_agent, num_step, 10)
775
+
776
+ next_token_index_gt = agent_token_index.roll(shifts=-1, dims=1)
777
+
778
+ # next state token
779
+ next_state_prob = self.state_predict_head(feat_a[:-1])
780
+ next_state_idx = next_state_prob.softmax(dim=-1).argmax(dim=-1, keepdim=True) # (num_agent, num_step, 1)
781
+
782
+ next_state_index_gt = agent_state_index.roll(shifts=-1, dims=1) # (invalid, valid, exit)
783
+
784
+ # seed agent
785
+ feat_seed = self.seed_head(feat_a[-1:]) + self.seed_feature.weight[:, None]
786
+ next_state_prob_seed = self.seed_state_predict_head(feat_seed)
787
+ next_state_idx_seed = next_state_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True) # (self.seed_size, num_step, 1)
788
+
789
+ next_type_prob_seed = self.seed_type_predict_head(feat_seed)
790
+ next_type_idx_seed = next_type_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True)
791
+
792
+ next_type_index_gt = agent_type_index[:, None].expand(-1, num_step).roll(shifts=-1, dims=1)
793
+
794
+ # polygon token for bos token
795
+ # next_bos_pl_prob = self.bos_pl_predict_head(feat_a)
796
+ # next_bos_pl_prob_softmax = torch.softmax(next_bos_pl_prob, dim=-1)
797
+ # _, next_bos_pl_idx = torch.topk(next_bos_pl_prob_softmax, k=1, dim=-1) # (num_agent, num_step, 1)
798
+
799
+ # next_bos_pl_index_gt = agent_enter_pl_token_id.roll(shifts=-1, dims=-1)
800
+
801
+ # offset token for bos token
802
+ # next_bos_offset_prob = self.bos_offset_predict_head(feat_a)
803
+ # next_bos_offset_prob_softmax = torch.softmax(next_bos_offset_prob, dim=-1)
804
+ # _, next_bos_offset_idx = torch.topk(next_bos_offset_prob_softmax, k=1, dim=-1)
805
+
806
+ # next_bos_offset_index_gt = agent_enter_offset_token_idx.roll(shifts=-1, dims=-1)
807
+
808
+ # next token prediction mask
809
+ bos_token_index = torch.nonzero(agent_state_index == self.enter_state)
810
+ eos_token_index = torch.nonzero(agent_state_index == self.exit_state)
811
+
812
+ # mask for motion tokens
813
+ next_token_eval_mask = mask.clone()
814
+ next_token_eval_mask = next_token_eval_mask * next_token_eval_mask.roll(shifts=-1, dims=1) * next_token_eval_mask.roll(shifts=1, dims=1)
815
+ for bos_token_index_ in bos_token_index:
816
+ next_token_eval_mask[bos_token_index_[0], bos_token_index_[1] : bos_token_index_[1] + 1] = 1
817
+ next_token_eval_mask[bos_token_index_[0], bos_token_index_[1] + 1 : bos_token_index_[1] + 2] = \
818
+ mask[bos_token_index_[0], bos_token_index_[1] + 2 : bos_token_index_[1] + 3]
819
+ next_token_eval_mask[eos_token_index[:, 0], eos_token_index[:, 1]] = 0
820
+
821
+ # mask for state tokens
822
+ next_state_eval_mask = mask.clone()
823
+ next_state_eval_mask = next_state_eval_mask * next_state_eval_mask.roll(shifts=-1, dims=1) * next_state_eval_mask.roll(shifts=1, dims=1)
824
+ for bos_token_index_ in bos_token_index:
825
+ next_state_eval_mask[bos_token_index_[0], :bos_token_index_[1]] = 0
826
+ next_state_eval_mask[bos_token_index_[0], bos_token_index_[1] : bos_token_index_[1] + 1] = 1
827
+ next_state_eval_mask[bos_token_index_[0], bos_token_index_[1] + 1 : bos_token_index_[1] + 2] = \
828
+ mask[bos_token_index_[0], bos_token_index_[1] + 2 : bos_token_index_[1] + 3]
829
+ for eos_token_index_ in eos_token_index:
830
+ next_state_eval_mask[eos_token_index_[0], eos_token_index_[1] + 1:] = 1
831
+ next_state_eval_mask[eos_token_index_[0], eos_token_index_[1] : eos_token_index_[1] + 1] = \
832
+ mask[eos_token_index_[0], eos_token_index_[1] - 1 : eos_token_index_[1]]
833
+
834
+ # seed agents
835
+ next_bos_token_index = torch.nonzero(next_state_index_gt == self.enter_state)
836
+ next_bos_token_index = next_bos_token_index[next_bos_token_index[:, 1] < num_step - 1]
837
+
838
+ next_state_index_gt_seed = torch.full((self.seed_size, num_step), self.seed_state_type.index('invalid'), device=next_state_index_gt.device)
839
+ next_type_index_gt_seed = torch.full((self.seed_size, num_step), self.seed_agent_type.index('seed'), device=next_state_index_gt.device)
840
+ next_eval_mask_seed = torch.ones_like(next_state_index_gt_seed)
841
+
842
+ num_seed = torch.zeros(num_step, device=next_state_index_gt.device).long()
843
+ for next_bos_token_index_ in next_bos_token_index:
844
+ if num_seed[next_bos_token_index_[1]] < self.seed_size:
845
+ next_state_index_gt_seed[num_seed[next_bos_token_index_[1]], next_bos_token_index_[1]] = self.seed_state_type.index('enter')
846
+ next_type_index_gt_seed[num_seed[next_bos_token_index_[1]], next_bos_token_index_[1]] = next_type_index_gt[next_bos_token_index_[0], next_bos_token_index_[1]]
847
+ num_seed[next_bos_token_index_[1]] += 1
848
+
849
+ # the last timestep is the beginning of the sequence (also the input)
850
+ next_token_eval_mask[:, -1] = 0
851
+ next_state_eval_mask[:, -1] = 0
852
+ next_eval_mask_seed[:, -1] = 0
853
+ # next_bos_token_eval_mask[:, -1] = False
854
+
855
+ # no invalid motion token will be supervised
856
+ if (next_token_index_gt[next_token_eval_mask] < 0).any():
857
+ raise RuntimeError()
858
+
859
+ next_state_index_gt[next_state_index_gt == self.exit_state] = self.valid_state_type.index('exit')
860
+
861
+ return {'x_a': feat_a,
862
+ # motion token
863
+ 'next_token_idx': next_token_idx,
864
+ 'next_token_prob': next_token_prob,
865
+ 'next_token_idx_gt': next_token_index_gt,
866
+ 'next_token_eval_mask': next_token_eval_mask.bool(),
867
+ # state token
868
+ 'next_state_idx': next_state_idx,
869
+ 'next_state_prob': next_state_prob,
870
+ 'next_state_idx_gt': next_state_index_gt,
871
+ 'next_state_eval_mask': next_state_eval_mask.bool(),
872
+ # seed agent
873
+ 'next_state_idx_seed': next_state_idx_seed,
874
+ 'next_state_prob_seed': next_state_prob_seed,
875
+ 'next_state_idx_gt_seed': next_state_index_gt_seed,
876
+ 'next_type_idx_seed': next_type_idx_seed,
877
+ 'next_type_prob_seed': next_type_prob_seed,
878
+ 'next_type_idx_gt_seed': next_type_index_gt_seed,
879
+ 'next_eval_mask_seed': next_eval_mask_seed.bool(),
880
+ # pl token for bos
881
+ # 'next_bos_pl_idx': next_bos_pl_idx,
882
+ # 'next_bos_pl_prob': next_bos_pl_prob,
883
+ # 'next_bos_pl_index_gt': next_bos_pl_index_gt,
884
+ # offset token for bos
885
+ # 'next_bos_offset_idx': next_bos_offset_idx,
886
+ # 'next_bos_offset_prob': next_bos_offset_prob,
887
+ # 'next_bos_offset_index_gt': next_bos_offset_index_gt,
888
+ # 'next_bos_token_eval_mask': next_bos_token_eval_mask,
889
+ }
890
+
891
+ def inference(self,
892
+ data: HeteroData,
893
+ map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
894
+
895
+ start_state_idx = data['agent']['state_idx'][:, (self.num_historical_steps - 1) // self.shift]
896
+ filter_mask = (start_state_idx == self.valid_state) | (start_state_idx == self.exit_state)
897
+ seed_step_mask = data['agent']['state_idx'][:, (self.num_historical_steps - 1) // self.shift:] == self.enter_state
898
+ seed_agent_index_per_step = [torch.nonzero(seed_step_mask[:, t]).squeeze(dim=-1) for t in range(seed_step_mask.shape[1])]
899
+ if torch.any(seed_step_mask.sum(dim=0) > self.seed_size):
900
+ raise RuntimeError(f"Seed size {self.seed_size} is too small.")
901
+
902
+ # num_historical_steps=11
903
+ eval_mask = data['agent']['valid_mask'][filter_mask, self.num_historical_steps - 1]
904
+
905
+ if self.predict_state:
906
+ eval_mask = torch.ones_like(eval_mask).bool()
907
+
908
+ # agent attributes
909
+ pos_a = data['agent']['token_pos'][filter_mask].clone() # (num_agent, num_step, 2)
910
+ state_a = data['agent']['state_idx'][filter_mask].clone() # (num_agent, num_step)
911
+ head_a = data['agent']['token_heading'][filter_mask].clone() # (num_agent, num_step)
912
+ gt_traj = data['agent']['position'][filter_mask, self.num_historical_steps:, :self.input_dim].contiguous()
913
+ num_agent, num_step, traj_dim = pos_a.shape
914
+
915
+ av_index = int(data['agent']['av_index'])
916
+ av_index -= (~filter_mask[:av_index]).sum()
917
+
918
+ # map attributes
919
+ pos_pl = data['pt_token']['position'][:, :2].clone() # (num_pl, 2)
920
+
921
+ # make future steps to zero
922
+ pos_a[:, (self.num_historical_steps - 1) // self.shift:] = 0
923
+ state_a[:, (self.num_historical_steps - 1) // self.shift:] = 0
924
+ head_a[:, (self.num_historical_steps - 1) // self.shift:] = 0
925
+
926
+ agent_valid_mask = data['agent']['raw_agent_valid_mask'][filter_mask].clone() # token_valid_mask
927
+ agent_valid_mask[:, (self.num_historical_steps - 1) // self.shift:] = True
928
+ agent_valid_mask[~eval_mask] = False
929
+ agent_token_index = data['agent']['token_idx'][filter_mask]
930
+ agent_state_index = data['agent']['state_idx'][filter_mask]
931
+ agent_type = data['agent']['type'][filter_mask]
932
+ agent_category = data['agent']['category'][filter_mask]
933
+
934
+ feat_a, head_vector_a, agent_token_traj_all, agent_token_emb, categorical_embs = self.agent_token_embedding(data,
935
+ agent_token_index,
936
+ agent_state_index,
937
+ pos_a,
938
+ head_a,
939
+ inference=True,
940
+ filter_mask=filter_mask,
941
+ av_index=av_index,
942
+ )
943
+ feat_seed = feat_a[-1:]
944
+ feat_a = feat_a[:-1]
945
+
946
+ agent_type = data["agent"]["type"][filter_mask]
947
+ veh_mask = agent_type == 0
948
+ cyc_mask = agent_type == 2
949
+ ped_mask = agent_type == 1
950
+
951
+ # self.num_recurrent_steps_val = 91 - 11 = 80
952
+ self.num_recurrent_steps_val = data["agent"]['position'].shape[1] - self.num_historical_steps
953
+ pred_traj = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val, 2, device=feat_a.device) # (num_agent, 80, 2)
954
+ pred_head = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val, device=feat_a.device)
955
+ pred_type = agent_type.clone()
956
+ pred_state = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val, device=feat_a.device)
957
+ pred_prob = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val // self.shift, device=feat_a.device) # (num_agent, 80 // 5 = 16)
958
+ next_token_idx_list = []
959
+ next_state_idx_list = []
960
+ next_bos_pl_idx_list = []
961
+ next_bos_offset_idx_list = []
962
+ feat_a_t_dict = {}
963
+ feat_sa_t_dict = {}
964
+
965
+ # build masks (init)
966
+ mask = agent_valid_mask.clone()
967
+ temporal_mask = mask.clone()
968
+ interact_mask = mask.clone()
969
+ if self.predict_state:
970
+
971
+ # find bos and eos index
972
+ is_bos = agent_state_index == self.enter_state
973
+ is_eos = agent_state_index == self.exit_state
974
+ bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
975
+ eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1))
976
+
977
+ temporal_mask = torch.ones_like(mask)
978
+ motion_mask = torch.arange(mask.shape[1]).expand(mask.shape[0], mask.shape[1]).to(mask.device)
979
+ motion_mask = (motion_mask > bos_index[:, None]) & (motion_mask <= eos_index[:, None])
980
+ motion_mask[:, self.num_historical_steps // self.shift:] = False
981
+ temporal_mask[motion_mask] = mask[motion_mask]
982
+
983
+ interact_mask = torch.ones_like(mask)
984
+ non_motion_mask = ~motion_mask
985
+ non_motion_mask[:, self.num_historical_steps // self.shift:] = False
986
+ interact_mask[non_motion_mask] = False
987
+ interact_mask[agent_state_index == self.enter_state] = True
988
+
989
+ temporal_mask[:, (self.num_historical_steps - 1) // self.shift:] = True
990
+ interact_mask[:, (self.num_historical_steps - 1) // self.shift:] = True
991
+
992
+ # mapping network
993
+ # z = torch.randn(num_agent, self.hidden_dim).to(feat_a.device)
994
+ # w = self.mapping_network(z)
995
+
996
+ # we only need to predict 16 next tokens
997
+ for t in range(self.num_recurrent_steps_val // self.shift):
998
+
999
+ # feat_a = feat_a + w[:, None]
1000
+ num_agent = pos_a.shape[0]
1001
+
1002
+ if t == 0:
1003
+ inference_mask = temporal_mask.clone()
1004
+ inference_mask = torch.cat([inference_mask, torch.ones_like(inference_mask[-1:])])
1005
+ inference_mask[:, (self.num_historical_steps - 1) // self.shift + t:] = False
1006
+ else:
1007
+ inference_mask = torch.zeros_like(temporal_mask)
1008
+ inference_mask = torch.cat([inference_mask, torch.zeros_like(inference_mask[-1:])])
1009
+ inference_mask[:, max((self.num_historical_steps - 1) // self.shift + t - (self.num_interaction_steps // self.shift), 0) :
1010
+ (self.num_historical_steps - 1) // self.shift + t] = True
1011
+
1012
+ interact_mask = torch.cat([interact_mask, torch.ones_like(interact_mask[:1])]).bool() # placeholder
1013
+
1014
+ edge_index_t, r_t = self.build_temporal_edge(pos_a, head_a, head_vector_a, state_a, temporal_mask, inference_mask,
1015
+ av_index=av_index)
1016
+
1017
+ # +1: placeholder for seed agent
1018
+ batch_s = torch.arange(num_step, device=pos_a.device).repeat_interleave(num_agent + 1)
1019
+ batch_pl = torch.arange(num_step, device=pos_a.device).repeat_interleave(data['pt_token']['num_nodes'])
1020
+
1021
+ # In the inference stage, we only infer the current stage for recurrent
1022
+ edge_index_a2a, r_a2a = self.build_interaction_edge(pos_a, head_a, head_vector_a, state_a, batch_s,
1023
+ interact_mask, inference_mask, av_index=av_index)
1024
+ edge_index_pl2a, r_pl2a = self.build_map2agent_edge(data, num_step, pos_a, head_a, head_vector_a, state_a, batch_s, batch_pl,
1025
+ interact_mask, inference_mask, av_index=av_index)
1026
+ interact_mask = interact_mask[:-1]
1027
+
1028
+ # if t > 0:
1029
+ # feat_a_sum = feat_a.sum(dim=-1)
1030
+ # for a in range(pos_a.shape[0]):
1031
+ # t_1 = (self.num_historical_steps - 1) // self.shift + t - 1
1032
+ # print(f"agent {a} t_1 {t_1}")
1033
+ # print(f"token: {next_token_idx[a]}")
1034
+ # print(f"state: {next_state_idx[a]}")
1035
+ # print(f"feat_a_sum: {feat_a_sum[a, t_1]}")
1036
+
1037
+ for i in range(self.num_layers):
1038
+
1039
+ if (i in feat_a_t_dict) and (i in feat_sa_t_dict):
1040
+ feat_a = feat_a_t_dict[i]
1041
+ feat_seed = feat_sa_t_dict[i]
1042
+
1043
+ feat_a = torch.cat([feat_a, feat_seed], dim=0)
1044
+
1045
+ feat_a = feat_a.reshape(-1, self.hidden_dim)
1046
+ feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t)
1047
+
1048
+ feat_a = feat_a.reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim)
1049
+ feat_a = self.pt2a_attn_layers[i]((
1050
+ map_enc['x_pt'].repeat_interleave(repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(
1051
+ -1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a)
1052
+
1053
+ feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a)
1054
+ feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1)
1055
+
1056
+ feat_seed = feat_a[-1:] # (1, num_step, hidden_dim)
1057
+ feat_a = feat_a[:-1] # (num_agent, num_step, hidden_dim)
1058
+
1059
+ if t == 0:
1060
+ feat_a_t_dict[i + 1] = feat_a
1061
+ feat_sa_t_dict[i + 1] = feat_seed
1062
+ else:
1063
+ # update agent features at current step
1064
+ n = feat_a_t_dict[i + 1].shape[0]
1065
+ feat_a_t_dict[i + 1][:n, (self.num_historical_steps - 1) // self.shift - 1 + t] = feat_a[:n, (self.num_historical_steps - 1) // self.shift - 1 + t]
1066
+ # add newly inserted agent features (only when t changed)
1067
+ if feat_a.shape[0] > n:
1068
+ m = feat_a.shape[0] - n
1069
+ feat_a_t_dict[i + 1] = torch.cat([feat_a_t_dict[i + 1], feat_a[-m:]])
1070
+ # update seed agent features at current step
1071
+ feat_sa_t_dict[i + 1][:, (self.num_historical_steps - 1) // self.shift - 1 + t] = feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t]
1072
+
1073
+ # next motion token
1074
+ next_token_prob = self.token_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
1075
+ next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)
1076
+ topk_token_prob, next_token_idx = torch.topk(next_token_prob_softmax, k=self.beam_size, dim=-1) # both (num_agent, beam_size) e.g. (31, 5)
1077
+
1078
+ # next state token
1079
+ next_state_prob = self.state_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
1080
+ next_state_idx = next_state_prob.softmax(dim=-1).argmax(dim=-1)
1081
+ next_state_idx[next_state_idx == self.valid_state_type.index('exit')] = self.exit_state
1082
+
1083
+ # seed agent
1084
+ feat_seed = self.seed_head(feat_seed) + self.seed_feature.weight[:, None]
1085
+ next_state_prob_seed = self.seed_state_predict_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
1086
+ next_state_idx_seed = next_state_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True)
1087
+ next_state_idx_seed[next_state_idx_seed == self.seed_state_type.index('enter')] = self.enter_state
1088
+
1089
+ next_type_prob_seed = self.seed_type_predict_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
1090
+ next_type_idx_seed = next_type_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True)
1091
+
1092
+ # print(f"t: {t}")
1093
+ # print(next_type_idx_seed[..., 0].tolist())
1094
+
1095
+ # bos pl prediction
1096
+ # next_bos_pl_prob = self.bos_pl_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
1097
+ # next_bos_pl_prob_softmax = torch.softmax(next_bos_pl_prob, dim=-1)
1098
+ # next_bos_pl_idx = torch.argmax(next_bos_pl_prob_softmax, dim=-1)
1099
+
1100
+ # bos offset prediction
1101
+ # next_bos_offset_prob = self.bos_offset_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
1102
+ # next_bos_offset_prob_softmax = torch.softmax(next_bos_offset_prob, dim=-1)
1103
+ # next_bos_offset_idx = torch.argmax(next_bos_offset_prob_softmax, dim=-1)
1104
+
1105
+ # convert the predicted token to a 0.5s (6 timesteps) trajectory
1106
+ expanded_token_index = next_token_idx[..., None, None, None].expand(-1, -1, 6, 4, 2)
1107
+ next_token_traj = torch.gather(agent_token_traj_all, 1, expanded_token_index) # (num_agent, beam_size, 6, 4, 2)
1108
+
1109
+ # apply rotation and translation on 'next_token_traj'
1110
+ theta = head_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]
1111
+ cos, sin = theta.cos(), theta.sin()
1112
+ rot_mat = torch.zeros((num_agent, 2, 2), device=theta.device)
1113
+ rot_mat[:, 0, 0] = cos
1114
+ rot_mat[:, 0, 1] = sin
1115
+ rot_mat[:, 1, 0] = -sin
1116
+ rot_mat[:, 1, 1] = cos
1117
+ agent_diff_rel = torch.bmm(next_token_traj.view(-1, 4, 2),
1118
+ rot_mat[:, None, None, ...].repeat(1, self.beam_size, self.shift + 1, 1, 1).view(
1119
+ -1, 2, 2)).view(num_agent, self.beam_size, self.shift + 1, 4, 2)
1120
+ agent_pred_rel = agent_diff_rel + pos_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t, :][:, None, None, None, ...]
1121
+
1122
+ # sample 1 most probable index of top beam_size tokens, (num_agent, beam_size) -> (num_agent, 1)
1123
+ # then sample the agent_pred_rel, (num_agent, beam_size, 6, 4, 2) -> (num_agent, 6, 4, 2)
1124
+ sample_token_index = torch.multinomial(topk_token_prob, 1).to(agent_pred_rel.device)
1125
+ next_token_idx = next_token_idx.gather(dim=1, index=sample_token_index).squeeze(-1)
1126
+ agent_pred_rel = agent_pred_rel.gather(dim=1,
1127
+ index=sample_token_index[..., None, None, None].expand(-1, -1, 6, 4,
1128
+ 2))[:, 0, ...]
1129
+
1130
+ # get predicted position and heading of current shifted timesteps
1131
+ diff_xy = agent_pred_rel[:, 1:, 0, :] - agent_pred_rel[:, 1:, 3, :]
1132
+ pred_traj[:num_agent, t * 5 : (t + 1) * 5] = agent_pred_rel[:, 1:, ...].clone().mean(dim=2)
1133
+ pred_head[:num_agent, t * 5 : (t + 1) * 5] = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])
1134
+ pred_state[:num_agent, t * 5 : (t + 1) * 5] = next_state_idx[:, None].repeat(1, 5)
1135
+ # pred_prob[:num_agent, t] = topk_token_prob.gather(dim=-1, index=sample_token_index)[:, 0] # (num_agent, beam_size) -> (num_agent,)
1136
+
1137
+ # update pos/head/state of current step
1138
+ pos_a[:, (self.num_historical_steps - 1) // self.shift + t] = agent_pred_rel[:, -1, ...].clone().mean(dim=1)
1139
+ diff_xy = agent_pred_rel[:, -1, 0, :] - agent_pred_rel[:, -1, 3, :]
1140
+ theta = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])
1141
+ head_a[:, (self.num_historical_steps - 1) // self.shift + t] = theta
1142
+ state_a[:, (self.num_historical_steps - 1) // self.shift + t] = next_state_idx
1143
+
1144
+ # the case that the current predicted state token is invalid/exit
1145
+ is_eos = next_state_idx == self.exit_state
1146
+ is_invalid = next_state_idx == self.invalid_state
1147
+
1148
+ next_token_idx[is_invalid] = -1
1149
+ pos_a[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = 0.
1150
+ head_a[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = 0.
1151
+
1152
+ mask[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = False # to handle those newly-added agents
1153
+ interact_mask[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = False
1154
+
1155
+ agent_token_emb[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = self.no_token_emb(torch.zeros(1, device=pos_a.device).long())
1156
+
1157
+ type_emb = categorical_embs[0].reshape(num_agent, num_step, -1)
1158
+ shape_emb = categorical_embs[1].reshape(num_agent, num_step, -1)
1159
+ type_emb[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = self.type_a_emb(torch.tensor(self.all_agent_type.index('invalid'), device=pos_a.device).long())
1160
+ shape_emb[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = self.shape_emb(torch.full((1, 3), self.invalid_shape_value, device=pos_a.device))
1161
+ categorical_embs = [type_emb.reshape(-1, self.hidden_dim), shape_emb.reshape(-1, self.hidden_dim)]
1162
+
1163
+ # FIXME: need to discuss!!!
1164
+ # if is_eos.any():
1165
+
1166
+ # pos_a[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = 0.
1167
+ # head_a[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = 0.
1168
+ # mask[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = False # to handle those newly-added agents
1169
+ # interact_mask[torch.cat([is_eos, torch.zeros(1, device=is_eos.device).bool()]), (self.num_historical_steps - 1) // self.shift + t + 1:] = False
1170
+
1171
+ # agent_token_emb[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = self.no_token_emb(torch.zeros(1, device=pos_a.device).long())
1172
+
1173
+ # type_emb = categorical_embs[0].reshape(num_agent, num_step, -1)
1174
+ # shape_emb = categorical_embs[1].reshape(num_agent, num_step, -1)
1175
+ # type_emb[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = self.type_a_emb(torch.tensor(self.all_agent_type.index('invalid'), device=pos_a.device).long())
1176
+ # shape_emb[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = self.shape_emb(torch.full((1, 3), self.invalid_shape_value, device=pos_a.device))
1177
+ # categorical_embs = [type_emb.reshape(-1, self.hidden_dim), shape_emb.reshape(-1, self.hidden_dim)]
1178
+
1179
+ # for sa in range(next_state_idx_seed.shape[0]):
1180
+ # if next_state_idx_seed[sa] == self.enter_state:
1181
+ # print(f"agent {sa} is entering at step {t}")
1182
+
1183
+ # insert new agents (from seed agent)
1184
+ seed_agent_index_cur_step = seed_agent_index_per_step[t]
1185
+ num_new_agent = min(len(seed_agent_index_cur_step), next_state_idx_seed.bool().sum())
1186
+ new_agent_mask = next_state_idx_seed.bool()
1187
+ next_state_idx_seed = next_state_idx_seed[new_agent_mask]
1188
+ next_state_idx_seed = next_state_idx_seed[:num_new_agent]
1189
+ next_type_idx_seed = next_type_idx_seed[new_agent_mask]
1190
+ next_type_idx_seed = next_type_idx_seed[:num_new_agent]
1191
+ selected_agent_index_cur_step = seed_agent_index_cur_step[:num_new_agent]
1192
+ agent_token_index = torch.cat([agent_token_index, data['agent']['token_idx'][selected_agent_index_cur_step]])
1193
+ agent_state_index = torch.cat([agent_state_index, data['agent']['state_idx'][selected_agent_index_cur_step]])
1194
+ agent_category = torch.cat([agent_category, data['agent']['category'][selected_agent_index_cur_step]])
1195
+ agent_valid_mask = torch.cat([agent_valid_mask, data['agent']['raw_agent_valid_mask'][selected_agent_index_cur_step]])
1196
+ gt_traj = torch.cat([gt_traj, data['agent']['position'][selected_agent_index_cur_step, self.num_historical_steps:, :self.input_dim]])
1197
+
1198
+ # FIXME: under test!!! bos token index is -2
1199
+ next_state_idx = torch.cat([next_state_idx, next_state_idx_seed], dim=0).long()
1200
+ next_token_idx = torch.cat([next_token_idx, torch.zeros(num_new_agent, device=next_token_idx.device) - 2], dim=0).long()
1201
+ mask = torch.cat([mask, torch.ones(num_new_agent, num_step, device=mask.device)], dim=0).bool()
1202
+ temporal_mask = torch.cat([temporal_mask, torch.ones(num_new_agent, num_step, device=temporal_mask.device)], dim=0).bool()
1203
+ interact_mask = torch.cat([interact_mask, torch.ones(num_new_agent, num_step, device=interact_mask.device)], dim=0).bool()
1204
+
1205
+ # new_pos_a = ego_pos_a[None].repeat(num_new_agent, 1, 1)
1206
+ # new_head_a = ego_head_a[None].repeat(num_new_agent, 1)
1207
+ new_pos_a = torch.zeros(num_new_agent, num_step, 2, device=pos_a.device)
1208
+ new_head_a = torch.zeros(num_new_agent, num_step, device=pos_a.device)
1209
+ new_state_a = torch.zeros(num_new_agent, num_step, device=state_a.device)
1210
+ new_shape_a = torch.full((num_new_agent, num_step, 3), self.invalid_shape_value, device=pos_a.device)
1211
+ new_type_a = torch.full((num_new_agent, num_step), self.all_agent_type.index('invalid'), device=pos_a.device)
1212
+
1213
+ if num_new_agent > 0:
1214
+ gt_bos_pos_a = data['agent']['position'][seed_agent_index_cur_step[:num_new_agent], (self.num_historical_steps - 1) // self.shift + t]
1215
+ new_pos_a[:, (self.num_historical_steps - 1) // self.shift + t] = gt_bos_pos_a[:, :2].clone()
1216
+ pos_a = torch.cat([pos_a, new_pos_a], dim=0)
1217
+
1218
+ gt_bos_head_a = data['agent']['heading'][seed_agent_index_cur_step[:num_new_agent], (self.num_historical_steps - 1) // self.shift + t]
1219
+ new_head_a[:, (self.num_historical_steps - 1) // self.shift + t] = gt_bos_head_a.clone()
1220
+ head_a = torch.cat([head_a, new_head_a], dim=0)
1221
+
1222
+ gt_bos_shape_a = data['agent']['shape'][seed_agent_index_cur_step[:num_new_agent], self.num_historical_steps - 1]
1223
+ gt_bos_type_a = data['agent']['type'][seed_agent_index_cur_step[:num_new_agent]]
1224
+ new_shape_a[:, (self.num_historical_steps - 1) // self.shift + t:] = gt_bos_shape_a.clone()[:, None]
1225
+ new_type_a[:, (self.num_historical_steps - 1) // self.shift + t:] = gt_bos_type_a.clone()[:, None]
1226
+ # new_type_a[:, (self.num_historical_steps - 1) // self.shift + t] = next_type_idx_seed
1227
+ pred_type = torch.cat([pred_type, new_type_a[:, (self.num_historical_steps - 1) // self.shift + t]])
1228
+
1229
+ new_state_a[:, (self.num_historical_steps - 1) // self.shift + t] = self.enter_state
1230
+ state_a = torch.cat([state_a, new_state_a], dim=0)
1231
+
1232
+ mask[-num_new_agent:, :(self.num_historical_steps - 1) // self.shift + t + 1] = 0
1233
+ interact_mask[-num_new_agent:, :(self.num_historical_steps - 1) // self.shift + t] = 0
1234
+
1235
+ # update all steps
1236
+ new_pred_traj = torch.zeros(num_new_agent, self.num_recurrent_steps_val, 2, device=pos_a.device)
1237
+ new_pred_traj[:, t * 5 : (t + 1) * 5] = new_pos_a[:, (self.num_historical_steps - 1) // self.shift + t][:, None].repeat(1, 5, 1)
1238
+ pred_traj = torch.cat([pred_traj, new_pred_traj], dim=0)
1239
+
1240
+ new_pred_head = torch.zeros(num_new_agent, self.num_recurrent_steps_val, device=pos_a.device)
1241
+ new_pred_head[:, t * 5 : (t + 1) * 5] = new_head_a[:, (self.num_historical_steps - 1) // self.shift + t][:, None].repeat(1, 5)
1242
+ pred_head = torch.cat([pred_head, new_pred_head], dim=0)
1243
+
1244
+ new_pred_state = torch.zeros(num_new_agent, self.num_recurrent_steps_val, device=pos_a.device)
1245
+ new_pred_state[:, t * 5 : (t + 1) * 5] = next_state_idx_seed[:, None].repeat(1, 5)
1246
+ pred_state = torch.cat([pred_state, new_pred_state], dim=0)
1247
+
1248
+ # handle the position/heading of bos token
1249
+ # bos_pl_pos = pos_pl[next_bos_pl_idx[is_bos].long()]
1250
+ # bos_offset_pos = discretize_neighboring(neighbor_index=next_bos_offset_idx[is_bos])
1251
+ # pos_a[is_bos, (self.num_historical_steps - 1) // self.shift + t] += (bos_pl_pos + bos_offset_pos)
1252
+ # # headings before bos token remains 0 which align with training process
1253
+ # head_a[is_bos, (self.num_historical_steps - 1) // self.shift + t] += 0.
1254
+
1255
+ # add new agents token embeddings
1256
+ agent_token_emb = torch.cat([agent_token_emb, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())[None, :].repeat(num_new_agent, num_step, 1)])
1257
+ veh_mask = torch.cat([veh_mask, next_type_idx_seed == self.seed_agent_type.index('veh')])
1258
+ ped_mask = torch.cat([ped_mask, next_type_idx_seed == self.seed_agent_type.index('ped')])
1259
+ cyc_mask = torch.cat([cyc_mask, next_type_idx_seed == self.seed_agent_type.index('cyc')])
1260
+
1261
+ # add new agents trajectory embeddings
1262
+ trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(pos_a.device).to(torch.float)
1263
+ trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(pos_a.device).to(torch.float)
1264
+ trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(pos_a.device).to(torch.float)
1265
+
1266
+ new_agent_token_traj_all = torch.zeros((num_new_agent, self.token_size, self.shift + 1, 4, 2), device=pos_a.device)
1267
+ trajectory_token_all_veh = torch.from_numpy(self.trajectory_token_all['veh']).clone().to(pos_a.device).to(torch.float)
1268
+ trajectory_token_all_ped = torch.from_numpy(self.trajectory_token_all['ped']).clone().to(pos_a.device).to(torch.float)
1269
+ trajectory_token_all_cyc = torch.from_numpy(self.trajectory_token_all['cyc']).clone().to(pos_a.device).to(torch.float)
1270
+ new_agent_token_traj_all[next_type_idx_seed == 0] = torch.cat(
1271
+ [trajectory_token_all_veh[:, :self.shift], trajectory_token_veh[:, None, ...]], dim=1)
1272
+ new_agent_token_traj_all[next_type_idx_seed == 1] = torch.cat(
1273
+ [trajectory_token_all_ped[:, :self.shift], trajectory_token_ped[:, None, ...]], dim=1)
1274
+ new_agent_token_traj_all[next_type_idx_seed == 2] = torch.cat(
1275
+ [trajectory_token_all_cyc[:, :self.shift], trajectory_token_cyc[:, None, ...]], dim=1)
1276
+
1277
+ agent_token_traj_all = torch.cat([agent_token_traj_all, new_agent_token_traj_all], dim=0)
1278
+
1279
+ # add new agents categorical embeddings
1280
+ new_categorical_embs = [self.type_a_emb(new_type_a.reshape(-1).long()), self.shape_emb(new_shape_a.reshape(-1, 3))]
1281
+ categorical_embs = [torch.cat([categorical_embs[0], new_categorical_embs[0]], dim=0),
1282
+ torch.cat([categorical_embs[1], new_categorical_embs[1]], dim=0)]
1283
+
1284
+ # update token embeddings of current step
1285
+ agent_token_emb[veh_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_veh[
1286
+ next_token_idx[veh_mask]]
1287
+ agent_token_emb[ped_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_ped[
1288
+ next_token_idx[ped_mask]]
1289
+ agent_token_emb[cyc_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_cyc[
1290
+ next_token_idx[cyc_mask]]
1291
+
1292
+ motion_vector_a, head_vector_a = self.build_vector_a(pos_a, head_a, state_a)
1293
+
1294
+ motion_vector_a[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0.
1295
+ head_vector_a[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0.
1296
+ x_a = torch.stack(
1297
+ [torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1),
1298
+ angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2])], dim=-1)
1299
+
1300
+ x_b = x_a.clone()
1301
+ x_a = self.x_a_emb(continuous_inputs=x_a.view(-1, x_a.size(-1)),
1302
+ categorical_embs=categorical_embs)
1303
+ x_a = x_a.view(-1, num_step, self.hidden_dim)
1304
+
1305
+ s_a = self.state_a_emb(state_a.reshape(-1).long()).reshape(num_agent + num_new_agent, num_step, self.hidden_dim)
1306
+ feat_a = torch.cat((agent_token_emb, x_a, s_a), dim=-1)
1307
+ feat_a = self.fusion_emb(feat_a)
1308
+
1309
+ # if t >= 15:
1310
+ # print(f"inference {t}")
1311
+ # is_invalid = state_a == self.invalid_state
1312
+ # is_bos = state_a == self.enter_state
1313
+ # is_eos = state_a == self.exit_state
1314
+ # bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
1315
+ # eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1))
1316
+ # mask = torch.arange(num_step).expand(num_agent + num_new_agent, -1).to(state_a.device)
1317
+ # mask = (mask >= bos_index[:, None]) & (mask <= eos_index[:, None] + 1)
1318
+ # is_invalid[mask] = False
1319
+ # is_invalid[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = False
1320
+ # print(pos_a[:, :((self.num_historical_steps - 1) // self.shift + 1 + t)])
1321
+ # print(state_a[:, :((self.num_historical_steps - 1) // self.shift + 1 + t)])
1322
+ # print(pos_a[is_invalid][:, 0])
1323
+ # print(head_a[is_invalid])
1324
+ # print(categorical_embs[0].sum(dim=-1)[is_invalid.reshape(-1)])
1325
+ # print(categorical_embs[1].sum(dim=-1)[is_invalid.reshape(-1)])
1326
+ # print(motion_vector_a[is_invalid][:, 0])
1327
+ # print(head_vector_a[is_invalid][:, 0])
1328
+ # print(x_b.sum(dim=-1)[is_invalid])
1329
+ # print(x_a.sum(dim=-1)[is_invalid])
1330
+ # for a in range(state_a.shape[0]):
1331
+ # print(f"agent: {a}")
1332
+ # print(state_a[a])
1333
+ # print(is_invalid[a].long())
1334
+ # print(pos_a[a, :, 0])
1335
+ # print(motion_vector_a[a, :, 0])
1336
+ # print(s_a.sum(dim=-1)[is_invalid])
1337
+ # print(feat_a.sum(dim=-1)[is_invalid])
1338
+
1339
+ # replace the features of steps before bos of valid agents with the corresponding seed agent features
1340
+ # is_bos = state_a == self.enter_state
1341
+ # bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(num_step))
1342
+ # before_bos_mask = torch.arange(num_step).expand(num_agent + num_new_agent, -1).to(state_a.device) < bos_index[:, None]
1343
+ # feat_a[before_bos_mask] = feat_seed.repeat(num_agent + num_new_agent, 1, 1)[before_bos_mask]
1344
+
1345
+ # build seed agent features
1346
+ motion_vector_seed = motion_vector_a[av_index : av_index + 1]
1347
+ head_vector_seed = head_vector_a[av_index : av_index + 1]
1348
+ feat_seed = self.build_invalid_agent_feature(num_step, pos_a.device, type_index=self.all_agent_type.index('seed'),
1349
+ motion_vector=motion_vector_seed, head_vector=head_vector_seed)
1350
+ # print(f"inference {t}")
1351
+ # print(feat_seed.sum(dim=-1))
1352
+
1353
+ next_token_idx_list.append(next_token_idx[:, None])
1354
+ next_state_idx_list.append(next_state_idx[:, None])
1355
+ # next_bos_pl_idx_list.append(next_bos_pl_idx[:, None])
1356
+ # next_bos_offset_idx_list.append(next_bos_offset_idx[:, None])
1357
+
1358
+ # TODO: check this
1359
+ # agent_valid_mask[agent_category != 3] = False
1360
+
1361
+ # print("inference")
1362
+ # is_invalid = state_a == self.invalid_state
1363
+ # is_bos = state_a == self.enter_state
1364
+ # is_eos = state_a == self.exit_state
1365
+ # bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
1366
+ # eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1))
1367
+ # mask = torch.arange(num_step).expand(num_agent, -1).to(state_a.device)
1368
+ # mask = (mask >= bos_index[:, None]) & (mask <= eos_index[:, None] + 1)
1369
+ # is_invalid[mask] = False
1370
+ # print(feat_a.sum(dim=-1)[is_invalid])
1371
+ # print(pos_a[is_invalid][: 0])
1372
+ # print(head_a[is_invalid])
1373
+ # exit(1)
1374
+
1375
+ num_agent = pos_a.shape[0]
1376
+ for i in range(len(next_token_idx_list)):
1377
+ next_token_idx_list[i] = torch.cat([next_token_idx_list[i], torch.zeros(num_agent - next_token_idx_list[i].shape[0], 1, device=next_token_idx_list[i].device) - 1], dim=0).long()
1378
+ next_state_idx_list[i] = torch.cat([next_state_idx_list[i], torch.zeros(num_agent - next_state_idx_list[i].shape[0], 1, device=next_state_idx_list[i].device)], dim=0).long()
1379
+
1380
+ # eval mask
1381
+ next_token_eval_mask = agent_valid_mask.clone()
1382
+ next_state_eval_mask = agent_valid_mask.clone()
1383
+ bos_token_index = torch.nonzero(agent_state_index == self.enter_state)
1384
+ eos_token_index = torch.nonzero(agent_state_index == self.exit_state)
1385
+
1386
+ next_token_eval_mask[bos_token_index[:, 0], bos_token_index[:, 1]] = 1
1387
+
1388
+ for bos_token_index_i in bos_token_index:
1389
+ next_state_eval_mask[bos_token_index_i[0], :bos_token_index_i[1] + 2] = 1
1390
+ for eos_token_index_i in eos_token_index:
1391
+ next_state_eval_mask[eos_token_index_i[0], eos_token_index_i[1]:] = 1
1392
+
1393
+ # add history attributes
1394
+ num_agent = pred_traj.shape[0]
1395
+ num_init_agent = filter_mask.sum()
1396
+
1397
+ pred_traj = torch.cat([pred_traj, torch.zeros(num_agent, self.num_historical_steps - 1, *(pred_traj.shape[2:]), device=pred_traj.device)], dim=1)
1398
+ pred_head = torch.cat([pred_head, torch.zeros(num_agent, self.num_historical_steps - 1, *(pred_head.shape[2:]), device=pred_head.device)], dim=1)
1399
+ pred_state = torch.cat([pred_state, torch.zeros(num_agent, self.num_historical_steps - 1, *(pred_state.shape[2:]), device=pred_state.device)], dim=1)
1400
+
1401
+ pred_state[:num_init_agent, :self.num_historical_steps - 1] = data['agent']['state_idx'][filter_mask, :(self.num_historical_steps - 1) // self.shift].repeat_interleave(repeats=self.shift, dim=1)
1402
+
1403
+ historical_token_idx = data['agent']['token_idx'][filter_mask, :(self.num_historical_steps - 1) // self.shift]
1404
+ historical_token_idx[historical_token_idx < 0] = 0
1405
+ historical_token_traj_all = torch.gather(agent_token_traj_all, 1, historical_token_idx[..., None, None, None].expand(-1, -1, 6, 4, 2))
1406
+ init_theta = head_a[:num_init_agent, 0]
1407
+ cos, sin = init_theta.cos(), init_theta.sin()
1408
+ rot_mat = torch.zeros((num_init_agent, 2, 2), device=init_theta.device)
1409
+ rot_mat[:, 0, 0] = cos
1410
+ rot_mat[:, 0, 1] = sin
1411
+ rot_mat[:, 1, 0] = -sin
1412
+ rot_mat[:, 1, 1] = cos
1413
+ historical_token_traj_all = torch.bmm(historical_token_traj_all.view(-1, 4, 2),
1414
+ rot_mat[:, None, None, ...].repeat(1, (self.num_historical_steps - 1) // self.shift, self.shift + 1, 1, 1).view(
1415
+ -1, 2, 2)).view(num_init_agent, (self.num_historical_steps - 1) // self.shift, self.shift + 1, 4, 2)
1416
+ historical_token_traj_all = historical_token_traj_all + pos_a[:num_init_agent, 0, :][:, None, None, None, ...]
1417
+ pred_traj[:num_init_agent, :self.num_historical_steps - 1] = historical_token_traj_all[:, :, 1:, ...].clone().mean(dim=3).reshape(num_init_agent, -1, 2)
1418
+ diff_xy = historical_token_traj_all[..., 1:, 0, :] - historical_token_traj_all[..., 1:, 3, :]
1419
+ pred_head[:num_init_agent, :self.num_historical_steps - 1] = torch.arctan2(diff_xy[..., 1], diff_xy[..., 0]).reshape(num_init_agent, -1)
1420
+
1421
+ return {
1422
+ 'av_index': av_index,
1423
+ 'valid_mask': agent_valid_mask[:, self.num_historical_steps:],
1424
+ 'pos_a': pos_a[:, (self.num_historical_steps - 1) // self.shift:],
1425
+ 'head_a': head_a[:, (self.num_historical_steps - 1) // self.shift:],
1426
+ 'gt_traj': gt_traj,
1427
+ 'pred_traj': pred_traj,
1428
+ 'pred_head': pred_head,
1429
+ 'pred_type': list(map(lambda i: self.seed_agent_type[i], pred_type.tolist())),
1430
+ 'pred_state': pred_state,
1431
+ 'next_token_idx': torch.cat(next_token_idx_list, dim=-1), # (num_agent, num_step)
1432
+ 'next_token_idx_gt': agent_token_index,
1433
+ 'next_state_idx': torch.cat(next_state_idx_list, dim=-1) if len(next_state_idx_list) > 0 else None,
1434
+ 'next_state_idx_gt': agent_state_index,
1435
+ 'next_token_eval_mask': next_token_eval_mask,
1436
+ 'next_state_eval_mask': next_state_eval_mask,
1437
+ # 'next_bos_pl_idx': torch.cat(next_bos_pl_idx_list, dim=-1),
1438
+ # 'next_bos_offset_idx': torch.cat(next_bos_offset_idx_list, dim=-1),
1439
+ }
backups/dev/modules/layers.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from typing import List, Optional, Tuple, Union
7
+ from torch_geometric.nn.conv import MessagePassing
8
+ from torch_geometric.utils import softmax
9
+
10
+ from dev.utils.func import weight_init
11
+
12
+
13
+ __all__ = ['AttentionLayer', 'FourierEmbedding', 'MLPEmbedding', 'MLPLayer', 'MappingNetwork']
14
+
15
+
16
+ class AttentionLayer(MessagePassing):
17
+
18
+ def __init__(self,
19
+ hidden_dim: int,
20
+ num_heads: int,
21
+ head_dim: int,
22
+ dropout: float,
23
+ bipartite: bool,
24
+ has_pos_emb: bool,
25
+ **kwargs) -> None:
26
+ super(AttentionLayer, self).__init__(aggr='add', node_dim=0, **kwargs)
27
+ self.num_heads = num_heads
28
+ self.head_dim = head_dim
29
+ self.has_pos_emb = has_pos_emb
30
+ self.scale = head_dim ** -0.5
31
+
32
+ self.to_q = nn.Linear(hidden_dim, head_dim * num_heads)
33
+ self.to_k = nn.Linear(hidden_dim, head_dim * num_heads, bias=False)
34
+ self.to_v = nn.Linear(hidden_dim, head_dim * num_heads)
35
+ if has_pos_emb:
36
+ self.to_k_r = nn.Linear(hidden_dim, head_dim * num_heads, bias=False)
37
+ self.to_v_r = nn.Linear(hidden_dim, head_dim * num_heads)
38
+ self.to_s = nn.Linear(hidden_dim, head_dim * num_heads)
39
+ self.to_g = nn.Linear(head_dim * num_heads + hidden_dim, head_dim * num_heads)
40
+ self.to_out = nn.Linear(head_dim * num_heads, hidden_dim)
41
+ self.attn_drop = nn.Dropout(dropout)
42
+ self.ff_mlp = nn.Sequential(
43
+ nn.Linear(hidden_dim, hidden_dim * 4),
44
+ nn.ReLU(inplace=True),
45
+ nn.Dropout(dropout),
46
+ nn.Linear(hidden_dim * 4, hidden_dim),
47
+ )
48
+ if bipartite:
49
+ self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim)
50
+ self.attn_prenorm_x_dst = nn.LayerNorm(hidden_dim)
51
+ else:
52
+ self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim)
53
+ self.attn_prenorm_x_dst = self.attn_prenorm_x_src
54
+ if has_pos_emb:
55
+ self.attn_prenorm_r = nn.LayerNorm(hidden_dim)
56
+ self.attn_postnorm = nn.LayerNorm(hidden_dim)
57
+ self.ff_prenorm = nn.LayerNorm(hidden_dim)
58
+ self.ff_postnorm = nn.LayerNorm(hidden_dim)
59
+ self.apply(weight_init)
60
+
61
+ def forward(self,
62
+ x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
63
+ r: Optional[torch.Tensor],
64
+ edge_index: torch.Tensor) -> torch.Tensor:
65
+ if isinstance(x, torch.Tensor):
66
+ x_src = x_dst = self.attn_prenorm_x_src(x)
67
+ else:
68
+ x_src, x_dst = x
69
+ x_src = self.attn_prenorm_x_src(x_src)
70
+ x_dst = self.attn_prenorm_x_dst(x_dst)
71
+ x = x[1]
72
+ if self.has_pos_emb and r is not None:
73
+ r = self.attn_prenorm_r(r)
74
+ x = x + self.attn_postnorm(self._attn_block(x_src, x_dst, r, edge_index))
75
+ x = x + self.ff_postnorm(self._ff_block(self.ff_prenorm(x)))
76
+ return x
77
+
78
+ def message(self,
79
+ q_i: torch.Tensor,
80
+ k_j: torch.Tensor,
81
+ v_j: torch.Tensor,
82
+ r: Optional[torch.Tensor],
83
+ index: torch.Tensor,
84
+ ptr: Optional[torch.Tensor]) -> torch.Tensor:
85
+ if self.has_pos_emb and r is not None:
86
+ k_j = k_j + self.to_k_r(r).view(-1, self.num_heads, self.head_dim)
87
+ v_j = v_j + self.to_v_r(r).view(-1, self.num_heads, self.head_dim)
88
+ sim = (q_i * k_j).sum(dim=-1) * self.scale
89
+ attn = softmax(sim, index, ptr)
90
+ self.attention_weight = attn.sum(-1).detach()
91
+ attn = self.attn_drop(attn)
92
+ return v_j * attn.unsqueeze(-1)
93
+
94
+ def update(self,
95
+ inputs: torch.Tensor,
96
+ x_dst: torch.Tensor) -> torch.Tensor:
97
+ inputs = inputs.view(-1, self.num_heads * self.head_dim)
98
+ g = torch.sigmoid(self.to_g(torch.cat([inputs, x_dst], dim=-1)))
99
+ return inputs + g * (self.to_s(x_dst) - inputs)
100
+
101
+ def _attn_block(self,
102
+ x_src: torch.Tensor,
103
+ x_dst: torch.Tensor,
104
+ r: Optional[torch.Tensor],
105
+ edge_index: torch.Tensor) -> torch.Tensor:
106
+ q = self.to_q(x_dst).view(-1, self.num_heads, self.head_dim)
107
+ k = self.to_k(x_src).view(-1, self.num_heads, self.head_dim)
108
+ v = self.to_v(x_src).view(-1, self.num_heads, self.head_dim)
109
+ agg = self.propagate(edge_index=edge_index, x_dst=x_dst, q=q, k=k, v=v, r=r)
110
+ return self.to_out(agg)
111
+
112
+ def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
113
+ return self.ff_mlp(x)
114
+
115
+
116
+ class FourierEmbedding(nn.Module):
117
+
118
+ def __init__(self,
119
+ input_dim: int,
120
+ hidden_dim: int,
121
+ num_freq_bands: int) -> None:
122
+ super(FourierEmbedding, self).__init__()
123
+ self.input_dim = input_dim
124
+ self.hidden_dim = hidden_dim
125
+
126
+ self.freqs = nn.Embedding(input_dim, num_freq_bands) if input_dim != 0 else None
127
+ self.mlps = nn.ModuleList(
128
+ [nn.Sequential(
129
+ nn.Linear(num_freq_bands * 2 + 1, hidden_dim),
130
+ nn.LayerNorm(hidden_dim),
131
+ nn.ReLU(inplace=True),
132
+ nn.Linear(hidden_dim, hidden_dim),
133
+ )
134
+ for _ in range(input_dim)])
135
+ self.to_out = nn.Sequential(
136
+ nn.LayerNorm(hidden_dim),
137
+ nn.ReLU(inplace=True),
138
+ nn.Linear(hidden_dim, hidden_dim),
139
+ )
140
+ self.apply(weight_init)
141
+
142
+ def forward(self,
143
+ continuous_inputs: Optional[torch.Tensor] = None,
144
+ categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
145
+ if continuous_inputs is None:
146
+ if categorical_embs is not None:
147
+ x = torch.stack(categorical_embs).sum(dim=0)
148
+ else:
149
+ raise ValueError('Both continuous_inputs and categorical_embs are None')
150
+ else:
151
+ x = continuous_inputs.unsqueeze(-1) * self.freqs.weight * 2 * math.pi
152
+ # Warning: if your data are noisy, don't use learnable sinusoidal embedding
153
+ x = torch.cat([x.cos(), x.sin(), continuous_inputs.unsqueeze(-1)], dim=-1)
154
+ continuous_embs: List[Optional[torch.Tensor]] = [None] * self.input_dim
155
+ for i in range(self.input_dim):
156
+ continuous_embs[i] = self.mlps[i](x[:, i])
157
+ x = torch.stack(continuous_embs).sum(dim=0)
158
+ if categorical_embs is not None:
159
+ x = x + torch.stack(categorical_embs).sum(dim=0)
160
+ return self.to_out(x)
161
+
162
+
163
+ class MLPEmbedding(nn.Module):
164
+ def __init__(self,
165
+ input_dim: int,
166
+ hidden_dim: int) -> None:
167
+ super(MLPEmbedding, self).__init__()
168
+ self.input_dim = input_dim
169
+ self.hidden_dim = hidden_dim
170
+ self.mlp = nn.Sequential(
171
+ nn.Linear(input_dim, 128),
172
+ nn.LayerNorm(128),
173
+ nn.ReLU(inplace=True),
174
+ nn.Linear(128, hidden_dim),
175
+ nn.LayerNorm(hidden_dim),
176
+ nn.ReLU(inplace=True),
177
+ nn.Linear(hidden_dim, hidden_dim))
178
+ self.apply(weight_init)
179
+
180
+ def forward(self,
181
+ continuous_inputs: Optional[torch.Tensor] = None,
182
+ categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
183
+ if continuous_inputs is None:
184
+ if categorical_embs is not None:
185
+ x = torch.stack(categorical_embs).sum(dim=0)
186
+ else:
187
+ raise ValueError('Both continuous_inputs and categorical_embs are None')
188
+ else:
189
+ x = self.mlp(continuous_inputs)
190
+ if categorical_embs is not None:
191
+ x = x + torch.stack(categorical_embs).sum(dim=0)
192
+ return x
193
+
194
+
195
+ class MLPLayer(nn.Module):
196
+
197
+ def __init__(self,
198
+ input_dim: int,
199
+ hidden_dim: int=None,
200
+ output_dim: int=None) -> None:
201
+ super(MLPLayer, self).__init__()
202
+
203
+ if hidden_dim is None:
204
+ hidden_dim = output_dim
205
+
206
+ self.mlp = nn.Sequential(
207
+ nn.Linear(input_dim, hidden_dim),
208
+ nn.LayerNorm(hidden_dim),
209
+ nn.ReLU(inplace=True),
210
+ nn.Linear(hidden_dim, output_dim),
211
+ )
212
+ self.apply(weight_init)
213
+
214
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
215
+ return self.mlp(x)
216
+
217
+
218
+ class MappingNetwork(nn.Module):
219
+ def __init__(self, z_dim, w_dim, layer_dim=None, num_layers=8):
220
+ super().__init__()
221
+
222
+ if not layer_dim:
223
+ layer_dim = w_dim
224
+ layer_dims = [z_dim] + [layer_dim] * (num_layers - 1) + [w_dim]
225
+
226
+ layers = []
227
+ for i in range(num_layers):
228
+ layers.extend([
229
+ nn.Linear(layer_dims[i], layer_dims[i + 1]),
230
+ nn.LeakyReLU(),
231
+ ])
232
+ self.layers = nn.Sequential(*layers)
233
+
234
+ def forward(self, z):
235
+ w = self.layers(z)
236
+ return w
237
+
238
+
239
+ # class FocalLoss:
240
+ # def __init__(self, alpha: float=.25, gamma: float=2):
241
+ # self.alpha = alpha
242
+ # self.gamma = gamma
243
+
244
+ # def __call__(self, inputs, targets):
245
+ # prob = inputs.sigmoid()
246
+ # ce_loss = F.binary_cross_entropy_with_logits(inputs, targets.float(), reduction='none')
247
+ # p_t = prob * targets + (1 - prob) * (1 - targets)
248
+ # loss = ce_loss * ((1 - p_t) ** self.gamma)
249
+
250
+ # if self.alpha >= 0:
251
+ # alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
252
+ # loss = alpha_t * loss
253
+
254
+ # return loss.mean()
255
+
256
+
257
+ class FocalLoss(nn.Module):
258
+ """Focal Loss, as described in https://arxiv.org/abs/1708.02002.
259
+ It is essentially an enhancement to cross entropy loss and is
260
+ useful for classification tasks when there is a large class imbalance.
261
+ x is expected to contain raw, unnormalized scores for each class.
262
+ y is expected to contain class labels.
263
+ Shape:
264
+ - x: (batch_size, C) or (batch_size, C, d1, d2, ..., dK), K > 0.
265
+ - y: (batch_size,) or (batch_size, d1, d2, ..., dK), K > 0.
266
+ """
267
+
268
+ def __init__(
269
+ self,
270
+ alpha: Optional[torch.Tensor] = None,
271
+ gamma: float = 0.0,
272
+ reduction: str = "mean",
273
+ ignore_index: int = -100,
274
+ ):
275
+ """Constructor.
276
+ Args:
277
+ alpha (Tensor, optional): Weights for each class. Defaults to None.
278
+ gamma (float, optional): A constant, as described in the paper.
279
+ Defaults to 0.
280
+ reduction (str, optional): 'mean', 'sum' or 'none'.
281
+ Defaults to 'mean'.
282
+ ignore_index (int, optional): class label to ignore.
283
+ Defaults to -100.
284
+ """
285
+ if reduction not in ("mean", "sum", "none"):
286
+ raise ValueError('Reduction must be one of: "mean", "sum", "none".')
287
+
288
+ super().__init__()
289
+ self.alpha = alpha
290
+ self.gamma = gamma
291
+ self.ignore_index = ignore_index
292
+ self.reduction = reduction
293
+
294
+ self.nll_loss = nn.NLLLoss(
295
+ weight=alpha, reduction="none", ignore_index=ignore_index
296
+ )
297
+
298
+ def __repr__(self):
299
+ arg_keys = ["alpha", "gamma", "ignore_index", "reduction"]
300
+ arg_vals = [self.__dict__[k] for k in arg_keys]
301
+ arg_strs = [f"{k}={v}" for k, v in zip(arg_keys, arg_vals)]
302
+ arg_str = ", ".join(arg_strs)
303
+ return f"{type(self).__name__}({arg_str})"
304
+
305
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
306
+ if x.ndim > 2:
307
+ # (N, C, d1, d2, ..., dK) --> (N * d1 * ... * dK, C)
308
+ c = x.shape[1]
309
+ x = x.permute(0, *range(2, x.ndim), 1).reshape(-1, c)
310
+ # (N, d1, d2, ..., dK) --> (N * d1 * ... * dK,)
311
+ y = y.view(-1)
312
+
313
+ unignored_mask = y != self.ignore_index
314
+ y = y[unignored_mask]
315
+ if len(y) == 0:
316
+ return 0.0
317
+ x = x[unignored_mask]
318
+
319
+ # compute weighted cross entropy term: -alpha * log(pt)
320
+ # (alpha is already part of self.nll_loss)
321
+ log_p = F.log_softmax(x, dim=-1)
322
+ ce = self.nll_loss(log_p, y)
323
+
324
+ # get true class column from each row
325
+ all_rows = torch.arange(len(x))
326
+ log_pt = log_p[all_rows, y]
327
+
328
+ # compute focal term: (1 - pt)^gamma
329
+ pt = log_pt.exp()
330
+ focal_term = (1 - pt) ** self.gamma
331
+
332
+ # the full loss: -alpha * ((1 - pt)^gamma) * log(pt)
333
+ loss = focal_term * ce
334
+
335
+ if self.reduction == "mean":
336
+ loss = loss.mean()
337
+ elif self.reduction == "sum":
338
+ loss = loss.sum()
339
+
340
+ return loss
341
+
342
+
343
+ class OccLoss(nn.Module):
344
+
345
+ # geo_scal_loss
346
+ def __init__(self):
347
+ super().__init__()
348
+
349
+ def forward(self, pred, target, mask=None):
350
+
351
+ nonempty_probs = torch.sigmoid(pred)
352
+ empty_probs = 1 - nonempty_probs
353
+
354
+ if mask is None:
355
+ mask = torch.ones_like(target).bool()
356
+
357
+ nonempty_target = target == 1
358
+ nonempty_target = nonempty_target[mask].float()
359
+ nonempty_probs = nonempty_probs[mask]
360
+ empty_probs = empty_probs[mask]
361
+
362
+ intersection = (nonempty_target * nonempty_probs).sum()
363
+ precision = intersection / nonempty_probs.sum()
364
+ recall = intersection / nonempty_target.sum()
365
+ spec = ((1 - nonempty_target) * (empty_probs)).sum() / (1 - nonempty_target).sum()
366
+
367
+ return (
368
+ F.binary_cross_entropy(precision, torch.ones_like(precision))
369
+ + F.binary_cross_entropy(recall, torch.ones_like(recall))
370
+ + F.binary_cross_entropy(spec, torch.ones_like(spec))
371
+ )
backups/dev/modules/map_decoder.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch_cluster import radius_graph
5
+ from torch_geometric.data import Batch
6
+ from torch_geometric.data import HeteroData
7
+ from torch_geometric.utils import subgraph
8
+
9
+ from dev.modules.layers import MLPLayer, AttentionLayer, FourierEmbedding, MLPEmbedding
10
+ from dev.utils.func import weight_init, wrap_angle, angle_between_2d_vectors
11
+
12
+
13
+ class SMARTMapDecoder(nn.Module):
14
+
15
+ def __init__(self,
16
+ dataset: str,
17
+ input_dim: int,
18
+ hidden_dim: int,
19
+ num_historical_steps: int,
20
+ pl2pl_radius: float,
21
+ num_freq_bands: int,
22
+ num_layers: int,
23
+ num_heads: int,
24
+ head_dim: int,
25
+ dropout: float,
26
+ map_token) -> None:
27
+
28
+ super(SMARTMapDecoder, self).__init__()
29
+ self.dataset = dataset
30
+ self.input_dim = input_dim
31
+ self.hidden_dim = hidden_dim
32
+ self.num_historical_steps = num_historical_steps
33
+ self.pl2pl_radius = pl2pl_radius
34
+ self.num_freq_bands = num_freq_bands
35
+ self.num_layers = num_layers
36
+ self.num_heads = num_heads
37
+ self.head_dim = head_dim
38
+ self.dropout = dropout
39
+
40
+ if input_dim == 2:
41
+ input_dim_r_pt2pt = 3
42
+ elif input_dim == 3:
43
+ input_dim_r_pt2pt = 4
44
+ else:
45
+ raise ValueError('{} is not a valid dimension'.format(input_dim))
46
+
47
+ self.type_pt_emb = nn.Embedding(17, hidden_dim)
48
+ self.side_pt_emb = nn.Embedding(4, hidden_dim)
49
+ self.polygon_type_emb = nn.Embedding(4, hidden_dim)
50
+ self.light_pl_emb = nn.Embedding(4, hidden_dim)
51
+
52
+ self.r_pt2pt_emb = FourierEmbedding(input_dim=input_dim_r_pt2pt, hidden_dim=hidden_dim,
53
+ num_freq_bands=num_freq_bands)
54
+ self.pt2pt_layers = nn.ModuleList(
55
+ [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
56
+ bipartite=False, has_pos_emb=True) for _ in range(num_layers)]
57
+ )
58
+ self.token_size = 1024
59
+ self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
60
+ output_dim=self.token_size)
61
+ input_dim_token = 22
62
+ self.token_emb = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
63
+ self.map_token = map_token
64
+ self.apply(weight_init)
65
+ self.mask_pt = False
66
+
67
+ def maybe_autocast(self, dtype=torch.float32):
68
+ return torch.cuda.amp.autocast(dtype=dtype)
69
+
70
+ def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]:
71
+ pt_valid_mask = data['pt_token']['pt_valid_mask']
72
+ pt_pred_mask = data['pt_token']['pt_pred_mask']
73
+ pt_target_mask = data['pt_token']['pt_target_mask']
74
+ mask_s = pt_valid_mask
75
+
76
+ pos_pt = data['pt_token']['position'][:, :self.input_dim].contiguous()
77
+ orient_pt = data['pt_token']['orientation'].contiguous()
78
+ orient_vector_pt = torch.stack([orient_pt.cos(), orient_pt.sin()], dim=-1)
79
+ token_sample_pt = self.map_token['traj_src'].to(pos_pt.device).to(torch.float)
80
+ pt_token_emb_src = self.token_emb(token_sample_pt.view(token_sample_pt.shape[0], -1))
81
+ pt_token_emb = pt_token_emb_src[data['pt_token']['token_idx']]
82
+
83
+ x_pt = pt_token_emb
84
+
85
+ token2pl = data[('pt_token', 'to', 'map_polygon')]['edge_index']
86
+ token_light_type = data['map_polygon']['light_type'][token2pl[1]]
87
+ x_pt_categorical_embs = [self.type_pt_emb(data['pt_token']['type'].long()),
88
+ self.polygon_type_emb(data['pt_token']['pl_type'].long()),
89
+ self.light_pl_emb(token_light_type.long()),]
90
+ x_pt = x_pt + torch.stack(x_pt_categorical_embs).sum(dim=0)
91
+ edge_index_pt2pt = radius_graph(x=pos_pt[:, :2], r=self.pl2pl_radius,
92
+ batch=data['pt_token']['batch'] if isinstance(data, Batch) else None,
93
+ loop=False, max_num_neighbors=100)
94
+ if self.mask_pt:
95
+ edge_index_pt2pt = subgraph(subset=mask_s, edge_index=edge_index_pt2pt)[0]
96
+ rel_pos_pt2pt = pos_pt[edge_index_pt2pt[0]] - pos_pt[edge_index_pt2pt[1]]
97
+ rel_orient_pt2pt = wrap_angle(orient_pt[edge_index_pt2pt[0]] - orient_pt[edge_index_pt2pt[1]])
98
+ if self.input_dim == 2:
99
+ r_pt2pt = torch.stack(
100
+ [torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1),
101
+ angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]],
102
+ nbr_vector=rel_pos_pt2pt[:, :2]),
103
+ rel_orient_pt2pt], dim=-1)
104
+ elif self.input_dim == 3:
105
+ r_pt2pt = torch.stack(
106
+ [torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1),
107
+ angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]],
108
+ nbr_vector=rel_pos_pt2pt[:, :2]),
109
+ rel_pos_pt2pt[:, -1],
110
+ rel_orient_pt2pt], dim=-1)
111
+ else:
112
+ raise ValueError('{} is not a valid dimension'.format(self.input_dim))
113
+
114
+ # layers
115
+ r_pt2pt = self.r_pt2pt_emb(continuous_inputs=r_pt2pt, categorical_embs=None)
116
+ for i in range(self.num_layers):
117
+ x_pt = self.pt2pt_layers[i](x_pt, r_pt2pt, edge_index_pt2pt)
118
+
119
+ next_token_prob = self.token_predict_head(x_pt[pt_pred_mask])
120
+ next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)
121
+ _, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1)
122
+ next_token_index_gt = data['pt_token']['token_idx'][pt_target_mask]
123
+
124
+ return {
125
+ 'x_pt': x_pt,
126
+ 'map_next_token_idx': next_token_idx,
127
+ 'map_next_token_prob': next_token_prob,
128
+ 'map_next_token_idx_gt': next_token_index_gt,
129
+ 'map_next_token_eval_mask': pt_pred_mask[pt_pred_mask]
130
+ }
backups/dev/modules/occ_decoder.py ADDED
@@ -0,0 +1,927 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from typing import Dict, Mapping, Optional, Literal
7
+ from torch_cluster import radius, radius_graph
8
+ from torch_geometric.data import HeteroData, Batch
9
+ from torch_geometric.utils import dense_to_sparse, subgraph
10
+ from scipy.optimize import linear_sum_assignment
11
+
12
+ from dev.modules.attr_tokenizer import Attr_Tokenizer
13
+ from dev.modules.layers import *
14
+ from dev.utils.visualization import *
15
+ from dev.utils.func import angle_between_2d_vectors, wrap_angle, weight_init
16
+
17
+
18
+ class SMARTOccDecoder(nn.Module):
19
+
20
+ def __init__(self,
21
+ dataset: str,
22
+ input_dim: int,
23
+ hidden_dim: int,
24
+ num_historical_steps: int,
25
+ time_span: Optional[int],
26
+ pl2a_radius: float,
27
+ pl2seed_radius: float,
28
+ a2a_radius: float,
29
+ a2sa_radius: float,
30
+ pl2sa_radius: float,
31
+ num_freq_bands: int,
32
+ num_layers: int,
33
+ num_heads: int,
34
+ head_dim: int,
35
+ dropout: float,
36
+ token_data: Dict,
37
+ token_size: int,
38
+ special_token_index: list=[],
39
+ attr_tokenizer: Attr_Tokenizer=None,
40
+ predict_motion: bool=False,
41
+ predict_state: bool=False,
42
+ predict_map: bool=False,
43
+ predict_occ: bool=False,
44
+ state_token: Dict[str, int]=None,
45
+ seed_size: int=5,
46
+ buffer_size: int=32,
47
+ loss_weight: dict=None,
48
+ logger=None) -> None:
49
+
50
+ super(SMARTOccDecoder, self).__init__()
51
+ self.dataset = dataset
52
+ self.input_dim = input_dim
53
+ self.hidden_dim = hidden_dim
54
+ self.num_historical_steps = num_historical_steps
55
+ self.time_span = time_span if time_span is not None else num_historical_steps
56
+ self.pl2a_radius = pl2a_radius
57
+ self.pl2seed_radius = pl2seed_radius
58
+ self.a2a_radius = a2a_radius
59
+ self.a2sa_radius = a2sa_radius
60
+ self.pl2sa_radius = pl2sa_radius
61
+ self.num_freq_bands = num_freq_bands
62
+ self.num_layers = num_layers
63
+ self.num_heads = num_heads
64
+ self.head_dim = head_dim
65
+ self.dropout = dropout
66
+ self.special_token_index = special_token_index
67
+ self.predict_motion = predict_motion
68
+ self.predict_state = predict_state
69
+ self.predict_map = predict_map
70
+ self.predict_occ = predict_occ
71
+ self.loss_weight = loss_weight
72
+ self.logger = logger
73
+
74
+ self.attr_tokenizer = attr_tokenizer
75
+
76
+ # state tokens
77
+ self.state_type = list(state_token.keys())
78
+ self.state_token = state_token
79
+ self.invalid_state = int(state_token['invalid'])
80
+ self.valid_state = int(state_token['valid'])
81
+ self.enter_state = int(state_token['enter'])
82
+ self.exit_state = int(state_token['exit'])
83
+
84
+ self.seed_state_type = ['invalid', 'enter']
85
+ self.valid_state_type = ['invalid', 'valid', 'exit']
86
+
87
+ input_dim_r_pt2a = 3
88
+ input_dim_r_a2a = 3
89
+
90
+ self.seed_size = seed_size
91
+ self.buffer_size = buffer_size
92
+
93
+ self.agent_type = ['veh', 'ped', 'cyc', 'seed']
94
+ self.type_a_emb = nn.Embedding(len(self.agent_type), hidden_dim)
95
+ self.shape_emb = MLPEmbedding(input_dim=3, hidden_dim=hidden_dim)
96
+ self.state_a_emb = nn.Embedding(len(self.state_type), hidden_dim)
97
+ self.motion_gap = 1.
98
+ self.heading_gap = 1.
99
+ self.invalid_shape_value = .1
100
+ self.invalid_motion_value = -2.
101
+ self.invalid_head_value = -2.
102
+
103
+ self.r_pt2a_emb = FourierEmbedding(input_dim=input_dim_r_pt2a, hidden_dim=hidden_dim,
104
+ num_freq_bands=num_freq_bands)
105
+ self.r_a2a_emb = FourierEmbedding(input_dim=input_dim_r_a2a, hidden_dim=hidden_dim,
106
+ num_freq_bands=num_freq_bands)
107
+
108
+ self.token_size = token_size # 2048
109
+ self.grid_size = self.attr_tokenizer.grid_size
110
+ self.angle_size = self.attr_tokenizer.angle_size
111
+ self.agent_limit = 3
112
+ self.pt_limit = 10
113
+ self.grid_agent_occ_head = MLPLayer(input_dim=hidden_dim, hidden_dim=self.grid_size,
114
+ output_dim=self.agent_limit * self.grid_size)
115
+ self.grid_pt_occ_head = MLPLayer(input_dim=hidden_dim, hidden_dim=self.grid_size,
116
+ output_dim=self.pt_limit * self.grid_size)
117
+
118
+ # self.num_seed_feature = 1
119
+ # self.num_seed_feature = self.seed_size
120
+ self.num_seed_feature = 10
121
+
122
+ self.trajectory_token = token_data['token'] # dict('veh', 'ped', 'cyc') (2048, 4, 2)
123
+ self.trajectory_token_traj = token_data['traj'] # (2048, 6, 3)
124
+ self.trajectory_token_all = token_data['token_all'] # (2048, 6, 4, 2)
125
+ self.apply(weight_init)
126
+
127
+ self.shift = 5
128
+ self.beam_size = 5
129
+ self.hist_mask = True
130
+ self.temporal_attn_to_invalid = False
131
+ self.use_rel = False
132
+
133
+ # seed agent
134
+ self.temporal_attn_seed = False
135
+ self.seed_attn_to_av = True
136
+ self.seed_use_ego_motion = False
137
+
138
+ def transform_rel(self, token_traj, prev_pos, prev_heading=None):
139
+ if prev_heading is None:
140
+ diff_xy = prev_pos[:, :, -1, :] - prev_pos[:, :, -2, :]
141
+ prev_heading = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])
142
+
143
+ num_agent, num_step, traj_num, traj_dim = token_traj.shape
144
+ cos, sin = prev_heading.cos(), prev_heading.sin()
145
+ rot_mat = torch.zeros((num_agent, num_step, 2, 2), device=prev_heading.device)
146
+ rot_mat[:, :, 0, 0] = cos
147
+ rot_mat[:, :, 0, 1] = -sin
148
+ rot_mat[:, :, 1, 0] = sin
149
+ rot_mat[:, :, 1, 1] = cos
150
+ agent_diff_rel = torch.bmm(token_traj.view(-1, traj_num, 2), rot_mat.view(-1, 2, 2)).view(num_agent, num_step, traj_num, traj_dim)
151
+ agent_pred_rel = agent_diff_rel + prev_pos[:, :, -1:, :]
152
+ return agent_pred_rel
153
+
154
+ def _agent_token_embedding(self, data, agent_token_index, agent_state, agent_offset_token_idx, pos_a, head_a,
155
+ inference=False, filter_mask=None, av_index=None):
156
+
157
+ if filter_mask is None:
158
+ filter_mask = torch.ones_like(agent_state[:, 2], dtype=torch.bool)
159
+
160
+ num_agent, num_step, traj_dim = pos_a.shape # traj_dim=2
161
+ agent_type = data['agent']['type'][filter_mask]
162
+ veh_mask = (agent_type == 0)
163
+ ped_mask = (agent_type == 1)
164
+ cyc_mask = (agent_type == 2)
165
+
166
+ motion_vector_a, head_vector_a = self._build_vector_a(pos_a, head_a, agent_state)
167
+
168
+ trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(pos_a.device).to(torch.float)
169
+ trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(pos_a.device).to(torch.float)
170
+ trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(pos_a.device).to(torch.float)
171
+ self.agent_token_emb_veh = self.token_emb_veh(trajectory_token_veh.view(trajectory_token_veh.shape[0], -1)) # (token_size, 8)
172
+ self.agent_token_emb_ped = self.token_emb_ped(trajectory_token_ped.view(trajectory_token_ped.shape[0], -1))
173
+ self.agent_token_emb_cyc = self.token_emb_cyc(trajectory_token_cyc.view(trajectory_token_cyc.shape[0], -1))
174
+
175
+ # add bos token embedding
176
+ self.agent_token_emb_veh = torch.cat([self.agent_token_emb_veh, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())])
177
+ self.agent_token_emb_ped = torch.cat([self.agent_token_emb_ped, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())])
178
+ self.agent_token_emb_cyc = torch.cat([self.agent_token_emb_cyc, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())])
179
+
180
+ # add invalid token embedding
181
+ self.agent_token_emb_veh = torch.cat([self.agent_token_emb_veh, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())])
182
+ self.agent_token_emb_ped = torch.cat([self.agent_token_emb_ped, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())])
183
+ self.agent_token_emb_cyc = torch.cat([self.agent_token_emb_cyc, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())])
184
+
185
+ # self.grid_token_emb = self.token_emb_grid(torch.stack([self.attr_tokenizer.dist,
186
+ # self.attr_tokenizer.dir], dim=-1).to(pos_a.device))
187
+ self.grid_token_emb = self.token_emb_grid(self.attr_tokenizer.grid)
188
+ self.grid_token_emb = torch.cat([self.grid_token_emb, self.invalid_offset_token_emb(torch.zeros(1, device=pos_a.device).long())])
189
+
190
+ if inference:
191
+ agent_token_traj_all = torch.zeros((num_agent, self.token_size, self.shift + 1, 4, 2), device=pos_a.device)
192
+ trajectory_token_all_veh = torch.from_numpy(self.trajectory_token_all['veh']).clone().to(pos_a.device).to(torch.float)
193
+ trajectory_token_all_ped = torch.from_numpy(self.trajectory_token_all['ped']).clone().to(pos_a.device).to(torch.float)
194
+ trajectory_token_all_cyc = torch.from_numpy(self.trajectory_token_all['cyc']).clone().to(pos_a.device).to(torch.float)
195
+ agent_token_traj_all[veh_mask] = torch.cat(
196
+ [trajectory_token_all_veh[:, :self.shift], trajectory_token_veh[:, None, ...]], dim=1)
197
+ agent_token_traj_all[ped_mask] = torch.cat(
198
+ [trajectory_token_all_ped[:, :self.shift], trajectory_token_ped[:, None, ...]], dim=1)
199
+ agent_token_traj_all[cyc_mask] = torch.cat(
200
+ [trajectory_token_all_cyc[:, :self.shift], trajectory_token_cyc[:, None, ...]], dim=1)
201
+
202
+ # additional token embeddings are already added -> -1: invalid, -2: bos
203
+ agent_token_emb = torch.zeros((num_agent, num_step, self.hidden_dim), device=pos_a.device)
204
+ agent_token_emb[veh_mask] = self.agent_token_emb_veh[agent_token_index[veh_mask]]
205
+ agent_token_emb[ped_mask] = self.agent_token_emb_ped[agent_token_index[ped_mask]]
206
+ agent_token_emb[cyc_mask] = self.agent_token_emb_cyc[agent_token_index[cyc_mask]]
207
+
208
+ offset_token_emb = self.grid_token_emb[agent_offset_token_idx]
209
+
210
+ # 'vehicle', 'pedestrian', 'cyclist', 'background'
211
+ is_invalid = agent_state == self.invalid_state
212
+ agent_types = data['agent']['type'].clone()[filter_mask].long().repeat_interleave(repeats=num_step, dim=0)
213
+ agent_types[is_invalid.reshape(-1)] = self.agent_type.index('seed')
214
+ agent_shapes = data['agent']['shape'].clone()[filter_mask, self.num_historical_steps - 1, :].repeat_interleave(repeats=num_step, dim=0)
215
+ agent_shapes[is_invalid.reshape(-1)] = self.invalid_shape_value
216
+
217
+ # TODO: fix ego_pos in inference mode
218
+ offset_pos = pos_a - pos_a[av_index].repeat_interleave(repeats=data['batch_size_a'], dim=0)
219
+ feat_a, categorical_embs = self._build_agent_feature(num_step, pos_a.device,
220
+ motion_vector_a,
221
+ head_vector_a,
222
+ agent_token_emb,
223
+ offset_token_emb,
224
+ offset_pos=offset_pos,
225
+ type=agent_types,
226
+ shape=agent_shapes,
227
+ state=agent_state,
228
+ n=num_agent)
229
+
230
+ if inference:
231
+ return feat_a, agent_token_traj_all, agent_token_emb, categorical_embs
232
+ else:
233
+ # seed agent feature
234
+ if self.seed_use_ego_motion:
235
+ motion_vector_seed = motion_vector_a[av_index].repeat_interleave(repeats=self.num_seed_feature, dim=0)
236
+ head_vector_seed = head_vector_a[av_index].repeat_interleave(repeats=self.num_seed_feature, dim=0)
237
+ else:
238
+ motion_vector_seed = head_vector_seed = None
239
+ feat_seed, _ = self._build_agent_feature(num_step, pos_a.device,
240
+ motion_vector_seed,
241
+ head_vector_seed,
242
+ state_index=self.invalid_state,
243
+ n=data.num_graphs * self.num_seed_feature)
244
+
245
+ feat_a = torch.cat([feat_a, feat_seed], dim=0) # (a + n, t, d)
246
+
247
+ return feat_a
248
+
249
+ def _build_vector_a(self, pos_a, head_a, state_a):
250
+ num_agent = pos_a.shape[0]
251
+
252
+ motion_vector_a = torch.cat([pos_a.new_zeros(num_agent, 1, self.input_dim),
253
+ pos_a[:, 1:] - pos_a[:, :-1]], dim=1)
254
+
255
+ motion_vector_a[state_a == self.invalid_state] = self.invalid_motion_value
256
+
257
+ # invalid -> valid
258
+ is_last_invalid = (state_a.roll(shifts=1, dims=1) == self.invalid_state) & (state_a != self.invalid_state)
259
+ is_last_invalid[:, 0] = state_a[:, 0] == self.enter_state
260
+ motion_vector_a[is_last_invalid] = self.motion_gap
261
+
262
+ # valid -> invalid
263
+ is_last_valid = (state_a.roll(shifts=1, dims=1) != self.invalid_state) & (state_a == self.invalid_state)
264
+ is_last_valid[:, 0] = False
265
+ motion_vector_a[is_last_valid] = -self.motion_gap
266
+
267
+ head_a[state_a == self.invalid_state] == self.invalid_head_value
268
+ head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1)
269
+
270
+ return motion_vector_a, head_vector_a
271
+
272
+ def _build_agent_feature(self, num_step, device,
273
+ motion_vector=None,
274
+ head_vector=None,
275
+ agent_token_emb=None,
276
+ agent_grid_emb=None,
277
+ offset_pos=None,
278
+ type=None,
279
+ shape=None,
280
+ categorical_embs_a=None,
281
+ state=None,
282
+ state_index=None,
283
+ n=1):
284
+
285
+ if agent_token_emb is None:
286
+ agent_token_emb = self.no_token_emb(torch.zeros(1, device=device).long())[:, None].repeat(n, num_step, 1)
287
+ if state is not None:
288
+ agent_token_emb[state == self.enter_state] = self.bos_token_emb(torch.zeros(1, device=device).long())
289
+
290
+ if agent_grid_emb is None:
291
+ agent_grid_emb = self.grid_token_emb[None, None, self.grid_size // 2].repeat(n, num_step, 1)
292
+
293
+ if motion_vector is None or head_vector is None:
294
+ pos_a = torch.zeros((n, num_step, 2), device=device)
295
+ head_a = torch.zeros((n, num_step), device=device)
296
+ if state is None:
297
+ state = torch.full((n, num_step), self.invalid_state, device=device)
298
+ motion_vector, head_vector = self._build_vector_a(pos_a, head_a, state)
299
+
300
+ if offset_pos is None:
301
+ offset_pos = torch.zeros_like(motion_vector)
302
+
303
+ feature_a = torch.stack(
304
+ [torch.norm(motion_vector[:, :, :2], p=2, dim=-1),
305
+ angle_between_2d_vectors(ctr_vector=head_vector, nbr_vector=motion_vector[:, :, :2]),
306
+ # torch.norm(offset_pos[:, :, :2], p=2, dim=-1),
307
+ ], dim=-1)
308
+
309
+ if categorical_embs_a is None:
310
+ if type is None:
311
+ type = torch.tensor([self.agent_type.index('seed')], device=device)
312
+ if shape is None:
313
+ shape = torch.full((1, 3), self.invalid_shape_value, device=device)
314
+
315
+ categorical_embs_a = [self.type_a_emb(type.reshape(-1)), self.shape_emb(shape.reshape(-1, shape.shape[-1]))]
316
+
317
+ x_a = self.x_a_emb(continuous_inputs=feature_a.view(-1, feature_a.size(-1)),
318
+ categorical_embs=categorical_embs_a)
319
+ x_a = x_a.view(-1, num_step, self.hidden_dim) # (a, t, d)
320
+
321
+ if state is None:
322
+ assert state_index is not None, f"state index need to be set when state tensor is None!"
323
+ state = torch.tensor([state_index], device=device)[:, None].repeat(n, num_step, 1) # do not use `expand`
324
+ s_a = self.state_a_emb(state.reshape(-1).long()).reshape(n, num_step, self.hidden_dim)
325
+
326
+ feat_a = torch.cat((agent_token_emb, x_a, s_a, agent_grid_emb), dim=-1)
327
+ feat_a = self.fusion_emb(feat_a) # (a, t, d)
328
+
329
+ return feat_a, categorical_embs_a
330
+
331
+ def _pad_feat(self, num_graph, av_index, *feats, num_seed_feature=None):
332
+
333
+ if num_seed_feature is None:
334
+ num_seed_feature = self.num_seed_feature
335
+
336
+ padded_feats = tuple()
337
+ for i in range(len(feats)):
338
+ padded_feats += (torch.cat([feats[i], feats[i][av_index].repeat_interleave(
339
+ repeats=num_seed_feature, dim=0)],
340
+ dim=0
341
+ ),)
342
+
343
+ pad_mask = torch.ones(*padded_feats[0].shape[:2], device=feats[0].device).bool() # (a, t)
344
+ pad_mask[-num_graph * num_seed_feature:] = False
345
+
346
+ return padded_feats + (pad_mask,)
347
+
348
+ def _build_seed_feat(self, data, pos_a, head_a, state_a, head_vector_a, mask, sort_indices, av_index):
349
+ seed_mask = sort_indices != av_index.repeat_interleave(repeats=data['batch_size_a'], dim=0)[:, None]
350
+ # TODO: fix batch_size!!!
351
+ print(mask.shape, sort_indices.shape, seed_mask.shape)
352
+ mask[-data.num_graphs * self.num_seed_feature:] = seed_mask[:self.num_seed_feature]
353
+
354
+ insert_pos_a = torch.gather(pos_a, dim=0, index=sort_indices[:self.num_seed_feature, :, None].expand(-1, -1, pos_a.shape[-1]))
355
+ pos_a[mask] = insert_pos_a[mask[-self.num_seed_feature:]]
356
+
357
+ state_a[-data.num_graphs * self.num_seed_feature:] = self.enter_state
358
+
359
+ return pos_a, head_a, state_a, head_vector_a, mask
360
+
361
+ def _build_temporal_edge(self, data, pos_a, head_a, state_a, head_vector_a, mask, inference_mask=None):
362
+
363
+ num_graph = data.num_graphs
364
+ num_agent = pos_a.shape[0]
365
+ hist_mask = mask.clone()
366
+
367
+ if not self.temporal_attn_to_invalid:
368
+ is_bos = state_a == self.enter_state
369
+ bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
370
+ history_invalid_mask = torch.arange(mask.shape[1]).expand(mask.shape[0], mask.shape[1]).to(mask.device)
371
+ history_invalid_mask = (history_invalid_mask < bos_index[:, None])
372
+ hist_mask[history_invalid_mask] = False
373
+
374
+ if not self.temporal_attn_seed:
375
+ hist_mask[-num_graph * self.num_seed_feature:] = False
376
+ if inference_mask is not None:
377
+ inference_mask[-num_graph * self.num_seed_feature:] = False
378
+ else:
379
+ # WARNING: if use temporal attn to seed
380
+ # we need to fix the pos/head of seed!!!
381
+ raise RuntimeError("Wrong settings!")
382
+
383
+ pos_t = pos_a.reshape(-1, self.input_dim) # (num_agent * num_step, ...)
384
+ head_t = head_a.reshape(-1)
385
+ head_vector_t = head_vector_a.reshape(-1, 2)
386
+
387
+ # for those invalid agents won't predict any motion token, we don't attend to them
388
+ is_bos = state_a == self.enter_state
389
+ is_bos[-num_graph * self.num_seed_feature:] = False
390
+ bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
391
+ motion_predict_start_index = torch.clamp(bos_index - self.time_span / self.shift + 1, min=0)
392
+ motion_predict_mask = torch.arange(hist_mask.shape[1]).expand(hist_mask.shape[0], -1).to(hist_mask.device)
393
+ motion_predict_mask = motion_predict_mask >= motion_predict_start_index[:, None]
394
+ hist_mask[~motion_predict_mask] = False
395
+
396
+ if self.hist_mask and self.training:
397
+ hist_mask[
398
+ torch.arange(mask.shape[0]).unsqueeze(1), torch.randint(0, mask.shape[1], (num_agent, 10))] = False
399
+ mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1)
400
+ elif inference_mask is not None:
401
+ mask_t = hist_mask.unsqueeze(2) & inference_mask.unsqueeze(1)
402
+ else:
403
+ mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1)
404
+
405
+ # mask_t: (num_agent, 18, 18), edge_index_t: (2, num_edge)
406
+ edge_index_t = dense_to_sparse(mask_t)[0]
407
+ edge_index_t = edge_index_t[:, (edge_index_t[1] - edge_index_t[0] > 0) &
408
+ (edge_index_t[1] - edge_index_t[0] <= self.time_span / self.shift)]
409
+ rel_pos_t = pos_t[edge_index_t[0]] - pos_t[edge_index_t[1]]
410
+ rel_head_t = wrap_angle(head_t[edge_index_t[0]] - head_t[edge_index_t[1]])
411
+
412
+ # handle the invalid steps
413
+ is_invalid = state_a == self.invalid_state
414
+ is_invalid_t = is_invalid.reshape(-1)
415
+
416
+ rel_pos_t[is_invalid_t[edge_index_t[0]] & ~is_invalid_t[edge_index_t[1]]] = -self.motion_gap
417
+ rel_pos_t[~is_invalid_t[edge_index_t[0]] & is_invalid_t[edge_index_t[1]]] = self.motion_gap
418
+ rel_head_t[is_invalid_t[edge_index_t[0]] & ~is_invalid_t[edge_index_t[1]]] = -self.heading_gap
419
+ rel_head_t[~is_invalid_t[edge_index_t[1]] & is_invalid_t[edge_index_t[1]]] = self.heading_gap
420
+
421
+ rel_pos_t[is_invalid_t[edge_index_t[0]] & is_invalid_t[edge_index_t[1]]] = self.invalid_motion_value
422
+ rel_head_t[is_invalid_t[edge_index_t[0]] & is_invalid_t[edge_index_t[1]]] = self.invalid_head_value
423
+
424
+ r_t = torch.stack(
425
+ [torch.norm(rel_pos_t[:, :2], p=2, dim=-1),
426
+ angle_between_2d_vectors(ctr_vector=head_vector_t[edge_index_t[1]], nbr_vector=rel_pos_t[:, :2]),
427
+ rel_head_t,
428
+ edge_index_t[0] - edge_index_t[1]], dim=-1)
429
+ r_t = self.r_t_emb(continuous_inputs=r_t, categorical_embs=None)
430
+
431
+ return edge_index_t, r_t
432
+
433
+ def _build_interaction_edge(self, data, pos_a, head_a, state_a, head_vector_a, batch_s, mask, pad_mask=None, inference_mask=None,
434
+ av_index=None, seq_mask=None, seq_index=None, grid_index_a=None, **plot_kwargs):
435
+ num_graph = data.num_graphs
436
+ num_agent, num_step, _ = pos_a.shape
437
+ is_training = inference_mask is None
438
+
439
+ mask_a = mask.clone()
440
+
441
+ if pad_mask is None:
442
+ pad_mask = torch.ones_like(state_a).bool()
443
+
444
+ pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
445
+ head_s = head_a.transpose(0, 1).reshape(-1)
446
+ head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
447
+ pad_mask_s = pad_mask.transpose(0, 1).reshape(-1)
448
+ if inference_mask is not None:
449
+ mask_a = mask_a & inference_mask
450
+ mask_s = mask_a.transpose(0, 1).reshape(-1)
451
+
452
+ # build agent2agent bilateral connection
453
+ edge_index_a2a = radius_graph(x=pos_s[:, :2], r=self.a2a_radius, batch=batch_s, loop=False,
454
+ max_num_neighbors=300)
455
+ edge_index_a2a = subgraph(subset=mask_s & pad_mask_s, edge_index=edge_index_a2a)[0]
456
+
457
+ if os.getenv('PLOT_EDGE', False):
458
+ plot_interact_edge(edge_index_a2a, data['scenario_id'], data['batch_size_a'].cpu(), self.num_seed_feature, num_step,
459
+ av_index=av_index, **plot_kwargs)
460
+
461
+ rel_pos_a2a = pos_s[edge_index_a2a[0]] - pos_s[edge_index_a2a[1]]
462
+ rel_head_a2a = wrap_angle(head_s[edge_index_a2a[0]] - head_s[edge_index_a2a[1]])
463
+
464
+ # handle the invalid steps
465
+ is_invalid = state_a == self.invalid_state
466
+ is_invalid_s = is_invalid.transpose(0, 1).reshape(-1)
467
+
468
+ rel_pos_a2a[is_invalid_s[edge_index_a2a[0]] & ~is_invalid_s[edge_index_a2a[1]]] = -self.motion_gap
469
+ rel_pos_a2a[~is_invalid_s[edge_index_a2a[0]] & is_invalid_s[edge_index_a2a[1]]] = self.motion_gap
470
+ rel_head_a2a[is_invalid_s[edge_index_a2a[0]] & ~is_invalid_s[edge_index_a2a[1]]] = -self.heading_gap
471
+ rel_head_a2a[~is_invalid_s[edge_index_a2a[1]] & is_invalid_s[edge_index_a2a[1]]] = self.heading_gap
472
+
473
+ rel_pos_a2a[is_invalid_s[edge_index_a2a[0]] & is_invalid_s[edge_index_a2a[1]]] = self.invalid_motion_value
474
+ rel_head_a2a[is_invalid_s[edge_index_a2a[0]] & is_invalid_s[edge_index_a2a[1]]] = self.invalid_head_value
475
+
476
+ r_a2a = torch.stack(
477
+ [torch.norm(rel_pos_a2a[:, :2], p=2, dim=-1),
478
+ angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2a[1]], nbr_vector=rel_pos_a2a[:, :2]),
479
+ rel_head_a2a,
480
+ torch.zeros_like(edge_index_a2a[0])], dim=-1)
481
+ r_a2a = self.r_a2a_emb(continuous_inputs=r_a2a, categorical_embs=None)
482
+
483
+ # add the edges which connect seed agents
484
+ if is_training:
485
+ mask_av = torch.ones_like(mask_a).bool()
486
+ if not self.seed_attn_to_av:
487
+ mask_av[av_index] = False
488
+ mask_a &= mask_av
489
+ edge_index_seed2a, r_seed2a = self._build_a2sa_edge(data, pos_a, head_a, head_vector_a, batch_s,
490
+ mask_a.clone(), ~pad_mask.clone(), inference_mask=inference_mask,
491
+ r=self.pl2seed_radius, max_num_neighbors=300,
492
+ seq_mask=seq_mask, seq_index=seq_index, grid_index_a=grid_index_a, mode='grid')
493
+
494
+ if os.getenv('PLOT_EDGE', False):
495
+ plot_interact_edge(edge_index_seed2a, data['scenario_id'], data['batch_size_a'].cpu(), self.num_seed_feature, num_step,
496
+ 'interact_edge_map_seed', av_index=av_index, **plot_kwargs)
497
+
498
+ edge_index_a2a = torch.cat([edge_index_a2a, edge_index_seed2a], dim=-1)
499
+ r_a2a = torch.cat([r_a2a, r_seed2a])
500
+
501
+ return edge_index_a2a, r_a2a, (edge_index_a2a.shape[1], edge_index_seed2a.shape[1]) #, nearest_dict
502
+
503
+ return edge_index_a2a, r_a2a
504
+
505
+ def _build_map2agent_edge(self, data, pos_a, head_a, state_a, head_vector_a, batch_s, batch_pl,
506
+ mask, pad_mask=None, inference_mask=None, av_index=None, **kwargs):
507
+ num_graph = data.num_graphs
508
+ num_agent, num_step, _ = pos_a.shape
509
+ is_training = inference_mask is None
510
+
511
+ mask_pl2a = mask.clone()
512
+
513
+ if pad_mask is None:
514
+ pad_mask = torch.ones_like(state_a).bool()
515
+
516
+ pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
517
+ head_s = head_a.transpose(0, 1).reshape(-1)
518
+ head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
519
+ pad_mask_s = pad_mask.transpose(0, 1).reshape(-1)
520
+ if inference_mask is not None:
521
+ mask_pl2a = mask_pl2a & inference_mask
522
+ mask_pl2a = mask_pl2a.transpose(0, 1).reshape(-1)
523
+
524
+ ori_pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous()
525
+ ori_orient_pl = data['pt_token']['orientation'].contiguous()
526
+ pos_pl = ori_pos_pl.repeat(num_step, 1) # not `repeat_interleave`
527
+ orient_pl = ori_orient_pl.repeat(num_step)
528
+
529
+ # build map2agent directed graph
530
+ # edge_index_pl2a = radius(x=pos_s[:, :2], y=pos_pl[:, :2], r=self.pl2a_radius,
531
+ # batch_x=batch_s, batch_y=batch_pl, max_num_neighbors=300)
532
+ edge_index_pl2a = radius(x=pos_pl[:, :2], y=pos_s[:, :2], r=self.pl2a_radius,
533
+ batch_x=batch_pl, batch_y=batch_s, max_num_neighbors=5)
534
+ edge_index_pl2a = edge_index_pl2a[[1, 0]]
535
+ edge_index_pl2a = edge_index_pl2a[:, mask_pl2a[edge_index_pl2a[1]] &
536
+ pad_mask_s[edge_index_pl2a[1]]]
537
+
538
+ rel_pos_pl2a = pos_pl[edge_index_pl2a[0]] - pos_s[edge_index_pl2a[1]]
539
+ rel_orient_pl2a = wrap_angle(orient_pl[edge_index_pl2a[0]] - head_s[edge_index_pl2a[1]])
540
+
541
+ # handle the invalid steps
542
+ is_invalid = state_a == self.invalid_state
543
+ is_invalid_s = is_invalid.transpose(0, 1).reshape(-1)
544
+ rel_pos_pl2a[is_invalid_s[edge_index_pl2a[1]]] = self.motion_gap
545
+ rel_orient_pl2a[is_invalid_s[edge_index_pl2a[1]]] = self.heading_gap
546
+
547
+ r_pl2a = torch.stack(
548
+ [torch.norm(rel_pos_pl2a[:, :2], p=2, dim=-1),
549
+ angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_pl2a[1]], nbr_vector=rel_pos_pl2a[:, :2]),
550
+ rel_orient_pl2a], dim=-1)
551
+ r_pl2a = self.r_pt2a_emb(continuous_inputs=r_pl2a, categorical_embs=None)
552
+
553
+ # add the edges which connect seed agents
554
+ if is_training:
555
+ edge_index_pl2seed, r_pl2seed = self._build_map2sa_edge(data, pos_a, head_a, head_vector_a, batch_s, batch_pl,
556
+ ~pad_mask.clone(), inference_mask=inference_mask,
557
+ r=self.pl2seed_radius, max_num_neighbors=2048, mode='grid')
558
+
559
+ # sanity check
560
+ # pl2a_index = torch.zeros(pos_a.shape[0], num_step)
561
+ # pl2a_r = torch.zeros(pos_a.shape[0], num_step)
562
+ # for src_index in torch.unique(edge_index_pl2seed[1]):
563
+ # src_row = src_index % pos_a.shape[0]
564
+ # src_col = src_index // pos_a.shape[0]
565
+ # pl2a_index[src_row, src_col] = edge_index_pl2seed[0, edge_index_pl2seed[1] == src_index].sum()
566
+ # pl2a_r[src_row, src_col] = r_pl2seed[edge_index_pl2seed[1] == src_index].sum()
567
+ # print(pl2a_index)
568
+ # print(pl2a_r)
569
+ # exit(1)
570
+
571
+ if os.getenv('PLOT_EDGE', False):
572
+ plot_interact_edge(edge_index_pl2seed, data['scenario_id'], data['batch_size_a'].cpu(), self.num_seed_feature, num_step,
573
+ 'interact_edge_map_seed', av_index=av_index)
574
+
575
+ edge_index_pl2a = torch.cat([edge_index_pl2a, edge_index_pl2seed], dim=-1)
576
+ r_pl2a = torch.cat([r_pl2a, r_pl2seed])
577
+
578
+ return edge_index_pl2a, r_pl2a, (edge_index_pl2a.shape[1], edge_index_pl2seed.shape[1])
579
+
580
+ return edge_index_pl2a, r_pl2a
581
+
582
+ def _build_a2sa_edge(self, data, pos_a, head_a, head_vector_a, batch_s, mask_a, mask_sa,
583
+ inference_mask=None, r=None, max_num_neighbors=8, seq_mask=None, seq_index=None,
584
+ grid_index_a=None, mode: Literal['grid', 'heading']='heading', **plot_kwargs):
585
+
586
+ num_agent, num_step, _ = pos_a.shape
587
+ is_training = inference_mask is None
588
+
589
+ pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
590
+ head_s = head_a.transpose(0, 1).reshape(-1)
591
+ head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
592
+ if inference_mask is not None:
593
+ mask_a = mask_a & inference_mask
594
+ mask_sa = mask_sa & inference_mask
595
+ mask_s = mask_a.transpose(0, 1).reshape(-1)
596
+ mask_s_sa = mask_sa.transpose(0, 1).reshape(-1)
597
+
598
+ # build seed_agent2agent unilateral connection
599
+ assert r is not None, "r needs to be specified!"
600
+ # edge_index_a2sa = radius(x=pos_s[mask_s_sa, :2], y=pos_s[:, :2], r=r,
601
+ # batch_x=batch_s[mask_s_sa], batch_y=batch_s, max_num_neighbors=max_num_neighbors)
602
+ edge_index_a2sa = radius(x=pos_s[:, :2], y=pos_s[mask_s_sa, :2], r=r,
603
+ batch_x=batch_s, batch_y=batch_s[mask_s_sa], max_num_neighbors=max_num_neighbors)
604
+ edge_index_a2sa = edge_index_a2sa[[1, 0]]
605
+ edge_index_a2sa = edge_index_a2sa[:, ~mask_s_sa[edge_index_a2sa[0]] & mask_s[edge_index_a2sa[0]]]
606
+
607
+ # only for seed agent sequence training
608
+ if seq_mask is not None:
609
+ edge_mask = seq_mask[edge_index_a2sa[1]]
610
+ edge_mask = torch.gather(edge_mask, dim=1, index=edge_index_a2sa[0, :, None] % num_agent)[:, 0]
611
+ edge_index_a2sa = edge_index_a2sa[:, edge_mask]
612
+
613
+ if seq_index is None:
614
+ seq_index = torch.zeros(num_agent, device=pos_a.device).long()
615
+ if seq_index.dim() == 1:
616
+ seq_index = seq_index[:, None].repeat(1, num_step)
617
+ seq_index = seq_index.transpose(0, 1).reshape(-1)
618
+ assert seq_index.shape[0] == pos_s.shape[0], f"Inconsistent lenght {seq_index.shape[0]} and {pos_s.shape[0]}!"
619
+
620
+ # convert to global index
621
+ all_index = torch.arange(pos_s.shape[0], device=pos_a.device).long()
622
+ sa_index = all_index[mask_s_sa]
623
+ edge_index_a2sa[1] = sa_index[edge_index_a2sa[1]]
624
+
625
+ # plot edge index TODO: now only support bs=1
626
+ if os.getenv('PLOT_EDGE_INFERENCE', False) and not is_training:
627
+ num_agent, num_step, _ = pos_a.shape
628
+ # plot_interact_edge(edge_index_a2sa, data['scenario_id'], data['batch_size_a'].cpu(), 1, num_step,
629
+ # 'interact_a2sa_edge_map', **plot_kwargs)
630
+ plot_interact_edge(edge_index_a2sa, data['scenario_id'], torch.tensor([num_agent - 1]), 1, num_step,
631
+ f"interact_a2sa_edge_map_infer_{plot_kwargs['tag']}", **plot_kwargs)
632
+
633
+ rel_pos_a2sa = pos_s[edge_index_a2sa[0]] - pos_s[edge_index_a2sa[1]]
634
+ rel_head_a2sa = wrap_angle(head_s[edge_index_a2sa[0]] - head_s[edge_index_a2sa[1]])
635
+
636
+ r_a2sa = torch.stack(
637
+ [torch.norm(rel_pos_a2sa[:, :2], p=2, dim=-1),
638
+ angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2sa[1]], nbr_vector=rel_pos_a2sa[:, :2]),
639
+ rel_head_a2sa,
640
+ seq_index[edge_index_a2sa[0]] - seq_index[edge_index_a2sa[1]]], dim=-1)
641
+ r_a2sa = self.r_a2sa_emb(continuous_inputs=r_a2sa, categorical_embs=None)
642
+
643
+ return edge_index_a2sa, r_a2sa
644
+
645
+ def _build_map2sa_edge(self, data, pos_a, head_a, head_vector_a, batch_s, batch_pl,
646
+ mask_sa, inference_mask=None, r=None, max_num_neighbors=32, mode: Literal['grid', 'heading']='heading'):
647
+
648
+ _, num_step, _ = pos_a.shape
649
+
650
+ mask_pl2sa = torch.ones_like(mask_sa).bool()
651
+ if inference_mask is not None:
652
+ mask_pl2sa = mask_pl2sa & inference_mask
653
+ mask_pl2sa = mask_pl2sa.transpose(0, 1).reshape(-1)
654
+ mask_s_sa = mask_sa.transpose(0, 1).reshape(-1)
655
+
656
+ pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
657
+ head_s = head_a.transpose(0, 1).reshape(-1)
658
+ head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
659
+
660
+ ori_pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous()
661
+ ori_orient_pl = data['pt_token']['orientation'].contiguous()
662
+ pos_pl = ori_pos_pl.repeat(num_step, 1) # not `repeat_interleave`
663
+ orient_pl = ori_orient_pl.repeat(num_step)
664
+
665
+ # build map2agent directed graph
666
+ assert r is not None, "r needs to be specified!"
667
+ # edge_index_pl2sa = radius(x=pos_s[mask_s_sa, :2], y=pos_pl[:, :2], r=r,
668
+ # batch_x=batch_s[mask_s_sa], batch_y=batch_pl, max_num_neighbors=max_num_neighbors)
669
+ edge_index_pl2sa = radius(x=pos_pl[:, :2], y=pos_s[mask_s_sa, :2], r=r,
670
+ batch_x=batch_pl, batch_y=batch_s[mask_s_sa], max_num_neighbors=max_num_neighbors)
671
+ edge_index_pl2sa = edge_index_pl2sa[[1, 0]]
672
+ edge_index_pl2sa = edge_index_pl2sa[:, mask_pl2sa[mask_s_sa][edge_index_pl2sa[1]]]
673
+
674
+ # convert to global index
675
+ all_index = torch.arange(pos_s.shape[0], device=pos_a.device).long()
676
+ sa_index = all_index[mask_s_sa]
677
+ edge_index_pl2sa[1] = sa_index[edge_index_pl2sa[1]]
678
+
679
+ # plot edge map
680
+ # if os.getenv('PLOT_EDGE', False):
681
+ # plot_map_edge(edge_index_pl2sa, pos_s[:, :2], data, save_path='map2sa_edge_map')
682
+
683
+ rel_pos_pl2sa = pos_pl[edge_index_pl2sa[0]] - pos_s[edge_index_pl2sa[1]]
684
+ rel_orient_pl2sa = wrap_angle(orient_pl[edge_index_pl2sa[0]] - head_s[edge_index_pl2sa[1]])
685
+
686
+ r_pl2sa = torch.stack(
687
+ [torch.norm(rel_pos_pl2sa[:, :2], p=2, dim=-1),
688
+ angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_pl2sa[1]], nbr_vector=rel_pos_pl2sa[:, :2]),
689
+ rel_orient_pl2sa], dim=-1)
690
+ r_pl2sa = self.r_pt2sa_emb(continuous_inputs=r_pl2sa, categorical_embs=None)
691
+
692
+ return edge_index_pl2sa, r_pl2sa
693
+
694
+ def _build_sa2sa_edge(self, data, pos_a, head_a, state_a, head_vector_a, batch_s, mask, inference_mask=None, **plot_kwargs):
695
+
696
+ num_agent = pos_a.shape[0]
697
+
698
+ pos_t = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
699
+ head_t = head_a.reshape(-1)
700
+ head_vector_t = head_vector_a.reshape(-1, 2)
701
+
702
+ if inference_mask is not None:
703
+ mask_t = mask.unsqueeze(2) & inference_mask.unsqueeze(1)
704
+ else:
705
+ mask_t = mask.unsqueeze(2) & mask.unsqueeze(1)
706
+
707
+ edge_index_sa2sa = dense_to_sparse(mask_t)[0]
708
+ edge_index_sa2sa = edge_index_sa2sa[:, edge_index_sa2sa[1] - edge_index_sa2sa[0] > 0]
709
+ rel_pos_t = pos_t[edge_index_sa2sa[0]] - pos_t[edge_index_sa2sa[1]]
710
+ rel_head_t = wrap_angle(head_t[edge_index_sa2sa[0]] - head_t[edge_index_sa2sa[1]])
711
+
712
+ r_t = torch.stack(
713
+ [torch.norm(rel_pos_t[:, :2], p=2, dim=-1),
714
+ angle_between_2d_vectors(ctr_vector=head_vector_t[edge_index_sa2sa[1]], nbr_vector=rel_pos_t[:, :2]),
715
+ rel_head_t,
716
+ edge_index_sa2sa[0] - edge_index_sa2sa[1]], dim=-1)
717
+ r_sa2sa = self.r_sa2sa_emb(continuous_inputs=r_t, categorical_embs=None)
718
+
719
+ return edge_index_sa2sa, r_sa2sa
720
+
721
+ def _build_seq(self, device, num_agent, num_step, av_index, sort_indices):
722
+ """
723
+ Args:
724
+ sort_indices (torch.Tensor): shape (num_agent, num_atep)
725
+ """
726
+ # sort_indices = sort_indices[:self.num_seed_feature]
727
+ seq_mask = torch.ones(self.num_seed_feature, num_step, num_agent + self.num_seed_feature, device=device).bool()
728
+ seq_mask[..., -self.num_seed_feature:] = False
729
+ for t in range(num_step):
730
+ for s in range(self.num_seed_feature):
731
+ seq_mask[s, t, sort_indices[s:, t].flatten().long()] = False
732
+ if self.seed_attn_to_av:
733
+ seq_mask[..., av_index] = True
734
+ seq_mask = seq_mask.transpose(0, 1).reshape(-1, num_agent + self.num_seed_feature)
735
+
736
+ seq_index = torch.cat([torch.zeros(num_agent), torch.arange(self.num_seed_feature) + 1]).to(device)
737
+ seq_index = seq_index[:, None].repeat(1, num_step)
738
+ for t in range(num_step):
739
+ for s in range(self.num_seed_feature):
740
+ seq_index[sort_indices[s : s + 1, t].flatten().long(), t] = s + 1
741
+ seq_index[av_index] = 0
742
+
743
+ return seq_mask, seq_index
744
+
745
+ def _build_occ_gt(self, data, seq_mask, pos_rel_index_gt, pos_rel_index_gt_seed, mask_seed,
746
+ edge_index=None, mode='edge_index'):
747
+ """
748
+ Args:
749
+ seq_mask (torch.Tensor): shape (num_step * num_seed_feature, num_agent + self.num_seed_feature)
750
+ pos_rel_index_gt (torch.Tensor): shape (num_agent, num_step)
751
+ pos_rel_index_gt_seed (torch.Tensor): shape (num_seed, num_step)
752
+ """
753
+ num_agent = data['agent']['state_idx'].shape[0] + self.num_seed_feature
754
+ num_step = data['agent']['state_idx'].shape[1]
755
+ data['agent']['agent_occ'] = torch.zeros(data.num_graphs * self.num_seed_feature, num_step, self.attr_tokenizer.grid_size,
756
+ device=data['agent']['state_idx'].device).long()
757
+ data['agent']['map_occ'] = torch.zeros(data.num_graphs, num_step, self.attr_tokenizer.grid_size,
758
+ device=data['agent']['state_idx'].device).long()
759
+
760
+ if mode == 'edge_index':
761
+
762
+ assert edge_index is not None, f"Need edge_index input!"
763
+ for src_index in torch.unique(edge_index[1]):
764
+ # decode src
765
+ src_row = src_index % num_agent - (num_agent - self.num_seed_feature)
766
+ src_col = src_index // num_agent
767
+ # decode tgt
768
+ tgt_indexes = edge_index[0, edge_index[1] == src_index]
769
+ tgt_rows = tgt_indexes % num_agent
770
+ tgt_cols = tgt_indexes // num_agent
771
+ assert tgt_rows.max() < num_agent - self.num_seed_feature, f"Invalid {tgt_rows}"
772
+ assert torch.unique(tgt_cols).shape[0] == 1 and torch.unique(tgt_cols)[0] == src_col
773
+ data['agent']['agent_occ'][src_row, src_col, pos_rel_index_gt[tgt_rows, tgt_cols]] = 1
774
+
775
+ else:
776
+
777
+ seq_mask = seq_mask.reshape(num_step, self.num_seed_feature, -1).transpose(0, 1)[..., :-self.num_seed_feature]
778
+ for s in range(self.num_seed_feature):
779
+ for t in range(num_step):
780
+ index = pos_rel_index_gt[seq_mask[s, t], t]
781
+ data['agent']['agent_occ'][s, t, index[index != -1]] = 1
782
+ if t > 0 and s < pos_rel_index_gt_seed.shape[0] and mask_seed[s, t - 1]: # insert agents
783
+ data['agent']['agent_occ'][s, t, pos_rel_index_gt_seed[s, t - 1]] = -1
784
+
785
+ # TODO: fix batch_size!!!
786
+ pt_grid_token_idx = data['agent']['pt_grid_token_idx'] # (t, num_pt)
787
+ for t in range(num_step):
788
+ data['agent']['map_occ'][:, t, pt_grid_token_idx[t][pt_grid_token_idx[t] != -1]] = 1
789
+ data['agent']['map_occ'] = data['agent']['map_occ'].repeat_interleave(repeats=self.num_seed_feature, dim=0)
790
+
791
+ def forward(self,
792
+ data: HeteroData,
793
+ map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
794
+
795
+ pos_a = data['agent']['token_pos'].clone() # (a, t, 2)
796
+ head_a = data['agent']['token_heading'].clone() # (a, t)
797
+ num_agent, num_step, traj_dim = pos_a.shape # e.g. (50, 18, 2)
798
+ num_pt = data['pt_token']['position'].shape[0]
799
+ agent_category = data['agent']['category'].clone() # (a,)
800
+ agent_shape = data['agent']['shape'][:, self.num_historical_steps - 1].clone() # (a, 3)
801
+ agent_token_index = data['agent']['token_idx'].clone() # (a, t)
802
+ agent_state_index = data['agent']['state_idx'].clone()
803
+ agent_type_index = data['agent']['type'].clone()
804
+
805
+ av_index = data['agent']['av_index'].long()
806
+ ego_pos = pos_a[av_index]
807
+ ego_head = head_a[av_index]
808
+
809
+ _, head_vector_a = self._build_vector_a(pos_a, head_a, agent_state_index)
810
+
811
+ agent_grid_token_idx = data['agent']['grid_token_idx']
812
+ agent_grid_offset_xy = data['agent']['grid_offset_xy']
813
+ agent_head_token_idx = data['agent']['heading_token_idx']
814
+ sort_indices = data['agent']['sort_indices']
815
+ pt_grid_token_idx = data['agent']['pt_grid_token_idx']
816
+
817
+ ori_pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous()
818
+ ori_orient_pl = data['pt_token']['orientation'].contiguous()
819
+ pos_pl = ori_pos_pl.repeat(num_step, 1)
820
+ orient_pl = ori_orient_pl.repeat(num_step)
821
+
822
+ # build relative 3d descriptors
823
+ pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
824
+ head_s = head_a.transpose(0, 1).reshape(-1)
825
+
826
+ ego_pos_a = ego_pos.repeat_interleave(repeats=data['batch_size_a'], dim=0)
827
+ ego_head_a = ego_head.repeat_interleave(repeats=data['batch_size_a'], dim=0)
828
+ ego_pos_s = ego_pos_a.transpose(0, 1).reshape(-1, self.input_dim)
829
+ ego_head_s = ego_head_a.transpose(0, 1).reshape(-1)
830
+ rel_pos_a2a = pos_s - ego_pos_s
831
+ rel_head_a2a = head_s - ego_head_s
832
+
833
+ ego_pos_pl = ego_pos.repeat_interleave(repeats=data['batch_size_pl'], dim=0)
834
+ ego_head_pl = ego_head.repeat_interleave(repeats=data['batch_size_pl'], dim=0)
835
+ ego_pos_s = ego_pos_pl.transpose(0, 1).reshape(-1, self.input_dim)
836
+ ego_head_s = ego_head_pl.transpose(0, 1).reshape(-1)
837
+ rel_pos_pl2a = pos_pl - ego_pos_s
838
+ rel_head_pl2a = orient_pl - ego_head_s
839
+
840
+ # releative encodings
841
+ ego_head_vector_a = head_vector_a[av_index].repeat_interleave(repeats=data['batch_size_a'], dim=0)
842
+ ego_head_vector_s = ego_head_vector_a.transpose(0, 1).reshape(-1, 2)
843
+ r_a2a = torch.stack(
844
+ [torch.norm(rel_pos_a2a[:, :2], p=2, dim=-1),
845
+ angle_between_2d_vectors(ctr_vector=ego_head_vector_s, nbr_vector=rel_pos_a2a[:, :2]),
846
+ rel_head_a2a], dim=-1)
847
+ r_a2a = self.r_a2a_emb(continuous_inputs=r_a2a, categorical_embs=None) # [N, hidden_dim]
848
+
849
+ ego_head_vector_a = head_vector_a[av_index].repeat_interleave(repeats=data['batch_size_pl'], dim=0)
850
+ ego_head_vector_s = ego_head_vector_a.transpose(0, 1).reshape(-1, 2)
851
+ r_pl2a = torch.stack(
852
+ [torch.norm(rel_pos_pl2a[:, :2], p=2, dim=-1),
853
+ angle_between_2d_vectors(ctr_vector=ego_head_vector_s, nbr_vector=rel_pos_pl2a[:, :2]),
854
+ rel_head_pl2a], dim=-1)
855
+ r_pl2a = self.r_pt2a_emb(continuous_inputs=r_pl2a, categorical_embs=None) # [M, d]
856
+
857
+ r_a2a = r_a2a.reshape(num_step, num_agent, -1).transpose(0, 1)
858
+ r_pl2a = r_pl2a.reshape(num_step, num_pt, -1).transpose(0, 1)
859
+ select_agent = torch.randperm(num_agent)[:self.agent_limit]
860
+ select_pt = torch.randperm(num_pt)[:self.pt_limit]
861
+ r_a2a = r_a2a[select_agent]
862
+ r_pl2a = r_pl2a[select_pt]
863
+
864
+ # aggregate to global feature
865
+ r_a2a = r_a2a.mean(dim=0) # [t, d]
866
+ r_pl2a = r_pl2a.mean(dim=0)
867
+
868
+ # decode grid index of neighbor agents
869
+ agent_occ = self.grid_agent_occ_head(r_a2a) # [t, grid_size]
870
+ pt_occ = self.grid_pt_occ_head(r_pl2a)
871
+
872
+ # 1.
873
+ # agent_occ_gt = torch.zeros_like(agent_occ).long()
874
+ # pt_occ_gt = torch.zeros_like(pt_occ).long()
875
+
876
+ # for t in range(num_step):
877
+ # agent_occ_gt[t, agent_grid_token_idx[:, t][agent_grid_token_idx[:, t] != -1]] = 1
878
+ # pt_occ_gt[t, pt_grid_token_idx[t][pt_grid_token_idx[t] != -1]] = 1
879
+
880
+ # agent_occ_gt[:, self.grid_size // 2] = 0
881
+ # pt_occ_gt[:, self.grid_size // 2] = 0
882
+
883
+ # agent_occ_eval_mask = torch.ones_like(agent_occ_gt)
884
+ # agent_occ_eval_mask[0] = 0
885
+ # agent_occ_eval_mask[:, self.grid_size // 2] = 0
886
+ # pt_occ_eval_mask = torch.ones_like(pt_occ_gt)
887
+ # pt_occ_eval_mask[0] = 0
888
+ # pt_occ_eval_mask[:, self.grid_size // 2] = 0
889
+
890
+ # 2.
891
+ # agent_occ_gt = agent_grid_token_idx.transpose(0, 1).reshape(-1)
892
+ # pt_occ_gt = pt_grid_token_idx.reshape(-1)
893
+
894
+ # agent_occ_eval_mask = torch.zeros_like(agent_occ_gt)
895
+ # agent_occ_eval_mask[torch.randperm(agent_occ_gt.shape[0])[:(num_step * 10)]] = 1
896
+ # agent_occ_eval_mask[agent_occ_gt == -1] = 0
897
+
898
+ # pt_occ_eval_mask = torch.zeros_like(pt_occ_gt)
899
+ # pt_occ_eval_mask[torch.randperm(pt_occ_gt.shape[0])[:(num_step * 300)]] = 1
900
+ # pt_occ_eval_mask[pt_occ_gt == -1] = 0
901
+
902
+ # 3.
903
+ agent_occ = agent_occ.reshape(num_step, self.agent_limit, -1)
904
+ pt_occ = pt_occ.reshape(num_step, self.pt_limit, -1)
905
+ agent_occ_gt = agent_grid_token_idx[select_agent].transpose(0, 1)
906
+ pt_occ_gt = pt_grid_token_idx[:, select_pt]
907
+ agent_occ_eval_mask = agent_occ_gt != -1
908
+ pt_occ_eval_mask = pt_occ_gt != -1
909
+
910
+ agent_occ = agent_occ[:, :agent_occ_gt.shape[1]]
911
+ pt_occ = pt_occ[:, :pt_occ_gt.shape[1]]
912
+
913
+ return {'occ_decoder': True,
914
+ 'num_step': num_step,
915
+ 'num_agent': self.agent_limit, # num_agent
916
+ 'num_pt': self.pt_limit, # num_pt
917
+ 'agent_occ': agent_occ,
918
+ 'agent_occ_gt': agent_occ_gt,
919
+ 'agent_occ_eval_mask': agent_occ_eval_mask.bool(),
920
+ 'pt_occ': pt_occ,
921
+ 'pt_occ_gt': pt_occ_gt,
922
+ 'pt_occ_eval_mask': pt_occ_eval_mask.bool(),
923
+ }
924
+
925
+ def inference(self, *args, **kwargs):
926
+ return self(*args, **kwargs)
927
+
backups/dev/modules/smart_decoder.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch_geometric.data import HeteroData
5
+ from dev.modules.attr_tokenizer import Attr_Tokenizer
6
+ from dev.modules.agent_decoder import SMARTAgentDecoder
7
+ from dev.modules.occ_decoder import SMARTOccDecoder
8
+ from dev.modules.map_decoder import SMARTMapDecoder
9
+
10
+
11
+ DECODER = {'agent_decoder': SMARTAgentDecoder,
12
+ 'occ_decoder': SMARTOccDecoder}
13
+
14
+
15
+ class SMARTDecoder(nn.Module):
16
+
17
+ def __init__(self,
18
+ decoder_type: str,
19
+ dataset: str,
20
+ input_dim: int,
21
+ hidden_dim: int,
22
+ num_historical_steps: int,
23
+ pl2pl_radius: float,
24
+ time_span: Optional[int],
25
+ pl2a_radius: float,
26
+ pl2seed_radius: float,
27
+ a2a_radius: float,
28
+ a2sa_radius: float,
29
+ pl2sa_radius: float,
30
+ num_freq_bands: int,
31
+ num_map_layers: int,
32
+ num_agent_layers: int,
33
+ num_heads: int,
34
+ head_dim: int,
35
+ dropout: float,
36
+ map_token: Dict,
37
+ token_size=512,
38
+ attr_tokenizer: Attr_Tokenizer=None,
39
+ predict_motion: bool=False,
40
+ predict_state: bool=False,
41
+ predict_map: bool=False,
42
+ predict_occ: bool=False,
43
+ use_grid_token: bool=False,
44
+ state_token: Dict[str, int]=None,
45
+ seed_size: int=5,
46
+ buffer_size: int=32,
47
+ num_recurrent_steps_val: int=-1,
48
+ loss_weight: dict=None,
49
+ logger=None) -> None:
50
+
51
+ super(SMARTDecoder, self).__init__()
52
+
53
+ self.map_encoder = SMARTMapDecoder(
54
+ dataset=dataset,
55
+ input_dim=input_dim,
56
+ hidden_dim=hidden_dim,
57
+ num_historical_steps=num_historical_steps,
58
+ pl2pl_radius=pl2pl_radius,
59
+ num_freq_bands=num_freq_bands,
60
+ num_layers=num_map_layers,
61
+ num_heads=num_heads,
62
+ head_dim=head_dim,
63
+ dropout=dropout,
64
+ map_token=map_token,
65
+ )
66
+
67
+ assert decoder_type in list(DECODER.keys()), f"Unsupport decoder type: {decoder_type}"
68
+ self.agent_encoder = DECODER[decoder_type](
69
+ dataset=dataset,
70
+ input_dim=input_dim,
71
+ hidden_dim=hidden_dim,
72
+ num_historical_steps=num_historical_steps,
73
+ time_span=time_span,
74
+ pl2a_radius=pl2a_radius,
75
+ pl2seed_radius=pl2seed_radius,
76
+ a2a_radius=a2a_radius,
77
+ a2sa_radius=a2sa_radius,
78
+ pl2sa_radius=pl2sa_radius,
79
+ num_freq_bands=num_freq_bands,
80
+ num_layers=num_agent_layers,
81
+ num_heads=num_heads,
82
+ head_dim=head_dim,
83
+ dropout=dropout,
84
+ token_size=token_size,
85
+ attr_tokenizer=attr_tokenizer,
86
+ predict_motion=predict_motion,
87
+ predict_state=predict_state,
88
+ predict_map=predict_map,
89
+ predict_occ=predict_occ,
90
+ state_token=state_token,
91
+ use_grid_token=use_grid_token,
92
+ seed_size=seed_size,
93
+ buffer_size=buffer_size,
94
+ num_recurrent_steps_val=num_recurrent_steps_val,
95
+ loss_weight=loss_weight,
96
+ logger=logger,
97
+ )
98
+ self.map_enc = None
99
+ self.predict_motion = predict_motion
100
+ self.predict_state = predict_state
101
+ self.predict_map = predict_map
102
+ self.predict_occ = predict_occ
103
+ self.data_keys = ["agent_valid_mask", "category", "valid_mask", "av_index", "scenario_id", "shape"]
104
+
105
+ def get_agent_inputs(self, data: HeteroData) -> Dict[str, torch.Tensor]:
106
+ return self.agent_encoder.get_inputs(data)
107
+
108
+ def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]:
109
+ map_enc = self.map_encoder(data)
110
+
111
+ agent_enc = {}
112
+ if self.predict_motion or self.predict_state or self.predict_occ:
113
+ agent_enc = self.agent_encoder(data, map_enc)
114
+
115
+ return {**map_enc, **agent_enc, **{k: data[k] for k in self.data_keys}}
116
+
117
+ def inference(self, data: HeteroData) -> Dict[str, torch.Tensor]:
118
+ map_enc = self.map_encoder(data)
119
+
120
+ agent_enc = {}
121
+ if self.predict_motion or self.predict_state or self.predict_occ:
122
+ agent_enc = self.agent_encoder.inference(data, map_enc)
123
+
124
+ return {**map_enc, **agent_enc, **{k: data[k] for k in self.data_keys}}
125
+
126
+ def inference_no_map(self, data: HeteroData, map_enc) -> Dict[str, torch.Tensor]:
127
+ agent_enc = self.agent_encoder.inference(data, map_enc)
128
+ return {**map_enc, **agent_enc}
129
+
130
+ def insert_agent(self, data: HeteroData) -> Dict[str, torch.Tensor]:
131
+ map_enc = self.map_encoder(data)
132
+ agent_enc = self.agent_encoder.insert(data, map_enc)
133
+ return {**map_enc, **agent_enc, **{k: data[k] for k in self.data_keys}}
134
+
135
+ def predict_nearest_pos(self, data: HeteroData, rank) -> Dict[str, torch.Tensor]:
136
+ map_enc = self.map_encoder(data)
137
+ self.agent_encoder.predict_nearest_pos(data, map_enc, rank)
backups/dev/utils/cluster_reader.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import pickle
3
+ import pandas as pd
4
+ import json
5
+
6
+
7
+ class LoadScenarioFromCeph:
8
+ def __init__(self):
9
+ from petrel_client.client import Client
10
+ self.file_client = Client('~/petreloss.conf')
11
+
12
+ def list(self, dir_path):
13
+ return list(self.file_client.list(dir_path))
14
+
15
+ def save(self, data, url):
16
+ self.file_client.put(url, pickle.dumps(data))
17
+
18
+ def read_correct_csv(self, scenario_path):
19
+ output = pd.read_csv(io.StringIO(self.file_client.get(scenario_path).decode('utf-8')), engine="python")
20
+ return output
21
+
22
+ def contains(self, url):
23
+ return self.file_client.contains(url)
24
+
25
+ def read_string(self, csv_url):
26
+ from io import StringIO
27
+ df = pd.read_csv(StringIO(str(self.file_client.get(csv_url), 'utf-8')), sep='\s+', low_memory=False)
28
+ return df
29
+
30
+ def read(self, scenario_path):
31
+ with io.BytesIO(self.file_client.get(scenario_path)) as f:
32
+ datas = pickle.load(f)
33
+ return datas
34
+
35
+ def read_json(self, path):
36
+ with io.BytesIO(self.file_client.get(path)) as f:
37
+ data = json.load(f)
38
+ return data
39
+
40
+ def read_csv(self, scenario_path):
41
+ return pickle.loads(self.file_client.get(scenario_path))
42
+
43
+ def read_model(self, model_path):
44
+ with io.BytesIO(self.file_client.get(model_path)) as f:
45
+ pass
backups/dev/utils/func.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+ import os
4
+ import yaml
5
+ import easydict
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ from rich.console import Console
10
+ from typing import Any, List, Optional, Mapping
11
+
12
+ from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
13
+
14
+
15
+ CONSOLE = Console(width=128)
16
+
17
+
18
+ def check_nan_inf(t, s):
19
+ assert not torch.isinf(t).any(), f"{s} is inf, {t}"
20
+ assert not torch.isnan(t).any(), f"{s} is nan, {t}"
21
+
22
+
23
+ def safe_list_index(ls: List[Any], elem: Any) -> Optional[int]:
24
+ try:
25
+ return ls.index(elem)
26
+ except ValueError:
27
+ return None
28
+
29
+
30
+ def angle_between_2d_vectors(
31
+ ctr_vector: torch.Tensor,
32
+ nbr_vector: torch.Tensor) -> torch.Tensor:
33
+ return torch.atan2(ctr_vector[..., 0] * nbr_vector[..., 1] - ctr_vector[..., 1] * nbr_vector[..., 0],
34
+ (ctr_vector[..., :2] * nbr_vector[..., :2]).sum(dim=-1))
35
+
36
+
37
+ def angle_between_3d_vectors(
38
+ ctr_vector: torch.Tensor,
39
+ nbr_vector: torch.Tensor) -> torch.Tensor:
40
+ return torch.atan2(torch.cross(ctr_vector, nbr_vector, dim=-1).norm(p=2, dim=-1),
41
+ (ctr_vector * nbr_vector).sum(dim=-1))
42
+
43
+
44
+ def side_to_directed_lineseg(
45
+ query_point: torch.Tensor,
46
+ start_point: torch.Tensor,
47
+ end_point: torch.Tensor) -> str:
48
+ cond = ((end_point[0] - start_point[0]) * (query_point[1] - start_point[1]) -
49
+ (end_point[1] - start_point[1]) * (query_point[0] - start_point[0]))
50
+ if cond > 0:
51
+ return 'LEFT'
52
+ elif cond < 0:
53
+ return 'RIGHT'
54
+ else:
55
+ return 'CENTER'
56
+
57
+
58
+ def wrap_angle(
59
+ angle: torch.Tensor,
60
+ min_val: float = -math.pi,
61
+ max_val: float = math.pi) -> torch.Tensor:
62
+ return min_val + (angle + max_val) % (max_val - min_val)
63
+
64
+
65
+ def load_config_act(path):
66
+ """ load config file"""
67
+ with open(path, 'r') as f:
68
+ cfg = yaml.load(f, Loader=yaml.FullLoader)
69
+ return easydict.EasyDict(cfg)
70
+
71
+
72
+ def load_config_init(path):
73
+ """ load config file"""
74
+ path = os.path.join('init/configs', f'{path}.yaml')
75
+ with open(path, 'r') as f:
76
+ cfg = yaml.load(f, Loader=yaml.FullLoader)
77
+ return cfg
78
+
79
+
80
+ class Logging:
81
+
82
+ def make_log_dir(self, dirname='logs'):
83
+ now_dir = os.path.dirname(__file__)
84
+ path = os.path.join(now_dir, dirname)
85
+ path = os.path.normpath(path)
86
+ if not os.path.exists(path):
87
+ os.mkdir(path)
88
+ return path
89
+
90
+ def get_log_filename(self):
91
+ filename = "{}.log".format(time.strftime("%Y-%m-%d-%H%M%S", time.localtime()))
92
+ filename = os.path.join(self.make_log_dir(), filename)
93
+ filename = os.path.normpath(filename)
94
+ return filename
95
+
96
+ def log(self, level='DEBUG', name="simagent"):
97
+ logger = logging.getLogger(name)
98
+ level = getattr(logging, level)
99
+ logger.setLevel(level)
100
+ if not logger.handlers:
101
+ sh = logging.StreamHandler()
102
+ fh = logging.FileHandler(filename=self.get_log_filename(), mode='a',encoding="utf-8")
103
+ fmt = logging.Formatter("%(asctime)s-%(levelname)s-%(filename)s-Line:%(lineno)d-Message:%(message)s")
104
+ sh.setFormatter(fmt=fmt)
105
+ fh.setFormatter(fmt=fmt)
106
+ logger.addHandler(sh)
107
+ logger.addHandler(fh)
108
+ return logger
109
+
110
+ def add_log(self, logger, level='DEBUG'):
111
+ level = getattr(logging, level)
112
+ logger.setLevel(level)
113
+ if not logger.handlers:
114
+ sh = logging.StreamHandler()
115
+ fh = logging.FileHandler(filename=self.get_log_filename(), mode='a',encoding="utf-8")
116
+ fmt = logging.Formatter("%(asctime)s-%(levelname)s-%(filename)s-Line:%(lineno)d-Message:%(message)s")
117
+ sh.setFormatter(fmt=fmt)
118
+ fh.setFormatter(fmt=fmt)
119
+ logger.addHandler(sh)
120
+ logger.addHandler(fh)
121
+ return logger
122
+
123
+
124
+ # Adapted from 'CatK'
125
+ class RankedLogger(logging.LoggerAdapter):
126
+ """A multi-GPU-friendly python command line logger."""
127
+
128
+ def __init__(
129
+ self,
130
+ name: str = __name__,
131
+ rank_zero_only: bool = False,
132
+ extra: Optional[Mapping[str, object]] = None,
133
+ ) -> None:
134
+ """Initializes a multi-GPU-friendly python command line logger that logs on all processes
135
+ with their rank prefixed in the log message.
136
+
137
+ :param name: The name of the logger. Default is ``__name__``.
138
+ :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
139
+ :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
140
+ """
141
+ logger = logging.getLogger(name)
142
+ super().__init__(logger=logger, extra=extra)
143
+ self.rank_zero_only = rank_zero_only
144
+
145
+ def log(
146
+ self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
147
+ ) -> None:
148
+ """Delegate a log call to the underlying logger, after prefixing its message with the rank
149
+ of the process it's being logged from. If `'rank'` is provided, then the log will only
150
+ occur on that rank/process.
151
+
152
+ :param level: The level to log at. Look at `logging.__init__.py` for more information.
153
+ :param msg: The message to log.
154
+ :param rank: The rank to log at.
155
+ :param args: Additional args to pass to the underlying logging function.
156
+ :param kwargs: Any additional keyword args to pass to the underlying logging function.
157
+ """
158
+ if self.isEnabledFor(level):
159
+ msg, kwargs = self.process(msg, kwargs)
160
+ current_rank = getattr(rank_zero_only, "rank", None)
161
+ if current_rank is None:
162
+ raise RuntimeError(
163
+ "The `rank_zero_only.rank` needs to be set before use"
164
+ )
165
+ msg = rank_prefixed_message(msg, current_rank)
166
+ if self.rank_zero_only:
167
+ if current_rank == 0:
168
+ self.logger.log(level, msg, *args, **kwargs)
169
+ else:
170
+ if rank is None:
171
+ self.logger.log(level, msg, *args, **kwargs)
172
+ elif current_rank == rank:
173
+ self.logger.log(level, msg, *args, **kwargs)
174
+
175
+
176
+
177
+ def weight_init(m: nn.Module) -> None:
178
+ if isinstance(m, nn.Linear):
179
+ nn.init.xavier_uniform_(m.weight)
180
+ if m.bias is not None:
181
+ nn.init.zeros_(m.bias)
182
+ elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
183
+ fan_in = m.in_channels / m.groups
184
+ fan_out = m.out_channels / m.groups
185
+ bound = (6.0 / (fan_in + fan_out)) ** 0.5
186
+ nn.init.uniform_(m.weight, -bound, bound)
187
+ if m.bias is not None:
188
+ nn.init.zeros_(m.bias)
189
+ elif isinstance(m, nn.Embedding):
190
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
191
+ elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
192
+ nn.init.ones_(m.weight)
193
+ nn.init.zeros_(m.bias)
194
+ elif isinstance(m, nn.LayerNorm):
195
+ nn.init.ones_(m.weight)
196
+ nn.init.zeros_(m.bias)
197
+ elif isinstance(m, nn.MultiheadAttention):
198
+ if m.in_proj_weight is not None:
199
+ fan_in = m.embed_dim
200
+ fan_out = m.embed_dim
201
+ bound = (6.0 / (fan_in + fan_out)) ** 0.5
202
+ nn.init.uniform_(m.in_proj_weight, -bound, bound)
203
+ else:
204
+ nn.init.xavier_uniform_(m.q_proj_weight)
205
+ nn.init.xavier_uniform_(m.k_proj_weight)
206
+ nn.init.xavier_uniform_(m.v_proj_weight)
207
+ if m.in_proj_bias is not None:
208
+ nn.init.zeros_(m.in_proj_bias)
209
+ nn.init.xavier_uniform_(m.out_proj.weight)
210
+ if m.out_proj.bias is not None:
211
+ nn.init.zeros_(m.out_proj.bias)
212
+ if m.bias_k is not None:
213
+ nn.init.normal_(m.bias_k, mean=0.0, std=0.02)
214
+ if m.bias_v is not None:
215
+ nn.init.normal_(m.bias_v, mean=0.0, std=0.02)
216
+ elif isinstance(m, (nn.LSTM, nn.LSTMCell)):
217
+ for name, param in m.named_parameters():
218
+ if 'weight_ih' in name:
219
+ for ih in param.chunk(4, 0):
220
+ nn.init.xavier_uniform_(ih)
221
+ elif 'weight_hh' in name:
222
+ for hh in param.chunk(4, 0):
223
+ nn.init.orthogonal_(hh)
224
+ elif 'weight_hr' in name:
225
+ nn.init.xavier_uniform_(param)
226
+ elif 'bias_ih' in name:
227
+ nn.init.zeros_(param)
228
+ elif 'bias_hh' in name:
229
+ nn.init.zeros_(param)
230
+ nn.init.ones_(param.chunk(4, 0)[1])
231
+ elif isinstance(m, (nn.GRU, nn.GRUCell)):
232
+ for name, param in m.named_parameters():
233
+ if 'weight_ih' in name:
234
+ for ih in param.chunk(3, 0):
235
+ nn.init.xavier_uniform_(ih)
236
+ elif 'weight_hh' in name:
237
+ for hh in param.chunk(3, 0):
238
+ nn.init.orthogonal_(hh)
239
+ elif 'bias_ih' in name:
240
+ nn.init.zeros_(param)
241
+ elif 'bias_hh' in name:
242
+ nn.init.zeros_(param)
243
+
244
+
245
+ def pos2posemb(pos, num_pos_feats=128, temperature=10000):
246
+
247
+ scale = 2 * math.pi
248
+ pos = pos * scale
249
+ dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)
250
+ dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
251
+
252
+ D = pos.shape[-1]
253
+ pos_dims = []
254
+ for i in range(D):
255
+ pos_dim_i = pos[..., i, None] / dim_t
256
+ pos_dim_i = torch.stack((pos_dim_i[..., 0::2].sin(), pos_dim_i[..., 1::2].cos()), dim=-1).flatten(-2)
257
+ pos_dims.append(pos_dim_i)
258
+ posemb = torch.cat(pos_dims, dim=-1)
259
+
260
+ return posemb
backups/dev/utils/graph.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from typing import List, Optional, Tuple, Union
4
+ from torch_geometric.utils import coalesce
5
+ from torch_geometric.utils import degree
6
+
7
+
8
+ def add_edges(
9
+ from_edge_index: torch.Tensor,
10
+ to_edge_index: torch.Tensor,
11
+ from_edge_attr: Optional[torch.Tensor] = None,
12
+ to_edge_attr: Optional[torch.Tensor] = None,
13
+ replace: bool = True) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
14
+ from_edge_index = from_edge_index.to(device=to_edge_index.device, dtype=to_edge_index.dtype)
15
+ mask = ((to_edge_index[0].unsqueeze(-1) == from_edge_index[0].unsqueeze(0)) &
16
+ (to_edge_index[1].unsqueeze(-1) == from_edge_index[1].unsqueeze(0)))
17
+ if replace:
18
+ to_mask = mask.any(dim=1)
19
+ if from_edge_attr is not None and to_edge_attr is not None:
20
+ from_edge_attr = from_edge_attr.to(device=to_edge_attr.device, dtype=to_edge_attr.dtype)
21
+ to_edge_attr = torch.cat([to_edge_attr[~to_mask], from_edge_attr], dim=0)
22
+ to_edge_index = torch.cat([to_edge_index[:, ~to_mask], from_edge_index], dim=1)
23
+ else:
24
+ from_mask = mask.any(dim=0)
25
+ if from_edge_attr is not None and to_edge_attr is not None:
26
+ from_edge_attr = from_edge_attr.to(device=to_edge_attr.device, dtype=to_edge_attr.dtype)
27
+ to_edge_attr = torch.cat([to_edge_attr, from_edge_attr[~from_mask]], dim=0)
28
+ to_edge_index = torch.cat([to_edge_index, from_edge_index[:, ~from_mask]], dim=1)
29
+ return to_edge_index, to_edge_attr
30
+
31
+
32
+ def merge_edges(
33
+ edge_indices: List[torch.Tensor],
34
+ edge_attrs: Optional[List[torch.Tensor]] = None,
35
+ reduce: str = 'add') -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
36
+ edge_index = torch.cat(edge_indices, dim=1)
37
+ if edge_attrs is not None:
38
+ edge_attr = torch.cat(edge_attrs, dim=0)
39
+ else:
40
+ edge_attr = None
41
+ return coalesce(edge_index=edge_index, edge_attr=edge_attr, reduce=reduce)
42
+
43
+
44
+ def complete_graph(
45
+ num_nodes: Union[int, Tuple[int, int]],
46
+ ptr: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
47
+ loop: bool = False,
48
+ device: Optional[Union[torch.device, str]] = None) -> torch.Tensor:
49
+ if ptr is None:
50
+ if isinstance(num_nodes, int):
51
+ num_src, num_dst = num_nodes, num_nodes
52
+ else:
53
+ num_src, num_dst = num_nodes
54
+ edge_index = torch.cartesian_prod(torch.arange(num_src, dtype=torch.long, device=device),
55
+ torch.arange(num_dst, dtype=torch.long, device=device)).t()
56
+ else:
57
+ if isinstance(ptr, torch.Tensor):
58
+ ptr_src, ptr_dst = ptr, ptr
59
+ num_src_batch = num_dst_batch = ptr[1:] - ptr[:-1]
60
+ else:
61
+ ptr_src, ptr_dst = ptr
62
+ num_src_batch = ptr_src[1:] - ptr_src[:-1]
63
+ num_dst_batch = ptr_dst[1:] - ptr_dst[:-1]
64
+ edge_index = torch.cat(
65
+ [torch.cartesian_prod(torch.arange(num_src, dtype=torch.long, device=device),
66
+ torch.arange(num_dst, dtype=torch.long, device=device)) + p
67
+ for num_src, num_dst, p in zip(num_src_batch, num_dst_batch, torch.stack([ptr_src, ptr_dst], dim=1))],
68
+ dim=0)
69
+ edge_index = edge_index.t()
70
+ if isinstance(num_nodes, int) and not loop:
71
+ edge_index = edge_index[:, edge_index[0] != edge_index[1]]
72
+ return edge_index.contiguous()
73
+
74
+
75
+ def bipartite_dense_to_sparse(adj: torch.Tensor) -> torch.Tensor:
76
+ index = adj.nonzero(as_tuple=True)
77
+ if len(index) == 3:
78
+ batch_src = index[0] * adj.size(1)
79
+ batch_dst = index[0] * adj.size(2)
80
+ index = (batch_src + index[1], batch_dst + index[2])
81
+ return torch.stack(index, dim=0)
82
+
83
+
84
+ def unbatch(
85
+ src: torch.Tensor,
86
+ batch: torch.Tensor,
87
+ dim: int = 0) -> List[torch.Tensor]:
88
+ sizes = degree(batch, dtype=torch.long).tolist()
89
+ return src.split(sizes, dim)
backups/dev/utils/metrics.py ADDED
@@ -0,0 +1,692 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import itertools
4
+ import multiprocessing as mp
5
+ import torch.nn.functional as F
6
+ from pathlib import Path
7
+ from torch.nn import CrossEntropyLoss
8
+ from torch_scatter import gather_csr
9
+ from torch_scatter import segment_csr
10
+ from torchmetrics import Metric
11
+ from typing import Optional, Tuple, Dict, List
12
+
13
+
14
+ __all__ = ['minADE', 'minFDE', 'TokenCls', 'StateAccuracy', 'GridOverlapRate']
15
+
16
+
17
+ class CustomCrossEntropyLoss(CrossEntropyLoss):
18
+
19
+ def __init__(self, label_smoothing=0.0, reduction='mean'):
20
+ super(CustomCrossEntropyLoss, self).__init__()
21
+ self.label_smoothing = label_smoothing
22
+ self.reduction = reduction
23
+
24
+ def forward(self, input, target):
25
+ num_classes = input.size(1)
26
+
27
+ log_probs = F.log_softmax(input, dim=1)
28
+
29
+ with torch.no_grad():
30
+ smooth_target = torch.zeros_like(log_probs).scatter_(1, target.unsqueeze(1), 1)
31
+ smooth_target = smooth_target * (1 - self.label_smoothing) + self.label_smoothing / num_classes
32
+
33
+ loss = -torch.sum(log_probs * smooth_target, dim=1)
34
+
35
+ if self.reduction == 'mean':
36
+ return loss.mean()
37
+ elif self.reduction == 'sum':
38
+ return loss.sum()
39
+ else:
40
+ return loss
41
+
42
+
43
+ def topk(
44
+ max_guesses: int,
45
+ pred: torch.Tensor,
46
+ prob: Optional[torch.Tensor] = None,
47
+ ptr: Optional[torch.Tensor] = None,
48
+ joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
49
+ max_guesses = min(max_guesses, pred.size(1))
50
+ if max_guesses == pred.size(1):
51
+ if prob is not None:
52
+ prob = prob / prob.sum(dim=-1, keepdim=True)
53
+ else:
54
+ prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
55
+ return pred, prob
56
+ else:
57
+ if prob is not None:
58
+ if joint:
59
+ if ptr is None:
60
+ inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True),
61
+ k=max_guesses, dim=-1, largest=True, sorted=True)[1]
62
+ inds_topk = inds_topk.repeat(pred.size(0), 1)
63
+ else:
64
+ inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr,
65
+ reduce='mean'),
66
+ k=max_guesses, dim=-1, largest=True, sorted=True)[1]
67
+ inds_topk = gather_csr(src=inds_topk, indptr=ptr)
68
+ else:
69
+ inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1]
70
+ pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
71
+ prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
72
+ prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True)
73
+ else:
74
+ pred_topk = pred[:, :max_guesses]
75
+ prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
76
+ return pred_topk, prob_topk
77
+
78
+
79
+ def topkind(
80
+ max_guesses: int,
81
+ pred: torch.Tensor,
82
+ prob: Optional[torch.Tensor] = None,
83
+ ptr: Optional[torch.Tensor] = None,
84
+ joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
85
+ max_guesses = min(max_guesses, pred.size(1))
86
+ if max_guesses == pred.size(1):
87
+ if prob is not None:
88
+ prob = prob / prob.sum(dim=-1, keepdim=True)
89
+ else:
90
+ prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
91
+ return pred, prob, None
92
+ else:
93
+ if prob is not None:
94
+ if joint:
95
+ if ptr is None:
96
+ inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True),
97
+ k=max_guesses, dim=-1, largest=True, sorted=True)[1]
98
+ inds_topk = inds_topk.repeat(pred.size(0), 1)
99
+ else:
100
+ inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr,
101
+ reduce='mean'),
102
+ k=max_guesses, dim=-1, largest=True, sorted=True)[1]
103
+ inds_topk = gather_csr(src=inds_topk, indptr=ptr)
104
+ else:
105
+ inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1]
106
+ pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
107
+ prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
108
+ prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True)
109
+ else:
110
+ pred_topk = pred[:, :max_guesses]
111
+ prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
112
+ return pred_topk, prob_topk, inds_topk
113
+
114
+
115
+ def valid_filter(
116
+ pred: torch.Tensor,
117
+ target: torch.Tensor,
118
+ prob: Optional[torch.Tensor] = None,
119
+ valid_mask: Optional[torch.Tensor] = None,
120
+ ptr: Optional[torch.Tensor] = None,
121
+ keep_invalid_final_step: bool = True) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
122
+ torch.Tensor, torch.Tensor]:
123
+ if valid_mask is None:
124
+ valid_mask = target.new_ones(target.size()[:-1], dtype=torch.bool)
125
+ if keep_invalid_final_step:
126
+ filter_mask = valid_mask.any(dim=-1)
127
+ else:
128
+ filter_mask = valid_mask[:, -1]
129
+ pred = pred[filter_mask]
130
+ target = target[filter_mask]
131
+ if prob is not None:
132
+ prob = prob[filter_mask]
133
+ valid_mask = valid_mask[filter_mask]
134
+ if ptr is not None:
135
+ num_nodes_batch = segment_csr(src=filter_mask.long(), indptr=ptr, reduce='sum')
136
+ ptr = num_nodes_batch.new_zeros((num_nodes_batch.size(0) + 1,))
137
+ torch.cumsum(num_nodes_batch, dim=0, out=ptr[1:])
138
+ else:
139
+ ptr = target.new_tensor([0, target.size(0)])
140
+ return pred, target, prob, valid_mask, ptr
141
+
142
+
143
+ def new_batch_nms(pred_trajs, dist_thresh, num_ret_modes=6):
144
+ """
145
+
146
+ Args:
147
+ pred_trajs (batch_size, num_modes, num_timestamps, 7)
148
+ pred_scores (batch_size, num_modes):
149
+ dist_thresh (float):
150
+ num_ret_modes (int, optional): Defaults to 6.
151
+
152
+ Returns:
153
+ ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)
154
+ ret_scores (batch_size, num_ret_modes)
155
+ ret_idxs (batch_size, num_ret_modes)
156
+ """
157
+ batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape
158
+ pred_goals = pred_trajs[:, :, -1, :]
159
+ dist = (pred_goals[:, :, None, 0:2] - pred_goals[:, None, :, 0:2]).norm(dim=-1)
160
+ nearby_neighbor = dist < dist_thresh
161
+ pred_scores = nearby_neighbor.sum(dim=-1) / num_modes
162
+
163
+ sorted_idxs = pred_scores.argsort(dim=-1, descending=True)
164
+ bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)
165
+ sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]
166
+ sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7)
167
+ sorted_pred_goals = sorted_pred_trajs[:, :, -1, :] # (batch_size, num_modes, 7)
168
+
169
+ dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
170
+ point_cover_mask = (dist < dist_thresh)
171
+
172
+ point_val = sorted_pred_scores.clone() # (batch_size, N)
173
+ point_val_selected = torch.zeros_like(point_val) # (batch_size, N)
174
+
175
+ ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()
176
+ ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim)
177
+ ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes)
178
+ bs_idxs = torch.arange(batch_size).type_as(ret_idxs)
179
+
180
+ for k in range(num_ret_modes):
181
+ cur_idx = point_val.argmax(dim=-1) # (batch_size)
182
+ ret_idxs[:, k] = cur_idx
183
+
184
+ new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N)
185
+ point_val = point_val * (~new_cover_mask).float() # (batch_size, N)
186
+ point_val_selected[bs_idxs, cur_idx] = -1
187
+ point_val += point_val_selected
188
+
189
+ ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx]
190
+ ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]
191
+
192
+ bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)
193
+
194
+ ret_idxs = sorted_idxs[bs_idxs, ret_idxs]
195
+ return ret_trajs, ret_scores, ret_idxs
196
+
197
+
198
+ def batch_nms(pred_trajs, pred_scores,
199
+ dist_thresh, num_ret_modes=6,
200
+ mode='static', speed=None):
201
+ """
202
+
203
+ Args:
204
+ pred_trajs (batch_size, num_modes, num_timestamps, 7)
205
+ pred_scores (batch_size, num_modes):
206
+ dist_thresh (float):
207
+ num_ret_modes (int, optional): Defaults to 6.
208
+
209
+ Returns:
210
+ ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)
211
+ ret_scores (batch_size, num_ret_modes)
212
+ ret_idxs (batch_size, num_ret_modes)
213
+ """
214
+ batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape
215
+
216
+ sorted_idxs = pred_scores.argsort(dim=-1, descending=True)
217
+ bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)
218
+ sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]
219
+ sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7)
220
+ sorted_pred_goals = sorted_pred_trajs[:, :, -1, :] # (batch_size, num_modes, 7)
221
+
222
+ if mode == "speed":
223
+ scale = torch.ones(batch_size).to(sorted_pred_goals.device)
224
+ lon_dist_thresh = 4 * scale
225
+ lat_dist_thresh = 0.5 * scale
226
+ lon_dist = (sorted_pred_goals[:, :, None, [0]] - sorted_pred_goals[:, None, :, [0]]).norm(dim=-1)
227
+ lat_dist = (sorted_pred_goals[:, :, None, [1]] - sorted_pred_goals[:, None, :, [1]]).norm(dim=-1)
228
+ point_cover_mask = (lon_dist < lon_dist_thresh[:, None, None]) & (lat_dist < lat_dist_thresh[:, None, None])
229
+ else:
230
+ dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
231
+ point_cover_mask = (dist < dist_thresh)
232
+
233
+ point_val = sorted_pred_scores.clone() # (batch_size, N)
234
+ point_val_selected = torch.zeros_like(point_val) # (batch_size, N)
235
+
236
+ ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()
237
+ ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim)
238
+ ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes)
239
+ bs_idxs = torch.arange(batch_size).type_as(ret_idxs)
240
+
241
+ for k in range(num_ret_modes):
242
+ cur_idx = point_val.argmax(dim=-1) # (batch_size)
243
+ ret_idxs[:, k] = cur_idx
244
+
245
+ new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N)
246
+ point_val = point_val * (~new_cover_mask).float() # (batch_size, N)
247
+ point_val_selected[bs_idxs, cur_idx] = -1
248
+ point_val += point_val_selected
249
+
250
+ ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx]
251
+ ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]
252
+
253
+ bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)
254
+
255
+ ret_idxs = sorted_idxs[bs_idxs, ret_idxs]
256
+ return ret_trajs, ret_scores, ret_idxs
257
+
258
+
259
+ def batch_nms_token(pred_trajs, pred_scores,
260
+ dist_thresh, num_ret_modes=6,
261
+ mode='static', speed=None):
262
+ """
263
+ Args:
264
+ pred_trajs (batch_size, num_modes, num_timestamps, 7)
265
+ pred_scores (batch_size, num_modes):
266
+ dist_thresh (float):
267
+ num_ret_modes (int, optional): Defaults to 6.
268
+
269
+ Returns:
270
+ ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)
271
+ ret_scores (batch_size, num_ret_modes)
272
+ ret_idxs (batch_size, num_ret_modes)
273
+ """
274
+ batch_size, num_modes, num_feat_dim = pred_trajs.shape
275
+
276
+ sorted_idxs = pred_scores.argsort(dim=-1, descending=True)
277
+ bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)
278
+ sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]
279
+ sorted_pred_goals = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7)
280
+
281
+ if mode == "nearby":
282
+ dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
283
+ values, indices = torch.topk(dist, 5, dim=-1, largest=False)
284
+ thresh_hold = values[..., -1]
285
+ point_cover_mask = dist < thresh_hold[..., None]
286
+ else:
287
+ dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
288
+ point_cover_mask = (dist < dist_thresh)
289
+
290
+ point_val = sorted_pred_scores.clone() # (batch_size, N)
291
+ point_val_selected = torch.zeros_like(point_val) # (batch_size, N)
292
+
293
+ ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()
294
+ ret_goals = sorted_pred_goals.new_zeros(batch_size, num_ret_modes, num_feat_dim)
295
+ ret_scores = sorted_pred_goals.new_zeros(batch_size, num_ret_modes)
296
+ bs_idxs = torch.arange(batch_size).type_as(ret_idxs)
297
+
298
+ for k in range(num_ret_modes):
299
+ cur_idx = point_val.argmax(dim=-1) # (batch_size)
300
+ ret_idxs[:, k] = cur_idx
301
+
302
+ new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N)
303
+ point_val = point_val * (~new_cover_mask).float() # (batch_size, N)
304
+ point_val_selected[bs_idxs, cur_idx] = -1
305
+ point_val += point_val_selected
306
+
307
+ ret_goals[:, k] = sorted_pred_goals[bs_idxs, cur_idx]
308
+ ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]
309
+
310
+ bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)
311
+
312
+ ret_idxs = sorted_idxs[bs_idxs, ret_idxs]
313
+ return ret_goals, ret_scores, ret_idxs
314
+
315
+
316
+ class TokenCls(Metric):
317
+
318
+ def __init__(self,
319
+ max_guesses: int = 6,
320
+ **kwargs) -> None:
321
+ super(TokenCls, self).__init__(**kwargs)
322
+ self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
323
+ self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
324
+ self.max_guesses = max_guesses
325
+
326
+ def update(self,
327
+ pred: torch.Tensor,
328
+ target: torch.Tensor,
329
+ valid_mask: Optional[torch.Tensor] = None) -> None:
330
+ target = target[..., None]
331
+ acc = (pred[:, :self.max_guesses] == target).any(dim=1) * valid_mask
332
+ self.sum += acc.sum()
333
+ self.count += valid_mask.sum()
334
+
335
+ def compute(self) -> torch.Tensor:
336
+ return self.sum / self.count
337
+
338
+
339
+ class minMultiFDE(Metric):
340
+
341
+ def __init__(self,
342
+ max_guesses: int = 6,
343
+ **kwargs) -> None:
344
+ super(minMultiFDE, self).__init__(**kwargs)
345
+ self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
346
+ self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
347
+ self.max_guesses = max_guesses
348
+
349
+ def update(self,
350
+ pred: torch.Tensor,
351
+ target: torch.Tensor,
352
+ prob: Optional[torch.Tensor] = None,
353
+ valid_mask: Optional[torch.Tensor] = None,
354
+ keep_invalid_final_step: bool = True) -> None:
355
+ pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step)
356
+ pred_topk, _ = topk(self.max_guesses, pred, prob)
357
+ inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1)
358
+ self.sum += torch.norm(pred_topk[torch.arange(pred.size(0)), :, inds_last] -
359
+ target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2),
360
+ p=2, dim=-1).min(dim=-1)[0].sum()
361
+ self.count += pred.size(0)
362
+
363
+ def compute(self) -> torch.Tensor:
364
+ return self.sum / self.count
365
+
366
+
367
+ class minFDE(Metric):
368
+
369
+ def __init__(self,
370
+ max_guesses: int = 6,
371
+ **kwargs) -> None:
372
+ super(minFDE, self).__init__(**kwargs)
373
+ self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
374
+ self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
375
+ self.max_guesses = max_guesses
376
+ self.eval_timestep = 70
377
+
378
+ def update(self,
379
+ pred: torch.Tensor,
380
+ target: torch.Tensor,
381
+ prob: Optional[torch.Tensor] = None,
382
+ valid_mask: Optional[torch.Tensor] = None,
383
+ keep_invalid_final_step: bool = True) -> None:
384
+ eval_timestep = min(self.eval_timestep, pred.shape[1]) - 1
385
+ self.sum += ((torch.norm(pred[:, eval_timestep-1:eval_timestep] - target[:, eval_timestep-1:eval_timestep], p=2, dim=-1) *
386
+ valid_mask[:, eval_timestep-1].unsqueeze(1)).sum(dim=-1)).sum()
387
+ self.count += valid_mask[:, eval_timestep-1].sum()
388
+
389
+ def compute(self) -> torch.Tensor:
390
+ return self.sum / self.count
391
+
392
+
393
+ class minMultiADE(Metric):
394
+
395
+ def __init__(self,
396
+ max_guesses: int = 6,
397
+ **kwargs) -> None:
398
+ super(minMultiADE, self).__init__(**kwargs)
399
+ self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
400
+ self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
401
+ self.max_guesses = max_guesses
402
+
403
+ def update(self,
404
+ pred: torch.Tensor,
405
+ target: torch.Tensor,
406
+ prob: Optional[torch.Tensor] = None,
407
+ valid_mask: Optional[torch.Tensor] = None,
408
+ keep_invalid_final_step: bool = True,
409
+ min_criterion: str = 'FDE') -> None:
410
+ pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step)
411
+ pred_topk, _ = topk(self.max_guesses, pred, prob)
412
+ if min_criterion == 'FDE':
413
+ inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1)
414
+ inds_best = torch.norm(
415
+ pred_topk[torch.arange(pred.size(0)), :, inds_last] -
416
+ target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).argmin(dim=-1)
417
+ self.sum += ((torch.norm(pred_topk[torch.arange(pred.size(0)), inds_best] - target, p=2, dim=-1) *
418
+ valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum()
419
+ elif min_criterion == 'ADE':
420
+ self.sum += ((torch.norm(pred_topk - target.unsqueeze(1), p=2, dim=-1) *
421
+ valid_mask.unsqueeze(1)).sum(dim=-1).min(dim=-1)[0] / valid_mask.sum(dim=-1)).sum()
422
+ else:
423
+ raise ValueError('{} is not a valid criterion'.format(min_criterion))
424
+ self.count += pred.size(0)
425
+
426
+ def compute(self) -> torch.Tensor:
427
+ return self.sum / self.count
428
+
429
+
430
+ class minADE(Metric):
431
+
432
+ def __init__(self,
433
+ max_guesses: int = 6,
434
+ **kwargs) -> None:
435
+ super(minADE, self).__init__(**kwargs)
436
+ self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
437
+ self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
438
+ self.max_guesses = max_guesses
439
+ self.eval_timestep = 70
440
+
441
+ def update(self,
442
+ pred: torch.Tensor,
443
+ target: torch.Tensor,
444
+ prob: Optional[torch.Tensor] = None,
445
+ valid_mask: Optional[torch.Tensor] = None,
446
+ keep_invalid_final_step: bool = True,
447
+ min_criterion: str = 'ADE') -> None:
448
+ # pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step)
449
+ # pred_topk, _ = topk(self.max_guesses, pred, prob)
450
+ # if min_criterion == 'FDE':
451
+ # inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1)
452
+ # inds_best = torch.norm(
453
+ # pred[torch.arange(pred.size(0)), :, inds_last] -
454
+ # target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).argmin(dim=-1)
455
+ # self.sum += ((torch.norm(pred[torch.arange(pred.size(0)), inds_best] - target, p=2, dim=-1) *
456
+ # valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum()
457
+ # elif min_criterion == 'ADE':
458
+ # self.sum += ((torch.norm(pred - target.unsqueeze(1), p=2, dim=-1) *
459
+ # valid_mask.unsqueeze(1)).sum(dim=-1).min(dim=-1)[0] / valid_mask.sum(dim=-1)).sum()
460
+ # else:
461
+ # raise ValueError('{} is not a valid criterion'.format(min_criterion))
462
+ eval_timestep = min(self.eval_timestep, pred.shape[1])
463
+ self.sum += ((torch.norm(pred[:, :eval_timestep] - target[:, :eval_timestep], p=2, dim=-1) * valid_mask[:, :eval_timestep]).sum(dim=-1) / pred.shape[1]).sum()
464
+ self.count += valid_mask[:, :eval_timestep].any(dim=-1).sum()
465
+
466
+ def compute(self) -> torch.Tensor:
467
+ return self.sum / self.count
468
+
469
+
470
+ class AverageMeter(Metric):
471
+
472
+ def __init__(self, **kwargs) -> None:
473
+ super(AverageMeter, self).__init__(**kwargs)
474
+ self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
475
+ self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
476
+
477
+ def update(self, val: torch.Tensor) -> None:
478
+ self.sum += val.sum()
479
+ self.count += val.numel()
480
+
481
+ def compute(self) -> torch.Tensor:
482
+ return self.sum / self.count
483
+
484
+
485
+ class StateAccuracy(Metric):
486
+
487
+ def __init__(self, state_token: Dict[str, int], **kwargs) -> None:
488
+ super().__init__(**kwargs)
489
+ self.invalid_state = int(state_token['invalid'])
490
+ self.valid_state = int(state_token['valid'])
491
+ self.enter_state = int(state_token['enter'])
492
+ self.exit_state = int(state_token['exit'])
493
+
494
+ self.add_state('valid', default=torch.tensor(0), dist_reduce_fx='sum')
495
+ self.add_state('valid_count', default=torch.tensor(0), dist_reduce_fx='sum')
496
+ self.add_state('invalid', default=torch.tensor(0), dist_reduce_fx='sum')
497
+ self.add_state('invalid_count', default=torch.tensor(0), dist_reduce_fx='sum')
498
+
499
+ def update(self,
500
+ state_idx: torch.Tensor,
501
+ valid_mask: Optional[torch.Tensor] = None) -> None:
502
+
503
+ num_agent, num_step = state_idx.shape
504
+
505
+ # check the evaluation outputs
506
+ for a in range(num_agent):
507
+ bos_idx = torch.where(state_idx[a] == self.enter_state)[0]
508
+ eos_idx = torch.where(state_idx[a] == self.exit_state)[0]
509
+ bos = 0
510
+ eos = num_step - 1
511
+ if len(bos_idx) > 0:
512
+ bos = bos_idx[0]
513
+ self.invalid += (state_idx[a, :bos] == self.invalid_state).sum()
514
+ self.invalid_count += len(state_idx[a, :bos])
515
+ if len(eos_idx) > 0:
516
+ eos = eos_idx[0]
517
+ self.invalid += (state_idx[a, eos + 1:] == self.invalid_state).sum()
518
+ self.invalid_count += len(state_idx[a, eos + 1:])
519
+ self.valid += (state_idx[a, bos + 1 : eos] == self.valid_state).sum()
520
+ self.valid_count += len(state_idx[a, bos + 1 : eos])
521
+
522
+ # check the tokenization
523
+ if valid_mask is not None:
524
+
525
+ state_idx = state_idx.roll(shifts=1, dims=1)
526
+
527
+ for a in range(num_agent):
528
+ bos_idx = torch.where(state_idx[a] == self.enter_state)[0]
529
+ eos_idx = torch.where(state_idx[a] == self.exit_state)[0]
530
+ bos = 0
531
+ eos = num_step - 1
532
+ if len(bos_idx) > 0:
533
+ bos = bos_idx[0]
534
+ self.invalid += (valid_mask[a, :bos] == 0).sum()
535
+ self.invalid_count += len(valid_mask[a, :bos])
536
+ if len(eos_idx) > 0:
537
+ eos = eos_idx[-1]
538
+ self.invalid += (valid_mask[a, eos + 1:] != 0).sum()
539
+ self.invalid_count += len(valid_mask[a, eos + 1:])
540
+ self.invalid += (((state_idx[a, bos : eos + 1] > 0) != valid_mask[a, bos : eos + 1])[valid_mask[a, bos : eos + 1] == 0]).sum()
541
+ self.invalid_count += (valid_mask[a, bos : eos + 1] == 0).sum()
542
+ self.valid += (((state_idx[a, bos : eos + 1] > 0) != valid_mask[a, bos : eos + 1])[valid_mask[a, bos : eos + 1] == 1]).sum()
543
+ self.valid_count += (valid_mask[a, bos : eos + 1] == 1).sum()
544
+
545
+ def compute(self) -> Dict[str, torch.Tensor]:
546
+ return {'valid': self.valid / self.valid_count,
547
+ 'invalid': self.invalid / self.invalid_count,
548
+ }
549
+
550
+ def __repr__(self):
551
+ head = "Results of " + self.__class__.__name__
552
+ results = self.compute()
553
+ body = [
554
+ "valid: {}".format(results['valid']),
555
+ "invalid: {}".format(results['invalid']),
556
+ ]
557
+ _repr_indent = 4
558
+ lines = [head] + [" " * _repr_indent + line for line in body]
559
+ return "\n".join(lines)
560
+
561
+
562
+ class GridOverlapRate(Metric):
563
+
564
+ def __init__(self, num_step, state_token, seed_size, **kwargs) -> None:
565
+ super().__init__(**kwargs)
566
+ self.num_step = num_step
567
+ self.enter_state = int(state_token['enter'])
568
+ self.seed_size = seed_size
569
+ self.add_state('num_overlap_t', default=torch.zeros(num_step).long(), dist_reduce_fx='sum')
570
+ self.add_state('num_insert_agent_t', default=torch.zeros(num_step).long(), dist_reduce_fx='sum')
571
+ self.add_state('num_total_agent_t', default=torch.zeros(num_step).long(), dist_reduce_fx='sum')
572
+ self.add_state('num_exceed_seed_t', default=torch.zeros(num_step).long(), dist_reduce_fx='sum')
573
+
574
+ def update(self,
575
+ state_token: torch.Tensor,
576
+ grid_index: torch.Tensor) -> None:
577
+
578
+ for t in range(self.num_step):
579
+ inrange_mask_t = grid_index[:, t] != -1
580
+ insert_mask_t = (state_token[:, t] == self.enter_state) & inrange_mask_t
581
+ self.num_total_agent_t[t] += inrange_mask_t.sum()
582
+ self.num_insert_agent_t[t] += insert_mask_t.sum()
583
+ self.num_exceed_seed_t[t] += int(insert_mask_t.sum() >= self.seed_size)
584
+
585
+ occupied_grids = set(grid_index[:, t][(grid_index[:, t] != -1) & (state_token[:, t] != self.enter_state)].tolist())
586
+ to_inserted_grids = grid_index[:, t][(grid_index[:, t] != -1) & (state_token[:, t] == self.enter_state)].tolist()
587
+ while to_inserted_grids:
588
+ grid_index_t_i = to_inserted_grids.pop()
589
+ if grid_index_t_i in occupied_grids:
590
+ self.num_overlap_t[t] += 1
591
+ occupied_grids.add(grid_index_t_i)
592
+
593
+ def compute(self) -> Dict[str, torch.Tensor]:
594
+ overlap_rate_t = self.num_overlap_t / self.num_insert_agent_t
595
+ overlap_rate_t.nan_to_num_()
596
+ return {'num_overlap_t': self.num_overlap_t,
597
+ 'num_insert_agent_t': self.num_insert_agent_t,
598
+ 'num_total_agent_t': self.num_total_agent_t,
599
+ 'overlap_rate_t': overlap_rate_t,
600
+ 'num_exceed_seed_t': self.num_exceed_seed_t,
601
+ }
602
+
603
+ def __repr__(self):
604
+ head = "Results of " + self.__class__.__name__
605
+ results = self.compute()
606
+ body = [
607
+ "num_overlap_t: {}".format(results['num_overlap_t'].tolist()),
608
+ "num_insert_agent_t: {}".format(results['num_insert_agent_t'].tolist()),
609
+ "num_total_agent_t: {}".format(results['num_total_agent_t'].tolist()),
610
+ "overlap_rate_t: {}".format(results['overlap_rate_t'].tolist()),
611
+ "num_exceed_seed_t: {}".format(results['num_exceed_seed_t'].tolist()),
612
+ ]
613
+ _repr_indent = 4
614
+ lines = [head] + [" " * _repr_indent + line for line in body]
615
+ return "\n".join(lines)
616
+
617
+
618
+ class NumInsertAccuracy(Metric):
619
+
620
+ def __init__(self, state_token: Dict[str, int], **kwargs) -> None:
621
+ super().__init__(**kwargs)
622
+ self.invalid_state = int(state_token['invalid'])
623
+ self.valid_state = int(state_token['valid'])
624
+ self.enter_state = int(state_token['enter'])
625
+ self.exit_state = int(state_token['exit'])
626
+
627
+ self.add_state('valid', default=torch.tensor(0), dist_reduce_fx='sum')
628
+ self.add_state('valid_count', default=torch.tensor(0), dist_reduce_fx='sum')
629
+ self.add_state('invalid', default=torch.tensor(0), dist_reduce_fx='sum')
630
+ self.add_state('invalid_count', default=torch.tensor(0), dist_reduce_fx='sum')
631
+
632
+ def update(self,
633
+ state_idx: torch.Tensor,
634
+ valid_mask: Optional[torch.Tensor] = None) -> None:
635
+
636
+ num_agent, num_step = state_idx.shape
637
+
638
+ # check the evaluation outputs
639
+ for a in range(num_agent):
640
+ bos_idx = torch.where(state_idx[a] == self.enter_state)[0]
641
+ eos_idx = torch.where(state_idx[a] == self.exit_state)[0]
642
+ bos = 0
643
+ eos = num_step - 1
644
+ if len(bos_idx) > 0:
645
+ bos = bos_idx[0]
646
+ self.invalid += (state_idx[a, :bos] == self.invalid_state).sum()
647
+ self.invalid_count += len(state_idx[a, :bos])
648
+ if len(eos_idx) > 0:
649
+ eos = eos_idx[0]
650
+ self.invalid += (state_idx[a, eos + 1:] == self.invalid_state).sum()
651
+ self.invalid_count += len(state_idx[a, eos + 1:])
652
+ self.valid += (state_idx[a, bos + 1 : eos] == self.valid_state).sum()
653
+ self.valid_count += len(state_idx[a, bos + 1 : eos])
654
+
655
+ # check the tokenization
656
+ if valid_mask is not None:
657
+
658
+ state_idx = state_idx.roll(shifts=1, dims=1)
659
+
660
+ for a in range(num_agent):
661
+ bos_idx = torch.where(state_idx[a] == self.enter_state)[0]
662
+ eos_idx = torch.where(state_idx[a] == self.exit_state)[0]
663
+ bos = 0
664
+ eos = num_step - 1
665
+ if len(bos_idx) > 0:
666
+ bos = bos_idx[0]
667
+ self.invalid += (valid_mask[a, :bos] == 0).sum()
668
+ self.invalid_count += len(valid_mask[a, :bos])
669
+ if len(eos_idx) > 0:
670
+ eos = eos_idx[-1]
671
+ self.invalid += (valid_mask[a, eos + 1:] != 0).sum()
672
+ self.invalid_count += len(valid_mask[a, eos + 1:])
673
+ self.invalid += (((state_idx[a, bos : eos + 1] > 0) != valid_mask[a, bos : eos + 1])[valid_mask[a, bos : eos + 1] == 0]).sum()
674
+ self.invalid_count += (valid_mask[a, bos : eos + 1] == 0).sum()
675
+ self.valid += (((state_idx[a, bos : eos + 1] > 0) != valid_mask[a, bos : eos + 1])[valid_mask[a, bos : eos + 1] == 1]).sum()
676
+ self.valid_count += (valid_mask[a, bos : eos + 1] == 1).sum()
677
+
678
+ def compute(self) -> Dict[str, torch.Tensor]:
679
+ return {'valid': self.valid / self.valid_count,
680
+ 'invalid': self.invalid / self.invalid_count,
681
+ }
682
+
683
+ def __repr__(self):
684
+ head = "Results of " + self.__class__.__name__
685
+ results = self.compute()
686
+ body = [
687
+ "valid: {}".format(results['valid']),
688
+ "invalid: {}".format(results['invalid']),
689
+ ]
690
+ _repr_indent = 4
691
+ lines = [head] + [" " * _repr_indent + line for line in body]
692
+ return "\n".join(lines)
backups/dev/utils/visualization.py ADDED
@@ -0,0 +1,1145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import torch
4
+ import pickle
5
+ import matplotlib.pyplot as plt
6
+ import tensorflow as tf
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+ import fnmatch
10
+ import seaborn as sns
11
+ import matplotlib.axes as Axes
12
+ import matplotlib.transforms as mtransforms
13
+ from PIL import Image
14
+ from functools import wraps
15
+ from typing import Sequence, Union, Optional
16
+ from tqdm import tqdm
17
+ from typing import List, Literal
18
+ from argparse import ArgumentParser
19
+ from scipy.ndimage.filters import gaussian_filter
20
+ from matplotlib.patches import FancyBboxPatch, Polygon, Rectangle, Circle
21
+ from matplotlib.collections import LineCollection
22
+ from torch_geometric.data import HeteroData, Dataset
23
+ from waymo_open_dataset.protos import scenario_pb2
24
+
25
+ from dev.utils.func import CONSOLE
26
+ from dev.modules.attr_tokenizer import Attr_Tokenizer
27
+ from dev.datasets.preprocess import TokenProcessor, cal_polygon_contour, AGENT_TYPE
28
+ from dev.datasets.scalable_dataset import WaymoTargetBuilder
29
+
30
+
31
+ __all__ = ['plot_occ_grid', 'plot_interact_edge', 'plot_map_edge', 'plot_insert_grid', 'plot_binary_map',
32
+ 'plot_map_token', 'plot_prob_seed', 'plot_scenario', 'get_heatmap', 'draw_heatmap', 'plot_val', 'plot_tokenize']
33
+
34
+
35
+ def safe_run(func):
36
+
37
+ @wraps(func)
38
+ def wrapper1(*args, **kwargs):
39
+ try:
40
+ return func(*args, **kwargs)
41
+ except Exception as e:
42
+ print(e)
43
+ return
44
+
45
+ @wraps(func)
46
+ def wrapper2(*args, **kwargs):
47
+ return func(*args, **kwargs)
48
+
49
+ if int(os.getenv('DEBUG', 0)):
50
+ return wrapper2
51
+ else:
52
+ return wrapper1
53
+
54
+
55
+ @safe_run
56
+ def plot_occ_grid(scenario_id, occ, gt_occ=None, save_path='', mode='agent', prefix=''):
57
+
58
+ def generate_box_edges(matrix, find_value=1):
59
+ y, x = np.where(matrix == find_value)
60
+ edges = []
61
+
62
+ for xi, yi in zip(x, y):
63
+ edges.append([(xi - 0.5, yi - 0.5), (xi + 0.5, yi - 0.5)])
64
+ edges.append([(xi + 0.5, yi - 0.5), (xi + 0.5, yi + 0.5)])
65
+ edges.append([(xi + 0.5, yi + 0.5), (xi - 0.5, yi + 0.5)])
66
+ edges.append([(xi - 0.5, yi + 0.5), (xi - 0.5, yi - 0.5)])
67
+
68
+ return edges
69
+
70
+ os.makedirs(save_path, exist_ok=True)
71
+ n = int(math.sqrt(occ.shape[-1]))
72
+
73
+ plot_n = 3
74
+ plot_t = 5
75
+
76
+ occ_list = []
77
+ for i in range(plot_n):
78
+ for j in range(plot_t):
79
+ occ_list.append(occ[i, j].reshape(n, n))
80
+
81
+ occ_gt_list = []
82
+ if gt_occ is not None:
83
+ for i in range(plot_n):
84
+ for j in range(plot_t):
85
+ occ_gt_list.append(gt_occ[i, j].reshape(n, n))
86
+
87
+ row_labels = [f'n={n}' for n in range(plot_n)]
88
+ col_labels = [f't={t}' for t in range(plot_t)]
89
+
90
+ fig, axes = plt.subplots(plot_n, plot_t, figsize=(9, 6))
91
+ plt.subplots_adjust(wspace=0.1, hspace=0.1)
92
+
93
+ for i, ax in enumerate(axes.flat):
94
+ # NOTE: do not set vmin and vamx!
95
+ ax.imshow(occ_list[i], cmap='viridis', interpolation='nearest')
96
+ ax.axis('off')
97
+
98
+ if occ_gt_list:
99
+ gt_edges = generate_box_edges(occ_gt_list[i])
100
+ gts = LineCollection(gt_edges, colors='blue', linewidths=0.5)
101
+ ax.add_collection(gts)
102
+ insert_edges = generate_box_edges(occ_gt_list[i], find_value=-1)
103
+ inserts = LineCollection(insert_edges, colors='red', linewidths=0.5)
104
+ ax.add_collection(inserts)
105
+
106
+ ax.add_patch(plt.Rectangle((-0.5, -0.5), occ_list[i].shape[1], occ_list[i].shape[0],
107
+ linewidth=2, edgecolor='black', facecolor='none'))
108
+
109
+ for i, ax in enumerate(axes[:, 0]):
110
+ ax.annotate(row_labels[i], xy=(-0.1, 0.5), xycoords="axes fraction",
111
+ fontsize=12, ha="right", va="center", rotation=0)
112
+
113
+ for j, ax in enumerate(axes[0, :]):
114
+ ax.annotate(col_labels[j], xy=(0.5, 1.05), xycoords="axes fraction",
115
+ fontsize=12, ha="center", va="bottom")
116
+
117
+ plt.savefig(os.path.join(save_path, f'{prefix}{scenario_id}_occ_{mode}.png'), dpi=500, bbox_inches='tight')
118
+ plt.close()
119
+
120
+
121
+ @safe_run
122
+ def plot_interact_edge(edge_index, scenario_ids, batch_sizes, num_seed, num_step, save_path='interact_edge_map',
123
+ **kwargs):
124
+
125
+ num_batch = len(scenario_ids)
126
+ batches = torch.cat([
127
+ torch.arange(num_batch).repeat_interleave(repeats=batch_sizes, dim=0),
128
+ torch.arange(num_batch).repeat_interleave(repeats=num_seed, dim=0),
129
+ ], dim=0).repeat(num_step).numpy()
130
+
131
+ num_agent = batch_sizes.sum() + num_seed * num_batch
132
+ batch_sizes = torch.nn.functional.pad(batch_sizes, (1, 0), mode='constant', value=0)
133
+ ptr = torch.cumsum(batch_sizes, dim=0)
134
+ # assume difference scenarios and different timestep have the same number of seed agents
135
+ ptr_seed = torch.tensor(np.array([0] + [num_seed] * num_batch), device=ptr.device)
136
+
137
+ all_av_index = None
138
+ if 'av_index' in kwargs:
139
+ all_av_index = kwargs.pop('av_index').cpu() - ptr[:-1]
140
+
141
+ is_bos = np.zeros((batch_sizes.sum(), num_step)).astype(np.bool_)
142
+ if 'is_bos' in kwargs:
143
+ is_bos = kwargs.pop('is_bos').cpu().numpy()
144
+
145
+ src_index = torch.unique(edge_index[1])
146
+ for idx, src in enumerate(tqdm(src_index)):
147
+
148
+ src_batch = batches[src]
149
+
150
+ src_row = src % num_agent
151
+ if src_row // batch_sizes.sum() > 0:
152
+ seed_row = src_row % batch_sizes.sum() - ptr_seed[src_batch]
153
+ src_row = batch_sizes[src_batch + 1] + seed_row
154
+ else:
155
+ src_row = src_row - ptr[src_batch]
156
+
157
+ src_col = src // (num_agent)
158
+ src_mask = np.zeros((batch_sizes[src_batch + 1] + num_seed, num_step))
159
+ src_mask[src_row, src_col] = 1
160
+
161
+ tgt_mask = np.zeros((src_mask.shape[0], num_step))
162
+ tgt_index = edge_index[0, edge_index[1] == src]
163
+ for tgt in tgt_index:
164
+
165
+ tgt_batch = batches[tgt]
166
+
167
+ tgt_row = tgt % num_agent
168
+ if tgt_row // batch_sizes.sum() > 0:
169
+ seed_row = tgt_row % batch_sizes.sum() - ptr_seed[tgt_batch]
170
+ tgt_row = batch_sizes[tgt_batch + 1] + seed_row
171
+ else:
172
+ tgt_row = tgt_row - ptr[tgt_batch]
173
+
174
+ tgt_col = tgt // num_agent
175
+ tgt_mask[tgt_row, tgt_col] = 1
176
+ assert tgt_batch == src_batch
177
+
178
+ selected_step = tgt_mask.sum(axis=0) > 0
179
+ if selected_step.sum() > 1:
180
+ print(f"\nidx={idx}", src.item(), src_row.item(), src_col.item())
181
+ print(selected_step)
182
+ print(edge_index[:, edge_index[1] == src].tolist())
183
+
184
+ if all_av_index is not None:
185
+ kwargs['av_index'] = int(all_av_index[src_batch])
186
+
187
+ t = kwargs.get('t', src_col)
188
+ n = kwargs.get('n', 0)
189
+ is_bos_batch = is_bos[ptr[src_batch] : ptr[src_batch + 1]]
190
+ plot_binary_map(src_mask, tgt_mask, save_path, suffix=f'_{scenario_ids[src_batch]}_{t:02d}_{n:02d}_{idx:04d}',
191
+ is_bos=is_bos_batch, **kwargs)
192
+
193
+
194
+ @safe_run
195
+ def plot_map_edge(edge_index, pos_a, data, save_path='map_edge_map'):
196
+
197
+ map_points = data['map_point']['position'][:, :2].cpu().numpy()
198
+ token_pos = data['pt_token']['position'][:, :2].cpu().numpy()
199
+ token_heading = data['pt_token']['orientation'].cpu().numpy()
200
+ num_pt = token_pos.shape[0]
201
+
202
+ agent_index = torch.unique(edge_index[1])
203
+ for i in tqdm(agent_index):
204
+ xy = pos_a[i].cpu().numpy()
205
+ pt_index = edge_index[0, edge_index[1] == i].cpu().numpy()
206
+ pt_index = pt_index % num_pt
207
+
208
+ plt.subplots_adjust(left=0.3, right=0.7, top=0.7, bottom=0.3)
209
+ _, ax = plt.subplots()
210
+ ax.set_axis_off()
211
+
212
+ plot_map_token(ax, map_points, token_pos[pt_index], token_heading[pt_index], colors='blue')
213
+
214
+ ax.scatter(xy[0], xy[1], s=0.5, c='red', edgecolors='none')
215
+
216
+ os.makedirs(save_path, exist_ok=True)
217
+ plt.savefig(os.path.join(save_path, f'map_{i}.png'), dpi=600, bbox_inches='tight')
218
+ plt.close()
219
+
220
+
221
+ def get_heatmap(x, y, prob, s=3, bins=1000):
222
+ heatmap, xedges, yedges = np.histogram2d(x, y, bins=bins, weights=prob, density=True)
223
+
224
+ heatmap = gaussian_filter(heatmap, sigma=s)
225
+
226
+ extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
227
+ return heatmap.T, extent
228
+
229
+
230
+ @safe_run
231
+ def draw_heatmap(vector, vector_prob, gt_idx):
232
+ fig, ax = plt.subplots(figsize=(10, 10))
233
+ vector_prob = vector_prob.cpu().numpy()
234
+
235
+ for j in range(vector.shape[0]):
236
+ if j in gt_idx:
237
+ color = (0, 0, 1)
238
+ else:
239
+ grey_scale = max(0, 0.9 - vector_prob[j])
240
+ color = (0.9, grey_scale, grey_scale)
241
+
242
+ # if lane[j, k, -1] == 0: continue
243
+ x0, y0, x1, y1, = vector[j, :4]
244
+ ax.plot((x0, x1), (y0, y1), color=color, linewidth=2)
245
+
246
+ return plt
247
+
248
+
249
+ @safe_run
250
+ def plot_insert_grid(scenario_id, prob, grid, ego_pos, map, save_path='', prefix='', inference=False, indices=None,
251
+ all_t_in_one=False):
252
+
253
+ """
254
+ prob: float array of shape (num_step, num_grid)
255
+ grid: float array of shape (num_grid, 2)
256
+ """
257
+
258
+ os.makedirs(save_path, exist_ok=True)
259
+
260
+ n = int(math.sqrt(prob.shape[1]))
261
+
262
+ # grid = grid[:, np.newaxis] + ego_pos[np.newaxis, ...]
263
+ for t in range(ego_pos.shape[0]):
264
+
265
+ plt.subplots_adjust(left=0.3, right=0.7, top=0.7, bottom=0.3)
266
+ _, ax = plt.subplots()
267
+
268
+ # plot probability
269
+ prob_t = prob[t].reshape(n, n)
270
+ plt.imshow(prob_t, cmap='viridis', interpolation='nearest')
271
+
272
+ if indices is not None:
273
+ indice = indices[t]
274
+
275
+ if isinstance(indice, (int, float, np.int_)):
276
+ indice = [indice]
277
+
278
+ for _indice in indice:
279
+ if _indice == -1: continue
280
+
281
+ row = _indice // n
282
+ col = _indice % n
283
+
284
+ rect = Rectangle((col - 0.5, row - 0.5), 1, 1, edgecolor='red', facecolor='none', lw=2)
285
+ ax.add_patch(rect)
286
+
287
+ ax.grid(False)
288
+ ax.set_aspect('equal', adjustable='box')
289
+
290
+ plt.title('Prob of Rel Position Grid')
291
+ plt.savefig(os.path.join(save_path, f'{prefix}{scenario_id}_heat_map_{t}.png'), dpi=300, bbox_inches='tight')
292
+ plt.close()
293
+
294
+ if all_t_in_one:
295
+ break
296
+
297
+
298
+ @safe_run
299
+ def plot_insert_grid(scenario_id, prob, indices=None, save_path='', prefix='', inference=False):
300
+
301
+ """
302
+ prob: float array of shape (num_seed, num_step, num_grid)
303
+ grid: float array of shape (num_grid, 2)
304
+ """
305
+
306
+ os.makedirs(save_path, exist_ok=True)
307
+
308
+ n = int(math.sqrt(prob.shape[-1]))
309
+
310
+ plot_n = 3
311
+ plot_t = 5
312
+
313
+ prob_list = []
314
+ for i in range(plot_n):
315
+ for j in range(plot_t):
316
+ prob_list.append(prob[i, j].reshape(n, n))
317
+
318
+ indice_list = []
319
+ if indices is not None:
320
+ for i in range(plot_n):
321
+ for j in range(plot_t):
322
+ indice_list.append(indices[i, j])
323
+
324
+ row_labels = [f'n={n}' for n in range(plot_n)]
325
+ col_labels = [f't={t}' for t in range(plot_t)]
326
+
327
+ fig, axes = plt.subplots(plot_n, plot_t, figsize=(9, 6))
328
+ fig.suptitle('Prob of Insert Position Grid')
329
+ plt.subplots_adjust(wspace=0.1, hspace=0.1)
330
+
331
+ for i, ax in enumerate(axes.flat):
332
+ ax.imshow(prob_list[i], cmap='viridis', interpolation='nearest')
333
+ ax.axis('off')
334
+
335
+ if indice_list:
336
+ row = indice_list[i] // n
337
+ col = indice_list[i] % n
338
+ rect = Rectangle((col - .5, row - .5), 1, 1, edgecolor='red', facecolor='none', lw=2)
339
+ ax.add_patch(rect)
340
+
341
+ ax.add_patch(plt.Rectangle((-0.5, -0.5), prob_list[i].shape[1], prob_list[i].shape[0],
342
+ linewidth=2, edgecolor='black', facecolor='none'))
343
+
344
+ for i, ax in enumerate(axes[:, 0]):
345
+ ax.annotate(row_labels[i], xy=(-0.1, 0.5), xycoords="axes fraction",
346
+ fontsize=12, ha="right", va="center", rotation=0)
347
+
348
+ for j, ax in enumerate(axes[0, :]):
349
+ ax.annotate(col_labels[j], xy=(0.5, 1.05), xycoords="axes fraction",
350
+ fontsize=12, ha="center", va="bottom")
351
+
352
+ ax.grid(False)
353
+ ax.set_aspect('equal', adjustable='box')
354
+
355
+ plt.savefig(os.path.join(save_path, f'{prefix}{scenario_id}_insert_map.png'), dpi=500, bbox_inches='tight')
356
+ plt.close()
357
+
358
+
359
+ @safe_run
360
+ def plot_binary_map(src_mask, tgt_mask, save_path='', suffix='', av_index=None, is_bos=None, **kwargs):
361
+
362
+ from matplotlib.colors import ListedColormap
363
+ os.makedirs(save_path, exist_ok=True)
364
+
365
+ fig, axes = plt.subplots(1, 2, figsize=(10, 8))
366
+
367
+ title = []
368
+ if kwargs.get('t', None) is not None:
369
+ t = kwargs['t']
370
+ title.append(f't={t}')
371
+
372
+ if kwargs.get('n', None) is not None:
373
+ n = kwargs['n']
374
+ title.append(f'n={n}')
375
+
376
+ plt.title(' '.join(title))
377
+
378
+ cmap = ListedColormap(['white', 'green'])
379
+ axes[0].imshow(src_mask, cmap=cmap, interpolation='nearest')
380
+
381
+ cmap = ListedColormap(['white', 'orange'])
382
+ axes[1].imshow(tgt_mask, cmap=cmap, interpolation='nearest')
383
+
384
+ if av_index is not None:
385
+ rect = Rectangle((-0.5, av_index - 0.5), src_mask.shape[1], 1, edgecolor='red', facecolor='none', lw=2)
386
+ axes[0].add_patch(rect)
387
+ rect = Rectangle((-0.5, av_index - 0.5), tgt_mask.shape[1], 1, edgecolor='red', facecolor='none', lw=2)
388
+ axes[1].add_patch(rect)
389
+
390
+ if is_bos is not None:
391
+ rows, cols = np.where(is_bos)
392
+ for row, col in zip(rows, cols):
393
+ rect = Rectangle((col - 0.5, row - 0.5), 1, 1, edgecolor='blue', facecolor='none', lw=1)
394
+ axes[0].add_patch(rect)
395
+ rect = Rectangle((col - 0.5, row - 0.5), 1, 1, edgecolor='blue', facecolor='none', lw=1)
396
+ axes[1].add_patch(rect)
397
+
398
+ for ax in axes:
399
+ ax.set_xticks(range(src_mask.shape[1] + 1), minor=False)
400
+ ax.set_yticks(range(src_mask.shape[0] + 1), minor=False)
401
+ ax.grid(which='major', color='gray', linestyle='--', linewidth=0.5)
402
+
403
+ plt.savefig(os.path.join(save_path, f'map{suffix}.png'), dpi=300, bbox_inches='tight')
404
+ plt.close()
405
+
406
+
407
+ @safe_run
408
+ def plot_prob_seed(scenario_id, prob, save_path, prefix='', indices=None):
409
+
410
+ os.makedirs(save_path, exist_ok=True)
411
+
412
+ plt.figure(figsize=(8, 5))
413
+ plt.imshow(prob, cmap='viridis', aspect='auto')
414
+ plt.colorbar()
415
+
416
+ plt.title('Seed Probability')
417
+
418
+ if indices is not None:
419
+
420
+ for col in range(indices.shape[1]):
421
+ for row in indices[:, col]:
422
+
423
+ if row == -1: continue
424
+
425
+ rect = Rectangle((col - 0.5, row - 0.5), 1, 1, edgecolor='red', facecolor='none', lw=2)
426
+ plt.gca().add_patch(rect)
427
+
428
+ plt.tight_layout()
429
+ plt.savefig(os.path.join(save_path, f'{prefix}{scenario_id}_prob_seed.png'), dpi=300, bbox_inches='tight')
430
+ plt.close()
431
+
432
+
433
+ @safe_run
434
+ def plot_raw():
435
+ plt.figure(figsize=(30, 30))
436
+ plt.rcParams['axes.facecolor']='white'
437
+
438
+ data_path = '/u/xiuyu/work/dev4/data/waymo/scenario/training'
439
+ os.makedirs("data/vis/raw/0/", exist_ok=True)
440
+ file_list = os.listdir(data_path)
441
+
442
+ for cnt_file, file in enumerate(file_list):
443
+ file_path = os.path.join(data_path, file)
444
+ dataset = tf.data.TFRecordDataset(file_path, compression_type='')
445
+ for scenario_idx, data in enumerate(dataset):
446
+ scenario = scenario_pb2.Scenario()
447
+ scenario.ParseFromString(bytearray(data.numpy()))
448
+ tqdm.write(f"scenario id: {scenario.scenario_id}")
449
+
450
+ # draw maps
451
+ for i in range(len(scenario.map_features)):
452
+
453
+ # draw lanes
454
+ if str(scenario.map_features[i].lane) != '':
455
+ line_x = [z.x for z in scenario.map_features[i].lane.polyline]
456
+ line_y = [z.y for z in scenario.map_features[i].lane.polyline]
457
+ plt.scatter(line_x, line_y, c='g', s=5)
458
+ plt.text(line_x[0], line_y[0], str(scenario.map_features[i].id), fontdict={'family': 'serif', 'size': 20, 'color': 'green'})
459
+
460
+ # draw road_edge
461
+ if str(scenario.map_features[i].road_edge) != '':
462
+ road_edge_x = [polyline.x for polyline in scenario.map_features[i].road_edge.polyline]
463
+ road_edge_y = [polyline.y for polyline in scenario.map_features[i].road_edge.polyline]
464
+ plt.scatter(road_edge_x, road_edge_y)
465
+ plt.text(road_edge_x[0], road_edge_y[0], scenario.map_features[i].road_edge.type, fontdict={'family': 'serif', 'size': 20, 'color': 'black'})
466
+ if scenario.map_features[i].road_edge.type == 2:
467
+ plt.scatter(road_edge_x, road_edge_y, c='k')
468
+ elif scenario.map_features[i].road_edge.type == 3:
469
+ plt.scatter(road_edge_x, road_edge_y, c='purple')
470
+ print(scenario.map_features[i].road_edge)
471
+ else:
472
+ plt.scatter(road_edge_x, road_edge_y, c='k')
473
+
474
+ # draw road_line
475
+ if str(scenario.map_features[i].road_line) != '':
476
+ road_line_x = [j.x for j in scenario.map_features[i].road_line.polyline]
477
+ road_line_y = [j.y for j in scenario.map_features[i].road_line.polyline]
478
+ if scenario.map_features[i].road_line.type == 7:
479
+ plt.plot(road_line_x, road_line_y, c='y')
480
+ elif scenario.map_features[i].road_line.type == 8:
481
+ plt.plot(road_line_x, road_line_y, c='y')
482
+ elif scenario.map_features[i].road_line.type == 6:
483
+ plt.plot(road_line_x, road_line_y, c='y')
484
+ elif scenario.map_features[i].road_line.type == 1:
485
+ for i in range(int(len(road_line_x) / 7)):
486
+ plt.plot(road_line_x[i * 7 : 5 + i * 7], road_line_y[i * 7 : 5 + i * 7], color='w')
487
+ elif scenario.map_features[i].road_line.type == 2:
488
+ plt.plot(road_line_x, road_line_y, c='w')
489
+ else:
490
+ plt.plot(road_line_x, road_line_y, c='w')
491
+
492
+ # draw tracks
493
+ scenario_has_invalid_tracks = False
494
+ for i in range(len(scenario.tracks)):
495
+ traj_x = [center.center_x for center in scenario.tracks[i].states]
496
+ traj_y = [center.center_y for center in scenario.tracks[i].states]
497
+ head = [center.heading for center in scenario.tracks[i].states]
498
+ valid = [center.valid for center in scenario.tracks[i].states]
499
+ print(valid)
500
+ if i == scenario.sdc_track_index:
501
+ plt.scatter(traj_x[0], traj_y[0], s=140, c='r', marker='s')
502
+ plt.scatter([x for x, v in zip(traj_x, valid) if v],
503
+ [y for y, v in zip(traj_y, valid) if v], s=14, c='r')
504
+ plt.scatter([x for x, v in zip(traj_x, valid) if not v],
505
+ [y for y, v in zip(traj_y, valid) if not v], s=14, c='m')
506
+ else:
507
+ plt.scatter(traj_x[0], traj_y[0], s=140, c='k', marker='s')
508
+ plt.scatter([x for x, v in zip(traj_x, valid) if v],
509
+ [y for y, v in zip(traj_y, valid) if v], s=14, c='b')
510
+ plt.scatter([x for x, v in zip(traj_x, valid) if not v],
511
+ [y for y, v in zip(traj_y, valid) if not v], s=14, c='m')
512
+ if valid.count(False) > 0:
513
+ scenario_has_invalid_tracks = True
514
+ if scenario_has_invalid_tracks:
515
+ plt.savefig(f"scenario_{scenario_idx}_{scenario.scenario_id}.png")
516
+ plt.clf()
517
+ breakpoint()
518
+ break
519
+
520
+
521
+ colors = [
522
+ ('#1f77b4', '#1a5a8a'), # blue
523
+ ('#2ca02c', '#217721'), # green
524
+ ('#ff7f0e', '#cc660b'), # orange
525
+ ('#9467bd', '#6f4a91'), # purple
526
+ ('#d62728', '#a31d1d'), # red
527
+ ('#000000', '#000000'), # black
528
+ ]
529
+
530
+ @safe_run
531
+ def plot_gif():
532
+ data_path = "/u/xiuyu/work/dev4/data/waymo_processed/training"
533
+ os.makedirs("data/vis/processed/0/gif", exist_ok=True)
534
+ file_list = os.listdir(data_path)
535
+
536
+ for scenario_idx, file in tqdm(enumerate(file_list), leave=False, desc="Scenario"):
537
+
538
+ fig, ax = plt.subplots()
539
+ ax.set_axis_off()
540
+
541
+ file_path = os.path.join(data_path, file)
542
+ data = pickle.load(open(file_path, "rb"))
543
+ scenario_id = data['scenario_id']
544
+
545
+ save_path = os.path.join("data/vis/processed/0/gif",
546
+ f"scenario_{scenario_idx}_{scenario_id}.gif")
547
+ if os.path.exists(save_path):
548
+ tqdm.write(f"Skipped {save_path}.")
549
+ continue
550
+
551
+ # draw maps
552
+ ax.scatter(data['map_point']['position'][:, 0],
553
+ data['map_point']['position'][:, 1], s=0.2, c='black', edgecolors='none')
554
+
555
+ # draw agents
556
+ agent_data = data['agent']
557
+ av_index = agent_data['av_index']
558
+ position = agent_data['position'] # (num_agent, 91, 3)
559
+ heading = agent_data['heading'] # (num_agent, 91)
560
+ shape = agent_data['shape'] # (num_agent, 91, 3)
561
+ category = agent_data['category'] # (num_agent,)
562
+ valid_mask = (position[..., 0] != 0) & (position[..., 1] != 0) # (num_agent, 91)
563
+
564
+ num_agent = valid_mask.shape[0]
565
+ num_timestep = position.shape[1]
566
+ is_av = np.arange(num_agent) == int(av_index)
567
+
568
+ is_blue = valid_mask.sum(axis=1) == num_timestep
569
+ is_green = ~valid_mask[:, 0] & valid_mask[:, -1]
570
+ is_orange = valid_mask[:, 0] & ~valid_mask[:, -1]
571
+ is_purple = (valid_mask.sum(axis=1) != num_timestep
572
+ ) & (~is_green) & (~is_orange)
573
+ agent_colors = np.zeros((num_agent,))
574
+ agent_colors[is_blue] = 1
575
+ agent_colors[is_green] = 2
576
+ agent_colors[is_orange] = 3
577
+ agent_colors[is_purple] = 4
578
+ agent_colors[is_av] = 5
579
+
580
+ veh_mask = category == 1
581
+ ped_mask = category == 2
582
+ cyc_mask = category == 3
583
+ shape[veh_mask, :, 1] = 1.8
584
+ shape[veh_mask, :, 0] = 1.8
585
+ shape[ped_mask, :, 1] = 0.5
586
+ shape[ped_mask, :, 0] = 0.5
587
+ shape[cyc_mask, :, 1] = 1.0
588
+ shape[cyc_mask, :, 0] = 1.0
589
+
590
+ fig_paths = []
591
+ for tid in tqdm(range(num_timestep), leave=False, desc="Timestep"):
592
+ current_valid_mask = valid_mask[:, tid]
593
+ xs = position[current_valid_mask, tid, 0]
594
+ ys = position[current_valid_mask, tid, 1]
595
+ widths = shape[current_valid_mask, tid, 1]
596
+ lengths = shape[current_valid_mask, tid, 0]
597
+ angles = heading[current_valid_mask, tid]
598
+ current_agent_colors = agent_colors[current_valid_mask]
599
+
600
+ drawn_agents = []
601
+ contours = cal_polygon_contour(xs, ys, angles, widths, lengths) # (num_agent, 4, 2)
602
+ contours = np.concatenate([contours, contours[:, 0:1]], axis=1) # (num_agent, 5, 2)
603
+ for x, y, width, length, angle, color_type in zip(
604
+ xs, ys, widths, lengths, angles, current_agent_colors):
605
+ agent = plt.Rectangle((x, y), width, length, angle=((angle + np.pi / 2) / np.pi * 360) % 360,
606
+ linewidth=0.2,
607
+ facecolor=colors[int(color_type) - 1][0],
608
+ edgecolor=colors[int(color_type) - 1][1])
609
+ ax.add_patch(agent)
610
+ drawn_agents.append(agent)
611
+ plt.gca().set_aspect('equal', adjustable='box')
612
+ # for contour, color_type in zip(contours, agent_colors):
613
+ # drawn_agent = ax.plot(contour[:, 0], contour[:, 1])
614
+ # drawn_agents.append(drawn_agent)
615
+
616
+ fig_path = os.path.join("data/vis/processed/0/",
617
+ f"scenario_{scenario_idx}_{scenario_id}_{tid}.png")
618
+ plt.savefig(fig_path, dpi=600)
619
+ fig_paths.append(fig_path)
620
+
621
+ for drawn_agent in drawn_agents:
622
+ drawn_agent.remove()
623
+
624
+ plt.close()
625
+
626
+ # generate gif
627
+ import imageio.v2 as imageio
628
+ images = []
629
+ for fig_path in tqdm(fig_paths, leave=False, desc="Generate gif ..."):
630
+ images.append(imageio.imread(fig_path))
631
+ imageio.mimsave(save_path, images, duration=0.1)
632
+
633
+
634
+ @safe_run
635
+ def plot_map_token(ax: Axes, map_points: npt.NDArray, token_pos: npt.NDArray, token_heading: npt.NDArray, colors: Union[str, npt.NDArray]=None):
636
+
637
+ plot_map(ax, map_points)
638
+
639
+ x, y = token_pos[:, 0], token_pos[:, 1]
640
+ u = np.cos(token_heading)
641
+ v = np.sin(token_heading)
642
+
643
+ if colors is None:
644
+ colors = np.random.rand(x.shape[0], 3)
645
+ ax.quiver(x, y, u, v, angles='xy', scale_units='xy', scale=0.2, color=colors, width=0.005,
646
+ headwidth=0.2, headlength=2)
647
+ ax.scatter(x, y, color='blue', s=0.2, edgecolors='none')
648
+ ax.axis("equal")
649
+
650
+
651
+ @safe_run
652
+ def plot_map(ax: Axes, map_points: npt.NDArray, color='black'):
653
+ ax.scatter(map_points[:, 0], map_points[:, 1], s=0.2, c=color, edgecolors='none')
654
+
655
+ xmin = np.min(map_points[:, 0])
656
+ xmax = np.max(map_points[:, 0])
657
+ ymin = np.min(map_points[:, 1])
658
+ ymax = np.max(map_points[:, 1])
659
+ ax.set_xlim(xmin, xmax)
660
+ ax.set_ylim(ymin, ymax)
661
+
662
+
663
+ @safe_run
664
+ def plot_agent(ax: Axes, xy: Sequence[float], heading: float, type: str, state, is_av: bool=False,
665
+ pl2seed_radius: float=25., attr_tokenizer: Attr_Tokenizer=None, enter_index: list=[], **kwargs):
666
+
667
+ if type == 'veh':
668
+ length = 4.3
669
+ width = 1.8
670
+ size = 1.0
671
+ elif type == 'ped':
672
+ length = 0.5
673
+ width = 0.5
674
+ size = 0.1
675
+ elif type == 'cyc':
676
+ length = 1.9
677
+ width = 0.5
678
+ size = 0.3
679
+ else:
680
+ raise ValueError(f"Unsupported agent type {type}")
681
+
682
+ if kwargs.get('label', None) is not None:
683
+ ax.text(
684
+ xy[0] + 1.5, xy[1] + 1.5,
685
+ kwargs['label'], fontsize=2, color="darkred", ha="center", va="center"
686
+ )
687
+
688
+ patch = FancyBboxPatch([-length / 2, -width / 2], length, width, linewidth=.2, **kwargs)
689
+ transform = (
690
+ mtransforms.Affine2D().rotate(heading).translate(xy[0], xy[1])
691
+ + ax.transData
692
+ )
693
+ patch.set_transform(transform)
694
+
695
+ kwargs['label'] = None
696
+ angles = [0, 2 * np.pi / 3, np.pi, 4 * np.pi / 3]
697
+ pts = np.stack([size * np.cos(angles), size * np.sin(angles)], axis=-1)
698
+ center_patch = Polygon(pts, zorder=10., linewidth=.2, **kwargs)
699
+ center_patch.set_transform(transform)
700
+
701
+ ax.add_patch(patch)
702
+ ax.add_patch(center_patch)
703
+
704
+ if is_av:
705
+
706
+ if attr_tokenizer is not None:
707
+
708
+ circle_patch = Circle(
709
+ (xy[0], xy[1]), pl2seed_radius, linewidth=0.5, edgecolor='gray', linestyle='--', facecolor='none'
710
+ )
711
+ ax.add_patch(circle_patch)
712
+
713
+ grid = attr_tokenizer.get_grid(torch.tensor(np.array(xy)).float(),
714
+ torch.tensor(np.array([heading])).float()).numpy()[0] # (num_grid, 2)
715
+ ax.scatter(grid[:, 0], grid[:, 1], s=0.3, c='blue', edgecolors='none')
716
+ ax.text(grid[0, 0], grid[0, 1], 'Front', fontsize=2, color='darkred', ha='center', va='center')
717
+ ax.text(grid[-1, 0], grid[-1, 1], 'Back', fontsize=2, color='darkred', ha='center', va='center')
718
+
719
+ if enter_index:
720
+ for i in enter_index:
721
+ ax.plot(grid[int(i), 0], grid[int(i), 1], marker='x', color='red', markersize=1)
722
+
723
+ return patch, center_patch
724
+
725
+
726
+ @safe_run
727
+ def plot_all(map, xs, ys, angles, types, colors, is_avs, pl2seed_radius: float=25.,
728
+ attr_tokenizer: Attr_Tokenizer=None, enter_index: list=[], labels: list=[], **kwargs):
729
+
730
+ plt.subplots_adjust(left=0.3, right=0.7, top=0.7, bottom=0.3)
731
+ _, ax = plt.subplots()
732
+ ax.set_axis_off()
733
+
734
+ plot_map(ax, map)
735
+
736
+ if not labels:
737
+ labels = [None] * xs.shape[0]
738
+
739
+ for x, y, angle, type, color, label, is_av in zip(xs, ys, angles, types, colors, labels, is_avs):
740
+ assert type in ('veh', 'ped', 'cyc'), f"Unsupported type {type}."
741
+ plot_agent(ax, [x, y], angle.item(), type, None, is_av, facecolor=color, edgecolor='k', label=label,
742
+ pl2seed_radius=pl2seed_radius, attr_tokenizer=attr_tokenizer, enter_index=enter_index)
743
+
744
+ ax.grid(False)
745
+ ax.set_aspect('equal', adjustable='box')
746
+
747
+ # ax.legend(loc='best', frameon=True)
748
+
749
+ if kwargs.get('save_path', None):
750
+ plt.savefig(kwargs['save_path'], dpi=600, bbox_inches="tight")
751
+
752
+ plt.close()
753
+
754
+ return ax
755
+
756
+
757
+ @safe_run
758
+ def plot_file(gt_folder: str,
759
+ folder: Optional[str] = None,
760
+ files: Optional[str] = None):
761
+ from dev.metrics.compute_metrics import _unbatch
762
+
763
+ if files is None:
764
+ assert os.path.exists(folder), f'Path {folder} does not exist.'
765
+ files = list(fnmatch.filter(os.listdir(folder), 'idx_*_rollouts.pkl'))
766
+ CONSOLE.log(f'Found {len(files)} rollouts files from {folder}.')
767
+
768
+
769
+ if folder is None:
770
+ assert os.path.exists(files), f'Path {files} does not exist.'
771
+ folder = os.path.dirname(files)
772
+ files = [files]
773
+
774
+ parent, folder_name = os.path.split(folder.rstrip(os.sep))
775
+ save_path = os.path.join(parent, f'{folder_name}_plots')
776
+
777
+ for file in (pbar := tqdm(files, leave=False, desc='Plotting files ...')):
778
+ pbar.set_postfix(file=file)
779
+
780
+ with open(os.path.join(folder, file), 'rb') as f:
781
+ preds = pickle.load(f)
782
+
783
+ scenario_ids = preds['_scenario_id']
784
+ agent_batch = preds['agent_batch']
785
+ agent_id = _unbatch(preds['agent_id'], agent_batch)
786
+ preds_traj = _unbatch(preds['pred_traj'], agent_batch)
787
+ preds_head = _unbatch(preds['pred_head'], agent_batch)
788
+ preds_type = _unbatch(preds['pred_type'], agent_batch)
789
+ preds_state = _unbatch(preds['pred_state'], agent_batch)
790
+ preds_valid = _unbatch(preds['pred_valid'], agent_batch)
791
+
792
+ for i, scenario_id in enumerate(scenario_ids):
793
+ n_rollouts = preds_traj[0].shape[1]
794
+
795
+ for j in range(n_rollouts): # 1
796
+ pred = dict(scenario_id=[scenario_id],
797
+ pred_traj=preds_traj[i][:, j],
798
+ pred_head=preds_head[i][:, j],
799
+ pred_state=preds_state[i][:, j],
800
+ pred_type=preds_type[i][:, j],
801
+ )
802
+ av_index = agent_id[i][:, 0].tolist().index(preds['av_id']) # NOTE: hard code!!!
803
+
804
+ data_path = os.path.join(gt_folder, 'validation', f'{scenario_id}.pkl')
805
+ with open(data_path, 'rb') as f:
806
+ data = pickle.load(f)
807
+ plot_val(data, pred, av_index=av_index, save_path=save_path)
808
+
809
+
810
+ @safe_run
811
+ def plot_val(data: Union[dict, str], pred: dict, av_index: int, save_path: str, suffix: str='',
812
+ pl2seed_radius: float=75., attr_tokenizer=None, **kwargs):
813
+
814
+ if isinstance(data, str):
815
+ assert data.endswith('.pkl'), f'Got invalid data path {data}.'
816
+ assert os.path.exists(data), f'Path {data} does not exist.'
817
+ with open(data, 'rb') as f:
818
+ data = pickle.load(f)
819
+
820
+ map_point = data['map_point']['position'].cpu().numpy()
821
+
822
+ scenario_id = pred['scenario_id'][0]
823
+ pred_traj = pred['pred_traj'].cpu().numpy() # (num_agent, num_future_step, 2)
824
+ pred_type = list(map(lambda i: AGENT_TYPE[i], pred['pred_type'].tolist()))
825
+ pred_state = pred['pred_state'].cpu().numpy()
826
+ pred_head = pred['pred_head'].cpu().numpy()
827
+ ids = np.arange(pred_traj.shape[0])
828
+
829
+ if 'agent_labels' in pred:
830
+ kwargs.update(agent_labels=pred['agent_labels'])
831
+
832
+ plot_scenario(scenario_id, map_point, pred_traj, pred_head, pred_state, pred_type,
833
+ av_index=av_index, ids=ids, save_path=save_path, suffix=suffix,
834
+ pl2seed_radius=pl2seed_radius, attr_tokenizer=attr_tokenizer, **kwargs)
835
+
836
+
837
+ @safe_run
838
+ def plot_scenario(scenario_id: str,
839
+ map_data: npt.NDArray,
840
+ traj: npt.NDArray,
841
+ heading: npt.NDArray,
842
+ state: npt.NDArray,
843
+ types: List[str],
844
+ av_index: int,
845
+ color_type: Literal['state', 'type', 'seed']='seed',
846
+ state_type: List[str]=['invalid', 'valid', 'enter', 'exit'],
847
+ plot_enter: bool=False,
848
+ suffix: str='',
849
+ pl2seed_radius: float=25.,
850
+ attr_tokenizer: Attr_Tokenizer=None,
851
+ enter_index: List[list] = [],
852
+ save_gif: bool=True,
853
+ tokenized: bool=False,
854
+ agent_labels: List[List[Optional[str]]] = [],
855
+ **kwargs):
856
+
857
+ num_historical_steps = 11
858
+ shift = 5
859
+ num_agent, num_timestep = traj.shape[:2]
860
+
861
+ if tokenized:
862
+ num_historical_steps = 2
863
+ shift = 1
864
+
865
+ if 'save_path' in kwargs and kwargs['save_path'] != '':
866
+ os.makedirs(kwargs['save_path'], exist_ok=True)
867
+ save_id = int(max([0] + list(map(lambda fname: int(fname.split("_")[-1]),
868
+ filter(lambda fname: fname.startswith(scenario_id)
869
+ and os.path.isdir(os.path.join(kwargs['save_path'], fname)),
870
+ os.listdir(kwargs['save_path'])))))) + 1
871
+ os.makedirs(f"{kwargs['save_path']}/{scenario_id}_{str(save_id).zfill(3)}", exist_ok=True)
872
+
873
+ if save_id > 1:
874
+ try:
875
+ import shutil
876
+ shutil.rmtree(f"{kwargs['save_path']}/{scenario_id}_{str(save_id - 1).zfill(3)}")
877
+ except:
878
+ pass
879
+
880
+ visible_mask = state != state_type.index('invalid')
881
+ if not plot_enter:
882
+ visible_mask &= (state != state_type.index('enter'))
883
+
884
+ last_valid_step = visible_mask.shape[1] - 1 - torch.argmax(torch.Tensor(visible_mask).flip(dims=[1]).long(), dim=1)
885
+ ids = None
886
+ if 'ids' in kwargs:
887
+ ids = kwargs['ids']
888
+ last_valid_step = {int(ids[i]): int(last_valid_step[i]) for i in range(len(ids))}
889
+
890
+ # agent colors
891
+ agent_colors = np.zeros((num_agent, num_timestep, 3))
892
+
893
+ agent_palette = sns.color_palette('husl', n_colors=7)
894
+ state_colors = {state: np.array(agent_palette[i]) for i, state in enumerate(state_type)}
895
+ seed_colors = {seed: np.array(agent_palette[i]) for i, seed in enumerate(['existing', 'entered', 'exited'])}
896
+
897
+ if color_type == 'state':
898
+ for t in range(state.shape[1]):
899
+ agent_colors[state[:, t] == state_type.index('invalid'), t * shift : (t + 1) * shift] = state_colors['invalid']
900
+ agent_colors[state[:, t] == state_type.index('valid'), t * shift : (t + 1) * shift] = state_colors['valid']
901
+ agent_colors[state[:, t] == state_type.index('enter'), t * shift : (t + 1) * shift] = state_colors['enter']
902
+ agent_colors[state[:, t] == state_type.index('exit'), t * shift : (t + 1) * shift] = state_colors['exit']
903
+
904
+ if color_type == 'seed':
905
+ agent_colors[:, :] = seed_colors['existing']
906
+ is_exited = np.any(state[:, num_historical_steps - 1:] == state_type.index('exit'), axis=-1)
907
+ is_entered = np.any(state[:, num_historical_steps - 1:] == state_type.index('enter'), axis=-1)
908
+ is_entered[av_index + 1:] = True # NOTE: hard code, need improvment
909
+ agent_colors[is_exited, :] = seed_colors['exited']
910
+ agent_colors[is_entered, :] = seed_colors['entered']
911
+
912
+ agent_colors[av_index, :] = np.array(agent_palette[-1])
913
+ is_av = np.zeros_like(state[:, 0]).astype(np.bool_)
914
+ is_av[av_index] = True
915
+
916
+ # draw agents
917
+ fig_paths = []
918
+ for tid in tqdm(range(num_timestep), leave=False, desc="Plot ..."):
919
+ mask_t = visible_mask[:, tid]
920
+ xs = traj[mask_t, tid, 0]
921
+ ys = traj[mask_t, tid, 1]
922
+ angles = heading[mask_t, tid]
923
+ colors = agent_colors[mask_t, tid]
924
+ types_t = [types[i] for i, mask in enumerate(mask_t) if mask]
925
+ if ids is not None:
926
+ ids_t = ids[mask_t]
927
+ is_av_t = is_av[mask_t]
928
+ enter_index_t = enter_index[tid] if enter_index else None
929
+ labels = []
930
+ if agent_labels:
931
+ labels = [agent_labels[i][tid // shift] for i in range(len(agent_labels)) if mask_t[i]]
932
+
933
+ fig_path = None
934
+ if 'save_path' in kwargs:
935
+ save_path = kwargs['save_path']
936
+ fig_path = os.path.join(f"{save_path}/{scenario_id}_{str(save_id).zfill(3)}", f"{tid}.png")
937
+ fig_paths.append(fig_path)
938
+
939
+ plot_all(map_data, xs, ys, angles, types_t, colors=colors, save_path=fig_path, is_avs=is_av_t,
940
+ pl2seed_radius=pl2seed_radius, attr_tokenizer=attr_tokenizer, enter_index=enter_index_t, labels=labels)
941
+
942
+ # generate gif
943
+ if fig_paths and save_gif:
944
+ os.makedirs(os.path.join(save_path, 'gifs'), exist_ok=True)
945
+ images = []
946
+ gif_path = f"{save_path}/gifs/{scenario_id}_{str(save_id).zfill(3)}.gif"
947
+ for fig_path in tqdm(fig_paths, leave=False, desc="Generate gif ..."):
948
+ images.append(Image.open(fig_path))
949
+ try:
950
+ images[0].save(gif_path, save_all=True, append_images=images[1:], duration=100, loop=0)
951
+ tqdm.write(f"Saved gif at {gif_path}")
952
+ try:
953
+ import shutil
954
+ shutil.rmtree(f"{save_path}/{scenario_id}_{str(save_id).zfill(3)}")
955
+ os.remove(f"{save_path}/gifs/{scenario_id}_{str(save_id - 1).zfill(3)}.gif")
956
+ except:
957
+ pass
958
+ except Exception as e:
959
+ tqdm.write(f"{e}! Failed to save gif at {gif_path}")
960
+
961
+
962
+ def match_token_map(data):
963
+
964
+ # init map token
965
+ argmin_sample_len = 3
966
+ map_token_traj_path = '/u/xiuyu/work/dev4/dev/tokens/map_traj_token5.pkl'
967
+
968
+ map_token_traj = pickle.load(open(map_token_traj_path, 'rb'))
969
+ map_token = {'traj_src': map_token_traj['traj_src'], }
970
+ traj_end_theta = np.arctan2(map_token['traj_src'][:, -1, 1] - map_token['traj_src'][:, -2, 1],
971
+ map_token['traj_src'][:, -1, 0] - map_token['traj_src'][:, -2, 0])
972
+ indices = torch.linspace(0, map_token['traj_src'].shape[1]-1, steps=argmin_sample_len).long()
973
+ map_token['sample_pt'] = torch.from_numpy(map_token['traj_src'][:, indices]).to(torch.float)
974
+ map_token['traj_end_theta'] = torch.from_numpy(traj_end_theta).to(torch.float)
975
+ map_token['traj_src'] = torch.from_numpy(map_token['traj_src']).to(torch.float)
976
+
977
+ traj_pos = data['map_save']['traj_pos'].to(torch.float)
978
+ traj_theta = data['map_save']['traj_theta'].to(torch.float)
979
+ pl_idx_list = data['map_save']['pl_idx_list']
980
+ token_sample_pt = map_token['sample_pt'].to(traj_pos.device)
981
+ token_src = map_token['traj_src'].to(traj_pos.device)
982
+ max_traj_len = map_token['traj_src'].shape[1]
983
+ pl_num = traj_pos.shape[0]
984
+
985
+ pt_token_pos = traj_pos[:, 0, :].clone()
986
+ pt_token_orientation = traj_theta.clone()
987
+ cos, sin = traj_theta.cos(), traj_theta.sin()
988
+ rot_mat = traj_theta.new_zeros(pl_num, 2, 2)
989
+ rot_mat[..., 0, 0] = cos
990
+ rot_mat[..., 0, 1] = -sin
991
+ rot_mat[..., 1, 0] = sin
992
+ rot_mat[..., 1, 1] = cos
993
+ traj_pos_local = torch.bmm((traj_pos - traj_pos[:, 0:1]), rot_mat.view(-1, 2, 2))
994
+ distance = torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1)) ** 2, dim=(-2, -1))
995
+ pt_token_id = torch.argmin(distance, dim=1)
996
+
997
+ noise = False
998
+ if noise:
999
+ topk_indices = torch.argsort(torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1)) ** 2, dim=(-2, -1)), dim=1)[:, :8]
1000
+ sample_topk = torch.randint(0, topk_indices.shape[-1], size=(topk_indices.shape[0], 1), device=topk_indices.device)
1001
+ pt_token_id = torch.gather(topk_indices, 1, sample_topk).squeeze(-1)
1002
+
1003
+ # cos, sin = traj_theta.cos(), traj_theta.sin()
1004
+ # rot_mat = traj_theta.new_zeros(pl_num, 2, 2)
1005
+ # rot_mat[..., 0, 0] = cos
1006
+ # rot_mat[..., 0, 1] = sin
1007
+ # rot_mat[..., 1, 0] = -sin
1008
+ # rot_mat[..., 1, 1] = cos
1009
+ # token_src_world = torch.bmm(token_src[None, ...].repeat(pl_num, 1, 1, 1).reshape(pl_num, -1, 2),
1010
+ # rot_mat.view(-1, 2, 2)).reshape(pl_num, token_src.shape[0], max_traj_len, 2) + traj_pos[:, None, [0], :]
1011
+ # token_src_world_select = token_src_world.view(-1, 1024, 11, 2)[torch.arange(pt_token_id.view(-1).shape[0]), pt_token_id.view(-1)].view(pl_num, max_traj_len, 2)
1012
+
1013
+ pl_idx_full = pl_idx_list.clone()
1014
+ token2pl = torch.stack([torch.arange(len(pl_idx_list), device=traj_pos.device), pl_idx_full.long()])
1015
+ count_nums = []
1016
+ for pl in pl_idx_full.unique():
1017
+ pt = token2pl[0, token2pl[1, :] == pl]
1018
+ left_side = (data['pt_token']['side'][pt] == 0).sum()
1019
+ right_side = (data['pt_token']['side'][pt] == 1).sum()
1020
+ center_side = (data['pt_token']['side'][pt] == 2).sum()
1021
+ count_nums.append(torch.Tensor([left_side, right_side, center_side]))
1022
+ count_nums = torch.stack(count_nums, dim=0)
1023
+ num_polyline = int(count_nums.max().item())
1024
+ traj_mask = torch.zeros((int(len(pl_idx_full.unique())), 3, num_polyline), dtype=bool)
1025
+ idx_matrix = torch.arange(traj_mask.size(2)).unsqueeze(0).unsqueeze(0)
1026
+ idx_matrix = idx_matrix.expand(traj_mask.size(0), traj_mask.size(1), -1)
1027
+ counts_num_expanded = count_nums.unsqueeze(-1)
1028
+ mask_update = idx_matrix < counts_num_expanded
1029
+ traj_mask[mask_update] = True
1030
+
1031
+ data['pt_token']['traj_mask'] = traj_mask
1032
+ data['pt_token']['position'] = torch.cat([pt_token_pos, torch.zeros((data['pt_token']['num_nodes'], 1),
1033
+ device=traj_pos.device, dtype=torch.float)], dim=-1)
1034
+ data['pt_token']['orientation'] = pt_token_orientation
1035
+ data['pt_token']['height'] = data['pt_token']['position'][:, -1]
1036
+ data[('pt_token', 'to', 'map_polygon')] = {}
1037
+ data[('pt_token', 'to', 'map_polygon')]['edge_index'] = token2pl # (2, num_points)
1038
+ data['pt_token']['token_idx'] = pt_token_id
1039
+ return data
1040
+
1041
+
1042
+ @safe_run
1043
+ def plot_tokenize(data, save_path: str):
1044
+
1045
+ shift = 5
1046
+ token_size = 2048
1047
+ pl2seed_radius = 75
1048
+
1049
+ # transformation
1050
+ transform = WaymoTargetBuilder(num_historical_steps=11,
1051
+ num_future_steps=80,
1052
+ max_num=32,
1053
+ training=False)
1054
+
1055
+ grid_range = 150.
1056
+ grid_interval = 3.
1057
+ angle_interval = 3.
1058
+ attr_tokenizer = Attr_Tokenizer(grid_range=grid_range,
1059
+ grid_interval=grid_interval,
1060
+ radius=pl2seed_radius,
1061
+ angle_interval=angle_interval)
1062
+
1063
+ # tokenization
1064
+ token_processor = TokenProcessor(token_size,
1065
+ training=False,
1066
+ predict_motion=True,
1067
+ predict_state=True,
1068
+ predict_map=True,
1069
+ state_token={'invalid': 0, 'valid': 1, 'enter': 2, 'exit': 3},
1070
+ pl2seed_radius=pl2seed_radius)
1071
+ CONSOLE.log(f"Loaded token processor with token_size: {token_size}")
1072
+
1073
+ # preprocess
1074
+ data: HeteroData = transform(data)
1075
+ tokenized_data = token_processor(data)
1076
+ CONSOLE.log(f"Keys in tokenized data:\n{tokenized_data.keys()}")
1077
+
1078
+ # plot
1079
+ agent_data = tokenized_data['agent']
1080
+ map_data = tokenized_data['map_point']
1081
+ # CONSOLE.log(f"Keys in agent data:\n{agent_data.keys()}")
1082
+
1083
+ av_index = agent_data['av_index']
1084
+ raw_traj = agent_data['position'][..., :2].contiguous() # [n_agent, n_step, 2]
1085
+ raw_heading = agent_data['heading'] # [n_agent, n_step]
1086
+
1087
+ traj = agent_data['traj_pos'][..., :2].contiguous() # [n_agent, n_step, 6, 2]
1088
+ traj = traj[:, :, 1:, :].flatten(1, 2)
1089
+ traj = torch.cat([raw_traj[:, :1], traj], dim=1)
1090
+ heading = agent_data['traj_heading'] # [n_agent, n_step, 6]
1091
+ heading = heading[:, :, 1:].flatten(1, 2)
1092
+ heading = torch.cat([raw_heading[:, :1], heading], dim=1)
1093
+
1094
+ agent_state = agent_data['state_idx'].repeat_interleave(repeats=shift, dim=-1)
1095
+ agent_state = torch.cat([torch.zeros_like(agent_state[:, :1]), agent_state], dim=1)
1096
+ agent_type = agent_data['type']
1097
+ ids = np.arange(raw_traj.shape[0])
1098
+
1099
+ plot_scenario(scenario_id=tokenized_data['scenario_id'],
1100
+ map_data=tokenized_data['map_point']['position'].numpy(),
1101
+ traj=raw_traj.numpy(),
1102
+ heading=raw_heading.numpy(),
1103
+ state=agent_state.numpy(),
1104
+ types=list(map(lambda i: AGENT_TYPE[i], agent_type.tolist())),
1105
+ av_index=av_index,
1106
+ ids=ids,
1107
+ save_path=save_path,
1108
+ pl2seed_radius=pl2seed_radius,
1109
+ attr_tokenizer=attr_tokenizer,
1110
+ color_type='state',
1111
+ )
1112
+
1113
+
1114
+ if __name__ == "__main__":
1115
+ parser = ArgumentParser()
1116
+ parser.add_argument('--data_path', type=str, default='/u/xiuyu/work/dev4/data/waymo_processed')
1117
+ parser.add_argument('--tfrecord_dir', type=str, default='validation_tfrecords_splitted')
1118
+ # plot tokenized data
1119
+ parser.add_argument('--save_folder', type=str, default='plot_gt')
1120
+ parser.add_argument('--split', type=str, default='validation')
1121
+ parser.add_argument('--scenario_id', type=str, default=None)
1122
+ parser.add_argument('--plot_tokenize', action='store_true')
1123
+ # plot generated rollouts
1124
+ parser.add_argument('--plot_file', action='store_true')
1125
+ parser.add_argument('--folder_path', type=str, default=None)
1126
+ parser.add_argument('--file_path', type=str, default=None)
1127
+ args = parser.parse_args()
1128
+
1129
+ if args.plot_tokenize:
1130
+
1131
+ scenario_id = "74ad7b76d5906d39"
1132
+ # scenario_id = "1d60300bc06f4801"
1133
+ data_path = os.path.join(args.data_path, args.split, f"{scenario_id}.pkl")
1134
+ data = pickle.load(open(data_path, "rb"))
1135
+ data['tfrecord_path'] = os.path.join(args.tfrecord_dir, f'{scenario_id}.tfrecords')
1136
+ CONSOLE.log(f"Loaded scenario {scenario_id}")
1137
+
1138
+ save_path = os.path.join(args.data_path, args.save_folder, args.split)
1139
+ os.makedirs(save_path, exist_ok=True)
1140
+
1141
+ plot_tokenize(data, save_path)
1142
+
1143
+ if args.plot_file:
1144
+
1145
+ plot_file(args.data_path, folder=args.folder_path, files=args.file_path)
backups/environment.yml ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: traj
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=conda_forge
7
+ - _openmp_mutex=4.5=2_gnu
8
+ - ca-certificates=2025.1.31=hbcca054_0
9
+ - ld_impl_linux-64=2.38=h1181459_1
10
+ - libffi=3.4.4=h6a678d5_1
11
+ - libgcc=14.2.0=h77fa898_1
12
+ - libgcc-ng=14.2.0=h69a702a_1
13
+ - libgomp=14.2.0=h77fa898_1
14
+ - libstdcxx-ng=11.2.0=h1234567_1
15
+ - ncdu=1.16=h0f457ee_0
16
+ - ncurses=6.4=h6a678d5_0
17
+ - openssl=3.4.0=h7b32b05_1
18
+ - pip=24.2=py39h06a4308_0
19
+ - python=3.9.19=h955ad1f_1
20
+ - readline=8.2=h5eee18b_0
21
+ - sqlite=3.45.3=h5eee18b_0
22
+ - tk=8.6.14=h39e8969_0
23
+ - wheel=0.43.0=py39h06a4308_0
24
+ - xz=5.4.6=h5eee18b_1
25
+ - zlib=1.2.13=h5eee18b_1
26
+ - pip:
27
+ - absl-py==1.4.0
28
+ - addict==2.4.0
29
+ - aiohappyeyeballs==2.4.0
30
+ - aiohttp==3.10.5
31
+ - aiosignal==1.3.1
32
+ - anyio==4.4.0
33
+ - appdirs==1.4.4
34
+ - argon2-cffi==23.1.0
35
+ - argon2-cffi-bindings==21.2.0
36
+ - array-record==0.5.1
37
+ - arrow==1.3.0
38
+ - asttokens==2.4.1
39
+ - astunparse==1.6.3
40
+ - async-lru==2.0.4
41
+ - async-timeout==4.0.3
42
+ - attrs==24.2.0
43
+ - av==12.3.0
44
+ - babel==2.16.0
45
+ - beautifulsoup4==4.12.3
46
+ - bidict==0.23.1
47
+ - bleach==6.1.0
48
+ - blinker==1.8.2
49
+ - cachetools==5.5.0
50
+ - certifi==2024.7.4
51
+ - cffi==1.17.0
52
+ - chardet==5.2.0
53
+ - charset-normalizer==3.3.2
54
+ - click==8.1.7
55
+ - cloudpickle==3.0.0
56
+ - colorlog==6.8.2
57
+ - comet-ml==3.45.0
58
+ - comm==0.2.2
59
+ - configargparse==1.7
60
+ - configobj==5.0.8
61
+ - contourpy==1.3.0
62
+ - cryptography==43.0.0
63
+ - cycler==0.12.1
64
+ - dacite==1.8.1
65
+ - dash==2.17.1
66
+ - dash-core-components==2.0.0
67
+ - dash-html-components==2.0.0
68
+ - dash-table==5.0.0
69
+ - dask==2023.3.1
70
+ - dataclass-array==1.5.1
71
+ - debugpy==1.8.5
72
+ - decorator==5.1.1
73
+ - defusedxml==0.7.1
74
+ - descartes==1.1.0
75
+ - dm-tree==0.1.8
76
+ - docker-pycreds==0.4.0
77
+ - docstring-parser==0.16
78
+ - dulwich==0.22.1
79
+ - easydict==1.13
80
+ - einops==0.8.0
81
+ - einsum==0.3.0
82
+ - embreex==2.17.7.post5
83
+ - etils==1.5.2
84
+ - eval-type-backport==0.2.0
85
+ - everett==3.1.0
86
+ - exceptiongroup==1.2.2
87
+ - executing==2.0.1
88
+ - fastjsonschema==2.20.0
89
+ - filelock==3.15.4
90
+ - fire==0.6.0
91
+ - flask==3.0.3
92
+ - flatbuffers==24.3.25
93
+ - fonttools==4.53.1
94
+ - fqdn==1.5.1
95
+ - frozenlist==1.4.1
96
+ - fsspec==2024.6.1
97
+ - gast==0.4.0
98
+ - gdown==5.2.0
99
+ - gitdb==4.0.11
100
+ - gitpython==3.1.43
101
+ - google-auth==2.16.2
102
+ - google-auth-oauthlib==1.0.0
103
+ - google-pasta==0.2.0
104
+ - grpcio==1.66.1
105
+ - h11==0.14.0
106
+ - h5py==3.11.0
107
+ - httpcore==1.0.5
108
+ - httpx==0.27.2
109
+ - idna==3.8
110
+ - imageio==2.35.1
111
+ - immutabledict==2.2.0
112
+ - importlib-metadata==8.4.0
113
+ - importlib-resources==6.4.4
114
+ - ipykernel==6.29.5
115
+ - ipython==8.18.1
116
+ - ipywidgets==8.1.5
117
+ - isoduration==20.11.0
118
+ - itsdangerous==2.2.0
119
+ - jax==0.4.30
120
+ - jaxlib==0.4.30
121
+ - jaxtyping==0.2.33
122
+ - jedi==0.19.1
123
+ - jinja2==3.1.4
124
+ - joblib==1.4.2
125
+ - json5==0.9.25
126
+ - jsonpointer==3.0.0
127
+ - jsonschema==4.23.0
128
+ - jsonschema-specifications==2023.12.1
129
+ - jupyter-client==8.6.2
130
+ - jupyter-core==5.7.2
131
+ - jupyter-events==0.10.0
132
+ - jupyter-lsp==2.2.5
133
+ - jupyter-server==2.14.2
134
+ - jupyter-server-terminals==0.5.3
135
+ - jupyterlab==4.2.5
136
+ - jupyterlab-pygments==0.3.0
137
+ - jupyterlab-server==2.27.3
138
+ - jupyterlab-widgets==3.0.13
139
+ - keras==2.12.0
140
+ - kiwisolver==1.4.5
141
+ - lark==1.2.2
142
+ - lazy-loader==0.4
143
+ - libclang==18.1.1
144
+ - lightning-utilities==0.11.6
145
+ - locket==1.0.0
146
+ - lxml==5.3.0
147
+ - manifold3d==2.5.1
148
+ - mapbox-earcut==1.0.2
149
+ - markdown==3.7
150
+ - markdown-it-py==3.0.0
151
+ - markupsafe==2.1.5
152
+ - matplotlib==3.9.2
153
+ - matplotlib-inline==0.1.7
154
+ - mdurl==0.1.2
155
+ - mediapy==1.2.2
156
+ - mistune==3.0.2
157
+ - ml-dtypes==0.4.0
158
+ - mpmath==1.3.0
159
+ - msgpack==1.0.8
160
+ - msgpack-numpy==0.4.8
161
+ - multidict==6.0.5
162
+ - namex==0.0.8
163
+ - nbclient==0.10.0
164
+ - nbconvert==7.16.4
165
+ - nbformat==5.10.4
166
+ - nerfacc==0.5.2
167
+ - nerfstudio==0.3.4
168
+ - nest-asyncio==1.6.0
169
+ - networkx==3.2.1
170
+ - ninja==1.11.1.1
171
+ - nodeenv==1.9.1
172
+ - notebook-shim==0.2.4
173
+ - numpy==1.23.0
174
+ - nuscenes-devkit==1.1.11
175
+ - nvidia-cublas-cu12==12.1.3.1
176
+ - nvidia-cuda-cupti-cu12==12.1.105
177
+ - nvidia-cuda-nvrtc-cu12==12.1.105
178
+ - nvidia-cuda-runtime-cu12==12.1.105
179
+ - nvidia-cudnn-cu12==9.1.0.70
180
+ - nvidia-cufft-cu12==11.0.2.54
181
+ - nvidia-curand-cu12==10.3.2.106
182
+ - nvidia-cusolver-cu12==11.4.5.107
183
+ - nvidia-cusparse-cu12==12.1.0.106
184
+ - nvidia-nccl-cu12==2.20.5
185
+ - nvidia-nvjitlink-cu12==12.6.20
186
+ - nvidia-nvtx-cu12==12.1.105
187
+ - oauthlib==3.2.2
188
+ - open3d==0.18.0
189
+ - opencv-python==4.6.0.66
190
+ - openexr==1.3.9
191
+ - opt-einsum==3.3.0
192
+ - optree==0.12.1
193
+ - overrides==7.7.0
194
+ - packaging==24.1
195
+ - pandas==1.5.3
196
+ - pandocfilters==1.5.1
197
+ - parso==0.8.4
198
+ - partd==1.4.2
199
+ - pexpect==4.9.0
200
+ - pillow==9.2.0
201
+ - platformdirs==4.2.2
202
+ - plotly==5.13.1
203
+ - prometheus-client==0.20.0
204
+ - promise==2.3
205
+ - prompt-toolkit==3.0.47
206
+ - protobuf==3.20.3
207
+ - psutil==6.0.0
208
+ - ptyprocess==0.7.0
209
+ - pure-eval==0.2.3
210
+ - pyarrow==10.0.0
211
+ - pyasn1==0.6.0
212
+ - pyasn1-modules==0.4.0
213
+ - pycocotools==2.0.8
214
+ - pycollada==0.8
215
+ - pycparser==2.22
216
+ - pygments==2.18.0
217
+ - pyliblzfse==0.4.1
218
+ - pymeshlab==2023.12.post1
219
+ - pyngrok==7.2.0
220
+ - pyparsing==3.1.4
221
+ - pyquaternion==0.9.9
222
+ - pysocks==1.7.1
223
+ - python-box==6.1.0
224
+ - python-dateutil==2.9.0.post0
225
+ - python-engineio==4.9.1
226
+ - python-json-logger==2.0.7
227
+ - python-socketio==5.11.3
228
+ - pytorch-lightning==2.4.0
229
+ - pytz==2024.1
230
+ - pywavelets==1.6.0
231
+ - pyyaml==6.0.2
232
+ - pyzmq==26.2.0
233
+ - rawpy==0.22.0
234
+ - referencing==0.35.1
235
+ - requests==2.32.3
236
+ - requests-oauthlib==2.0.0
237
+ - requests-toolbelt==1.0.0
238
+ - retrying==1.3.4
239
+ - rfc3339-validator==0.1.4
240
+ - rfc3986-validator==0.1.1
241
+ - rich==13.8.0
242
+ - rpds-py==0.20.0
243
+ - rsa==4.9
244
+ - rtree==1.3.0
245
+ - scikit-image==0.20.0
246
+ - scikit-learn==1.2.2
247
+ - scipy==1.9.1
248
+ - seaborn==0.13.2
249
+ - semantic-version==2.10.0
250
+ - send2trash==1.8.3
251
+ - sentry-sdk==2.13.0
252
+ - setproctitle==1.3.3
253
+ - setuptools==67.6.0
254
+ - shapely==1.8.5.post1
255
+ - shtab==1.7.1
256
+ - simple-websocket==1.0.0
257
+ - simplejson==3.19.3
258
+ - six==1.16.0
259
+ - smmap==5.0.1
260
+ - sniffio==1.3.1
261
+ - soupsieve==2.6
262
+ - splines==0.3.0
263
+ - stack-data==0.6.3
264
+ - svg-path==6.3
265
+ - sympy==1.13.2
266
+ - tenacity==9.0.0
267
+ - tensorboard==2.12.3
268
+ - tensorboard-data-server==0.7.2
269
+ - tensorflow==2.12.0
270
+ - tensorflow-addons==0.23.0
271
+ - tensorflow-datasets==4.9.3
272
+ - tensorflow-estimator==2.12.0
273
+ - tensorflow-graphics==2021.12.3
274
+ - tensorflow-io-gcs-filesystem==0.37.1
275
+ - tensorflow-metadata==1.15.0
276
+ - tensorflow-probability==0.19.0
277
+ - termcolor==2.4.0
278
+ - terminado==0.18.1
279
+ - threadpoolctl==3.5.0
280
+ - tifffile==2024.8.28
281
+ - timm==0.6.7
282
+ - tinycss2==1.3.0
283
+ - toml==0.10.2
284
+ - tomli==2.0.1
285
+ - toolz==0.12.1
286
+ - torch==2.4.0
287
+ - torch-cluster==1.6.3+pt24cu121
288
+ - torch-fidelity==0.3.0
289
+ - torch-geometric==2.5.3
290
+ - torch-scatter==2.1.2+pt24cu121
291
+ - torch-sparse==0.6.18+pt24cu121
292
+ - torchmetrics==1.4.1
293
+ - torchvision==0.19.0
294
+ - tornado==6.4.1
295
+ - tqdm==4.66.5
296
+ - traitlets==5.14.3
297
+ - trimesh==4.4.7
298
+ - triton==3.0.0
299
+ - typeguard==2.13.3
300
+ - types-python-dateutil==2.9.0.20240821
301
+ - typing-extensions==4.12.2
302
+ - tyro==0.8.10
303
+ - tzdata==2024.1
304
+ - uri-template==1.3.0
305
+ - urllib3==2.2.2
306
+ - vhacdx==0.0.8.post1
307
+ - viser==0.1.3
308
+ - visu3d==1.5.1
309
+ - wandb==0.17.8
310
+ - waymo-open-dataset-tf-2-12-0==1.6.4
311
+ - wcwidth==0.2.13
312
+ - webcolors==24.8.0
313
+ - webencodings==0.5.1
314
+ - websocket-client==1.8.0
315
+ - websockets==13.0.1
316
+ - werkzeug==3.0.4
317
+ - widgetsnbextension==4.0.13
318
+ - wrapt==1.14.1
319
+ - wsproto==1.2.0
320
+ - wurlitzer==3.1.1
321
+ - xatlas==0.0.9
322
+ - xxhash==3.5.0
323
+ - yarl==1.9.11
324
+ - yourdfpy==0.0.56
325
+ - zipp==3.20.1
326
+ prefix: /u/xiuyu/anaconda3/envs/traffic
backups/run.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import os
3
+ import shutil
4
+ import fnmatch
5
+ import torch
6
+ from argparse import ArgumentParser
7
+ from pytorch_lightning.callbacks import LearningRateMonitor
8
+ from pytorch_lightning.callbacks import ModelCheckpoint
9
+ from pytorch_lightning.strategies import DDPStrategy
10
+ from pytorch_lightning.loggers import WandbLogger
11
+
12
+ from dev.utils.func import RankedLogger, load_config_act, CONSOLE
13
+ from dev.datasets.scalable_dataset import MultiDataModule
14
+ from dev.model.smart import SMART
15
+
16
+
17
+ def backup(source_dir, backup_dir):
18
+ """
19
+ Back up the source directory (code and configs) to a backup directory.
20
+ """
21
+
22
+ if os.path.exists(backup_dir):
23
+ return
24
+ os.makedirs(backup_dir, exist_ok=False)
25
+
26
+ # Helper function to check if a path matches exclude patterns
27
+ def should_exclude(path):
28
+ for pattern in exclude_patterns:
29
+ if fnmatch.fnmatch(os.path.basename(path), pattern):
30
+ return True
31
+ return False
32
+
33
+ # Iterate through the files and directories in source_dir
34
+ for root, dirs, files in os.walk(source_dir):
35
+ # Skip excluded directories
36
+ dirs[:] = [d for d in dirs if not should_exclude(d)]
37
+
38
+ # Determine the relative path and destination path
39
+ rel_path = os.path.relpath(root, source_dir)
40
+ dest_dir = os.path.join(backup_dir, rel_path)
41
+ os.makedirs(dest_dir, exist_ok=True)
42
+
43
+ # Copy all relevant files
44
+ for file in files:
45
+ if any(fnmatch.fnmatch(file, pattern) for pattern in include_patterns):
46
+ shutil.copy2(os.path.join(root, file), os.path.join(dest_dir, file))
47
+
48
+ logger.info(f"Backup completed. Files saved to: {backup_dir}")
49
+
50
+
51
+ if __name__ == '__main__':
52
+ pl.seed_everything(2024, workers=True)
53
+ torch.set_printoptions(precision=3)
54
+
55
+ parser = ArgumentParser()
56
+ parser.add_argument('--config', type=str, default='configs/ours_long_term.yaml')
57
+ parser.add_argument('--pretrain_ckpt', type=str, default=None,
58
+ help='Path to any pretrained model, will only load its parameters.'
59
+ )
60
+ parser.add_argument('--ckpt_path', type=str, default=None,
61
+ help='Path to any trained model, will load all the states.'
62
+ )
63
+ parser.add_argument('--save_ckpt_path', type=str, default='output/debug',
64
+ help='Path to save the checkpoints in training mode'
65
+ )
66
+ parser.add_argument('--save_path', type=str, default=None,
67
+ help='Path to save the inference results in validation and test mode.'
68
+ )
69
+ parser.add_argument('--wandb', action='store_true',
70
+ help='Whether to use wandb logger in training.'
71
+ )
72
+ parser.add_argument('--devices', type=int, default=1)
73
+ parser.add_argument('--train', action='store_true')
74
+ parser.add_argument('--validate', action='store_true')
75
+ parser.add_argument('--test', action='store_true')
76
+ parser.add_argument('--plot_rollouts', action='store_true')
77
+ args = parser.parse_args()
78
+
79
+ if not (args.train or args.validate or args.test or args.plot_rollouts):
80
+ raise RuntimeError(f"Got invalid action, should be one of ['train', 'validate', 'test', 'plot_rollouts']")
81
+
82
+ # ! setup logger
83
+ logger = RankedLogger(__name__, rank_zero_only=True)
84
+
85
+ # ! backup codes
86
+ exclude_patterns = ['*output*', '*logs', 'wandb', 'data', '*debug*', '*backup*', 'interact_*', '*edge_map*', '__pycache__']
87
+ include_patterns = ['*.py', '*.json', '*.yaml', '*.yml', '*.sh']
88
+ backup(os.getcwd(), os.path.join(args.save_ckpt_path, 'backups'))
89
+
90
+ config = load_config_act(args.config)
91
+
92
+ wandb_logger = None
93
+ if args.wandb and not int(os.getenv('DEBUG', 0)):
94
+ # squeue -O username,state,nodelist,gres,minmemory,numcpus,name
95
+ wandb_logger = WandbLogger(project='simagent')
96
+
97
+ trainer_config = config.Trainer
98
+ max_epochs = trainer_config.max_epochs
99
+
100
+ # ! setup datamodule and model
101
+ datamodule = MultiDataModule(**vars(config.Dataset), logger=logger)
102
+ model = SMART(config.Model, save_path=args.save_ckpt_path, logger=logger, max_epochs=max_epochs)
103
+ if args.pretrain_ckpt:
104
+ model.load_state_from_file(filename=args.pretrain_ckpt)
105
+ strategy = DDPStrategy(find_unused_parameters=True, gradient_as_bucket_view=True)
106
+ logger.info(f'Build model: {model.__class__.__name__} datamodule: {datamodule.__class__.__name__}')
107
+
108
+ # ! checkpoint configuration
109
+ every_n_epochs = 1
110
+ if int(os.getenv('OVERFIT', 0)):
111
+ max_epochs = trainer_config.overfit_epochs
112
+ every_n_epochs = 100
113
+
114
+ if int(os.getenv('CHECK_INPUTS', 0)):
115
+ max_epochs = 1
116
+
117
+ check_val_every_n_epoch = 1 # save checkpoints for each epoch
118
+ model_checkpoint = ModelCheckpoint(dirpath=args.save_ckpt_path,
119
+ filename='{epoch:02d}',
120
+ save_top_k=5,
121
+ monitor='epoch',
122
+ mode='max',
123
+ save_last=True,
124
+ every_n_train_steps=1000,
125
+ save_on_train_epoch_end=True)
126
+
127
+ # ! setup trainer
128
+ lr_monitor = LearningRateMonitor(logging_interval='epoch')
129
+ trainer = pl.Trainer(accelerator=trainer_config.accelerator, devices=args.devices if args.devices is not None else trainer_config.devices,
130
+ strategy=strategy, logger=wandb_logger,
131
+ accumulate_grad_batches=trainer_config.accumulate_grad_batches,
132
+ num_nodes=trainer_config.num_nodes,
133
+ callbacks=[model_checkpoint, lr_monitor],
134
+ max_epochs=max_epochs,
135
+ num_sanity_val_steps=0,
136
+ check_val_every_n_epoch=check_val_every_n_epoch,
137
+ log_every_n_steps=1,
138
+ gradient_clip_val=0.5)
139
+ logger.info(f'Build trainer: {trainer.__class__.__name__}')
140
+
141
+ # ! run
142
+ if args.train:
143
+
144
+ logger.info(f'Start training ...')
145
+ trainer.fit(model, datamodule, ckpt_path=args.ckpt_path)
146
+
147
+ # NOTE: here both validation and test process use validation split data
148
+ # for validation, we enable the online metric calculation with results dumping
149
+ # for test, we disable it and only dump the inference results.
150
+ else:
151
+
152
+ if args.save_path is not None:
153
+ save_path = args.save_path
154
+ else:
155
+ assert args.ckpt_path is not None and os.path.exists(args.ckpt_path), \
156
+ f'Path {args.ckpt_path} not exists!'
157
+ save_path = os.path.join(os.path.dirname(args.ckpt_path), 'validation')
158
+ os.makedirs(save_path, exist_ok=True)
159
+ CONSOLE.log(f'Results will be saved to [yellow]{save_path}[/]')
160
+
161
+ model.save_path = save_path
162
+
163
+ if not args.ckpt_path:
164
+ CONSOLE.log(f'[yellow] Warning: no checkpoint will be loaded in validation! [/]')
165
+
166
+ if args.validate:
167
+
168
+ CONSOLE.log('[on blue] Start validating ... [/]')
169
+ model.set(mode='validation')
170
+
171
+ elif args.test:
172
+
173
+ CONSOLE.log('[on blue] Sart testing ... [/]')
174
+ model.set(mode='test')
175
+
176
+ elif args.plot_rollouts:
177
+
178
+ CONSOLE.log('[on blue] Sart generating ... [/]')
179
+ model.set(mode='plot_rollouts')
180
+
181
+ trainer.validate(model, datamodule, ckpt_path=args.ckpt_path)
backups/scripts/aggregate_log_metric_features.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ export TF_CPP_MIN_LOG_LEVEL='2'
4
+ export PYTHONPATH='.'
5
+
6
+ # dump all features
7
+ echo 'Start dump all log features ...'
8
+ python dev/metrics/compute_metrics.py --dump_log --no_batch
9
+
10
+ sleep 20
11
+
12
+ # aggregate features
13
+ echo 'Start aggregate log features ...'
14
+ python dev/metrics/compute_metrics.py --aggregate_log
15
+
16
+ echo 'Done!
backups/scripts/c128.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ #SBATCH --job-name c128 # Job name
4
+ ### Logging
5
+ #SBATCH --output=%j.out # Stdout (%j expands to jobId)
6
+ #SBATCH --error=%j.err # Stderr (%j expands to jobId)
7
+ ### Node info
8
+ #SBATCH --nodes=1 # Single node or multi node
9
+ #SBATCH --time 100:00:00 # Max time (hh:mm:ss)
10
+ #SBATCH --gres=gpu:0 # GPUs per node
11
+ #SBATCH --mem=128G # Recommend 32G per GPU
12
+ #SBATCH --ntasks-per-node=1 # Tasks per node
13
+ #SBATCH --cpus-per-task=64 # Recommend 8 per GPU
backups/scripts/c64.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ #SBATCH --job-name c128 # Job name
4
+ ### Logging
5
+ #SBATCH --output=%j.out # Stdout (%j expands to jobId)
6
+ #SBATCH --error=%j.err # Stderr (%j expands to jobId)
7
+ ### Node info
8
+ #SBATCH --nodes=1 # Single node or multi node
9
+ #SBATCH --time 100:00:00 # Max time (hh:mm:ss)
10
+ #SBATCH --gres=gpu:0 # GPUs per node
11
+ #SBATCH --mem=128G # Recommend 32G per GPU
12
+ #SBATCH --ntasks-per-node=1 # Tasks per node
13
+ #SBATCH --cpus-per-task=64 # Recommend 8 per GPU
backups/scripts/compute_metrics.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ export TORCH_LOGS='0'
4
+ export TF_CPP_MIN_LOG_LEVEL='2'
5
+ export PYTHONPATH='.'
6
+
7
+ NUM_WORKERS=$1
8
+ SIM_DIR=$2
9
+
10
+ echo 'Start running ...'
11
+ python dev/metrics/compute_metrics.py --compute_metric --num_workers "$NUM_WORKERS" --sim_dir "$SIM_DIR" ${@:3}
12
+
13
+ echo 'Done!
backups/scripts/data_preprocess.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # env
4
+ source ~/anaconda3/etc/profile.d/conda.sh
5
+ conda config --append envs_dirs ~/.conda/envs
6
+ conda activate traj
7
+
8
+ echo "Starting running..."
9
+
10
+ # multi-GPU training
11
+ cd ~/work/dev6/thirdparty/dev4
12
+ PYTHONPATH='..':$PYTHONPATH python3 data_preprocess.py --split validation
backups/scripts/data_preprocess_loop.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ RED='\033[0;31m'
4
+ NC='\033[0m'
5
+
6
+ cd /u/xiuyu/work/dev4/
7
+
8
+ trap "echo -e \"${RED}Stopping script...${NC}\"; kill -- -$$" SIGINT
9
+
10
+ while true; do
11
+ echo -e "${RED}Start running ...${NC}"
12
+ PYTHONPATH='.':$PYTHONPATH setsid python data_preprocess.py --split training &
13
+ PID=$!
14
+
15
+ sleep 1200
16
+
17
+ echo -e "${RED}Sending SIGINT to process group $PID...${NC}"
18
+ PGID=$(ps -o pgid= -p $PID | tail -n 1 | tr -d ' ')
19
+ kill -- -$PGID
20
+ wait $PID
21
+
22
+ sleep 10
23
+ done
backups/scripts/debug.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.optim as optim
4
+
5
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6
+
7
+ embedding = nn.Embedding(180, 128).to(device)
8
+ gt = torch.randint(0, 2, (180, 2048)).to(device)
9
+ head = nn.Linear(128, 2048).to(device)
10
+ optimizer = optim.Adam([embedding.weight, head.weight])
11
+
12
+ while True:
13
+ pred = head(embedding.weight).sigmoid()
14
+ loss = nn.MSELoss()(pred, gt.float())
15
+ optimizer.zero_grad()
16
+ loss.backward()
17
+ optimizer.step()
backups/scripts/debug_map.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import torch
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ from argparse import ArgumentParser
8
+ from dev.datasets.preprocess import TokenProcessor
9
+ from dev.transforms.target_builder import WaymoTargetBuilder
10
+
11
+
12
+ colors = [
13
+ ('#1f77b4', '#1a5a8a'), # blue
14
+ ('#2ca02c', '#217721'), # green
15
+ ('#ff7f0e', '#cc660b'), # orange
16
+ ('#9467bd', '#6f4a91'), # purple
17
+ ('#d62728', '#a31d1d'), # red
18
+ ('#000000', '#000000'), # black
19
+ ]
20
+
21
+
22
+ def draw_map(tokenize_data, token_processor: TokenProcessor, index, posfix):
23
+ print("Drawing raw data ...")
24
+ shift = 5
25
+ token_size = 2048
26
+
27
+ traj_token = token_processor.trajectory_token["veh"]
28
+ traj_token_all = token_processor.trajectory_token_all["veh"]
29
+
30
+ plt.subplots_adjust(left=0.3, right=0.7, top=0.7, bottom=0.3)
31
+ fig, ax = plt.subplots()
32
+ ax.set_axis_off()
33
+
34
+ scenario_id = data['scenario_id']
35
+ ax.scatter(tokenize_data["map_point"]["position"][:, 0],
36
+ tokenize_data["map_point"]["position"][:, 1], s=0.2, c='black', edgecolors='none')
37
+
38
+ index = np.array(index).astype(np.int32)
39
+ agent_data = tokenize_data["agent"]
40
+ token_index = agent_data["token_idx"][index]
41
+ token_valid_mask = agent_data["agent_valid_mask"][index]
42
+
43
+ num_agent, num_token = token_index.shape
44
+ tokens = traj_token[token_index.view(-1)].reshape(num_agent, num_token, 4, 2)
45
+ tokens_all = traj_token_all[token_index.view(-1)].reshape(num_agent, num_token, 6, 4, 2)
46
+
47
+ position = agent_data['position'][index, :, :2] # (num_agent, 91, 2)
48
+ heading = agent_data['heading'][index] # (num_agent, 91)
49
+ valid_mask = (position[..., 0] != 0) & (position[..., 1] != 0) # (num_agent, 91)
50
+ # TODO: fix this
51
+ if args.smart:
52
+ for shifted_tid in range(token_valid_mask.shape[1]):
53
+ valid_mask[:, shifted_tid * shift : (shifted_tid + 1) * shift] = token_valid_mask[:, shifted_tid : shifted_tid + 1].repeat(1, shift)
54
+ else:
55
+ for shifted_tid in range(token_index.shape[1]):
56
+ valid_mask[:, shifted_tid * shift : (shifted_tid + 1) * shift] = token_index[:, shifted_tid : shifted_tid + 1] != token_size + 2
57
+ last_valid_step = valid_mask.shape[1] - 1 - torch.argmax(valid_mask.flip(dims=[1]).long(), dim=1)
58
+ last_valid_step = {int(index[i]): int(last_valid_step[i]) for i in range(len(index))}
59
+
60
+ _, token_num, token_contour_dim, feat_dim = tokens.shape
61
+ tokens_src = tokens.reshape(num_agent, token_num * token_contour_dim, feat_dim)
62
+ tokens_all_src = tokens_all.reshape(num_agent, token_num * 6 * token_contour_dim, feat_dim)
63
+ prev_heading = heading[:, 0]
64
+ prev_pos = position[:, 0]
65
+
66
+ fig_paths = []
67
+ agent_colors = np.zeros((num_agent, position.shape[1]))
68
+ shape = np.zeros((num_agent, position.shape[1], 2)) + 3.
69
+ for tid in tqdm(range(shift, position.shape[1], shift), leave=False, desc="Token ..."):
70
+ cos, sin = prev_heading.cos(), prev_heading.sin()
71
+ rot_mat = prev_heading.new_zeros(num_agent, 2, 2)
72
+ rot_mat[:, 0, 0] = cos
73
+ rot_mat[:, 0, 1] = sin
74
+ rot_mat[:, 1, 0] = -sin
75
+ rot_mat[:, 1, 1] = cos
76
+ tokens_world = torch.bmm(torch.from_numpy(tokens_src).float(), rot_mat).reshape(num_agent,
77
+ token_num,
78
+ token_contour_dim,
79
+ feat_dim)
80
+ tokens_all_world = torch.bmm(torch.from_numpy(tokens_all_src).float(), rot_mat).reshape(num_agent,
81
+ token_num,
82
+ 6,
83
+ token_contour_dim,
84
+ feat_dim)
85
+ tokens_world += prev_pos[:, None, None, :2]
86
+ tokens_all_world += prev_pos[:, None, None, None, :2]
87
+ tokens_select = tokens_world[:, tid // shift - 1] # (num_agent, token_contour_dim, feat_dim)
88
+ tokens_all_select = tokens_all_world[:, tid // shift - 1] # (num_agent, 6, token_contour_dim, feat_dim)
89
+
90
+ diff_xy = tokens_select[:, 0, :] - tokens_select[:, 3, :]
91
+ prev_heading = heading[:, tid].clone()
92
+ # prev_heading[valid_mask[:, tid - shift]] = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])[
93
+ # valid_mask[:, tid - shift]]
94
+ prev_pos = position[:, tid].clone()
95
+ # prev_pos[valid_mask[:, tid - shift]] = tokens_select.mean(dim=1)[valid_mask[:, tid - shift]]
96
+
97
+ # NOTE tokens_pos equals to tokens_all_pos[:, -1]
98
+ tokens_pos = tokens_select.mean(dim=1) # (num_agent, 2)
99
+ tokens_all_pos = tokens_all_select.mean(dim=2) # (num_agent, 6, 2)
100
+
101
+ # colors
102
+ cur_token_index = token_index[:, tid // shift - 1]
103
+ is_bos = cur_token_index == token_size
104
+ is_eos = cur_token_index == token_size + 1
105
+ is_invalid = cur_token_index == token_size + 2
106
+ is_valid = ~is_bos & ~is_eos & ~is_invalid
107
+ agent_colors[is_valid, tid - shift : tid] = 1
108
+ agent_colors[is_bos, tid - shift : tid] = 2
109
+ agent_colors[is_eos, tid - shift : tid] = 3
110
+ agent_colors[is_invalid, tid - shift : tid] = 4
111
+
112
+ for i in tqdm(range(shift), leave=False, desc="Timestep ..."):
113
+ global_tid = tid - shift + i
114
+ cur_valid_mask = valid_mask[:, tid - shift] # only when the last tokenized timestep is valid the current shifts trajectory is valid
115
+ xs = tokens_all_pos[cur_valid_mask, i, 0]
116
+ ys = tokens_all_pos[cur_valid_mask, i, 1]
117
+ widths = shape[cur_valid_mask, global_tid, 1]
118
+ lengths = shape[cur_valid_mask, global_tid, 0]
119
+ angles = heading[cur_valid_mask, global_tid]
120
+ cur_agent_colors = agent_colors[cur_valid_mask, global_tid]
121
+ current_index = index[cur_valid_mask]
122
+
123
+ drawn_agents = []
124
+ drawn_texts = []
125
+ for x, y, width, length, angle, color_type, id in zip(
126
+ xs, ys, widths, lengths, angles, cur_agent_colors, current_index):
127
+ if x < 3000: continue
128
+ agent = plt.Rectangle((x, y), width, length, # angle=((angle + np.pi / 2) / np.pi * 360) % 360,
129
+ linewidth=0.2,
130
+ facecolor=colors[int(color_type) - 1][0],
131
+ edgecolor=colors[int(color_type) - 1][1])
132
+ ax.add_patch(agent)
133
+ text = plt.text(x-4, y-4, f"{str(id)}:{str(global_tid)}", fontdict={'family': 'serif', 'size': 3, 'color': 'red'})
134
+
135
+ if global_tid != last_valid_step[id]:
136
+ drawn_agents.append(agent)
137
+ drawn_texts.append(text)
138
+
139
+ # draw timestep to be tokenized
140
+ if global_tid % shift == 0:
141
+ tokenize_agent = plt.Rectangle((x, y), width, length, # angle=((angle + np.pi / 2) / np.pi * 360) % 360,
142
+ linewidth=0.2, fill=False,
143
+ edgecolor=colors[int(color_type) - 1][1])
144
+ ax.add_patch(tokenize_agent)
145
+
146
+ plt.gca().set_aspect('equal', adjustable='box')
147
+
148
+ fig_path = f"debug/tokenize/steps/{scenario_id}_{global_tid}.png"
149
+ plt.savefig(fig_path, dpi=600, bbox_inches="tight")
150
+ fig_paths.append(fig_path)
151
+
152
+ for drawn_agent, drawn_text in zip(drawn_agents, drawn_texts):
153
+ drawn_agent.remove()
154
+ drawn_text.remove()
155
+
156
+ plt.close()
157
+
158
+ # generate gif
159
+ import imageio.v2 as imageio
160
+ images = []
161
+ for fig_path in tqdm(fig_paths, leave=False, desc="Generate gif ..."):
162
+ images.append(imageio.imread(fig_path))
163
+ imageio.mimsave(f"debug/tokenize/{scenario_id}_tokenize_{posfix}.gif", images, duration=0.1)
164
+
165
+
166
+ def main(data):
167
+
168
+ token_size = 2048
169
+
170
+ os.makedirs("debug/tokenize/steps/", exist_ok=True)
171
+ scenario_id = data["scenario_id"]
172
+
173
+ selected_agents_index = [1, 21, 35, 36, 46]
174
+
175
+ # raw data
176
+ if not os.path.exists(f"debug/tokenize/{scenario_id}_raw.gif"):
177
+ draw_raw(data, selected_agents_index)
178
+
179
+ # tokenization
180
+ token_processor = TokenProcessor(token_size, disable_invalid=args.smart)
181
+ print(f"Loaded token processor with token_size: {token_size}")
182
+ data = token_processor.preprocess(data)
183
+
184
+ # tokenzied data
185
+ posfix = "smart" if args.smart else "ours"
186
+ # if not os.path.exists(f"debug/tokenize/{scenario_id}_tokenize_{posfix}.gif"):
187
+ draw_tokenize(data, token_processor, selected_agents_index, posfix)
188
+
189
+ target_builder = WaymoTargetBuilder(num_historical_steps=11, num_future_steps=80)
190
+ data = target_builder(data)
191
+
192
+
193
+ if __name__ == "__main__":
194
+ parser = ArgumentParser(description="Testing script parameters")
195
+ parser.add_argument("--smart", action="store_true")
196
+ parser.add_argument("--data_path", type=str, default="/u/xiuyu/work/dev4/data/waymo_processed/training")
197
+ args = parser.parse_args()
198
+
199
+ scenario_id = "74ad7b76d5906d39"
200
+ data_path = os.path.join(args.data_path, f"{scenario_id}.pkl")
201
+ data = pickle.load(open(data_path, "rb"))
202
+ print(f"Loaded scenario {scenario_id}")
203
+
204
+ main(data)
backups/scripts/g2.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ #SBATCH --job-name g2 # Job name
4
+ ### Logging
5
+ #SBATCH --output=%j.out # Stdout (%j expands to jobId)
6
+ #SBATCH --error=%j.err # Stderr (%j expands to jobId)
7
+ ### Node info
8
+ #SBATCH --nodes=1 # Single node or multi node
9
+ #SBATCH --nodelist=sota-2
10
+ #SBATCH --time 24:00:00 # Max time (hh:mm:ss)
11
+ #SBATCH --gres=gpu:2 # GPUs per node
12
+ #SBATCH --mem=96G # Recommend 32G per GPU
13
+ #SBATCH --ntasks-per-node=1 # Tasks per node
14
+ #SBATCH --cpus-per-task=16 # Recommend 8 per GPU
15
+
16
+ export NCCL_DEBUG=INFO
17
+ export REQUESTS_CA_BUNDLE="/etc/ssl/certs/ca-certificates.crt"
18
+ export HTTPS_PROXY="https://192.168.0.10:443/"
19
+ export https_proxy="https://192.168.0.10:443/"
20
+
21
+ export TEST_VAL_TRAIN=False
22
+ export TEST_VAL_PRED=True
23
+ export WANDB=True
24
+
25
+ sleep 86400
26
+
27
+ cd /u/xiuyu/work/dev4
28
+ PYTHONPATH=".":$PYTHONPATH python3 train.py \
29
+ --devices 2 \
30
+ --config configs/train/train_scalable_with_state.yaml \
31
+ --save_ckpt_path output/seed_1k_pure_seed_150_3_emb_head_3_debug \
32
+ --pretrain_ckpt output/ours_map_pretrain/epoch=31.ckpt
33
+
34
+ PYTHONPATH=".":$PYTHONPATH python val.py \
35
+ --config configs/validation/val_scalable_with_state.yaml \
36
+ --save_path output/seed_debug \
37
+ --pretrain_ckpt output/seed_1k_pure_seed_150_3_emb_head_3/last.ckpt
backups/scripts/g4.sh ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ #SBATCH --job-name g4 # Job name
4
+ ### Logging
5
+ #SBATCH --output=%j.out # Stdout (%j expands to jobId)
6
+ #SBATCH --error=%j.err # Stderr (%j expands to jobId)
7
+ ### Node info
8
+ #SBATCH --nodes=1 # Single node or multi node
9
+ #SBATCH --nodelist=sota-1
10
+ #SBATCH --time 72:00:00 # Max time (hh:mm:ss)
11
+ #SBATCH --gres=gpu:4 # GPUs per node
12
+ #SBATCH --mem=128G # Recommend 32G per GPU
13
+ #SBATCH --ntasks-per-node=1 # Tasks per node
14
+ #SBATCH --cpus-per-task=32 # Recommend 8 per GPU
15
+
16
+ export REQUESTS_CA_BUNDLE="/etc/ssl/certs/ca-certificates.crt"
17
+ export HTTPS_PROXY="https://192.168.0.10:443/"
18
+ export https_proxy="https://192.168.0.10:443/"
19
+
20
+ export TEST_VAL_TRAIN=0
21
+ export TEST_VAL_PRED=1
22
+ export WANDB=1
23
+
24
+ sleep 604800
25
+
26
+ cd /u/xiuyu/work/dev4
27
+ PYTHONPATH=".":$PYTHONPATH python3 train.py \
28
+ --devices 4 \
29
+ --config configs/train/train_scalable_with_state.yaml \
30
+ --save_ckpt_path output/seq_1k_10_150_3_3_encode_occ_separate_offsets \
31
+ --pretrain_ckpt output/pretrain_scalable_map/epoch=31.ckpt
32
+
33
+ PYTHONPATH=".":$PYTHONPATH python val.py \
34
+ --config configs/ours_long_term.yaml \
35
+ --ckpt_path output/seq_5k_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long/epoch=31.ckpt
backups/scripts/g8.sh ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ #SBATCH --job-name g8 # Job name
4
+ ### Logging
5
+ #SBATCH --output=%j.out # Stdout (%j expands to jobId)
6
+ #SBATCH --error=%j.err # Stderr (%j expands to jobId)
7
+ ### Node info
8
+ #SBATCH --nodes=1 # Single node or multi node
9
+ #SBATCH --nodelist=sota-6
10
+ #SBATCH --time 120:00:00 # Max time (hh:mm:ss)
11
+ #SBATCH --gres=gpu:8 # GPUs per node
12
+ #SBATCH --mem=256G # Recommend 32G per GPU
13
+ #SBATCH --ntasks-per-node=1 # Tasks per node
14
+ #SBATCH --cpus-per-task=32 # Recommend 8 per GPU
15
+
16
+ export REQUESTS_CA_BUNDLE="/etc/ssl/certs/ca-certificates.crt"
17
+ export HTTPS_PROXY="https://192.168.0.10:443/"
18
+ export https_proxy="https://192.168.0.10:443/"
19
+
20
+ export TEST_VAL_TRAIN=0
21
+ export TEST_VAL_PRED=1
22
+ export WANDB=1
23
+
24
+ sleep 864000
25
+
26
+ cd /u/xiuyu/work/dev4
27
+ PYTHONPATH=".":$PYTHONPATH python3 train.py \
28
+ --devices 8 \
29
+ --config configs/train/train_scalable_long_term.yaml \
30
+ --save_ckpt_path output/seq_5k_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long \
31
+ --pretrain_ckpt output/pretrain_scalable_map/epoch=31.ckpt
32
+
33
+ PYTHONPATH=".":$PYTHONPATH python3 train.py \
34
+ --devices 8 \
35
+ --config configs/ours_long_term.yaml \
36
+ --save_ckpt_path output2/seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid
37
+
38
+ PYTHONPATH=".":$PYTHONPATH python3 train.py \
39
+ --config configs/ours_long_term.yaml \
40
+ --save_ckpt_path output2/debug
41
+
42
+ PYTHONPATH=".":$PYTHONPATH python val.py \
43
+ --config configs/ours_long_term.yaml \
44
+ --ckpt_path output2/bug_seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long/last.ckpt
backups/scripts/hf_model.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from huggingface_hub import upload_folder, upload_file, hf_hub_download
4
+ from rich.console import Console
5
+ from rich.panel import Panel
6
+ from rich import box, style
7
+ from rich.table import Table
8
+
9
+ CONSOLE = Console(width=120)
10
+
11
+
12
+ def upload():
13
+
14
+ if args.folder_path:
15
+
16
+ try:
17
+ if token is not None:
18
+ upload_folder(repo_id=args.repo_id, folder_path=args.folder_path, ignore_patterns=ignore_patterns, token=token)
19
+ else:
20
+ upload_folder(repo_id=args.repo_id, folder_path=args.folder_path, ignore_patterns=ignore_patterns)
21
+ table = Table(title=None, show_header=False, box=box.MINIMAL, title_style=style.Style(bold=True))
22
+ table.add_row(f"Model id {args.repo_id}", str(args.folder_path))
23
+ CONSOLE.print(Panel(table, title="[bold][green]:tada: Upload completed DO NOT forget specify the model id in methods! :tada:[/bold]", expand=False))
24
+
25
+ except Exception as e:
26
+ CONSOLE.print(f"[bold][yellow]:tada: Upload failed due to {e}.")
27
+ raise e
28
+
29
+ if args.file_path:
30
+
31
+ try:
32
+ if token is not None:
33
+ upload_file(
34
+ path_or_fileobj=args.file_path,
35
+ path_in_repo=os.path.basename(args.file_path),
36
+ repo_id=args.repo_id,
37
+ repo_type='model',
38
+ token=token
39
+ )
40
+ else:
41
+ upload_file(
42
+ path_or_fileobj=args.file_path,
43
+ path_in_repo=os.path.basename(args.file_path),
44
+ repo_id=args.repo_id,
45
+ repo_type='model',
46
+ )
47
+ table = Table(title=None, show_header=False, box=box.MINIMAL, title_style=style.Style(bold=True))
48
+ table.add_row(f"Model id {args.repo_id}", str(args.file_path))
49
+ CONSOLE.print(Panel(table, title="[bold][green]:tada: Upload completed! :tada:[/bold]", expand=False))
50
+
51
+ except Exception as e:
52
+ CONSOLE.print(f"[bold][yellow]:tada: Upload failed due to {e}.")
53
+ raise e
54
+
55
+
56
+ def download():
57
+
58
+ try:
59
+ if token is not None:
60
+ ckpt_path = hf_hub_download(
61
+ repo_id=args.repo_id,
62
+ filename=args.file_path,
63
+ token=token
64
+ )
65
+ else:
66
+ ckpt_path = hf_hub_download(
67
+ repo_id=args.repo_id,
68
+ filename=args.file_path,
69
+ )
70
+ table = Table(title=None, show_header=False, box=box.MINIMAL, title_style=style.Style(bold=True))
71
+ table.add_row(f"Model id {args.repo_id}", str(args.file_path))
72
+ CONSOLE.print(Panel(table, title=f"[bold][green]:tada: Download completed to {ckpt_path}! :tada:[/bold]", expand=False))
73
+
74
+ if args.save_path is not None:
75
+ os.makedirs(args.save_path, exist_ok=True)
76
+ import shutil
77
+ shutil.copy(ckpt_path, os.path.join(args.save_path, args.file_path))
78
+
79
+ except Exception as e:
80
+ CONSOLE.print(f"[bold][yellow]:tada: Download failed due to {e}.")
81
+ raise e
82
+
83
+ return ckpt_path
84
+
85
+
86
+ if __name__ == "__main__":
87
+ parser = argparse.ArgumentParser()
88
+ parser.add_argument("--repo_id", type=str, default=None, required=True)
89
+ parser.add_argument("--upload", action="store_true")
90
+ parser.add_argument("--download", action="store_true")
91
+ parser.add_argument("--folder_path", type=str, default=None, required=False)
92
+ parser.add_argument("--file_path", type=str, default=None, required=False)
93
+ parser.add_argument("--save_path", type=str, default=None, required=False)
94
+ parser.add_argument("--token", type=str, default=None, required=False)
95
+ args = parser.parse_args()
96
+
97
+ token = args.token or os.getenv("hf_token", None)
98
+ ignore_patterns = ["**/optimizer.bin", "**/random_states*", "**/scaler.pt", "**/scheduler.bin"]
99
+
100
+ if not (args.folder_path or args.file_path):
101
+ raise RuntimeError(f'Choose either folder path or file path please!')
102
+
103
+ if len(args.repo_id.split('/')) != 2:
104
+ raise RuntimeError(f'Invalid repo_id: {args.repo_id}, please use in [use-id]/[repo-name] format')
105
+ CONSOLE.log(f"Use repo: [bold][yellow] {args.repo_id}")
106
+
107
+ if args.upload:
108
+ upload()
109
+
110
+ if args.download:
111
+ download()
backups/scripts/pretrain_map.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ mkdir -p job_out
4
+
5
+ #SBATCH --job-name YOUR_JOB_NAME # Job name
6
+ ### Logging
7
+ #SBATCH --output=job_out/%j.out # Stdout (%j expands to jobId)
8
+ #SBATCH --error=job_out/%j.err # Stderr (%j expands to jobId)
9
+ ### Node info
10
+ #SBATCH --nodes=1 # Single node or multi node
11
+ #SBATCH --nodelist=sota-6
12
+ #SBATCH --time 20:00:00 # Max time (hh:mm:ss)
13
+ #SBATCH --gres=gpu:4 # GPUs per node
14
+ #SBATCH --mem=256G # Recommend 32G per GPU
15
+ #SBATCH --ntasks-per-node=4 # Tasks per node
16
+ #SBATCH --cpus-per-task=256 # Recommend 8 per GPU
17
+ ### Whatever your job needs to do
18
+
19
+ export REQUESTS_CA_BUNDLE="/etc/ssl/certs/ca-certificates.crt"
20
+ export HTTPS_PROXY="https://192.168.0.10:443/"
21
+ export https_proxy="https://192.168.0.10:443/"
22
+
23
+ export TEST_VAL_PRED=True
24
+ export WANDB=True
25
+
26
+ cd /u/xiuyu/work/dev4
27
+ PYTHONPATH=".":$PYTHONPATH python3 train.py --config configs/train/pretrain_scalable_map.yaml --save_ckpt_path output/ours_map_pretrain
backups/scripts/run_eval.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /bin/bash
2
+
3
+ # env
4
+ export REQUESTS_CA_BUNDLE="/etc/ssl/certs/ca-certificates.crt"
5
+ export HTTPS_PROXY="https://192.168.0.10:443/"
6
+ export https_proxy="https://192.168.0.10:443/"
7
+
8
+ export WANDB=1
9
+
10
+ # args
11
+ DEVICES=$1
12
+ CONFIG='configs/ours_long_term.yaml'
13
+ # CKPT_PATH='output/scalable_smart_long/last.ckpt'
14
+ CKPT_PATH='output2/seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long/last.ckpt'
15
+
16
+ # run
17
+ PYTHONPATH=".":$PYTHONPATH python3 run.py \
18
+ --devices $DEVICES \
19
+ --config $CONFIG \
20
+ --ckpt_path $CKPT_PATH ${@:2}
backups/scripts/run_train.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /bin/bash
2
+
3
+ # env
4
+ export REQUESTS_CA_BUNDLE="/etc/ssl/certs/ca-certificates.crt"
5
+ export HTTPS_PROXY="https://192.168.0.10:443/"
6
+ export https_proxy="https://192.168.0.10:443/"
7
+
8
+ export WANDB=1
9
+
10
+ # args
11
+ DEVICES=$1
12
+ CONFIG='configs/ours_long_term.yaml'
13
+ SAVE_CKPT_PATH='output/scalable_smart_long'
14
+
15
+ # run
16
+ PYTHONPATH=".":$PYTHONPATH python3 run.py \
17
+ --train \
18
+ --devices $DEVICES \
19
+ --config $CONFIG \
20
+ --save_ckpt_path $SAVE_CKPT_PATH