Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/configs/experiments/ablate_grid_tokens.yaml +106 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/configs/ours_long_term.yaml +105 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/configs/ours_standard.yaml +101 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/configs/ours_standard_decode_occ.yaml +100 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/configs/pretrain_scalable_map.yaml +97 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/configs/smart.yaml +70 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/data_preprocess.py +916 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/datasets/preprocess.py +761 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/datasets/scalable_dataset.py +276 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/metrics/box_utils.py +113 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/metrics/compute_metrics.py +1812 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/metrics/geometry_utils.py +137 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/metrics/interact_features.py +220 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/metrics/map_features.py +349 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/metrics/placement_features.py +48 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/metrics/protos/long_metrics_pb2.py +648 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/metrics/protos/map_pb2.py +1070 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/metrics/protos/scenario_pb2.py +454 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/metrics/trajectory_features.py +52 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/metrics/val_close_long_metrics.json +24 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/model/smart.py +1100 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/modules/agent_decoder.py +0 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/modules/attr_tokenizer.py +109 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/modules/debug.py +1439 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/modules/layers.py +371 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/modules/map_decoder.py +130 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/modules/occ_decoder.py +927 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/modules/smart_decoder.py +137 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/utils/cluster_reader.py +45 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/utils/func.py +260 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/utils/graph.py +89 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/utils/metrics.py +692 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/utils/visualization.py +1145 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/environment.yml +326 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/run.py +181 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/aggregate_log_metric_features.sh +16 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/c128.sh +13 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/c64.sh +13 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/compute_metrics.sh +13 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/data_preprocess.sh +12 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/data_preprocess_loop.sh +23 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/debug.py +17 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/debug_map.py +204 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/g2.sh +37 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/g4.sh +35 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/g8.sh +44 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/hf_model.py +111 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/pretrain_map.sh +27 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/run_eval.sh +20 -0
- seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/run_train.sh +20 -0
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/configs/experiments/ablate_grid_tokens.yaml
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Config format schema number, the yaml support to valid case source from different dataset
|
2 |
+
time_info: &time_info
|
3 |
+
num_historical_steps: 11
|
4 |
+
num_future_steps: 80
|
5 |
+
use_intention: True
|
6 |
+
token_size: 2048
|
7 |
+
predict_motion: True
|
8 |
+
predict_state: True
|
9 |
+
predict_map: True
|
10 |
+
predict_occ: True
|
11 |
+
state_token:
|
12 |
+
invalid: 0
|
13 |
+
valid: 1
|
14 |
+
enter: 2
|
15 |
+
exit: 3
|
16 |
+
pl2seed_radius: 75.
|
17 |
+
disable_grid_token: True
|
18 |
+
grid_range: 150. # 2 times of pl2seed_radius
|
19 |
+
grid_interval: 3.
|
20 |
+
angle_interval: 3.
|
21 |
+
seed_size: 1
|
22 |
+
buffer_size: 128
|
23 |
+
max_num: 32
|
24 |
+
|
25 |
+
Dataset:
|
26 |
+
root:
|
27 |
+
train_batch_size: 1
|
28 |
+
val_batch_size: 1
|
29 |
+
test_batch_size: 1
|
30 |
+
shuffle: True
|
31 |
+
num_workers: 1
|
32 |
+
pin_memory: True
|
33 |
+
persistent_workers: True
|
34 |
+
train_raw_dir: 'data/waymo_processed/training'
|
35 |
+
val_raw_dir: 'data/waymo_processed/validation'
|
36 |
+
test_raw_dir: 'data/waymo_processed/validation'
|
37 |
+
val_tfrecords_splitted: data/waymo_processed/training/validation_tfrecords_splitted
|
38 |
+
transform: WaymoTargetBuilder
|
39 |
+
train_processed_dir:
|
40 |
+
val_processed_dir:
|
41 |
+
test_processed_dir:
|
42 |
+
dataset: 'scalable'
|
43 |
+
<<: *time_info
|
44 |
+
|
45 |
+
Trainer:
|
46 |
+
strategy: ddp_find_unused_parameters_false
|
47 |
+
accelerator: 'gpu'
|
48 |
+
devices: 1
|
49 |
+
max_epochs: 32
|
50 |
+
save_ckpt_path:
|
51 |
+
num_nodes: 1
|
52 |
+
mode:
|
53 |
+
ckpt_path:
|
54 |
+
precision: 32
|
55 |
+
accumulate_grad_batches: 1
|
56 |
+
overfit_epochs: 6000
|
57 |
+
|
58 |
+
Model:
|
59 |
+
predictor: 'smart'
|
60 |
+
decoder_type: 'agent_decoder' # choose from ['agent_decoder', 'occ_decoder']
|
61 |
+
dataset: 'waymo'
|
62 |
+
input_dim: 2
|
63 |
+
hidden_dim: 128
|
64 |
+
output_dim: 2
|
65 |
+
output_head: False
|
66 |
+
num_heads: 8
|
67 |
+
<<: *time_info
|
68 |
+
head_dim: 16
|
69 |
+
dropout: 0.1
|
70 |
+
num_freq_bands: 64
|
71 |
+
lr: 0.0005
|
72 |
+
warmup_steps: 0
|
73 |
+
total_steps: 32
|
74 |
+
predict_map_token: False
|
75 |
+
num_recurrent_steps_val: 300
|
76 |
+
val_open_loop: False
|
77 |
+
val_close_loop: True
|
78 |
+
val_insert: False
|
79 |
+
n_rollout_close_val: 1
|
80 |
+
decoder:
|
81 |
+
<<: *time_info
|
82 |
+
num_map_layers: 3
|
83 |
+
num_agent_layers: 6
|
84 |
+
a2a_radius: 60
|
85 |
+
pl2pl_radius: 10
|
86 |
+
pl2a_radius: 30
|
87 |
+
a2sa_radius: 10
|
88 |
+
pl2sa_radius: 10
|
89 |
+
time_span: 60
|
90 |
+
loss_weight:
|
91 |
+
token_cls_loss: 1
|
92 |
+
map_token_loss: 1
|
93 |
+
state_cls_loss: 10
|
94 |
+
type_cls_loss: 5
|
95 |
+
pos_cls_loss: 1
|
96 |
+
head_cls_loss: 1
|
97 |
+
offset_reg_loss: 5
|
98 |
+
shape_reg_loss: .2
|
99 |
+
pos_reg_loss: 10
|
100 |
+
state_weight: [0.1, 0.1, 0.8] # invalid, valid, exit
|
101 |
+
seed_state_weight: [0.1, 0.9] # invalid, enter
|
102 |
+
seed_type_weight: [0.8, 0.1, 0.1]
|
103 |
+
agent_occ_pos_weight: 100
|
104 |
+
pt_occ_pos_weight: 5
|
105 |
+
agent_occ_loss: 10
|
106 |
+
pt_occ_loss: 10
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/configs/ours_long_term.yaml
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Config format schema number, the yaml support to valid case source from different dataset
|
2 |
+
time_info: &time_info
|
3 |
+
num_historical_steps: 11
|
4 |
+
num_future_steps: 80
|
5 |
+
use_intention: True
|
6 |
+
token_size: 2048
|
7 |
+
predict_motion: True
|
8 |
+
predict_state: True
|
9 |
+
predict_map: True
|
10 |
+
predict_occ: True
|
11 |
+
state_token:
|
12 |
+
invalid: 0
|
13 |
+
valid: 1
|
14 |
+
enter: 2
|
15 |
+
exit: 3
|
16 |
+
pl2seed_radius: 75.
|
17 |
+
grid_range: 150. # 2 times of pl2seed_radius
|
18 |
+
grid_interval: 3.
|
19 |
+
angle_interval: 3.
|
20 |
+
seed_size: 1
|
21 |
+
buffer_size: 128
|
22 |
+
max_num: 32
|
23 |
+
|
24 |
+
Dataset:
|
25 |
+
root:
|
26 |
+
train_batch_size: 1
|
27 |
+
val_batch_size: 1
|
28 |
+
test_batch_size: 1
|
29 |
+
shuffle: True
|
30 |
+
num_workers: 1
|
31 |
+
pin_memory: True
|
32 |
+
persistent_workers: True
|
33 |
+
train_raw_dir: 'data/waymo_processed/training'
|
34 |
+
val_raw_dir: 'data/waymo_processed/validation'
|
35 |
+
test_raw_dir: 'data/waymo_processed/validation'
|
36 |
+
val_tfrecords_splitted: data/waymo_processed/training/validation_tfrecords_splitted
|
37 |
+
transform: WaymoTargetBuilder
|
38 |
+
train_processed_dir:
|
39 |
+
val_processed_dir:
|
40 |
+
test_processed_dir:
|
41 |
+
dataset: 'scalable'
|
42 |
+
<<: *time_info
|
43 |
+
|
44 |
+
Trainer:
|
45 |
+
strategy: ddp_find_unused_parameters_false
|
46 |
+
accelerator: 'gpu'
|
47 |
+
devices: 1
|
48 |
+
max_epochs: 32
|
49 |
+
save_ckpt_path:
|
50 |
+
num_nodes: 1
|
51 |
+
mode:
|
52 |
+
ckpt_path:
|
53 |
+
precision: 32
|
54 |
+
accumulate_grad_batches: 1
|
55 |
+
overfit_epochs: 6000
|
56 |
+
|
57 |
+
Model:
|
58 |
+
predictor: 'smart'
|
59 |
+
decoder_type: 'agent_decoder' # choose from ['agent_decoder', 'occ_decoder']
|
60 |
+
dataset: 'waymo'
|
61 |
+
input_dim: 2
|
62 |
+
hidden_dim: 128
|
63 |
+
output_dim: 2
|
64 |
+
output_head: False
|
65 |
+
num_heads: 8
|
66 |
+
<<: *time_info
|
67 |
+
head_dim: 16
|
68 |
+
dropout: 0.1
|
69 |
+
num_freq_bands: 64
|
70 |
+
lr: 0.0005
|
71 |
+
warmup_steps: 0
|
72 |
+
total_steps: 32
|
73 |
+
predict_map_token: False
|
74 |
+
num_recurrent_steps_val: 300
|
75 |
+
val_open_loop: False
|
76 |
+
val_close_loop: True
|
77 |
+
val_insert: False
|
78 |
+
n_rollout_close_val: 1
|
79 |
+
decoder:
|
80 |
+
<<: *time_info
|
81 |
+
num_map_layers: 3
|
82 |
+
num_agent_layers: 6
|
83 |
+
a2a_radius: 60
|
84 |
+
pl2pl_radius: 10
|
85 |
+
pl2a_radius: 30
|
86 |
+
a2sa_radius: 10
|
87 |
+
pl2sa_radius: 10
|
88 |
+
time_span: 60
|
89 |
+
loss_weight:
|
90 |
+
token_cls_loss: 1
|
91 |
+
map_token_loss: 1
|
92 |
+
state_cls_loss: 10
|
93 |
+
type_cls_loss: 5
|
94 |
+
pos_cls_loss: 1
|
95 |
+
head_cls_loss: 1
|
96 |
+
offset_reg_loss: 5
|
97 |
+
shape_reg_loss: .2
|
98 |
+
pos_reg_loss: 10
|
99 |
+
state_weight: [0.1, 0.1, 0.8] # invalid, valid, exit
|
100 |
+
seed_state_weight: [0.1, 0.9] # invalid, enter
|
101 |
+
seed_type_weight: [0.8, 0.1, 0.1]
|
102 |
+
agent_occ_pos_weight: 100
|
103 |
+
pt_occ_pos_weight: 5
|
104 |
+
agent_occ_loss: 10
|
105 |
+
pt_occ_loss: 10
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/configs/ours_standard.yaml
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Config format schema number, the yaml support to valid case source from different dataset
|
2 |
+
time_info: &time_info
|
3 |
+
num_historical_steps: 11
|
4 |
+
num_future_steps: 80
|
5 |
+
use_intention: True
|
6 |
+
token_size: 2048
|
7 |
+
predict_motion: True
|
8 |
+
predict_state: True
|
9 |
+
predict_map: False
|
10 |
+
predict_occ: True
|
11 |
+
state_token:
|
12 |
+
invalid: 0
|
13 |
+
valid: 1
|
14 |
+
enter: 2
|
15 |
+
exit: 3
|
16 |
+
pl2seed_radius: 75.
|
17 |
+
grid_range: 150. # 2 times of pl2seed_radius
|
18 |
+
grid_interval: 3.
|
19 |
+
angle_interval: 3.
|
20 |
+
seed_size: 1
|
21 |
+
buffer_size: 128
|
22 |
+
|
23 |
+
Dataset:
|
24 |
+
root:
|
25 |
+
train_batch_size: 1
|
26 |
+
val_batch_size: 1
|
27 |
+
test_batch_size: 1
|
28 |
+
shuffle: True
|
29 |
+
num_workers: 1
|
30 |
+
pin_memory: True
|
31 |
+
persistent_workers: True
|
32 |
+
train_raw_dir: ["data/waymo_processed/training"]
|
33 |
+
val_raw_dir: ["data/waymo_processed/validation"]
|
34 |
+
test_raw_dir: ["data/waymo_processed/validation"]
|
35 |
+
transform: WaymoTargetBuilder
|
36 |
+
train_processed_dir:
|
37 |
+
val_processed_dir:
|
38 |
+
test_processed_dir:
|
39 |
+
dataset: "scalable"
|
40 |
+
<<: *time_info
|
41 |
+
|
42 |
+
Trainer:
|
43 |
+
strategy: ddp_find_unused_parameters_false
|
44 |
+
accelerator: "gpu"
|
45 |
+
devices: 1
|
46 |
+
max_epochs: 32
|
47 |
+
overfit_epochs: 6000
|
48 |
+
save_ckpt_path:
|
49 |
+
num_nodes: 1
|
50 |
+
mode:
|
51 |
+
ckpt_path:
|
52 |
+
precision: 32
|
53 |
+
accumulate_grad_batches: 1
|
54 |
+
|
55 |
+
Model:
|
56 |
+
predictor: "smart"
|
57 |
+
decoder_type: "agent_decoder" # choose from ['agent_decoder', 'occ_decoder']
|
58 |
+
dataset: "waymo"
|
59 |
+
input_dim: 2
|
60 |
+
hidden_dim: 128
|
61 |
+
output_dim: 2
|
62 |
+
output_head: False
|
63 |
+
num_heads: 8
|
64 |
+
<<: *time_info
|
65 |
+
head_dim: 16
|
66 |
+
dropout: 0.1
|
67 |
+
num_freq_bands: 64
|
68 |
+
lr: 0.0005
|
69 |
+
warmup_steps: 0
|
70 |
+
total_steps: 32
|
71 |
+
predict_map_token: False
|
72 |
+
num_recurrent_steps_val: -1
|
73 |
+
val_open_loop: True
|
74 |
+
val_close_loop: False
|
75 |
+
val_insert: False
|
76 |
+
decoder:
|
77 |
+
<<: *time_info
|
78 |
+
num_map_layers: 3
|
79 |
+
num_agent_layers: 6
|
80 |
+
a2a_radius: 60
|
81 |
+
pl2pl_radius: 10
|
82 |
+
pl2a_radius: 30
|
83 |
+
a2sa_radius: 10
|
84 |
+
pl2sa_radius: 10
|
85 |
+
time_span: 60
|
86 |
+
loss_weight:
|
87 |
+
token_cls_loss: 1
|
88 |
+
map_token_loss: 1
|
89 |
+
state_cls_loss: 10
|
90 |
+
type_cls_loss: 5
|
91 |
+
pos_cls_loss: 1
|
92 |
+
head_cls_loss: 1
|
93 |
+
offset_reg_loss: 5
|
94 |
+
shape_reg_loss: .2
|
95 |
+
state_weight: [0.1, 0.1, 0.8] # invalid, valid, exit
|
96 |
+
seed_state_weight: [0.1, 0.9] # invalid, enter
|
97 |
+
seed_type_weight: [0.8, 0.1, 0.1]
|
98 |
+
agent_occ_pos_weight: 100
|
99 |
+
pt_occ_pos_weight: 5
|
100 |
+
agent_occ_loss: 10
|
101 |
+
pt_occ_loss: 10
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/configs/ours_standard_decode_occ.yaml
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Config format schema number, the yaml support to valid case source from different dataset
|
2 |
+
time_info: &time_info
|
3 |
+
num_historical_steps: 11
|
4 |
+
num_future_steps: 80
|
5 |
+
use_intention: True
|
6 |
+
token_size: 2048
|
7 |
+
predict_motion: False
|
8 |
+
predict_state: False
|
9 |
+
predict_map: False
|
10 |
+
predict_occ: True
|
11 |
+
state_token:
|
12 |
+
invalid: 0
|
13 |
+
valid: 1
|
14 |
+
enter: 2
|
15 |
+
exit: 3
|
16 |
+
pl2seed_radius: 75.
|
17 |
+
grid_range: 150. # 2 times of pl2seed_radius
|
18 |
+
grid_interval: 3.
|
19 |
+
angle_interval: 3.
|
20 |
+
seed_size: 1
|
21 |
+
buffer_size: 128
|
22 |
+
|
23 |
+
Dataset:
|
24 |
+
root:
|
25 |
+
train_batch_size: 1
|
26 |
+
val_batch_size: 1
|
27 |
+
test_batch_size: 1
|
28 |
+
shuffle: True
|
29 |
+
num_workers: 1
|
30 |
+
pin_memory: True
|
31 |
+
persistent_workers: True
|
32 |
+
train_raw_dir: ["data/waymo_processed/training"]
|
33 |
+
val_raw_dir: ["data/waymo_processed/validation"]
|
34 |
+
test_raw_dir: ["data/waymo_processed/validation"]
|
35 |
+
transform: WaymoTargetBuilder
|
36 |
+
train_processed_dir:
|
37 |
+
val_processed_dir:
|
38 |
+
test_processed_dir:
|
39 |
+
dataset: "scalable"
|
40 |
+
<<: *time_info
|
41 |
+
|
42 |
+
Trainer:
|
43 |
+
strategy: ddp_find_unused_parameters_false
|
44 |
+
accelerator: "gpu"
|
45 |
+
devices: 1
|
46 |
+
max_epochs: 32
|
47 |
+
overfit_epochs: 6000
|
48 |
+
save_ckpt_path:
|
49 |
+
num_nodes: 1
|
50 |
+
mode:
|
51 |
+
ckpt_path:
|
52 |
+
precision: 32
|
53 |
+
accumulate_grad_batches: 1
|
54 |
+
|
55 |
+
Model:
|
56 |
+
predictor: "smart"
|
57 |
+
decoder_type: "occ_decoder" # choose from ['agent_decoder', 'occ_decoder']
|
58 |
+
dataset: "waymo"
|
59 |
+
input_dim: 2
|
60 |
+
hidden_dim: 128
|
61 |
+
output_dim: 2
|
62 |
+
output_head: False
|
63 |
+
num_heads: 8
|
64 |
+
<<: *time_info
|
65 |
+
head_dim: 16
|
66 |
+
dropout: 0.1
|
67 |
+
num_freq_bands: 64
|
68 |
+
lr: 0.0005
|
69 |
+
warmup_steps: 0
|
70 |
+
total_steps: 32
|
71 |
+
predict_map_token: False
|
72 |
+
num_recurrent_steps_val: -1
|
73 |
+
val_open_loop: True
|
74 |
+
val_closed_loop: False
|
75 |
+
decoder:
|
76 |
+
<<: *time_info
|
77 |
+
num_map_layers: 3
|
78 |
+
num_agent_layers: 6
|
79 |
+
a2a_radius: 60
|
80 |
+
pl2pl_radius: 10
|
81 |
+
pl2a_radius: 30
|
82 |
+
a2sa_radius: 10
|
83 |
+
pl2sa_radius: 10
|
84 |
+
time_span: 60
|
85 |
+
loss_weight:
|
86 |
+
token_cls_loss: 1
|
87 |
+
map_token_loss: 1
|
88 |
+
state_cls_loss: 10
|
89 |
+
type_cls_loss: 5
|
90 |
+
pos_cls_loss: 1
|
91 |
+
head_cls_loss: 1
|
92 |
+
offset_reg_loss: 5
|
93 |
+
shape_reg_loss: .2
|
94 |
+
state_weight: [0.1, 0.1, 0.8] # invalid, valid, exit
|
95 |
+
seed_state_weight: [0.1, 0.9] # invalid, enter
|
96 |
+
seed_type_weight: [0.8, 0.1, 0.1]
|
97 |
+
agent_occ_pos_weight: 100
|
98 |
+
pt_occ_pos_weight: 5
|
99 |
+
agent_occ_loss: 10
|
100 |
+
pt_occ_loss: 10
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/configs/pretrain_scalable_map.yaml
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Config format schema number, the yaml support to valid case source from different dataset
|
2 |
+
time_info: &time_info
|
3 |
+
num_historical_steps: 11
|
4 |
+
num_future_steps: 80
|
5 |
+
use_intention: True
|
6 |
+
token_size: 2048
|
7 |
+
predict_motion: False
|
8 |
+
predict_state: False
|
9 |
+
predict_map: True
|
10 |
+
predict_occ: False
|
11 |
+
state_token:
|
12 |
+
invalid: 0
|
13 |
+
valid: 1
|
14 |
+
enter: 2
|
15 |
+
exit: 3
|
16 |
+
pl2seed_radius: 75.
|
17 |
+
grid_range: 150. # 2 times of pl2seed_radius
|
18 |
+
grid_interval: 3.
|
19 |
+
angle_interval: 3.
|
20 |
+
seed_size: 1
|
21 |
+
buffer_size: 32
|
22 |
+
|
23 |
+
Dataset:
|
24 |
+
root:
|
25 |
+
train_batch_size: 1
|
26 |
+
val_batch_size: 1
|
27 |
+
test_batch_size: 1
|
28 |
+
shuffle: True
|
29 |
+
num_workers: 1
|
30 |
+
pin_memory: True
|
31 |
+
persistent_workers: True
|
32 |
+
train_raw_dir: ["data/waymo_processed/training"]
|
33 |
+
val_raw_dir: ["data/waymo_processed/validation"]
|
34 |
+
test_raw_dir: ["data/waymo_processed/validation"]
|
35 |
+
transform: WaymoTargetBuilder
|
36 |
+
train_processed_dir:
|
37 |
+
val_processed_dir:
|
38 |
+
test_processed_dir:
|
39 |
+
dataset: "scalable"
|
40 |
+
<<: *time_info
|
41 |
+
|
42 |
+
Trainer:
|
43 |
+
strategy: ddp_find_unused_parameters_false
|
44 |
+
accelerator: "gpu"
|
45 |
+
devices: 1
|
46 |
+
max_epochs: 32
|
47 |
+
overfit_epochs: 6000
|
48 |
+
save_ckpt_path:
|
49 |
+
num_nodes: 1
|
50 |
+
mode:
|
51 |
+
ckpt_path:
|
52 |
+
precision: 32
|
53 |
+
accumulate_grad_batches: 1
|
54 |
+
|
55 |
+
Model:
|
56 |
+
mode: "train"
|
57 |
+
predictor: "smart"
|
58 |
+
decoder_type: "agent_decoder"
|
59 |
+
dataset: "waymo"
|
60 |
+
input_dim: 2
|
61 |
+
hidden_dim: 128
|
62 |
+
output_dim: 2
|
63 |
+
output_head: False
|
64 |
+
num_heads: 8
|
65 |
+
<<: *time_info
|
66 |
+
head_dim: 16
|
67 |
+
dropout: 0.1
|
68 |
+
num_freq_bands: 64
|
69 |
+
lr: 0.0005
|
70 |
+
warmup_steps: 0
|
71 |
+
total_steps: 32
|
72 |
+
predict_map_token: False
|
73 |
+
decoder:
|
74 |
+
<<: *time_info
|
75 |
+
num_map_layers: 3
|
76 |
+
num_agent_layers: 6
|
77 |
+
a2a_radius: 60
|
78 |
+
pl2pl_radius: 10
|
79 |
+
pl2a_radius: 30
|
80 |
+
a2sa_radius: 10
|
81 |
+
pl2sa_radius: 10
|
82 |
+
time_span: 60
|
83 |
+
loss_weight:
|
84 |
+
token_cls_loss: 1
|
85 |
+
map_token_loss: 1
|
86 |
+
state_cls_loss: 10
|
87 |
+
type_cls_loss: 5
|
88 |
+
pos_cls_loss: 1
|
89 |
+
head_cls_loss: 1
|
90 |
+
shape_reg_loss: .2
|
91 |
+
state_weight: [0.1, 0.1, 0.8] # invalid, valid, exit
|
92 |
+
seed_state_weight: [0.1, 0.9] # invalid, enter
|
93 |
+
seed_type_weight: [0.8, 0.1, 0.1]
|
94 |
+
agent_occ_pos_weight: 100
|
95 |
+
pt_occ_pos_weight: 5
|
96 |
+
agent_occ_loss: 10
|
97 |
+
pt_occ_loss: 10
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/configs/smart.yaml
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Config format schema number, the yaml support to valid case source from different dataset
|
2 |
+
time_info: &time_info
|
3 |
+
num_historical_steps: 11
|
4 |
+
num_future_steps: 80
|
5 |
+
use_intention: True
|
6 |
+
token_size: 2048
|
7 |
+
disable_invalid: True
|
8 |
+
use_special_motion_token: False
|
9 |
+
use_state_token: False
|
10 |
+
only_state: False
|
11 |
+
|
12 |
+
Dataset:
|
13 |
+
root:
|
14 |
+
train_batch_size: 1
|
15 |
+
val_batch_size: 1
|
16 |
+
test_batch_size: 1
|
17 |
+
shuffle: True
|
18 |
+
num_workers: 1
|
19 |
+
pin_memory: True
|
20 |
+
persistent_workers: True
|
21 |
+
train_raw_dir: ["data/waymo_processed/training"]
|
22 |
+
val_raw_dir: ["data/waymo_processed/validation"]
|
23 |
+
test_raw_dir: ["data/waymo_processed/validation"]
|
24 |
+
transform: WaymoTargetBuilder
|
25 |
+
train_processed_dir:
|
26 |
+
val_processed_dir:
|
27 |
+
test_processed_dir:
|
28 |
+
dataset: "scalable"
|
29 |
+
<<: *time_info
|
30 |
+
|
31 |
+
Trainer:
|
32 |
+
strategy: ddp_find_unused_parameters_false
|
33 |
+
accelerator: "gpu"
|
34 |
+
devices: 1
|
35 |
+
max_epochs: 32
|
36 |
+
overfit_epochs: 5000
|
37 |
+
save_ckpt_path:
|
38 |
+
num_nodes: 1
|
39 |
+
mode:
|
40 |
+
ckpt_path:
|
41 |
+
precision: 32
|
42 |
+
accumulate_grad_batches: 1
|
43 |
+
|
44 |
+
Model:
|
45 |
+
mode: "train"
|
46 |
+
predictor: "smart"
|
47 |
+
dataset: "waymo"
|
48 |
+
input_dim: 2
|
49 |
+
hidden_dim: 128
|
50 |
+
output_dim: 2
|
51 |
+
output_head: False
|
52 |
+
num_heads: 8
|
53 |
+
<<: *time_info
|
54 |
+
head_dim: 16
|
55 |
+
dropout: 0.1
|
56 |
+
num_freq_bands: 64
|
57 |
+
lr: 0.0005
|
58 |
+
warmup_steps: 0
|
59 |
+
total_steps: 32
|
60 |
+
decoder:
|
61 |
+
<<: *time_info
|
62 |
+
num_map_layers: 3
|
63 |
+
num_agent_layers: 6
|
64 |
+
a2a_radius: 60
|
65 |
+
pl2pl_radius: 10
|
66 |
+
pl2a_radius: 30
|
67 |
+
time_span: 30
|
68 |
+
loss_weight:
|
69 |
+
token_cls_loss: 1
|
70 |
+
state_cls_loss: 5
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/data_preprocess.py
ADDED
@@ -0,0 +1,916 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Not a contribution
|
2 |
+
# Changes made by NVIDIA CORPORATION & AFFILIATES enabling <CAT-K> or otherwise documented as
|
3 |
+
# NVIDIA-proprietary are not a contribution and subject to the following terms and conditions:
|
4 |
+
# SPDX-FileCopyrightText: Copyright (c) <year> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
5 |
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
6 |
+
#
|
7 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
8 |
+
# property and proprietary rights in and to this material, related
|
9 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
10 |
+
# disclosure or distribution of this material and related documentation
|
11 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
12 |
+
# its affiliates is strictly prohibited.
|
13 |
+
|
14 |
+
import signal
|
15 |
+
import multiprocessing
|
16 |
+
import os
|
17 |
+
import numpy as np
|
18 |
+
import pandas as pd
|
19 |
+
import tensorflow as tf
|
20 |
+
import torch
|
21 |
+
import pickle
|
22 |
+
import easydict
|
23 |
+
from functools import partial
|
24 |
+
from scipy.interpolate import interp1d
|
25 |
+
from argparse import ArgumentParser
|
26 |
+
from tqdm import tqdm
|
27 |
+
from typing import Any, Dict, List, Optional
|
28 |
+
from waymo_open_dataset.protos import scenario_pb2
|
29 |
+
|
30 |
+
|
31 |
+
MIN_VALID_STEPS = 15
|
32 |
+
|
33 |
+
|
34 |
+
_polygon_types = ['VEHICLE', 'BIKE', 'BUS', 'PEDESTRIAN']
|
35 |
+
_polygon_light_type = ['LANE_STATE_STOP', 'LANE_STATE_GO', 'LANE_STATE_CAUTION', 'LANE_STATE_UNKNOWN']
|
36 |
+
_point_types = ['DASH_SOLID_YELLOW', 'DASH_SOLID_WHITE', 'DASHED_WHITE', 'DASHED_YELLOW',
|
37 |
+
'DOUBLE_SOLID_YELLOW', 'DOUBLE_SOLID_WHITE', 'DOUBLE_DASH_YELLOW', 'DOUBLE_DASH_WHITE',
|
38 |
+
'SOLID_YELLOW', 'SOLID_WHITE', 'SOLID_DASH_WHITE', 'SOLID_DASH_YELLOW', 'EDGE',
|
39 |
+
'NONE', 'UNKNOWN', 'CROSSWALK', 'CENTERLINE']
|
40 |
+
_polygon_to_polygon_types = ['NONE', 'PRED', 'SUCC', 'LEFT', 'RIGHT']
|
41 |
+
|
42 |
+
|
43 |
+
Lane_type_hash = {
|
44 |
+
4: "BIKE",
|
45 |
+
3: "VEHICLE",
|
46 |
+
2: "VEHICLE",
|
47 |
+
1: "BUS"
|
48 |
+
}
|
49 |
+
|
50 |
+
boundary_type_hash = {
|
51 |
+
5: "UNKNOWN",
|
52 |
+
6: "DASHED_WHITE",
|
53 |
+
7: "SOLID_WHITE",
|
54 |
+
8: "DOUBLE_DASH_WHITE",
|
55 |
+
9: "DASHED_YELLOW",
|
56 |
+
10: "DOUBLE_DASH_YELLOW",
|
57 |
+
11: "SOLID_YELLOW",
|
58 |
+
12: "DOUBLE_SOLID_YELLOW",
|
59 |
+
13: "DASH_SOLID_YELLOW",
|
60 |
+
14: "UNKNOWN",
|
61 |
+
15: "EDGE",
|
62 |
+
16: "EDGE"
|
63 |
+
}
|
64 |
+
|
65 |
+
|
66 |
+
def safe_list_index(ls: List[Any], elem: Any) -> Optional[int]:
|
67 |
+
try:
|
68 |
+
return ls.index(elem)
|
69 |
+
except ValueError:
|
70 |
+
return None
|
71 |
+
|
72 |
+
|
73 |
+
# def get_agent_features(df: pd.DataFrame, av_id, num_historical_steps=11, dim=3, num_steps=91) -> Dict[str, Any]:
|
74 |
+
# if args.disable_invalid: # filter out agents that are unseen during the historical time steps
|
75 |
+
# historical_df = df[df['timestep'] == num_historical_steps-1] # extract the timestep==10 (current)
|
76 |
+
# agent_ids = list(historical_df['track_id'].unique()) # these agents are seen at timestep==10 (current)
|
77 |
+
# df = df[df['track_id'].isin(agent_ids)] # remove other agents
|
78 |
+
# else:
|
79 |
+
# agent_ids = list(df['track_id'].unique())
|
80 |
+
|
81 |
+
# num_agents = len(agent_ids)
|
82 |
+
# # initialization
|
83 |
+
# valid_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool)
|
84 |
+
# current_valid_mask = torch.zeros(num_agents, dtype=torch.bool)
|
85 |
+
# predict_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool)
|
86 |
+
# agent_id: List[Optional[str]] = [None] * num_agents
|
87 |
+
# agent_type = torch.zeros(num_agents, dtype=torch.uint8)
|
88 |
+
# agent_category = torch.zeros(num_agents, dtype=torch.uint8)
|
89 |
+
# position = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
|
90 |
+
# heading = torch.zeros(num_agents, num_steps, dtype=torch.float)
|
91 |
+
# velocity = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
|
92 |
+
# shape = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
|
93 |
+
|
94 |
+
# for track_id, track_df in df.groupby('track_id'):
|
95 |
+
# agent_idx = agent_ids.index(track_id)
|
96 |
+
# all_agent_steps = track_df['timestep'].values
|
97 |
+
# valid_agent_steps = all_agent_steps[track_df['validity'].astype(np.bool_)].astype(np.int32)
|
98 |
+
# valid_mask[agent_idx, valid_agent_steps] = True
|
99 |
+
# current_valid_mask[agent_idx] = valid_mask[agent_idx, num_historical_steps - 1] # current timestep 10
|
100 |
+
# if args.disable_invalid:
|
101 |
+
# predict_mask[agent_idx, valid_agent_steps] = True
|
102 |
+
# else:
|
103 |
+
# predict_mask[agent_idx] = True
|
104 |
+
# predict_mask[agent_idx, :num_historical_steps] = False
|
105 |
+
# if not current_valid_mask[agent_idx]:
|
106 |
+
# predict_mask[agent_idx, num_historical_steps:] = False
|
107 |
+
|
108 |
+
# # TODO: why using vector_repr?
|
109 |
+
# if vector_repr: # a time step t is valid only when both t and t-1 are valid
|
110 |
+
# valid_mask[agent_idx, 1 : num_historical_steps] = (
|
111 |
+
# valid_mask[agent_idx, : num_historical_steps - 1] &
|
112 |
+
# valid_mask[agent_idx, 1 : num_historical_steps])
|
113 |
+
# valid_mask[agent_idx, 0] = False
|
114 |
+
|
115 |
+
# agent_id[agent_idx] = track_id
|
116 |
+
# agent_type[agent_idx] = _agent_types.index(track_df['object_type'].values[0])
|
117 |
+
# agent_category[agent_idx] = track_df['object_category'].values[0]
|
118 |
+
# position[agent_idx, valid_agent_steps, :3] = torch.from_numpy(np.stack([track_df['position_x'].values[valid_agent_steps],
|
119 |
+
# track_df['position_y'].values[valid_agent_steps],
|
120 |
+
# track_df['position_z'].values[valid_agent_steps]],
|
121 |
+
# axis=-1)).float()
|
122 |
+
# heading[agent_idx, valid_agent_steps] = torch.from_numpy(track_df['heading'].values[valid_agent_steps]).float()
|
123 |
+
# velocity[agent_idx, valid_agent_steps, :2] = torch.from_numpy(np.stack([track_df['velocity_x'].values[valid_agent_steps],
|
124 |
+
# track_df['velocity_y'].values[valid_agent_steps]],
|
125 |
+
# axis=-1)).float()
|
126 |
+
# shape[agent_idx, valid_agent_steps, :3] = torch.from_numpy(np.stack([track_df['length'].values[valid_agent_steps],
|
127 |
+
# track_df['width'].values[valid_agent_steps],
|
128 |
+
# track_df["height"].values[valid_agent_steps]],
|
129 |
+
# axis=-1)).float()
|
130 |
+
# av_idx = agent_id.index(av_id)
|
131 |
+
# if split == 'test':
|
132 |
+
# predict_mask[current_valid_mask
|
133 |
+
# | (agent_category == 2)
|
134 |
+
# | (agent_category == 3), num_historical_steps:] = True
|
135 |
+
|
136 |
+
# return {
|
137 |
+
# 'num_nodes': num_agents,
|
138 |
+
# 'av_index': av_idx,
|
139 |
+
# 'valid_mask': valid_mask,
|
140 |
+
# 'predict_mask': predict_mask,
|
141 |
+
# 'id': agent_id,
|
142 |
+
# 'type': agent_type,
|
143 |
+
# 'category': agent_category,
|
144 |
+
# 'position': position,
|
145 |
+
# 'heading': heading,
|
146 |
+
# 'velocity': velocity,
|
147 |
+
# 'shape': shape
|
148 |
+
# }
|
149 |
+
|
150 |
+
|
151 |
+
def get_agent_features(track_infos: Dict[str, np.ndarray], av_id: int, num_historical_steps: int, num_steps: int) -> Dict[str, Any]:
|
152 |
+
|
153 |
+
agent_idx_to_add = []
|
154 |
+
for i in range(len(track_infos['object_id'])):
|
155 |
+
is_visible = track_infos['valid'][i, num_historical_steps - 1]
|
156 |
+
valid_steps = np.where(track_infos['valid'][i])[0]
|
157 |
+
valid_start, valid_end = valid_steps[0], valid_steps[-1]
|
158 |
+
is_valid = (valid_end - valid_start + 1) >= MIN_VALID_STEPS
|
159 |
+
|
160 |
+
if (is_visible or not args.disable_invalid) and is_valid:
|
161 |
+
agent_idx_to_add.append(i)
|
162 |
+
|
163 |
+
num_agents = len(agent_idx_to_add)
|
164 |
+
out_dict = {
|
165 |
+
'num_nodes': num_agents,
|
166 |
+
'valid_mask': torch.zeros(num_agents, num_steps, dtype=torch.bool),
|
167 |
+
'role': torch.zeros(num_agents, 3, dtype=torch.bool),
|
168 |
+
'id': torch.zeros(num_agents, dtype=torch.int64) - 1,
|
169 |
+
'type': torch.zeros(num_agents, dtype=torch.uint8),
|
170 |
+
'category': torch.zeros(num_agents, dtype=torch.uint8),
|
171 |
+
'position': torch.zeros(num_agents, num_steps, 3, dtype=torch.float),
|
172 |
+
'heading': torch.zeros(num_agents, num_steps, dtype=torch.float),
|
173 |
+
'velocity': torch.zeros(num_agents, num_steps, 2, dtype=torch.float),
|
174 |
+
'shape': torch.zeros(num_agents, num_steps, 3, dtype=torch.float),
|
175 |
+
}
|
176 |
+
|
177 |
+
for i, idx in enumerate(agent_idx_to_add):
|
178 |
+
|
179 |
+
out_dict['role'][i] = torch.from_numpy(track_infos['role'][idx])
|
180 |
+
out_dict['id'][i] = track_infos['object_id'][idx]
|
181 |
+
out_dict['type'][i] = track_infos['object_type'][idx]
|
182 |
+
out_dict['category'][i] = idx in track_infos['tracks_to_predict']
|
183 |
+
|
184 |
+
valid = track_infos["valid"][idx] # [n_step]
|
185 |
+
states = track_infos["states"][idx]
|
186 |
+
|
187 |
+
object_shape = states[:, 3:6] # [n_step, 3], length, width, height
|
188 |
+
object_shape = object_shape[valid].mean(axis=0) # [3]
|
189 |
+
out_dict["shape"][i] = torch.from_numpy(object_shape)
|
190 |
+
|
191 |
+
valid_steps = np.where(valid)[0]
|
192 |
+
position = states[:, :3] # [n_step, dim], x, y, z
|
193 |
+
velocity = states[:, 7:9] # [n_step, 2], vx, vy
|
194 |
+
heading = states[:, 6] # [n_step], heading
|
195 |
+
|
196 |
+
# valid.sum() should > 1:
|
197 |
+
t_start, t_end = valid_steps[0], valid_steps[-1]
|
198 |
+
f_pos = interp1d(valid_steps, position[valid], axis=0)
|
199 |
+
f_vel = interp1d(valid_steps, velocity[valid], axis=0)
|
200 |
+
f_yaw = interp1d(valid_steps, np.unwrap(heading[valid], axis=0), axis=0)
|
201 |
+
t_in = np.arange(t_start, t_end + 1)
|
202 |
+
out_dict["valid_mask"][i, t_start : t_end + 1] = True
|
203 |
+
out_dict["position"][i, t_start : t_end + 1] = torch.from_numpy(f_pos(t_in))
|
204 |
+
out_dict["velocity"][i, t_start : t_end + 1] = torch.from_numpy(f_vel(t_in))
|
205 |
+
out_dict["heading"][i, t_start : t_end + 1] = torch.from_numpy(f_yaw(t_in))
|
206 |
+
|
207 |
+
out_dict['av_idx'] = out_dict['id'].tolist().index(av_id)
|
208 |
+
|
209 |
+
return out_dict
|
210 |
+
|
211 |
+
|
212 |
+
def get_map_features(map_infos, tf_current_light, dim=3):
|
213 |
+
lane_segments = map_infos['lane']
|
214 |
+
all_polylines = map_infos["all_polylines"]
|
215 |
+
crosswalks = map_infos['crosswalk']
|
216 |
+
road_edges = map_infos['road_edge']
|
217 |
+
road_lines = map_infos['road_line']
|
218 |
+
lane_segment_ids = [info["id"] for info in lane_segments]
|
219 |
+
cross_walk_ids = [info["id"] for info in crosswalks]
|
220 |
+
road_edge_ids = [info["id"] for info in road_edges]
|
221 |
+
road_line_ids = [info["id"] for info in road_lines]
|
222 |
+
polygon_ids = lane_segment_ids + road_edge_ids + road_line_ids + cross_walk_ids
|
223 |
+
num_polygons = len(lane_segment_ids) + len(road_edge_ids) + len(road_line_ids) + len(cross_walk_ids)
|
224 |
+
|
225 |
+
# initialization
|
226 |
+
polygon_type = torch.zeros(num_polygons, dtype=torch.uint8)
|
227 |
+
polygon_light_type = torch.ones(num_polygons, dtype=torch.uint8) * 3
|
228 |
+
|
229 |
+
# list of (num_of_segments,), each element has shape of (num_of_points_of_current_segment - 1, dim)
|
230 |
+
point_position: List[Optional[torch.Tensor]] = [None] * num_polygons
|
231 |
+
point_orientation: List[Optional[torch.Tensor]] = [None] * num_polygons
|
232 |
+
point_magnitude: List[Optional[torch.Tensor]] = [None] * num_polygons
|
233 |
+
point_height: List[Optional[torch.Tensor]] = [None] * num_polygons
|
234 |
+
point_type: List[Optional[torch.Tensor]] = [None] * num_polygons
|
235 |
+
|
236 |
+
for lane_segment in lane_segments:
|
237 |
+
lane_segment = easydict.EasyDict(lane_segment)
|
238 |
+
lane_segment_idx = polygon_ids.index(lane_segment.id)
|
239 |
+
polyline_index = lane_segment.polyline_index # (start index of point in current scenario, end index of point in current scenario)
|
240 |
+
centerline = all_polylines[polyline_index[0] : polyline_index[1], :] # (num_of_points_of_current_segment, 5)
|
241 |
+
centerline = torch.from_numpy(centerline).float()
|
242 |
+
polygon_type[lane_segment_idx] = _polygon_types.index(Lane_type_hash[lane_segment.type])
|
243 |
+
|
244 |
+
res = tf_current_light[tf_current_light["lane_id"] == str(lane_segment.id)]
|
245 |
+
if len(res) != 0:
|
246 |
+
polygon_light_type[lane_segment_idx] = _polygon_light_type.index(res["state"].item())
|
247 |
+
|
248 |
+
point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0) # (num_of_points_of_current_segment - 1, 3)
|
249 |
+
center_vectors = centerline[1:] - centerline[:-1] # (num_of_points_of_current_segment - 1, 5)
|
250 |
+
point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0) # (num_of_points_of_current_segment - 1,)
|
251 |
+
point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1) # (num_of_points_of_current_segment - 1,)
|
252 |
+
point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0) # (num_of_points_of_current_segment - 1,)
|
253 |
+
center_type = _point_types.index('CENTERLINE')
|
254 |
+
point_type[lane_segment_idx] = torch.cat(
|
255 |
+
[torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)
|
256 |
+
|
257 |
+
for lane_segment in road_edges:
|
258 |
+
lane_segment = easydict.EasyDict(lane_segment)
|
259 |
+
lane_segment_idx = polygon_ids.index(lane_segment.id)
|
260 |
+
polyline_index = lane_segment.polyline_index
|
261 |
+
centerline = all_polylines[polyline_index[0] : polyline_index[1], :]
|
262 |
+
centerline = torch.from_numpy(centerline).float()
|
263 |
+
polygon_type[lane_segment_idx] = _polygon_types.index("VEHICLE")
|
264 |
+
|
265 |
+
point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)
|
266 |
+
center_vectors = centerline[1:] - centerline[:-1]
|
267 |
+
point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)
|
268 |
+
point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)
|
269 |
+
point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)
|
270 |
+
center_type = _point_types.index('EDGE')
|
271 |
+
point_type[lane_segment_idx] = torch.cat(
|
272 |
+
[torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)
|
273 |
+
|
274 |
+
for lane_segment in road_lines:
|
275 |
+
lane_segment = easydict.EasyDict(lane_segment)
|
276 |
+
lane_segment_idx = polygon_ids.index(lane_segment.id)
|
277 |
+
polyline_index = lane_segment.polyline_index
|
278 |
+
centerline = all_polylines[polyline_index[0] : polyline_index[1], :]
|
279 |
+
centerline = torch.from_numpy(centerline).float()
|
280 |
+
|
281 |
+
polygon_type[lane_segment_idx] = _polygon_types.index("VEHICLE")
|
282 |
+
|
283 |
+
point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)
|
284 |
+
center_vectors = centerline[1:] - centerline[:-1]
|
285 |
+
point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)
|
286 |
+
point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)
|
287 |
+
point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)
|
288 |
+
center_type = _point_types.index(boundary_type_hash[lane_segment.type])
|
289 |
+
point_type[lane_segment_idx] = torch.cat(
|
290 |
+
[torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)
|
291 |
+
|
292 |
+
for crosswalk in crosswalks:
|
293 |
+
crosswalk = easydict.EasyDict(crosswalk)
|
294 |
+
lane_segment_idx = polygon_ids.index(crosswalk.id)
|
295 |
+
polyline_index = crosswalk.polyline_index
|
296 |
+
centerline = all_polylines[polyline_index[0] : polyline_index[1], :]
|
297 |
+
centerline = torch.from_numpy(centerline).float()
|
298 |
+
|
299 |
+
polygon_type[lane_segment_idx] = _polygon_types.index("PEDESTRIAN")
|
300 |
+
|
301 |
+
point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)
|
302 |
+
center_vectors = centerline[1:] - centerline[:-1]
|
303 |
+
point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)
|
304 |
+
point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)
|
305 |
+
point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)
|
306 |
+
center_type = _point_types.index("CROSSWALK")
|
307 |
+
point_type[lane_segment_idx] = torch.cat(
|
308 |
+
[torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)
|
309 |
+
|
310 |
+
# (num_of_segments,), each element represents the number of points of the segment
|
311 |
+
num_points = torch.tensor([point.size(0) for point in point_position], dtype=torch.long)
|
312 |
+
# (2, total_num_of_points_of_all_segments), store the point index of segment and its corresponding segment index
|
313 |
+
# e.g. a scenario has 203 segments, and totally 14039 points:
|
314 |
+
# tensor([[ 0, 1, 2, ..., 14927, 14928, 14929],
|
315 |
+
# [ 0, 0, 0, ..., 202, 202, 202]]) => polygon_ids.index(lane_segment.id)
|
316 |
+
point_to_polygon_edge_index = torch.stack(
|
317 |
+
[torch.arange(num_points.sum(), dtype=torch.long),
|
318 |
+
torch.arange(num_polygons, dtype=torch.long).repeat_interleave(num_points)], dim=0)
|
319 |
+
# list of (num_of_lane_segments,)
|
320 |
+
polygon_to_polygon_edge_index = []
|
321 |
+
# list of (num_of_lane_segments,)
|
322 |
+
polygon_to_polygon_type = []
|
323 |
+
for lane_segment in lane_segments:
|
324 |
+
lane_segment = easydict.EasyDict(lane_segment)
|
325 |
+
lane_segment_idx = polygon_ids.index(lane_segment.id)
|
326 |
+
pred_inds = []
|
327 |
+
for pred in lane_segment.entry_lanes:
|
328 |
+
pred_idx = safe_list_index(polygon_ids, pred)
|
329 |
+
if pred_idx is not None:
|
330 |
+
pred_inds.append(pred_idx)
|
331 |
+
if len(pred_inds) != 0:
|
332 |
+
polygon_to_polygon_edge_index.append(
|
333 |
+
torch.stack([torch.tensor(pred_inds, dtype=torch.long),
|
334 |
+
torch.full((len(pred_inds),), lane_segment_idx, dtype=torch.long)], dim=0))
|
335 |
+
polygon_to_polygon_type.append(
|
336 |
+
torch.full((len(pred_inds),), _polygon_to_polygon_types.index('PRED'), dtype=torch.uint8))
|
337 |
+
succ_inds = []
|
338 |
+
for succ in lane_segment.exit_lanes:
|
339 |
+
succ_idx = safe_list_index(polygon_ids, succ)
|
340 |
+
if succ_idx is not None:
|
341 |
+
succ_inds.append(succ_idx)
|
342 |
+
if len(succ_inds) != 0:
|
343 |
+
polygon_to_polygon_edge_index.append(
|
344 |
+
torch.stack([torch.tensor(succ_inds, dtype=torch.long),
|
345 |
+
torch.full((len(succ_inds),), lane_segment_idx, dtype=torch.long)], dim=0))
|
346 |
+
polygon_to_polygon_type.append(
|
347 |
+
torch.full((len(succ_inds),), _polygon_to_polygon_types.index('SUCC'), dtype=torch.uint8))
|
348 |
+
if len(lane_segment.left_neighbors) != 0:
|
349 |
+
left_neighbor_ids = lane_segment.left_neighbors
|
350 |
+
for left_neighbor_id in left_neighbor_ids:
|
351 |
+
left_idx = safe_list_index(polygon_ids, left_neighbor_id)
|
352 |
+
if left_idx is not None:
|
353 |
+
polygon_to_polygon_edge_index.append(
|
354 |
+
torch.tensor([[left_idx], [lane_segment_idx]], dtype=torch.long))
|
355 |
+
polygon_to_polygon_type.append(
|
356 |
+
torch.tensor([_polygon_to_polygon_types.index('LEFT')], dtype=torch.uint8))
|
357 |
+
if len(lane_segment.right_neighbors) != 0:
|
358 |
+
right_neighbor_ids = lane_segment.right_neighbors
|
359 |
+
for right_neighbor_id in right_neighbor_ids:
|
360 |
+
right_idx = safe_list_index(polygon_ids, right_neighbor_id)
|
361 |
+
if right_idx is not None:
|
362 |
+
polygon_to_polygon_edge_index.append(
|
363 |
+
torch.tensor([[right_idx], [lane_segment_idx]], dtype=torch.long))
|
364 |
+
polygon_to_polygon_type.append(
|
365 |
+
torch.tensor([_polygon_to_polygon_types.index('RIGHT')], dtype=torch.uint8))
|
366 |
+
if len(polygon_to_polygon_edge_index) != 0:
|
367 |
+
polygon_to_polygon_edge_index = torch.cat(polygon_to_polygon_edge_index, dim=1)
|
368 |
+
polygon_to_polygon_type = torch.cat(polygon_to_polygon_type, dim=0)
|
369 |
+
else:
|
370 |
+
polygon_to_polygon_edge_index = torch.tensor([[], []], dtype=torch.long)
|
371 |
+
polygon_to_polygon_type = torch.tensor([], dtype=torch.uint8)
|
372 |
+
|
373 |
+
map_data = {
|
374 |
+
'map_polygon': {},
|
375 |
+
'map_point': {},
|
376 |
+
('map_point', 'to', 'map_polygon'): {},
|
377 |
+
('map_polygon', 'to', 'map_polygon'): {},
|
378 |
+
}
|
379 |
+
map_data['map_polygon']['num_nodes'] = num_polygons # int, number of map segments in the scenario
|
380 |
+
map_data['map_polygon']['type'] = polygon_type # (num_polygons,) type of each polygon
|
381 |
+
map_data['map_polygon']['light_type'] = polygon_light_type # (num_polygons,) light type of each polygon, 3 means unknown
|
382 |
+
if len(num_points) == 0:
|
383 |
+
map_data['map_point']['num_nodes'] = 0
|
384 |
+
map_data['map_point']['position'] = torch.tensor([], dtype=torch.float)
|
385 |
+
map_data['map_point']['orientation'] = torch.tensor([], dtype=torch.float)
|
386 |
+
map_data['map_point']['magnitude'] = torch.tensor([], dtype=torch.float)
|
387 |
+
if dim == 3:
|
388 |
+
map_data['map_point']['height'] = torch.tensor([], dtype=torch.float)
|
389 |
+
map_data['map_point']['type'] = torch.tensor([], dtype=torch.uint8)
|
390 |
+
map_data['map_point']['side'] = torch.tensor([], dtype=torch.uint8)
|
391 |
+
else:
|
392 |
+
map_data['map_point']['num_nodes'] = num_points.sum().item() # int, number of total points of all segments in the scenario
|
393 |
+
map_data['map_point']['position'] = torch.cat(point_position, dim=0) # (num_of_total_points_of_all_segments, 3)
|
394 |
+
map_data['map_point']['orientation'] = torch.cat(point_orientation, dim=0) # (num_of_total_points_of_all_segments,)
|
395 |
+
map_data['map_point']['magnitude'] = torch.cat(point_magnitude, dim=0) # (num_of_total_points_of_all_segments,)
|
396 |
+
if dim == 3:
|
397 |
+
map_data['map_point']['height'] = torch.cat(point_height, dim=0) # (num_of_total_points_of_all_segments,)
|
398 |
+
map_data['map_point']['type'] = torch.cat(point_type, dim=0) # (num_of_total_points_of_all_segments,) type of point => `_point_types`
|
399 |
+
map_data['map_point', 'to', 'map_polygon']['edge_index'] = point_to_polygon_edge_index # (2, num_of_total_points_of_all_segments)
|
400 |
+
map_data['map_polygon', 'to', 'map_polygon']['edge_index'] = polygon_to_polygon_edge_index
|
401 |
+
map_data['map_polygon', 'to', 'map_polygon']['type'] = polygon_to_polygon_type
|
402 |
+
|
403 |
+
if int(os.getenv('DEBUG_MAP', 1)):
|
404 |
+
import matplotlib.pyplot as plt
|
405 |
+
plt.axis('equal')
|
406 |
+
plt.scatter(map_data['map_point']['position'][:, 0],
|
407 |
+
map_data['map_point']['position'][:, 1], s=0.2, c='black', edgecolors='none')
|
408 |
+
plt.savefig("debug.png", dpi=600)
|
409 |
+
|
410 |
+
return map_data
|
411 |
+
|
412 |
+
|
413 |
+
# def process_agent(track_info, tracks_to_predict, scenario_id, start_timestamp, end_timestamp):
|
414 |
+
|
415 |
+
# agents_array = track_info["states"].transpose(1, 0, 2) # (num_timesteps, num_agents, 10) e.g. (91, 15, 10)
|
416 |
+
# object_id = np.array(track_info["object_id"]) # (num_agents,) global id of each agent
|
417 |
+
# object_type = track_info["object_type"] # (num_agents,) type of each agent, e.g. 'TYPE_VEHICLE'
|
418 |
+
# id_hash = {object_id[o_idx]: object_type[o_idx] for o_idx in range(len(object_id))}
|
419 |
+
|
420 |
+
# def type_hash(x):
|
421 |
+
# tp = id_hash[x]
|
422 |
+
# type_re_hash = {
|
423 |
+
# "TYPE_VEHICLE": "vehicle",
|
424 |
+
# "TYPE_PEDESTRIAN": "pedestrian",
|
425 |
+
# "TYPE_CYCLIST": "cyclist",
|
426 |
+
# "TYPE_OTHER": "background",
|
427 |
+
# "TYPE_UNSET": "background"
|
428 |
+
# }
|
429 |
+
# return type_re_hash[tp]
|
430 |
+
|
431 |
+
# columns = ['observed', 'track_id', 'object_type', 'object_category', 'timestep',
|
432 |
+
# 'position_x', 'position_y', 'position_z', 'length', 'width', 'height', 'heading', 'velocity_x', 'velocity_y',
|
433 |
+
# 'scenario_id', 'start_timestamp', 'end_timestamp', 'num_timestamps',
|
434 |
+
# 'focal_track_id', 'city', 'validity']
|
435 |
+
|
436 |
+
# # (num_timesteps, num_agents, 10) e.g. (91, 15, 10)
|
437 |
+
# new_columns = np.ones((agents_array.shape[0], agents_array.shape[1], 11))
|
438 |
+
# new_columns[:11, :, 0] = True # observed, 10 timesteps
|
439 |
+
# new_columns[11:, :, 0] = False # not observed (current + future)
|
440 |
+
# for index in range(new_columns.shape[0]):
|
441 |
+
# new_columns[index, :, 4] = int(index) # timestep (0 ~ 90)
|
442 |
+
# new_columns[..., 1] = object_id
|
443 |
+
# new_columns[..., 2] = object_id
|
444 |
+
# new_columns[:, tracks_to_predict['track_index'], 3] = 3
|
445 |
+
# new_columns[..., 5] = 11
|
446 |
+
# new_columns[..., 6] = int(start_timestamp) # 0
|
447 |
+
# new_columns[..., 7] = int(end_timestamp) # 91
|
448 |
+
# new_columns[..., 8] = int(91) # 91
|
449 |
+
# new_columns[..., 9] = object_id
|
450 |
+
# new_columns[..., 10] = 10086
|
451 |
+
# new_columns = new_columns
|
452 |
+
# new_agents_array = np.concatenate([new_columns, agents_array], axis=-1) # (num_timesteps, num_agents, 21) e.g. (91, 15, 21)
|
453 |
+
# # filter out the invalid timestep of agents, reshape to (num_valid_of_timesteps_of_all_agents, 21) e.g. (91, 15, 21) -> (1137, 21)
|
454 |
+
# if args.disable_invalid:
|
455 |
+
# new_agents_array = new_agents_array[new_agents_array[..., -1] == 1.0].reshape(-1, new_agents_array.shape[-1])
|
456 |
+
# else:
|
457 |
+
# agent_valid_mask = new_agents_array[..., -1] # (num_timesteps, num_agents)
|
458 |
+
# agent_mask = np.sum(agent_valid_mask, axis=0) > MIN_VALID_STEPS # NOTE: 10 is a empirical parameter
|
459 |
+
# new_agents_array = new_agents_array[:, agent_mask]
|
460 |
+
# new_agents_array = new_agents_array.reshape(-1, new_agents_array.shape[-1]) # (91, 15, 21) -> (1365, 21)
|
461 |
+
# new_agents_array = new_agents_array[..., [0, 1, 2, 3, 4, 11, 12, 13, 14, 15, 16, 17, 18, 19, 5, 6, 7, 8, 9, 10, 20]]
|
462 |
+
# new_agents_array = pd.DataFrame(data=new_agents_array, columns=columns)
|
463 |
+
# new_agents_array["object_type"] = new_agents_array["object_type"].apply(func=type_hash)
|
464 |
+
# new_agents_array["start_timestamp"] = new_agents_array["start_timestamp"].astype(int)
|
465 |
+
# new_agents_array["end_timestamp"] = new_agents_array["end_timestamp"].astype(int)
|
466 |
+
# new_agents_array["num_timestamps"] = new_agents_array["num_timestamps"].astype(int)
|
467 |
+
# new_agents_array["scenario_id"] = scenario_id
|
468 |
+
|
469 |
+
# return new_agents_array
|
470 |
+
|
471 |
+
|
472 |
+
def process_dynamic_map(dynamic_map_infos):
|
473 |
+
lane_ids = dynamic_map_infos["lane_id"]
|
474 |
+
tf_lights = []
|
475 |
+
for t in range(len(lane_ids)):
|
476 |
+
lane_id = lane_ids[t]
|
477 |
+
time = np.ones_like(lane_id) * t
|
478 |
+
state = dynamic_map_infos["state"][t]
|
479 |
+
tf_light = np.concatenate([lane_id, time, state], axis=0)
|
480 |
+
tf_lights.append(tf_light)
|
481 |
+
tf_lights = np.concatenate(tf_lights, axis=1).transpose(1, 0)
|
482 |
+
tf_lights = pd.DataFrame(data=tf_lights, columns=["lane_id", "time_step", "state"])
|
483 |
+
tf_lights["time_step"] = tf_lights["time_step"].astype("str")
|
484 |
+
tf_lights["lane_id"] = tf_lights["lane_id"].astype("str")
|
485 |
+
tf_lights["state"] = tf_lights["state"].astype("str")
|
486 |
+
tf_lights.loc[tf_lights["state"].str.contains("STOP"), ["state"]] = (
|
487 |
+
"LANE_STATE_STOP"
|
488 |
+
)
|
489 |
+
tf_lights.loc[tf_lights["state"].str.contains("GO"), ["state"]] = "LANE_STATE_GO"
|
490 |
+
tf_lights.loc[tf_lights["state"].str.contains("CAUTION"), ["state"]] = (
|
491 |
+
"LANE_STATE_CAUTION"
|
492 |
+
)
|
493 |
+
tf_lights.loc[tf_lights["state"].str.contains("UNKNOWN"), ["state"]] = (
|
494 |
+
"LANE_STATE_UNKNOWN"
|
495 |
+
)
|
496 |
+
|
497 |
+
return tf_lights
|
498 |
+
|
499 |
+
|
500 |
+
polyline_type = {
|
501 |
+
# for lane
|
502 |
+
'TYPE_UNDEFINED': -1,
|
503 |
+
'TYPE_FREEWAY': 1,
|
504 |
+
'TYPE_SURFACE_STREET': 2,
|
505 |
+
'TYPE_BIKE_LANE': 3,
|
506 |
+
|
507 |
+
# for roadline
|
508 |
+
'TYPE_UNKNOWN': -1,
|
509 |
+
'TYPE_BROKEN_SINGLE_WHITE': 6,
|
510 |
+
'TYPE_SOLID_SINGLE_WHITE': 7,
|
511 |
+
'TYPE_SOLID_DOUBLE_WHITE': 8,
|
512 |
+
'TYPE_BROKEN_SINGLE_YELLOW': 9,
|
513 |
+
'TYPE_BROKEN_DOUBLE_YELLOW': 10,
|
514 |
+
'TYPE_SOLID_SINGLE_YELLOW': 11,
|
515 |
+
'TYPE_SOLID_DOUBLE_YELLOW': 12,
|
516 |
+
'TYPE_PASSING_DOUBLE_YELLOW': 13,
|
517 |
+
|
518 |
+
# for roadedge
|
519 |
+
'TYPE_ROAD_EDGE_BOUNDARY': 15,
|
520 |
+
'TYPE_ROAD_EDGE_MEDIAN': 16,
|
521 |
+
|
522 |
+
# for stopsign
|
523 |
+
'TYPE_STOP_SIGN': 17,
|
524 |
+
|
525 |
+
# for crosswalk
|
526 |
+
'TYPE_CROSSWALK': 18,
|
527 |
+
|
528 |
+
# for speed bump
|
529 |
+
'TYPE_SPEED_BUMP': 19
|
530 |
+
}
|
531 |
+
|
532 |
+
object_type = {
|
533 |
+
0: 'TYPE_UNSET',
|
534 |
+
1: 'TYPE_VEHICLE',
|
535 |
+
2: 'TYPE_PEDESTRIAN',
|
536 |
+
3: 'TYPE_CYCLIST',
|
537 |
+
4: 'TYPE_OTHER'
|
538 |
+
}
|
539 |
+
|
540 |
+
|
541 |
+
def decode_tracks_from_proto(scenario):
|
542 |
+
sdc_track_index = scenario.sdc_track_index
|
543 |
+
track_index_predict = [i.track_index for i in scenario.tracks_to_predict]
|
544 |
+
object_id_interest = [i for i in scenario.objects_of_interest]
|
545 |
+
|
546 |
+
track_infos = {
|
547 |
+
'object_id': [], # {0: unset, 1: vehicle, 2: pedestrian, 3: cyclist, 4: others}
|
548 |
+
'object_type': [],
|
549 |
+
'states': [],
|
550 |
+
'valid': [],
|
551 |
+
'role': [],
|
552 |
+
}
|
553 |
+
|
554 |
+
# tracks mean N number of objects, e.g. len(tracks) = 55
|
555 |
+
# each track has 91 states, e.g. len(tracks[0].states) == 91
|
556 |
+
# each state has 10 attributes: center_x, center_y, center_z, length, ..., velocity_y, valid
|
557 |
+
for i, cur_data in enumerate(scenario.tracks):
|
558 |
+
|
559 |
+
step_state = []
|
560 |
+
step_valid = []
|
561 |
+
|
562 |
+
for s in cur_data.states: # n_steps
|
563 |
+
step_state.append(
|
564 |
+
[
|
565 |
+
s.center_x,
|
566 |
+
s.center_y,
|
567 |
+
s.center_z,
|
568 |
+
s.length,
|
569 |
+
s.width,
|
570 |
+
s.height,
|
571 |
+
s.heading,
|
572 |
+
s.velocity_x,
|
573 |
+
s.velocity_y,
|
574 |
+
]
|
575 |
+
)
|
576 |
+
step_valid.append(s.valid)
|
577 |
+
# This angle is normalized to [-pi, pi). The velocity vector in m/s
|
578 |
+
|
579 |
+
track_infos['object_id'].append(cur_data.id) # id of object in this track
|
580 |
+
track_infos['object_type'].append(cur_data.object_type - 1)
|
581 |
+
track_infos['states'].append(np.array(step_state, dtype=np.float32))
|
582 |
+
track_infos['valid'].append(np.array(step_valid))
|
583 |
+
|
584 |
+
track_infos['role'].append([False, False, False])
|
585 |
+
if i in track_index_predict:
|
586 |
+
track_infos['role'][-1][2] = True # predict=2
|
587 |
+
if cur_data.id in object_id_interest:
|
588 |
+
track_infos['role'][-1][1] = True # interest=1
|
589 |
+
if i == sdc_track_index:
|
590 |
+
track_infos['role'][-1][0] = True # ego_vehicle=0
|
591 |
+
|
592 |
+
track_infos['states'] = np.array(track_infos['states'], dtype=np.float32) # (n_agent, n_step, 9)
|
593 |
+
track_infos['valid'] = np.array(track_infos['valid'], dtype=np.bool_)
|
594 |
+
track_infos['role'] = np.array(track_infos['role'], dtype=np.bool_)
|
595 |
+
track_infos['object_id'] = np.array(track_infos['object_id'], dtype=np.int64)
|
596 |
+
track_infos['object_type'] = np.array(track_infos['object_type'], dtype=np.uint8)
|
597 |
+
track_infos['tracks_to_predict'] = np.array(track_index_predict, dtype=np.int64)
|
598 |
+
|
599 |
+
return track_infos
|
600 |
+
|
601 |
+
|
602 |
+
from collections import defaultdict
|
603 |
+
|
604 |
+
def decode_map_features_from_proto(map_features):
|
605 |
+
map_infos = {
|
606 |
+
'lane': [],
|
607 |
+
'road_line': [],
|
608 |
+
'road_edge': [],
|
609 |
+
'stop_sign': [],
|
610 |
+
'crosswalk': [],
|
611 |
+
'speed_bump': [],
|
612 |
+
'lane_dict': {},
|
613 |
+
'lane2other_dict': {}
|
614 |
+
}
|
615 |
+
polylines = []
|
616 |
+
|
617 |
+
point_cnt = 0
|
618 |
+
lane2other_dict = defaultdict(list)
|
619 |
+
|
620 |
+
for cur_data in map_features:
|
621 |
+
cur_info = {'id': cur_data.id}
|
622 |
+
|
623 |
+
if cur_data.lane.ByteSize() > 0:
|
624 |
+
cur_info['speed_limit_mph'] = cur_data.lane.speed_limit_mph
|
625 |
+
cur_info['type'] = cur_data.lane.type + 1 # 0: undefined, 1: freeway, 2: surface_street, 3: bike_lane
|
626 |
+
cur_info['left_neighbors'] = [lane.feature_id for lane in cur_data.lane.left_neighbors]
|
627 |
+
|
628 |
+
cur_info['right_neighbors'] = [lane.feature_id for lane in cur_data.lane.right_neighbors]
|
629 |
+
|
630 |
+
cur_info['interpolating'] = cur_data.lane.interpolating
|
631 |
+
cur_info['entry_lanes'] = list(cur_data.lane.entry_lanes)
|
632 |
+
cur_info['exit_lanes'] = list(cur_data.lane.exit_lanes)
|
633 |
+
|
634 |
+
cur_info['left_boundary_type'] = [x.boundary_type + 5 for x in cur_data.lane.left_boundaries]
|
635 |
+
cur_info['right_boundary_type'] = [x.boundary_type + 5 for x in cur_data.lane.right_boundaries]
|
636 |
+
|
637 |
+
cur_info['left_boundary'] = [x.boundary_feature_id for x in cur_data.lane.left_boundaries]
|
638 |
+
cur_info['right_boundary'] = [x.boundary_feature_id for x in cur_data.lane.right_boundaries]
|
639 |
+
cur_info['left_boundary_start_index'] = [lane.lane_start_index for lane in cur_data.lane.left_boundaries]
|
640 |
+
cur_info['left_boundary_end_index'] = [lane.lane_end_index for lane in cur_data.lane.left_boundaries]
|
641 |
+
cur_info['right_boundary_start_index'] = [lane.lane_start_index for lane in cur_data.lane.right_boundaries]
|
642 |
+
cur_info['right_boundary_end_index'] = [lane.lane_end_index for lane in cur_data.lane.right_boundaries]
|
643 |
+
|
644 |
+
lane2other_dict[cur_data.id].extend(cur_info['left_boundary'])
|
645 |
+
lane2other_dict[cur_data.id].extend(cur_info['right_boundary'])
|
646 |
+
|
647 |
+
global_type = cur_info['type']
|
648 |
+
cur_polyline = np.stack(
|
649 |
+
[np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in cur_data.lane.polyline],
|
650 |
+
axis=0)
|
651 |
+
cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
|
652 |
+
if cur_polyline.shape[0] <= 1:
|
653 |
+
continue
|
654 |
+
map_infos['lane'].append(cur_info)
|
655 |
+
map_infos['lane_dict'][cur_data.id] = cur_info
|
656 |
+
|
657 |
+
elif cur_data.road_line.ByteSize() > 0:
|
658 |
+
cur_info['type'] = cur_data.road_line.type + 5
|
659 |
+
|
660 |
+
global_type = cur_info['type']
|
661 |
+
cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in
|
662 |
+
cur_data.road_line.polyline], axis=0)
|
663 |
+
cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
|
664 |
+
if cur_polyline.shape[0] <= 1:
|
665 |
+
continue
|
666 |
+
map_infos['road_line'].append(cur_info) # (num_points, 5)
|
667 |
+
|
668 |
+
elif cur_data.road_edge.ByteSize() > 0:
|
669 |
+
cur_info['type'] = cur_data.road_edge.type + 14
|
670 |
+
|
671 |
+
global_type = cur_info['type']
|
672 |
+
cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in
|
673 |
+
cur_data.road_edge.polyline], axis=0)
|
674 |
+
cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
|
675 |
+
if cur_polyline.shape[0] <= 1:
|
676 |
+
continue
|
677 |
+
map_infos['road_edge'].append(cur_info)
|
678 |
+
|
679 |
+
elif cur_data.stop_sign.ByteSize() > 0:
|
680 |
+
cur_info['lane_ids'] = list(cur_data.stop_sign.lane)
|
681 |
+
for i in cur_info['lane_ids']:
|
682 |
+
lane2other_dict[i].append(cur_data.id)
|
683 |
+
point = cur_data.stop_sign.position
|
684 |
+
cur_info['position'] = np.array([point.x, point.y, point.z])
|
685 |
+
|
686 |
+
global_type = polyline_type['TYPE_STOP_SIGN']
|
687 |
+
cur_polyline = np.array([point.x, point.y, point.z, global_type, cur_data.id]).reshape(1, 5)
|
688 |
+
if cur_polyline.shape[0] <= 1:
|
689 |
+
continue
|
690 |
+
map_infos['stop_sign'].append(cur_info)
|
691 |
+
elif cur_data.crosswalk.ByteSize() > 0:
|
692 |
+
global_type = polyline_type['TYPE_CROSSWALK']
|
693 |
+
cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in
|
694 |
+
cur_data.crosswalk.polygon], axis=0)
|
695 |
+
cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
|
696 |
+
if cur_polyline.shape[0] <= 1:
|
697 |
+
continue
|
698 |
+
map_infos['crosswalk'].append(cur_info)
|
699 |
+
|
700 |
+
elif cur_data.speed_bump.ByteSize() > 0:
|
701 |
+
global_type = polyline_type['TYPE_SPEED_BUMP']
|
702 |
+
cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in
|
703 |
+
cur_data.speed_bump.polygon], axis=0)
|
704 |
+
cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
|
705 |
+
if cur_polyline.shape[0] <= 1:
|
706 |
+
continue
|
707 |
+
map_infos['speed_bump'].append(cur_info)
|
708 |
+
|
709 |
+
else:
|
710 |
+
continue
|
711 |
+
polylines.append(cur_polyline)
|
712 |
+
cur_info['polyline_index'] = (point_cnt, point_cnt + len(cur_polyline)) # (start index of point in current scenario, end index of point in current scenario)
|
713 |
+
point_cnt += len(cur_polyline)
|
714 |
+
|
715 |
+
polylines = np.concatenate(polylines, axis=0).astype(np.float32)
|
716 |
+
map_infos['all_polylines'] = polylines # (num_of_total_points_in_current_scenario, 5)
|
717 |
+
map_infos['lane2other_dict'] = lane2other_dict
|
718 |
+
return map_infos
|
719 |
+
|
720 |
+
|
721 |
+
def decode_dynamic_map_states_from_proto(dynamic_map_states):
|
722 |
+
|
723 |
+
signal_state = {
|
724 |
+
0: 'LANE_STATE_UNKNOWN',
|
725 |
+
# States for traffic signals with arrows.
|
726 |
+
1: 'LANE_STATE_ARROW_STOP',
|
727 |
+
2: 'LANE_STATE_ARROW_CAUTION',
|
728 |
+
3: 'LANE_STATE_ARROW_GO',
|
729 |
+
# Standard round traffic signals.
|
730 |
+
4: 'LANE_STATE_STOP',
|
731 |
+
5: 'LANE_STATE_CAUTION',
|
732 |
+
6: 'LANE_STATE_GO',
|
733 |
+
# Flashing light signals.
|
734 |
+
7: 'LANE_STATE_FLASHING_STOP',
|
735 |
+
8: 'LANE_STATE_FLASHING_CAUTION'
|
736 |
+
}
|
737 |
+
|
738 |
+
dynamic_map_infos = {
|
739 |
+
'lane_id': [],
|
740 |
+
'state': [],
|
741 |
+
'stop_point': []
|
742 |
+
}
|
743 |
+
for cur_data in dynamic_map_states: # len(dynamic_map_states) = num_timestamp
|
744 |
+
lane_id, state, stop_point = [], [], []
|
745 |
+
for cur_signal in cur_data.lane_states: # (num_observed_signals)
|
746 |
+
lane_id.append(cur_signal.lane)
|
747 |
+
state.append(signal_state[cur_signal.state])
|
748 |
+
stop_point.append([cur_signal.stop_point.x, cur_signal.stop_point.y, cur_signal.stop_point.z])
|
749 |
+
|
750 |
+
dynamic_map_infos['lane_id'].append(np.array([lane_id]))
|
751 |
+
dynamic_map_infos['state'].append(np.array([state]))
|
752 |
+
dynamic_map_infos['stop_point'].append(np.array([stop_point]))
|
753 |
+
|
754 |
+
return dynamic_map_infos
|
755 |
+
|
756 |
+
|
757 |
+
# def process_single_data(scenario):
|
758 |
+
# info = {}
|
759 |
+
# info['scenario_id'] = scenario.scenario_id
|
760 |
+
# info['timestamps_seconds'] = list(scenario.timestamps_seconds) # list of int of shape (91)
|
761 |
+
# info['current_time_index'] = scenario.current_time_index # int, 10
|
762 |
+
# info['sdc_track_index'] = scenario.sdc_track_index # int
|
763 |
+
# info['objects_of_interest'] = list(scenario.objects_of_interest) # list, could be empty list
|
764 |
+
|
765 |
+
# info['tracks_to_predict'] = {
|
766 |
+
# 'track_index': [cur_pred.track_index for cur_pred in scenario.tracks_to_predict],
|
767 |
+
# 'difficulty': [cur_pred.difficulty for cur_pred in scenario.tracks_to_predict]
|
768 |
+
# } # for training: suggestion of objects to train on, for val/test: need to be predicted
|
769 |
+
|
770 |
+
# # decode tracks data
|
771 |
+
# track_infos = decode_tracks_from_proto(scenario.tracks)
|
772 |
+
# info['tracks_to_predict']['object_type'] = [track_infos['object_type'][cur_idx] for cur_idx in
|
773 |
+
# info['tracks_to_predict']['track_index']]
|
774 |
+
# # decode map related data
|
775 |
+
# map_infos = decode_map_features_from_proto(scenario.map_features)
|
776 |
+
# dynamic_map_infos = decode_dynamic_map_states_from_proto(scenario.dynamic_map_states)
|
777 |
+
|
778 |
+
# save_infos = {
|
779 |
+
# 'track_infos': track_infos,
|
780 |
+
# 'map_infos': map_infos,
|
781 |
+
# 'dynamic_map_infos': dynamic_map_infos,
|
782 |
+
# }
|
783 |
+
# save_infos.update(info)
|
784 |
+
# return save_infos
|
785 |
+
|
786 |
+
|
787 |
+
def wm2argo(file, input_dir, output_dir, existing_files=[], output_dir_tfrecords_splitted=None):
|
788 |
+
file_path = os.path.join(input_dir, file)
|
789 |
+
dataset = tf.data.TFRecordDataset(file_path, compression_type='', num_parallel_reads=3)
|
790 |
+
|
791 |
+
for cnt, tf_data in tqdm(enumerate(dataset), leave=False, desc=f'Process {file}...'):
|
792 |
+
|
793 |
+
scenario = scenario_pb2.Scenario()
|
794 |
+
scenario.ParseFromString(bytearray(tf_data.numpy()))
|
795 |
+
scenario_id = scenario.scenario_id
|
796 |
+
tqdm.write(f"idx: {cnt}, scenario_id: {scenario_id} of {file}")
|
797 |
+
|
798 |
+
if f'{scenario_id}.pkl' not in existing_files:
|
799 |
+
|
800 |
+
map_infos = decode_map_features_from_proto(scenario.map_features)
|
801 |
+
track_infos = decode_tracks_from_proto(scenario)
|
802 |
+
dynamic_map_infos = decode_dynamic_map_states_from_proto(scenario.dynamic_map_states)
|
803 |
+
sdc_track_index = scenario.sdc_track_index # int
|
804 |
+
av_id = track_infos['object_id'][sdc_track_index]
|
805 |
+
# if len(track_infos['tracks_to_predict']) < 1:
|
806 |
+
# return
|
807 |
+
|
808 |
+
current_time_index = scenario.current_time_index
|
809 |
+
tf_lights = process_dynamic_map(dynamic_map_infos)
|
810 |
+
tf_current_light = tf_lights.loc[tf_lights["time_step"] == current_time_index] # 10 (history) + 1 (current) + 80 (future)
|
811 |
+
map_data = get_map_features(map_infos, tf_current_light)
|
812 |
+
|
813 |
+
# new_agents_array = process_agent(track_infos, tracks_to_predict, scenario_id, 0, 91) # mtr2argo
|
814 |
+
data = dict()
|
815 |
+
data.update(map_data)
|
816 |
+
data['scenario_id'] = scenario_id
|
817 |
+
data['agent'] = get_agent_features(track_infos, av_id, num_historical_steps=current_time_index + 1, num_steps=91)
|
818 |
+
|
819 |
+
with open(os.path.join(output_dir, f'{scenario_id}.pkl'), "wb+") as f:
|
820 |
+
pickle.dump(data, f)
|
821 |
+
|
822 |
+
if output_dir_tfrecords_splitted is not None:
|
823 |
+
tf_file = os.path.join(output_dir_tfrecords_splitted, f'{scenario_id}.tfrecords')
|
824 |
+
if not os.path.exists(tf_file):
|
825 |
+
with tf.io.TFRecordWriter(tf_file) as file_writer:
|
826 |
+
file_writer.write(tf_data.numpy())
|
827 |
+
|
828 |
+
|
829 |
+
def batch_process9s_transformer(input_dir, output_dir, split, num_workers=2):
|
830 |
+
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
831 |
+
|
832 |
+
output_dir_tfrecords_splitted = None
|
833 |
+
if split == "validation":
|
834 |
+
output_dir_tfrecords_splitted = os.path.join(output_dir, 'validation_tfrecords_splitted')
|
835 |
+
os.makedirs(output_dir_tfrecords_splitted, exist_ok=True)
|
836 |
+
|
837 |
+
input_dir = os.path.join(input_dir, split)
|
838 |
+
output_dir = os.path.join(output_dir, split)
|
839 |
+
os.makedirs(output_dir, exist_ok=True)
|
840 |
+
|
841 |
+
packages = sorted(os.listdir(input_dir))
|
842 |
+
existing_files = sorted(os.listdir(output_dir))
|
843 |
+
func = partial(
|
844 |
+
wm2argo,
|
845 |
+
output_dir=output_dir,
|
846 |
+
input_dir=input_dir,
|
847 |
+
existing_files=existing_files,
|
848 |
+
output_dir_tfrecords_splitted=output_dir_tfrecords_splitted
|
849 |
+
)
|
850 |
+
try:
|
851 |
+
with multiprocessing.Pool(num_workers, maxtasksperchild=10) as p:
|
852 |
+
r = list(tqdm(p.imap_unordered(func, packages), total=len(packages)))
|
853 |
+
except KeyboardInterrupt:
|
854 |
+
p.terminate()
|
855 |
+
p.join()
|
856 |
+
|
857 |
+
|
858 |
+
def generate_meta_infos(data_dir):
|
859 |
+
import json
|
860 |
+
|
861 |
+
meta_infos = dict()
|
862 |
+
|
863 |
+
for split in tqdm(['training', 'validation', 'test'], leave=False):
|
864 |
+
if not os.path.exists(os.path.join(data_dir, split)):
|
865 |
+
continue
|
866 |
+
|
867 |
+
split_infos = dict()
|
868 |
+
files = os.listdir(os.path.join(data_dir, split))
|
869 |
+
for file in tqdm(files, leave=False):
|
870 |
+
try:
|
871 |
+
data = pickle.load(open(os.path.join(data_dir, split, file), 'rb'))
|
872 |
+
except Exception as e:
|
873 |
+
tqdm.write(f'Failed to load scenario {file} due to {e}')
|
874 |
+
continue
|
875 |
+
scenario_infos = dict(num_agents=data['agent']['num_nodes'])
|
876 |
+
scenario_id = data['scenario_id']
|
877 |
+
split_infos[scenario_id] = scenario_infos
|
878 |
+
|
879 |
+
meta_infos[split] = split_infos
|
880 |
+
|
881 |
+
with open(os.path.join(data_dir, 'meta_infos.json'), 'w', encoding='utf-8') as f:
|
882 |
+
json.dump(meta_infos, f, indent=4)
|
883 |
+
|
884 |
+
|
885 |
+
if __name__ == "__main__":
|
886 |
+
parser = ArgumentParser()
|
887 |
+
parser.add_argument('--input_dir', type=str, default='data/waymo/')
|
888 |
+
parser.add_argument('--output_dir', type=str, default='data/waymo_processed/')
|
889 |
+
parser.add_argument('--split', type=str, default='validation')
|
890 |
+
parser.add_argument('--no_batch', action='store_true')
|
891 |
+
parser.add_argument('--disable_invalid', action="store_true")
|
892 |
+
parser.add_argument('--generate_meta_infos', action="store_true")
|
893 |
+
args = parser.parse_args()
|
894 |
+
|
895 |
+
if args.generate_meta_infos:
|
896 |
+
generate_meta_infos(args.output_dir)
|
897 |
+
|
898 |
+
elif args.no_batch:
|
899 |
+
|
900 |
+
output_dir_tfrecords_splitted = None
|
901 |
+
if args.split == "validation":
|
902 |
+
output_dir_tfrecords_splitted = os.path.join(args.output_dir, 'validation_tfrecords_splitted')
|
903 |
+
os.makedirs(output_dir_tfrecords_splitted, exist_ok=True)
|
904 |
+
|
905 |
+
input_dir = os.path.join(args.input_dir, args.split)
|
906 |
+
output_dir = os.path.join(args.output_dir, args.split)
|
907 |
+
os.makedirs(output_dir, exist_ok=True)
|
908 |
+
|
909 |
+
files = sorted(os.listdir(input_dir))
|
910 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
911 |
+
for file in tqdm(files, leave=False, desc=f'Process {args.split}...'):
|
912 |
+
wm2argo(file, input_dir, output_dir, output_dir_tfrecords_splitted)
|
913 |
+
|
914 |
+
else:
|
915 |
+
|
916 |
+
batch_process9s_transformer(args.input_dir, args.output_dir, args.split, num_workers=96)
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/datasets/preprocess.py
ADDED
@@ -0,0 +1,761 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import math
|
4 |
+
import pickle
|
5 |
+
import numpy as np
|
6 |
+
from torch import nn
|
7 |
+
from typing import Dict, Sequence
|
8 |
+
from scipy.interpolate import interp1d
|
9 |
+
from scipy.spatial.distance import euclidean
|
10 |
+
from dev.utils.func import wrap_angle
|
11 |
+
|
12 |
+
|
13 |
+
SHIFT = 5
|
14 |
+
AGENT_SHAPE = {
|
15 |
+
'vehicle': [4.3, 1.8, 1.],
|
16 |
+
'pedstrain': [0.5, 0.5, 1.],
|
17 |
+
'cyclist': [1.9, 0.5, 1.],
|
18 |
+
}
|
19 |
+
AGENT_TYPE = ['veh', 'ped', 'cyc', 'seed']
|
20 |
+
AGENT_STATE = ['invalid', 'valid', 'enter', 'exit']
|
21 |
+
|
22 |
+
|
23 |
+
@torch.no_grad()
|
24 |
+
def cal_polygon_contour(pos, head, width_length) -> torch.Tensor: # [n_agent, n_step, n_target, 4, 2]
|
25 |
+
x, y = pos[..., 0], pos[..., 1] # [n_agent, n_step, n_target]
|
26 |
+
width, length = width_length[..., 0], width_length[..., 1] # [n_agent, 1, 1]
|
27 |
+
|
28 |
+
half_cos = 0.5 * head.cos() # [n_agent, n_step, n_target]
|
29 |
+
half_sin = 0.5 * head.sin() # [n_agent, n_step, n_target]
|
30 |
+
length_cos = length * half_cos # [n_agent, n_step, n_target]
|
31 |
+
length_sin = length * half_sin # [n_agent, n_step, n_target]
|
32 |
+
width_cos = width * half_cos # [n_agent, n_step, n_target]
|
33 |
+
width_sin = width * half_sin # [n_agent, n_step, n_target]
|
34 |
+
|
35 |
+
left_front_x = x + length_cos - width_sin
|
36 |
+
left_front_y = y + length_sin + width_cos
|
37 |
+
left_front = torch.stack((left_front_x, left_front_y), dim=-1)
|
38 |
+
|
39 |
+
right_front_x = x + length_cos + width_sin
|
40 |
+
right_front_y = y + length_sin - width_cos
|
41 |
+
right_front = torch.stack((right_front_x, right_front_y), dim=-1)
|
42 |
+
|
43 |
+
right_back_x = x - length_cos + width_sin
|
44 |
+
right_back_y = y - length_sin - width_cos
|
45 |
+
right_back = torch.stack((right_back_x, right_back_y), dim=-1)
|
46 |
+
|
47 |
+
left_back_x = x - length_cos - width_sin
|
48 |
+
left_back_y = y - length_sin + width_cos
|
49 |
+
left_back = torch.stack((left_back_x, left_back_y), dim=-1)
|
50 |
+
|
51 |
+
polygon_contour = torch.stack(
|
52 |
+
(left_front, right_front, right_back, left_back), dim=-2
|
53 |
+
)
|
54 |
+
|
55 |
+
return polygon_contour
|
56 |
+
|
57 |
+
|
58 |
+
def interplating_polyline(polylines, heading, distance=0.5, split_distace=5):
|
59 |
+
# Calculate the cumulative distance along the path, up-sample the polyline to 0.5 meter
|
60 |
+
dist_along_path_list = [[0]]
|
61 |
+
polylines_list = [[polylines[0]]]
|
62 |
+
for i in range(1, polylines.shape[0]):
|
63 |
+
euclidean_dist = euclidean(polylines[i, :2], polylines[i - 1, :2])
|
64 |
+
heading_diff = min(abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1])),
|
65 |
+
abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1]) + math.pi))
|
66 |
+
if heading_diff > math.pi / 4 and euclidean_dist > 3:
|
67 |
+
dist_along_path_list.append([0])
|
68 |
+
polylines_list.append([polylines[i]])
|
69 |
+
elif heading_diff > math.pi / 8 and euclidean_dist > 3:
|
70 |
+
dist_along_path_list.append([0])
|
71 |
+
polylines_list.append([polylines[i]])
|
72 |
+
elif heading_diff > 0.1 and euclidean_dist > 3:
|
73 |
+
dist_along_path_list.append([0])
|
74 |
+
polylines_list.append([polylines[i]])
|
75 |
+
elif euclidean_dist > 10:
|
76 |
+
dist_along_path_list.append([0])
|
77 |
+
polylines_list.append([polylines[i]])
|
78 |
+
else:
|
79 |
+
dist_along_path_list[-1].append(dist_along_path_list[-1][-1] + euclidean_dist)
|
80 |
+
polylines_list[-1].append(polylines[i])
|
81 |
+
# plt.plot(polylines[:, 0], polylines[:, 1])
|
82 |
+
# plt.savefig('tmp.jpg')
|
83 |
+
new_x_list = []
|
84 |
+
new_y_list = []
|
85 |
+
multi_polylines_list = []
|
86 |
+
for idx in range(len(dist_along_path_list)):
|
87 |
+
if len(dist_along_path_list[idx]) < 2:
|
88 |
+
continue
|
89 |
+
dist_along_path = np.array(dist_along_path_list[idx])
|
90 |
+
polylines_cur = np.array(polylines_list[idx])
|
91 |
+
# Create interpolation functions for x and y coordinates
|
92 |
+
fx = interp1d(dist_along_path, polylines_cur[:, 0])
|
93 |
+
fy = interp1d(dist_along_path, polylines_cur[:, 1])
|
94 |
+
# fyaw = interp1d(dist_along_path, heading)
|
95 |
+
|
96 |
+
# Create an array of distances at which to interpolate
|
97 |
+
new_dist_along_path = np.arange(0, dist_along_path[-1], distance)
|
98 |
+
new_dist_along_path = np.concatenate([new_dist_along_path, dist_along_path[[-1]]])
|
99 |
+
# Use the interpolation functions to generate new x and y coordinates
|
100 |
+
new_x = fx(new_dist_along_path)
|
101 |
+
new_y = fy(new_dist_along_path)
|
102 |
+
# new_yaw = fyaw(new_dist_along_path)
|
103 |
+
new_x_list.append(new_x)
|
104 |
+
new_y_list.append(new_y)
|
105 |
+
|
106 |
+
# Combine the new x and y coordinates into a single array
|
107 |
+
new_polylines = np.vstack((new_x, new_y)).T
|
108 |
+
polyline_size = int(split_distace / distance)
|
109 |
+
if new_polylines.shape[0] >= (polyline_size + 1):
|
110 |
+
padding_size = (new_polylines.shape[0] - (polyline_size + 1)) % polyline_size
|
111 |
+
final_index = (new_polylines.shape[0] - (polyline_size + 1)) // polyline_size + 1
|
112 |
+
else:
|
113 |
+
padding_size = new_polylines.shape[0]
|
114 |
+
final_index = 0
|
115 |
+
multi_polylines = None
|
116 |
+
new_polylines = torch.from_numpy(new_polylines)
|
117 |
+
new_heading = torch.atan2(new_polylines[1:, 1] - new_polylines[:-1, 1],
|
118 |
+
new_polylines[1:, 0] - new_polylines[:-1, 0])
|
119 |
+
new_heading = torch.cat([new_heading, new_heading[-1:]], -1)[..., None]
|
120 |
+
new_polylines = torch.cat([new_polylines, new_heading], -1)
|
121 |
+
if new_polylines.shape[0] >= (polyline_size + 1):
|
122 |
+
multi_polylines = new_polylines.unfold(dimension=0, size=polyline_size + 1, step=polyline_size)
|
123 |
+
multi_polylines = multi_polylines.transpose(1, 2)
|
124 |
+
multi_polylines = multi_polylines[:, ::5, :]
|
125 |
+
if padding_size >= 3:
|
126 |
+
last_polyline = new_polylines[final_index * polyline_size:]
|
127 |
+
last_polyline = last_polyline[torch.linspace(0, last_polyline.shape[0] - 1, steps=3).long()]
|
128 |
+
if multi_polylines is not None:
|
129 |
+
multi_polylines = torch.cat([multi_polylines, last_polyline.unsqueeze(0)], dim=0)
|
130 |
+
else:
|
131 |
+
multi_polylines = last_polyline.unsqueeze(0)
|
132 |
+
if multi_polylines is None:
|
133 |
+
continue
|
134 |
+
multi_polylines_list.append(multi_polylines)
|
135 |
+
if len(multi_polylines_list) > 0:
|
136 |
+
multi_polylines_list = torch.cat(multi_polylines_list, dim=0)
|
137 |
+
else:
|
138 |
+
multi_polylines_list = None
|
139 |
+
return multi_polylines_list
|
140 |
+
|
141 |
+
|
142 |
+
# def interplating_polyline(polylines, heading, distance=0.5, split_distance=5, device='cpu'):
|
143 |
+
# dist_along_path_list = [[0]]
|
144 |
+
# polylines_list = [[polylines[0]]]
|
145 |
+
# for i in range(1, polylines.shape[0]):
|
146 |
+
# euclidean_dist = torch.norm(polylines[i, :2] - polylines[i - 1, :2])
|
147 |
+
# heading_diff = min(abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1])),
|
148 |
+
# abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1]) + torch.pi))
|
149 |
+
# if heading_diff > torch.pi / 4 and euclidean_dist > 3:
|
150 |
+
# dist_along_path_list.append([0])
|
151 |
+
# polylines_list.append([polylines[i]])
|
152 |
+
# elif heading_diff > torch.pi / 8 and euclidean_dist > 3:
|
153 |
+
# dist_along_path_list.append([0])
|
154 |
+
# polylines_list.append([polylines[i]])
|
155 |
+
# elif heading_diff > 0.1 and euclidean_dist > 3:
|
156 |
+
# dist_along_path_list.append([0])
|
157 |
+
# polylines_list.append([polylines[i]])
|
158 |
+
# elif euclidean_dist > 10:
|
159 |
+
# dist_along_path_list.append([0])
|
160 |
+
# polylines_list.append([polylines[i]])
|
161 |
+
# else:
|
162 |
+
# dist_along_path_list[-1].append(dist_along_path_list[-1][-1] + euclidean_dist)
|
163 |
+
# polylines_list[-1].append(polylines[i])
|
164 |
+
|
165 |
+
# new_x_list = []
|
166 |
+
# new_y_list = []
|
167 |
+
# multi_polylines_list = []
|
168 |
+
|
169 |
+
# for idx in range(len(dist_along_path_list)):
|
170 |
+
# if len(dist_along_path_list[idx]) < 2:
|
171 |
+
# continue
|
172 |
+
|
173 |
+
# dist_along_path = torch.tensor(dist_along_path_list[idx], device=device)
|
174 |
+
# polylines_cur = torch.stack(polylines_list[idx])
|
175 |
+
|
176 |
+
# new_dist_along_path = torch.arange(0, dist_along_path[-1], distance)
|
177 |
+
# new_dist_along_path = torch.cat([new_dist_along_path, dist_along_path[[-1]]])
|
178 |
+
|
179 |
+
# new_x = torch.interp(new_dist_along_path, dist_along_path, polylines_cur[:, 0])
|
180 |
+
# new_y = torch.interp(new_dist_along_path, dist_along_path, polylines_cur[:, 1])
|
181 |
+
|
182 |
+
# new_x_list.append(new_x)
|
183 |
+
# new_y_list.append(new_y)
|
184 |
+
|
185 |
+
# new_polylines = torch.stack((new_x, new_y), dim=-1)
|
186 |
+
|
187 |
+
# polyline_size = int(split_distance / distance)
|
188 |
+
# if new_polylines.shape[0] >= (polyline_size + 1):
|
189 |
+
# padding_size = (new_polylines.shape[0] - (polyline_size + 1)) % polyline_size
|
190 |
+
# final_index = (new_polylines.shape[0] - (polyline_size + 1)) // polyline_size + 1
|
191 |
+
# else:
|
192 |
+
# padding_size = new_polylines.shape[0]
|
193 |
+
# final_index = 0
|
194 |
+
|
195 |
+
# multi_polylines = None
|
196 |
+
# new_heading = torch.atan2(new_polylines[1:, 1] - new_polylines[:-1, 1],
|
197 |
+
# new_polylines[1:, 0] - new_polylines[:-1, 0])
|
198 |
+
# new_heading = torch.cat([new_heading, new_heading[-1:]], -1)[..., None]
|
199 |
+
# new_polylines = torch.cat([new_polylines, new_heading], -1)
|
200 |
+
|
201 |
+
# if new_polylines.shape[0] >= (polyline_size + 1):
|
202 |
+
# multi_polylines = new_polylines.unfold(dimension=0, size=polyline_size + 1, step=polyline_size)
|
203 |
+
# multi_polylines = multi_polylines.transpose(1, 2)
|
204 |
+
# multi_polylines = multi_polylines[:, ::5, :]
|
205 |
+
|
206 |
+
# if padding_size >= 3:
|
207 |
+
# last_polyline = new_polylines[final_index * polyline_size:]
|
208 |
+
# last_polyline = last_polyline[torch.linspace(0, last_polyline.shape[0] - 1, steps=3).long()]
|
209 |
+
# if multi_polylines is not None:
|
210 |
+
# multi_polylines = torch.cat([multi_polylines, last_polyline.unsqueeze(0)], dim=0)
|
211 |
+
# else:
|
212 |
+
# multi_polylines = last_polyline.unsqueeze(0)
|
213 |
+
|
214 |
+
# if multi_polylines is None:
|
215 |
+
# continue
|
216 |
+
# multi_polylines_list.append(multi_polylines)
|
217 |
+
|
218 |
+
# if len(multi_polylines_list) > 0:
|
219 |
+
# multi_polylines_list = torch.cat(multi_polylines_list, dim=0)
|
220 |
+
# else:
|
221 |
+
# multi_polylines_list = None
|
222 |
+
|
223 |
+
# return multi_polylines_list
|
224 |
+
|
225 |
+
|
226 |
+
def average_distance_vectorized(point_set1, centroids):
|
227 |
+
dists = np.sqrt(np.sum((point_set1[:, None, :, :] - centroids[None, :, :, :]) ** 2, axis=-1))
|
228 |
+
return np.mean(dists, axis=2)
|
229 |
+
|
230 |
+
|
231 |
+
def assign_clusters(sub_X, centroids):
|
232 |
+
distances = average_distance_vectorized(sub_X, centroids)
|
233 |
+
return np.argmin(distances, axis=1)
|
234 |
+
|
235 |
+
|
236 |
+
class TokenProcessor(nn.Module):
|
237 |
+
|
238 |
+
def __init__(self, token_size,
|
239 |
+
training: bool=False,
|
240 |
+
predict_motion: bool=False,
|
241 |
+
predict_state: bool=False,
|
242 |
+
predict_map: bool=False,
|
243 |
+
state_token: Dict[str, int]=None, **kwargs):
|
244 |
+
super().__init__()
|
245 |
+
|
246 |
+
module_dir = os.path.dirname(os.path.dirname(__file__))
|
247 |
+
self.agent_token_path = os.path.join(module_dir, f'tokens/agent_vocab_555_s2.pkl')
|
248 |
+
self.map_token_traj_path = os.path.join(module_dir, 'tokens/map_traj_token5.pkl')
|
249 |
+
assert os.path.exists(self.agent_token_path), f"File {self.agent_token_path} not found."
|
250 |
+
assert os.path.exists(self.map_token_traj_path), f"File {self.map_token_traj_path} not found."
|
251 |
+
|
252 |
+
self.training = training
|
253 |
+
self.token_size = token_size
|
254 |
+
self.disable_invalid = not predict_state
|
255 |
+
self.predict_motion = predict_motion
|
256 |
+
self.predict_state = predict_state
|
257 |
+
self.predict_map = predict_map
|
258 |
+
|
259 |
+
# define new special tokens
|
260 |
+
self.bos_token_index = token_size
|
261 |
+
self.eos_token_index = token_size + 1
|
262 |
+
self.invalid_token_index = token_size + 2
|
263 |
+
self.special_token_index = []
|
264 |
+
self._init_token()
|
265 |
+
|
266 |
+
# define agent states
|
267 |
+
self.invalid_state = int(state_token['invalid'])
|
268 |
+
self.valid_state = int(state_token['valid'])
|
269 |
+
self.enter_state = int(state_token['enter'])
|
270 |
+
self.exit_state = int(state_token['exit'])
|
271 |
+
|
272 |
+
self.pl2seed_radius = kwargs.get('pl2seed_radius', None)
|
273 |
+
|
274 |
+
self.noise = False
|
275 |
+
self.disturb = False
|
276 |
+
self.shift = 5
|
277 |
+
self.training = False
|
278 |
+
self.current_step = 10
|
279 |
+
|
280 |
+
# debugging
|
281 |
+
self.debug_data = None
|
282 |
+
|
283 |
+
def forward(self, data):
|
284 |
+
"""
|
285 |
+
Each pkl data represents a extracted scenario from raw tfrecord data
|
286 |
+
"""
|
287 |
+
data['agent']['av_index'] = data['agent']['av_idx']
|
288 |
+
data = self._tokenize_agent(data)
|
289 |
+
# data = self._tokenize_map(data)
|
290 |
+
del data['city']
|
291 |
+
if 'polygon_is_intersection' in data['map_polygon']:
|
292 |
+
del data['map_polygon']['polygon_is_intersection']
|
293 |
+
if 'route_type' in data['map_polygon']:
|
294 |
+
del data['map_polygon']['route_type']
|
295 |
+
|
296 |
+
av_index = int(data['agent']['av_idx'])
|
297 |
+
data['ego_pos'] = data['agent']['token_pos'][[av_index]]
|
298 |
+
data['ego_heading'] = data['agent']['token_heading'][[av_index]]
|
299 |
+
|
300 |
+
return data
|
301 |
+
|
302 |
+
def _init_token(self):
|
303 |
+
|
304 |
+
agent_token_data = pickle.load(open(self.agent_token_path, 'rb'))
|
305 |
+
for agent_type, token in agent_token_data['token_all'].items():
|
306 |
+
token = torch.tensor(token, dtype=torch.float32)
|
307 |
+
self.register_buffer(f'agent_token_all_{agent_type}', token, persistent=False) # [n_token, 6, 4, 2]
|
308 |
+
|
309 |
+
map_token_traj = pickle.load(open(self.map_token_traj_path, 'rb'))['traj_src']
|
310 |
+
map_token_traj = torch.tensor(map_token_traj, dtype=torch.float32)
|
311 |
+
self.register_buffer('map_token_traj_src', map_token_traj, persistent=False) # [n_token, 11 * 2]
|
312 |
+
|
313 |
+
# self.trajectory_token = agent_token_data['token'] # (token_size, 4, 2)
|
314 |
+
# self.trajectory_token_all = agent_token_data['token_all'] # (token_size, shift + 1, 4, 2)
|
315 |
+
# self.map_token = {'traj_src': map_token_traj['traj_src']}
|
316 |
+
|
317 |
+
@staticmethod
|
318 |
+
def clean_heading(valid: torch.Tensor, heading: torch.Tensor) -> torch.Tensor:
|
319 |
+
valid_pairs = valid[:, :-1] & valid[:, 1:]
|
320 |
+
for i in range(heading.shape[1] - 1):
|
321 |
+
heading_diff = torch.abs(wrap_angle(heading[:, i] - heading[:, i + 1]))
|
322 |
+
change_needed = (heading_diff > 1.5) & valid_pairs[:, i]
|
323 |
+
heading[:, i + 1][change_needed] = heading[:, i][change_needed]
|
324 |
+
return heading
|
325 |
+
|
326 |
+
def _extrapolate_agent_to_prev_token_step(self, valid, pos, heading, vel) -> Sequence[torch.Tensor]:
|
327 |
+
# [n_agent], max will give the first True step
|
328 |
+
first_valid_step = torch.max(valid, dim=1).indices
|
329 |
+
|
330 |
+
for i, t in enumerate(first_valid_step): # extrapolate to previous 5th step.
|
331 |
+
n_step_to_extrapolate = t % self.shift
|
332 |
+
if (t == self.current_step) and (not valid[i, self.current_step - self.shift]):
|
333 |
+
# such that at least one token is valid in the history.
|
334 |
+
n_step_to_extrapolate = self.shift
|
335 |
+
|
336 |
+
if n_step_to_extrapolate > 0:
|
337 |
+
vel[i, t - n_step_to_extrapolate : t] = vel[i, t]
|
338 |
+
valid[i, t - n_step_to_extrapolate : t] = True
|
339 |
+
heading[i, t - n_step_to_extrapolate : t] = heading[i, t]
|
340 |
+
|
341 |
+
for j in range(n_step_to_extrapolate):
|
342 |
+
pos[i, t - j - 1] = pos[i, t - j] - vel[i, t] * 0.1
|
343 |
+
|
344 |
+
return valid, pos, heading, vel
|
345 |
+
|
346 |
+
def _get_agent_shape(self, agent_type_masks: dict) -> torch.Tensor:
|
347 |
+
agent_shape = 0.
|
348 |
+
for type, type_mask in agent_type_masks.items():
|
349 |
+
if type == 'veh': width = 2.; length = 4.8
|
350 |
+
if type == 'ped': width = 1.; length = 2.
|
351 |
+
if type == 'cyc': width = 1.; length = 1.
|
352 |
+
agent_shape += torch.stack([width * type_mask, length * type_mask], dim=-1)
|
353 |
+
|
354 |
+
return agent_shape
|
355 |
+
|
356 |
+
def _get_token_traj_all(self, agent_type_masks: dict) -> torch.Tensor:
|
357 |
+
token_traj_all = 0.
|
358 |
+
for type, type_mask in agent_type_masks.items():
|
359 |
+
token_traj_all += type_mask[:, None, None, None, None] * (
|
360 |
+
getattr(self, f'agent_token_all_{type}').unsqueeze(0)
|
361 |
+
)
|
362 |
+
return token_traj_all
|
363 |
+
|
364 |
+
def _tokenize_agent(self, data):
|
365 |
+
|
366 |
+
# get raw data
|
367 |
+
valid_mask = data['agent']['valid_mask'] # [n_agent, n_step]
|
368 |
+
agent_heading = data['agent']['heading'] # [n_agent, n_step]
|
369 |
+
agent_pos = data['agent']['position'][..., :2].contiguous() # [n_agent, n_step, 2]
|
370 |
+
agent_vel = data['agent']['velocity'] # [n_agent, n_step, 2]
|
371 |
+
agent_type = data['agent']['type']
|
372 |
+
agent_category = data['agent']['category']
|
373 |
+
|
374 |
+
n_agent, n_all_step = valid_mask.shape
|
375 |
+
|
376 |
+
agent_type_masks = {
|
377 |
+
"veh": agent_type == 0,
|
378 |
+
"ped": agent_type == 1,
|
379 |
+
"cyc": agent_type == 2,
|
380 |
+
}
|
381 |
+
agent_heading = self.clean_heading(valid_mask, agent_heading)
|
382 |
+
agent_shape = self._get_agent_shape(agent_type_masks)
|
383 |
+
token_traj_all = self._get_token_traj_all(agent_type_masks)
|
384 |
+
valid_mask, agent_pos, agent_heading, agent_vel = self._extrapolate_agent_to_prev_token_step(
|
385 |
+
valid_mask, agent_pos, agent_heading, agent_vel
|
386 |
+
)
|
387 |
+
token_traj = token_traj_all[:, :, -1, ...]
|
388 |
+
data['agent']['token_traj_all'] = token_traj_all # [n_agent, n_token, 6, 4, 2]
|
389 |
+
data['agent']['token_traj'] = token_traj # [n_agent, n_token, 4, 2]
|
390 |
+
|
391 |
+
valid_mask_shift = valid_mask.unfold(1, self.shift + 1, self.shift)
|
392 |
+
token_valid_mask = valid_mask_shift[:, :, 0] * valid_mask_shift[:, :, -1]
|
393 |
+
|
394 |
+
# vehicle_mask = agent_type == 0
|
395 |
+
# cyclist_mask = agent_type == 2
|
396 |
+
# ped_mask = agent_type == 1
|
397 |
+
|
398 |
+
# veh_pos = agent_pos[vehicle_mask, :, :]
|
399 |
+
# veh_valid_mask = valid_mask[vehicle_mask, :]
|
400 |
+
# cyc_pos = agent_pos[cyclist_mask, :, :]
|
401 |
+
# cyc_valid_mask = valid_mask[cyclist_mask, :]
|
402 |
+
# ped_pos = agent_pos[ped_mask, :, :]
|
403 |
+
# ped_valid_mask = valid_mask[ped_mask, :]
|
404 |
+
|
405 |
+
# index: [n_agent, n_step] contour: [n_agent, n_step, 4, 2]
|
406 |
+
token_index, token_contour, token_all = self._match_agent_token(
|
407 |
+
valid_mask, agent_pos, agent_heading, agent_shape, token_traj, None # token_traj_all
|
408 |
+
)
|
409 |
+
|
410 |
+
traj_pos = traj_heading = None
|
411 |
+
if len(token_all) > 0:
|
412 |
+
traj_pos = token_all.mean(dim=3) # [n_agent, n_step, 6, 2]
|
413 |
+
diff_xy = token_all[..., 0, :] - token_all[..., 3, :]
|
414 |
+
traj_heading = torch.arctan2(diff_xy[..., 1], diff_xy[..., 0])
|
415 |
+
token_pos = token_contour.mean(dim=2) # [n_agent, n_step, 2]
|
416 |
+
diff_xy = token_contour[:, :, 0, :] - token_contour[:, :, 3, :]
|
417 |
+
token_heading = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])
|
418 |
+
|
419 |
+
# token_index: (num_agent, num_timestep // shift) e.g. (49, 18)
|
420 |
+
# token_contour: (num_agent, num_timestep // shift, contour_dim, feat_dim, 2) e.g. (49, 18, 4, 2)
|
421 |
+
# veh_token_index, veh_token_contour = self._match_agent_token(veh_valid_mask, veh_pos, agent_heading[vehicle_mask],
|
422 |
+
# 'veh', agent_shape[vehicle_mask])
|
423 |
+
# ped_token_index, ped_token_contour = self._match_agent_token(ped_valid_mask, ped_pos, agent_heading[ped_mask],
|
424 |
+
# 'ped', agent_shape[ped_mask])
|
425 |
+
# cyc_token_index, cyc_token_contour = self._match_agent_token(cyc_valid_mask, cyc_pos, agent_heading[cyclist_mask],
|
426 |
+
# 'cyc', agent_shape[cyclist_mask])
|
427 |
+
|
428 |
+
# token_index = torch.zeros((agent_pos.shape[0], veh_token_index.shape[1])).to(torch.int64)
|
429 |
+
# token_index[vehicle_mask] = veh_token_index
|
430 |
+
# token_index[ped_mask] = ped_token_index
|
431 |
+
# token_index[cyclist_mask] = cyc_token_index
|
432 |
+
|
433 |
+
# ! compute agent states
|
434 |
+
bos_index = torch.argmax(token_valid_mask.long(), dim=1)
|
435 |
+
eos_index = token_valid_mask.shape[1] - 1 - torch.argmax(torch.flip(token_valid_mask.long(), dims=[1]), dim=1)
|
436 |
+
state_index = torch.ones_like(token_index) # init with all valid
|
437 |
+
step_index = torch.arange(state_index.shape[1])[None].repeat(state_index.shape[0], 1).to(token_index.device)
|
438 |
+
state_index[step_index == bos_index[:, None]] = self.enter_state
|
439 |
+
state_index[step_index == eos_index[:, None]] = self.exit_state
|
440 |
+
state_index[(step_index < bos_index[:, None]) | (step_index > eos_index[:, None])] = self.invalid_state
|
441 |
+
# ! IMPORTANT: if the last step is exit token, should convert it back to valid token
|
442 |
+
state_index[state_index[:, -1] == self.exit_state, -1] = self.valid_state
|
443 |
+
|
444 |
+
# update token attributions according to state tokens
|
445 |
+
token_valid_mask[state_index == self.enter_state] = False
|
446 |
+
token_pos[state_index == self.invalid_state] = 0.
|
447 |
+
token_heading[state_index == self.invalid_state] = 0.
|
448 |
+
for i in range(self.shift, agent_pos.shape[1], self.shift):
|
449 |
+
is_bos = state_index[:, i // self.shift - 1] == self.enter_state
|
450 |
+
token_pos[is_bos, i // self.shift - 1] = agent_pos[is_bos, i].clone()
|
451 |
+
# token_heading[is_bos, i // self.shift - 1] = agent_heading[is_bos, i].clone()
|
452 |
+
token_index[state_index == self.invalid_state] = -1
|
453 |
+
token_index[state_index == self.enter_state] = -2
|
454 |
+
|
455 |
+
# acc_token_valid_step = torch.concat([torch.zeros_like(token_valid_mask[:, :1]),
|
456 |
+
# torch.cumsum(token_valid_mask.int(), dim=1),
|
457 |
+
# torch.zeros_like(token_valid_mask[:, -1:])], dim=1)
|
458 |
+
# state_index = torch.ones_like(token_index) # init with all valid
|
459 |
+
# max_valid_index = torch.argmax(acc_token_valid_step, dim=1)
|
460 |
+
# for step in range(1, acc_token_valid_step.shape[1] - 1):
|
461 |
+
|
462 |
+
# # replace part of motion tokens with special tokens
|
463 |
+
# is_bos = (acc_token_valid_step[:, step] == 0) & (acc_token_valid_step[:, step + 1] == 1)
|
464 |
+
# is_eos = (step == max_valid_index) & (step < acc_token_valid_step.shape[1] - 2) & ~is_bos
|
465 |
+
# is_invalid = ~token_valid_mask[:, step - 1] & ~is_bos & ~is_eos
|
466 |
+
|
467 |
+
# state_index[is_bos, step - 1] = self.enter_state
|
468 |
+
# state_index[is_eos, step - 1] = self.exit_state
|
469 |
+
# state_index[is_invalid, step - 1] = self.invalid_state
|
470 |
+
|
471 |
+
# token_valid_mask[state_index[:, 0] == self.valid_state, 0] = False
|
472 |
+
# state_index[state_index[:, 0] == self.valid_state, 0] = self.enter_state
|
473 |
+
|
474 |
+
# token_contour = torch.zeros((agent_pos.shape[0], veh_token_contour.shape[1],
|
475 |
+
# veh_token_contour.shape[2], veh_token_contour.shape[3]))
|
476 |
+
# token_contour[vehicle_mask] = veh_token_contour
|
477 |
+
# token_contour[ped_mask] = ped_token_contour
|
478 |
+
# token_contour[cyclist_mask] = cyc_token_contour
|
479 |
+
|
480 |
+
raw_token_valid_mask = token_valid_mask.clone()
|
481 |
+
if not self.disable_invalid:
|
482 |
+
token_valid_mask = torch.ones_like(token_valid_mask).bool()
|
483 |
+
|
484 |
+
# apply mask
|
485 |
+
# apply_mask = raw_token_valid_mask.sum(dim=-1) > 2
|
486 |
+
# if self.training and os.getenv('AUG_MASK', False):
|
487 |
+
# aug_mask = torch.randint(0, 2, (raw_token_valid_mask.shape[0],)).to(raw_token_valid_mask).bool()
|
488 |
+
# apply_mask &= aug_mask
|
489 |
+
|
490 |
+
# remove invalid agents which are outside the range of pl2inva_radius
|
491 |
+
# remove_ina_mask = torch.zeros_like(data['agent']['train_mask'])
|
492 |
+
# if self.pl2seed_radius is not None:
|
493 |
+
# num_history_token = 1 if self.training else 2 # NOTE: hard code!!!
|
494 |
+
# av_index = int(data['agent']['av_index'])
|
495 |
+
# is_invalid = torch.any(state_index[:, :num_history_token] == self.invalid_state, dim=-1)
|
496 |
+
# ina_bos_mask = (state_index == self.enter_state) & is_invalid[:, None]
|
497 |
+
# invalid_bos_step = torch.nonzero(ina_bos_mask, as_tuple=False)
|
498 |
+
# av_bos_pos = token_pos[av_index, invalid_bos_step[:, 1]] # (num_invalid_bos, 2)
|
499 |
+
# ina_bos_pos = token_pos[invalid_bos_step[:, 0], invalid_bos_step[:, 1]] # (num_invalid_bos, 2)
|
500 |
+
# distance = torch.sqrt(torch.sum((ina_bos_pos - av_bos_pos) ** 2, dim=-1))
|
501 |
+
# remove_ina_mask = (distance > self.pl2seed_radius) | (distance < 0.)
|
502 |
+
# # apply_mask[invalid_bos_step[remove_ina_mask, 0]] = False
|
503 |
+
|
504 |
+
# data['agent']['remove_ina_mask'] = remove_ina_mask
|
505 |
+
|
506 |
+
# apply_mask[int(data['agent']['av_index'])] = True
|
507 |
+
# data['agent']['num_nodes'] = apply_mask.sum()
|
508 |
+
|
509 |
+
# av_id = data['agent']['id'][data['agent']['av_index']]
|
510 |
+
# data['agent']['id'] = [data['agent']['id'][i] for i in range(len(apply_mask)) if apply_mask[i]]
|
511 |
+
# data['agent']['av_index'] = data['agent']['id'].index(av_id)
|
512 |
+
# data['agent']['id'] = torch.tensor(data['agent']['id'], dtype=torch.long)
|
513 |
+
|
514 |
+
# agent_keys = ['valid_mask', 'predict_mask', 'type', 'category', 'position', 'heading', 'velocity', 'shape']
|
515 |
+
# for key in agent_keys:
|
516 |
+
# if key in data['agent']:
|
517 |
+
# data['agent'][key] = data['agent'][key][apply_mask]
|
518 |
+
|
519 |
+
# reset agent shapes
|
520 |
+
for i in range(n_agent):
|
521 |
+
bos_shape_index = torch.nonzero(torch.all(data['agent']['shape'][i] != 0., dim=-1))[0]
|
522 |
+
data['agent']['shape'][i, :] = data['agent']['shape'][i, bos_shape_index]
|
523 |
+
if torch.any(torch.all(data['agent']['shape'][i] == 0., dim=-1)):
|
524 |
+
raise ValueError(f"Found invalid shape values.")
|
525 |
+
|
526 |
+
# compute mean height values for each scenario
|
527 |
+
raw_height = data['agent']['position'][:, self.current_step, 2]
|
528 |
+
valid_height = raw_token_valid_mask[:, 1].bool()
|
529 |
+
veh_mean_z = raw_height[agent_type_masks['veh'] & valid_height].mean()
|
530 |
+
ped_mean_z = raw_height[agent_type_masks['ped'] & valid_height].mean().nan_to_num_(veh_mean_z) # FIXME: hard code
|
531 |
+
cyc_mean_z = raw_height[agent_type_masks['cyc'] & valid_height].mean().nan_to_num_(veh_mean_z)
|
532 |
+
|
533 |
+
# output
|
534 |
+
data['agent']['token_idx'] = token_index
|
535 |
+
data['agent']['state_idx'] = state_index
|
536 |
+
data['agent']['token_contour'] = token_contour
|
537 |
+
data['agent']['traj_pos'] = traj_pos
|
538 |
+
data['agent']['traj_heading'] = traj_heading
|
539 |
+
data['agent']['token_pos'] = token_pos
|
540 |
+
data['agent']['token_heading'] = token_heading
|
541 |
+
data['agent']['agent_valid_mask'] = token_valid_mask # (a, t)
|
542 |
+
data['agent']['raw_agent_valid_mask'] = raw_token_valid_mask
|
543 |
+
data['agent']['raw_height'] = dict(veh=veh_mean_z,
|
544 |
+
ped=ped_mean_z,
|
545 |
+
cyc=cyc_mean_z)
|
546 |
+
for type in ['veh', 'ped', 'cyc']:
|
547 |
+
data['agent'][f'trajectory_token_{type}'] = getattr(
|
548 |
+
self, f'agent_token_all_{type}') # [n_token, 6, 4, 2]
|
549 |
+
|
550 |
+
return data
|
551 |
+
|
552 |
+
def _match_agent_token(self, valid_mask, pos, heading, shape, token_traj, token_traj_all=None):
|
553 |
+
"""
|
554 |
+
Parameters:
|
555 |
+
valid_mask (torch.Tensor): Validity mask for agents over time. Shape: (n_agent, n_step)
|
556 |
+
pos (torch.Tensor): Positions of agents at each time step. Shape: (n_agent, n_step, 3)
|
557 |
+
heading (torch.Tensor): Headings of agents at each time step. Shape: (n_agent, n_step)
|
558 |
+
shape (torch.Tensor): Shape information of agents. Shape: (n_agent, 3)
|
559 |
+
token_traj (torch.Tensor): Token trajectories for agents. Shape: (n_agent, n_token, 4, 2)
|
560 |
+
token_traj_all (torch.Tensor): Token trajectories for all agents. Shape: (n_agnet, n_token_all, n_contour, 4, 2)
|
561 |
+
|
562 |
+
Returns:
|
563 |
+
tuple: Contains token indices and contours for agents.
|
564 |
+
"""
|
565 |
+
|
566 |
+
n_agent, n_step = valid_mask.shape
|
567 |
+
|
568 |
+
# agent_token_src = self.trajectory_token[category]
|
569 |
+
# if self.shift <= 2:
|
570 |
+
# if category == 'veh':
|
571 |
+
# width = 1.0
|
572 |
+
# length = 2.4
|
573 |
+
# elif category == 'cyc':
|
574 |
+
# width = 0.5
|
575 |
+
# length = 1.5
|
576 |
+
# else:
|
577 |
+
# width = 0.5
|
578 |
+
# length = 0.5
|
579 |
+
# else:
|
580 |
+
# if category == 'veh':
|
581 |
+
# width = 2.0
|
582 |
+
# length = 4.8
|
583 |
+
# elif category == 'cyc':
|
584 |
+
# width = 1.0
|
585 |
+
# length = 2.0
|
586 |
+
# else:
|
587 |
+
# width = 1.0
|
588 |
+
# length = 1.0
|
589 |
+
|
590 |
+
_, n_token, token_contour_dim, feat_dim = token_traj.shape
|
591 |
+
# agent_token_src = agent_token_src.reshape(1, token_num * token_contour_dim, feat_dim).repeat(agent_num, 0)
|
592 |
+
|
593 |
+
token_index_list = []
|
594 |
+
token_contour_list = []
|
595 |
+
token_all = []
|
596 |
+
|
597 |
+
prev_heading = heading[:, 0]
|
598 |
+
prev_pos = pos[:, 0]
|
599 |
+
prev_token_idx = None
|
600 |
+
for i in range(self.shift, n_step, self.shift): # [5, 10, 15, ..., 90]
|
601 |
+
_valid_mask = valid_mask[:, i - self.shift] & valid_mask[:, i]
|
602 |
+
_invalid_mask = ~_valid_mask
|
603 |
+
|
604 |
+
# transformation
|
605 |
+
theta = prev_heading
|
606 |
+
cos, sin = theta.cos(), theta.sin()
|
607 |
+
rot_mat = theta.new_zeros(n_agent, 2, 2)
|
608 |
+
rot_mat[:, 0, 0] = cos
|
609 |
+
rot_mat[:, 0, 1] = sin
|
610 |
+
rot_mat[:, 1, 0] = -sin
|
611 |
+
rot_mat[:, 1, 1] = cos
|
612 |
+
agent_token_world = torch.bmm(token_traj.flatten(1, 2), rot_mat).reshape(*token_traj.shape)
|
613 |
+
agent_token_world += prev_pos[:, None, None, :]
|
614 |
+
|
615 |
+
cur_contour = cal_polygon_contour(pos[:, i], heading[:, i], shape) # [n_agent, 4, 2]
|
616 |
+
agent_token_index = torch.argmin(
|
617 |
+
torch.norm(agent_token_world - cur_contour[:, None, ...], dim=-1).sum(-1), dim=-1
|
618 |
+
)
|
619 |
+
agent_token_contour = agent_token_world[torch.arange(n_agent), agent_token_index] # [n_agent, 4, 2]
|
620 |
+
# agent_token_index = torch.from_numpy(np.argmin(
|
621 |
+
# np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), axis=2),
|
622 |
+
# axis=-1))
|
623 |
+
|
624 |
+
# except for the first timestep TODO
|
625 |
+
if prev_token_idx is not None and self.noise:
|
626 |
+
same_idx = prev_token_idx == agent_token_index
|
627 |
+
same_idx[:] = True
|
628 |
+
topk_indices = np.argsort(
|
629 |
+
np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)),
|
630 |
+
axis=2), axis=-1)[:, :5]
|
631 |
+
sample_topk = np.random.choice(range(0, topk_indices.shape[1]), topk_indices.shape[0])
|
632 |
+
agent_token_index[same_idx] = \
|
633 |
+
torch.from_numpy(topk_indices[np.arange(topk_indices.shape[0]), sample_topk])[same_idx]
|
634 |
+
|
635 |
+
# update prev_heading
|
636 |
+
prev_heading = heading[:, i].clone()
|
637 |
+
diff_xy = agent_token_contour[:, 0] - agent_token_contour[:, 3]
|
638 |
+
prev_heading[_valid_mask] = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])[_valid_mask]
|
639 |
+
|
640 |
+
# update prev_pos
|
641 |
+
prev_pos = pos[:, i].clone()
|
642 |
+
prev_pos[_valid_mask] = agent_token_contour.mean(dim=1)[_valid_mask]
|
643 |
+
|
644 |
+
prev_token_idx = agent_token_index
|
645 |
+
token_index_list.append(agent_token_index)
|
646 |
+
token_contour_list.append(agent_token_contour)
|
647 |
+
|
648 |
+
# calculate tokenized trajectory
|
649 |
+
if token_traj_all is not None:
|
650 |
+
agent_token_all_world = torch.bmm(token_traj_all.flatten(1, 3), rot_mat).reshape(*token_traj_all.shape)
|
651 |
+
agent_token_all_world += prev_pos[:, None, None, None, :]
|
652 |
+
agent_token_all = agent_token_all_world[torch.arange(n_agent), agent_token_index] # [n_agent, 6, 4, 2]
|
653 |
+
token_all.append(agent_token_all)
|
654 |
+
|
655 |
+
token_index = torch.stack(token_index_list, dim=1) # [n_agent, n_step]
|
656 |
+
token_contour = torch.stack(token_contour_list, dim=1) # [n_agent, n_step, 4, 2]
|
657 |
+
if len(token_all) > 0:
|
658 |
+
token_all = torch.stack(token_all, dim=1) # [n_agent, n_step, 6, 4, 2]
|
659 |
+
|
660 |
+
# sanity check
|
661 |
+
assert tuple(token_index.shape) == (n_agent, n_step // self.shift), \
|
662 |
+
f'Invalid token_index shape, got {token_index.shape}'
|
663 |
+
assert tuple(token_contour.shape )== (n_agent, n_step // self.shift, token_contour_dim, feat_dim), \
|
664 |
+
f'Invalid token_contour shape, got {token_contour.shape}'
|
665 |
+
|
666 |
+
# extra matching
|
667 |
+
# if not self.training:
|
668 |
+
# theta = heading[extra_mask, self.current_step - 1]
|
669 |
+
# prev_pos = pos[extra_mask, self.current_step - 1]
|
670 |
+
# cur_pos = pos[extra_mask, self.current_step]
|
671 |
+
# cur_heading = heading[extra_mask, self.current_step]
|
672 |
+
# cos, sin = theta.cos(), theta.sin()
|
673 |
+
# rot_mat = theta.new_zeros(extra_mask.sum(), 2, 2)
|
674 |
+
# rot_mat[:, 0, 0] = cos
|
675 |
+
# rot_mat[:, 0, 1] = sin
|
676 |
+
# rot_mat[:, 1, 0] = -sin
|
677 |
+
# rot_mat[:, 1, 1] = cos
|
678 |
+
# agent_token_world = torch.bmm(torch.from_numpy(token_last).to(torch.float), rot_mat).reshape(
|
679 |
+
# extra_mask.sum(), token_num, token_contour_dim, feat_dim)
|
680 |
+
# agent_token_world += prev_pos[:, None, None, :]
|
681 |
+
|
682 |
+
# cur_contour = cal_polygon_contour(cur_pos[:, 0], cur_pos[:, 1], cur_heading, width, length)
|
683 |
+
# agent_token_index = torch.from_numpy(np.argmin(
|
684 |
+
# np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), axis=2),
|
685 |
+
# axis=-1))
|
686 |
+
# token_contour_select = agent_token_world[torch.arange(extra_mask.sum()), agent_token_index]
|
687 |
+
|
688 |
+
# token_index[extra_mask, 1] = agent_token_index
|
689 |
+
# token_contour[extra_mask, 1] = token_contour_select
|
690 |
+
|
691 |
+
return token_index, token_contour, token_all
|
692 |
+
|
693 |
+
@staticmethod
|
694 |
+
def _tokenize_map(data):
|
695 |
+
|
696 |
+
data['map_polygon']['type'] = data['map_polygon']['type'].to(torch.uint8)
|
697 |
+
data['map_point']['type'] = data['map_point']['type'].to(torch.uint8)
|
698 |
+
pt2pl = data[('map_point', 'to', 'map_polygon')]['edge_index']
|
699 |
+
pt_type = data['map_point']['type'].to(torch.uint8)
|
700 |
+
pt_side = torch.zeros_like(pt_type)
|
701 |
+
pt_pos = data['map_point']['position'][:, :2]
|
702 |
+
data['map_point']['orientation'] = wrap_angle(data['map_point']['orientation'])
|
703 |
+
pt_heading = data['map_point']['orientation']
|
704 |
+
split_polyline_type = []
|
705 |
+
split_polyline_pos = []
|
706 |
+
split_polyline_theta = []
|
707 |
+
split_polyline_side = []
|
708 |
+
pl_idx_list = []
|
709 |
+
split_polygon_type = []
|
710 |
+
data['map_point']['type'].unique()
|
711 |
+
|
712 |
+
for i in sorted(np.unique(pt2pl[1])): # number of polygons in the scenario
|
713 |
+
index = pt2pl[0, pt2pl[1] == i] # index of points which belongs to i-th polygon
|
714 |
+
polygon_type = data['map_polygon']["type"][i]
|
715 |
+
cur_side = pt_side[index]
|
716 |
+
cur_type = pt_type[index]
|
717 |
+
cur_pos = pt_pos[index]
|
718 |
+
cur_heading = pt_heading[index]
|
719 |
+
|
720 |
+
for side_val in np.unique(cur_side):
|
721 |
+
for type_val in np.unique(cur_type):
|
722 |
+
if type_val == 13:
|
723 |
+
continue
|
724 |
+
indices = np.where((cur_side == side_val) & (cur_type == type_val))[0]
|
725 |
+
if len(indices) <= 2:
|
726 |
+
continue
|
727 |
+
split_polyline = interplating_polyline(cur_pos[indices].numpy(), cur_heading[indices].numpy())
|
728 |
+
if split_polyline is None:
|
729 |
+
continue
|
730 |
+
new_cur_type = cur_type[indices][0]
|
731 |
+
new_cur_side = cur_side[indices][0]
|
732 |
+
map_polygon_type = polygon_type.repeat(split_polyline.shape[0])
|
733 |
+
new_cur_type = new_cur_type.repeat(split_polyline.shape[0])
|
734 |
+
new_cur_side = new_cur_side.repeat(split_polyline.shape[0])
|
735 |
+
cur_pl_idx = torch.Tensor([i])
|
736 |
+
new_cur_pl_idx = cur_pl_idx.repeat(split_polyline.shape[0])
|
737 |
+
split_polyline_pos.append(split_polyline[..., :2])
|
738 |
+
split_polyline_theta.append(split_polyline[..., 2])
|
739 |
+
split_polyline_type.append(new_cur_type)
|
740 |
+
split_polyline_side.append(new_cur_side)
|
741 |
+
pl_idx_list.append(new_cur_pl_idx)
|
742 |
+
split_polygon_type.append(map_polygon_type)
|
743 |
+
|
744 |
+
split_polyline_pos = torch.cat(split_polyline_pos, dim=0)
|
745 |
+
split_polyline_theta = torch.cat(split_polyline_theta, dim=0)
|
746 |
+
split_polyline_type = torch.cat(split_polyline_type, dim=0)
|
747 |
+
split_polyline_side = torch.cat(split_polyline_side, dim=0)
|
748 |
+
split_polygon_type = torch.cat(split_polygon_type, dim=0)
|
749 |
+
pl_idx_list = torch.cat(pl_idx_list, dim=0)
|
750 |
+
|
751 |
+
data['map_save'] = {}
|
752 |
+
data['pt_token'] = {}
|
753 |
+
data['map_save']['traj_pos'] = split_polyline_pos
|
754 |
+
data['map_save']['traj_theta'] = split_polyline_theta[:, 0] # torch.arctan2(vec[:, 1], vec[:, 0])
|
755 |
+
data['map_save']['pl_idx_list'] = pl_idx_list
|
756 |
+
data['pt_token']['type'] = split_polyline_type
|
757 |
+
data['pt_token']['side'] = split_polyline_side
|
758 |
+
data['pt_token']['pl_type'] = split_polygon_type
|
759 |
+
data['pt_token']['num_nodes'] = split_polyline_pos.shape[0]
|
760 |
+
|
761 |
+
return data
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/datasets/scalable_dataset.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
import torch
|
4 |
+
import json
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
import pandas as pd
|
7 |
+
from tqdm import tqdm
|
8 |
+
from torch_geometric.data import HeteroData, Dataset
|
9 |
+
from torch_geometric.transforms import BaseTransform
|
10 |
+
from torch_geometric.loader import DataLoader
|
11 |
+
from typing import Callable, Dict, List, Optional
|
12 |
+
|
13 |
+
from dev.datasets.preprocess import TokenProcessor
|
14 |
+
|
15 |
+
|
16 |
+
class MultiDataset(Dataset):
|
17 |
+
def __init__(self,
|
18 |
+
split: str,
|
19 |
+
raw_dir: List[str] = None,
|
20 |
+
transform: Optional[Callable] = None,
|
21 |
+
tfrecord_dir: Optional[str] = None,
|
22 |
+
token_size=512,
|
23 |
+
predict_motion: bool=False,
|
24 |
+
predict_state: bool=False,
|
25 |
+
predict_map: bool=False,
|
26 |
+
# state_token: Dict[str, int]=None,
|
27 |
+
# pl2seed_radius: float=None,
|
28 |
+
buffer_size: int=128,
|
29 |
+
logger=None) -> None:
|
30 |
+
|
31 |
+
self.disable_invalid = not predict_state
|
32 |
+
self.predict_motion = predict_motion
|
33 |
+
self.predict_state = predict_state
|
34 |
+
self.predict_map = predict_map
|
35 |
+
self.logger = logger
|
36 |
+
if split not in ('train', 'val', 'test'):
|
37 |
+
raise ValueError(f'{split} is not a valid split')
|
38 |
+
self.training = split == 'train'
|
39 |
+
self.buffer_size = buffer_size
|
40 |
+
self._tfrecord_dir = tfrecord_dir
|
41 |
+
self.logger.debug('Starting loading dataset')
|
42 |
+
|
43 |
+
raw_dir = os.path.expanduser(os.path.normpath(raw_dir))
|
44 |
+
self._raw_files = sorted(os.listdir(raw_dir))
|
45 |
+
|
46 |
+
# for debugging
|
47 |
+
if int(os.getenv('OVERFIT', 0)):
|
48 |
+
# if self.training:
|
49 |
+
# # scenario_id = ['74ad7b76d5906d39', '13596229fd8cdb7e', '1d73db1fc42be3bf', '1351ea8b8333ddcb']
|
50 |
+
# self._raw_files = ['74ad7b76d5906d39.pkl'] + self._raw_files[:9]
|
51 |
+
# else:
|
52 |
+
# self._raw_files = self._raw_files[:10]
|
53 |
+
self._raw_files = self._raw_files[:1]
|
54 |
+
# self._raw_files = ['1002fdc9826fc6d1.pkl']
|
55 |
+
|
56 |
+
# load meta infos and do filter
|
57 |
+
json_path = '/u/xiuyu/work/dev4/data/waymo_processed/meta_infos.json'
|
58 |
+
label = 'training' if split == 'train' else ('validation' if split == 'val' else split)
|
59 |
+
self.meta_infos = json.load(open(json_path, 'r', encoding='utf-8'))[label]
|
60 |
+
self.logger.debug(f"Loaded meta infos from {json_path}")
|
61 |
+
self.available_scenarios = list(self.meta_infos.keys())
|
62 |
+
# self._raw_files = list(tqdm(filter(lambda fn: (
|
63 |
+
# scenario_id := fn.removesuffix('.pkl') in self.available_scenarios and
|
64 |
+
# 8 <= self.meta_infos[scenario_id]['num_agents'] < self.buffer_size
|
65 |
+
# ), self._raw_files), leave=False))
|
66 |
+
df = pd.DataFrame.from_dict(self.meta_infos, orient='index')
|
67 |
+
available_scenarios_set = set(self.available_scenarios)
|
68 |
+
df_filtered = df[(df.index.isin(available_scenarios_set)) & (df['num_agents'] >= 8) & (df['num_agents'] < self.buffer_size)]
|
69 |
+
valid_scenarios = set(df_filtered.index)
|
70 |
+
self._raw_files = list(tqdm(filter(lambda fn: fn.removesuffix('.pkl') in valid_scenarios, self._raw_files), leave=False))
|
71 |
+
if len(self._raw_files) <= 0:
|
72 |
+
raise RuntimeError(f'Invalid number of data {len(self._raw_files)}!')
|
73 |
+
self._raw_paths = list(map(lambda fn: os.path.join(raw_dir, fn), self._raw_files))
|
74 |
+
|
75 |
+
self.logger.debug(f"The number of {split} dataset is {len(self._raw_paths)}")
|
76 |
+
self.logger.debug(f"The buffer size is {self.buffer_size}")
|
77 |
+
# self.token_processor = TokenProcessor(token_size,
|
78 |
+
# training=self.training,
|
79 |
+
# predict_motion=self.predict_motion,
|
80 |
+
# predict_state=self.predict_state,
|
81 |
+
# predict_map=self.predict_map,
|
82 |
+
# state_token=state_token,
|
83 |
+
# pl2seed_radius=pl2seed_radius) # 2048
|
84 |
+
self.logger.debug(f"The used token size is {token_size}.")
|
85 |
+
super().__init__(transform=transform, pre_transform=None, pre_filter=None)
|
86 |
+
|
87 |
+
def len(self) -> int:
|
88 |
+
return len(self._raw_paths)
|
89 |
+
|
90 |
+
def get(self, idx: int):
|
91 |
+
"""
|
92 |
+
Load pkl file (each represents a 91s scenario for waymo dataset)
|
93 |
+
"""
|
94 |
+
with open(self._raw_paths[idx], 'rb') as handle:
|
95 |
+
data = pickle.load(handle)
|
96 |
+
|
97 |
+
if self._tfrecord_dir is not None:
|
98 |
+
data['tfrecord_path'] = os.path.join(self._tfrecord_dir, f"{data['scenario_id']}.tfrecords")
|
99 |
+
|
100 |
+
# data = self.token_processor.preprocess(data)
|
101 |
+
return data
|
102 |
+
|
103 |
+
|
104 |
+
class WaymoTargetBuilder(BaseTransform):
|
105 |
+
|
106 |
+
def __init__(self,
|
107 |
+
num_historical_steps: int,
|
108 |
+
num_future_steps: int,
|
109 |
+
max_num: int,
|
110 |
+
training: bool=False) -> None:
|
111 |
+
|
112 |
+
self.max_num = max_num
|
113 |
+
self.num_historical_steps = num_historical_steps
|
114 |
+
self.num_future_steps = num_future_steps
|
115 |
+
self.step_current = num_historical_steps - 1
|
116 |
+
self.training = training
|
117 |
+
|
118 |
+
def _score_trained_agents(self, data):
|
119 |
+
pos = data['agent']['position']
|
120 |
+
av_index = torch.where(data['agent']['role'][:, 0])[0].item()
|
121 |
+
distance = torch.norm(pos - pos[av_index], dim=-1)
|
122 |
+
|
123 |
+
# we do not believe the perception out of range of 150 meters
|
124 |
+
data['agent']['valid_mask'] &= distance < 150
|
125 |
+
|
126 |
+
# we do not predict vehicle too far away from ego car
|
127 |
+
role_train_mask = data['agent']['role'].any(-1)
|
128 |
+
extra_train_mask = (distance[:, self.step_current] < 100) & (
|
129 |
+
data['agent']['valid_mask'][:, self.step_current + 1 :].sum(-1) >= 5
|
130 |
+
)
|
131 |
+
|
132 |
+
train_mask = extra_train_mask | role_train_mask
|
133 |
+
if train_mask.sum() > self.max_num: # too many vehicle
|
134 |
+
_indices = torch.where(extra_train_mask & ~role_train_mask)[0]
|
135 |
+
selected_indices = _indices[
|
136 |
+
torch.randperm(_indices.size(0))[: self.max_num - role_train_mask.sum()]
|
137 |
+
]
|
138 |
+
data['agent']['train_mask'] = role_train_mask
|
139 |
+
data['agent']['train_mask'][selected_indices] = True
|
140 |
+
else:
|
141 |
+
data['agent']['train_mask'] = train_mask # [n_agent]
|
142 |
+
|
143 |
+
return data
|
144 |
+
|
145 |
+
def __call__(self, data) -> HeteroData:
|
146 |
+
|
147 |
+
if self.training:
|
148 |
+
self._score_trained_agents(data)
|
149 |
+
|
150 |
+
data = TokenProcessor._tokenize_map(data)
|
151 |
+
# data keys: dict_keys(['scenario_id', 'agent', 'map_polygon', 'map_point', ('map_point', 'to', 'map_polygon'), ('map_polygon', 'to', 'map_polygon'), 'map_save', 'pt_token'])
|
152 |
+
return HeteroData(data)
|
153 |
+
|
154 |
+
|
155 |
+
class MultiDataModule(pl.LightningDataModule):
|
156 |
+
transforms = {
|
157 |
+
'WaymoTargetBuilder': WaymoTargetBuilder,
|
158 |
+
}
|
159 |
+
|
160 |
+
dataset = {
|
161 |
+
'scalable': MultiDataset,
|
162 |
+
}
|
163 |
+
|
164 |
+
def __init__(self,
|
165 |
+
root: str,
|
166 |
+
train_batch_size: int,
|
167 |
+
val_batch_size: int,
|
168 |
+
test_batch_size: int,
|
169 |
+
shuffle: bool = False,
|
170 |
+
num_workers: int = 0,
|
171 |
+
pin_memory: bool = True,
|
172 |
+
persistent_workers: bool = True,
|
173 |
+
train_raw_dir: Optional[str] = None,
|
174 |
+
val_raw_dir: Optional[str] = None,
|
175 |
+
test_raw_dir: Optional[str] = None,
|
176 |
+
train_processed_dir: Optional[str] = None,
|
177 |
+
val_processed_dir: Optional[str] = None,
|
178 |
+
test_processed_dir: Optional[str] = None,
|
179 |
+
val_tfrecords_splitted: Optional[str] = None,
|
180 |
+
transform: Optional[str] = None,
|
181 |
+
dataset: Optional[str] = None,
|
182 |
+
num_historical_steps: int = 50,
|
183 |
+
num_future_steps: int = 60,
|
184 |
+
processor='ntp',
|
185 |
+
token_size=512,
|
186 |
+
predict_motion: bool=False,
|
187 |
+
predict_state: bool=False,
|
188 |
+
predict_map: bool=False,
|
189 |
+
state_token: Dict[str, int]=None,
|
190 |
+
pl2seed_radius: float=None,
|
191 |
+
max_num: int=32,
|
192 |
+
buffer_size: int=256,
|
193 |
+
logger=None,
|
194 |
+
**kwargs) -> None:
|
195 |
+
|
196 |
+
super(MultiDataModule, self).__init__()
|
197 |
+
self.root = root
|
198 |
+
self.dataset_class = dataset
|
199 |
+
self.train_batch_size = train_batch_size
|
200 |
+
self.val_batch_size = val_batch_size
|
201 |
+
self.test_batch_size = test_batch_size
|
202 |
+
self.shuffle = shuffle
|
203 |
+
self.num_workers = num_workers
|
204 |
+
self.pin_memory = pin_memory
|
205 |
+
self.persistent_workers = persistent_workers and num_workers > 0
|
206 |
+
self.train_raw_dir = train_raw_dir
|
207 |
+
self.val_raw_dir = val_raw_dir
|
208 |
+
self.test_raw_dir = test_raw_dir
|
209 |
+
self.train_processed_dir = train_processed_dir
|
210 |
+
self.val_processed_dir = val_processed_dir
|
211 |
+
self.test_processed_dir = test_processed_dir
|
212 |
+
self.val_tfrecords_splitted = val_tfrecords_splitted
|
213 |
+
self.processor = processor
|
214 |
+
self.token_size = token_size
|
215 |
+
self.predict_motion = predict_motion
|
216 |
+
self.predict_state = predict_state
|
217 |
+
self.predict_map = predict_map
|
218 |
+
self.state_token = state_token
|
219 |
+
self.pl2seed_radius = pl2seed_radius
|
220 |
+
self.buffer_size = buffer_size
|
221 |
+
self.logger = logger
|
222 |
+
|
223 |
+
self.train_transform = MultiDataModule.transforms[transform](num_historical_steps,
|
224 |
+
num_future_steps,
|
225 |
+
max_num=max_num,
|
226 |
+
training=True)
|
227 |
+
self.val_transform = MultiDataModule.transforms[transform](num_historical_steps,
|
228 |
+
num_future_steps,
|
229 |
+
max_num=max_num,
|
230 |
+
training=False)
|
231 |
+
|
232 |
+
def setup(self, stage: Optional[str] = None) -> None:
|
233 |
+
general_params = dict(token_size=self.token_size,
|
234 |
+
predict_motion=self.predict_motion,
|
235 |
+
predict_state=self.predict_state,
|
236 |
+
predict_map=self.predict_map,
|
237 |
+
buffer_size=self.buffer_size,
|
238 |
+
logger=self.logger)
|
239 |
+
|
240 |
+
if stage == 'fit' or stage is None:
|
241 |
+
self.train_dataset = MultiDataModule.dataset[self.dataset_class](split='train',
|
242 |
+
raw_dir=self.train_raw_dir,
|
243 |
+
transform=self.train_transform,
|
244 |
+
**general_params)
|
245 |
+
self.val_dataset = MultiDataModule.dataset[self.dataset_class](split='val',
|
246 |
+
raw_dir=self.val_raw_dir,
|
247 |
+
transform=self.val_transform,
|
248 |
+
tfrecord_dir=self.val_tfrecords_splitted,
|
249 |
+
**general_params)
|
250 |
+
if stage == 'validate':
|
251 |
+
self.val_dataset = MultiDataModule.dataset[self.dataset_class](split='val',
|
252 |
+
raw_dir=self.val_raw_dir,
|
253 |
+
transform=self.val_transform,
|
254 |
+
tfrecord_dir=self.val_tfrecords_splitted,
|
255 |
+
**general_params)
|
256 |
+
if stage == 'test':
|
257 |
+
self.test_dataset = MultiDataModule.dataset[self.dataset_class](split='test',
|
258 |
+
raw_dir=self.test_raw_dir,
|
259 |
+
transform=self.val_transform,
|
260 |
+
tfrecord_dir=self.val_tfrecords_splitted,
|
261 |
+
**general_params)
|
262 |
+
|
263 |
+
def train_dataloader(self):
|
264 |
+
return DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=self.shuffle,
|
265 |
+
num_workers=self.num_workers, pin_memory=self.pin_memory,
|
266 |
+
persistent_workers=self.persistent_workers)
|
267 |
+
|
268 |
+
def val_dataloader(self):
|
269 |
+
return DataLoader(self.val_dataset, batch_size=self.val_batch_size, shuffle=False,
|
270 |
+
num_workers=self.num_workers, pin_memory=self.pin_memory,
|
271 |
+
persistent_workers=self.persistent_workers)
|
272 |
+
|
273 |
+
def test_dataloader(self):
|
274 |
+
return DataLoader(self.test_dataset, batch_size=self.test_batch_size, shuffle=False,
|
275 |
+
num_workers=self.num_workers, pin_memory=self.pin_memory,
|
276 |
+
persistent_workers=self.persistent_workers)
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/metrics/box_utils.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import Tensor
|
3 |
+
|
4 |
+
|
5 |
+
def get_yaw_rotation_2d(yaw):
|
6 |
+
"""
|
7 |
+
Gets a 2D rotation matrix given a yaw angle.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
yaw: torch.Tensor, rotation angle in radians. Can be any shape except empty.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
rotation: torch.Tensor with shape [..., 2, 2], where `...` matches input shape.
|
14 |
+
"""
|
15 |
+
cos_yaw = torch.cos(yaw)
|
16 |
+
sin_yaw = torch.sin(yaw)
|
17 |
+
|
18 |
+
rotation = torch.stack([
|
19 |
+
torch.stack([cos_yaw, -sin_yaw], dim=-1),
|
20 |
+
torch.stack([sin_yaw, cos_yaw], dim=-1),
|
21 |
+
], dim=-2) # Shape: [..., 2, 2]
|
22 |
+
|
23 |
+
return rotation
|
24 |
+
|
25 |
+
|
26 |
+
def get_yaw_rotation(yaw):
|
27 |
+
"""
|
28 |
+
Computes a 3D rotation matrix given a yaw angle (rotation around the Z-axis).
|
29 |
+
|
30 |
+
Args:
|
31 |
+
yaw: torch.Tensor of any shape, representing yaw angles in radians.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
rotation: torch.Tensor of shape [input_shape, 3, 3], representing the rotation matrices.
|
35 |
+
"""
|
36 |
+
cos_yaw = torch.cos(yaw)
|
37 |
+
sin_yaw = torch.sin(yaw)
|
38 |
+
ones = torch.ones_like(yaw)
|
39 |
+
zeros = torch.zeros_like(yaw)
|
40 |
+
|
41 |
+
return torch.stack([
|
42 |
+
torch.stack([cos_yaw, -sin_yaw, zeros], dim=-1),
|
43 |
+
torch.stack([sin_yaw, cos_yaw, zeros], dim=-1),
|
44 |
+
torch.stack([zeros, zeros, ones], dim=-1),
|
45 |
+
], dim=-2)
|
46 |
+
|
47 |
+
|
48 |
+
def get_transform(rotation, translation):
|
49 |
+
"""
|
50 |
+
Combines an NxN rotation matrix and an Nx1 translation vector into an (N+1)x(N+1) transformation matrix.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
rotation: torch.Tensor of shape [..., N, N], representing rotation matrices.
|
54 |
+
translation: torch.Tensor of shape [..., N], representing translation vectors.
|
55 |
+
This must have the same dtype as rotation.
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
transform: torch.Tensor of shape [..., (N+1), (N+1)], representing the transformation matrices.
|
59 |
+
This has the same dtype as rotation.
|
60 |
+
"""
|
61 |
+
# [..., N, 1]
|
62 |
+
translation_n_1 = translation.unsqueeze(-1)
|
63 |
+
|
64 |
+
# [..., N, N+1] - Combine rotation and translation
|
65 |
+
transform = torch.cat([rotation, translation_n_1], dim=-1)
|
66 |
+
|
67 |
+
# [..., N] - Create the last row, which is [0, 0, ..., 0, 1]
|
68 |
+
last_row = torch.zeros_like(translation)
|
69 |
+
last_row = torch.cat([last_row, torch.ones_like(last_row[..., :1])], dim=-1)
|
70 |
+
|
71 |
+
# [..., N+1, N+1] - Append the last row to form the final transformation matrix
|
72 |
+
transform = torch.cat([transform, last_row.unsqueeze(-2)], dim=-2)
|
73 |
+
|
74 |
+
return transform
|
75 |
+
|
76 |
+
|
77 |
+
def get_upright_3d_box_corners(boxes: Tensor):
|
78 |
+
"""
|
79 |
+
Given a set of upright 3D bounding boxes, return its 8 corner points.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
boxes: torch.Tensor [N, 7]. The inner dims are [center{x,y,z}, length, width,
|
83 |
+
height, heading].
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
corners: torch.Tensor [N, 8, 3].
|
87 |
+
"""
|
88 |
+
center_x, center_y, center_z, length, width, height, heading = boxes.unbind(dim=-1)
|
89 |
+
|
90 |
+
# Compute rotation matrix [N, 3, 3]
|
91 |
+
rotation = get_yaw_rotation(heading)
|
92 |
+
|
93 |
+
# Translation [N, 3]
|
94 |
+
translation = torch.stack([center_x, center_y, center_z], dim=-1)
|
95 |
+
|
96 |
+
l2, w2, h2 = length * 0.5, width * 0.5, height * 0.5
|
97 |
+
|
98 |
+
# Define the 8 corners in local coordinates [N, 8, 3]
|
99 |
+
corners_local = torch.stack([
|
100 |
+
torch.stack([ l2, w2, -h2], dim=-1),
|
101 |
+
torch.stack([-l2, w2, -h2], dim=-1),
|
102 |
+
torch.stack([-l2, -w2, -h2], dim=-1),
|
103 |
+
torch.stack([ l2, -w2, -h2], dim=-1),
|
104 |
+
torch.stack([ l2, w2, h2], dim=-1),
|
105 |
+
torch.stack([-l2, w2, h2], dim=-1),
|
106 |
+
torch.stack([-l2, -w2, h2], dim=-1),
|
107 |
+
torch.stack([ l2, -w2, h2], dim=-1),
|
108 |
+
], dim=1) # Shape: [N, 8, 3]
|
109 |
+
|
110 |
+
# Rotate and translate the corners
|
111 |
+
corners = torch.einsum('n i j, n k j -> n k i', rotation, corners_local) + translation.unsqueeze(1)
|
112 |
+
|
113 |
+
return corners
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/metrics/compute_metrics.py
ADDED
@@ -0,0 +1,1812 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ! Metrics Calculation
|
2 |
+
import concurrent.futures
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import tensorflow as tf
|
6 |
+
import collections
|
7 |
+
import dataclasses
|
8 |
+
import fnmatch
|
9 |
+
import json
|
10 |
+
import pandas as pd
|
11 |
+
import pickle
|
12 |
+
import copy
|
13 |
+
import concurrent
|
14 |
+
import multiprocessing
|
15 |
+
from torch_geometric.utils import degree
|
16 |
+
from functools import partial
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from tqdm import tqdm
|
19 |
+
from argparse import ArgumentParser
|
20 |
+
from torch import Tensor
|
21 |
+
from google.protobuf import text_format
|
22 |
+
from torchmetrics import Metric
|
23 |
+
from typing import Optional, Sequence, List, Dict
|
24 |
+
|
25 |
+
from waymo_open_dataset.utils.sim_agents import submission_specs
|
26 |
+
|
27 |
+
from dev.utils.visualization import safe_run
|
28 |
+
from dev.utils.func import CONSOLE
|
29 |
+
from dev.datasets.scalable_dataset import WaymoTargetBuilder
|
30 |
+
from dev.datasets.preprocess import TokenProcessor, SHIFT, AGENT_STATE
|
31 |
+
from dev.metrics import trajectory_features, interact_features, map_features, placement_features
|
32 |
+
from dev.metrics.protos import scenario_pb2, long_metrics_pb2
|
33 |
+
|
34 |
+
|
35 |
+
_METRIC_FIELD_NAMES_BY_BUCKET = {
|
36 |
+
'kinematic': [
|
37 |
+
'linear_speed', 'linear_acceleration',
|
38 |
+
'angular_speed', 'angular_acceleration',
|
39 |
+
],
|
40 |
+
'interactive': [
|
41 |
+
'distance_to_nearest_object', 'collision_indication',
|
42 |
+
'time_to_collision',
|
43 |
+
],
|
44 |
+
'map_based': [
|
45 |
+
# 'distance_to_road_edge', 'offroad_indication'
|
46 |
+
],
|
47 |
+
'placement_based': [
|
48 |
+
'num_placement', 'num_removement',
|
49 |
+
'distance_placement', 'distance_removement',
|
50 |
+
]
|
51 |
+
}
|
52 |
+
_METRIC_FIELD_NAMES = (
|
53 |
+
_METRIC_FIELD_NAMES_BY_BUCKET['kinematic'] +
|
54 |
+
_METRIC_FIELD_NAMES_BY_BUCKET['interactive'] +
|
55 |
+
_METRIC_FIELD_NAMES_BY_BUCKET['map_based'] +
|
56 |
+
_METRIC_FIELD_NAMES_BY_BUCKET['placement_based']
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
""" Help Functions """
|
61 |
+
|
62 |
+
def _arg_gather(tensor: Tensor, reference_tensor: Tensor) -> Tensor:
|
63 |
+
"""Finds corresponding indices in `tensor` for each element in `reference_tensor`.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
tensor: A 1D tensor without repetitions.
|
67 |
+
reference_tensor: A 1D tensor containing items from `tensor`.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
A tensor of indices such that `tensor[indices] == reference_tensor`.
|
71 |
+
"""
|
72 |
+
assert tensor.ndim == 1, "tensor must be 1D"
|
73 |
+
assert reference_tensor.ndim == 1, "reference_tensor must be 1D"
|
74 |
+
|
75 |
+
# Create the comparison matrix
|
76 |
+
bit_mask = tensor[None, :] == reference_tensor[:, None] # Shape: [len(reference_tensor), len(tensor)]
|
77 |
+
|
78 |
+
# Count the matches along `tensor` dimension
|
79 |
+
bit_mask_sum = bit_mask.int().sum(dim=1)
|
80 |
+
|
81 |
+
if (bit_mask_sum < 1).any():
|
82 |
+
raise ValueError(
|
83 |
+
'Some items in `reference_tensor` are missing from `tensor`: '
|
84 |
+
f'\n{reference_tensor} \nvs. \n{tensor}.'
|
85 |
+
)
|
86 |
+
|
87 |
+
if (bit_mask_sum > 1).any():
|
88 |
+
raise ValueError('Some items in `tensor` are repeated.')
|
89 |
+
|
90 |
+
# Compute indices
|
91 |
+
indices = torch.matmul(bit_mask.int(), torch.arange(tensor.shape[0], dtype=torch.int32))
|
92 |
+
return indices
|
93 |
+
|
94 |
+
|
95 |
+
def is_valid_sim_agent(track: scenario_pb2.Track) -> bool: # type: ignore
|
96 |
+
"""Checks if the object needs to be resimulated as a sim agent.
|
97 |
+
|
98 |
+
For the Sim Agents challenge, every object that is valid at the
|
99 |
+
`current_time_index` step (here hardcoded to 10) needs to be resimulated.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
track: A track proto for a single object.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
A boolean flag, True if the object needs to be resimulated, False otherwise.
|
106 |
+
"""
|
107 |
+
return track.states[submission_specs.CURRENT_TIME_INDEX].valid
|
108 |
+
|
109 |
+
|
110 |
+
def get_sim_agent_ids(
|
111 |
+
scenario: scenario_pb2.Scenario) -> Sequence[int]: # type: ignore
|
112 |
+
"""Returns the list of object IDs that needs to be resimulated.
|
113 |
+
|
114 |
+
Internally calls `is_valid_sim_agent` to verify the simulation criteria,
|
115 |
+
i.e. is the object valid at `current_time_index`.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
scenario: The Scenario proto containing the data.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
A list of int IDs, containing all the objects that need to be simulated.
|
122 |
+
"""
|
123 |
+
object_ids = []
|
124 |
+
for track in scenario.tracks:
|
125 |
+
if is_valid_sim_agent(track):
|
126 |
+
object_ids.append(track.id)
|
127 |
+
return object_ids
|
128 |
+
|
129 |
+
|
130 |
+
def get_evaluation_agent_ids(
|
131 |
+
scenario: scenario_pb2.Scenario) -> Sequence[int]: # type: ignore
|
132 |
+
# Start with the AV object.
|
133 |
+
object_ids = {scenario.tracks[scenario.sdc_track_index].id}
|
134 |
+
# Add the `tracks_to_predict` objects.
|
135 |
+
for required_prediction in scenario.tracks_to_predict:
|
136 |
+
object_ids.add(scenario.tracks[required_prediction.track_index].id)
|
137 |
+
return sorted(object_ids)
|
138 |
+
|
139 |
+
|
140 |
+
""" Base Data Classes s"""
|
141 |
+
|
142 |
+
@dataclass(frozen=True)
|
143 |
+
class ObjectTrajectories:
|
144 |
+
|
145 |
+
x: Tensor
|
146 |
+
y: Tensor
|
147 |
+
z: Tensor
|
148 |
+
heading: Tensor
|
149 |
+
length: Tensor
|
150 |
+
width: Tensor
|
151 |
+
height: Tensor
|
152 |
+
valid: Tensor
|
153 |
+
object_id: Tensor
|
154 |
+
object_type: Tensor
|
155 |
+
|
156 |
+
state: Optional[Tensor] = None
|
157 |
+
token_pos: Optional[Tensor] = None
|
158 |
+
token_heading: Optional[Tensor] = None
|
159 |
+
token_valid: Optional[Tensor] = None
|
160 |
+
processed_object_id: Optional[Tensor] = None
|
161 |
+
av_id: Optional[int] = None
|
162 |
+
processed_av_id: Optional[int] = None
|
163 |
+
|
164 |
+
def slice_time(self, start_index: int = 0, end_index: Optional[int] = None):
|
165 |
+
return ObjectTrajectories(
|
166 |
+
x=self.x[..., start_index:end_index],
|
167 |
+
y=self.y[..., start_index:end_index],
|
168 |
+
z=self.z[..., start_index:end_index],
|
169 |
+
heading=self.heading[..., start_index:end_index],
|
170 |
+
length=self.length[..., start_index:end_index],
|
171 |
+
width=self.width[..., start_index:end_index],
|
172 |
+
height=self.height[..., start_index:end_index],
|
173 |
+
valid=self.valid[..., start_index:end_index],
|
174 |
+
object_id=self.object_id,
|
175 |
+
object_type=self.object_type,
|
176 |
+
|
177 |
+
# these properties can only come from processed file
|
178 |
+
state=self.state,
|
179 |
+
token_pos=self.token_pos,
|
180 |
+
token_heading=self.token_heading,
|
181 |
+
token_valid=self.token_valid,
|
182 |
+
processed_object_id=self.processed_object_id,
|
183 |
+
av_id=self.av_id,
|
184 |
+
processed_av_id=self.processed_av_id,
|
185 |
+
)
|
186 |
+
|
187 |
+
def gather_objects(self, object_indices: Tensor):
|
188 |
+
assert object_indices.ndim == 1, "object_indices must be 1D"
|
189 |
+
return ObjectTrajectories(
|
190 |
+
x=torch.index_select(self.x, dim=-2, index=object_indices),
|
191 |
+
y=torch.index_select(self.y, dim=-2, index=object_indices),
|
192 |
+
z=torch.index_select(self.z, dim=-2, index=object_indices),
|
193 |
+
heading=torch.index_select(self.heading, dim=-2, index=object_indices),
|
194 |
+
length=torch.index_select(self.length, dim=-2, index=object_indices),
|
195 |
+
width=torch.index_select(self.width, dim=-2, index=object_indices),
|
196 |
+
height=torch.index_select(self.height, dim=-2, index=object_indices),
|
197 |
+
valid=torch.index_select(self.valid, dim=-2, index=object_indices),
|
198 |
+
object_id=torch.index_select(self.object_id, dim=-1, index=object_indices),
|
199 |
+
object_type=torch.index_select(self.object_type, dim=-1, index=object_indices),
|
200 |
+
|
201 |
+
# these properties can only come from processed file
|
202 |
+
state=self.state,
|
203 |
+
token_pos=self.token_pos,
|
204 |
+
token_heading=self.token_heading,
|
205 |
+
token_valid=self.token_valid,
|
206 |
+
processed_object_id=self.processed_object_id,
|
207 |
+
av_id=self.av_id,
|
208 |
+
processed_av_id=self.processed_av_id,
|
209 |
+
)
|
210 |
+
|
211 |
+
def gather_objects_by_id(self, object_ids: tf.Tensor):
|
212 |
+
indices = _arg_gather(self.object_id, object_ids)
|
213 |
+
return self.gather_objects(indices)
|
214 |
+
|
215 |
+
@classmethod
|
216 |
+
def _get_init_dict_from_processed(cls, scenario: dict):
|
217 |
+
"""Load from processed pkl data"""
|
218 |
+
position = scenario['agent']['position']
|
219 |
+
heading = scenario['agent']['heading']
|
220 |
+
shape = scenario['agent']['shape']
|
221 |
+
object_ids = scenario['agent']['id']
|
222 |
+
object_types = scenario['agent']['type']
|
223 |
+
valid = scenario['agent']['valid_mask']
|
224 |
+
|
225 |
+
init_dict = dict(x=position[..., 0],
|
226 |
+
y=position[..., 1],
|
227 |
+
z=position[..., 2],
|
228 |
+
heading=heading,
|
229 |
+
length=shape[..., 0],
|
230 |
+
width=shape[..., 1],
|
231 |
+
height=shape[..., 2],
|
232 |
+
valid=valid,
|
233 |
+
object_ids=object_ids,
|
234 |
+
object_types=object_types)
|
235 |
+
|
236 |
+
return init_dict
|
237 |
+
|
238 |
+
@classmethod
|
239 |
+
def _get_init_dict_from_raw(cls,
|
240 |
+
scenario: scenario_pb2.Scenario): # type: ignore
|
241 |
+
|
242 |
+
"""Load from tfrecords data"""
|
243 |
+
states, dimensions, objects = [], [], []
|
244 |
+
for track in scenario.tracks: # n_object
|
245 |
+
# Iterate over a single object's states.
|
246 |
+
track_states, track_dimensions = [], []
|
247 |
+
for state in track.states: # n_timestep
|
248 |
+
track_states.append((state.center_x, state.center_y, state.center_z,
|
249 |
+
state.heading, state.valid))
|
250 |
+
track_dimensions.append((state.length, state.width, state.height))
|
251 |
+
# Adds to the global states.
|
252 |
+
states.append(list(zip(*track_states)))
|
253 |
+
dimensions.append(list(zip(*track_dimensions)))
|
254 |
+
objects.append((track.id, track.object_type))
|
255 |
+
|
256 |
+
# Unpack and convert to tf tensors.
|
257 |
+
x, y, z, heading, valid = [torch.tensor(s) for s in zip(*states)]
|
258 |
+
length, width, height = [torch.tensor(s) for s in zip(*dimensions)]
|
259 |
+
object_ids, object_types = [torch.tensor(s) for s in zip(*objects)]
|
260 |
+
|
261 |
+
av_id = object_ids[scenario.sdc_track_index]
|
262 |
+
|
263 |
+
init_dict = dict(x=x, y=y, z=z,
|
264 |
+
heading=heading,
|
265 |
+
length=length,
|
266 |
+
width=width,
|
267 |
+
height=height,
|
268 |
+
valid=valid,
|
269 |
+
object_id=object_ids,
|
270 |
+
object_type=object_types,
|
271 |
+
av_id=int(av_id))
|
272 |
+
|
273 |
+
return init_dict
|
274 |
+
|
275 |
+
@classmethod
|
276 |
+
def from_scenario(cls,
|
277 |
+
scenario: scenario_pb2.Scenario, # type: ignore
|
278 |
+
processed_scenario: Optional[dict]=None,
|
279 |
+
from_where: str='raw'):
|
280 |
+
|
281 |
+
if from_where == 'raw':
|
282 |
+
init_dict = cls._get_init_dict_from_raw(scenario)
|
283 |
+
elif from_where == 'processed':
|
284 |
+
assert processed_scenario is not None, f'`processed_scenario` should be given!'
|
285 |
+
init_dict = cls._get_init_dict_from_processed(processed_scenario)
|
286 |
+
else:
|
287 |
+
raise RuntimeError(f'Invalid from {from_where}')
|
288 |
+
|
289 |
+
if processed_scenario is not None:
|
290 |
+
init_dict.update(state=processed_scenario['agent']['state_idx'],
|
291 |
+
token_pos=processed_scenario['agent']['token_pos'],
|
292 |
+
token_heading=processed_scenario['agent']['token_heading'],
|
293 |
+
token_valid=processed_scenario['agent']['raw_agent_valid_mask'],
|
294 |
+
processed_object_id=processed_scenario['agent']['id'],
|
295 |
+
processed_av_id=int(processed_scenario['agent']['id'][
|
296 |
+
processed_scenario['agent']['av_idx']
|
297 |
+
]),
|
298 |
+
)
|
299 |
+
|
300 |
+
return cls(**init_dict)
|
301 |
+
|
302 |
+
|
303 |
+
@dataclass
|
304 |
+
class ScenarioRollouts:
|
305 |
+
scenario_id: Optional[str] = None
|
306 |
+
joint_scenes: List[ObjectTrajectories] = field(default_factory=list)
|
307 |
+
|
308 |
+
|
309 |
+
""" Conversion Methods """
|
310 |
+
|
311 |
+
def scenario_to_trajectories(
|
312 |
+
scenario: scenario_pb2.Scenario, # type: ignore
|
313 |
+
processed_scenario: Optional[dict]=None,
|
314 |
+
from_where: Optional[str]='raw',
|
315 |
+
remove_history: Optional[bool]=False
|
316 |
+
) -> ObjectTrajectories:
|
317 |
+
"""Converts a WOMD Scenario proto into the `ObjectTrajectories`.
|
318 |
+
|
319 |
+
Returns:
|
320 |
+
A `ObjectTrajectories` with trajectories copied from data.
|
321 |
+
"""
|
322 |
+
trajectories = ObjectTrajectories.from_scenario(scenario,
|
323 |
+
processed_scenario,
|
324 |
+
from_where,
|
325 |
+
)
|
326 |
+
# Slice by the required sim agents.
|
327 |
+
sim_agent_ids = get_sim_agent_ids(scenario)
|
328 |
+
# CONSOLE.log(f'sim_agent_ids of log scenario: {sim_agent_ids} total: {len(sim_agent_ids)}')
|
329 |
+
trajectories = trajectories.gather_objects_by_id(torch.tensor(sim_agent_ids))
|
330 |
+
|
331 |
+
if remove_history:
|
332 |
+
# Slice in time to only include steps after `current_time_index`.
|
333 |
+
trajectories = trajectories.slice_time(submission_specs.CURRENT_TIME_INDEX + 1) # 10 + 1
|
334 |
+
if trajectories.valid.shape[-1] != submission_specs.N_SIMULATION_STEPS: # 80 simulated steps
|
335 |
+
raise ValueError(
|
336 |
+
'The Scenario used does not include the right number of time steps. '
|
337 |
+
f'Expected: {submission_specs.N_SIMULATION_STEPS}, '
|
338 |
+
f'Actual: {trajectories.valid.shape[-1]}.')
|
339 |
+
|
340 |
+
return trajectories
|
341 |
+
|
342 |
+
|
343 |
+
def _unbatch(src: Tensor, batch: Tensor, dim: int = 0) -> List[Tensor]:
|
344 |
+
sizes = degree(batch, dtype=torch.long).tolist()
|
345 |
+
return src.split(sizes, dim)
|
346 |
+
|
347 |
+
|
348 |
+
def get_scenario_id_int_tensor(scenario_id: List[str], device: torch.device=torch.device('cpu')) -> torch.Tensor:
|
349 |
+
scenario_id_int_tensor = []
|
350 |
+
for str_id in scenario_id:
|
351 |
+
int_id = [-1] * 16 # max_len of scenario_id string is 16
|
352 |
+
for i, c in enumerate(str_id):
|
353 |
+
int_id[i] = ord(c)
|
354 |
+
scenario_id_int_tensor.append(
|
355 |
+
torch.tensor(int_id, dtype=torch.int32, device=device)
|
356 |
+
)
|
357 |
+
return torch.stack(scenario_id_int_tensor, dim=0) # [n_scenario, 16]
|
358 |
+
|
359 |
+
|
360 |
+
def output_to_rollouts(scenario: dict) -> List[ScenarioRollouts]: # n_scenario
|
361 |
+
# scenario_id: Tensor, # [n_scenario, n_str_length]
|
362 |
+
# agent_id: Tensor, # [n_agent, n_rollout]
|
363 |
+
# agent_batch: Tensor, # [n_agent]
|
364 |
+
# pred_traj: Tensor, # [n_agent, n_rollout, n_step, 2]
|
365 |
+
# pred_z: Tensor, # [n_agent, n_rollout, n_step]
|
366 |
+
# pred_head: Tensor, # [n_agent, n_rollout, n_step]
|
367 |
+
# pred_shape: Tensor, # [n_agent, n_rollout, 3]
|
368 |
+
# pred_type: Tensor, # [n_agent, n_rollout]
|
369 |
+
# pred_state: Tensor, # [n_agent, n_rollout, n_step]
|
370 |
+
scenario_id = scenario['scenario_id']
|
371 |
+
av_id = (
|
372 |
+
scenario['av_id'] if 'av_id' in scenario else -1
|
373 |
+
)
|
374 |
+
agent_id = scenario['agent_id']
|
375 |
+
agent_batch = scenario['agent_batch']
|
376 |
+
pred_traj = scenario['pred_traj']
|
377 |
+
pred_z = scenario['pred_z']
|
378 |
+
pred_head = scenario['pred_head']
|
379 |
+
pred_shape = scenario['pred_shape']
|
380 |
+
pred_type = scenario['pred_type']
|
381 |
+
pred_state = (
|
382 |
+
scenario['pred_state'] if 'pred_state' in scenario else
|
383 |
+
torch.zeros_like(pred_z).long()
|
384 |
+
)
|
385 |
+
pred_valid = scenario['pred_valid']
|
386 |
+
token_pos = scenario['token_pos']
|
387 |
+
token_head = scenario['token_head']
|
388 |
+
|
389 |
+
# CONSOLE.log("Generate scenario rollouts ...")
|
390 |
+
# CONSOLE.log(f'scenario_id: {scenario_id}')
|
391 |
+
# CONSOLE.log(f'agent_id: {agent_id.flatten()} total: {agent_id.shape}')
|
392 |
+
# CONSOLE.log(f'av_id: {av_id}')
|
393 |
+
# CONSOLE.log(f'agent_batch: {agent_batch} total: {agent_batch.shape}')
|
394 |
+
# CONSOLE.log(f'pred_traj: {pred_traj.shape}')
|
395 |
+
# CONSOLE.log(f'pred_z: {pred_z.shape}')
|
396 |
+
# CONSOLE.log(f'pred_head: {pred_head.shape}')
|
397 |
+
# CONSOLE.log(f'pred_shape: {pred_shape.shape}')
|
398 |
+
# CONSOLE.log(f'pred_type: {pred_type.shape}')
|
399 |
+
# CONSOLE.log(f'pred_state: {pred_state.shape}')
|
400 |
+
# CONSOLE.log(f'token_pos: {token_pos.shape}')
|
401 |
+
# CONSOLE.log(f'token_head: {token_head.shape}')
|
402 |
+
|
403 |
+
scenario_id = scenario_id.cpu().numpy()
|
404 |
+
n_agent, n_rollout, n_step, _ = pred_traj.shape
|
405 |
+
agent_id = _unbatch(agent_id, agent_batch)
|
406 |
+
pred_traj = _unbatch(pred_traj, agent_batch)
|
407 |
+
pred_z = _unbatch(pred_z, agent_batch)
|
408 |
+
pred_head = _unbatch(pred_head, agent_batch)
|
409 |
+
pred_shape = _unbatch(pred_shape, agent_batch)
|
410 |
+
pred_type = _unbatch(pred_type, agent_batch)
|
411 |
+
pred_state = _unbatch(pred_state, agent_batch)
|
412 |
+
pred_valid = _unbatch(pred_valid, agent_batch)
|
413 |
+
token_pos = _unbatch(token_pos, agent_batch)
|
414 |
+
token_head = _unbatch(token_head, agent_batch)
|
415 |
+
|
416 |
+
agent_id = [x.cpu() for x in agent_id]
|
417 |
+
pred_traj = [x.cpu() for x in pred_traj]
|
418 |
+
pred_z = [x.cpu() for x in pred_z]
|
419 |
+
pred_head = [x.cpu() for x in pred_head]
|
420 |
+
pred_shape = [x[:, :, None].repeat(1, 1, n_step, 1).cpu() for x in pred_shape]
|
421 |
+
pred_type = [x[:, :, None].repeat(1, 1, n_step, 1).cpu() for x in pred_type]
|
422 |
+
pred_state = [x.cpu() for x in pred_state]
|
423 |
+
pred_valid = [x.cpu() for x in pred_valid]
|
424 |
+
token_pos = [x.cpu() for x in token_pos]
|
425 |
+
token_head = [x.cpu() for x in token_head]
|
426 |
+
|
427 |
+
n_scenario = scenario_id.shape[0]
|
428 |
+
scenario_rollouts = []
|
429 |
+
for i_scenario in range(n_scenario):
|
430 |
+
joint_scenes = []
|
431 |
+
for i_rollout in range(n_rollout): # 1
|
432 |
+
joint_scenes.append(
|
433 |
+
ObjectTrajectories(
|
434 |
+
x=pred_traj[i_scenario][:, i_rollout, :, 0],
|
435 |
+
y=pred_traj[i_scenario][:, i_rollout, :, 1],
|
436 |
+
z=pred_z[i_scenario][:, i_rollout],
|
437 |
+
heading=pred_head[i_scenario][:, i_rollout],
|
438 |
+
length=pred_shape[i_scenario][:, i_rollout, :, 0],
|
439 |
+
width=pred_shape[i_scenario][:, i_rollout, :, 1],
|
440 |
+
height=pred_shape[i_scenario][:, i_rollout, :, 2],
|
441 |
+
valid=pred_valid[i_scenario][:, i_rollout],
|
442 |
+
state=pred_state[i_scenario][:, i_rollout],
|
443 |
+
object_id=agent_id[i_scenario][:, i_rollout],
|
444 |
+
processed_object_id=agent_id[i_scenario][:, i_rollout],
|
445 |
+
object_type=pred_type[i_scenario][:, i_rollout],
|
446 |
+
token_pos=token_pos[i_scenario][:, i_rollout, :, :2],
|
447 |
+
token_heading=token_head[i_scenario][:, i_rollout],
|
448 |
+
av_id=av_id,
|
449 |
+
processed_av_id=av_id,
|
450 |
+
)
|
451 |
+
)
|
452 |
+
|
453 |
+
_str_scenario_id = "".join([chr(x) for x in scenario_id[i_scenario] if x > 0])
|
454 |
+
scenario_rollouts.append(
|
455 |
+
ScenarioRollouts(
|
456 |
+
joint_scenes=joint_scenes, scenario_id=_str_scenario_id
|
457 |
+
)
|
458 |
+
)
|
459 |
+
|
460 |
+
# CONSOLE.log(f'n_scenario: {len(scenario_rollouts)}')
|
461 |
+
# CONSOLE.log(f'n_rollout: {len(scenario_rollouts[0].joint_scenes)}')
|
462 |
+
# CONSOLE.log(f'x shape: {scenario_rollouts[0].joint_scenes[0].x.shape}')
|
463 |
+
|
464 |
+
return scenario_rollouts
|
465 |
+
|
466 |
+
|
467 |
+
""" Compute Metric Features """
|
468 |
+
|
469 |
+
def _compute_metametric(
|
470 |
+
config: long_metrics_pb2.SimAgentMetricsConfig, # type: ignore
|
471 |
+
metrics: long_metrics_pb2.SimAgentMetrics, # type: ignore
|
472 |
+
):
|
473 |
+
"""Computes the meta-metric aggregation."""
|
474 |
+
metametric = 0.0
|
475 |
+
for field_name in _METRIC_FIELD_NAMES:
|
476 |
+
likelihood_field_name = field_name + '_likelihood'
|
477 |
+
weight = getattr(config, field_name).metametric_weight
|
478 |
+
metric_score = getattr(metrics, likelihood_field_name)
|
479 |
+
metametric += weight * metric_score
|
480 |
+
return metametric
|
481 |
+
|
482 |
+
|
483 |
+
@dataclasses.dataclass(frozen=True)
|
484 |
+
class MetricFeatures:
|
485 |
+
|
486 |
+
object_id: Tensor
|
487 |
+
valid: Tensor
|
488 |
+
linear_speed: Tensor
|
489 |
+
linear_acceleration: Tensor
|
490 |
+
angular_speed: Tensor
|
491 |
+
angular_acceleration: Tensor
|
492 |
+
distance_to_nearest_object: Tensor
|
493 |
+
collision_per_step: Tensor
|
494 |
+
time_to_collision: Tensor
|
495 |
+
distance_to_road_edge: Tensor
|
496 |
+
offroad_per_step: Tensor
|
497 |
+
num_placement: Tensor
|
498 |
+
num_removement: Tensor
|
499 |
+
distance_placement: Tensor
|
500 |
+
distance_removement: Tensor
|
501 |
+
|
502 |
+
@classmethod
|
503 |
+
def from_file(cls, file_path: str):
|
504 |
+
|
505 |
+
if not os.path.exists(file_path):
|
506 |
+
raise FileNotFoundError(f'Not found file {file_path}')
|
507 |
+
|
508 |
+
with open(file_path, 'rb') as f:
|
509 |
+
feat_dict = pickle.load(f)
|
510 |
+
|
511 |
+
fields = [field.name for field in dataclasses.fields(cls)]
|
512 |
+
init_dict = dict()
|
513 |
+
|
514 |
+
for field in fields:
|
515 |
+
if field in feat_dict:
|
516 |
+
init_dict[field] = feat_dict[field]
|
517 |
+
else:
|
518 |
+
init_dict[field] = None
|
519 |
+
|
520 |
+
return cls(**init_dict)
|
521 |
+
|
522 |
+
def unfold(self, size: int, step: int):
|
523 |
+
return MetricFeatures(
|
524 |
+
object_id=self.object_id,
|
525 |
+
valid=self.valid.unfold(1, size, step),
|
526 |
+
linear_speed=self.linear_speed.unfold(1, size, step),
|
527 |
+
linear_acceleration=self.linear_acceleration.unfold(1, size, step),
|
528 |
+
angular_speed=self.angular_speed.unfold(1, size, step),
|
529 |
+
angular_acceleration=self.angular_acceleration.unfold(1, size, step),
|
530 |
+
distance_to_nearest_object=self.distance_to_nearest_object.unfold(1, size, step),
|
531 |
+
collision_per_step=self.collision_per_step.unfold(1, size, step),
|
532 |
+
time_to_collision=self.time_to_collision.unfold(1, size, step),
|
533 |
+
distance_to_road_edge=self.distance_to_road_edge.unfold(1, size, step),
|
534 |
+
offroad_per_step=self.offroad_per_step.unfold(1, size, step),
|
535 |
+
num_placement=self.num_placement.unfold(1, size // SHIFT, step // SHIFT),
|
536 |
+
num_removement=self.num_removement.unfold(1, size // SHIFT, step // SHIFT),
|
537 |
+
distance_placement=self.distance_placement.unfold(1, size // SHIFT, step // SHIFT),
|
538 |
+
distance_removement=self.distance_removement.unfold(1, size // SHIFT, step // SHIFT),
|
539 |
+
)
|
540 |
+
|
541 |
+
|
542 |
+
def compute_metric_features(
|
543 |
+
simulate_trajectories: ObjectTrajectories,
|
544 |
+
evaluate_agent_ids: Optional[Tensor]=None,
|
545 |
+
scenario_log: Optional[scenario_pb2.Scenario]=None, # type: ignore
|
546 |
+
) -> MetricFeatures:
|
547 |
+
|
548 |
+
if evaluate_agent_ids is not None:
|
549 |
+
evaluate_trajectories = simulate_trajectories.gather_objects_by_id(
|
550 |
+
evaluate_agent_ids
|
551 |
+
)
|
552 |
+
else:
|
553 |
+
evaluate_trajectories = simulate_trajectories
|
554 |
+
|
555 |
+
# valid mask
|
556 |
+
validity_mask = evaluate_trajectories.valid
|
557 |
+
validity_mask = validity_mask[:, submission_specs.CURRENT_TIME_INDEX + 1:]
|
558 |
+
|
559 |
+
# ! Kinematics-related features, i.e. speed and acceleration, this needs
|
560 |
+
# history steps to be prepended to make the first evaluate step valid.
|
561 |
+
# Resulted `lienar_speed` and others: (n_object_to_evaluate, n_future_step)
|
562 |
+
linear_speed, linear_accel, angular_speed, angular_accel = (
|
563 |
+
trajectory_features.compute_kinematic_features(
|
564 |
+
evaluate_trajectories.x,
|
565 |
+
evaluate_trajectories.y,
|
566 |
+
evaluate_trajectories.z,
|
567 |
+
evaluate_trajectories.heading,
|
568 |
+
seconds_per_step=submission_specs.STEP_DURATION_SECONDS))
|
569 |
+
# Removes the data corresponding to the history time interval.
|
570 |
+
linear_speed, linear_accel, angular_speed, angular_accel = (
|
571 |
+
map(lambda t: t[:, submission_specs.CURRENT_TIME_INDEX + 1:],
|
572 |
+
[linear_speed, linear_accel, angular_speed, angular_accel])
|
573 |
+
)
|
574 |
+
|
575 |
+
# ! Distances to nearest objects.
|
576 |
+
# evaluate_object_mask = torch.any(
|
577 |
+
# evaluate_agent_ids[:, None] == simulated_trajectories.object_id, axis=0
|
578 |
+
# )
|
579 |
+
evaluate_object_mask = torch.ones(len(simulate_trajectories.object_id)).bool()
|
580 |
+
distances_to_objects = interact_features.compute_distance_to_nearest_object(
|
581 |
+
center_x=simulate_trajectories.x,
|
582 |
+
center_y=simulate_trajectories.y,
|
583 |
+
center_z=simulate_trajectories.z,
|
584 |
+
length=simulate_trajectories.length,
|
585 |
+
width=simulate_trajectories.width,
|
586 |
+
height=simulate_trajectories.height,
|
587 |
+
heading=simulate_trajectories.heading,
|
588 |
+
valid=simulate_trajectories.valid,
|
589 |
+
evaluated_object_mask=evaluate_object_mask,
|
590 |
+
)
|
591 |
+
distances_to_objects = (
|
592 |
+
distances_to_objects[:, submission_specs.CURRENT_TIME_INDEX + 1:])
|
593 |
+
is_colliding_per_step = torch.lt(
|
594 |
+
distances_to_objects, interact_features.COLLISION_DISTANCE_THRESHOLD)
|
595 |
+
|
596 |
+
# ! Time to collision
|
597 |
+
times_to_collision = (
|
598 |
+
interact_features.compute_time_to_collision_with_object_in_front(
|
599 |
+
center_x=simulate_trajectories.x,
|
600 |
+
center_y=simulate_trajectories.y,
|
601 |
+
length=simulate_trajectories.length,
|
602 |
+
width=simulate_trajectories.width,
|
603 |
+
heading=simulate_trajectories.heading,
|
604 |
+
valid=simulate_trajectories.valid,
|
605 |
+
evaluated_object_mask=evaluate_object_mask,
|
606 |
+
seconds_per_step=submission_specs.STEP_DURATION_SECONDS,
|
607 |
+
)
|
608 |
+
)
|
609 |
+
times_to_collision = times_to_collision[:, submission_specs.CURRENT_TIME_INDEX + 1:]
|
610 |
+
|
611 |
+
# ! Distance to road edge
|
612 |
+
distances_to_road_edge = torch.empty_like(distances_to_objects)
|
613 |
+
is_offroad_per_step = torch.empty_like(is_colliding_per_step)
|
614 |
+
if scenario_log is not None:
|
615 |
+
road_edges = []
|
616 |
+
for map_feature in scenario_log.map_features:
|
617 |
+
if map_feature.HasField('road_edge'):
|
618 |
+
road_edges.append(map_feature.road_edge.polyline)
|
619 |
+
distances_to_road_edge = map_features.compute_distance_to_road_edge(
|
620 |
+
center_x=simulate_trajectories.x,
|
621 |
+
center_y=simulate_trajectories.y,
|
622 |
+
center_z=simulate_trajectories.z,
|
623 |
+
length=simulate_trajectories.length,
|
624 |
+
width=simulate_trajectories.width,
|
625 |
+
height=simulate_trajectories.height,
|
626 |
+
heading=simulate_trajectories.heading,
|
627 |
+
valid=simulate_trajectories.valid,
|
628 |
+
evaluated_object_mask=evaluate_object_mask,
|
629 |
+
road_edge_polylines=road_edges,
|
630 |
+
)
|
631 |
+
distances_to_road_edge = distances_to_road_edge[:, submission_specs.CURRENT_TIME_INDEX + 1:]
|
632 |
+
is_offroad_per_step = torch.gt(
|
633 |
+
distances_to_road_edge, map_features.OFFROAD_DISTANCE_THRESHOLD
|
634 |
+
)
|
635 |
+
|
636 |
+
# ! Placement
|
637 |
+
if simulate_trajectories.av_id == simulate_trajectories.processed_av_id == -1:
|
638 |
+
n_agent, n_step_10hz = linear_speed.shape
|
639 |
+
num_placement = torch.zeros((n_step_10hz // SHIFT,))
|
640 |
+
num_removement = torch.zeros((n_step_10hz // SHIFT,))
|
641 |
+
distance_placement = torch.zeros((n_agent, n_step_10hz // SHIFT))
|
642 |
+
distance_removement = torch.zeros((n_agent, n_step_10hz // SHIFT))
|
643 |
+
|
644 |
+
else:
|
645 |
+
assert simulate_trajectories.av_id == simulate_trajectories.processed_av_id, \
|
646 |
+
f"Got duplicated av_id: {simulate_trajectories.av_id} and {simulate_trajectories.processed_av_id}"
|
647 |
+
num_placement, num_removement = (
|
648 |
+
placement_features.compute_num_placement(
|
649 |
+
state=simulate_trajectories.state,
|
650 |
+
valid=simulate_trajectories.token_valid,
|
651 |
+
av_id=simulate_trajectories.processed_av_id,
|
652 |
+
object_id=simulate_trajectories.processed_object_id,
|
653 |
+
agent_state=AGENT_STATE,
|
654 |
+
)
|
655 |
+
)
|
656 |
+
num_placement = num_placement[submission_specs.CURRENT_TIME_INDEX // SHIFT:]
|
657 |
+
num_removement = num_removement[submission_specs.CURRENT_TIME_INDEX // SHIFT:]
|
658 |
+
distance_placement, distance_removement = (
|
659 |
+
placement_features.compute_distance_placement(
|
660 |
+
position=simulate_trajectories.token_pos,
|
661 |
+
state=simulate_trajectories.state,
|
662 |
+
valid=simulate_trajectories.valid,
|
663 |
+
av_id=simulate_trajectories.processed_av_id,
|
664 |
+
object_id=simulate_trajectories.processed_object_id,
|
665 |
+
agent_state=AGENT_STATE,
|
666 |
+
)
|
667 |
+
)
|
668 |
+
distance_placement = distance_placement[:, submission_specs.CURRENT_TIME_INDEX // SHIFT:]
|
669 |
+
distance_removement = distance_removement[:, submission_specs.CURRENT_TIME_INDEX // SHIFT:]
|
670 |
+
# distance_placement = distance_placement[distance_placement > 0]
|
671 |
+
# distance_removement = distance_removement[distance_removement > 0]
|
672 |
+
|
673 |
+
# print out some results for debugging
|
674 |
+
# CONSOLE.log(f'trajectory x: {simulate_trajectories.x.shape}, \n{simulate_trajectories.x}')
|
675 |
+
# CONSOLE.log(f'linear speed: {linear_speed.shape}, \n{linear_speed}')
|
676 |
+
# CONSOLE.log(f'distances: {distances_to_objects.shape}, \n{distances_to_objects}')
|
677 |
+
# CONSOLE.log(f'time to collision: {times_to_collision.shape}, {times_to_collision}')
|
678 |
+
|
679 |
+
return MetricFeatures(
|
680 |
+
object_id=simulate_trajectories.object_id,
|
681 |
+
valid=validity_mask,
|
682 |
+
# kinematic
|
683 |
+
linear_speed=linear_speed,
|
684 |
+
linear_acceleration=linear_accel,
|
685 |
+
angular_speed=angular_speed,
|
686 |
+
angular_acceleration=angular_accel,
|
687 |
+
# interact
|
688 |
+
distance_to_nearest_object=distances_to_objects,
|
689 |
+
collision_per_step=is_colliding_per_step,
|
690 |
+
time_to_collision=times_to_collision,
|
691 |
+
# map
|
692 |
+
distance_to_road_edge=distances_to_road_edge,
|
693 |
+
offroad_per_step=is_offroad_per_step,
|
694 |
+
# placement
|
695 |
+
num_placement=num_placement[None, ...],
|
696 |
+
num_removement=num_removement[None, ...],
|
697 |
+
distance_placement=distance_placement,
|
698 |
+
distance_removement=distance_removement,
|
699 |
+
)
|
700 |
+
|
701 |
+
|
702 |
+
@dataclass(frozen=True)
|
703 |
+
class LogDistributions:
|
704 |
+
|
705 |
+
linear_speed: Tensor
|
706 |
+
linear_acceleration: Tensor
|
707 |
+
angular_speed: Tensor
|
708 |
+
angular_acceleration: Tensor
|
709 |
+
distance_to_nearest_object: Tensor
|
710 |
+
collision_indication: Tensor
|
711 |
+
time_to_collision: Tensor
|
712 |
+
distance_to_road_edge: Tensor
|
713 |
+
num_placement: Tensor
|
714 |
+
num_removement: Tensor
|
715 |
+
distance_placement: Tensor
|
716 |
+
distance_removement: Tensor
|
717 |
+
offroad_indication: Optional[Tensor] = None
|
718 |
+
|
719 |
+
|
720 |
+
""" Compute Metrics """
|
721 |
+
|
722 |
+
def _assert_and_return_batch_size(
|
723 |
+
log_samples: Tensor,
|
724 |
+
sim_samples: Tensor
|
725 |
+
) -> int:
|
726 |
+
"""Asserts consistency in the tensor shapes and returns batch size.
|
727 |
+
|
728 |
+
Args:
|
729 |
+
log_samples: A tensor of shape (batch_size, log_sample_size).
|
730 |
+
sim_samples: A tensor of shape (batch_size, sim_sample_size).
|
731 |
+
|
732 |
+
Returns:
|
733 |
+
The `batch_size`.
|
734 |
+
"""
|
735 |
+
assert log_samples.shape[0] == sim_samples.shape[0], "Log and Sim batch sizes must be equal."
|
736 |
+
return log_samples.shape[0]
|
737 |
+
|
738 |
+
|
739 |
+
def _reduce_average_with_validity(
|
740 |
+
tensor: Tensor, validity: Tensor) -> Tensor:
|
741 |
+
"""Returns the tensor's average, only selecting valid items.
|
742 |
+
|
743 |
+
Args:
|
744 |
+
tensor: A float tensor of any shape.
|
745 |
+
validity: A boolean tensor of the same shape as `tensor`.
|
746 |
+
|
747 |
+
Returns:
|
748 |
+
A float tensor of shape (1,), containing the average of the valid elements
|
749 |
+
of `tensor`.
|
750 |
+
"""
|
751 |
+
if tensor.shape != validity.shape:
|
752 |
+
raise ValueError('Shapes of `tensor` and `validity` must be the same.'
|
753 |
+
f'(Actual: {tensor.shape}, {validity.shape}).')
|
754 |
+
cond_sum = torch.sum(torch.where(validity, tensor, torch.zeros_like(tensor)))
|
755 |
+
valid_sum = torch.sum(validity)
|
756 |
+
if valid_sum == 0:
|
757 |
+
return torch.tensor(0.)
|
758 |
+
return cond_sum / valid_sum
|
759 |
+
|
760 |
+
|
761 |
+
def histogram_estimate(
|
762 |
+
config: long_metrics_pb2.SimAgentMetricsConfig.HistogramEstimate, # type: ignore
|
763 |
+
log_samples: Tensor,
|
764 |
+
sim_samples: Tensor,
|
765 |
+
) -> Tensor:
|
766 |
+
"""Computes log-likelihoods of samples based on histograms.
|
767 |
+
|
768 |
+
Args:
|
769 |
+
config: A configuration dictionary, similar to the one in TensorFlow.
|
770 |
+
log_samples: A tensor of shape (batch_size, log_sample_size),
|
771 |
+
containing `log_sample_size` samples from `batch_size` independent
|
772 |
+
populations.
|
773 |
+
sim_samples: A tensor of shape (batch_size, sim_sample_size),
|
774 |
+
containing `sim_sample_size` samples from `batch_size` independent
|
775 |
+
populations.
|
776 |
+
|
777 |
+
Returns:
|
778 |
+
A tensor of shape (batch_size, log_sample_size), where each element (i, k)
|
779 |
+
is the log likelihood of the log sample (i, k) under the sim distribution
|
780 |
+
(i).
|
781 |
+
"""
|
782 |
+
batch_size = _assert_and_return_batch_size(log_samples, sim_samples)
|
783 |
+
|
784 |
+
# We generate `num_bins`+1 edges for the histogram buckets.
|
785 |
+
edges = torch.linspace(
|
786 |
+
config.min_val, config.max_val, config.num_bins + 1
|
787 |
+
).float()
|
788 |
+
|
789 |
+
# Clip the samples to avoid errors with histograms.
|
790 |
+
log_samples = torch.clamp(log_samples, config.min_val, config.max_val)
|
791 |
+
sim_samples = torch.clamp(sim_samples, config.min_val, config.max_val)
|
792 |
+
|
793 |
+
# Create the categorical distribution for simulation. `tfp.histogram` returns
|
794 |
+
# a tensor of shape (num_bins, batch_size), so we need to transpose to conform
|
795 |
+
# to `tfp.distribution.Categorical`, which requires `probs` to be
|
796 |
+
# (batch_size, num_bins).
|
797 |
+
sim_counts = torch.vmap(lambda x: torch.histogram(x, bins=edges).hist)(sim_samples)
|
798 |
+
sim_counts += config.additive_smoothing_pseudocount
|
799 |
+
distributions = torch.distributions.Categorical(probs=sim_counts)
|
800 |
+
|
801 |
+
# Generate the counts for the log distribution. We reshape the log samples to
|
802 |
+
# (batch_size * log_sample_size, 1), so every log sample is independently
|
803 |
+
# scored.
|
804 |
+
log_values_flat = log_samples.reshape(-1, 1)
|
805 |
+
# Shape of log_counts: (batch_size * log_sample_size, num_bins).
|
806 |
+
log_counts = torch.vmap(lambda x: torch.histogram(x, bins=edges).hist)(log_values_flat)
|
807 |
+
# Identify which bin each sample belongs to and get the log probability of
|
808 |
+
# that bin under the sim distribution.
|
809 |
+
max_log_bin = log_counts.argmax(dim=-1)
|
810 |
+
batched_max_log_bin = max_log_bin.reshape(batch_size, -1)
|
811 |
+
|
812 |
+
# Since we have defined the categorical distribution to have `batch_size`
|
813 |
+
# independent populations, tfp expects this `batch_size` to be in the last
|
814 |
+
# dimension of the tensor, so transpose the log bins to
|
815 |
+
# (log_sample_size, batch_size).
|
816 |
+
log_likelihood = distributions.log_prob(batched_max_log_bin.transpose(0, 1))
|
817 |
+
|
818 |
+
# Return log likelihood in the shape (batch_size, log_sample_size)
|
819 |
+
return log_likelihood.transpose(0, 1)
|
820 |
+
|
821 |
+
|
822 |
+
def log_likelihood_estimate_timeseries(
|
823 |
+
field: str,
|
824 |
+
feature_config: long_metrics_pb2.SimAgentMetricsConfig.FeatureConfig, # type: ignore
|
825 |
+
sim_values: Tensor,
|
826 |
+
log_distributions: torch.distributions.Categorical,
|
827 |
+
estimate_method: str='histogram',
|
828 |
+
) -> Tensor:
|
829 |
+
"""Computes the log-likelihood estimates for a time-series simulated feature.
|
830 |
+
|
831 |
+
Args:
|
832 |
+
feature_config: A time-series compatible `FeatureConfig`.
|
833 |
+
log_distributions: A float Tensor with shape (batch_sizie, n_bins).
|
834 |
+
sim_values: A float Tensor with shape (n_objects / n_scenarios, n_segments, n_steps).
|
835 |
+
|
836 |
+
Returns:
|
837 |
+
A tensor of shape (n_objects, n_steps) containing the simulation probability
|
838 |
+
estimates of the simulation features under the logged distribution of the same
|
839 |
+
feature.
|
840 |
+
"""
|
841 |
+
assert sim_values.ndim == 3, f'Expect sim_values.ndim==3, got {sim_values.ndim}, shape {sim_values.shape} for {field}'
|
842 |
+
|
843 |
+
sim_values_flat = sim_values.reshape(-1, 1) # [n_objects * n_segments * n_steps]
|
844 |
+
|
845 |
+
# if not feature_config.independent_timesteps:
|
846 |
+
# # If time steps needs to be considered independent, reshape:
|
847 |
+
# # - `sim_values` as (n_objects, n_rollouts * n_steps)
|
848 |
+
# # - `log_values` as (n_objects, n_steps)
|
849 |
+
# # If values in time are instead to be compared per-step, reshape:
|
850 |
+
# # - `sim_values` as (n_objects * n_steps, n_rollouts)
|
851 |
+
# # - `log_values` as (n_objects * n_steps, 1)
|
852 |
+
# sim_values = sim_values.reshape(-1, 1) # n_rollouts=1
|
853 |
+
|
854 |
+
# if feature_config.independent_timesteps:
|
855 |
+
# sim_values = sim_values.permute(1, 0, 2).reshape(n_objects, n_rollouts * n_steps)
|
856 |
+
# else:
|
857 |
+
# sim_values = sim_values.permute(1, 2, 0).reshape(n_objects * n_steps, n_rollouts)
|
858 |
+
# log_values = log_values.reshape(n_objects * n_steps, 1)
|
859 |
+
|
860 |
+
# ! calculate distributions for simulate features
|
861 |
+
if estimate_method == 'histogram':
|
862 |
+
config = feature_config.histogram
|
863 |
+
elif estimate_method == 'bernoulli':
|
864 |
+
config = (
|
865 |
+
long_metrics_pb2.SimAgentMetricsConfig.HistogramEstimate(
|
866 |
+
min_val=-0.5, max_val=0.5, num_bins=2,
|
867 |
+
additive_smoothing_pseudocount=feature_config.bernoulli.additive_smoothing_pseudocount
|
868 |
+
)
|
869 |
+
)
|
870 |
+
sim_values_flat = sim_values_flat.float() # cast torch.bool to torch.float32
|
871 |
+
|
872 |
+
# We generate `num_bins`+1 edges for the histogram buckets.
|
873 |
+
edges = torch.linspace(
|
874 |
+
config.min_val, config.max_val, config.num_bins + 1
|
875 |
+
).float()
|
876 |
+
|
877 |
+
sim_counts = torch.vmap(lambda x: torch.histogram(x, bins=edges).hist)(sim_values_flat) # [batch_size, num_bins]
|
878 |
+
# Identify which bin each sample belongs to and get the log probability of
|
879 |
+
# that bin under the sim distribution.
|
880 |
+
max_sim_bin = sim_counts.argmax(dim=-1)
|
881 |
+
batched_max_sim_bin = max_sim_bin.reshape(1, -1) # `batch_size` = 1, follows the log distributions
|
882 |
+
|
883 |
+
sim_likelihood = log_distributions.log_prob(batched_max_sim_bin.transpose(0, 1)).flatten()
|
884 |
+
return sim_likelihood.reshape(*sim_values.shape) # [n_objects, n_segments, n_steps]
|
885 |
+
|
886 |
+
|
887 |
+
def compute_scenario_metrics_for_bundle(
|
888 |
+
config: long_metrics_pb2.SimAgentMetricsConfig, # type: ignore
|
889 |
+
log_distributions: LogDistributions,
|
890 |
+
scenario_log: Optional[scenario_pb2.Scenario], # type: ignore
|
891 |
+
scenario_rollouts: ScenarioRollouts,
|
892 |
+
) -> long_metrics_pb2.SimAgentMetrics: # type: ignore
|
893 |
+
|
894 |
+
features_fields = [field.name for field in dataclasses.fields(MetricFeatures)]
|
895 |
+
features_fields.remove('object_id')
|
896 |
+
|
897 |
+
# ! compute simluation features
|
898 |
+
# CONSOLE.log('[on yellow] Compute sim features [/]')
|
899 |
+
sim_features = collections.defaultdict(list)
|
900 |
+
for simulate_trajectories in tqdm(scenario_rollouts.joint_scenes, leave=False, desc='rollouts ...'): # n_rollout=1
|
901 |
+
rollout_features = compute_metric_features(
|
902 |
+
simulate_trajectories,
|
903 |
+
evaluate_agent_ids=None,
|
904 |
+
scenario_log=scenario_log
|
905 |
+
)
|
906 |
+
|
907 |
+
for field in features_fields:
|
908 |
+
sim_features[field].append(getattr(rollout_features, field))
|
909 |
+
|
910 |
+
for field in features_fields:
|
911 |
+
if sim_features[field][0] is not None:
|
912 |
+
sim_features[field] = torch.concat(sim_features[field], dim=0) # n_rollout for dim=0
|
913 |
+
|
914 |
+
sim_features = MetricFeatures(
|
915 |
+
**sim_features, object_id=None,
|
916 |
+
)
|
917 |
+
# after unfold: linear_speed shape [n_agent, n_window, window_size],
|
918 |
+
# num_placement shape [n_scenario=1, n_window, window_size]
|
919 |
+
flattened_sim_features = copy.deepcopy(sim_features)
|
920 |
+
sim_features = sim_features.unfold(size=submission_specs.N_SIMULATION_STEPS, step=SHIFT)
|
921 |
+
# CONSOLE.log(f'sim linear_speed feature: {sim_features.linear_speed.shape}')
|
922 |
+
# CONSOLE.log(f'sim num_placement feature: {sim_features.num_placement.shape}')
|
923 |
+
|
924 |
+
## ! compute metrics
|
925 |
+
|
926 |
+
# ! kinematics-related metrics
|
927 |
+
linear_speed_log_likelihood = log_likelihood_estimate_timeseries(
|
928 |
+
field='linear_speed',
|
929 |
+
feature_config=config.linear_speed,
|
930 |
+
sim_values=sim_features.linear_speed,
|
931 |
+
log_distributions=log_distributions.linear_speed,
|
932 |
+
)
|
933 |
+
angular_speed_log_likelihood = log_likelihood_estimate_timeseries(
|
934 |
+
field='angular_speed',
|
935 |
+
feature_config=config.angular_speed,
|
936 |
+
sim_values=sim_features.angular_speed,
|
937 |
+
log_distributions=log_distributions.angular_speed,
|
938 |
+
)
|
939 |
+
speed_validity, acceleration_validity = (
|
940 |
+
trajectory_features.compute_kinematic_validity(flattened_sim_features.valid)
|
941 |
+
)
|
942 |
+
speed_validity = speed_validity.unfold(1, size=submission_specs.N_SIMULATION_STEPS, step=SHIFT)
|
943 |
+
acceleration_validity = acceleration_validity.unfold(1, size=submission_specs.N_SIMULATION_STEPS, step=SHIFT)
|
944 |
+
linear_speed_likelihood = torch.exp(_reduce_average_with_validity(
|
945 |
+
linear_speed_log_likelihood, speed_validity))
|
946 |
+
angular_speed_likelihood = torch.exp(_reduce_average_with_validity(
|
947 |
+
angular_speed_log_likelihood, speed_validity))
|
948 |
+
# CONSOLE.log(f'linear_speed_likelihood: {linear_speed_likelihood}')
|
949 |
+
# CONSOLE.log(f'angular_speed_likelihood: {angular_speed_likelihood}')
|
950 |
+
|
951 |
+
linear_accel_log_likelihood = log_likelihood_estimate_timeseries(
|
952 |
+
field='linear_acceleration',
|
953 |
+
feature_config=config.linear_acceleration,
|
954 |
+
sim_values=sim_features.linear_acceleration,
|
955 |
+
log_distributions=log_distributions.linear_acceleration,
|
956 |
+
)
|
957 |
+
angular_accel_log_likelihood = log_likelihood_estimate_timeseries(
|
958 |
+
field='angular_acceleration',
|
959 |
+
feature_config=config.angular_acceleration,
|
960 |
+
sim_values=sim_features.angular_acceleration,
|
961 |
+
log_distributions=log_distributions.angular_acceleration,
|
962 |
+
)
|
963 |
+
linear_accel_likelihood = torch.exp(_reduce_average_with_validity(
|
964 |
+
linear_accel_log_likelihood, acceleration_validity))
|
965 |
+
angular_accel_likelihood = torch.exp(_reduce_average_with_validity(
|
966 |
+
angular_accel_log_likelihood, acceleration_validity))
|
967 |
+
# CONSOLE.log(f'linear_accel_likelihood: {linear_accel_likelihood}')
|
968 |
+
# CONSOLE.log(f'angular_accel_likelihood: {angular_accel_likelihood}')
|
969 |
+
|
970 |
+
# ! collision and distance to other objects.
|
971 |
+
|
972 |
+
sim_collision_indication = torch.any(
|
973 |
+
torch.where(sim_features.valid, sim_features.collision_per_step, False),
|
974 |
+
dim=2)[..., None] # add a dummy time dimension
|
975 |
+
collision_score = log_likelihood_estimate_timeseries(
|
976 |
+
field='collision_indication',
|
977 |
+
feature_config=config.collision_indication,
|
978 |
+
sim_values=sim_collision_indication,
|
979 |
+
log_distributions=log_distributions.collision_indication,
|
980 |
+
estimate_method='bernoulli',
|
981 |
+
)
|
982 |
+
collision_likelihood = torch.exp(torch.mean(collision_score))
|
983 |
+
|
984 |
+
distance_to_objects_log_likelihodd = log_likelihood_estimate_timeseries(
|
985 |
+
field='distance_to_nearest_object',
|
986 |
+
feature_config=config.distance_to_nearest_object,
|
987 |
+
sim_values=sim_features.distance_to_nearest_object,
|
988 |
+
log_distributions=log_distributions.distance_to_nearest_object,
|
989 |
+
)
|
990 |
+
distance_to_objects_likelihodd = torch.exp(_reduce_average_with_validity(
|
991 |
+
distance_to_objects_log_likelihodd, sim_features.valid))
|
992 |
+
# CONSOLE.log(f'distance_to_objects_likelihodd: {distance_to_objects_likelihodd}')
|
993 |
+
|
994 |
+
ttc_log_likelihood = log_likelihood_estimate_timeseries(
|
995 |
+
field='time_to_collision',
|
996 |
+
feature_config=config.time_to_collision,
|
997 |
+
sim_values=sim_features.time_to_collision,
|
998 |
+
log_distributions=log_distributions.time_to_collision,
|
999 |
+
)
|
1000 |
+
ttc_likelihood = torch.exp(_reduce_average_with_validity(
|
1001 |
+
ttc_log_likelihood, sim_features.valid))
|
1002 |
+
# CONSOLE.log(f'ttc_likelihood: {ttc_likelihood}')
|
1003 |
+
|
1004 |
+
# ! offroad and distance to road edge.
|
1005 |
+
|
1006 |
+
# distance_to_road_edge_log_likelihood = log_likelihood_estimate_timeseries(
|
1007 |
+
# field='distance_to_road_edge',
|
1008 |
+
# sim_values=sim_features.distance_to_road_edge,
|
1009 |
+
# log_distributions=log_distributions.distance_to_road_edge,
|
1010 |
+
# )
|
1011 |
+
# distance_to_road_edge_likelihood = torch.exp(_reduce_average_with_validity(
|
1012 |
+
# distance_to_road_edge_log_likelihood, sim_features.valid))
|
1013 |
+
# CONSOLE.log(f'distance_to_road_edge_likelihood: {distance_to_road_edge_likelihood}')
|
1014 |
+
|
1015 |
+
# ! placement
|
1016 |
+
|
1017 |
+
num_placement_log_likelihood = log_likelihood_estimate_timeseries(
|
1018 |
+
field='num_placement',
|
1019 |
+
feature_config=config.num_placement,
|
1020 |
+
sim_values=sim_features.num_placement.float(),
|
1021 |
+
log_distributions=log_distributions.num_placement,
|
1022 |
+
)
|
1023 |
+
num_placement_likelihood = torch.exp(torch.mean(num_placement_log_likelihood))
|
1024 |
+
num_removement_log_likelihood = log_likelihood_estimate_timeseries(
|
1025 |
+
field='num_removement',
|
1026 |
+
feature_config=config.num_removement,
|
1027 |
+
sim_values=sim_features.num_removement.float(),
|
1028 |
+
log_distributions=log_distributions.num_removement,
|
1029 |
+
)
|
1030 |
+
num_removement_likelihood = torch.exp(torch.mean(num_removement_log_likelihood))
|
1031 |
+
# CONSOLE.log(f'num_placement_likelihood: {num_placement_likelihood}')
|
1032 |
+
# CONSOLE.log(f'num_removement_likelihood: {num_removement_likelihood}')
|
1033 |
+
|
1034 |
+
# tensor([[0.0013, 0.0078, 0.0194, 0.0373, 0.0628, 0.0938, 0.1232, 0.1470, 0.1701,
|
1035 |
+
# 0.3371]])
|
1036 |
+
# tensor([[0.0201, 0.0570, 0.0689, 0.0839, 0.1029, 0.1172, 0.1282, 0.1286, 0.1237,
|
1037 |
+
# 0.1695]])
|
1038 |
+
distance_placement_log_likelihood = log_likelihood_estimate_timeseries(
|
1039 |
+
field='distance_placement',
|
1040 |
+
feature_config=config.distance_placement,
|
1041 |
+
sim_values=sim_features.distance_placement,
|
1042 |
+
log_distributions=log_distributions.distance_placement,
|
1043 |
+
)
|
1044 |
+
distance_placement_validity = (
|
1045 |
+
(sim_features.distance_placement > config.distance_placement.histogram.min_val) &
|
1046 |
+
(sim_features.distance_placement < config.distance_placement.histogram.max_val)
|
1047 |
+
)
|
1048 |
+
distance_placement_likelihood = torch.exp(_reduce_average_with_validity(
|
1049 |
+
distance_placement_log_likelihood, distance_placement_validity))
|
1050 |
+
distance_removement_log_likelihood = log_likelihood_estimate_timeseries(
|
1051 |
+
field='distance_removement',
|
1052 |
+
feature_config=config.distance_removement,
|
1053 |
+
sim_values=sim_features.distance_removement,
|
1054 |
+
log_distributions=log_distributions.distance_removement,
|
1055 |
+
)
|
1056 |
+
distance_removement_validity = (
|
1057 |
+
(sim_features.distance_removement > config.distance_removement.histogram.min_val) &
|
1058 |
+
(sim_features.distance_removement < config.distance_removement.histogram.max_val)
|
1059 |
+
)
|
1060 |
+
distance_removement_likelihood = torch.exp(_reduce_average_with_validity(
|
1061 |
+
distance_removement_log_likelihood, distance_removement_validity))
|
1062 |
+
|
1063 |
+
# ==== Simulated collision and offroad rates ====
|
1064 |
+
simulated_collision_rate = torch.sum(
|
1065 |
+
sim_collision_indication.long()
|
1066 |
+
) / torch.sum(torch.ones_like(sim_collision_indication).long())
|
1067 |
+
# simulated_offroad_rate = tf.reduce_sum(
|
1068 |
+
# # `sim_offroad_indication` shape: (n_samples, n_objects).
|
1069 |
+
# tf.cast(sim_offroad_indication, tf.int32)
|
1070 |
+
# ) / tf.reduce_sum(tf.ones_like(sim_offroad_indication, dtype=tf.int32))
|
1071 |
+
|
1072 |
+
# ==== Meta metric ====
|
1073 |
+
likelihood_metrics = {
|
1074 |
+
'linear_speed_likelihood': float(linear_speed_likelihood.numpy()),
|
1075 |
+
'linear_acceleration_likelihood': float(linear_accel_likelihood.numpy()),
|
1076 |
+
'angular_speed_likelihood': float(angular_speed_likelihood.numpy()),
|
1077 |
+
'angular_acceleration_likelihood': float(angular_accel_likelihood.numpy()),
|
1078 |
+
'distance_to_nearest_object_likelihood': float(distance_to_objects_likelihodd.numpy()),
|
1079 |
+
'collision_indication_likelihood': float(collision_likelihood.numpy()),
|
1080 |
+
'time_to_collision_likelihood': float(ttc_likelihood.numpy()),
|
1081 |
+
# 'distance_to_road_edge_likelihoodfloat(': distance_road_edge_likelihood.nump)y(),
|
1082 |
+
# 'offroad_indication_likelihoodfloat(': offroad_likelihood.nump)y(),
|
1083 |
+
'num_placement_likelihood': float(num_placement_likelihood.numpy()),
|
1084 |
+
'num_removement_likelihood': float(num_removement_likelihood.numpy()),
|
1085 |
+
'distance_placement_likelihood': float(distance_placement_likelihood.numpy()),
|
1086 |
+
'distance_removement_likelihood': float(distance_removement_likelihood.numpy()),
|
1087 |
+
}
|
1088 |
+
|
1089 |
+
metametric = _compute_metametric(
|
1090 |
+
config, long_metrics_pb2.SimAgentMetrics(**likelihood_metrics)
|
1091 |
+
)
|
1092 |
+
# CONSOLE.log(f'metametric: {metametric}')
|
1093 |
+
|
1094 |
+
return long_metrics_pb2.SimAgentMetrics(
|
1095 |
+
scenario_id=scenario_rollouts.scenario_id,
|
1096 |
+
metametric=metametric,
|
1097 |
+
simulated_collision_rate=float(simulated_collision_rate.numpy()),
|
1098 |
+
# simulated_offroad_rate=simulated_offroad_rate.numpy(),
|
1099 |
+
**likelihood_metrics,
|
1100 |
+
)
|
1101 |
+
|
1102 |
+
|
1103 |
+
""" Log Features """
|
1104 |
+
|
1105 |
+
def _get_log_distributions(
|
1106 |
+
field: str,
|
1107 |
+
feature_config: long_metrics_pb2.SimAgentMetricsConfig.FeatureConfig, # type: ignore
|
1108 |
+
log_values: Tensor,
|
1109 |
+
estimate_method: str = 'histogram',
|
1110 |
+
) -> Tensor:
|
1111 |
+
"""Computes the log-likelihood estimates for a time-series simulated feature.
|
1112 |
+
|
1113 |
+
Args:
|
1114 |
+
feature_config: A time-series compatible `FeatureConfig`.
|
1115 |
+
log_values: A float Tensor with shape (n_objects, n_steps).
|
1116 |
+
sim_values: A float Tensor with shape (n_rollouts, n_objects, n_steps).
|
1117 |
+
|
1118 |
+
Returns:
|
1119 |
+
A tensor of shape (n_objects, n_steps) containing the log probability
|
1120 |
+
estimates of the log features under the simulated distribution of the same
|
1121 |
+
feature.
|
1122 |
+
"""
|
1123 |
+
assert log_values.ndim == 2, f'Expect log_values.ndim==2, got {log_values.ndim}, shape {log_values.shape} for {field}'
|
1124 |
+
|
1125 |
+
# [n_objects, n_steps] -> [n_objects * n_steps]
|
1126 |
+
log_samples = log_values.reshape(-1)
|
1127 |
+
|
1128 |
+
# ! estimate
|
1129 |
+
if estimate_method == 'histogram':
|
1130 |
+
config = feature_config.histogram
|
1131 |
+
elif estimate_method == 'bernoulli':
|
1132 |
+
config = (
|
1133 |
+
long_metrics_pb2.SimAgentMetricsConfig.HistogramEstimate(
|
1134 |
+
min_val=-0.5, max_val=0.5, num_bins=2,
|
1135 |
+
additive_smoothing_pseudocount=feature_config.bernoulli.additive_smoothing_pseudocount
|
1136 |
+
)
|
1137 |
+
)
|
1138 |
+
log_samples = log_samples.float() # cast torch.bool to torch.float32
|
1139 |
+
|
1140 |
+
# We generate `num_bins`+1 edges for the histogram buckets.
|
1141 |
+
edges = torch.linspace(
|
1142 |
+
config.min_val, config.max_val, config.num_bins + 1
|
1143 |
+
).float()
|
1144 |
+
|
1145 |
+
if field in ('distance_placement', 'distance_removement'):
|
1146 |
+
log_samples = log_samples[(log_samples > config.min_val) & (log_samples < config.max_val)]
|
1147 |
+
|
1148 |
+
# Clip the samples to avoid errors with histograms. Nonetheless, the min/max
|
1149 |
+
# values should be configured to never hit this condition in practice.
|
1150 |
+
log_samples = torch.clamp(log_samples, config.min_val, config.max_val)
|
1151 |
+
|
1152 |
+
# Create the categorical distribution for simulation. `tfp.histogram` returns
|
1153 |
+
# a tensor of shape (num_bins, batch_size), so we need to transpose to conform
|
1154 |
+
# to `tfp.distribution.Categorical`, which requires `probs` to be
|
1155 |
+
# (batch_size, num_bins).
|
1156 |
+
log_counts = torch.histogram(log_samples, bins=edges).hist.unsqueeze(dim=0) # [1, n_samples]
|
1157 |
+
log_counts += config.additive_smoothing_pseudocount
|
1158 |
+
distributions = torch.distributions.Categorical(probs=log_counts)
|
1159 |
+
|
1160 |
+
return distributions
|
1161 |
+
|
1162 |
+
|
1163 |
+
class LongMetric(Metric):
|
1164 |
+
|
1165 |
+
log_features: MetricFeatures
|
1166 |
+
|
1167 |
+
def __init__(
|
1168 |
+
self,
|
1169 |
+
prefix: str='',
|
1170 |
+
log_features_dir: str='data/waymo_processed/log_features/',
|
1171 |
+
config_path: str='dev/metrics/metric_config.textproto',
|
1172 |
+
) -> None:
|
1173 |
+
super().__init__()
|
1174 |
+
self.prefix = prefix
|
1175 |
+
self.metrics_config = self.load_metrics_config(config_path)
|
1176 |
+
|
1177 |
+
self.use_log = False
|
1178 |
+
|
1179 |
+
self.field_names = [
|
1180 |
+
"metametric",
|
1181 |
+
"average_displacement_error",
|
1182 |
+
"min_average_displacement_error",
|
1183 |
+
"linear_speed_likelihood",
|
1184 |
+
"linear_acceleration_likelihood",
|
1185 |
+
"angular_speed_likelihood",
|
1186 |
+
"angular_acceleration_likelihood",
|
1187 |
+
'distance_to_nearest_object_likelihood',
|
1188 |
+
'collision_indication_likelihood',
|
1189 |
+
'time_to_collision_likelihood',
|
1190 |
+
# 'distance_to_road_edge_likelihood',
|
1191 |
+
# 'offroad_indication_likelihood',
|
1192 |
+
'simulated_collision_rate',
|
1193 |
+
# 'simulated_offroad_rate',
|
1194 |
+
'num_placement_likelihood',
|
1195 |
+
'num_removement_likelihood',
|
1196 |
+
'distance_placement_likelihood',
|
1197 |
+
'distance_removement_likelihood',
|
1198 |
+
]
|
1199 |
+
for k in self.field_names:
|
1200 |
+
self.add_state(k, default=torch.tensor(0.), dist_reduce_fx='sum')
|
1201 |
+
self.add_state('scenario_counter', default=torch.tensor(0.), dist_reduce_fx='sum')
|
1202 |
+
self.add_state('placement_valid_scenario_counter', default=torch.tensor(0.), dist_reduce_fx='sum')
|
1203 |
+
self.add_state('removement_valid_scenario_counter', default=torch.tensor(0.), dist_reduce_fx='sum')
|
1204 |
+
|
1205 |
+
# get log features
|
1206 |
+
log_features_path = os.path.join(log_features_dir, 'total_features.pkl')
|
1207 |
+
if not os.path.exists(log_features_path):
|
1208 |
+
CONSOLE.log(f'[on yellow] Log features does not exist, loading now ... [/]')
|
1209 |
+
log_features = aggregate_log_metric_features(log_features_dir)
|
1210 |
+
else:
|
1211 |
+
log_features = MetricFeatures.from_file(log_features_path)
|
1212 |
+
CONSOLE.log(f'Loaded log features from {log_features_path}')
|
1213 |
+
self.log_features = log_features
|
1214 |
+
|
1215 |
+
self._compute_distributions()
|
1216 |
+
CONSOLE.log(f"Calculated log distributions:\n{self.log_distributions}")
|
1217 |
+
|
1218 |
+
def _compute_distributions(self):
|
1219 |
+
self.log_distributions = LogDistributions(
|
1220 |
+
linear_speed = _get_log_distributions('linear_speed',
|
1221 |
+
self.metrics_config.linear_speed, self.log_features.linear_speed,
|
1222 |
+
),
|
1223 |
+
linear_acceleration = _get_log_distributions('linear_acceleration',
|
1224 |
+
self.metrics_config.linear_acceleration, self.log_features.linear_acceleration,
|
1225 |
+
),
|
1226 |
+
angular_speed = _get_log_distributions('angular_speed',
|
1227 |
+
self.metrics_config.angular_speed, self.log_features.angular_speed,
|
1228 |
+
),
|
1229 |
+
angular_acceleration = _get_log_distributions('angular_acceleration',
|
1230 |
+
self.metrics_config.angular_acceleration, self.log_features.angular_acceleration,
|
1231 |
+
),
|
1232 |
+
distance_to_nearest_object = _get_log_distributions('distance_to_nearest_object',
|
1233 |
+
self.metrics_config.distance_to_nearest_object, self.log_features.distance_to_nearest_object,
|
1234 |
+
),
|
1235 |
+
collision_indication = _get_log_distributions('collision_indication',
|
1236 |
+
self.metrics_config.collision_indication,
|
1237 |
+
log_collision_indication := torch.any(
|
1238 |
+
torch.where(self.log_features.valid, self.log_features.collision_per_step, False), dim=1
|
1239 |
+
)[..., None], # add a dummy time dimension
|
1240 |
+
estimate_method = 'bernoulli',
|
1241 |
+
),
|
1242 |
+
time_to_collision = _get_log_distributions('time_to_collision',
|
1243 |
+
self.metrics_config.time_to_collision, self.log_features.time_to_collision,
|
1244 |
+
),
|
1245 |
+
distance_to_road_edge = _get_log_distributions('distance_to_road_edge',
|
1246 |
+
self.metrics_config.distance_to_road_edge, self.log_features.distance_to_road_edge,
|
1247 |
+
),
|
1248 |
+
# dist_offroad_indication = _get_log_distributions(
|
1249 |
+
# 'offroad_indication',
|
1250 |
+
# self.metrics_config.offroad_indication,
|
1251 |
+
# log_offroad_indication := torch.any(
|
1252 |
+
# torch.where(self.log_features.valid, self.log_features.offroad_per_step, False), dim=1
|
1253 |
+
# ),
|
1254 |
+
# ),
|
1255 |
+
num_placement = _get_log_distributions('num_placement',
|
1256 |
+
self.metrics_config.num_placement, self.log_features.num_placement.float(),
|
1257 |
+
),
|
1258 |
+
num_removement = _get_log_distributions('num_removement',
|
1259 |
+
self.metrics_config.num_removement, self.log_features.num_removement.float(),
|
1260 |
+
),
|
1261 |
+
distance_placement = _get_log_distributions('distance_placement',
|
1262 |
+
self.metrics_config.distance_placement, (
|
1263 |
+
self.log_features.distance_placement[self.log_features.distance_placement > 0])[None, ...],
|
1264 |
+
),
|
1265 |
+
distance_removement = _get_log_distributions('distance_removement',
|
1266 |
+
self.metrics_config.distance_removement, (
|
1267 |
+
self.log_features.distance_removement[self.log_features.distance_removement > 0])[None, ...],
|
1268 |
+
),
|
1269 |
+
)
|
1270 |
+
|
1271 |
+
def _compute_scenario_metrics(
|
1272 |
+
self,
|
1273 |
+
scenario_file: Optional[str],
|
1274 |
+
scenario_rollout: ScenarioRollouts,
|
1275 |
+
) -> long_metrics_pb2.SimAgentMetrics: # type: ignore
|
1276 |
+
|
1277 |
+
scenario_log = None
|
1278 |
+
if self.use_log and scenario_file is not None:
|
1279 |
+
if not os.path.exists(scenario_file):
|
1280 |
+
raise FileNotFoundError(f"Not found file {scenario_file}")
|
1281 |
+
scenario_log = scenario_pb2.Scenario()
|
1282 |
+
for data in tf.data.TFRecordDataset([scenario_file], compression_type=''):
|
1283 |
+
scenario_log.ParseFromString(bytes(data.numpy()))
|
1284 |
+
break
|
1285 |
+
|
1286 |
+
return compute_scenario_metrics_for_bundle(
|
1287 |
+
self.metrics_config, self.log_distributions, scenario_log, scenario_rollout
|
1288 |
+
)
|
1289 |
+
|
1290 |
+
def compute_metrics(self, outputs: dict) -> List[long_metrics_pb2.SimAgentMetrics]: # type: ignore
|
1291 |
+
"""
|
1292 |
+
`outputs` is a dict directly generated by predict models:
|
1293 |
+
>>> outputs = dict(
|
1294 |
+
>>> scenario_id=get_scenario_id_int_tensor(data['scenario_id'], device),
|
1295 |
+
>>> agent_id=agent_id,
|
1296 |
+
>>> agent_batch=agent_batch,
|
1297 |
+
>>> pred_traj=pred_traj,
|
1298 |
+
>>> pred_z=pred_z,
|
1299 |
+
>>> pred_head=pred_head,
|
1300 |
+
>>> pred_shape=pred_shape,
|
1301 |
+
>>> pred_type=pred_type,
|
1302 |
+
>>> pred_state=pred_state,
|
1303 |
+
>>> )
|
1304 |
+
"""
|
1305 |
+
|
1306 |
+
scenario_rollouts = output_to_rollouts(outputs)
|
1307 |
+
log_paths: List[str] = outputs['tfrecord_path']
|
1308 |
+
|
1309 |
+
pool_scenario_metrics = []
|
1310 |
+
for _scenario_file, _scenario_rollout in tqdm(
|
1311 |
+
zip(log_paths, scenario_rollouts), leave=False, desc='scenarios ...'): # n_scenarios
|
1312 |
+
pool_scenario_metrics.append(
|
1313 |
+
self._compute_scenario_metrics(
|
1314 |
+
_scenario_file, _scenario_rollout,
|
1315 |
+
)
|
1316 |
+
)
|
1317 |
+
|
1318 |
+
return pool_scenario_metrics
|
1319 |
+
|
1320 |
+
def update(
|
1321 |
+
self,
|
1322 |
+
outputs: Optional[dict]=None,
|
1323 |
+
metrics: Optional[List[long_metrics_pb2.SimAgentMetrics]]=None # type: ignore
|
1324 |
+
) -> None:
|
1325 |
+
|
1326 |
+
if metrics is None:
|
1327 |
+
assert outputs is not None, f'`outputs` should not be None!'
|
1328 |
+
metrics = self.compute_metrics(outputs)
|
1329 |
+
|
1330 |
+
for scenario_metrics in metrics:
|
1331 |
+
self.scenario_counter += 1
|
1332 |
+
|
1333 |
+
self.metametric += scenario_metrics.metametric
|
1334 |
+
self.average_displacement_error += (
|
1335 |
+
scenario_metrics.average_displacement_error
|
1336 |
+
)
|
1337 |
+
self.min_average_displacement_error += (
|
1338 |
+
scenario_metrics.min_average_displacement_error
|
1339 |
+
)
|
1340 |
+
self.linear_speed_likelihood += scenario_metrics.linear_speed_likelihood
|
1341 |
+
self.linear_acceleration_likelihood += (
|
1342 |
+
scenario_metrics.linear_acceleration_likelihood
|
1343 |
+
)
|
1344 |
+
self.angular_speed_likelihood += scenario_metrics.angular_speed_likelihood
|
1345 |
+
self.angular_acceleration_likelihood += (
|
1346 |
+
scenario_metrics.angular_acceleration_likelihood
|
1347 |
+
)
|
1348 |
+
self.distance_to_nearest_object_likelihood += (
|
1349 |
+
scenario_metrics.distance_to_nearest_object_likelihood
|
1350 |
+
)
|
1351 |
+
self.collision_indication_likelihood += (
|
1352 |
+
scenario_metrics.collision_indication_likelihood
|
1353 |
+
)
|
1354 |
+
self.time_to_collision_likelihood += (
|
1355 |
+
scenario_metrics.time_to_collision_likelihood
|
1356 |
+
)
|
1357 |
+
# self.distance_to_road_edge_likelihood += (
|
1358 |
+
# scenario_metrics.distance_to_road_edge_likelihood
|
1359 |
+
# )
|
1360 |
+
# self.offroad_indication_likelihood += (
|
1361 |
+
# scenario_metrics.offroad_indication_likelihood
|
1362 |
+
# )
|
1363 |
+
self.simulated_collision_rate += scenario_metrics.simulated_collision_rate
|
1364 |
+
# self.simulated_offroad_rate += scenario_metrics.simulated_offroad_rate
|
1365 |
+
|
1366 |
+
self.num_placement_likelihood += (
|
1367 |
+
scenario_metrics.num_placement_likelihood
|
1368 |
+
)
|
1369 |
+
self.num_removement_likelihood += (
|
1370 |
+
scenario_metrics.num_removement_likelihood
|
1371 |
+
)
|
1372 |
+
self.distance_placement_likelihood += (
|
1373 |
+
scenario_metrics.distance_placement_likelihood
|
1374 |
+
)
|
1375 |
+
self.distance_removement_likelihood += (
|
1376 |
+
scenario_metrics.distance_removement_likelihood
|
1377 |
+
)
|
1378 |
+
|
1379 |
+
if scenario_metrics.distance_placement_likelihood > 0:
|
1380 |
+
self.placement_valid_scenario_counter += 1
|
1381 |
+
|
1382 |
+
if scenario_metrics.distance_removement_likelihood > 0:
|
1383 |
+
self.removement_valid_scenario_counter += 1
|
1384 |
+
|
1385 |
+
def compute(self) -> Dict[str, Tensor]:
|
1386 |
+
metrics_dict = {}
|
1387 |
+
for k in self.field_names:
|
1388 |
+
if k not in ('distance_placement', 'distance_removement'):
|
1389 |
+
metrics_dict[k] = getattr(self, k) / self.scenario_counter
|
1390 |
+
if k == 'distance_placement':
|
1391 |
+
metrics_dict[k] = getattr(self, k) / self.placement_valid_scenario_counter
|
1392 |
+
if k == 'distance_removement':
|
1393 |
+
metrics_dict[k] = getattr(self, k) / self.removement_valid_scenario_counter
|
1394 |
+
|
1395 |
+
mean_metrics = long_metrics_pb2.SimAgentMetrics(
|
1396 |
+
scenario_id='', **metrics_dict,
|
1397 |
+
)
|
1398 |
+
final_metrics = self.aggregate_metrics_to_buckets(
|
1399 |
+
self.metrics_config, mean_metrics
|
1400 |
+
)
|
1401 |
+
CONSOLE.log(f'final_metrics:\n{final_metrics}')
|
1402 |
+
|
1403 |
+
out_dict = {
|
1404 |
+
f"{self.prefix}/wosac/realism_meta_metric": final_metrics.realism_meta_metric,
|
1405 |
+
f"{self.prefix}/wosac/kinematic_metrics": final_metrics.kinematic_metrics,
|
1406 |
+
f"{self.prefix}/wosac/interactive_metrics": final_metrics.interactive_metrics,
|
1407 |
+
f"{self.prefix}/wosac/map_based_metrics": final_metrics.map_based_metrics,
|
1408 |
+
f"{self.prefix}/wosac/placement_based_metrics": final_metrics.placement_based_metrics,
|
1409 |
+
f"{self.prefix}/wosac/min_ade": final_metrics.min_ade,
|
1410 |
+
f"{self.prefix}/wosac/scenario_counter": int(self.scenario_counter),
|
1411 |
+
}
|
1412 |
+
for k in self.field_names:
|
1413 |
+
out_dict[f"{self.prefix}/wosac_likelihood/{k}"] = float(metrics_dict[k])
|
1414 |
+
|
1415 |
+
return out_dict
|
1416 |
+
|
1417 |
+
@staticmethod
|
1418 |
+
def aggregate_metrics_to_buckets(
|
1419 |
+
config: long_metrics_pb2.SimAgentMetricsConfig, # type: ignore
|
1420 |
+
metrics: long_metrics_pb2.SimAgentMetrics # type: ignore
|
1421 |
+
) -> long_metrics_pb2.SimAgentsBucketedMetrics: # type: ignore
|
1422 |
+
"""Aggregates metrics into buckets for better readability."""
|
1423 |
+
bucketed_metrics = {}
|
1424 |
+
for bucket_name, fields_in_bucket in _METRIC_FIELD_NAMES_BY_BUCKET.items():
|
1425 |
+
weighted_metric, weights_sum = 0.0, 0.0
|
1426 |
+
for field_name in fields_in_bucket:
|
1427 |
+
likelihood_field_name = field_name + '_likelihood'
|
1428 |
+
weight = getattr(config, field_name).metametric_weight
|
1429 |
+
metric_score = getattr(metrics, likelihood_field_name)
|
1430 |
+
weighted_metric += weight * metric_score
|
1431 |
+
weights_sum += weight
|
1432 |
+
if weights_sum == 0:
|
1433 |
+
weights_sum = 1 # FIXME: hack!!!
|
1434 |
+
# raise ValueError('The bucket\'s weight sum is zero. Check your metrics'
|
1435 |
+
# ' config.')
|
1436 |
+
bucketed_metrics[bucket_name] = weighted_metric / weights_sum
|
1437 |
+
|
1438 |
+
return long_metrics_pb2.SimAgentsBucketedMetrics(
|
1439 |
+
realism_meta_metric=metrics.metametric,
|
1440 |
+
kinematic_metrics=bucketed_metrics['kinematic'],
|
1441 |
+
interactive_metrics=bucketed_metrics['interactive'],
|
1442 |
+
map_based_metrics=bucketed_metrics['map_based'],
|
1443 |
+
placement_based_metrics=bucketed_metrics['placement_based'],
|
1444 |
+
min_ade=metrics.min_average_displacement_error,
|
1445 |
+
simulated_collision_rate=metrics.simulated_collision_rate,
|
1446 |
+
simulated_offroad_rate=metrics.simulated_offroad_rate,
|
1447 |
+
)
|
1448 |
+
|
1449 |
+
@staticmethod
|
1450 |
+
def load_metrics_config(config_path: str = 'dev/metrics/metric_config.textproto',
|
1451 |
+
) -> long_metrics_pb2.SimAgentMetricsConfig: # type: ignore
|
1452 |
+
config = long_metrics_pb2.SimAgentMetricsConfig()
|
1453 |
+
with open(config_path, 'r') as f:
|
1454 |
+
text_format.Parse(f.read(), config)
|
1455 |
+
return config
|
1456 |
+
|
1457 |
+
def dumps(self, dir):
|
1458 |
+
from datetime import datetime
|
1459 |
+
|
1460 |
+
timestamp = datetime.now().strftime("%m_%d_%H%M%S")
|
1461 |
+
|
1462 |
+
results = self.compute()
|
1463 |
+
path = os.path.join(dir, f'{self.prefix}_{timestamp}.json')
|
1464 |
+
with open(path, 'w', encoding='utf-8') as f:
|
1465 |
+
json.dump(results, f, indent=4)
|
1466 |
+
|
1467 |
+
CONSOLE.log(f'Saved results to [bold][yellow]{path}')
|
1468 |
+
|
1469 |
+
|
1470 |
+
""" Preprocess Methods """
|
1471 |
+
|
1472 |
+
def _dump_log_metric_features(
|
1473 |
+
pkl_dir: str,
|
1474 |
+
tfrecords_dir: str,
|
1475 |
+
save_dir: str,
|
1476 |
+
transform: WaymoTargetBuilder,
|
1477 |
+
token_processor: TokenProcessor,
|
1478 |
+
scenario_id: str,
|
1479 |
+
):
|
1480 |
+
|
1481 |
+
try:
|
1482 |
+
|
1483 |
+
tqdm.write(f'Processing scenario {scenario_id}')
|
1484 |
+
save_path = os.path.join(save_dir, f'{scenario_id}.pkl')
|
1485 |
+
if os.path.exists(save_path):
|
1486 |
+
return
|
1487 |
+
|
1488 |
+
# load gt data
|
1489 |
+
pkl_file = os.path.join(pkl_dir, f'{scenario_id}.pkl')
|
1490 |
+
if not os.path.exists(pkl_file):
|
1491 |
+
raise FileNotFoundError(f"Not found file {pkl_file}")
|
1492 |
+
tfrecord_file = os.path.join(tfrecords_dir, f'{scenario_id}.tfrecords')
|
1493 |
+
if not os.path.exists(tfrecord_file):
|
1494 |
+
raise FileNotFoundError(f"Not found file {tfrecord_file}")
|
1495 |
+
|
1496 |
+
scenario_log = scenario_pb2.Scenario()
|
1497 |
+
for data in tf.data.TFRecordDataset([tfrecord_file], compression_type=''):
|
1498 |
+
scenario_log.ParseFromString(bytes(data.numpy()))
|
1499 |
+
break
|
1500 |
+
|
1501 |
+
with open(pkl_file, 'rb') as f:
|
1502 |
+
log_data = pickle.load(f)
|
1503 |
+
|
1504 |
+
# preprocess data
|
1505 |
+
log_data = transform._score_trained_agents(log_data) # get `train_mask`
|
1506 |
+
log_data = token_processor._tokenize_agent(log_data)
|
1507 |
+
|
1508 |
+
# convert to `JointScene` and compute features
|
1509 |
+
log_trajectories = scenario_to_trajectories(scenario_log, processed_scenario=log_data)
|
1510 |
+
# log_trajectories = ObjectTrajectories.init_from_processed_scenario(data)
|
1511 |
+
|
1512 |
+
# NOTE: we do not consider the `evaluation_agent_ids` here
|
1513 |
+
# evaluate_agent_ids = torch.tensor(
|
1514 |
+
# get_evaluation_agent_ids(scenario_log)
|
1515 |
+
# )
|
1516 |
+
evaluate_agent_ids = None
|
1517 |
+
log_features = compute_metric_features(
|
1518 |
+
log_trajectories, evaluate_agent_ids=evaluate_agent_ids, #scenario_log=scenario_log,
|
1519 |
+
)
|
1520 |
+
|
1521 |
+
# save to pkl file
|
1522 |
+
with open(save_path, 'wb') as f:
|
1523 |
+
pickle.dump(log_features, f)
|
1524 |
+
|
1525 |
+
except Exception as e:
|
1526 |
+
CONSOLE.log(f'[on red] Failed to process scenario {scenario_id} due to {e}.[/]')
|
1527 |
+
return
|
1528 |
+
|
1529 |
+
|
1530 |
+
def dump_log_metric_features(log_dir, save_dir):
|
1531 |
+
|
1532 |
+
buffer_size = 128
|
1533 |
+
|
1534 |
+
# file loaders
|
1535 |
+
pkl_dir = os.path.join(log_dir, 'validation')
|
1536 |
+
if not os.path.exists(pkl_dir):
|
1537 |
+
raise RuntimeError(f'Not found folder {pkl_dir}')
|
1538 |
+
tfrecords_dir = os.path.join(log_dir, 'validation_tfrecords_splitted')
|
1539 |
+
if not os.path.exists(tfrecords_dir):
|
1540 |
+
raise RuntimeError(f'Not found folder {tfrecords_dir}')
|
1541 |
+
|
1542 |
+
files = list(fnmatch.filter(os.listdir(pkl_dir), '*.pkl'))
|
1543 |
+
json_path = os.path.join(log_dir, 'meta_infos.json')
|
1544 |
+
meta_infos = json.load(open(json_path, 'r', encoding='utf-8'))['validation']
|
1545 |
+
CONSOLE.log(f"Loaded meta infos from {json_path}")
|
1546 |
+
available_scenarios = list(meta_infos.keys())
|
1547 |
+
df = pd.DataFrame.from_dict(meta_infos, orient='index')
|
1548 |
+
available_scenarios_set = set(available_scenarios)
|
1549 |
+
df_filtered = df[(df.index.isin(available_scenarios_set)) & (df['num_agents'] >= 8) & (df['num_agents'] < buffer_size)]
|
1550 |
+
valid_scenarios = set(df_filtered.index)
|
1551 |
+
files = list(tqdm(filter(lambda fn: fn.removesuffix('.pkl') in valid_scenarios, files), leave=False))
|
1552 |
+
|
1553 |
+
scenario_ids = list(map(lambda fn: fn.removesuffix('.pkl'), files))
|
1554 |
+
CONSOLE.log(f'Loaded {len(scenario_ids)} scenarios from validation split.')
|
1555 |
+
|
1556 |
+
# initialize
|
1557 |
+
transform = WaymoTargetBuilder(num_historical_steps=11,
|
1558 |
+
num_future_steps=80,
|
1559 |
+
max_num=32)
|
1560 |
+
|
1561 |
+
token_processor = TokenProcessor(token_size=2048,
|
1562 |
+
state_token={'invalid': 0, 'valid': 1, 'enter': 2, 'exit': 3},
|
1563 |
+
pl2seed_radius=75)
|
1564 |
+
|
1565 |
+
partial_dump_gt_metric_features = partial(
|
1566 |
+
_dump_log_metric_features, pkl_dir, tfrecords_dir, save_dir, transform, token_processor)
|
1567 |
+
|
1568 |
+
for scenario_id in tqdm(scenario_ids, leave=False, desc='scenarios ...'):
|
1569 |
+
|
1570 |
+
partial_dump_gt_metric_features(scenario_id)
|
1571 |
+
|
1572 |
+
|
1573 |
+
def batch_dump_log_metric_features(log_dir, save_dir, num_workers=64):
|
1574 |
+
|
1575 |
+
buffer_size = 128
|
1576 |
+
|
1577 |
+
# file loaders
|
1578 |
+
pkl_dir = os.path.join(log_dir, 'validation')
|
1579 |
+
if not os.path.exists(pkl_dir):
|
1580 |
+
raise RuntimeError(f'Not found folder {pkl_dir}')
|
1581 |
+
tfrecords_dir = os.path.join(log_dir, 'validation_tfrecords_splitted')
|
1582 |
+
if not os.path.exists(tfrecords_dir):
|
1583 |
+
raise RuntimeError(f'Not found folder {tfrecords_dir}')
|
1584 |
+
|
1585 |
+
files = list(fnmatch.filter(os.listdir(pkl_dir), '*.pkl'))
|
1586 |
+
json_path = os.path.join(log_dir, 'meta_infos.json')
|
1587 |
+
meta_infos = json.load(open(json_path, 'r', encoding='utf-8'))['validation']
|
1588 |
+
CONSOLE.log(f"Loaded meta infos from {json_path}")
|
1589 |
+
available_scenarios = list(meta_infos.keys())
|
1590 |
+
df = pd.DataFrame.from_dict(meta_infos, orient='index')
|
1591 |
+
available_scenarios_set = set(available_scenarios)
|
1592 |
+
df_filtered = df[(df.index.isin(available_scenarios_set)) & (df['num_agents'] >= 8) & (df['num_agents'] < buffer_size)]
|
1593 |
+
valid_scenarios = set(df_filtered.index)
|
1594 |
+
files = list(tqdm(filter(lambda fn: fn.removesuffix('.pkl') in valid_scenarios, files), leave=False))
|
1595 |
+
|
1596 |
+
scenario_ids = list(map(lambda fn: fn.removesuffix('.pkl'), files))
|
1597 |
+
CONSOLE.log(f'Loaded {len(scenario_ids)} scenarios from validation split.')
|
1598 |
+
|
1599 |
+
# initialize
|
1600 |
+
transform = WaymoTargetBuilder(num_historical_steps=11,
|
1601 |
+
num_future_steps=80,
|
1602 |
+
max_num=32)
|
1603 |
+
|
1604 |
+
token_processor = TokenProcessor(token_size=2048,
|
1605 |
+
state_token={'invalid': 0, 'valid': 1, 'enter': 2, 'exit': 3},
|
1606 |
+
pl2seed_radius=75)
|
1607 |
+
|
1608 |
+
partial_dump_gt_metric_features = partial(
|
1609 |
+
_dump_log_metric_features, pkl_dir, tfrecords_dir, save_dir, transform, token_processor)
|
1610 |
+
|
1611 |
+
with multiprocessing.Pool(num_workers) as p:
|
1612 |
+
list(tqdm(p.imap_unordered(partial_dump_gt_metric_features, scenario_ids), total=len(scenario_ids)))
|
1613 |
+
|
1614 |
+
|
1615 |
+
def aggregate_log_metric_features(load_dir):
|
1616 |
+
|
1617 |
+
files = list(fnmatch.filter(os.listdir(load_dir), '*.pkl'))
|
1618 |
+
if 'total_features.pkl' in files:
|
1619 |
+
files.remove('total_features.pkl')
|
1620 |
+
CONSOLE.log(f'Loaded {len(files)} scenarios from dumpped log metric features')
|
1621 |
+
|
1622 |
+
features_fields = [field.name for field in dataclasses.fields(MetricFeatures)]
|
1623 |
+
features_fields.remove('object_id')
|
1624 |
+
|
1625 |
+
# load and append
|
1626 |
+
total_features = collections.defaultdict(list)
|
1627 |
+
for file in tqdm(files, leave=False, desc='scenario ...'):
|
1628 |
+
|
1629 |
+
with open(os.path.join(load_dir, file), 'rb') as f:
|
1630 |
+
log_features = pickle.load(f)
|
1631 |
+
|
1632 |
+
for field in features_fields:
|
1633 |
+
total_features[field].append(getattr(log_features, field))
|
1634 |
+
|
1635 |
+
# aggregate
|
1636 |
+
features_info = dict()
|
1637 |
+
for field in (pbar := tqdm(features_fields, leave=False)):
|
1638 |
+
pbar.set_postfix(f=field)
|
1639 |
+
if total_features[field][0] is not None:
|
1640 |
+
total_features[field] = torch.concat(total_features[field], dim=0) # n_agent or n_scenario
|
1641 |
+
features_info[field] = total_features[field].shape
|
1642 |
+
CONSOLE.log(f'Aggregated log features:\n{features_info}')
|
1643 |
+
|
1644 |
+
# save
|
1645 |
+
save_path = os.path.join(load_dir, 'total_features.pkl')
|
1646 |
+
with open(save_path, 'wb') as f:
|
1647 |
+
pickle.dump(total_features, f)
|
1648 |
+
CONSOLE.log(f'Saved total features to [green]{save_path}.[/]')
|
1649 |
+
|
1650 |
+
return MetricFeatures(**total_features, object_id=None)
|
1651 |
+
|
1652 |
+
|
1653 |
+
def _compute_metrics(
|
1654 |
+
metric: LongMetric,
|
1655 |
+
load_dir: str,
|
1656 |
+
verbose: bool,
|
1657 |
+
rollouts_file: str,
|
1658 |
+
) -> List[long_metrics_pb2.SimAgentMetrics]: # type: ignore
|
1659 |
+
|
1660 |
+
if verbose:
|
1661 |
+
print(f'Processing {rollouts_file}')
|
1662 |
+
|
1663 |
+
with open(os.path.join(load_dir, rollouts_file), 'rb') as f:
|
1664 |
+
rollouts = pickle.load(f)
|
1665 |
+
# CONSOLE.log(f'Loaded rollouts from {rollouts_file}')
|
1666 |
+
|
1667 |
+
return metric.compute_metrics(rollouts)
|
1668 |
+
|
1669 |
+
|
1670 |
+
def compute_metrics(load_dir, rollouts_files):
|
1671 |
+
|
1672 |
+
log_every_n_steps = 100
|
1673 |
+
|
1674 |
+
metric = LongMetric('val_close_long')
|
1675 |
+
CONSOLE.log(f'metrics config:\n{metric.metrics_config}')
|
1676 |
+
|
1677 |
+
i = 0
|
1678 |
+
for rollouts_file in tqdm(rollouts_files, leave=False, desc='Rollouts files ...'):
|
1679 |
+
|
1680 |
+
# ! compute metrics and update
|
1681 |
+
metric.update(
|
1682 |
+
metrics=_compute_metrics(metric, load_dir, verbose=False, rollouts_file=rollouts_file)
|
1683 |
+
)
|
1684 |
+
|
1685 |
+
if i % log_every_n_steps == 0:
|
1686 |
+
CONSOLE.log(f'Step={i}:\n{metric.compute()}')
|
1687 |
+
|
1688 |
+
i += 1
|
1689 |
+
|
1690 |
+
CONSOLE.log(f'[bold][yellow] Compute metrics completed!')
|
1691 |
+
CONSOLE.log(f'[bold][yellow] Final metrics: [/]\n {metric.compute()}')
|
1692 |
+
|
1693 |
+
|
1694 |
+
def batch_compute_metrics(load_dir, rollouts_files, num_workers, save_dir=None):
|
1695 |
+
from queue import Queue
|
1696 |
+
from threading import Thread
|
1697 |
+
|
1698 |
+
if save_dir is None:
|
1699 |
+
save_dir = load_dir
|
1700 |
+
|
1701 |
+
results_buffer = Queue()
|
1702 |
+
|
1703 |
+
log_every_n_steps = 20
|
1704 |
+
|
1705 |
+
metric = LongMetric('val_close_long')
|
1706 |
+
CONSOLE.log(f'metrics config:\n{metric.metrics_config}')
|
1707 |
+
|
1708 |
+
def _collect_result():
|
1709 |
+
while True:
|
1710 |
+
r = results_buffer.get()
|
1711 |
+
if r is None:
|
1712 |
+
break
|
1713 |
+
metric.update(metrics=r)
|
1714 |
+
results_buffer.task_done()
|
1715 |
+
|
1716 |
+
collector = Thread(target=_collect_result, daemon=True)
|
1717 |
+
collector.start()
|
1718 |
+
|
1719 |
+
partial_compute_metrics = partial(_compute_metrics, metric, load_dir, True)
|
1720 |
+
|
1721 |
+
# ! compute metrics in batch
|
1722 |
+
with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
|
1723 |
+
# results = list(executor.map(partial_compute_metrics, rollouts_files))
|
1724 |
+
futures = [executor.submit(partial_compute_metrics, rollouts_file) for rollouts_file in rollouts_files]
|
1725 |
+
# results = [f.result() for f in concurrent.futures.as_completed(futures)]
|
1726 |
+
|
1727 |
+
for i, future in tqdm(enumerate(concurrent.futures.as_completed(futures)), total=len(futures), leave=False):
|
1728 |
+
results_buffer.put(future.result())
|
1729 |
+
|
1730 |
+
if i % log_every_n_steps == 0:
|
1731 |
+
CONSOLE.log(f'Step={i}:\n{metric.compute()}')
|
1732 |
+
metric.dumps(save_dir)
|
1733 |
+
|
1734 |
+
results_buffer.put(None)
|
1735 |
+
collector.join()
|
1736 |
+
|
1737 |
+
CONSOLE.log(f'[bold][yellow] Compute metrics completed!')
|
1738 |
+
CONSOLE.log(f'[bold][yellow] Final metrics: [/]\n {metric.compute()}')
|
1739 |
+
|
1740 |
+
# save results to disk
|
1741 |
+
metric.dumps(save_dir)
|
1742 |
+
|
1743 |
+
|
1744 |
+
if __name__ == "__main__":
|
1745 |
+
parser = ArgumentParser()
|
1746 |
+
parser.add_argument('--dump_log', action='store_true')
|
1747 |
+
parser.add_argument('--dump_sim', action='store_true')
|
1748 |
+
parser.add_argument('--aggregate_log', action='store_true')
|
1749 |
+
parser.add_argument('--num_workers', type=int, default=32)
|
1750 |
+
parser.add_argument('--compute_metric', action='store_true')
|
1751 |
+
parser.add_argument('--log_dir', type=str, default='data/waymo_processed/')
|
1752 |
+
parser.add_argument('--sim_dir', type=str, default=None, required=False)
|
1753 |
+
parser.add_argument('--save_dir', type=str, default='results', required=False)
|
1754 |
+
parser.add_argument('--no_batch', action='store_true')
|
1755 |
+
parser.add_argument('--debug', action='store_true')
|
1756 |
+
parser.add_argument('--debug_batch', action='store_true')
|
1757 |
+
args = parser.parse_args()
|
1758 |
+
|
1759 |
+
if args.dump_log:
|
1760 |
+
|
1761 |
+
save_dir = os.path.join(args.log_dir, 'log_features')
|
1762 |
+
os.makedirs(save_dir, exist_ok=True)
|
1763 |
+
|
1764 |
+
if args.no_batch or args.debug:
|
1765 |
+
dump_log_metric_features(args.log_dir, save_dir)
|
1766 |
+
else:
|
1767 |
+
batch_dump_log_metric_features(args.log_dir, save_dir)
|
1768 |
+
|
1769 |
+
elif args.aggregate_log:
|
1770 |
+
|
1771 |
+
load_dir = os.path.join(args.log_dir, 'log_features')
|
1772 |
+
aggregate_log_metric_features(load_dir)
|
1773 |
+
|
1774 |
+
elif args.compute_metric:
|
1775 |
+
|
1776 |
+
assert args.sim_dir is not None and os.path.exists(args.sim_dir), \
|
1777 |
+
f'Folder {args.sim_dir} does not exist!'
|
1778 |
+
rollouts_files = list(sorted(fnmatch.filter(os.listdir(args.sim_dir), 'idx_*_rollouts.pkl')))
|
1779 |
+
CONSOLE.log(f'Found {len(rollouts_files)} rollouts files.')
|
1780 |
+
|
1781 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
1782 |
+
if args.no_batch:
|
1783 |
+
compute_metrics(args.sim_dir, rollouts_files)
|
1784 |
+
|
1785 |
+
else:
|
1786 |
+
multiprocessing.set_start_method('spawn', force=True)
|
1787 |
+
batch_compute_metrics(args.sim_dir, rollouts_files, args.num_workers, save_dir=args.save_dir)
|
1788 |
+
|
1789 |
+
elif args.debug:
|
1790 |
+
|
1791 |
+
debug_path = 'output/scalable_smart_long/validation_catk/idx_0_0_rollouts.pkl'
|
1792 |
+
|
1793 |
+
# ! for debugging
|
1794 |
+
with open(debug_path, 'rb') as f:
|
1795 |
+
rollouts = pickle.load(f)
|
1796 |
+
metric = LongMetric('debug')
|
1797 |
+
CONSOLE.log(f'metrics config: {metric.metrics_config}')
|
1798 |
+
|
1799 |
+
metric.update(outputs=rollouts)
|
1800 |
+
CONSOLE.log(f'metrics:\n{metric.compute()}')
|
1801 |
+
|
1802 |
+
|
1803 |
+
elif args.debug_batch:
|
1804 |
+
|
1805 |
+
rollouts_files = ['idx_0_rollouts.pkl'] * 1000
|
1806 |
+
CONSOLE.log(f'Found {len(rollouts_files)} rollouts files.')
|
1807 |
+
|
1808 |
+
sim_dir = 'dev/metrics/'
|
1809 |
+
|
1810 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
1811 |
+
multiprocessing.set_start_method('spawn', force=True)
|
1812 |
+
batch_compute_metrics(args.sim_dir, rollouts_files, args.num_workers, save_dir=args.save_dir)
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/metrics/geometry_utils.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from torch import Tensor
|
4 |
+
from typing import Tuple
|
5 |
+
|
6 |
+
|
7 |
+
NUM_VERTICES_IN_BOX = 4
|
8 |
+
|
9 |
+
|
10 |
+
def minkowski_sum_of_box_and_box_points(box1_points: Tensor,
|
11 |
+
box2_points: Tensor) -> Tensor:
|
12 |
+
"""Batched Minkowski sum of two boxes (counter-clockwise corners in xy)."""
|
13 |
+
point_order_1 = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.long)
|
14 |
+
point_order_2 = torch.tensor([0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long)
|
15 |
+
|
16 |
+
box1_start_idx, downmost_box1_edge_direction = _get_downmost_edge_in_box(
|
17 |
+
box1_points)
|
18 |
+
box2_start_idx, downmost_box2_edge_direction = _get_downmost_edge_in_box(
|
19 |
+
box2_points)
|
20 |
+
|
21 |
+
condition = (cross_product_2d(downmost_box1_edge_direction, downmost_box2_edge_direction) >= 0.)
|
22 |
+
condition = condition.repeat(1, 8)
|
23 |
+
|
24 |
+
box1_point_order = torch.where(condition, point_order_2, point_order_1)
|
25 |
+
box1_point_order = (box1_point_order + box1_start_idx) % NUM_VERTICES_IN_BOX
|
26 |
+
ordered_box1_points = torch.gather(
|
27 |
+
box1_points, 1, box1_point_order.unsqueeze(-1).expand(-1, -1, 2))
|
28 |
+
|
29 |
+
box2_point_order = torch.where(condition, point_order_1, point_order_2)
|
30 |
+
box2_point_order = (box2_point_order + box2_start_idx) % NUM_VERTICES_IN_BOX
|
31 |
+
ordered_box2_points = torch.gather(
|
32 |
+
box2_points, 1, box2_point_order.unsqueeze(-1).expand(-1, -1, 2))
|
33 |
+
|
34 |
+
minkowski_sum = ordered_box1_points + ordered_box2_points
|
35 |
+
|
36 |
+
return minkowski_sum
|
37 |
+
|
38 |
+
|
39 |
+
def signed_distance_from_point_to_convex_polygon(query_points: Tensor, polygon_points: Tensor) -> Tensor:
|
40 |
+
"""Finds the signed distances from query points to convex polygons."""
|
41 |
+
tangent_unit_vectors, normal_unit_vectors, edge_lengths = _get_edge_info(
|
42 |
+
polygon_points)
|
43 |
+
|
44 |
+
query_points = query_points.unsqueeze(1)
|
45 |
+
vertices_to_query_vectors = query_points - polygon_points
|
46 |
+
vertices_distances = torch.norm(vertices_to_query_vectors, dim=-1)
|
47 |
+
|
48 |
+
edge_signed_perp_distances = torch.sum(-normal_unit_vectors * vertices_to_query_vectors, dim=-1)
|
49 |
+
|
50 |
+
is_inside = torch.all(edge_signed_perp_distances <= 0, dim=-1)
|
51 |
+
|
52 |
+
projection_along_tangent = torch.sum(tangent_unit_vectors * vertices_to_query_vectors, dim=-1)
|
53 |
+
projection_along_tangent_proportion = projection_along_tangent / edge_lengths
|
54 |
+
|
55 |
+
is_projection_on_edge = (projection_along_tangent_proportion >= 0.) & (
|
56 |
+
projection_along_tangent_proportion <= 1.)
|
57 |
+
|
58 |
+
edge_perp_distances = edge_signed_perp_distances.abs()
|
59 |
+
edge_distances = torch.where(is_projection_on_edge, edge_perp_distances, torch.tensor(np.inf))
|
60 |
+
|
61 |
+
edge_and_vertex_distance = torch.cat([edge_distances, vertices_distances], dim=-1)
|
62 |
+
min_distance = torch.min(edge_and_vertex_distance, dim=-1)[0]
|
63 |
+
|
64 |
+
signed_distances = torch.where(is_inside, -min_distance, min_distance)
|
65 |
+
|
66 |
+
return signed_distances
|
67 |
+
|
68 |
+
|
69 |
+
def _get_downmost_edge_in_box(box: Tensor) -> Tuple[Tensor, Tensor]:
|
70 |
+
"""Finds the downmost (lowest y-coordinate) edge in the box."""
|
71 |
+
downmost_vertex_idx = torch.argmin(box[..., 1], dim=-1, keepdim=True)
|
72 |
+
|
73 |
+
edge_start_vertex = torch.gather(box, 1, downmost_vertex_idx.unsqueeze(-1).expand(-1, -1, 2))
|
74 |
+
edge_end_idx = (downmost_vertex_idx + 1) % NUM_VERTICES_IN_BOX
|
75 |
+
edge_end_vertex = torch.gather(box, 1, edge_end_idx.unsqueeze(-1).expand(-1, -1, 2))
|
76 |
+
|
77 |
+
downmost_edge = edge_end_vertex - edge_start_vertex
|
78 |
+
downmost_edge_length = torch.norm(downmost_edge, dim=-1, keepdim=True)
|
79 |
+
downmost_edge_direction = downmost_edge / downmost_edge_length
|
80 |
+
|
81 |
+
return downmost_vertex_idx, downmost_edge_direction
|
82 |
+
|
83 |
+
|
84 |
+
def cross_product_2d(a: Tensor, b: Tensor) -> Tensor:
|
85 |
+
return a[..., 0] * b[..., 1] - a[..., 1] * b[..., 0]
|
86 |
+
|
87 |
+
|
88 |
+
def dot_product_2d(a: Tensor, b: Tensor) -> Tensor:
|
89 |
+
return a[..., 0] * b[..., 0] + a[..., 1] * b[..., 1]
|
90 |
+
|
91 |
+
|
92 |
+
def _get_edge_info(polygon_points: Tensor):
|
93 |
+
"""
|
94 |
+
Computes properties about the edges of a polygon.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
polygon_points: Tensor containing the vertices of each polygon, with
|
98 |
+
shape (num_polygons, num_points_per_polygon, 2). Each polygon is assumed
|
99 |
+
to have an equal number of vertices.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
tangent_unit_vectors: A unit vector in (x,y) with the same direction as
|
103 |
+
the tangent to the edge. Shape: (num_polygons, num_points_per_polygon, 2).
|
104 |
+
normal_unit_vectors: A unit vector in (x,y) with the same direction as
|
105 |
+
the normal to the edge.
|
106 |
+
Shape: (num_polygons, num_points_per_polygon, 2).
|
107 |
+
edge_lengths: Lengths of the edges.
|
108 |
+
Shape (num_polygons, num_points_per_polygon).
|
109 |
+
"""
|
110 |
+
# Shift the polygon points by 1 position to get the edges.
|
111 |
+
first_point_in_polygon = polygon_points[:, 0:1, :] # Shape: (num_polygons, 1, 2)
|
112 |
+
shifted_polygon_points = torch.cat([polygon_points[:, 1:, :], first_point_in_polygon], dim=1)
|
113 |
+
# Shape: (num_polygons, num_points_per_polygon, 2)
|
114 |
+
|
115 |
+
edge_vectors = shifted_polygon_points - polygon_points # Shape: (num_polygons, num_points_per_polygon, 2)
|
116 |
+
edge_lengths = torch.norm(edge_vectors, dim=-1) # Shape: (num_polygons, num_points_per_polygon)
|
117 |
+
|
118 |
+
# Avoid division by zero by adding a small epsilon
|
119 |
+
eps = torch.finfo(edge_lengths.dtype).eps
|
120 |
+
tangent_unit_vectors = edge_vectors / (edge_lengths[..., None] + eps) # Shape: (num_polygons, num_points_per_polygon, 2)
|
121 |
+
|
122 |
+
normal_unit_vectors = torch.stack(
|
123 |
+
[-tangent_unit_vectors[..., 1], tangent_unit_vectors[..., 0]], dim=-1
|
124 |
+
) # Shape: (num_polygons, num_points_per_polygon, 2)
|
125 |
+
|
126 |
+
return tangent_unit_vectors, normal_unit_vectors, edge_lengths
|
127 |
+
|
128 |
+
|
129 |
+
def rotate_2d_points(xys: Tensor, rotation_yaws: Tensor) -> Tensor:
|
130 |
+
"""Rotates `xys` counterclockwise using the `rotation_yaws`."""
|
131 |
+
cos_yaws = torch.cos(rotation_yaws)
|
132 |
+
sin_yaws = torch.sin(rotation_yaws)
|
133 |
+
|
134 |
+
rotated_x = cos_yaws * xys[..., 0] - sin_yaws * xys[..., 1]
|
135 |
+
rotated_y = sin_yaws * xys[..., 0] + cos_yaws * xys[..., 1]
|
136 |
+
|
137 |
+
return torch.stack([rotated_x, rotated_y], axis=-1)
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/metrics/interact_features.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import Tensor
|
4 |
+
|
5 |
+
from dev.metrics import box_utils
|
6 |
+
from dev.metrics import geometry_utils
|
7 |
+
from dev.metrics import trajectory_features
|
8 |
+
|
9 |
+
|
10 |
+
EXTREMELY_LARGE_DISTANCE = 1e10
|
11 |
+
COLLISION_DISTANCE_THRESHOLD = 0.0
|
12 |
+
CORNER_ROUNDING_FACTOR = 0.7
|
13 |
+
MAX_HEADING_DIFF = math.radians(75.0)
|
14 |
+
MAX_HEADING_DIFF_FOR_SMALL_OVERLAP = math.radians(10.0)
|
15 |
+
SMALL_OVERLAP_THRESHOLD = 0.5
|
16 |
+
MAXIMUM_TIME_TO_COLLISION = 5.0
|
17 |
+
|
18 |
+
|
19 |
+
def compute_distance_to_nearest_object(
|
20 |
+
center_x: Tensor,
|
21 |
+
center_y: Tensor,
|
22 |
+
center_z: Tensor,
|
23 |
+
length: Tensor,
|
24 |
+
width: Tensor,
|
25 |
+
height: Tensor,
|
26 |
+
heading: Tensor,
|
27 |
+
valid: Tensor,
|
28 |
+
evaluated_object_mask: Tensor,
|
29 |
+
corner_rounding_factor: float = CORNER_ROUNDING_FACTOR,
|
30 |
+
) -> Tensor:
|
31 |
+
"""Computes the distance to nearest object for each of the evaluated objects."""
|
32 |
+
boxes = torch.stack([center_x, center_y, center_z, length, width, height, heading], dim=-1)
|
33 |
+
num_objects, num_steps, num_features = boxes.shape
|
34 |
+
|
35 |
+
shrinking_distance = (torch.minimum(boxes[:, :, 3], boxes[:, :, 4]) * corner_rounding_factor / 2.)
|
36 |
+
|
37 |
+
boxes = torch.cat([
|
38 |
+
boxes[:, :, :3],
|
39 |
+
boxes[:, :, 3:4] - 2.0 * shrinking_distance[..., None],
|
40 |
+
boxes[:, :, 4:5] - 2.0 * shrinking_distance[..., None],
|
41 |
+
boxes[:, :, 5:]
|
42 |
+
], dim=2)
|
43 |
+
|
44 |
+
boxes = boxes.reshape(num_objects * num_steps, num_features)
|
45 |
+
|
46 |
+
box_corners = box_utils.get_upright_3d_box_corners(boxes)[:, :4, :2]
|
47 |
+
box_corners = box_corners.reshape(num_objects, num_steps, 4, 2)
|
48 |
+
|
49 |
+
eval_corners = box_corners[evaluated_object_mask]
|
50 |
+
num_eval_objects = eval_corners.shape[0]
|
51 |
+
other_corners = box_corners[~evaluated_object_mask]
|
52 |
+
all_corners = torch.cat([eval_corners, other_corners], dim=0)
|
53 |
+
|
54 |
+
eval_corners = eval_corners.unsqueeze(1).expand(num_eval_objects, num_objects, num_steps, 4, 2)
|
55 |
+
all_corners = all_corners.unsqueeze(0).expand(num_eval_objects, num_objects, num_steps, 4, 2)
|
56 |
+
|
57 |
+
eval_corners = eval_corners.reshape(num_eval_objects * num_objects * num_steps, 4, 2)
|
58 |
+
all_corners = all_corners.reshape(num_eval_objects * num_objects * num_steps, 4, 2)
|
59 |
+
|
60 |
+
neg_all_corners = -1.0 * all_corners
|
61 |
+
minkowski_sum = geometry_utils.minkowski_sum_of_box_and_box_points(
|
62 |
+
box1_points=eval_corners, box2_points=neg_all_corners,
|
63 |
+
)
|
64 |
+
|
65 |
+
assert minkowski_sum.shape[1:] == (8, 2), f"Shape mismatch: {minkowski_sum.shape}, expected (*, 8, 2)"
|
66 |
+
signed_distances_flat = (
|
67 |
+
geometry_utils.signed_distance_from_point_to_convex_polygon(
|
68 |
+
query_points=torch.zeros_like(minkowski_sum[:, 0, :]),
|
69 |
+
polygon_points=minkowski_sum,
|
70 |
+
)
|
71 |
+
)
|
72 |
+
|
73 |
+
signed_distances = signed_distances_flat.reshape(num_eval_objects, num_objects, num_steps)
|
74 |
+
|
75 |
+
eval_shrinking_distance = shrinking_distance[evaluated_object_mask]
|
76 |
+
other_shrinking_distance = shrinking_distance[~evaluated_object_mask]
|
77 |
+
all_shrinking_distance = torch.cat([eval_shrinking_distance, other_shrinking_distance], dim=0)
|
78 |
+
|
79 |
+
signed_distances -= eval_shrinking_distance.unsqueeze(1)
|
80 |
+
signed_distances -= all_shrinking_distance.unsqueeze(0)
|
81 |
+
|
82 |
+
self_mask = torch.eye(num_eval_objects, num_objects, dtype=torch.float32)[:, :, None]
|
83 |
+
signed_distances = signed_distances + self_mask * EXTREMELY_LARGE_DISTANCE
|
84 |
+
|
85 |
+
eval_validity = valid[evaluated_object_mask]
|
86 |
+
other_validity = valid[~evaluated_object_mask]
|
87 |
+
all_validity = torch.cat([eval_validity, other_validity], dim=0)
|
88 |
+
|
89 |
+
valid_mask = eval_validity.unsqueeze(1) & all_validity.unsqueeze(0)
|
90 |
+
|
91 |
+
signed_distances = torch.where(valid_mask, signed_distances, EXTREMELY_LARGE_DISTANCE)
|
92 |
+
|
93 |
+
return torch.min(signed_distances, dim=1).values
|
94 |
+
|
95 |
+
|
96 |
+
def compute_time_to_collision_with_object_in_front(
|
97 |
+
*,
|
98 |
+
center_x: Tensor,
|
99 |
+
center_y: Tensor,
|
100 |
+
length: Tensor,
|
101 |
+
width: Tensor,
|
102 |
+
heading: Tensor,
|
103 |
+
valid: Tensor,
|
104 |
+
evaluated_object_mask: Tensor,
|
105 |
+
seconds_per_step: float,
|
106 |
+
) -> Tensor:
|
107 |
+
"""Computes the time-to-collision of the evaluated objects."""
|
108 |
+
# `speed` shape: (num_objects, num_steps)
|
109 |
+
speed = trajectory_features.compute_kinematic_features(
|
110 |
+
x=center_x,
|
111 |
+
y=center_y,
|
112 |
+
z=torch.zeros_like(center_x),
|
113 |
+
heading=heading,
|
114 |
+
seconds_per_step=seconds_per_step,
|
115 |
+
)[0]
|
116 |
+
|
117 |
+
boxes = torch.stack([center_x, center_y, length, width, heading, speed], dim=-1)
|
118 |
+
boxes = boxes.permute(1, 0, 2) # (num_steps, num_objects, 6)
|
119 |
+
valid = valid.permute(1, 0)
|
120 |
+
|
121 |
+
eval_boxes = boxes[:, evaluated_object_mask]
|
122 |
+
ego_xy, ego_sizes, ego_yaw, ego_speed = torch.split(eval_boxes, [2, 2, 1, 1], dim=-1)
|
123 |
+
other_xy, other_sizes, other_yaw, _ = torch.split(boxes, [2, 2, 1, 1], dim=-1)
|
124 |
+
|
125 |
+
yaw_diff = torch.abs(other_yaw[:, None] - ego_yaw[:, :, None])
|
126 |
+
yaw_diff_cos = torch.cos(yaw_diff)
|
127 |
+
yaw_diff_sin = torch.sin(yaw_diff)
|
128 |
+
|
129 |
+
other_long_offset = geometry_utils.dot_product_2d(
|
130 |
+
other_sizes[:, None] / 2.0, torch.abs(torch.cat([yaw_diff_cos, yaw_diff_sin], dim=-1))
|
131 |
+
)
|
132 |
+
other_lat_offset = geometry_utils.dot_product_2d(
|
133 |
+
other_sizes[:, None] / 2.0, torch.abs(torch.cat([yaw_diff_sin, yaw_diff_cos], dim=-1))
|
134 |
+
)
|
135 |
+
|
136 |
+
other_relative_xy = geometry_utils.rotate_2d_points(
|
137 |
+
(other_xy[:, None] - ego_xy[:, :, None]), -ego_yaw
|
138 |
+
)
|
139 |
+
|
140 |
+
long_distance = (
|
141 |
+
other_relative_xy[..., 0] - ego_sizes[:, :, None, 0] / 2.0 - other_long_offset
|
142 |
+
)
|
143 |
+
lat_overlap = (
|
144 |
+
torch.abs(other_relative_xy[..., 1]) - ego_sizes[:, :, None, 1] / 2.0 - other_lat_offset
|
145 |
+
)
|
146 |
+
|
147 |
+
following_mask = _get_object_following_mask(
|
148 |
+
long_distance, lat_overlap, yaw_diff[..., 0]
|
149 |
+
)
|
150 |
+
valid_mask = valid[:, None] & following_mask
|
151 |
+
|
152 |
+
masked_long_distance = (
|
153 |
+
long_distance + (1.0 - valid_mask.float()) * EXTREMELY_LARGE_DISTANCE
|
154 |
+
)
|
155 |
+
|
156 |
+
box_ahead_index = masked_long_distance.argmin(dim=-1)
|
157 |
+
distance_to_box_ahead = torch.gather(
|
158 |
+
masked_long_distance, -1, box_ahead_index.unsqueeze(-1)
|
159 |
+
).squeeze(-1)
|
160 |
+
|
161 |
+
speed_broadcast = speed.T[:, None, :].expand_as(masked_long_distance)
|
162 |
+
box_ahead_speed = torch.gather(speed_broadcast, -1, box_ahead_index.unsqueeze(-1)).squeeze(-1)
|
163 |
+
|
164 |
+
rel_speed = ego_speed[..., 0] - box_ahead_speed
|
165 |
+
time_to_collision = torch.where(
|
166 |
+
rel_speed > 0.0,
|
167 |
+
torch.minimum(distance_to_box_ahead / rel_speed,
|
168 |
+
torch.tensor(MAXIMUM_TIME_TO_COLLISION)), # the float will be broadcasted automatically
|
169 |
+
MAXIMUM_TIME_TO_COLLISION,
|
170 |
+
)
|
171 |
+
|
172 |
+
return time_to_collision.T
|
173 |
+
|
174 |
+
|
175 |
+
def _get_object_following_mask(
|
176 |
+
longitudinal_distance: Tensor,
|
177 |
+
lateral_overlap: Tensor,
|
178 |
+
yaw_diff: Tensor,
|
179 |
+
) -> Tensor:
|
180 |
+
"""Checks whether objects satisfy criteria for following another object.
|
181 |
+
|
182 |
+
An object on which the criteria are applied is called "ego object" in this
|
183 |
+
function to disambiguate it from the other objects acting as obstacles.
|
184 |
+
|
185 |
+
An "ego" object is considered to be following another object if they satisfy
|
186 |
+
conditions on the longitudinal distance, lateral overlap, and yaw alignment
|
187 |
+
between them.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
longitudinal_distance: A float Tensor with shape (batch_dim, num_egos,
|
191 |
+
num_others) containing longitudinal distances from the back side of each
|
192 |
+
ego box to every other box.
|
193 |
+
lateral_overlap: A float Tensor with shape (batch_dim, num_egos, num_others)
|
194 |
+
containing lateral overlaps of other boxes over the trails of ego boxes.
|
195 |
+
yaw_diff: A float Tensor with shape (batch_dim, num_egos, num_others)
|
196 |
+
containing absolute yaw differences between egos and other boxes.
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
A boolean Tensor with shape (batch_dim, num_egos, num_others) indicating for
|
200 |
+
each ego box if it is following the other boxes.
|
201 |
+
"""
|
202 |
+
# Check object is ahead of the ego box's front.
|
203 |
+
valid_mask = longitudinal_distance > 0.0
|
204 |
+
|
205 |
+
# Check alignment.
|
206 |
+
valid_mask = torch.logical_and(valid_mask, yaw_diff <= MAX_HEADING_DIFF)
|
207 |
+
|
208 |
+
# Check object is directly ahead of the ego box.
|
209 |
+
valid_mask = torch.logical_and(valid_mask, lateral_overlap < 0.0)
|
210 |
+
|
211 |
+
# Check strict alignment if the overlap is small.
|
212 |
+
# `lateral_overlap` is a signed penetration distance: it is negative when the
|
213 |
+
# boxes have an actual lateral overlap.
|
214 |
+
return torch.logical_and(
|
215 |
+
valid_mask,
|
216 |
+
torch.logical_or(
|
217 |
+
lateral_overlap < -SMALL_OVERLAP_THRESHOLD,
|
218 |
+
yaw_diff <= MAX_HEADING_DIFF_FOR_SMALL_OVERLAP,
|
219 |
+
),
|
220 |
+
)
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/metrics/map_features.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import Tensor
|
3 |
+
from typing import Optional, Sequence
|
4 |
+
|
5 |
+
from dev.metrics import box_utils
|
6 |
+
from dev.metrics import geometry_utils
|
7 |
+
from dev.metrics.protos import map_pb2
|
8 |
+
|
9 |
+
# Constant distance to apply when distances are invalid. This will avoid the
|
10 |
+
# propagation of nans and should be reduced out when taking the minimum anyway.
|
11 |
+
EXTREMELY_LARGE_DISTANCE = 1e10
|
12 |
+
# Off-road threshold, i.e. smallest distance away from the road edge that is
|
13 |
+
# considered to be a off-road.
|
14 |
+
OFFROAD_DISTANCE_THRESHOLD = 0.0
|
15 |
+
|
16 |
+
# How close the start and end point of a map feature need to be for the feature
|
17 |
+
# to be considered cyclic, in m^2.
|
18 |
+
_CYCLIC_MAP_FEATURE_TOLERANCE_M2 = 1.0
|
19 |
+
# Scaling factor for vertical distances used when finding the closest segment to
|
20 |
+
# a query point. This prevents wrong associations in cases with under- and
|
21 |
+
# over-passes.
|
22 |
+
_Z_STRETCH_FACTOR = 3.0
|
23 |
+
|
24 |
+
_Polyline = Sequence[map_pb2.MapPoint]
|
25 |
+
|
26 |
+
|
27 |
+
def compute_distance_to_road_edge(
|
28 |
+
*,
|
29 |
+
center_x: Tensor,
|
30 |
+
center_y: Tensor,
|
31 |
+
center_z: Tensor,
|
32 |
+
length: Tensor,
|
33 |
+
width: Tensor,
|
34 |
+
height: Tensor,
|
35 |
+
heading: Tensor,
|
36 |
+
valid: Tensor,
|
37 |
+
evaluated_object_mask: Tensor,
|
38 |
+
road_edge_polylines: Sequence[_Polyline],
|
39 |
+
) -> Tensor:
|
40 |
+
"""Computes the distance to the road edge for each of the evaluated objects."""
|
41 |
+
if not road_edge_polylines:
|
42 |
+
raise ValueError('Missing road edges.')
|
43 |
+
|
44 |
+
# Concatenate tensors to have the same convention as `box_utils`.
|
45 |
+
boxes = torch.stack([center_x, center_y, center_z, length, width, height, heading], dim=-1)
|
46 |
+
num_objects, num_steps, num_features = boxes.shape
|
47 |
+
boxes = boxes.reshape(num_objects * num_steps, num_features)
|
48 |
+
|
49 |
+
# Compute box corners using `box_utils`, and take the xyz coords of the bottom corners.
|
50 |
+
box_corners = box_utils.get_upright_3d_box_corners(boxes)[:, :4]
|
51 |
+
box_corners = box_corners.reshape(num_objects, num_steps, 4, 3)
|
52 |
+
|
53 |
+
# Gather objects in the evaluation set
|
54 |
+
eval_corners = box_corners[evaluated_object_mask]
|
55 |
+
num_eval_objects = eval_corners.shape[0]
|
56 |
+
|
57 |
+
# Flatten query points.
|
58 |
+
flat_eval_corners = eval_corners.reshape(-1, 3)
|
59 |
+
|
60 |
+
# Tensorize road edges.
|
61 |
+
polylines_tensor = _tensorize_polylines(road_edge_polylines)
|
62 |
+
is_polyline_cyclic = _check_polyline_cycles(road_edge_polylines)
|
63 |
+
|
64 |
+
# Compute distances for all query points.
|
65 |
+
corner_distance_to_road_edge = _compute_signed_distance_to_polylines(
|
66 |
+
xyzs=flat_eval_corners, polylines=polylines_tensor,
|
67 |
+
is_polyline_cyclic=is_polyline_cyclic, z_stretch=_Z_STRETCH_FACTOR
|
68 |
+
)
|
69 |
+
|
70 |
+
# Reshape back to (num_evaluated_objects, num_steps, 4)
|
71 |
+
corner_distance_to_road_edge = corner_distance_to_road_edge.reshape(num_eval_objects, num_steps, 4)
|
72 |
+
|
73 |
+
# Reduce to most off-road corner.
|
74 |
+
signed_distances = torch.max(corner_distance_to_road_edge, dim=-1)[0]
|
75 |
+
|
76 |
+
# Mask out invalid boxes.
|
77 |
+
eval_validity = valid[evaluated_object_mask]
|
78 |
+
|
79 |
+
return torch.where(eval_validity, signed_distances, -EXTREMELY_LARGE_DISTANCE)
|
80 |
+
|
81 |
+
|
82 |
+
def _tensorize_polylines(polylines):
|
83 |
+
"""Stacks a sequence of polylines into a tensor.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
polylines: A sequence of Polyline objects.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
A float tensor with shape (num_polylines, max_length, 4) containing xyz
|
90 |
+
coordinates and a validity flag for all points in the polylines. Polylines
|
91 |
+
are padded with zeros up to the length of the longest one.
|
92 |
+
"""
|
93 |
+
polyline_tensors = []
|
94 |
+
max_length = 0
|
95 |
+
|
96 |
+
for polyline in polylines:
|
97 |
+
# Skip degenerate polylines.
|
98 |
+
if len(polyline) < 2:
|
99 |
+
continue
|
100 |
+
max_length = max(max_length, len(polyline))
|
101 |
+
polyline_tensors.append(
|
102 |
+
torch.tensor(
|
103 |
+
[[map_point.x, map_point.y, map_point.z, 1.0] for map_point in polyline],
|
104 |
+
dtype=torch.float32
|
105 |
+
)
|
106 |
+
)
|
107 |
+
|
108 |
+
# Pad and stack polylines
|
109 |
+
padded_polylines = [
|
110 |
+
torch.cat([p, torch.zeros((max_length - p.shape[0], 4), dtype=torch.float32)], dim=0)
|
111 |
+
for p in polyline_tensors
|
112 |
+
]
|
113 |
+
|
114 |
+
return torch.stack(padded_polylines, dim=0)
|
115 |
+
|
116 |
+
|
117 |
+
def _check_polyline_cycles(polylines):
|
118 |
+
"""Checks if given polylines are cyclic and returns the result as a tensor.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
polylines: A sequence of Polyline objects.
|
122 |
+
tolerance: A float representing the cyclic tolerance.
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
A bool tensor with shape (num_polylines) indicating whether each polyline is cyclic.
|
126 |
+
"""
|
127 |
+
cycles = []
|
128 |
+
for polyline in polylines:
|
129 |
+
# Skip degenerate polylines.
|
130 |
+
if len(polyline) < 2:
|
131 |
+
continue
|
132 |
+
first_point = torch.tensor([polyline[0].x, polyline[0].y, polyline[0].z], dtype=torch.float32)
|
133 |
+
last_point = torch.tensor([polyline[-1].x, polyline[-1].y, polyline[-1].z], dtype=torch.float32)
|
134 |
+
cycles.append(torch.sum((first_point - last_point) ** 2) < _CYCLIC_MAP_FEATURE_TOLERANCE_M2)
|
135 |
+
|
136 |
+
return torch.stack(cycles, dim=0)
|
137 |
+
|
138 |
+
|
139 |
+
def _compute_signed_distance_to_polylines(
|
140 |
+
xyzs: Tensor,
|
141 |
+
polylines: Tensor,
|
142 |
+
is_polyline_cyclic: Optional[Tensor] = None,
|
143 |
+
z_stretch: float = 1.0,
|
144 |
+
) -> Tensor:
|
145 |
+
"""Computes the signed distance to the 2D boundary defined by polylines.
|
146 |
+
|
147 |
+
Negative distances correspond to being inside the boundary (e.g. on the
|
148 |
+
road), positive distances to being outside (e.g. off-road).
|
149 |
+
|
150 |
+
The polylines should be oriented such that port side is inside the boundary
|
151 |
+
and starboard is outside, a.k.a counterclockwise winding order.
|
152 |
+
|
153 |
+
The altitudes i.e. the z-coordinates of query points and polyline segments
|
154 |
+
are used to pair each query point with the most relevant segment, that is
|
155 |
+
closest and at the right altitude. The distances returned are 2D distances in
|
156 |
+
the xy plane.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
xyzs: A float Tensor of shape (num_points, 3) containing xyz coordinates of
|
160 |
+
query points.
|
161 |
+
polylines: Tensor with shape (num_polylines, num_segments+1, 4) containing
|
162 |
+
sequences of xyz coordinates and validity, representing start and end
|
163 |
+
points of consecutive segments.
|
164 |
+
is_polyline_cyclic: A boolean Tensor with shape (num_polylines) indicating
|
165 |
+
whether each polyline is cyclic. If None, all polylines are considered
|
166 |
+
non-cyclic.
|
167 |
+
z_stretch: Factor by which to scale distances over the z axis. This can be
|
168 |
+
done to ensure edge points from the wrong level (e.g. overpasses) are not
|
169 |
+
selected. Defaults to 1.0 (no stretching).
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
A tensor of shape (num_points), containing the signed 2D distance from
|
173 |
+
queried points to the nearest polyline.
|
174 |
+
"""
|
175 |
+
num_points = xyzs.shape[0]
|
176 |
+
assert xyzs.shape == (num_points, 3), f"Expected shape ({num_points}, 3), but got {xyzs.shape}"
|
177 |
+
num_polylines = polylines.shape[0]
|
178 |
+
num_segments = polylines.shape[1] - 1
|
179 |
+
assert polylines.shape == (num_polylines, num_segments + 1, 4), \
|
180 |
+
f"Expected shape ({num_polylines}, {num_segments + 1}, 4), but got {polylines.shape}"
|
181 |
+
|
182 |
+
# shape: (num_polylines, num_segments+1)
|
183 |
+
is_point_valid = polylines[:, :, 3].bool()
|
184 |
+
# shape: (num_polylines, num_segments)
|
185 |
+
is_segment_valid = is_point_valid[:, :-1] & is_point_valid[:, 1:]
|
186 |
+
|
187 |
+
if is_polyline_cyclic is None:
|
188 |
+
is_polyline_cyclic = torch.zeros(num_polylines, dtype=torch.bool)
|
189 |
+
else:
|
190 |
+
assert is_polyline_cyclic.shape == (num_polylines,), \
|
191 |
+
f"Expected shape ({num_polylines},), but got {is_polyline_cyclic.shape}"
|
192 |
+
|
193 |
+
# Get distance to each segment.
|
194 |
+
# shape: (num_points, num_polylines, num_segments, 3)
|
195 |
+
xyz_starts = polylines[None, :, :-1, :3]
|
196 |
+
xyz_ends = polylines[None, :, 1:, :3]
|
197 |
+
start_to_point = xyzs[:, None, None, :3] - xyz_starts
|
198 |
+
start_to_end = xyz_ends - xyz_starts
|
199 |
+
|
200 |
+
# Relative coordinate of point projection on segment.
|
201 |
+
# shape: (num_points, num_polylines, num_segments)
|
202 |
+
numerator = geometry_utils.dot_product_2d(
|
203 |
+
start_to_point[..., :2], start_to_end[..., :2]
|
204 |
+
)
|
205 |
+
denominator = geometry_utils.dot_product_2d(
|
206 |
+
start_to_end[..., :2], start_to_end[..., :2]
|
207 |
+
)
|
208 |
+
rel_t = torch.where(denominator != 0, numerator / denominator, torch.zeros_like(numerator))
|
209 |
+
|
210 |
+
# Negative if point is on port side of segment, positive if point on
|
211 |
+
# starboard side of segment.
|
212 |
+
# shape: (num_points, num_polylines, num_segments)
|
213 |
+
n = torch.sign(
|
214 |
+
geometry_utils.cross_product_2d(
|
215 |
+
start_to_point[..., :2], start_to_end[..., :2]
|
216 |
+
)
|
217 |
+
)
|
218 |
+
|
219 |
+
# Compute the absolute 3D distance to segment.
|
220 |
+
# The vertical component is scaled by `z-stretch` to increase the separation
|
221 |
+
# between different road altitudes.
|
222 |
+
# shape: (num_points, num_polylines, num_segments, 3)
|
223 |
+
segment_to_point = start_to_point - (
|
224 |
+
start_to_end * torch.clamp(rel_t, 0.0, 1.0)[..., None]
|
225 |
+
)
|
226 |
+
stretch_vector = torch.tensor([1.0, 1.0, z_stretch], dtype=torch.float32)
|
227 |
+
distance_to_segment_3d = torch.norm(
|
228 |
+
segment_to_point * stretch_vector[None, None, None],
|
229 |
+
dim=-1,
|
230 |
+
)
|
231 |
+
|
232 |
+
# Absolute planar distance to segment.
|
233 |
+
# shape: (num_points, num_polylines, num_segments)
|
234 |
+
distance_to_segment_2d = torch.norm(segment_to_point[..., :2], dim=-1)
|
235 |
+
|
236 |
+
# Padded start-to-end segments.
|
237 |
+
# shape: (num_points, num_polylines, num_segments+2, 2)
|
238 |
+
start_to_end_padded = torch.cat(
|
239 |
+
[
|
240 |
+
start_to_end[:, :, -1:, :2],
|
241 |
+
start_to_end[..., :2],
|
242 |
+
start_to_end[:, :, :1, :2],
|
243 |
+
],
|
244 |
+
dim=-2,
|
245 |
+
)
|
246 |
+
|
247 |
+
# shape: (num_points, num_polylines, num_segments+1)
|
248 |
+
is_locally_convex = torch.gt(
|
249 |
+
geometry_utils.cross_product_2d(
|
250 |
+
start_to_end_padded[:, :, :-1], start_to_end_padded[:, :, 1:]
|
251 |
+
),
|
252 |
+
0.,
|
253 |
+
)
|
254 |
+
|
255 |
+
# Get shifted versions of `n` and `is_segment_valid`. If the polyline is
|
256 |
+
# cyclic, the tensors are rolled, else they are padded with their edge value.
|
257 |
+
# shape: (num_points, num_polylines, num_segments)
|
258 |
+
n_prior = torch.cat(
|
259 |
+
[
|
260 |
+
torch.where(
|
261 |
+
is_polyline_cyclic[None, :, None],
|
262 |
+
n[:, :, -1:],
|
263 |
+
n[:, :, :1],
|
264 |
+
),
|
265 |
+
n[:, :, :-1],
|
266 |
+
],
|
267 |
+
dim=-1,
|
268 |
+
)
|
269 |
+
n_next = torch.cat(
|
270 |
+
[
|
271 |
+
n[:, :, 1:],
|
272 |
+
torch.where(
|
273 |
+
is_polyline_cyclic[None, :, None],
|
274 |
+
n[:, :, :1],
|
275 |
+
n[:, :, -1:],
|
276 |
+
),
|
277 |
+
],
|
278 |
+
dim=-1,
|
279 |
+
)
|
280 |
+
# shape: (num_polylines, num_segments)
|
281 |
+
is_prior_segment_valid = torch.cat(
|
282 |
+
[
|
283 |
+
torch.where(
|
284 |
+
is_polyline_cyclic[:, None],
|
285 |
+
is_segment_valid[:, -1:],
|
286 |
+
is_segment_valid[:, :1],
|
287 |
+
),
|
288 |
+
is_segment_valid[:, :-1],
|
289 |
+
],
|
290 |
+
dim=-1,
|
291 |
+
)
|
292 |
+
is_next_segment_valid = torch.cat(
|
293 |
+
[
|
294 |
+
is_segment_valid[:, 1:],
|
295 |
+
torch.where(
|
296 |
+
is_polyline_cyclic[:, None],
|
297 |
+
is_segment_valid[:, :1],
|
298 |
+
is_segment_valid[:, -1:],
|
299 |
+
),
|
300 |
+
],
|
301 |
+
dim=-1,
|
302 |
+
)
|
303 |
+
|
304 |
+
# shape: (num_points, num_polylines, num_segments)
|
305 |
+
sign_if_before = torch.where(
|
306 |
+
is_locally_convex[:, :, :-1],
|
307 |
+
torch.maximum(n, n_prior),
|
308 |
+
torch.minimum(n, n_prior),
|
309 |
+
)
|
310 |
+
sign_if_after = torch.where(
|
311 |
+
is_locally_convex[:, :, 1:], torch.maximum(n, n_next), torch.minimum(n, n_next)
|
312 |
+
)
|
313 |
+
|
314 |
+
# shape: (num_points, num_polylines, num_segments)
|
315 |
+
sign_to_segment = torch.where(
|
316 |
+
(rel_t < 0.0) & is_prior_segment_valid,
|
317 |
+
sign_if_before,
|
318 |
+
torch.where((rel_t > 1.0) & is_next_segment_valid, sign_if_after, n),
|
319 |
+
)
|
320 |
+
|
321 |
+
# Flatten polylines together.
|
322 |
+
# shape: (num_points, all_segments)
|
323 |
+
distance_to_segment_3d = distance_to_segment_3d.view(num_points, num_polylines * num_segments)
|
324 |
+
distance_to_segment_2d = distance_to_segment_2d.view(num_points, num_polylines * num_segments)
|
325 |
+
sign_to_segment = sign_to_segment.view(num_points, num_polylines * num_segments)
|
326 |
+
|
327 |
+
# Mask out invalid segments.
|
328 |
+
# shape: (all_segments)
|
329 |
+
is_segment_valid = is_segment_valid.view(num_polylines * num_segments)
|
330 |
+
# shape: (num_points, all_segments)
|
331 |
+
distance_to_segment_3d = torch.where(
|
332 |
+
is_segment_valid[None],
|
333 |
+
distance_to_segment_3d,
|
334 |
+
EXTREMELY_LARGE_DISTANCE,
|
335 |
+
)
|
336 |
+
distance_to_segment_2d = torch.where(
|
337 |
+
is_segment_valid[None],
|
338 |
+
distance_to_segment_2d,
|
339 |
+
EXTREMELY_LARGE_DISTANCE,
|
340 |
+
)
|
341 |
+
|
342 |
+
# Get closest segment according to absolute 3D distance and return the
|
343 |
+
# corresponding signed 2D distance.
|
344 |
+
# shape: (num_points)
|
345 |
+
closest_segment_index = torch.argmin(distance_to_segment_3d, dim=-1)
|
346 |
+
distance_sign = torch.gather(sign_to_segment, 1, closest_segment_index.unsqueeze(-1)).squeeze(-1)
|
347 |
+
distance_2d = torch.gather(distance_to_segment_2d, 1, closest_segment_index.unsqueeze(-1)).squeeze(-1)
|
348 |
+
|
349 |
+
return distance_sign * distance_2d
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/metrics/placement_features.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import Tensor
|
3 |
+
from typing import Optional, Sequence, List
|
4 |
+
|
5 |
+
|
6 |
+
def compute_num_placement(
|
7 |
+
valid: Tensor, # [n_agent, n_step]
|
8 |
+
state: Tensor, # [n_agent, n_step]
|
9 |
+
av_id: int,
|
10 |
+
object_id: Tensor,
|
11 |
+
agent_state: List[str],
|
12 |
+
) -> Tensor:
|
13 |
+
|
14 |
+
enter_state = agent_state.index('enter')
|
15 |
+
exit_state = agent_state.index('exit')
|
16 |
+
|
17 |
+
av_index = object_id.tolist().index(av_id)
|
18 |
+
state[av_index] = -1 # we do not incorporate the sdc
|
19 |
+
|
20 |
+
is_bos = state == enter_state
|
21 |
+
is_eos = state == exit_state
|
22 |
+
|
23 |
+
num_bos = torch.sum(is_bos, dim=0)
|
24 |
+
num_eos = torch.sum(is_eos, dim=0)
|
25 |
+
|
26 |
+
return num_bos, num_eos
|
27 |
+
|
28 |
+
|
29 |
+
def compute_distance_placement(
|
30 |
+
position: Tensor,
|
31 |
+
state: Tensor,
|
32 |
+
valid: Tensor,
|
33 |
+
av_id: int,
|
34 |
+
object_id: Tensor,
|
35 |
+
agent_state: List[str],
|
36 |
+
) -> Tensor:
|
37 |
+
|
38 |
+
enter_state = agent_state.index('enter')
|
39 |
+
exit_state = agent_state.index('exit')
|
40 |
+
|
41 |
+
av_index = object_id.tolist().index(av_id)
|
42 |
+
state[av_index] = -1 # we do not incorporate the sdc
|
43 |
+
distance = torch.norm(position - position[av_index : av_index + 1], p=2, dim=-1)
|
44 |
+
|
45 |
+
bos_distance = distance * (state == enter_state)
|
46 |
+
eos_distance = distance * (state == exit_state)
|
47 |
+
|
48 |
+
return bos_distance, eos_distance
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/metrics/protos/long_metrics_pb2.py
ADDED
@@ -0,0 +1,648 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
3 |
+
# source: dev/metrics/protos/long_metrics.proto
|
4 |
+
|
5 |
+
from google.protobuf import descriptor as _descriptor
|
6 |
+
from google.protobuf import message as _message
|
7 |
+
from google.protobuf import reflection as _reflection
|
8 |
+
from google.protobuf import symbol_database as _symbol_database
|
9 |
+
# @@protoc_insertion_point(imports)
|
10 |
+
|
11 |
+
_sym_db = _symbol_database.Default()
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
DESCRIPTOR = _descriptor.FileDescriptor(
|
17 |
+
name='dev/metrics/protos/long_metrics.proto',
|
18 |
+
package='long_metric',
|
19 |
+
syntax='proto2',
|
20 |
+
serialized_options=None,
|
21 |
+
create_key=_descriptor._internal_create_key,
|
22 |
+
serialized_pb=b'\n%dev/metrics/protos/long_metrics.proto\x12\x0blong_metric\"\xb4\x0c\n\x15SimAgentMetricsConfig\x12\x46\n\x0clinear_speed\x18\x01 \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12M\n\x13linear_acceleration\x18\x02 \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12G\n\rangular_speed\x18\x03 \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12N\n\x14\x61ngular_acceleration\x18\x04 \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12T\n\x1a\x64istance_to_nearest_object\x18\x05 \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12N\n\x14\x63ollision_indication\x18\x06 \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12K\n\x11time_to_collision\x18\x07 \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12O\n\x15\x64istance_to_road_edge\x18\x08 \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12L\n\x12offroad_indication\x18\t \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12G\n\rnum_placement\x18\n \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12H\n\x0enum_removement\x18\x0b \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12L\n\x12\x64istance_placement\x18\x0c \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x12M\n\x13\x64istance_removement\x18\r \x01(\x0b\x32\x30.long_metric.SimAgentMetricsConfig.FeatureConfig\x1a\xc0\x02\n\rFeatureConfig\x12I\n\thistogram\x18\x01 \x01(\x0b\x32\x34.long_metric.SimAgentMetricsConfig.HistogramEstimateH\x00\x12R\n\x0ekernel_density\x18\x02 \x01(\x0b\x32\x38.long_metric.SimAgentMetricsConfig.KernelDensityEstimateH\x00\x12I\n\tbernoulli\x18\x03 \x01(\x0b\x32\x34.long_metric.SimAgentMetricsConfig.BernoulliEstimateH\x00\x12\x1d\n\x15independent_timesteps\x18\x04 \x01(\x08\x12\x19\n\x11metametric_weight\x18\x05 \x01(\x02\x42\x0b\n\testimator\x1av\n\x11HistogramEstimate\x12\x0f\n\x07min_val\x18\x01 \x01(\x02\x12\x0f\n\x07max_val\x18\x02 \x01(\x02\x12\x10\n\x08num_bins\x18\x03 \x01(\x05\x12-\n\x1e\x61\x64\x64itive_smoothing_pseudocount\x18\x04 \x01(\x02:\x05\x30.001\x1a*\n\x15KernelDensityEstimate\x12\x11\n\tbandwidth\x18\x01 \x01(\x02\x1a\x42\n\x11\x42\x65rnoulliEstimate\x12-\n\x1e\x61\x64\x64itive_smoothing_pseudocount\x18\x04 \x01(\x02:\x05\x30.001\"\xbf\x05\n\x0fSimAgentMetrics\x12\x13\n\x0bscenario_id\x18\x01 \x01(\t\x12\x12\n\nmetametric\x18\x02 \x01(\x02\x12\"\n\x1a\x61verage_displacement_error\x18\x03 \x01(\x02\x12&\n\x1emin_average_displacement_error\x18\x13 \x01(\x02\x12\x1f\n\x17linear_speed_likelihood\x18\x04 \x01(\x02\x12&\n\x1elinear_acceleration_likelihood\x18\x05 \x01(\x02\x12 \n\x18\x61ngular_speed_likelihood\x18\x06 \x01(\x02\x12\'\n\x1f\x61ngular_acceleration_likelihood\x18\x07 \x01(\x02\x12-\n%distance_to_nearest_object_likelihood\x18\x08 \x01(\x02\x12\'\n\x1f\x63ollision_indication_likelihood\x18\t \x01(\x02\x12$\n\x1ctime_to_collision_likelihood\x18\n \x01(\x02\x12(\n distance_to_road_edge_likelihood\x18\x0b \x01(\x02\x12%\n\x1doffroad_indication_likelihood\x18\x0c \x01(\x02\x12 \n\x18num_placement_likelihood\x18\r \x01(\x02\x12!\n\x19num_removement_likelihood\x18\x0e \x01(\x02\x12%\n\x1d\x64istance_placement_likelihood\x18\x0f \x01(\x02\x12&\n\x1e\x64istance_removement_likelihood\x18\x10 \x01(\x02\x12 \n\x18simulated_collision_rate\x18\x11 \x01(\x02\x12\x1e\n\x16simulated_offroad_rate\x18\x12 \x01(\x02\"\xfe\x01\n\x18SimAgentsBucketedMetrics\x12\x1b\n\x13realism_meta_metric\x18\x01 \x01(\x02\x12\x19\n\x11kinematic_metrics\x18\x02 \x01(\x02\x12\x1b\n\x13interactive_metrics\x18\x05 \x01(\x02\x12\x19\n\x11map_based_metrics\x18\x06 \x01(\x02\x12\x1f\n\x17placement_based_metrics\x18\x07 \x01(\x02\x12\x0f\n\x07min_ade\x18\x08 \x01(\x02\x12 \n\x18simulated_collision_rate\x18\t \x01(\x02\x12\x1e\n\x16simulated_offroad_rate\x18\n \x01(\x02'
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
_SIMAGENTMETRICSCONFIG_FEATURECONFIG = _descriptor.Descriptor(
|
29 |
+
name='FeatureConfig',
|
30 |
+
full_name='long_metric.SimAgentMetricsConfig.FeatureConfig',
|
31 |
+
filename=None,
|
32 |
+
file=DESCRIPTOR,
|
33 |
+
containing_type=None,
|
34 |
+
create_key=_descriptor._internal_create_key,
|
35 |
+
fields=[
|
36 |
+
_descriptor.FieldDescriptor(
|
37 |
+
name='histogram', full_name='long_metric.SimAgentMetricsConfig.FeatureConfig.histogram', index=0,
|
38 |
+
number=1, type=11, cpp_type=10, label=1,
|
39 |
+
has_default_value=False, default_value=None,
|
40 |
+
message_type=None, enum_type=None, containing_type=None,
|
41 |
+
is_extension=False, extension_scope=None,
|
42 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
43 |
+
_descriptor.FieldDescriptor(
|
44 |
+
name='kernel_density', full_name='long_metric.SimAgentMetricsConfig.FeatureConfig.kernel_density', index=1,
|
45 |
+
number=2, type=11, cpp_type=10, label=1,
|
46 |
+
has_default_value=False, default_value=None,
|
47 |
+
message_type=None, enum_type=None, containing_type=None,
|
48 |
+
is_extension=False, extension_scope=None,
|
49 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
50 |
+
_descriptor.FieldDescriptor(
|
51 |
+
name='bernoulli', full_name='long_metric.SimAgentMetricsConfig.FeatureConfig.bernoulli', index=2,
|
52 |
+
number=3, type=11, cpp_type=10, label=1,
|
53 |
+
has_default_value=False, default_value=None,
|
54 |
+
message_type=None, enum_type=None, containing_type=None,
|
55 |
+
is_extension=False, extension_scope=None,
|
56 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
57 |
+
_descriptor.FieldDescriptor(
|
58 |
+
name='independent_timesteps', full_name='long_metric.SimAgentMetricsConfig.FeatureConfig.independent_timesteps', index=3,
|
59 |
+
number=4, type=8, cpp_type=7, label=1,
|
60 |
+
has_default_value=False, default_value=False,
|
61 |
+
message_type=None, enum_type=None, containing_type=None,
|
62 |
+
is_extension=False, extension_scope=None,
|
63 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
64 |
+
_descriptor.FieldDescriptor(
|
65 |
+
name='metametric_weight', full_name='long_metric.SimAgentMetricsConfig.FeatureConfig.metametric_weight', index=4,
|
66 |
+
number=5, type=2, cpp_type=6, label=1,
|
67 |
+
has_default_value=False, default_value=float(0),
|
68 |
+
message_type=None, enum_type=None, containing_type=None,
|
69 |
+
is_extension=False, extension_scope=None,
|
70 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
71 |
+
],
|
72 |
+
extensions=[
|
73 |
+
],
|
74 |
+
nested_types=[],
|
75 |
+
enum_types=[
|
76 |
+
],
|
77 |
+
serialized_options=None,
|
78 |
+
is_extendable=False,
|
79 |
+
syntax='proto2',
|
80 |
+
extension_ranges=[],
|
81 |
+
oneofs=[
|
82 |
+
_descriptor.OneofDescriptor(
|
83 |
+
name='estimator', full_name='long_metric.SimAgentMetricsConfig.FeatureConfig.estimator',
|
84 |
+
index=0, containing_type=None,
|
85 |
+
create_key=_descriptor._internal_create_key,
|
86 |
+
fields=[]),
|
87 |
+
],
|
88 |
+
serialized_start=1091,
|
89 |
+
serialized_end=1411,
|
90 |
+
)
|
91 |
+
|
92 |
+
_SIMAGENTMETRICSCONFIG_HISTOGRAMESTIMATE = _descriptor.Descriptor(
|
93 |
+
name='HistogramEstimate',
|
94 |
+
full_name='long_metric.SimAgentMetricsConfig.HistogramEstimate',
|
95 |
+
filename=None,
|
96 |
+
file=DESCRIPTOR,
|
97 |
+
containing_type=None,
|
98 |
+
create_key=_descriptor._internal_create_key,
|
99 |
+
fields=[
|
100 |
+
_descriptor.FieldDescriptor(
|
101 |
+
name='min_val', full_name='long_metric.SimAgentMetricsConfig.HistogramEstimate.min_val', index=0,
|
102 |
+
number=1, type=2, cpp_type=6, label=1,
|
103 |
+
has_default_value=False, default_value=float(0),
|
104 |
+
message_type=None, enum_type=None, containing_type=None,
|
105 |
+
is_extension=False, extension_scope=None,
|
106 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
107 |
+
_descriptor.FieldDescriptor(
|
108 |
+
name='max_val', full_name='long_metric.SimAgentMetricsConfig.HistogramEstimate.max_val', index=1,
|
109 |
+
number=2, type=2, cpp_type=6, label=1,
|
110 |
+
has_default_value=False, default_value=float(0),
|
111 |
+
message_type=None, enum_type=None, containing_type=None,
|
112 |
+
is_extension=False, extension_scope=None,
|
113 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
114 |
+
_descriptor.FieldDescriptor(
|
115 |
+
name='num_bins', full_name='long_metric.SimAgentMetricsConfig.HistogramEstimate.num_bins', index=2,
|
116 |
+
number=3, type=5, cpp_type=1, label=1,
|
117 |
+
has_default_value=False, default_value=0,
|
118 |
+
message_type=None, enum_type=None, containing_type=None,
|
119 |
+
is_extension=False, extension_scope=None,
|
120 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
121 |
+
_descriptor.FieldDescriptor(
|
122 |
+
name='additive_smoothing_pseudocount', full_name='long_metric.SimAgentMetricsConfig.HistogramEstimate.additive_smoothing_pseudocount', index=3,
|
123 |
+
number=4, type=2, cpp_type=6, label=1,
|
124 |
+
has_default_value=True, default_value=float(0.001),
|
125 |
+
message_type=None, enum_type=None, containing_type=None,
|
126 |
+
is_extension=False, extension_scope=None,
|
127 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
128 |
+
],
|
129 |
+
extensions=[
|
130 |
+
],
|
131 |
+
nested_types=[],
|
132 |
+
enum_types=[
|
133 |
+
],
|
134 |
+
serialized_options=None,
|
135 |
+
is_extendable=False,
|
136 |
+
syntax='proto2',
|
137 |
+
extension_ranges=[],
|
138 |
+
oneofs=[
|
139 |
+
],
|
140 |
+
serialized_start=1413,
|
141 |
+
serialized_end=1531,
|
142 |
+
)
|
143 |
+
|
144 |
+
_SIMAGENTMETRICSCONFIG_KERNELDENSITYESTIMATE = _descriptor.Descriptor(
|
145 |
+
name='KernelDensityEstimate',
|
146 |
+
full_name='long_metric.SimAgentMetricsConfig.KernelDensityEstimate',
|
147 |
+
filename=None,
|
148 |
+
file=DESCRIPTOR,
|
149 |
+
containing_type=None,
|
150 |
+
create_key=_descriptor._internal_create_key,
|
151 |
+
fields=[
|
152 |
+
_descriptor.FieldDescriptor(
|
153 |
+
name='bandwidth', full_name='long_metric.SimAgentMetricsConfig.KernelDensityEstimate.bandwidth', index=0,
|
154 |
+
number=1, type=2, cpp_type=6, label=1,
|
155 |
+
has_default_value=False, default_value=float(0),
|
156 |
+
message_type=None, enum_type=None, containing_type=None,
|
157 |
+
is_extension=False, extension_scope=None,
|
158 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
159 |
+
],
|
160 |
+
extensions=[
|
161 |
+
],
|
162 |
+
nested_types=[],
|
163 |
+
enum_types=[
|
164 |
+
],
|
165 |
+
serialized_options=None,
|
166 |
+
is_extendable=False,
|
167 |
+
syntax='proto2',
|
168 |
+
extension_ranges=[],
|
169 |
+
oneofs=[
|
170 |
+
],
|
171 |
+
serialized_start=1533,
|
172 |
+
serialized_end=1575,
|
173 |
+
)
|
174 |
+
|
175 |
+
_SIMAGENTMETRICSCONFIG_BERNOULLIESTIMATE = _descriptor.Descriptor(
|
176 |
+
name='BernoulliEstimate',
|
177 |
+
full_name='long_metric.SimAgentMetricsConfig.BernoulliEstimate',
|
178 |
+
filename=None,
|
179 |
+
file=DESCRIPTOR,
|
180 |
+
containing_type=None,
|
181 |
+
create_key=_descriptor._internal_create_key,
|
182 |
+
fields=[
|
183 |
+
_descriptor.FieldDescriptor(
|
184 |
+
name='additive_smoothing_pseudocount', full_name='long_metric.SimAgentMetricsConfig.BernoulliEstimate.additive_smoothing_pseudocount', index=0,
|
185 |
+
number=4, type=2, cpp_type=6, label=1,
|
186 |
+
has_default_value=True, default_value=float(0.001),
|
187 |
+
message_type=None, enum_type=None, containing_type=None,
|
188 |
+
is_extension=False, extension_scope=None,
|
189 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
190 |
+
],
|
191 |
+
extensions=[
|
192 |
+
],
|
193 |
+
nested_types=[],
|
194 |
+
enum_types=[
|
195 |
+
],
|
196 |
+
serialized_options=None,
|
197 |
+
is_extendable=False,
|
198 |
+
syntax='proto2',
|
199 |
+
extension_ranges=[],
|
200 |
+
oneofs=[
|
201 |
+
],
|
202 |
+
serialized_start=1577,
|
203 |
+
serialized_end=1643,
|
204 |
+
)
|
205 |
+
|
206 |
+
_SIMAGENTMETRICSCONFIG = _descriptor.Descriptor(
|
207 |
+
name='SimAgentMetricsConfig',
|
208 |
+
full_name='long_metric.SimAgentMetricsConfig',
|
209 |
+
filename=None,
|
210 |
+
file=DESCRIPTOR,
|
211 |
+
containing_type=None,
|
212 |
+
create_key=_descriptor._internal_create_key,
|
213 |
+
fields=[
|
214 |
+
_descriptor.FieldDescriptor(
|
215 |
+
name='linear_speed', full_name='long_metric.SimAgentMetricsConfig.linear_speed', index=0,
|
216 |
+
number=1, type=11, cpp_type=10, label=1,
|
217 |
+
has_default_value=False, default_value=None,
|
218 |
+
message_type=None, enum_type=None, containing_type=None,
|
219 |
+
is_extension=False, extension_scope=None,
|
220 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
221 |
+
_descriptor.FieldDescriptor(
|
222 |
+
name='linear_acceleration', full_name='long_metric.SimAgentMetricsConfig.linear_acceleration', index=1,
|
223 |
+
number=2, type=11, cpp_type=10, label=1,
|
224 |
+
has_default_value=False, default_value=None,
|
225 |
+
message_type=None, enum_type=None, containing_type=None,
|
226 |
+
is_extension=False, extension_scope=None,
|
227 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
228 |
+
_descriptor.FieldDescriptor(
|
229 |
+
name='angular_speed', full_name='long_metric.SimAgentMetricsConfig.angular_speed', index=2,
|
230 |
+
number=3, type=11, cpp_type=10, label=1,
|
231 |
+
has_default_value=False, default_value=None,
|
232 |
+
message_type=None, enum_type=None, containing_type=None,
|
233 |
+
is_extension=False, extension_scope=None,
|
234 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
235 |
+
_descriptor.FieldDescriptor(
|
236 |
+
name='angular_acceleration', full_name='long_metric.SimAgentMetricsConfig.angular_acceleration', index=3,
|
237 |
+
number=4, type=11, cpp_type=10, label=1,
|
238 |
+
has_default_value=False, default_value=None,
|
239 |
+
message_type=None, enum_type=None, containing_type=None,
|
240 |
+
is_extension=False, extension_scope=None,
|
241 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
242 |
+
_descriptor.FieldDescriptor(
|
243 |
+
name='distance_to_nearest_object', full_name='long_metric.SimAgentMetricsConfig.distance_to_nearest_object', index=4,
|
244 |
+
number=5, type=11, cpp_type=10, label=1,
|
245 |
+
has_default_value=False, default_value=None,
|
246 |
+
message_type=None, enum_type=None, containing_type=None,
|
247 |
+
is_extension=False, extension_scope=None,
|
248 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
249 |
+
_descriptor.FieldDescriptor(
|
250 |
+
name='collision_indication', full_name='long_metric.SimAgentMetricsConfig.collision_indication', index=5,
|
251 |
+
number=6, type=11, cpp_type=10, label=1,
|
252 |
+
has_default_value=False, default_value=None,
|
253 |
+
message_type=None, enum_type=None, containing_type=None,
|
254 |
+
is_extension=False, extension_scope=None,
|
255 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
256 |
+
_descriptor.FieldDescriptor(
|
257 |
+
name='time_to_collision', full_name='long_metric.SimAgentMetricsConfig.time_to_collision', index=6,
|
258 |
+
number=7, type=11, cpp_type=10, label=1,
|
259 |
+
has_default_value=False, default_value=None,
|
260 |
+
message_type=None, enum_type=None, containing_type=None,
|
261 |
+
is_extension=False, extension_scope=None,
|
262 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
263 |
+
_descriptor.FieldDescriptor(
|
264 |
+
name='distance_to_road_edge', full_name='long_metric.SimAgentMetricsConfig.distance_to_road_edge', index=7,
|
265 |
+
number=8, type=11, cpp_type=10, label=1,
|
266 |
+
has_default_value=False, default_value=None,
|
267 |
+
message_type=None, enum_type=None, containing_type=None,
|
268 |
+
is_extension=False, extension_scope=None,
|
269 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
270 |
+
_descriptor.FieldDescriptor(
|
271 |
+
name='offroad_indication', full_name='long_metric.SimAgentMetricsConfig.offroad_indication', index=8,
|
272 |
+
number=9, type=11, cpp_type=10, label=1,
|
273 |
+
has_default_value=False, default_value=None,
|
274 |
+
message_type=None, enum_type=None, containing_type=None,
|
275 |
+
is_extension=False, extension_scope=None,
|
276 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
277 |
+
_descriptor.FieldDescriptor(
|
278 |
+
name='num_placement', full_name='long_metric.SimAgentMetricsConfig.num_placement', index=9,
|
279 |
+
number=10, type=11, cpp_type=10, label=1,
|
280 |
+
has_default_value=False, default_value=None,
|
281 |
+
message_type=None, enum_type=None, containing_type=None,
|
282 |
+
is_extension=False, extension_scope=None,
|
283 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
284 |
+
_descriptor.FieldDescriptor(
|
285 |
+
name='num_removement', full_name='long_metric.SimAgentMetricsConfig.num_removement', index=10,
|
286 |
+
number=11, type=11, cpp_type=10, label=1,
|
287 |
+
has_default_value=False, default_value=None,
|
288 |
+
message_type=None, enum_type=None, containing_type=None,
|
289 |
+
is_extension=False, extension_scope=None,
|
290 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
291 |
+
_descriptor.FieldDescriptor(
|
292 |
+
name='distance_placement', full_name='long_metric.SimAgentMetricsConfig.distance_placement', index=11,
|
293 |
+
number=12, type=11, cpp_type=10, label=1,
|
294 |
+
has_default_value=False, default_value=None,
|
295 |
+
message_type=None, enum_type=None, containing_type=None,
|
296 |
+
is_extension=False, extension_scope=None,
|
297 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
298 |
+
_descriptor.FieldDescriptor(
|
299 |
+
name='distance_removement', full_name='long_metric.SimAgentMetricsConfig.distance_removement', index=12,
|
300 |
+
number=13, type=11, cpp_type=10, label=1,
|
301 |
+
has_default_value=False, default_value=None,
|
302 |
+
message_type=None, enum_type=None, containing_type=None,
|
303 |
+
is_extension=False, extension_scope=None,
|
304 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
305 |
+
],
|
306 |
+
extensions=[
|
307 |
+
],
|
308 |
+
nested_types=[_SIMAGENTMETRICSCONFIG_FEATURECONFIG, _SIMAGENTMETRICSCONFIG_HISTOGRAMESTIMATE, _SIMAGENTMETRICSCONFIG_KERNELDENSITYESTIMATE, _SIMAGENTMETRICSCONFIG_BERNOULLIESTIMATE, ],
|
309 |
+
enum_types=[
|
310 |
+
],
|
311 |
+
serialized_options=None,
|
312 |
+
is_extendable=False,
|
313 |
+
syntax='proto2',
|
314 |
+
extension_ranges=[],
|
315 |
+
oneofs=[
|
316 |
+
],
|
317 |
+
serialized_start=55,
|
318 |
+
serialized_end=1643,
|
319 |
+
)
|
320 |
+
|
321 |
+
|
322 |
+
_SIMAGENTMETRICS = _descriptor.Descriptor(
|
323 |
+
name='SimAgentMetrics',
|
324 |
+
full_name='long_metric.SimAgentMetrics',
|
325 |
+
filename=None,
|
326 |
+
file=DESCRIPTOR,
|
327 |
+
containing_type=None,
|
328 |
+
create_key=_descriptor._internal_create_key,
|
329 |
+
fields=[
|
330 |
+
_descriptor.FieldDescriptor(
|
331 |
+
name='scenario_id', full_name='long_metric.SimAgentMetrics.scenario_id', index=0,
|
332 |
+
number=1, type=9, cpp_type=9, label=1,
|
333 |
+
has_default_value=False, default_value=b"".decode('utf-8'),
|
334 |
+
message_type=None, enum_type=None, containing_type=None,
|
335 |
+
is_extension=False, extension_scope=None,
|
336 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
337 |
+
_descriptor.FieldDescriptor(
|
338 |
+
name='metametric', full_name='long_metric.SimAgentMetrics.metametric', index=1,
|
339 |
+
number=2, type=2, cpp_type=6, label=1,
|
340 |
+
has_default_value=False, default_value=float(0),
|
341 |
+
message_type=None, enum_type=None, containing_type=None,
|
342 |
+
is_extension=False, extension_scope=None,
|
343 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
344 |
+
_descriptor.FieldDescriptor(
|
345 |
+
name='average_displacement_error', full_name='long_metric.SimAgentMetrics.average_displacement_error', index=2,
|
346 |
+
number=3, type=2, cpp_type=6, label=1,
|
347 |
+
has_default_value=False, default_value=float(0),
|
348 |
+
message_type=None, enum_type=None, containing_type=None,
|
349 |
+
is_extension=False, extension_scope=None,
|
350 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
351 |
+
_descriptor.FieldDescriptor(
|
352 |
+
name='min_average_displacement_error', full_name='long_metric.SimAgentMetrics.min_average_displacement_error', index=3,
|
353 |
+
number=19, type=2, cpp_type=6, label=1,
|
354 |
+
has_default_value=False, default_value=float(0),
|
355 |
+
message_type=None, enum_type=None, containing_type=None,
|
356 |
+
is_extension=False, extension_scope=None,
|
357 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
358 |
+
_descriptor.FieldDescriptor(
|
359 |
+
name='linear_speed_likelihood', full_name='long_metric.SimAgentMetrics.linear_speed_likelihood', index=4,
|
360 |
+
number=4, type=2, cpp_type=6, label=1,
|
361 |
+
has_default_value=False, default_value=float(0),
|
362 |
+
message_type=None, enum_type=None, containing_type=None,
|
363 |
+
is_extension=False, extension_scope=None,
|
364 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
365 |
+
_descriptor.FieldDescriptor(
|
366 |
+
name='linear_acceleration_likelihood', full_name='long_metric.SimAgentMetrics.linear_acceleration_likelihood', index=5,
|
367 |
+
number=5, type=2, cpp_type=6, label=1,
|
368 |
+
has_default_value=False, default_value=float(0),
|
369 |
+
message_type=None, enum_type=None, containing_type=None,
|
370 |
+
is_extension=False, extension_scope=None,
|
371 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
372 |
+
_descriptor.FieldDescriptor(
|
373 |
+
name='angular_speed_likelihood', full_name='long_metric.SimAgentMetrics.angular_speed_likelihood', index=6,
|
374 |
+
number=6, type=2, cpp_type=6, label=1,
|
375 |
+
has_default_value=False, default_value=float(0),
|
376 |
+
message_type=None, enum_type=None, containing_type=None,
|
377 |
+
is_extension=False, extension_scope=None,
|
378 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
379 |
+
_descriptor.FieldDescriptor(
|
380 |
+
name='angular_acceleration_likelihood', full_name='long_metric.SimAgentMetrics.angular_acceleration_likelihood', index=7,
|
381 |
+
number=7, type=2, cpp_type=6, label=1,
|
382 |
+
has_default_value=False, default_value=float(0),
|
383 |
+
message_type=None, enum_type=None, containing_type=None,
|
384 |
+
is_extension=False, extension_scope=None,
|
385 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
386 |
+
_descriptor.FieldDescriptor(
|
387 |
+
name='distance_to_nearest_object_likelihood', full_name='long_metric.SimAgentMetrics.distance_to_nearest_object_likelihood', index=8,
|
388 |
+
number=8, type=2, cpp_type=6, label=1,
|
389 |
+
has_default_value=False, default_value=float(0),
|
390 |
+
message_type=None, enum_type=None, containing_type=None,
|
391 |
+
is_extension=False, extension_scope=None,
|
392 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
393 |
+
_descriptor.FieldDescriptor(
|
394 |
+
name='collision_indication_likelihood', full_name='long_metric.SimAgentMetrics.collision_indication_likelihood', index=9,
|
395 |
+
number=9, type=2, cpp_type=6, label=1,
|
396 |
+
has_default_value=False, default_value=float(0),
|
397 |
+
message_type=None, enum_type=None, containing_type=None,
|
398 |
+
is_extension=False, extension_scope=None,
|
399 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
400 |
+
_descriptor.FieldDescriptor(
|
401 |
+
name='time_to_collision_likelihood', full_name='long_metric.SimAgentMetrics.time_to_collision_likelihood', index=10,
|
402 |
+
number=10, type=2, cpp_type=6, label=1,
|
403 |
+
has_default_value=False, default_value=float(0),
|
404 |
+
message_type=None, enum_type=None, containing_type=None,
|
405 |
+
is_extension=False, extension_scope=None,
|
406 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
407 |
+
_descriptor.FieldDescriptor(
|
408 |
+
name='distance_to_road_edge_likelihood', full_name='long_metric.SimAgentMetrics.distance_to_road_edge_likelihood', index=11,
|
409 |
+
number=11, type=2, cpp_type=6, label=1,
|
410 |
+
has_default_value=False, default_value=float(0),
|
411 |
+
message_type=None, enum_type=None, containing_type=None,
|
412 |
+
is_extension=False, extension_scope=None,
|
413 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
414 |
+
_descriptor.FieldDescriptor(
|
415 |
+
name='offroad_indication_likelihood', full_name='long_metric.SimAgentMetrics.offroad_indication_likelihood', index=12,
|
416 |
+
number=12, type=2, cpp_type=6, label=1,
|
417 |
+
has_default_value=False, default_value=float(0),
|
418 |
+
message_type=None, enum_type=None, containing_type=None,
|
419 |
+
is_extension=False, extension_scope=None,
|
420 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
421 |
+
_descriptor.FieldDescriptor(
|
422 |
+
name='num_placement_likelihood', full_name='long_metric.SimAgentMetrics.num_placement_likelihood', index=13,
|
423 |
+
number=13, type=2, cpp_type=6, label=1,
|
424 |
+
has_default_value=False, default_value=float(0),
|
425 |
+
message_type=None, enum_type=None, containing_type=None,
|
426 |
+
is_extension=False, extension_scope=None,
|
427 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
428 |
+
_descriptor.FieldDescriptor(
|
429 |
+
name='num_removement_likelihood', full_name='long_metric.SimAgentMetrics.num_removement_likelihood', index=14,
|
430 |
+
number=14, type=2, cpp_type=6, label=1,
|
431 |
+
has_default_value=False, default_value=float(0),
|
432 |
+
message_type=None, enum_type=None, containing_type=None,
|
433 |
+
is_extension=False, extension_scope=None,
|
434 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
435 |
+
_descriptor.FieldDescriptor(
|
436 |
+
name='distance_placement_likelihood', full_name='long_metric.SimAgentMetrics.distance_placement_likelihood', index=15,
|
437 |
+
number=15, type=2, cpp_type=6, label=1,
|
438 |
+
has_default_value=False, default_value=float(0),
|
439 |
+
message_type=None, enum_type=None, containing_type=None,
|
440 |
+
is_extension=False, extension_scope=None,
|
441 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
442 |
+
_descriptor.FieldDescriptor(
|
443 |
+
name='distance_removement_likelihood', full_name='long_metric.SimAgentMetrics.distance_removement_likelihood', index=16,
|
444 |
+
number=16, type=2, cpp_type=6, label=1,
|
445 |
+
has_default_value=False, default_value=float(0),
|
446 |
+
message_type=None, enum_type=None, containing_type=None,
|
447 |
+
is_extension=False, extension_scope=None,
|
448 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
449 |
+
_descriptor.FieldDescriptor(
|
450 |
+
name='simulated_collision_rate', full_name='long_metric.SimAgentMetrics.simulated_collision_rate', index=17,
|
451 |
+
number=17, type=2, cpp_type=6, label=1,
|
452 |
+
has_default_value=False, default_value=float(0),
|
453 |
+
message_type=None, enum_type=None, containing_type=None,
|
454 |
+
is_extension=False, extension_scope=None,
|
455 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
456 |
+
_descriptor.FieldDescriptor(
|
457 |
+
name='simulated_offroad_rate', full_name='long_metric.SimAgentMetrics.simulated_offroad_rate', index=18,
|
458 |
+
number=18, type=2, cpp_type=6, label=1,
|
459 |
+
has_default_value=False, default_value=float(0),
|
460 |
+
message_type=None, enum_type=None, containing_type=None,
|
461 |
+
is_extension=False, extension_scope=None,
|
462 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
463 |
+
],
|
464 |
+
extensions=[
|
465 |
+
],
|
466 |
+
nested_types=[],
|
467 |
+
enum_types=[
|
468 |
+
],
|
469 |
+
serialized_options=None,
|
470 |
+
is_extendable=False,
|
471 |
+
syntax='proto2',
|
472 |
+
extension_ranges=[],
|
473 |
+
oneofs=[
|
474 |
+
],
|
475 |
+
serialized_start=1646,
|
476 |
+
serialized_end=2349,
|
477 |
+
)
|
478 |
+
|
479 |
+
|
480 |
+
_SIMAGENTSBUCKETEDMETRICS = _descriptor.Descriptor(
|
481 |
+
name='SimAgentsBucketedMetrics',
|
482 |
+
full_name='long_metric.SimAgentsBucketedMetrics',
|
483 |
+
filename=None,
|
484 |
+
file=DESCRIPTOR,
|
485 |
+
containing_type=None,
|
486 |
+
create_key=_descriptor._internal_create_key,
|
487 |
+
fields=[
|
488 |
+
_descriptor.FieldDescriptor(
|
489 |
+
name='realism_meta_metric', full_name='long_metric.SimAgentsBucketedMetrics.realism_meta_metric', index=0,
|
490 |
+
number=1, type=2, cpp_type=6, label=1,
|
491 |
+
has_default_value=False, default_value=float(0),
|
492 |
+
message_type=None, enum_type=None, containing_type=None,
|
493 |
+
is_extension=False, extension_scope=None,
|
494 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
495 |
+
_descriptor.FieldDescriptor(
|
496 |
+
name='kinematic_metrics', full_name='long_metric.SimAgentsBucketedMetrics.kinematic_metrics', index=1,
|
497 |
+
number=2, type=2, cpp_type=6, label=1,
|
498 |
+
has_default_value=False, default_value=float(0),
|
499 |
+
message_type=None, enum_type=None, containing_type=None,
|
500 |
+
is_extension=False, extension_scope=None,
|
501 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
502 |
+
_descriptor.FieldDescriptor(
|
503 |
+
name='interactive_metrics', full_name='long_metric.SimAgentsBucketedMetrics.interactive_metrics', index=2,
|
504 |
+
number=5, type=2, cpp_type=6, label=1,
|
505 |
+
has_default_value=False, default_value=float(0),
|
506 |
+
message_type=None, enum_type=None, containing_type=None,
|
507 |
+
is_extension=False, extension_scope=None,
|
508 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
509 |
+
_descriptor.FieldDescriptor(
|
510 |
+
name='map_based_metrics', full_name='long_metric.SimAgentsBucketedMetrics.map_based_metrics', index=3,
|
511 |
+
number=6, type=2, cpp_type=6, label=1,
|
512 |
+
has_default_value=False, default_value=float(0),
|
513 |
+
message_type=None, enum_type=None, containing_type=None,
|
514 |
+
is_extension=False, extension_scope=None,
|
515 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
516 |
+
_descriptor.FieldDescriptor(
|
517 |
+
name='placement_based_metrics', full_name='long_metric.SimAgentsBucketedMetrics.placement_based_metrics', index=4,
|
518 |
+
number=7, type=2, cpp_type=6, label=1,
|
519 |
+
has_default_value=False, default_value=float(0),
|
520 |
+
message_type=None, enum_type=None, containing_type=None,
|
521 |
+
is_extension=False, extension_scope=None,
|
522 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
523 |
+
_descriptor.FieldDescriptor(
|
524 |
+
name='min_ade', full_name='long_metric.SimAgentsBucketedMetrics.min_ade', index=5,
|
525 |
+
number=8, type=2, cpp_type=6, label=1,
|
526 |
+
has_default_value=False, default_value=float(0),
|
527 |
+
message_type=None, enum_type=None, containing_type=None,
|
528 |
+
is_extension=False, extension_scope=None,
|
529 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
530 |
+
_descriptor.FieldDescriptor(
|
531 |
+
name='simulated_collision_rate', full_name='long_metric.SimAgentsBucketedMetrics.simulated_collision_rate', index=6,
|
532 |
+
number=9, type=2, cpp_type=6, label=1,
|
533 |
+
has_default_value=False, default_value=float(0),
|
534 |
+
message_type=None, enum_type=None, containing_type=None,
|
535 |
+
is_extension=False, extension_scope=None,
|
536 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
537 |
+
_descriptor.FieldDescriptor(
|
538 |
+
name='simulated_offroad_rate', full_name='long_metric.SimAgentsBucketedMetrics.simulated_offroad_rate', index=7,
|
539 |
+
number=10, type=2, cpp_type=6, label=1,
|
540 |
+
has_default_value=False, default_value=float(0),
|
541 |
+
message_type=None, enum_type=None, containing_type=None,
|
542 |
+
is_extension=False, extension_scope=None,
|
543 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
544 |
+
],
|
545 |
+
extensions=[
|
546 |
+
],
|
547 |
+
nested_types=[],
|
548 |
+
enum_types=[
|
549 |
+
],
|
550 |
+
serialized_options=None,
|
551 |
+
is_extendable=False,
|
552 |
+
syntax='proto2',
|
553 |
+
extension_ranges=[],
|
554 |
+
oneofs=[
|
555 |
+
],
|
556 |
+
serialized_start=2352,
|
557 |
+
serialized_end=2606,
|
558 |
+
)
|
559 |
+
|
560 |
+
_SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['histogram'].message_type = _SIMAGENTMETRICSCONFIG_HISTOGRAMESTIMATE
|
561 |
+
_SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['kernel_density'].message_type = _SIMAGENTMETRICSCONFIG_KERNELDENSITYESTIMATE
|
562 |
+
_SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['bernoulli'].message_type = _SIMAGENTMETRICSCONFIG_BERNOULLIESTIMATE
|
563 |
+
_SIMAGENTMETRICSCONFIG_FEATURECONFIG.containing_type = _SIMAGENTMETRICSCONFIG
|
564 |
+
_SIMAGENTMETRICSCONFIG_FEATURECONFIG.oneofs_by_name['estimator'].fields.append(
|
565 |
+
_SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['histogram'])
|
566 |
+
_SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['histogram'].containing_oneof = _SIMAGENTMETRICSCONFIG_FEATURECONFIG.oneofs_by_name['estimator']
|
567 |
+
_SIMAGENTMETRICSCONFIG_FEATURECONFIG.oneofs_by_name['estimator'].fields.append(
|
568 |
+
_SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['kernel_density'])
|
569 |
+
_SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['kernel_density'].containing_oneof = _SIMAGENTMETRICSCONFIG_FEATURECONFIG.oneofs_by_name['estimator']
|
570 |
+
_SIMAGENTMETRICSCONFIG_FEATURECONFIG.oneofs_by_name['estimator'].fields.append(
|
571 |
+
_SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['bernoulli'])
|
572 |
+
_SIMAGENTMETRICSCONFIG_FEATURECONFIG.fields_by_name['bernoulli'].containing_oneof = _SIMAGENTMETRICSCONFIG_FEATURECONFIG.oneofs_by_name['estimator']
|
573 |
+
_SIMAGENTMETRICSCONFIG_HISTOGRAMESTIMATE.containing_type = _SIMAGENTMETRICSCONFIG
|
574 |
+
_SIMAGENTMETRICSCONFIG_KERNELDENSITYESTIMATE.containing_type = _SIMAGENTMETRICSCONFIG
|
575 |
+
_SIMAGENTMETRICSCONFIG_BERNOULLIESTIMATE.containing_type = _SIMAGENTMETRICSCONFIG
|
576 |
+
_SIMAGENTMETRICSCONFIG.fields_by_name['linear_speed'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
|
577 |
+
_SIMAGENTMETRICSCONFIG.fields_by_name['linear_acceleration'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
|
578 |
+
_SIMAGENTMETRICSCONFIG.fields_by_name['angular_speed'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
|
579 |
+
_SIMAGENTMETRICSCONFIG.fields_by_name['angular_acceleration'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
|
580 |
+
_SIMAGENTMETRICSCONFIG.fields_by_name['distance_to_nearest_object'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
|
581 |
+
_SIMAGENTMETRICSCONFIG.fields_by_name['collision_indication'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
|
582 |
+
_SIMAGENTMETRICSCONFIG.fields_by_name['time_to_collision'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
|
583 |
+
_SIMAGENTMETRICSCONFIG.fields_by_name['distance_to_road_edge'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
|
584 |
+
_SIMAGENTMETRICSCONFIG.fields_by_name['offroad_indication'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
|
585 |
+
_SIMAGENTMETRICSCONFIG.fields_by_name['num_placement'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
|
586 |
+
_SIMAGENTMETRICSCONFIG.fields_by_name['num_removement'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
|
587 |
+
_SIMAGENTMETRICSCONFIG.fields_by_name['distance_placement'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
|
588 |
+
_SIMAGENTMETRICSCONFIG.fields_by_name['distance_removement'].message_type = _SIMAGENTMETRICSCONFIG_FEATURECONFIG
|
589 |
+
DESCRIPTOR.message_types_by_name['SimAgentMetricsConfig'] = _SIMAGENTMETRICSCONFIG
|
590 |
+
DESCRIPTOR.message_types_by_name['SimAgentMetrics'] = _SIMAGENTMETRICS
|
591 |
+
DESCRIPTOR.message_types_by_name['SimAgentsBucketedMetrics'] = _SIMAGENTSBUCKETEDMETRICS
|
592 |
+
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
|
593 |
+
|
594 |
+
SimAgentMetricsConfig = _reflection.GeneratedProtocolMessageType('SimAgentMetricsConfig', (_message.Message,), {
|
595 |
+
|
596 |
+
'FeatureConfig' : _reflection.GeneratedProtocolMessageType('FeatureConfig', (_message.Message,), {
|
597 |
+
'DESCRIPTOR' : _SIMAGENTMETRICSCONFIG_FEATURECONFIG,
|
598 |
+
'__module__' : 'dev.metrics.protos.long_metrics_pb2'
|
599 |
+
# @@protoc_insertion_point(class_scope:long_metric.SimAgentMetricsConfig.FeatureConfig)
|
600 |
+
})
|
601 |
+
,
|
602 |
+
|
603 |
+
'HistogramEstimate' : _reflection.GeneratedProtocolMessageType('HistogramEstimate', (_message.Message,), {
|
604 |
+
'DESCRIPTOR' : _SIMAGENTMETRICSCONFIG_HISTOGRAMESTIMATE,
|
605 |
+
'__module__' : 'dev.metrics.protos.long_metrics_pb2'
|
606 |
+
# @@protoc_insertion_point(class_scope:long_metric.SimAgentMetricsConfig.HistogramEstimate)
|
607 |
+
})
|
608 |
+
,
|
609 |
+
|
610 |
+
'KernelDensityEstimate' : _reflection.GeneratedProtocolMessageType('KernelDensityEstimate', (_message.Message,), {
|
611 |
+
'DESCRIPTOR' : _SIMAGENTMETRICSCONFIG_KERNELDENSITYESTIMATE,
|
612 |
+
'__module__' : 'dev.metrics.protos.long_metrics_pb2'
|
613 |
+
# @@protoc_insertion_point(class_scope:long_metric.SimAgentMetricsConfig.KernelDensityEstimate)
|
614 |
+
})
|
615 |
+
,
|
616 |
+
|
617 |
+
'BernoulliEstimate' : _reflection.GeneratedProtocolMessageType('BernoulliEstimate', (_message.Message,), {
|
618 |
+
'DESCRIPTOR' : _SIMAGENTMETRICSCONFIG_BERNOULLIESTIMATE,
|
619 |
+
'__module__' : 'dev.metrics.protos.long_metrics_pb2'
|
620 |
+
# @@protoc_insertion_point(class_scope:long_metric.SimAgentMetricsConfig.BernoulliEstimate)
|
621 |
+
})
|
622 |
+
,
|
623 |
+
'DESCRIPTOR' : _SIMAGENTMETRICSCONFIG,
|
624 |
+
'__module__' : 'dev.metrics.protos.long_metrics_pb2'
|
625 |
+
# @@protoc_insertion_point(class_scope:long_metric.SimAgentMetricsConfig)
|
626 |
+
})
|
627 |
+
_sym_db.RegisterMessage(SimAgentMetricsConfig)
|
628 |
+
_sym_db.RegisterMessage(SimAgentMetricsConfig.FeatureConfig)
|
629 |
+
_sym_db.RegisterMessage(SimAgentMetricsConfig.HistogramEstimate)
|
630 |
+
_sym_db.RegisterMessage(SimAgentMetricsConfig.KernelDensityEstimate)
|
631 |
+
_sym_db.RegisterMessage(SimAgentMetricsConfig.BernoulliEstimate)
|
632 |
+
|
633 |
+
SimAgentMetrics = _reflection.GeneratedProtocolMessageType('SimAgentMetrics', (_message.Message,), {
|
634 |
+
'DESCRIPTOR' : _SIMAGENTMETRICS,
|
635 |
+
'__module__' : 'dev.metrics.protos.long_metrics_pb2'
|
636 |
+
# @@protoc_insertion_point(class_scope:long_metric.SimAgentMetrics)
|
637 |
+
})
|
638 |
+
_sym_db.RegisterMessage(SimAgentMetrics)
|
639 |
+
|
640 |
+
SimAgentsBucketedMetrics = _reflection.GeneratedProtocolMessageType('SimAgentsBucketedMetrics', (_message.Message,), {
|
641 |
+
'DESCRIPTOR' : _SIMAGENTSBUCKETEDMETRICS,
|
642 |
+
'__module__' : 'dev.metrics.protos.long_metrics_pb2'
|
643 |
+
# @@protoc_insertion_point(class_scope:long_metric.SimAgentsBucketedMetrics)
|
644 |
+
})
|
645 |
+
_sym_db.RegisterMessage(SimAgentsBucketedMetrics)
|
646 |
+
|
647 |
+
|
648 |
+
# @@protoc_insertion_point(module_scope)
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/metrics/protos/map_pb2.py
ADDED
@@ -0,0 +1,1070 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
3 |
+
# source: dev/metrics/protos/map.proto
|
4 |
+
|
5 |
+
from google.protobuf import descriptor as _descriptor
|
6 |
+
from google.protobuf import message as _message
|
7 |
+
from google.protobuf import reflection as _reflection
|
8 |
+
from google.protobuf import symbol_database as _symbol_database
|
9 |
+
# @@protoc_insertion_point(imports)
|
10 |
+
|
11 |
+
_sym_db = _symbol_database.Default()
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
DESCRIPTOR = _descriptor.FileDescriptor(
|
17 |
+
name='dev/metrics/protos/map.proto',
|
18 |
+
package='long_metric',
|
19 |
+
syntax='proto2',
|
20 |
+
serialized_options=None,
|
21 |
+
create_key=_descriptor._internal_create_key,
|
22 |
+
serialized_pb=b'\n\x1c\x64\x65v/metrics/protos/map.proto\x12\x0blong_metric\"g\n\x03Map\x12-\n\x0cmap_features\x18\x01 \x03(\x0b\x32\x17.long_metric.MapFeature\x12\x31\n\x0e\x64ynamic_states\x18\x02 \x03(\x0b\x32\x19.long_metric.DynamicState\"c\n\x0c\x44ynamicState\x12\x19\n\x11timestamp_seconds\x18\x01 \x01(\x01\x12\x38\n\x0blane_states\x18\x02 \x03(\x0b\x32#.long_metric.TrafficSignalLaneState\"\xfe\x02\n\x16TrafficSignalLaneState\x12\x0c\n\x04lane\x18\x01 \x01(\x03\x12\x38\n\x05state\x18\x02 \x01(\x0e\x32).long_metric.TrafficSignalLaneState.State\x12)\n\nstop_point\x18\x03 \x01(\x0b\x32\x15.long_metric.MapPoint\"\xf0\x01\n\x05State\x12\x16\n\x12LANE_STATE_UNKNOWN\x10\x00\x12\x19\n\x15LANE_STATE_ARROW_STOP\x10\x01\x12\x1c\n\x18LANE_STATE_ARROW_CAUTION\x10\x02\x12\x17\n\x13LANE_STATE_ARROW_GO\x10\x03\x12\x13\n\x0fLANE_STATE_STOP\x10\x04\x12\x16\n\x12LANE_STATE_CAUTION\x10\x05\x12\x11\n\rLANE_STATE_GO\x10\x06\x12\x1c\n\x18LANE_STATE_FLASHING_STOP\x10\x07\x12\x1f\n\x1bLANE_STATE_FLASHING_CAUTION\x10\x08\"\xdb\x02\n\nMapFeature\x12\n\n\x02id\x18\x01 \x01(\x03\x12\'\n\x04lane\x18\x03 \x01(\x0b\x32\x17.long_metric.LaneCenterH\x00\x12*\n\troad_line\x18\x04 \x01(\x0b\x32\x15.long_metric.RoadLineH\x00\x12*\n\troad_edge\x18\x05 \x01(\x0b\x32\x15.long_metric.RoadEdgeH\x00\x12*\n\tstop_sign\x18\x07 \x01(\x0b\x32\x15.long_metric.StopSignH\x00\x12+\n\tcrosswalk\x18\x08 \x01(\x0b\x32\x16.long_metric.CrosswalkH\x00\x12,\n\nspeed_bump\x18\t \x01(\x0b\x32\x16.long_metric.SpeedBumpH\x00\x12)\n\x08\x64riveway\x18\n \x01(\x0b\x32\x15.long_metric.DrivewayH\x00\x42\x0e\n\x0c\x66\x65\x61ture_data\"+\n\x08MapPoint\x12\t\n\x01x\x18\x01 \x01(\x01\x12\t\n\x01y\x18\x02 \x01(\x01\x12\t\n\x01z\x18\x03 \x01(\x01\"\x9b\x01\n\x0f\x42oundarySegment\x12\x18\n\x10lane_start_index\x18\x01 \x01(\x05\x12\x16\n\x0elane_end_index\x18\x02 \x01(\x05\x12\x1b\n\x13\x62oundary_feature_id\x18\x03 \x01(\x03\x12\x39\n\rboundary_type\x18\x04 \x01(\x0e\x32\".long_metric.RoadLine.RoadLineType\"\xc0\x01\n\x0cLaneNeighbor\x12\x12\n\nfeature_id\x18\x01 \x01(\x03\x12\x18\n\x10self_start_index\x18\x02 \x01(\x05\x12\x16\n\x0eself_end_index\x18\x03 \x01(\x05\x12\x1c\n\x14neighbor_start_index\x18\x04 \x01(\x05\x12\x1a\n\x12neighbor_end_index\x18\x05 \x01(\x05\x12\x30\n\nboundaries\x18\x06 \x03(\x0b\x32\x1c.long_metric.BoundarySegment\"\xfb\x03\n\nLaneCenter\x12\x17\n\x0fspeed_limit_mph\x18\x01 \x01(\x01\x12.\n\x04type\x18\x02 \x01(\x0e\x32 .long_metric.LaneCenter.LaneType\x12\x15\n\rinterpolating\x18\x03 \x01(\x08\x12\'\n\x08polyline\x18\x08 \x03(\x0b\x32\x15.long_metric.MapPoint\x12\x17\n\x0b\x65ntry_lanes\x18\t \x03(\x03\x42\x02\x10\x01\x12\x16\n\nexit_lanes\x18\n \x03(\x03\x42\x02\x10\x01\x12\x35\n\x0fleft_boundaries\x18\r \x03(\x0b\x32\x1c.long_metric.BoundarySegment\x12\x36\n\x10right_boundaries\x18\x0e \x03(\x0b\x32\x1c.long_metric.BoundarySegment\x12\x31\n\x0eleft_neighbors\x18\x0b \x03(\x0b\x32\x19.long_metric.LaneNeighbor\x12\x32\n\x0fright_neighbors\x18\x0c \x03(\x0b\x32\x19.long_metric.LaneNeighbor\"]\n\x08LaneType\x12\x12\n\x0eTYPE_UNDEFINED\x10\x00\x12\x10\n\x0cTYPE_FREEWAY\x10\x01\x12\x17\n\x13TYPE_SURFACE_STREET\x10\x02\x12\x12\n\x0eTYPE_BIKE_LANE\x10\x03\"\xbf\x01\n\x08RoadEdge\x12\x30\n\x04type\x18\x01 \x01(\x0e\x32\".long_metric.RoadEdge.RoadEdgeType\x12\'\n\x08polyline\x18\x02 \x03(\x0b\x32\x15.long_metric.MapPoint\"X\n\x0cRoadEdgeType\x12\x10\n\x0cTYPE_UNKNOWN\x10\x00\x12\x1b\n\x17TYPE_ROAD_EDGE_BOUNDARY\x10\x01\x12\x19\n\x15TYPE_ROAD_EDGE_MEDIAN\x10\x02\"\xfa\x02\n\x08RoadLine\x12\x30\n\x04type\x18\x01 \x01(\x0e\x32\".long_metric.RoadLine.RoadLineType\x12\'\n\x08polyline\x18\x02 \x03(\x0b\x32\x15.long_metric.MapPoint\"\x92\x02\n\x0cRoadLineType\x12\x10\n\x0cTYPE_UNKNOWN\x10\x00\x12\x1c\n\x18TYPE_BROKEN_SINGLE_WHITE\x10\x01\x12\x1b\n\x17TYPE_SOLID_SINGLE_WHITE\x10\x02\x12\x1b\n\x17TYPE_SOLID_DOUBLE_WHITE\x10\x03\x12\x1d\n\x19TYPE_BROKEN_SINGLE_YELLOW\x10\x04\x12\x1d\n\x19TYPE_BROKEN_DOUBLE_YELLOW\x10\x05\x12\x1c\n\x18TYPE_SOLID_SINGLE_YELLOW\x10\x06\x12\x1c\n\x18TYPE_SOLID_DOUBLE_YELLOW\x10\x07\x12\x1e\n\x1aTYPE_PASSING_DOUBLE_YELLOW\x10\x08\"A\n\x08StopSign\x12\x0c\n\x04lane\x18\x01 \x03(\x03\x12\'\n\x08position\x18\x02 \x01(\x0b\x32\x15.long_metric.MapPoint\"3\n\tCrosswalk\x12&\n\x07polygon\x18\x01 \x03(\x0b\x32\x15.long_metric.MapPoint\"3\n\tSpeedBump\x12&\n\x07polygon\x18\x01 \x03(\x0b\x32\x15.long_metric.MapPoint\"2\n\x08\x44riveway\x12&\n\x07polygon\x18\x01 \x03(\x0b\x32\x15.long_metric.MapPoint'
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
_TRAFFICSIGNALLANESTATE_STATE = _descriptor.EnumDescriptor(
|
28 |
+
name='State',
|
29 |
+
full_name='long_metric.TrafficSignalLaneState.State',
|
30 |
+
filename=None,
|
31 |
+
file=DESCRIPTOR,
|
32 |
+
create_key=_descriptor._internal_create_key,
|
33 |
+
values=[
|
34 |
+
_descriptor.EnumValueDescriptor(
|
35 |
+
name='LANE_STATE_UNKNOWN', index=0, number=0,
|
36 |
+
serialized_options=None,
|
37 |
+
type=None,
|
38 |
+
create_key=_descriptor._internal_create_key),
|
39 |
+
_descriptor.EnumValueDescriptor(
|
40 |
+
name='LANE_STATE_ARROW_STOP', index=1, number=1,
|
41 |
+
serialized_options=None,
|
42 |
+
type=None,
|
43 |
+
create_key=_descriptor._internal_create_key),
|
44 |
+
_descriptor.EnumValueDescriptor(
|
45 |
+
name='LANE_STATE_ARROW_CAUTION', index=2, number=2,
|
46 |
+
serialized_options=None,
|
47 |
+
type=None,
|
48 |
+
create_key=_descriptor._internal_create_key),
|
49 |
+
_descriptor.EnumValueDescriptor(
|
50 |
+
name='LANE_STATE_ARROW_GO', index=3, number=3,
|
51 |
+
serialized_options=None,
|
52 |
+
type=None,
|
53 |
+
create_key=_descriptor._internal_create_key),
|
54 |
+
_descriptor.EnumValueDescriptor(
|
55 |
+
name='LANE_STATE_STOP', index=4, number=4,
|
56 |
+
serialized_options=None,
|
57 |
+
type=None,
|
58 |
+
create_key=_descriptor._internal_create_key),
|
59 |
+
_descriptor.EnumValueDescriptor(
|
60 |
+
name='LANE_STATE_CAUTION', index=5, number=5,
|
61 |
+
serialized_options=None,
|
62 |
+
type=None,
|
63 |
+
create_key=_descriptor._internal_create_key),
|
64 |
+
_descriptor.EnumValueDescriptor(
|
65 |
+
name='LANE_STATE_GO', index=6, number=6,
|
66 |
+
serialized_options=None,
|
67 |
+
type=None,
|
68 |
+
create_key=_descriptor._internal_create_key),
|
69 |
+
_descriptor.EnumValueDescriptor(
|
70 |
+
name='LANE_STATE_FLASHING_STOP', index=7, number=7,
|
71 |
+
serialized_options=None,
|
72 |
+
type=None,
|
73 |
+
create_key=_descriptor._internal_create_key),
|
74 |
+
_descriptor.EnumValueDescriptor(
|
75 |
+
name='LANE_STATE_FLASHING_CAUTION', index=8, number=8,
|
76 |
+
serialized_options=None,
|
77 |
+
type=None,
|
78 |
+
create_key=_descriptor._internal_create_key),
|
79 |
+
],
|
80 |
+
containing_type=None,
|
81 |
+
serialized_options=None,
|
82 |
+
serialized_start=394,
|
83 |
+
serialized_end=634,
|
84 |
+
)
|
85 |
+
_sym_db.RegisterEnumDescriptor(_TRAFFICSIGNALLANESTATE_STATE)
|
86 |
+
|
87 |
+
_LANECENTER_LANETYPE = _descriptor.EnumDescriptor(
|
88 |
+
name='LaneType',
|
89 |
+
full_name='long_metric.LaneCenter.LaneType',
|
90 |
+
filename=None,
|
91 |
+
file=DESCRIPTOR,
|
92 |
+
create_key=_descriptor._internal_create_key,
|
93 |
+
values=[
|
94 |
+
_descriptor.EnumValueDescriptor(
|
95 |
+
name='TYPE_UNDEFINED', index=0, number=0,
|
96 |
+
serialized_options=None,
|
97 |
+
type=None,
|
98 |
+
create_key=_descriptor._internal_create_key),
|
99 |
+
_descriptor.EnumValueDescriptor(
|
100 |
+
name='TYPE_FREEWAY', index=1, number=1,
|
101 |
+
serialized_options=None,
|
102 |
+
type=None,
|
103 |
+
create_key=_descriptor._internal_create_key),
|
104 |
+
_descriptor.EnumValueDescriptor(
|
105 |
+
name='TYPE_SURFACE_STREET', index=2, number=2,
|
106 |
+
serialized_options=None,
|
107 |
+
type=None,
|
108 |
+
create_key=_descriptor._internal_create_key),
|
109 |
+
_descriptor.EnumValueDescriptor(
|
110 |
+
name='TYPE_BIKE_LANE', index=3, number=3,
|
111 |
+
serialized_options=None,
|
112 |
+
type=None,
|
113 |
+
create_key=_descriptor._internal_create_key),
|
114 |
+
],
|
115 |
+
containing_type=None,
|
116 |
+
serialized_options=None,
|
117 |
+
serialized_start=1799,
|
118 |
+
serialized_end=1892,
|
119 |
+
)
|
120 |
+
_sym_db.RegisterEnumDescriptor(_LANECENTER_LANETYPE)
|
121 |
+
|
122 |
+
_ROADEDGE_ROADEDGETYPE = _descriptor.EnumDescriptor(
|
123 |
+
name='RoadEdgeType',
|
124 |
+
full_name='long_metric.RoadEdge.RoadEdgeType',
|
125 |
+
filename=None,
|
126 |
+
file=DESCRIPTOR,
|
127 |
+
create_key=_descriptor._internal_create_key,
|
128 |
+
values=[
|
129 |
+
_descriptor.EnumValueDescriptor(
|
130 |
+
name='TYPE_UNKNOWN', index=0, number=0,
|
131 |
+
serialized_options=None,
|
132 |
+
type=None,
|
133 |
+
create_key=_descriptor._internal_create_key),
|
134 |
+
_descriptor.EnumValueDescriptor(
|
135 |
+
name='TYPE_ROAD_EDGE_BOUNDARY', index=1, number=1,
|
136 |
+
serialized_options=None,
|
137 |
+
type=None,
|
138 |
+
create_key=_descriptor._internal_create_key),
|
139 |
+
_descriptor.EnumValueDescriptor(
|
140 |
+
name='TYPE_ROAD_EDGE_MEDIAN', index=2, number=2,
|
141 |
+
serialized_options=None,
|
142 |
+
type=None,
|
143 |
+
create_key=_descriptor._internal_create_key),
|
144 |
+
],
|
145 |
+
containing_type=None,
|
146 |
+
serialized_options=None,
|
147 |
+
serialized_start=1998,
|
148 |
+
serialized_end=2086,
|
149 |
+
)
|
150 |
+
_sym_db.RegisterEnumDescriptor(_ROADEDGE_ROADEDGETYPE)
|
151 |
+
|
152 |
+
_ROADLINE_ROADLINETYPE = _descriptor.EnumDescriptor(
|
153 |
+
name='RoadLineType',
|
154 |
+
full_name='long_metric.RoadLine.RoadLineType',
|
155 |
+
filename=None,
|
156 |
+
file=DESCRIPTOR,
|
157 |
+
create_key=_descriptor._internal_create_key,
|
158 |
+
values=[
|
159 |
+
_descriptor.EnumValueDescriptor(
|
160 |
+
name='TYPE_UNKNOWN', index=0, number=0,
|
161 |
+
serialized_options=None,
|
162 |
+
type=None,
|
163 |
+
create_key=_descriptor._internal_create_key),
|
164 |
+
_descriptor.EnumValueDescriptor(
|
165 |
+
name='TYPE_BROKEN_SINGLE_WHITE', index=1, number=1,
|
166 |
+
serialized_options=None,
|
167 |
+
type=None,
|
168 |
+
create_key=_descriptor._internal_create_key),
|
169 |
+
_descriptor.EnumValueDescriptor(
|
170 |
+
name='TYPE_SOLID_SINGLE_WHITE', index=2, number=2,
|
171 |
+
serialized_options=None,
|
172 |
+
type=None,
|
173 |
+
create_key=_descriptor._internal_create_key),
|
174 |
+
_descriptor.EnumValueDescriptor(
|
175 |
+
name='TYPE_SOLID_DOUBLE_WHITE', index=3, number=3,
|
176 |
+
serialized_options=None,
|
177 |
+
type=None,
|
178 |
+
create_key=_descriptor._internal_create_key),
|
179 |
+
_descriptor.EnumValueDescriptor(
|
180 |
+
name='TYPE_BROKEN_SINGLE_YELLOW', index=4, number=4,
|
181 |
+
serialized_options=None,
|
182 |
+
type=None,
|
183 |
+
create_key=_descriptor._internal_create_key),
|
184 |
+
_descriptor.EnumValueDescriptor(
|
185 |
+
name='TYPE_BROKEN_DOUBLE_YELLOW', index=5, number=5,
|
186 |
+
serialized_options=None,
|
187 |
+
type=None,
|
188 |
+
create_key=_descriptor._internal_create_key),
|
189 |
+
_descriptor.EnumValueDescriptor(
|
190 |
+
name='TYPE_SOLID_SINGLE_YELLOW', index=6, number=6,
|
191 |
+
serialized_options=None,
|
192 |
+
type=None,
|
193 |
+
create_key=_descriptor._internal_create_key),
|
194 |
+
_descriptor.EnumValueDescriptor(
|
195 |
+
name='TYPE_SOLID_DOUBLE_YELLOW', index=7, number=7,
|
196 |
+
serialized_options=None,
|
197 |
+
type=None,
|
198 |
+
create_key=_descriptor._internal_create_key),
|
199 |
+
_descriptor.EnumValueDescriptor(
|
200 |
+
name='TYPE_PASSING_DOUBLE_YELLOW', index=8, number=8,
|
201 |
+
serialized_options=None,
|
202 |
+
type=None,
|
203 |
+
create_key=_descriptor._internal_create_key),
|
204 |
+
],
|
205 |
+
containing_type=None,
|
206 |
+
serialized_options=None,
|
207 |
+
serialized_start=2193,
|
208 |
+
serialized_end=2467,
|
209 |
+
)
|
210 |
+
_sym_db.RegisterEnumDescriptor(_ROADLINE_ROADLINETYPE)
|
211 |
+
|
212 |
+
|
213 |
+
_MAP = _descriptor.Descriptor(
|
214 |
+
name='Map',
|
215 |
+
full_name='long_metric.Map',
|
216 |
+
filename=None,
|
217 |
+
file=DESCRIPTOR,
|
218 |
+
containing_type=None,
|
219 |
+
create_key=_descriptor._internal_create_key,
|
220 |
+
fields=[
|
221 |
+
_descriptor.FieldDescriptor(
|
222 |
+
name='map_features', full_name='long_metric.Map.map_features', index=0,
|
223 |
+
number=1, type=11, cpp_type=10, label=3,
|
224 |
+
has_default_value=False, default_value=[],
|
225 |
+
message_type=None, enum_type=None, containing_type=None,
|
226 |
+
is_extension=False, extension_scope=None,
|
227 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
228 |
+
_descriptor.FieldDescriptor(
|
229 |
+
name='dynamic_states', full_name='long_metric.Map.dynamic_states', index=1,
|
230 |
+
number=2, type=11, cpp_type=10, label=3,
|
231 |
+
has_default_value=False, default_value=[],
|
232 |
+
message_type=None, enum_type=None, containing_type=None,
|
233 |
+
is_extension=False, extension_scope=None,
|
234 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
235 |
+
],
|
236 |
+
extensions=[
|
237 |
+
],
|
238 |
+
nested_types=[],
|
239 |
+
enum_types=[
|
240 |
+
],
|
241 |
+
serialized_options=None,
|
242 |
+
is_extendable=False,
|
243 |
+
syntax='proto2',
|
244 |
+
extension_ranges=[],
|
245 |
+
oneofs=[
|
246 |
+
],
|
247 |
+
serialized_start=45,
|
248 |
+
serialized_end=148,
|
249 |
+
)
|
250 |
+
|
251 |
+
|
252 |
+
_DYNAMICSTATE = _descriptor.Descriptor(
|
253 |
+
name='DynamicState',
|
254 |
+
full_name='long_metric.DynamicState',
|
255 |
+
filename=None,
|
256 |
+
file=DESCRIPTOR,
|
257 |
+
containing_type=None,
|
258 |
+
create_key=_descriptor._internal_create_key,
|
259 |
+
fields=[
|
260 |
+
_descriptor.FieldDescriptor(
|
261 |
+
name='timestamp_seconds', full_name='long_metric.DynamicState.timestamp_seconds', index=0,
|
262 |
+
number=1, type=1, cpp_type=5, label=1,
|
263 |
+
has_default_value=False, default_value=float(0),
|
264 |
+
message_type=None, enum_type=None, containing_type=None,
|
265 |
+
is_extension=False, extension_scope=None,
|
266 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
267 |
+
_descriptor.FieldDescriptor(
|
268 |
+
name='lane_states', full_name='long_metric.DynamicState.lane_states', index=1,
|
269 |
+
number=2, type=11, cpp_type=10, label=3,
|
270 |
+
has_default_value=False, default_value=[],
|
271 |
+
message_type=None, enum_type=None, containing_type=None,
|
272 |
+
is_extension=False, extension_scope=None,
|
273 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
274 |
+
],
|
275 |
+
extensions=[
|
276 |
+
],
|
277 |
+
nested_types=[],
|
278 |
+
enum_types=[
|
279 |
+
],
|
280 |
+
serialized_options=None,
|
281 |
+
is_extendable=False,
|
282 |
+
syntax='proto2',
|
283 |
+
extension_ranges=[],
|
284 |
+
oneofs=[
|
285 |
+
],
|
286 |
+
serialized_start=150,
|
287 |
+
serialized_end=249,
|
288 |
+
)
|
289 |
+
|
290 |
+
|
291 |
+
_TRAFFICSIGNALLANESTATE = _descriptor.Descriptor(
|
292 |
+
name='TrafficSignalLaneState',
|
293 |
+
full_name='long_metric.TrafficSignalLaneState',
|
294 |
+
filename=None,
|
295 |
+
file=DESCRIPTOR,
|
296 |
+
containing_type=None,
|
297 |
+
create_key=_descriptor._internal_create_key,
|
298 |
+
fields=[
|
299 |
+
_descriptor.FieldDescriptor(
|
300 |
+
name='lane', full_name='long_metric.TrafficSignalLaneState.lane', index=0,
|
301 |
+
number=1, type=3, cpp_type=2, label=1,
|
302 |
+
has_default_value=False, default_value=0,
|
303 |
+
message_type=None, enum_type=None, containing_type=None,
|
304 |
+
is_extension=False, extension_scope=None,
|
305 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
306 |
+
_descriptor.FieldDescriptor(
|
307 |
+
name='state', full_name='long_metric.TrafficSignalLaneState.state', index=1,
|
308 |
+
number=2, type=14, cpp_type=8, label=1,
|
309 |
+
has_default_value=False, default_value=0,
|
310 |
+
message_type=None, enum_type=None, containing_type=None,
|
311 |
+
is_extension=False, extension_scope=None,
|
312 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
313 |
+
_descriptor.FieldDescriptor(
|
314 |
+
name='stop_point', full_name='long_metric.TrafficSignalLaneState.stop_point', index=2,
|
315 |
+
number=3, type=11, cpp_type=10, label=1,
|
316 |
+
has_default_value=False, default_value=None,
|
317 |
+
message_type=None, enum_type=None, containing_type=None,
|
318 |
+
is_extension=False, extension_scope=None,
|
319 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
320 |
+
],
|
321 |
+
extensions=[
|
322 |
+
],
|
323 |
+
nested_types=[],
|
324 |
+
enum_types=[
|
325 |
+
_TRAFFICSIGNALLANESTATE_STATE,
|
326 |
+
],
|
327 |
+
serialized_options=None,
|
328 |
+
is_extendable=False,
|
329 |
+
syntax='proto2',
|
330 |
+
extension_ranges=[],
|
331 |
+
oneofs=[
|
332 |
+
],
|
333 |
+
serialized_start=252,
|
334 |
+
serialized_end=634,
|
335 |
+
)
|
336 |
+
|
337 |
+
|
338 |
+
_MAPFEATURE = _descriptor.Descriptor(
|
339 |
+
name='MapFeature',
|
340 |
+
full_name='long_metric.MapFeature',
|
341 |
+
filename=None,
|
342 |
+
file=DESCRIPTOR,
|
343 |
+
containing_type=None,
|
344 |
+
create_key=_descriptor._internal_create_key,
|
345 |
+
fields=[
|
346 |
+
_descriptor.FieldDescriptor(
|
347 |
+
name='id', full_name='long_metric.MapFeature.id', index=0,
|
348 |
+
number=1, type=3, cpp_type=2, label=1,
|
349 |
+
has_default_value=False, default_value=0,
|
350 |
+
message_type=None, enum_type=None, containing_type=None,
|
351 |
+
is_extension=False, extension_scope=None,
|
352 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
353 |
+
_descriptor.FieldDescriptor(
|
354 |
+
name='lane', full_name='long_metric.MapFeature.lane', index=1,
|
355 |
+
number=3, type=11, cpp_type=10, label=1,
|
356 |
+
has_default_value=False, default_value=None,
|
357 |
+
message_type=None, enum_type=None, containing_type=None,
|
358 |
+
is_extension=False, extension_scope=None,
|
359 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
360 |
+
_descriptor.FieldDescriptor(
|
361 |
+
name='road_line', full_name='long_metric.MapFeature.road_line', index=2,
|
362 |
+
number=4, type=11, cpp_type=10, label=1,
|
363 |
+
has_default_value=False, default_value=None,
|
364 |
+
message_type=None, enum_type=None, containing_type=None,
|
365 |
+
is_extension=False, extension_scope=None,
|
366 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
367 |
+
_descriptor.FieldDescriptor(
|
368 |
+
name='road_edge', full_name='long_metric.MapFeature.road_edge', index=3,
|
369 |
+
number=5, type=11, cpp_type=10, label=1,
|
370 |
+
has_default_value=False, default_value=None,
|
371 |
+
message_type=None, enum_type=None, containing_type=None,
|
372 |
+
is_extension=False, extension_scope=None,
|
373 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
374 |
+
_descriptor.FieldDescriptor(
|
375 |
+
name='stop_sign', full_name='long_metric.MapFeature.stop_sign', index=4,
|
376 |
+
number=7, type=11, cpp_type=10, label=1,
|
377 |
+
has_default_value=False, default_value=None,
|
378 |
+
message_type=None, enum_type=None, containing_type=None,
|
379 |
+
is_extension=False, extension_scope=None,
|
380 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
381 |
+
_descriptor.FieldDescriptor(
|
382 |
+
name='crosswalk', full_name='long_metric.MapFeature.crosswalk', index=5,
|
383 |
+
number=8, type=11, cpp_type=10, label=1,
|
384 |
+
has_default_value=False, default_value=None,
|
385 |
+
message_type=None, enum_type=None, containing_type=None,
|
386 |
+
is_extension=False, extension_scope=None,
|
387 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
388 |
+
_descriptor.FieldDescriptor(
|
389 |
+
name='speed_bump', full_name='long_metric.MapFeature.speed_bump', index=6,
|
390 |
+
number=9, type=11, cpp_type=10, label=1,
|
391 |
+
has_default_value=False, default_value=None,
|
392 |
+
message_type=None, enum_type=None, containing_type=None,
|
393 |
+
is_extension=False, extension_scope=None,
|
394 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
395 |
+
_descriptor.FieldDescriptor(
|
396 |
+
name='driveway', full_name='long_metric.MapFeature.driveway', index=7,
|
397 |
+
number=10, type=11, cpp_type=10, label=1,
|
398 |
+
has_default_value=False, default_value=None,
|
399 |
+
message_type=None, enum_type=None, containing_type=None,
|
400 |
+
is_extension=False, extension_scope=None,
|
401 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
402 |
+
],
|
403 |
+
extensions=[
|
404 |
+
],
|
405 |
+
nested_types=[],
|
406 |
+
enum_types=[
|
407 |
+
],
|
408 |
+
serialized_options=None,
|
409 |
+
is_extendable=False,
|
410 |
+
syntax='proto2',
|
411 |
+
extension_ranges=[],
|
412 |
+
oneofs=[
|
413 |
+
_descriptor.OneofDescriptor(
|
414 |
+
name='feature_data', full_name='long_metric.MapFeature.feature_data',
|
415 |
+
index=0, containing_type=None,
|
416 |
+
create_key=_descriptor._internal_create_key,
|
417 |
+
fields=[]),
|
418 |
+
],
|
419 |
+
serialized_start=637,
|
420 |
+
serialized_end=984,
|
421 |
+
)
|
422 |
+
|
423 |
+
|
424 |
+
_MAPPOINT = _descriptor.Descriptor(
|
425 |
+
name='MapPoint',
|
426 |
+
full_name='long_metric.MapPoint',
|
427 |
+
filename=None,
|
428 |
+
file=DESCRIPTOR,
|
429 |
+
containing_type=None,
|
430 |
+
create_key=_descriptor._internal_create_key,
|
431 |
+
fields=[
|
432 |
+
_descriptor.FieldDescriptor(
|
433 |
+
name='x', full_name='long_metric.MapPoint.x', index=0,
|
434 |
+
number=1, type=1, cpp_type=5, label=1,
|
435 |
+
has_default_value=False, default_value=float(0),
|
436 |
+
message_type=None, enum_type=None, containing_type=None,
|
437 |
+
is_extension=False, extension_scope=None,
|
438 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
439 |
+
_descriptor.FieldDescriptor(
|
440 |
+
name='y', full_name='long_metric.MapPoint.y', index=1,
|
441 |
+
number=2, type=1, cpp_type=5, label=1,
|
442 |
+
has_default_value=False, default_value=float(0),
|
443 |
+
message_type=None, enum_type=None, containing_type=None,
|
444 |
+
is_extension=False, extension_scope=None,
|
445 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
446 |
+
_descriptor.FieldDescriptor(
|
447 |
+
name='z', full_name='long_metric.MapPoint.z', index=2,
|
448 |
+
number=3, type=1, cpp_type=5, label=1,
|
449 |
+
has_default_value=False, default_value=float(0),
|
450 |
+
message_type=None, enum_type=None, containing_type=None,
|
451 |
+
is_extension=False, extension_scope=None,
|
452 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
453 |
+
],
|
454 |
+
extensions=[
|
455 |
+
],
|
456 |
+
nested_types=[],
|
457 |
+
enum_types=[
|
458 |
+
],
|
459 |
+
serialized_options=None,
|
460 |
+
is_extendable=False,
|
461 |
+
syntax='proto2',
|
462 |
+
extension_ranges=[],
|
463 |
+
oneofs=[
|
464 |
+
],
|
465 |
+
serialized_start=986,
|
466 |
+
serialized_end=1029,
|
467 |
+
)
|
468 |
+
|
469 |
+
|
470 |
+
_BOUNDARYSEGMENT = _descriptor.Descriptor(
|
471 |
+
name='BoundarySegment',
|
472 |
+
full_name='long_metric.BoundarySegment',
|
473 |
+
filename=None,
|
474 |
+
file=DESCRIPTOR,
|
475 |
+
containing_type=None,
|
476 |
+
create_key=_descriptor._internal_create_key,
|
477 |
+
fields=[
|
478 |
+
_descriptor.FieldDescriptor(
|
479 |
+
name='lane_start_index', full_name='long_metric.BoundarySegment.lane_start_index', index=0,
|
480 |
+
number=1, type=5, cpp_type=1, label=1,
|
481 |
+
has_default_value=False, default_value=0,
|
482 |
+
message_type=None, enum_type=None, containing_type=None,
|
483 |
+
is_extension=False, extension_scope=None,
|
484 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
485 |
+
_descriptor.FieldDescriptor(
|
486 |
+
name='lane_end_index', full_name='long_metric.BoundarySegment.lane_end_index', index=1,
|
487 |
+
number=2, type=5, cpp_type=1, label=1,
|
488 |
+
has_default_value=False, default_value=0,
|
489 |
+
message_type=None, enum_type=None, containing_type=None,
|
490 |
+
is_extension=False, extension_scope=None,
|
491 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
492 |
+
_descriptor.FieldDescriptor(
|
493 |
+
name='boundary_feature_id', full_name='long_metric.BoundarySegment.boundary_feature_id', index=2,
|
494 |
+
number=3, type=3, cpp_type=2, label=1,
|
495 |
+
has_default_value=False, default_value=0,
|
496 |
+
message_type=None, enum_type=None, containing_type=None,
|
497 |
+
is_extension=False, extension_scope=None,
|
498 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
499 |
+
_descriptor.FieldDescriptor(
|
500 |
+
name='boundary_type', full_name='long_metric.BoundarySegment.boundary_type', index=3,
|
501 |
+
number=4, type=14, cpp_type=8, label=1,
|
502 |
+
has_default_value=False, default_value=0,
|
503 |
+
message_type=None, enum_type=None, containing_type=None,
|
504 |
+
is_extension=False, extension_scope=None,
|
505 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
506 |
+
],
|
507 |
+
extensions=[
|
508 |
+
],
|
509 |
+
nested_types=[],
|
510 |
+
enum_types=[
|
511 |
+
],
|
512 |
+
serialized_options=None,
|
513 |
+
is_extendable=False,
|
514 |
+
syntax='proto2',
|
515 |
+
extension_ranges=[],
|
516 |
+
oneofs=[
|
517 |
+
],
|
518 |
+
serialized_start=1032,
|
519 |
+
serialized_end=1187,
|
520 |
+
)
|
521 |
+
|
522 |
+
|
523 |
+
_LANENEIGHBOR = _descriptor.Descriptor(
|
524 |
+
name='LaneNeighbor',
|
525 |
+
full_name='long_metric.LaneNeighbor',
|
526 |
+
filename=None,
|
527 |
+
file=DESCRIPTOR,
|
528 |
+
containing_type=None,
|
529 |
+
create_key=_descriptor._internal_create_key,
|
530 |
+
fields=[
|
531 |
+
_descriptor.FieldDescriptor(
|
532 |
+
name='feature_id', full_name='long_metric.LaneNeighbor.feature_id', index=0,
|
533 |
+
number=1, type=3, cpp_type=2, label=1,
|
534 |
+
has_default_value=False, default_value=0,
|
535 |
+
message_type=None, enum_type=None, containing_type=None,
|
536 |
+
is_extension=False, extension_scope=None,
|
537 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
538 |
+
_descriptor.FieldDescriptor(
|
539 |
+
name='self_start_index', full_name='long_metric.LaneNeighbor.self_start_index', index=1,
|
540 |
+
number=2, type=5, cpp_type=1, label=1,
|
541 |
+
has_default_value=False, default_value=0,
|
542 |
+
message_type=None, enum_type=None, containing_type=None,
|
543 |
+
is_extension=False, extension_scope=None,
|
544 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
545 |
+
_descriptor.FieldDescriptor(
|
546 |
+
name='self_end_index', full_name='long_metric.LaneNeighbor.self_end_index', index=2,
|
547 |
+
number=3, type=5, cpp_type=1, label=1,
|
548 |
+
has_default_value=False, default_value=0,
|
549 |
+
message_type=None, enum_type=None, containing_type=None,
|
550 |
+
is_extension=False, extension_scope=None,
|
551 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
552 |
+
_descriptor.FieldDescriptor(
|
553 |
+
name='neighbor_start_index', full_name='long_metric.LaneNeighbor.neighbor_start_index', index=3,
|
554 |
+
number=4, type=5, cpp_type=1, label=1,
|
555 |
+
has_default_value=False, default_value=0,
|
556 |
+
message_type=None, enum_type=None, containing_type=None,
|
557 |
+
is_extension=False, extension_scope=None,
|
558 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
559 |
+
_descriptor.FieldDescriptor(
|
560 |
+
name='neighbor_end_index', full_name='long_metric.LaneNeighbor.neighbor_end_index', index=4,
|
561 |
+
number=5, type=5, cpp_type=1, label=1,
|
562 |
+
has_default_value=False, default_value=0,
|
563 |
+
message_type=None, enum_type=None, containing_type=None,
|
564 |
+
is_extension=False, extension_scope=None,
|
565 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
566 |
+
_descriptor.FieldDescriptor(
|
567 |
+
name='boundaries', full_name='long_metric.LaneNeighbor.boundaries', index=5,
|
568 |
+
number=6, type=11, cpp_type=10, label=3,
|
569 |
+
has_default_value=False, default_value=[],
|
570 |
+
message_type=None, enum_type=None, containing_type=None,
|
571 |
+
is_extension=False, extension_scope=None,
|
572 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
573 |
+
],
|
574 |
+
extensions=[
|
575 |
+
],
|
576 |
+
nested_types=[],
|
577 |
+
enum_types=[
|
578 |
+
],
|
579 |
+
serialized_options=None,
|
580 |
+
is_extendable=False,
|
581 |
+
syntax='proto2',
|
582 |
+
extension_ranges=[],
|
583 |
+
oneofs=[
|
584 |
+
],
|
585 |
+
serialized_start=1190,
|
586 |
+
serialized_end=1382,
|
587 |
+
)
|
588 |
+
|
589 |
+
|
590 |
+
_LANECENTER = _descriptor.Descriptor(
|
591 |
+
name='LaneCenter',
|
592 |
+
full_name='long_metric.LaneCenter',
|
593 |
+
filename=None,
|
594 |
+
file=DESCRIPTOR,
|
595 |
+
containing_type=None,
|
596 |
+
create_key=_descriptor._internal_create_key,
|
597 |
+
fields=[
|
598 |
+
_descriptor.FieldDescriptor(
|
599 |
+
name='speed_limit_mph', full_name='long_metric.LaneCenter.speed_limit_mph', index=0,
|
600 |
+
number=1, type=1, cpp_type=5, label=1,
|
601 |
+
has_default_value=False, default_value=float(0),
|
602 |
+
message_type=None, enum_type=None, containing_type=None,
|
603 |
+
is_extension=False, extension_scope=None,
|
604 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
605 |
+
_descriptor.FieldDescriptor(
|
606 |
+
name='type', full_name='long_metric.LaneCenter.type', index=1,
|
607 |
+
number=2, type=14, cpp_type=8, label=1,
|
608 |
+
has_default_value=False, default_value=0,
|
609 |
+
message_type=None, enum_type=None, containing_type=None,
|
610 |
+
is_extension=False, extension_scope=None,
|
611 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
612 |
+
_descriptor.FieldDescriptor(
|
613 |
+
name='interpolating', full_name='long_metric.LaneCenter.interpolating', index=2,
|
614 |
+
number=3, type=8, cpp_type=7, label=1,
|
615 |
+
has_default_value=False, default_value=False,
|
616 |
+
message_type=None, enum_type=None, containing_type=None,
|
617 |
+
is_extension=False, extension_scope=None,
|
618 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
619 |
+
_descriptor.FieldDescriptor(
|
620 |
+
name='polyline', full_name='long_metric.LaneCenter.polyline', index=3,
|
621 |
+
number=8, type=11, cpp_type=10, label=3,
|
622 |
+
has_default_value=False, default_value=[],
|
623 |
+
message_type=None, enum_type=None, containing_type=None,
|
624 |
+
is_extension=False, extension_scope=None,
|
625 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
626 |
+
_descriptor.FieldDescriptor(
|
627 |
+
name='entry_lanes', full_name='long_metric.LaneCenter.entry_lanes', index=4,
|
628 |
+
number=9, type=3, cpp_type=2, label=3,
|
629 |
+
has_default_value=False, default_value=[],
|
630 |
+
message_type=None, enum_type=None, containing_type=None,
|
631 |
+
is_extension=False, extension_scope=None,
|
632 |
+
serialized_options=b'\020\001', file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
633 |
+
_descriptor.FieldDescriptor(
|
634 |
+
name='exit_lanes', full_name='long_metric.LaneCenter.exit_lanes', index=5,
|
635 |
+
number=10, type=3, cpp_type=2, label=3,
|
636 |
+
has_default_value=False, default_value=[],
|
637 |
+
message_type=None, enum_type=None, containing_type=None,
|
638 |
+
is_extension=False, extension_scope=None,
|
639 |
+
serialized_options=b'\020\001', file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
640 |
+
_descriptor.FieldDescriptor(
|
641 |
+
name='left_boundaries', full_name='long_metric.LaneCenter.left_boundaries', index=6,
|
642 |
+
number=13, type=11, cpp_type=10, label=3,
|
643 |
+
has_default_value=False, default_value=[],
|
644 |
+
message_type=None, enum_type=None, containing_type=None,
|
645 |
+
is_extension=False, extension_scope=None,
|
646 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
647 |
+
_descriptor.FieldDescriptor(
|
648 |
+
name='right_boundaries', full_name='long_metric.LaneCenter.right_boundaries', index=7,
|
649 |
+
number=14, type=11, cpp_type=10, label=3,
|
650 |
+
has_default_value=False, default_value=[],
|
651 |
+
message_type=None, enum_type=None, containing_type=None,
|
652 |
+
is_extension=False, extension_scope=None,
|
653 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
654 |
+
_descriptor.FieldDescriptor(
|
655 |
+
name='left_neighbors', full_name='long_metric.LaneCenter.left_neighbors', index=8,
|
656 |
+
number=11, type=11, cpp_type=10, label=3,
|
657 |
+
has_default_value=False, default_value=[],
|
658 |
+
message_type=None, enum_type=None, containing_type=None,
|
659 |
+
is_extension=False, extension_scope=None,
|
660 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
661 |
+
_descriptor.FieldDescriptor(
|
662 |
+
name='right_neighbors', full_name='long_metric.LaneCenter.right_neighbors', index=9,
|
663 |
+
number=12, type=11, cpp_type=10, label=3,
|
664 |
+
has_default_value=False, default_value=[],
|
665 |
+
message_type=None, enum_type=None, containing_type=None,
|
666 |
+
is_extension=False, extension_scope=None,
|
667 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
668 |
+
],
|
669 |
+
extensions=[
|
670 |
+
],
|
671 |
+
nested_types=[],
|
672 |
+
enum_types=[
|
673 |
+
_LANECENTER_LANETYPE,
|
674 |
+
],
|
675 |
+
serialized_options=None,
|
676 |
+
is_extendable=False,
|
677 |
+
syntax='proto2',
|
678 |
+
extension_ranges=[],
|
679 |
+
oneofs=[
|
680 |
+
],
|
681 |
+
serialized_start=1385,
|
682 |
+
serialized_end=1892,
|
683 |
+
)
|
684 |
+
|
685 |
+
|
686 |
+
_ROADEDGE = _descriptor.Descriptor(
|
687 |
+
name='RoadEdge',
|
688 |
+
full_name='long_metric.RoadEdge',
|
689 |
+
filename=None,
|
690 |
+
file=DESCRIPTOR,
|
691 |
+
containing_type=None,
|
692 |
+
create_key=_descriptor._internal_create_key,
|
693 |
+
fields=[
|
694 |
+
_descriptor.FieldDescriptor(
|
695 |
+
name='type', full_name='long_metric.RoadEdge.type', index=0,
|
696 |
+
number=1, type=14, cpp_type=8, label=1,
|
697 |
+
has_default_value=False, default_value=0,
|
698 |
+
message_type=None, enum_type=None, containing_type=None,
|
699 |
+
is_extension=False, extension_scope=None,
|
700 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
701 |
+
_descriptor.FieldDescriptor(
|
702 |
+
name='polyline', full_name='long_metric.RoadEdge.polyline', index=1,
|
703 |
+
number=2, type=11, cpp_type=10, label=3,
|
704 |
+
has_default_value=False, default_value=[],
|
705 |
+
message_type=None, enum_type=None, containing_type=None,
|
706 |
+
is_extension=False, extension_scope=None,
|
707 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
708 |
+
],
|
709 |
+
extensions=[
|
710 |
+
],
|
711 |
+
nested_types=[],
|
712 |
+
enum_types=[
|
713 |
+
_ROADEDGE_ROADEDGETYPE,
|
714 |
+
],
|
715 |
+
serialized_options=None,
|
716 |
+
is_extendable=False,
|
717 |
+
syntax='proto2',
|
718 |
+
extension_ranges=[],
|
719 |
+
oneofs=[
|
720 |
+
],
|
721 |
+
serialized_start=1895,
|
722 |
+
serialized_end=2086,
|
723 |
+
)
|
724 |
+
|
725 |
+
|
726 |
+
_ROADLINE = _descriptor.Descriptor(
|
727 |
+
name='RoadLine',
|
728 |
+
full_name='long_metric.RoadLine',
|
729 |
+
filename=None,
|
730 |
+
file=DESCRIPTOR,
|
731 |
+
containing_type=None,
|
732 |
+
create_key=_descriptor._internal_create_key,
|
733 |
+
fields=[
|
734 |
+
_descriptor.FieldDescriptor(
|
735 |
+
name='type', full_name='long_metric.RoadLine.type', index=0,
|
736 |
+
number=1, type=14, cpp_type=8, label=1,
|
737 |
+
has_default_value=False, default_value=0,
|
738 |
+
message_type=None, enum_type=None, containing_type=None,
|
739 |
+
is_extension=False, extension_scope=None,
|
740 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
741 |
+
_descriptor.FieldDescriptor(
|
742 |
+
name='polyline', full_name='long_metric.RoadLine.polyline', index=1,
|
743 |
+
number=2, type=11, cpp_type=10, label=3,
|
744 |
+
has_default_value=False, default_value=[],
|
745 |
+
message_type=None, enum_type=None, containing_type=None,
|
746 |
+
is_extension=False, extension_scope=None,
|
747 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
748 |
+
],
|
749 |
+
extensions=[
|
750 |
+
],
|
751 |
+
nested_types=[],
|
752 |
+
enum_types=[
|
753 |
+
_ROADLINE_ROADLINETYPE,
|
754 |
+
],
|
755 |
+
serialized_options=None,
|
756 |
+
is_extendable=False,
|
757 |
+
syntax='proto2',
|
758 |
+
extension_ranges=[],
|
759 |
+
oneofs=[
|
760 |
+
],
|
761 |
+
serialized_start=2089,
|
762 |
+
serialized_end=2467,
|
763 |
+
)
|
764 |
+
|
765 |
+
|
766 |
+
_STOPSIGN = _descriptor.Descriptor(
|
767 |
+
name='StopSign',
|
768 |
+
full_name='long_metric.StopSign',
|
769 |
+
filename=None,
|
770 |
+
file=DESCRIPTOR,
|
771 |
+
containing_type=None,
|
772 |
+
create_key=_descriptor._internal_create_key,
|
773 |
+
fields=[
|
774 |
+
_descriptor.FieldDescriptor(
|
775 |
+
name='lane', full_name='long_metric.StopSign.lane', index=0,
|
776 |
+
number=1, type=3, cpp_type=2, label=3,
|
777 |
+
has_default_value=False, default_value=[],
|
778 |
+
message_type=None, enum_type=None, containing_type=None,
|
779 |
+
is_extension=False, extension_scope=None,
|
780 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
781 |
+
_descriptor.FieldDescriptor(
|
782 |
+
name='position', full_name='long_metric.StopSign.position', index=1,
|
783 |
+
number=2, type=11, cpp_type=10, label=1,
|
784 |
+
has_default_value=False, default_value=None,
|
785 |
+
message_type=None, enum_type=None, containing_type=None,
|
786 |
+
is_extension=False, extension_scope=None,
|
787 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
788 |
+
],
|
789 |
+
extensions=[
|
790 |
+
],
|
791 |
+
nested_types=[],
|
792 |
+
enum_types=[
|
793 |
+
],
|
794 |
+
serialized_options=None,
|
795 |
+
is_extendable=False,
|
796 |
+
syntax='proto2',
|
797 |
+
extension_ranges=[],
|
798 |
+
oneofs=[
|
799 |
+
],
|
800 |
+
serialized_start=2469,
|
801 |
+
serialized_end=2534,
|
802 |
+
)
|
803 |
+
|
804 |
+
|
805 |
+
_CROSSWALK = _descriptor.Descriptor(
|
806 |
+
name='Crosswalk',
|
807 |
+
full_name='long_metric.Crosswalk',
|
808 |
+
filename=None,
|
809 |
+
file=DESCRIPTOR,
|
810 |
+
containing_type=None,
|
811 |
+
create_key=_descriptor._internal_create_key,
|
812 |
+
fields=[
|
813 |
+
_descriptor.FieldDescriptor(
|
814 |
+
name='polygon', full_name='long_metric.Crosswalk.polygon', index=0,
|
815 |
+
number=1, type=11, cpp_type=10, label=3,
|
816 |
+
has_default_value=False, default_value=[],
|
817 |
+
message_type=None, enum_type=None, containing_type=None,
|
818 |
+
is_extension=False, extension_scope=None,
|
819 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
820 |
+
],
|
821 |
+
extensions=[
|
822 |
+
],
|
823 |
+
nested_types=[],
|
824 |
+
enum_types=[
|
825 |
+
],
|
826 |
+
serialized_options=None,
|
827 |
+
is_extendable=False,
|
828 |
+
syntax='proto2',
|
829 |
+
extension_ranges=[],
|
830 |
+
oneofs=[
|
831 |
+
],
|
832 |
+
serialized_start=2536,
|
833 |
+
serialized_end=2587,
|
834 |
+
)
|
835 |
+
|
836 |
+
|
837 |
+
_SPEEDBUMP = _descriptor.Descriptor(
|
838 |
+
name='SpeedBump',
|
839 |
+
full_name='long_metric.SpeedBump',
|
840 |
+
filename=None,
|
841 |
+
file=DESCRIPTOR,
|
842 |
+
containing_type=None,
|
843 |
+
create_key=_descriptor._internal_create_key,
|
844 |
+
fields=[
|
845 |
+
_descriptor.FieldDescriptor(
|
846 |
+
name='polygon', full_name='long_metric.SpeedBump.polygon', index=0,
|
847 |
+
number=1, type=11, cpp_type=10, label=3,
|
848 |
+
has_default_value=False, default_value=[],
|
849 |
+
message_type=None, enum_type=None, containing_type=None,
|
850 |
+
is_extension=False, extension_scope=None,
|
851 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
852 |
+
],
|
853 |
+
extensions=[
|
854 |
+
],
|
855 |
+
nested_types=[],
|
856 |
+
enum_types=[
|
857 |
+
],
|
858 |
+
serialized_options=None,
|
859 |
+
is_extendable=False,
|
860 |
+
syntax='proto2',
|
861 |
+
extension_ranges=[],
|
862 |
+
oneofs=[
|
863 |
+
],
|
864 |
+
serialized_start=2589,
|
865 |
+
serialized_end=2640,
|
866 |
+
)
|
867 |
+
|
868 |
+
|
869 |
+
_DRIVEWAY = _descriptor.Descriptor(
|
870 |
+
name='Driveway',
|
871 |
+
full_name='long_metric.Driveway',
|
872 |
+
filename=None,
|
873 |
+
file=DESCRIPTOR,
|
874 |
+
containing_type=None,
|
875 |
+
create_key=_descriptor._internal_create_key,
|
876 |
+
fields=[
|
877 |
+
_descriptor.FieldDescriptor(
|
878 |
+
name='polygon', full_name='long_metric.Driveway.polygon', index=0,
|
879 |
+
number=1, type=11, cpp_type=10, label=3,
|
880 |
+
has_default_value=False, default_value=[],
|
881 |
+
message_type=None, enum_type=None, containing_type=None,
|
882 |
+
is_extension=False, extension_scope=None,
|
883 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
884 |
+
],
|
885 |
+
extensions=[
|
886 |
+
],
|
887 |
+
nested_types=[],
|
888 |
+
enum_types=[
|
889 |
+
],
|
890 |
+
serialized_options=None,
|
891 |
+
is_extendable=False,
|
892 |
+
syntax='proto2',
|
893 |
+
extension_ranges=[],
|
894 |
+
oneofs=[
|
895 |
+
],
|
896 |
+
serialized_start=2642,
|
897 |
+
serialized_end=2692,
|
898 |
+
)
|
899 |
+
|
900 |
+
_MAP.fields_by_name['map_features'].message_type = _MAPFEATURE
|
901 |
+
_MAP.fields_by_name['dynamic_states'].message_type = _DYNAMICSTATE
|
902 |
+
_DYNAMICSTATE.fields_by_name['lane_states'].message_type = _TRAFFICSIGNALLANESTATE
|
903 |
+
_TRAFFICSIGNALLANESTATE.fields_by_name['state'].enum_type = _TRAFFICSIGNALLANESTATE_STATE
|
904 |
+
_TRAFFICSIGNALLANESTATE.fields_by_name['stop_point'].message_type = _MAPPOINT
|
905 |
+
_TRAFFICSIGNALLANESTATE_STATE.containing_type = _TRAFFICSIGNALLANESTATE
|
906 |
+
_MAPFEATURE.fields_by_name['lane'].message_type = _LANECENTER
|
907 |
+
_MAPFEATURE.fields_by_name['road_line'].message_type = _ROADLINE
|
908 |
+
_MAPFEATURE.fields_by_name['road_edge'].message_type = _ROADEDGE
|
909 |
+
_MAPFEATURE.fields_by_name['stop_sign'].message_type = _STOPSIGN
|
910 |
+
_MAPFEATURE.fields_by_name['crosswalk'].message_type = _CROSSWALK
|
911 |
+
_MAPFEATURE.fields_by_name['speed_bump'].message_type = _SPEEDBUMP
|
912 |
+
_MAPFEATURE.fields_by_name['driveway'].message_type = _DRIVEWAY
|
913 |
+
_MAPFEATURE.oneofs_by_name['feature_data'].fields.append(
|
914 |
+
_MAPFEATURE.fields_by_name['lane'])
|
915 |
+
_MAPFEATURE.fields_by_name['lane'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data']
|
916 |
+
_MAPFEATURE.oneofs_by_name['feature_data'].fields.append(
|
917 |
+
_MAPFEATURE.fields_by_name['road_line'])
|
918 |
+
_MAPFEATURE.fields_by_name['road_line'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data']
|
919 |
+
_MAPFEATURE.oneofs_by_name['feature_data'].fields.append(
|
920 |
+
_MAPFEATURE.fields_by_name['road_edge'])
|
921 |
+
_MAPFEATURE.fields_by_name['road_edge'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data']
|
922 |
+
_MAPFEATURE.oneofs_by_name['feature_data'].fields.append(
|
923 |
+
_MAPFEATURE.fields_by_name['stop_sign'])
|
924 |
+
_MAPFEATURE.fields_by_name['stop_sign'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data']
|
925 |
+
_MAPFEATURE.oneofs_by_name['feature_data'].fields.append(
|
926 |
+
_MAPFEATURE.fields_by_name['crosswalk'])
|
927 |
+
_MAPFEATURE.fields_by_name['crosswalk'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data']
|
928 |
+
_MAPFEATURE.oneofs_by_name['feature_data'].fields.append(
|
929 |
+
_MAPFEATURE.fields_by_name['speed_bump'])
|
930 |
+
_MAPFEATURE.fields_by_name['speed_bump'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data']
|
931 |
+
_MAPFEATURE.oneofs_by_name['feature_data'].fields.append(
|
932 |
+
_MAPFEATURE.fields_by_name['driveway'])
|
933 |
+
_MAPFEATURE.fields_by_name['driveway'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data']
|
934 |
+
_BOUNDARYSEGMENT.fields_by_name['boundary_type'].enum_type = _ROADLINE_ROADLINETYPE
|
935 |
+
_LANENEIGHBOR.fields_by_name['boundaries'].message_type = _BOUNDARYSEGMENT
|
936 |
+
_LANECENTER.fields_by_name['type'].enum_type = _LANECENTER_LANETYPE
|
937 |
+
_LANECENTER.fields_by_name['polyline'].message_type = _MAPPOINT
|
938 |
+
_LANECENTER.fields_by_name['left_boundaries'].message_type = _BOUNDARYSEGMENT
|
939 |
+
_LANECENTER.fields_by_name['right_boundaries'].message_type = _BOUNDARYSEGMENT
|
940 |
+
_LANECENTER.fields_by_name['left_neighbors'].message_type = _LANENEIGHBOR
|
941 |
+
_LANECENTER.fields_by_name['right_neighbors'].message_type = _LANENEIGHBOR
|
942 |
+
_LANECENTER_LANETYPE.containing_type = _LANECENTER
|
943 |
+
_ROADEDGE.fields_by_name['type'].enum_type = _ROADEDGE_ROADEDGETYPE
|
944 |
+
_ROADEDGE.fields_by_name['polyline'].message_type = _MAPPOINT
|
945 |
+
_ROADEDGE_ROADEDGETYPE.containing_type = _ROADEDGE
|
946 |
+
_ROADLINE.fields_by_name['type'].enum_type = _ROADLINE_ROADLINETYPE
|
947 |
+
_ROADLINE.fields_by_name['polyline'].message_type = _MAPPOINT
|
948 |
+
_ROADLINE_ROADLINETYPE.containing_type = _ROADLINE
|
949 |
+
_STOPSIGN.fields_by_name['position'].message_type = _MAPPOINT
|
950 |
+
_CROSSWALK.fields_by_name['polygon'].message_type = _MAPPOINT
|
951 |
+
_SPEEDBUMP.fields_by_name['polygon'].message_type = _MAPPOINT
|
952 |
+
_DRIVEWAY.fields_by_name['polygon'].message_type = _MAPPOINT
|
953 |
+
DESCRIPTOR.message_types_by_name['Map'] = _MAP
|
954 |
+
DESCRIPTOR.message_types_by_name['DynamicState'] = _DYNAMICSTATE
|
955 |
+
DESCRIPTOR.message_types_by_name['TrafficSignalLaneState'] = _TRAFFICSIGNALLANESTATE
|
956 |
+
DESCRIPTOR.message_types_by_name['MapFeature'] = _MAPFEATURE
|
957 |
+
DESCRIPTOR.message_types_by_name['MapPoint'] = _MAPPOINT
|
958 |
+
DESCRIPTOR.message_types_by_name['BoundarySegment'] = _BOUNDARYSEGMENT
|
959 |
+
DESCRIPTOR.message_types_by_name['LaneNeighbor'] = _LANENEIGHBOR
|
960 |
+
DESCRIPTOR.message_types_by_name['LaneCenter'] = _LANECENTER
|
961 |
+
DESCRIPTOR.message_types_by_name['RoadEdge'] = _ROADEDGE
|
962 |
+
DESCRIPTOR.message_types_by_name['RoadLine'] = _ROADLINE
|
963 |
+
DESCRIPTOR.message_types_by_name['StopSign'] = _STOPSIGN
|
964 |
+
DESCRIPTOR.message_types_by_name['Crosswalk'] = _CROSSWALK
|
965 |
+
DESCRIPTOR.message_types_by_name['SpeedBump'] = _SPEEDBUMP
|
966 |
+
DESCRIPTOR.message_types_by_name['Driveway'] = _DRIVEWAY
|
967 |
+
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
|
968 |
+
|
969 |
+
Map = _reflection.GeneratedProtocolMessageType('Map', (_message.Message,), {
|
970 |
+
'DESCRIPTOR' : _MAP,
|
971 |
+
'__module__' : 'dev.metrics.protos.map_pb2'
|
972 |
+
# @@protoc_insertion_point(class_scope:long_metric.Map)
|
973 |
+
})
|
974 |
+
_sym_db.RegisterMessage(Map)
|
975 |
+
|
976 |
+
DynamicState = _reflection.GeneratedProtocolMessageType('DynamicState', (_message.Message,), {
|
977 |
+
'DESCRIPTOR' : _DYNAMICSTATE,
|
978 |
+
'__module__' : 'dev.metrics.protos.map_pb2'
|
979 |
+
# @@protoc_insertion_point(class_scope:long_metric.DynamicState)
|
980 |
+
})
|
981 |
+
_sym_db.RegisterMessage(DynamicState)
|
982 |
+
|
983 |
+
TrafficSignalLaneState = _reflection.GeneratedProtocolMessageType('TrafficSignalLaneState', (_message.Message,), {
|
984 |
+
'DESCRIPTOR' : _TRAFFICSIGNALLANESTATE,
|
985 |
+
'__module__' : 'dev.metrics.protos.map_pb2'
|
986 |
+
# @@protoc_insertion_point(class_scope:long_metric.TrafficSignalLaneState)
|
987 |
+
})
|
988 |
+
_sym_db.RegisterMessage(TrafficSignalLaneState)
|
989 |
+
|
990 |
+
MapFeature = _reflection.GeneratedProtocolMessageType('MapFeature', (_message.Message,), {
|
991 |
+
'DESCRIPTOR' : _MAPFEATURE,
|
992 |
+
'__module__' : 'dev.metrics.protos.map_pb2'
|
993 |
+
# @@protoc_insertion_point(class_scope:long_metric.MapFeature)
|
994 |
+
})
|
995 |
+
_sym_db.RegisterMessage(MapFeature)
|
996 |
+
|
997 |
+
MapPoint = _reflection.GeneratedProtocolMessageType('MapPoint', (_message.Message,), {
|
998 |
+
'DESCRIPTOR' : _MAPPOINT,
|
999 |
+
'__module__' : 'dev.metrics.protos.map_pb2'
|
1000 |
+
# @@protoc_insertion_point(class_scope:long_metric.MapPoint)
|
1001 |
+
})
|
1002 |
+
_sym_db.RegisterMessage(MapPoint)
|
1003 |
+
|
1004 |
+
BoundarySegment = _reflection.GeneratedProtocolMessageType('BoundarySegment', (_message.Message,), {
|
1005 |
+
'DESCRIPTOR' : _BOUNDARYSEGMENT,
|
1006 |
+
'__module__' : 'dev.metrics.protos.map_pb2'
|
1007 |
+
# @@protoc_insertion_point(class_scope:long_metric.BoundarySegment)
|
1008 |
+
})
|
1009 |
+
_sym_db.RegisterMessage(BoundarySegment)
|
1010 |
+
|
1011 |
+
LaneNeighbor = _reflection.GeneratedProtocolMessageType('LaneNeighbor', (_message.Message,), {
|
1012 |
+
'DESCRIPTOR' : _LANENEIGHBOR,
|
1013 |
+
'__module__' : 'dev.metrics.protos.map_pb2'
|
1014 |
+
# @@protoc_insertion_point(class_scope:long_metric.LaneNeighbor)
|
1015 |
+
})
|
1016 |
+
_sym_db.RegisterMessage(LaneNeighbor)
|
1017 |
+
|
1018 |
+
LaneCenter = _reflection.GeneratedProtocolMessageType('LaneCenter', (_message.Message,), {
|
1019 |
+
'DESCRIPTOR' : _LANECENTER,
|
1020 |
+
'__module__' : 'dev.metrics.protos.map_pb2'
|
1021 |
+
# @@protoc_insertion_point(class_scope:long_metric.LaneCenter)
|
1022 |
+
})
|
1023 |
+
_sym_db.RegisterMessage(LaneCenter)
|
1024 |
+
|
1025 |
+
RoadEdge = _reflection.GeneratedProtocolMessageType('RoadEdge', (_message.Message,), {
|
1026 |
+
'DESCRIPTOR' : _ROADEDGE,
|
1027 |
+
'__module__' : 'dev.metrics.protos.map_pb2'
|
1028 |
+
# @@protoc_insertion_point(class_scope:long_metric.RoadEdge)
|
1029 |
+
})
|
1030 |
+
_sym_db.RegisterMessage(RoadEdge)
|
1031 |
+
|
1032 |
+
RoadLine = _reflection.GeneratedProtocolMessageType('RoadLine', (_message.Message,), {
|
1033 |
+
'DESCRIPTOR' : _ROADLINE,
|
1034 |
+
'__module__' : 'dev.metrics.protos.map_pb2'
|
1035 |
+
# @@protoc_insertion_point(class_scope:long_metric.RoadLine)
|
1036 |
+
})
|
1037 |
+
_sym_db.RegisterMessage(RoadLine)
|
1038 |
+
|
1039 |
+
StopSign = _reflection.GeneratedProtocolMessageType('StopSign', (_message.Message,), {
|
1040 |
+
'DESCRIPTOR' : _STOPSIGN,
|
1041 |
+
'__module__' : 'dev.metrics.protos.map_pb2'
|
1042 |
+
# @@protoc_insertion_point(class_scope:long_metric.StopSign)
|
1043 |
+
})
|
1044 |
+
_sym_db.RegisterMessage(StopSign)
|
1045 |
+
|
1046 |
+
Crosswalk = _reflection.GeneratedProtocolMessageType('Crosswalk', (_message.Message,), {
|
1047 |
+
'DESCRIPTOR' : _CROSSWALK,
|
1048 |
+
'__module__' : 'dev.metrics.protos.map_pb2'
|
1049 |
+
# @@protoc_insertion_point(class_scope:long_metric.Crosswalk)
|
1050 |
+
})
|
1051 |
+
_sym_db.RegisterMessage(Crosswalk)
|
1052 |
+
|
1053 |
+
SpeedBump = _reflection.GeneratedProtocolMessageType('SpeedBump', (_message.Message,), {
|
1054 |
+
'DESCRIPTOR' : _SPEEDBUMP,
|
1055 |
+
'__module__' : 'dev.metrics.protos.map_pb2'
|
1056 |
+
# @@protoc_insertion_point(class_scope:long_metric.SpeedBump)
|
1057 |
+
})
|
1058 |
+
_sym_db.RegisterMessage(SpeedBump)
|
1059 |
+
|
1060 |
+
Driveway = _reflection.GeneratedProtocolMessageType('Driveway', (_message.Message,), {
|
1061 |
+
'DESCRIPTOR' : _DRIVEWAY,
|
1062 |
+
'__module__' : 'dev.metrics.protos.map_pb2'
|
1063 |
+
# @@protoc_insertion_point(class_scope:long_metric.Driveway)
|
1064 |
+
})
|
1065 |
+
_sym_db.RegisterMessage(Driveway)
|
1066 |
+
|
1067 |
+
|
1068 |
+
_LANECENTER.fields_by_name['entry_lanes']._options = None
|
1069 |
+
_LANECENTER.fields_by_name['exit_lanes']._options = None
|
1070 |
+
# @@protoc_insertion_point(module_scope)
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/metrics/protos/scenario_pb2.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
3 |
+
# source: dev/metrics/protos/scenario.proto
|
4 |
+
|
5 |
+
from google.protobuf import descriptor as _descriptor
|
6 |
+
from google.protobuf import message as _message
|
7 |
+
from google.protobuf import reflection as _reflection
|
8 |
+
from google.protobuf import symbol_database as _symbol_database
|
9 |
+
# @@protoc_insertion_point(imports)
|
10 |
+
|
11 |
+
_sym_db = _symbol_database.Default()
|
12 |
+
|
13 |
+
|
14 |
+
from dev.metrics.protos import map_pb2 as dev_dot_metrics_dot_protos_dot_map__pb2
|
15 |
+
|
16 |
+
|
17 |
+
DESCRIPTOR = _descriptor.FileDescriptor(
|
18 |
+
name='dev/metrics/protos/scenario.proto',
|
19 |
+
package='long_metric',
|
20 |
+
syntax='proto2',
|
21 |
+
serialized_options=None,
|
22 |
+
create_key=_descriptor._internal_create_key,
|
23 |
+
serialized_pb=b'\n!dev/metrics/protos/scenario.proto\x12\x0blong_metric\x1a\x1c\x64\x65v/metrics/protos/map.proto\"\xba\x01\n\x0bObjectState\x12\x10\n\x08\x63\x65nter_x\x18\x02 \x01(\x01\x12\x10\n\x08\x63\x65nter_y\x18\x03 \x01(\x01\x12\x10\n\x08\x63\x65nter_z\x18\x04 \x01(\x01\x12\x0e\n\x06length\x18\x05 \x01(\x02\x12\r\n\x05width\x18\x06 \x01(\x02\x12\x0e\n\x06height\x18\x07 \x01(\x02\x12\x0f\n\x07heading\x18\x08 \x01(\x02\x12\x12\n\nvelocity_x\x18\t \x01(\x02\x12\x12\n\nvelocity_y\x18\n \x01(\x02\x12\r\n\x05valid\x18\x0b \x01(\x08\"\xd8\x01\n\x05Track\x12\n\n\x02id\x18\x01 \x01(\x05\x12\x32\n\x0bobject_type\x18\x02 \x01(\x0e\x32\x1d.long_metric.Track.ObjectType\x12(\n\x06states\x18\x03 \x03(\x0b\x32\x18.long_metric.ObjectState\"e\n\nObjectType\x12\x0e\n\nTYPE_UNSET\x10\x00\x12\x10\n\x0cTYPE_VEHICLE\x10\x01\x12\x13\n\x0fTYPE_PEDESTRIAN\x10\x02\x12\x10\n\x0cTYPE_CYCLIST\x10\x03\x12\x0e\n\nTYPE_OTHER\x10\x04\"K\n\x0f\x44ynamicMapState\x12\x38\n\x0blane_states\x18\x01 \x03(\x0b\x32#.long_metric.TrafficSignalLaneState\"\xa5\x01\n\x12RequiredPrediction\x12\x13\n\x0btrack_index\x18\x01 \x01(\x05\x12\x43\n\ndifficulty\x18\x02 \x01(\x0e\x32/.long_metric.RequiredPrediction.DifficultyLevel\"5\n\x0f\x44ifficultyLevel\x12\x08\n\x04NONE\x10\x00\x12\x0b\n\x07LEVEL_1\x10\x01\x12\x0b\n\x07LEVEL_2\x10\x02\"\xdc\x02\n\x08Scenario\x12\x13\n\x0bscenario_id\x18\x05 \x01(\t\x12\x1a\n\x12timestamps_seconds\x18\x01 \x03(\x01\x12\x1a\n\x12\x63urrent_time_index\x18\n \x01(\x05\x12\"\n\x06tracks\x18\x02 \x03(\x0b\x32\x12.long_metric.Track\x12\x38\n\x12\x64ynamic_map_states\x18\x07 \x03(\x0b\x32\x1c.long_metric.DynamicMapState\x12-\n\x0cmap_features\x18\x08 \x03(\x0b\x32\x17.long_metric.MapFeature\x12\x17\n\x0fsdc_track_index\x18\x06 \x01(\x05\x12\x1b\n\x13objects_of_interest\x18\x04 \x03(\x05\x12:\n\x11tracks_to_predict\x18\x0b \x03(\x0b\x32\x1f.long_metric.RequiredPredictionJ\x04\x08\t\x10\n'
|
24 |
+
,
|
25 |
+
dependencies=[dev_dot_metrics_dot_protos_dot_map__pb2.DESCRIPTOR,])
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
_TRACK_OBJECTTYPE = _descriptor.EnumDescriptor(
|
30 |
+
name='ObjectType',
|
31 |
+
full_name='long_metric.Track.ObjectType',
|
32 |
+
filename=None,
|
33 |
+
file=DESCRIPTOR,
|
34 |
+
create_key=_descriptor._internal_create_key,
|
35 |
+
values=[
|
36 |
+
_descriptor.EnumValueDescriptor(
|
37 |
+
name='TYPE_UNSET', index=0, number=0,
|
38 |
+
serialized_options=None,
|
39 |
+
type=None,
|
40 |
+
create_key=_descriptor._internal_create_key),
|
41 |
+
_descriptor.EnumValueDescriptor(
|
42 |
+
name='TYPE_VEHICLE', index=1, number=1,
|
43 |
+
serialized_options=None,
|
44 |
+
type=None,
|
45 |
+
create_key=_descriptor._internal_create_key),
|
46 |
+
_descriptor.EnumValueDescriptor(
|
47 |
+
name='TYPE_PEDESTRIAN', index=2, number=2,
|
48 |
+
serialized_options=None,
|
49 |
+
type=None,
|
50 |
+
create_key=_descriptor._internal_create_key),
|
51 |
+
_descriptor.EnumValueDescriptor(
|
52 |
+
name='TYPE_CYCLIST', index=3, number=3,
|
53 |
+
serialized_options=None,
|
54 |
+
type=None,
|
55 |
+
create_key=_descriptor._internal_create_key),
|
56 |
+
_descriptor.EnumValueDescriptor(
|
57 |
+
name='TYPE_OTHER', index=4, number=4,
|
58 |
+
serialized_options=None,
|
59 |
+
type=None,
|
60 |
+
create_key=_descriptor._internal_create_key),
|
61 |
+
],
|
62 |
+
containing_type=None,
|
63 |
+
serialized_options=None,
|
64 |
+
serialized_start=385,
|
65 |
+
serialized_end=486,
|
66 |
+
)
|
67 |
+
_sym_db.RegisterEnumDescriptor(_TRACK_OBJECTTYPE)
|
68 |
+
|
69 |
+
_REQUIREDPREDICTION_DIFFICULTYLEVEL = _descriptor.EnumDescriptor(
|
70 |
+
name='DifficultyLevel',
|
71 |
+
full_name='long_metric.RequiredPrediction.DifficultyLevel',
|
72 |
+
filename=None,
|
73 |
+
file=DESCRIPTOR,
|
74 |
+
create_key=_descriptor._internal_create_key,
|
75 |
+
values=[
|
76 |
+
_descriptor.EnumValueDescriptor(
|
77 |
+
name='NONE', index=0, number=0,
|
78 |
+
serialized_options=None,
|
79 |
+
type=None,
|
80 |
+
create_key=_descriptor._internal_create_key),
|
81 |
+
_descriptor.EnumValueDescriptor(
|
82 |
+
name='LEVEL_1', index=1, number=1,
|
83 |
+
serialized_options=None,
|
84 |
+
type=None,
|
85 |
+
create_key=_descriptor._internal_create_key),
|
86 |
+
_descriptor.EnumValueDescriptor(
|
87 |
+
name='LEVEL_2', index=2, number=2,
|
88 |
+
serialized_options=None,
|
89 |
+
type=None,
|
90 |
+
create_key=_descriptor._internal_create_key),
|
91 |
+
],
|
92 |
+
containing_type=None,
|
93 |
+
serialized_options=None,
|
94 |
+
serialized_start=678,
|
95 |
+
serialized_end=731,
|
96 |
+
)
|
97 |
+
_sym_db.RegisterEnumDescriptor(_REQUIREDPREDICTION_DIFFICULTYLEVEL)
|
98 |
+
|
99 |
+
|
100 |
+
_OBJECTSTATE = _descriptor.Descriptor(
|
101 |
+
name='ObjectState',
|
102 |
+
full_name='long_metric.ObjectState',
|
103 |
+
filename=None,
|
104 |
+
file=DESCRIPTOR,
|
105 |
+
containing_type=None,
|
106 |
+
create_key=_descriptor._internal_create_key,
|
107 |
+
fields=[
|
108 |
+
_descriptor.FieldDescriptor(
|
109 |
+
name='center_x', full_name='long_metric.ObjectState.center_x', index=0,
|
110 |
+
number=2, type=1, cpp_type=5, label=1,
|
111 |
+
has_default_value=False, default_value=float(0),
|
112 |
+
message_type=None, enum_type=None, containing_type=None,
|
113 |
+
is_extension=False, extension_scope=None,
|
114 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
115 |
+
_descriptor.FieldDescriptor(
|
116 |
+
name='center_y', full_name='long_metric.ObjectState.center_y', index=1,
|
117 |
+
number=3, type=1, cpp_type=5, label=1,
|
118 |
+
has_default_value=False, default_value=float(0),
|
119 |
+
message_type=None, enum_type=None, containing_type=None,
|
120 |
+
is_extension=False, extension_scope=None,
|
121 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
122 |
+
_descriptor.FieldDescriptor(
|
123 |
+
name='center_z', full_name='long_metric.ObjectState.center_z', index=2,
|
124 |
+
number=4, type=1, cpp_type=5, label=1,
|
125 |
+
has_default_value=False, default_value=float(0),
|
126 |
+
message_type=None, enum_type=None, containing_type=None,
|
127 |
+
is_extension=False, extension_scope=None,
|
128 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
129 |
+
_descriptor.FieldDescriptor(
|
130 |
+
name='length', full_name='long_metric.ObjectState.length', index=3,
|
131 |
+
number=5, type=2, cpp_type=6, label=1,
|
132 |
+
has_default_value=False, default_value=float(0),
|
133 |
+
message_type=None, enum_type=None, containing_type=None,
|
134 |
+
is_extension=False, extension_scope=None,
|
135 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
136 |
+
_descriptor.FieldDescriptor(
|
137 |
+
name='width', full_name='long_metric.ObjectState.width', index=4,
|
138 |
+
number=6, type=2, cpp_type=6, label=1,
|
139 |
+
has_default_value=False, default_value=float(0),
|
140 |
+
message_type=None, enum_type=None, containing_type=None,
|
141 |
+
is_extension=False, extension_scope=None,
|
142 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
143 |
+
_descriptor.FieldDescriptor(
|
144 |
+
name='height', full_name='long_metric.ObjectState.height', index=5,
|
145 |
+
number=7, type=2, cpp_type=6, label=1,
|
146 |
+
has_default_value=False, default_value=float(0),
|
147 |
+
message_type=None, enum_type=None, containing_type=None,
|
148 |
+
is_extension=False, extension_scope=None,
|
149 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
150 |
+
_descriptor.FieldDescriptor(
|
151 |
+
name='heading', full_name='long_metric.ObjectState.heading', index=6,
|
152 |
+
number=8, type=2, cpp_type=6, label=1,
|
153 |
+
has_default_value=False, default_value=float(0),
|
154 |
+
message_type=None, enum_type=None, containing_type=None,
|
155 |
+
is_extension=False, extension_scope=None,
|
156 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
157 |
+
_descriptor.FieldDescriptor(
|
158 |
+
name='velocity_x', full_name='long_metric.ObjectState.velocity_x', index=7,
|
159 |
+
number=9, type=2, cpp_type=6, label=1,
|
160 |
+
has_default_value=False, default_value=float(0),
|
161 |
+
message_type=None, enum_type=None, containing_type=None,
|
162 |
+
is_extension=False, extension_scope=None,
|
163 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
164 |
+
_descriptor.FieldDescriptor(
|
165 |
+
name='velocity_y', full_name='long_metric.ObjectState.velocity_y', index=8,
|
166 |
+
number=10, type=2, cpp_type=6, label=1,
|
167 |
+
has_default_value=False, default_value=float(0),
|
168 |
+
message_type=None, enum_type=None, containing_type=None,
|
169 |
+
is_extension=False, extension_scope=None,
|
170 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
171 |
+
_descriptor.FieldDescriptor(
|
172 |
+
name='valid', full_name='long_metric.ObjectState.valid', index=9,
|
173 |
+
number=11, type=8, cpp_type=7, label=1,
|
174 |
+
has_default_value=False, default_value=False,
|
175 |
+
message_type=None, enum_type=None, containing_type=None,
|
176 |
+
is_extension=False, extension_scope=None,
|
177 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
178 |
+
],
|
179 |
+
extensions=[
|
180 |
+
],
|
181 |
+
nested_types=[],
|
182 |
+
enum_types=[
|
183 |
+
],
|
184 |
+
serialized_options=None,
|
185 |
+
is_extendable=False,
|
186 |
+
syntax='proto2',
|
187 |
+
extension_ranges=[],
|
188 |
+
oneofs=[
|
189 |
+
],
|
190 |
+
serialized_start=81,
|
191 |
+
serialized_end=267,
|
192 |
+
)
|
193 |
+
|
194 |
+
|
195 |
+
_TRACK = _descriptor.Descriptor(
|
196 |
+
name='Track',
|
197 |
+
full_name='long_metric.Track',
|
198 |
+
filename=None,
|
199 |
+
file=DESCRIPTOR,
|
200 |
+
containing_type=None,
|
201 |
+
create_key=_descriptor._internal_create_key,
|
202 |
+
fields=[
|
203 |
+
_descriptor.FieldDescriptor(
|
204 |
+
name='id', full_name='long_metric.Track.id', index=0,
|
205 |
+
number=1, type=5, cpp_type=1, label=1,
|
206 |
+
has_default_value=False, default_value=0,
|
207 |
+
message_type=None, enum_type=None, containing_type=None,
|
208 |
+
is_extension=False, extension_scope=None,
|
209 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
210 |
+
_descriptor.FieldDescriptor(
|
211 |
+
name='object_type', full_name='long_metric.Track.object_type', index=1,
|
212 |
+
number=2, type=14, cpp_type=8, label=1,
|
213 |
+
has_default_value=False, default_value=0,
|
214 |
+
message_type=None, enum_type=None, containing_type=None,
|
215 |
+
is_extension=False, extension_scope=None,
|
216 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
217 |
+
_descriptor.FieldDescriptor(
|
218 |
+
name='states', full_name='long_metric.Track.states', index=2,
|
219 |
+
number=3, type=11, cpp_type=10, label=3,
|
220 |
+
has_default_value=False, default_value=[],
|
221 |
+
message_type=None, enum_type=None, containing_type=None,
|
222 |
+
is_extension=False, extension_scope=None,
|
223 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
224 |
+
],
|
225 |
+
extensions=[
|
226 |
+
],
|
227 |
+
nested_types=[],
|
228 |
+
enum_types=[
|
229 |
+
_TRACK_OBJECTTYPE,
|
230 |
+
],
|
231 |
+
serialized_options=None,
|
232 |
+
is_extendable=False,
|
233 |
+
syntax='proto2',
|
234 |
+
extension_ranges=[],
|
235 |
+
oneofs=[
|
236 |
+
],
|
237 |
+
serialized_start=270,
|
238 |
+
serialized_end=486,
|
239 |
+
)
|
240 |
+
|
241 |
+
|
242 |
+
_DYNAMICMAPSTATE = _descriptor.Descriptor(
|
243 |
+
name='DynamicMapState',
|
244 |
+
full_name='long_metric.DynamicMapState',
|
245 |
+
filename=None,
|
246 |
+
file=DESCRIPTOR,
|
247 |
+
containing_type=None,
|
248 |
+
create_key=_descriptor._internal_create_key,
|
249 |
+
fields=[
|
250 |
+
_descriptor.FieldDescriptor(
|
251 |
+
name='lane_states', full_name='long_metric.DynamicMapState.lane_states', index=0,
|
252 |
+
number=1, type=11, cpp_type=10, label=3,
|
253 |
+
has_default_value=False, default_value=[],
|
254 |
+
message_type=None, enum_type=None, containing_type=None,
|
255 |
+
is_extension=False, extension_scope=None,
|
256 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
257 |
+
],
|
258 |
+
extensions=[
|
259 |
+
],
|
260 |
+
nested_types=[],
|
261 |
+
enum_types=[
|
262 |
+
],
|
263 |
+
serialized_options=None,
|
264 |
+
is_extendable=False,
|
265 |
+
syntax='proto2',
|
266 |
+
extension_ranges=[],
|
267 |
+
oneofs=[
|
268 |
+
],
|
269 |
+
serialized_start=488,
|
270 |
+
serialized_end=563,
|
271 |
+
)
|
272 |
+
|
273 |
+
|
274 |
+
_REQUIREDPREDICTION = _descriptor.Descriptor(
|
275 |
+
name='RequiredPrediction',
|
276 |
+
full_name='long_metric.RequiredPrediction',
|
277 |
+
filename=None,
|
278 |
+
file=DESCRIPTOR,
|
279 |
+
containing_type=None,
|
280 |
+
create_key=_descriptor._internal_create_key,
|
281 |
+
fields=[
|
282 |
+
_descriptor.FieldDescriptor(
|
283 |
+
name='track_index', full_name='long_metric.RequiredPrediction.track_index', index=0,
|
284 |
+
number=1, type=5, cpp_type=1, label=1,
|
285 |
+
has_default_value=False, default_value=0,
|
286 |
+
message_type=None, enum_type=None, containing_type=None,
|
287 |
+
is_extension=False, extension_scope=None,
|
288 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
289 |
+
_descriptor.FieldDescriptor(
|
290 |
+
name='difficulty', full_name='long_metric.RequiredPrediction.difficulty', index=1,
|
291 |
+
number=2, type=14, cpp_type=8, label=1,
|
292 |
+
has_default_value=False, default_value=0,
|
293 |
+
message_type=None, enum_type=None, containing_type=None,
|
294 |
+
is_extension=False, extension_scope=None,
|
295 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
296 |
+
],
|
297 |
+
extensions=[
|
298 |
+
],
|
299 |
+
nested_types=[],
|
300 |
+
enum_types=[
|
301 |
+
_REQUIREDPREDICTION_DIFFICULTYLEVEL,
|
302 |
+
],
|
303 |
+
serialized_options=None,
|
304 |
+
is_extendable=False,
|
305 |
+
syntax='proto2',
|
306 |
+
extension_ranges=[],
|
307 |
+
oneofs=[
|
308 |
+
],
|
309 |
+
serialized_start=566,
|
310 |
+
serialized_end=731,
|
311 |
+
)
|
312 |
+
|
313 |
+
|
314 |
+
_SCENARIO = _descriptor.Descriptor(
|
315 |
+
name='Scenario',
|
316 |
+
full_name='long_metric.Scenario',
|
317 |
+
filename=None,
|
318 |
+
file=DESCRIPTOR,
|
319 |
+
containing_type=None,
|
320 |
+
create_key=_descriptor._internal_create_key,
|
321 |
+
fields=[
|
322 |
+
_descriptor.FieldDescriptor(
|
323 |
+
name='scenario_id', full_name='long_metric.Scenario.scenario_id', index=0,
|
324 |
+
number=5, type=9, cpp_type=9, label=1,
|
325 |
+
has_default_value=False, default_value=b"".decode('utf-8'),
|
326 |
+
message_type=None, enum_type=None, containing_type=None,
|
327 |
+
is_extension=False, extension_scope=None,
|
328 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
329 |
+
_descriptor.FieldDescriptor(
|
330 |
+
name='timestamps_seconds', full_name='long_metric.Scenario.timestamps_seconds', index=1,
|
331 |
+
number=1, type=1, cpp_type=5, label=3,
|
332 |
+
has_default_value=False, default_value=[],
|
333 |
+
message_type=None, enum_type=None, containing_type=None,
|
334 |
+
is_extension=False, extension_scope=None,
|
335 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
336 |
+
_descriptor.FieldDescriptor(
|
337 |
+
name='current_time_index', full_name='long_metric.Scenario.current_time_index', index=2,
|
338 |
+
number=10, type=5, cpp_type=1, label=1,
|
339 |
+
has_default_value=False, default_value=0,
|
340 |
+
message_type=None, enum_type=None, containing_type=None,
|
341 |
+
is_extension=False, extension_scope=None,
|
342 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
343 |
+
_descriptor.FieldDescriptor(
|
344 |
+
name='tracks', full_name='long_metric.Scenario.tracks', index=3,
|
345 |
+
number=2, type=11, cpp_type=10, label=3,
|
346 |
+
has_default_value=False, default_value=[],
|
347 |
+
message_type=None, enum_type=None, containing_type=None,
|
348 |
+
is_extension=False, extension_scope=None,
|
349 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
350 |
+
_descriptor.FieldDescriptor(
|
351 |
+
name='dynamic_map_states', full_name='long_metric.Scenario.dynamic_map_states', index=4,
|
352 |
+
number=7, type=11, cpp_type=10, label=3,
|
353 |
+
has_default_value=False, default_value=[],
|
354 |
+
message_type=None, enum_type=None, containing_type=None,
|
355 |
+
is_extension=False, extension_scope=None,
|
356 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
357 |
+
_descriptor.FieldDescriptor(
|
358 |
+
name='map_features', full_name='long_metric.Scenario.map_features', index=5,
|
359 |
+
number=8, type=11, cpp_type=10, label=3,
|
360 |
+
has_default_value=False, default_value=[],
|
361 |
+
message_type=None, enum_type=None, containing_type=None,
|
362 |
+
is_extension=False, extension_scope=None,
|
363 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
364 |
+
_descriptor.FieldDescriptor(
|
365 |
+
name='sdc_track_index', full_name='long_metric.Scenario.sdc_track_index', index=6,
|
366 |
+
number=6, type=5, cpp_type=1, label=1,
|
367 |
+
has_default_value=False, default_value=0,
|
368 |
+
message_type=None, enum_type=None, containing_type=None,
|
369 |
+
is_extension=False, extension_scope=None,
|
370 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
371 |
+
_descriptor.FieldDescriptor(
|
372 |
+
name='objects_of_interest', full_name='long_metric.Scenario.objects_of_interest', index=7,
|
373 |
+
number=4, type=5, cpp_type=1, label=3,
|
374 |
+
has_default_value=False, default_value=[],
|
375 |
+
message_type=None, enum_type=None, containing_type=None,
|
376 |
+
is_extension=False, extension_scope=None,
|
377 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
378 |
+
_descriptor.FieldDescriptor(
|
379 |
+
name='tracks_to_predict', full_name='long_metric.Scenario.tracks_to_predict', index=8,
|
380 |
+
number=11, type=11, cpp_type=10, label=3,
|
381 |
+
has_default_value=False, default_value=[],
|
382 |
+
message_type=None, enum_type=None, containing_type=None,
|
383 |
+
is_extension=False, extension_scope=None,
|
384 |
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
385 |
+
],
|
386 |
+
extensions=[
|
387 |
+
],
|
388 |
+
nested_types=[],
|
389 |
+
enum_types=[
|
390 |
+
],
|
391 |
+
serialized_options=None,
|
392 |
+
is_extendable=False,
|
393 |
+
syntax='proto2',
|
394 |
+
extension_ranges=[],
|
395 |
+
oneofs=[
|
396 |
+
],
|
397 |
+
serialized_start=734,
|
398 |
+
serialized_end=1082,
|
399 |
+
)
|
400 |
+
|
401 |
+
_TRACK.fields_by_name['object_type'].enum_type = _TRACK_OBJECTTYPE
|
402 |
+
_TRACK.fields_by_name['states'].message_type = _OBJECTSTATE
|
403 |
+
_TRACK_OBJECTTYPE.containing_type = _TRACK
|
404 |
+
_DYNAMICMAPSTATE.fields_by_name['lane_states'].message_type = dev_dot_metrics_dot_protos_dot_map__pb2._TRAFFICSIGNALLANESTATE
|
405 |
+
_REQUIREDPREDICTION.fields_by_name['difficulty'].enum_type = _REQUIREDPREDICTION_DIFFICULTYLEVEL
|
406 |
+
_REQUIREDPREDICTION_DIFFICULTYLEVEL.containing_type = _REQUIREDPREDICTION
|
407 |
+
_SCENARIO.fields_by_name['tracks'].message_type = _TRACK
|
408 |
+
_SCENARIO.fields_by_name['dynamic_map_states'].message_type = _DYNAMICMAPSTATE
|
409 |
+
_SCENARIO.fields_by_name['map_features'].message_type = dev_dot_metrics_dot_protos_dot_map__pb2._MAPFEATURE
|
410 |
+
_SCENARIO.fields_by_name['tracks_to_predict'].message_type = _REQUIREDPREDICTION
|
411 |
+
DESCRIPTOR.message_types_by_name['ObjectState'] = _OBJECTSTATE
|
412 |
+
DESCRIPTOR.message_types_by_name['Track'] = _TRACK
|
413 |
+
DESCRIPTOR.message_types_by_name['DynamicMapState'] = _DYNAMICMAPSTATE
|
414 |
+
DESCRIPTOR.message_types_by_name['RequiredPrediction'] = _REQUIREDPREDICTION
|
415 |
+
DESCRIPTOR.message_types_by_name['Scenario'] = _SCENARIO
|
416 |
+
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
|
417 |
+
|
418 |
+
ObjectState = _reflection.GeneratedProtocolMessageType('ObjectState', (_message.Message,), {
|
419 |
+
'DESCRIPTOR' : _OBJECTSTATE,
|
420 |
+
'__module__' : 'dev.metrics.protos.scenario_pb2'
|
421 |
+
# @@protoc_insertion_point(class_scope:long_metric.ObjectState)
|
422 |
+
})
|
423 |
+
_sym_db.RegisterMessage(ObjectState)
|
424 |
+
|
425 |
+
Track = _reflection.GeneratedProtocolMessageType('Track', (_message.Message,), {
|
426 |
+
'DESCRIPTOR' : _TRACK,
|
427 |
+
'__module__' : 'dev.metrics.protos.scenario_pb2'
|
428 |
+
# @@protoc_insertion_point(class_scope:long_metric.Track)
|
429 |
+
})
|
430 |
+
_sym_db.RegisterMessage(Track)
|
431 |
+
|
432 |
+
DynamicMapState = _reflection.GeneratedProtocolMessageType('DynamicMapState', (_message.Message,), {
|
433 |
+
'DESCRIPTOR' : _DYNAMICMAPSTATE,
|
434 |
+
'__module__' : 'dev.metrics.protos.scenario_pb2'
|
435 |
+
# @@protoc_insertion_point(class_scope:long_metric.DynamicMapState)
|
436 |
+
})
|
437 |
+
_sym_db.RegisterMessage(DynamicMapState)
|
438 |
+
|
439 |
+
RequiredPrediction = _reflection.GeneratedProtocolMessageType('RequiredPrediction', (_message.Message,), {
|
440 |
+
'DESCRIPTOR' : _REQUIREDPREDICTION,
|
441 |
+
'__module__' : 'dev.metrics.protos.scenario_pb2'
|
442 |
+
# @@protoc_insertion_point(class_scope:long_metric.RequiredPrediction)
|
443 |
+
})
|
444 |
+
_sym_db.RegisterMessage(RequiredPrediction)
|
445 |
+
|
446 |
+
Scenario = _reflection.GeneratedProtocolMessageType('Scenario', (_message.Message,), {
|
447 |
+
'DESCRIPTOR' : _SCENARIO,
|
448 |
+
'__module__' : 'dev.metrics.protos.scenario_pb2'
|
449 |
+
# @@protoc_insertion_point(class_scope:long_metric.Scenario)
|
450 |
+
})
|
451 |
+
_sym_db.RegisterMessage(Scenario)
|
452 |
+
|
453 |
+
|
454 |
+
# @@protoc_insertion_point(module_scope)
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/metrics/trajectory_features.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from torch import Tensor
|
4 |
+
from typing import Tuple
|
5 |
+
|
6 |
+
|
7 |
+
def _wrap_angle(angle: Tensor) -> Tensor:
|
8 |
+
return (angle + np.pi) % (2 * np.pi) - np.pi
|
9 |
+
|
10 |
+
|
11 |
+
def central_diff(t: Tensor, pad_value: float) -> Tensor:
|
12 |
+
pad_shape = (*t.shape[:-1], 1)
|
13 |
+
pad_tensor = torch.full(pad_shape, pad_value, dtype=t.dtype, device=t.device)
|
14 |
+
diff_t = (t[..., 2:] - t[..., :-2]) / 2
|
15 |
+
return torch.cat([pad_tensor, diff_t, pad_tensor], dim=-1)
|
16 |
+
|
17 |
+
|
18 |
+
def central_logical_and(t: Tensor, pad_value: bool) -> Tensor:
|
19 |
+
pad_shape = (*t.shape[:-1], 1)
|
20 |
+
pad_tensor = torch.full(pad_shape, pad_value, dtype=torch.bool, device=t.device)
|
21 |
+
diff_t = torch.logical_and(t[..., 2:], t[..., :-2])
|
22 |
+
return torch.cat([pad_tensor, diff_t, pad_tensor], dim=-1)
|
23 |
+
|
24 |
+
|
25 |
+
def compute_displacement_error(x, y, z, ref_x, ref_y, ref_z) -> Tensor:
|
26 |
+
return torch.norm(
|
27 |
+
torch.stack([x, y, z], dim=-1) - torch.stack([ref_x, ref_y, ref_z], dim=-1),
|
28 |
+
p=2, dim=-1
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def compute_kinematic_features(
|
33 |
+
x: Tensor,
|
34 |
+
y: Tensor,
|
35 |
+
z: Tensor,
|
36 |
+
heading: Tensor,
|
37 |
+
seconds_per_step: float
|
38 |
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
39 |
+
dpos = central_diff(torch.stack([x, y, z], dim=0), pad_value=np.nan)
|
40 |
+
linear_speed = torch.norm(dpos, p=2, dim=0) / seconds_per_step
|
41 |
+
linear_accel = central_diff(linear_speed, pad_value=np.nan) / seconds_per_step
|
42 |
+
dh_step = _wrap_angle(central_diff(heading, pad_value=np.nan) * 2) / 2
|
43 |
+
dh = dh_step / seconds_per_step
|
44 |
+
d2h_step = _wrap_angle(central_diff(dh_step, pad_value=np.nan) * 2) / 2
|
45 |
+
d2h = d2h_step / (seconds_per_step ** 2)
|
46 |
+
return linear_speed, linear_accel, dh, d2h
|
47 |
+
|
48 |
+
|
49 |
+
def compute_kinematic_validity(valid: Tensor) -> Tuple[Tensor, Tensor]:
|
50 |
+
speed_validity = central_logical_and(valid, pad_value=False)
|
51 |
+
acceleration_validity = central_logical_and(speed_validity, pad_value=False)
|
52 |
+
return speed_validity, acceleration_validity
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/metrics/val_close_long_metrics.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"val_close_long/wosac/realism_meta_metric": 0.6187323331832886,
|
3 |
+
"val_close_long/wosac/kinematic_metrics": 0.6323384046554565,
|
4 |
+
"val_close_long/wosac/interactive_metrics": 0.5528579354286194,
|
5 |
+
"val_close_long/wosac/map_based_metrics": 0.0,
|
6 |
+
"val_close_long/wosac/placement_based_metrics": 0.6086956858634949,
|
7 |
+
"val_close_long/wosac/min_ade": 0.0,
|
8 |
+
"val_close_long/wosac/scenario_counter": 61,
|
9 |
+
"val_close_long/wosac_likelihood/metametric": 0.6187323331832886,
|
10 |
+
"val_close_long/wosac_likelihood/average_displacement_error": 0.0,
|
11 |
+
"val_close_long/wosac_likelihood/min_average_displacement_error": 0.0,
|
12 |
+
"val_close_long/wosac_likelihood/linear_speed_likelihood": 0.11858943104743958,
|
13 |
+
"val_close_long/wosac_likelihood/linear_acceleration_likelihood": 0.6093839406967163,
|
14 |
+
"val_close_long/wosac_likelihood/angular_speed_likelihood": 0.8988037705421448,
|
15 |
+
"val_close_long/wosac_likelihood/angular_acceleration_likelihood": 0.9025763869285583,
|
16 |
+
"val_close_long/wosac_likelihood/distance_to_nearest_object_likelihood": 0.10390616208314896,
|
17 |
+
"val_close_long/wosac_likelihood/collision_indication_likelihood": 0.6108496785163879,
|
18 |
+
"val_close_long/wosac_likelihood/time_to_collision_likelihood": 0.8568302989006042,
|
19 |
+
"val_close_long/wosac_likelihood/simulated_collision_rate": 0.030917881056666374,
|
20 |
+
"val_close_long/wosac_likelihood/num_placement_likelihood": 0.7245867848396301,
|
21 |
+
"val_close_long/wosac_likelihood/num_removement_likelihood": 0.6228984594345093,
|
22 |
+
"val_close_long/wosac_likelihood/distance_placement_likelihood": 1.0,
|
23 |
+
"val_close_long/wosac_likelihood/distance_removement_likelihood": 0.08729743212461472
|
24 |
+
}
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/model/smart.py
ADDED
@@ -0,0 +1,1100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import contextlib
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
import math
|
5 |
+
import numpy as np
|
6 |
+
import pickle
|
7 |
+
import random
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from tqdm import tqdm
|
11 |
+
from torch_geometric.data import Batch
|
12 |
+
from torch_geometric.data import HeteroData
|
13 |
+
from torch.optim.lr_scheduler import LambdaLR
|
14 |
+
from collections import defaultdict
|
15 |
+
|
16 |
+
from dev.utils.func import angle_between_2d_vectors
|
17 |
+
from dev.modules.layers import OccLoss
|
18 |
+
from dev.modules.attr_tokenizer import Attr_Tokenizer
|
19 |
+
from dev.modules.smart_decoder import SMARTDecoder
|
20 |
+
from dev.datasets.preprocess import TokenProcessor
|
21 |
+
from dev.metrics.compute_metrics import *
|
22 |
+
from dev.utils.metrics import *
|
23 |
+
from dev.utils.visualization import *
|
24 |
+
|
25 |
+
|
26 |
+
class SMART(pl.LightningModule):
|
27 |
+
|
28 |
+
def __init__(self, model_config, save_path: os.PathLike="", logger=None, **kwargs) -> None:
|
29 |
+
super(SMART, self).__init__()
|
30 |
+
self.save_hyperparameters()
|
31 |
+
self.model_config = model_config
|
32 |
+
self.warmup_steps = model_config.warmup_steps
|
33 |
+
self.lr = model_config.lr
|
34 |
+
self.total_steps = model_config.total_steps
|
35 |
+
self.dataset = model_config.dataset
|
36 |
+
self.input_dim = model_config.input_dim
|
37 |
+
self.hidden_dim = model_config.hidden_dim
|
38 |
+
self.output_dim = model_config.output_dim
|
39 |
+
self.output_head = model_config.output_head
|
40 |
+
self.num_historical_steps = model_config.num_historical_steps
|
41 |
+
self.num_future_steps = model_config.decoder.num_future_steps
|
42 |
+
self.num_freq_bands = model_config.num_freq_bands
|
43 |
+
self.save_path = save_path
|
44 |
+
self.vis_map = False
|
45 |
+
self.noise = True
|
46 |
+
self.local_logger = logger
|
47 |
+
self.max_epochs = kwargs.get('max_epochs', 0)
|
48 |
+
module_dir = os.path.dirname(os.path.dirname(__file__))
|
49 |
+
|
50 |
+
self.map_token_traj_path = os.path.join(module_dir, 'tokens/map_traj_token5.pkl')
|
51 |
+
self.init_map_token()
|
52 |
+
|
53 |
+
self.predict_motion = model_config.predict_motion
|
54 |
+
self.predict_state = model_config.predict_state
|
55 |
+
self.predict_map = model_config.predict_map
|
56 |
+
self.predict_occ = model_config.predict_occ
|
57 |
+
self.pl2seed_radius = model_config.decoder.pl2seed_radius
|
58 |
+
self.token_size = model_config.decoder.token_size
|
59 |
+
|
60 |
+
# if `disable_grid_token` is True, then we process all locations as
|
61 |
+
# the continuous values. Besides, no occupancy grid input.
|
62 |
+
# Also, no need to predict the xy offset.
|
63 |
+
self.disable_grid_token = getattr(model_config, 'disable_grid_token') \
|
64 |
+
if hasattr(model_config, 'disable_grid_token') else False
|
65 |
+
self.use_grid_token = not self.disable_grid_token
|
66 |
+
if self.disable_grid_token:
|
67 |
+
self.predict_occ = False
|
68 |
+
|
69 |
+
self.token_processer = TokenProcessor(self.token_size,
|
70 |
+
training=self.training,
|
71 |
+
predict_motion=self.predict_motion,
|
72 |
+
predict_state=self.predict_state,
|
73 |
+
predict_map=self.predict_map,
|
74 |
+
state_token=model_config.state_token,
|
75 |
+
pl2seed_radius=self.pl2seed_radius)
|
76 |
+
|
77 |
+
self.attr_tokenizer = Attr_Tokenizer(grid_range=self.model_config.grid_range,
|
78 |
+
grid_interval=self.model_config.grid_interval,
|
79 |
+
radius=model_config.decoder.pl2seed_radius,
|
80 |
+
angle_interval=self.model_config.angle_interval)
|
81 |
+
|
82 |
+
# state tokens
|
83 |
+
self.invalid_state = int(self.model_config.state_token['invalid'])
|
84 |
+
self.valid_state = int(self.model_config.state_token['valid'])
|
85 |
+
self.enter_state = int(self.model_config.state_token['enter'])
|
86 |
+
self.exit_state = int(self.model_config.state_token['exit'])
|
87 |
+
|
88 |
+
self.seed_size = int(model_config.decoder.seed_size)
|
89 |
+
|
90 |
+
self.encoder = SMARTDecoder(
|
91 |
+
decoder_type=model_config.decoder_type,
|
92 |
+
dataset=model_config.dataset,
|
93 |
+
input_dim=model_config.input_dim,
|
94 |
+
hidden_dim=model_config.hidden_dim,
|
95 |
+
num_historical_steps=model_config.num_historical_steps,
|
96 |
+
num_freq_bands=model_config.num_freq_bands,
|
97 |
+
num_heads=model_config.num_heads,
|
98 |
+
head_dim=model_config.head_dim,
|
99 |
+
dropout=model_config.dropout,
|
100 |
+
num_map_layers=model_config.decoder.num_map_layers,
|
101 |
+
num_agent_layers=model_config.decoder.num_agent_layers,
|
102 |
+
pl2pl_radius=model_config.decoder.pl2pl_radius,
|
103 |
+
pl2a_radius=model_config.decoder.pl2a_radius,
|
104 |
+
pl2seed_radius=model_config.decoder.pl2seed_radius,
|
105 |
+
a2a_radius=model_config.decoder.a2a_radius,
|
106 |
+
a2sa_radius=model_config.decoder.a2sa_radius,
|
107 |
+
pl2sa_radius=model_config.decoder.pl2sa_radius,
|
108 |
+
time_span=model_config.decoder.time_span,
|
109 |
+
map_token={'traj_src': self.map_token['traj_src']},
|
110 |
+
token_size=self.token_size,
|
111 |
+
attr_tokenizer=self.attr_tokenizer,
|
112 |
+
predict_motion=self.predict_motion,
|
113 |
+
predict_state=self.predict_state,
|
114 |
+
predict_map=self.predict_map,
|
115 |
+
predict_occ=self.predict_occ,
|
116 |
+
state_token=model_config.state_token,
|
117 |
+
use_grid_token=self.use_grid_token,
|
118 |
+
seed_size=self.seed_size,
|
119 |
+
buffer_size=model_config.decoder.buffer_size,
|
120 |
+
num_recurrent_steps_val=model_config.num_recurrent_steps_val,
|
121 |
+
loss_weight=model_config.loss_weight,
|
122 |
+
logger=logger,
|
123 |
+
)
|
124 |
+
self.minADE = minADE(max_guesses=1)
|
125 |
+
self.minFDE = minFDE(max_guesses=1)
|
126 |
+
self.TokenCls = TokenCls(max_guesses=1)
|
127 |
+
self.StateCls = TokenCls(max_guesses=1)
|
128 |
+
self.StateAccuracy = StateAccuracy(state_token=self.model_config.state_token)
|
129 |
+
self.GridOverlapRate = GridOverlapRate(num_step=18,
|
130 |
+
state_token=self.model_config.state_token,
|
131 |
+
seed_size=self.encoder.agent_encoder.num_seed_feature)
|
132 |
+
# self.NumInsertAccuracy = NumInsertAccuracy()
|
133 |
+
self.loss_weight = model_config.loss_weight
|
134 |
+
|
135 |
+
self.token_cls_loss = nn.CrossEntropyLoss(label_smoothing=0.1)
|
136 |
+
if self.predict_map:
|
137 |
+
self.map_token_loss = nn.CrossEntropyLoss(label_smoothing=0.1)
|
138 |
+
if self.predict_state:
|
139 |
+
self.state_cls_loss = nn.CrossEntropyLoss(
|
140 |
+
torch.tensor(self.loss_weight['state_weight']))
|
141 |
+
self.state_cls_loss_seed = nn.CrossEntropyLoss(
|
142 |
+
torch.tensor(self.loss_weight['seed_state_weight']))
|
143 |
+
self.type_cls_loss_seed = nn.CrossEntropyLoss(
|
144 |
+
torch.tensor(self.loss_weight['seed_type_weight']))
|
145 |
+
self.pos_cls_loss_seed = nn.CrossEntropyLoss(label_smoothing=0.1)
|
146 |
+
self.head_cls_loss_seed = nn.CrossEntropyLoss()
|
147 |
+
self.offset_reg_loss_seed = nn.MSELoss()
|
148 |
+
self.shape_reg_loss_seed = nn.MSELoss()
|
149 |
+
self.pos_reg_loss_seed = nn.MSELoss()
|
150 |
+
if self.predict_occ:
|
151 |
+
self.occ_cls_loss = nn.CrossEntropyLoss()
|
152 |
+
self.agent_occ_loss_seed = nn.BCEWithLogitsLoss(
|
153 |
+
pos_weight=torch.tensor([self.loss_weight['agent_occ_pos_weight']]))
|
154 |
+
self.pt_occ_loss_seed = nn.BCEWithLogitsLoss(
|
155 |
+
pos_weight=torch.tensor([self.loss_weight['pt_occ_pos_weight']]))
|
156 |
+
# self.agent_occ_loss_seed = OccLoss()
|
157 |
+
# self.pt_occ_loss_seed = OccLoss()
|
158 |
+
# self.agent_occ_loss_seed = nn.BCEWithLogitsLoss()
|
159 |
+
# self.pt_occ_loss_seed = nn.BCEWithLogitsLoss()
|
160 |
+
self.rollout_num = 1
|
161 |
+
|
162 |
+
self.val_open_loop = model_config.val_open_loop
|
163 |
+
self.val_close_loop = model_config.val_close_loop
|
164 |
+
self.val_insert = model_config.val_insert or bool(os.getenv('VAL_INSERT'))
|
165 |
+
self.n_rollout_close_val = model_config.n_rollout_close_val
|
166 |
+
self.t = kwargs.get('t', 2)
|
167 |
+
|
168 |
+
# for validation / test
|
169 |
+
self._mode = 'training'
|
170 |
+
self._long_metrics = None
|
171 |
+
self._online_metric = False
|
172 |
+
self._save_validate_reuslts = False
|
173 |
+
self._plot_rollouts = False
|
174 |
+
|
175 |
+
def set(self, mode: str = 'train'):
|
176 |
+
self._mode = mode
|
177 |
+
|
178 |
+
if mode == 'validation':
|
179 |
+
self._online_metric = True
|
180 |
+
self._save_validate_reuslts = True
|
181 |
+
self._long_metrics = LongMetric('val_close_long')
|
182 |
+
|
183 |
+
elif mode == 'test':
|
184 |
+
self._save_validate_reuslts = True
|
185 |
+
|
186 |
+
elif mode == 'plot_rollouts':
|
187 |
+
self._plot_rollouts = True
|
188 |
+
|
189 |
+
def init_map_token(self):
|
190 |
+
self.argmin_sample_len = 3
|
191 |
+
map_token_traj = pickle.load(open(self.map_token_traj_path, 'rb'))
|
192 |
+
self.map_token = {'traj_src': map_token_traj['traj_src'], }
|
193 |
+
traj_end_theta = np.arctan2(self.map_token['traj_src'][:, -1, 1]-self.map_token['traj_src'][:, -2, 1],
|
194 |
+
self.map_token['traj_src'][:, -1, 0]-self.map_token['traj_src'][:, -2, 0])
|
195 |
+
indices = torch.linspace(0, self.map_token['traj_src'].shape[1]-1, steps=self.argmin_sample_len).long()
|
196 |
+
self.map_token['sample_pt'] = torch.from_numpy(self.map_token['traj_src'][:, indices]).to(torch.float)
|
197 |
+
self.map_token['traj_end_theta'] = torch.from_numpy(traj_end_theta).to(torch.float)
|
198 |
+
self.map_token['traj_src'] = torch.from_numpy(self.map_token['traj_src']).to(torch.float)
|
199 |
+
|
200 |
+
def get_agent_inputs(self, data: HeteroData):
|
201 |
+
res = self.encoder.get_agent_inputs(data)
|
202 |
+
return res
|
203 |
+
|
204 |
+
def forward(self, data: HeteroData):
|
205 |
+
res = self.encoder(data)
|
206 |
+
return res
|
207 |
+
|
208 |
+
def maybe_autocast(self, dtype=torch.float16):
|
209 |
+
enable_autocast = self.device != torch.device("cpu")
|
210 |
+
|
211 |
+
if enable_autocast:
|
212 |
+
return torch.cuda.amp.autocast(dtype=dtype)
|
213 |
+
else:
|
214 |
+
return contextlib.nullcontext()
|
215 |
+
|
216 |
+
def check_inputs(self, data: HeteroData):
|
217 |
+
inputs = self.get_agent_inputs(data)
|
218 |
+
next_token_idx_gt = inputs['next_token_idx_gt']
|
219 |
+
next_state_idx_gt = inputs['next_state_idx_gt'].clone()
|
220 |
+
next_token_eval_mask = inputs['next_token_eval_mask'].clone()
|
221 |
+
raw_agent_valid_mask = inputs['raw_agent_valid_mask'].clone()
|
222 |
+
|
223 |
+
self.StateAccuracy.update(state_idx=next_state_idx_gt,
|
224 |
+
valid_mask=raw_agent_valid_mask)
|
225 |
+
|
226 |
+
state_token = inputs['state_token']
|
227 |
+
grid_index = inputs['grid_index']
|
228 |
+
self.GridOverlapRate.update(state_token=state_token,
|
229 |
+
grid_index=grid_index)
|
230 |
+
|
231 |
+
print(self.StateAccuracy)
|
232 |
+
print(self.GridOverlapRate)
|
233 |
+
# self.log('valid_accuracy', self.StateAccuracy.compute()['valid'], prog_bar=True, on_step=True, on_epoch=True, batch_size=1)
|
234 |
+
# self.log('invalid_accuracy', self.StateAccuracy.compute()['invalid'], prog_bar=True, on_step=True, on_epoch=True, batch_size=1)
|
235 |
+
|
236 |
+
def training_step(self,
|
237 |
+
data,
|
238 |
+
batch_idx):
|
239 |
+
|
240 |
+
data = self.token_processer(data)
|
241 |
+
|
242 |
+
data = self.match_token_map(data)
|
243 |
+
data = self.sample_pt_pred(data)
|
244 |
+
|
245 |
+
# find map tokens for entering agents
|
246 |
+
data = self._fetch_enterings(data)
|
247 |
+
|
248 |
+
data['batch_size_a'] = data['agent']['ptr'][1:] - data['agent']['ptr'][:-1]
|
249 |
+
data['batch_size_pl'] = data['pt_token']['ptr'][1:] - data['pt_token']['ptr'][:-1]
|
250 |
+
if isinstance(data, Batch):
|
251 |
+
data['agent']['av_index'] += data['agent']['ptr'][:-1]
|
252 |
+
|
253 |
+
if int(os.getenv("CHECK_INPUTS", 0)):
|
254 |
+
return self.check_inputs(data)
|
255 |
+
|
256 |
+
pred = self(data)
|
257 |
+
|
258 |
+
loss = 0
|
259 |
+
|
260 |
+
log_params = dict(prog_bar=True, on_step=True, on_epoch=True, batch_size=1, sync_dist=True)
|
261 |
+
|
262 |
+
if pred.get('occ_decoder', False):
|
263 |
+
|
264 |
+
agent_occ = pred['agent_occ']
|
265 |
+
agent_occ_gt = pred['agent_occ_gt']
|
266 |
+
agent_occ_eval_mask = pred['agent_occ_eval_mask']
|
267 |
+
pt_occ = pred['pt_occ']
|
268 |
+
pt_occ_gt = pred['pt_occ_gt']
|
269 |
+
pt_occ_eval_mask = pred['pt_occ_eval_mask']
|
270 |
+
|
271 |
+
agent_occ_cls_loss = self.occ_cls_loss(agent_occ[agent_occ_eval_mask],
|
272 |
+
agent_occ_gt[agent_occ_eval_mask])
|
273 |
+
pt_occ_cls_loss = self.occ_cls_loss(pt_occ[pt_occ_eval_mask],
|
274 |
+
pt_occ_gt[pt_occ_eval_mask])
|
275 |
+
self.log('agent_occ_cls_loss', agent_occ_cls_loss, **log_params)
|
276 |
+
self.log('pt_occ_cls_loss', pt_occ_cls_loss, **log_params)
|
277 |
+
loss = loss + agent_occ_cls_loss + pt_occ_cls_loss
|
278 |
+
|
279 |
+
# plot
|
280 |
+
# plot_scenario_ids = ['74ad7b76d5906d39', '1351ea8b8333ddcb', '1352066cc3c0508d', '135436833ce5b9e7', '13570a32432d449', '13577c32a81336fb']
|
281 |
+
if random.random() < 4e-5 or os.getenv('DEBUG'):
|
282 |
+
num_step = pred['num_step']
|
283 |
+
num_agent = pred['num_agent']
|
284 |
+
num_pt = pred['num_pt']
|
285 |
+
with torch.no_grad():
|
286 |
+
agent_occ = agent_occ.reshape(num_step, num_agent, -1).transpose(0, 1)
|
287 |
+
agent_occ_gt = agent_occ_gt.reshape(num_step, num_agent).transpose(0, 1)
|
288 |
+
agent_occ_gt[agent_occ_gt == -1] = self.encoder.agent_encoder.grid_size // 2
|
289 |
+
agent_occ_gt = torch.nn.functional.one_hot(agent_occ_gt, num_classes=self.encoder.agent_encoder.grid_size)
|
290 |
+
agent_occ = self.attr_tokenizer.pad_square(agent_occ.softmax(-1).detach().cpu().numpy())[0]
|
291 |
+
agent_occ_gt = self.attr_tokenizer.pad_square(agent_occ_gt.detach().cpu().numpy())[0]
|
292 |
+
plot_occ_grid(pred['scenario_id'][0],
|
293 |
+
agent_occ,
|
294 |
+
gt_occ=agent_occ_gt,
|
295 |
+
mode='agent',
|
296 |
+
save_path=self.save_path,
|
297 |
+
prefix=f'training_{self.global_step:06d}_')
|
298 |
+
pt_occ = pt_occ.reshape(num_step, num_pt, -1).transpose(0, 1)
|
299 |
+
pt_occ_gt = pt_occ_gt.reshape(num_step, num_pt).transpose(0, 1)
|
300 |
+
pt_occ_gt[pt_occ_gt == -1] = self.encoder.agent_encoder.grid_size // 2
|
301 |
+
pt_occ_gt = torch.nn.functional.one_hot(pt_occ_gt, num_classes=self.encoder.agent_encoder.grid_size)
|
302 |
+
pt_occ = self.attr_tokenizer.pad_square(pt_occ.sigmoid().detach().cpu().numpy())[0]
|
303 |
+
pt_occ_gt = self.attr_tokenizer.pad_square(pt_occ_gt.detach().cpu().numpy())[0]
|
304 |
+
plot_occ_grid(pred['scenario_id'][0],
|
305 |
+
pt_occ,
|
306 |
+
gt_occ=pt_occ_gt,
|
307 |
+
mode='pt',
|
308 |
+
save_path=self.save_path,
|
309 |
+
prefix=f'training_{self.global_step:06d}_')
|
310 |
+
|
311 |
+
return loss
|
312 |
+
|
313 |
+
train_mask = data['agent']['train_mask']
|
314 |
+
# remove_ina_mask = data['agent']['remove_ina_mask']
|
315 |
+
|
316 |
+
# motion token loss
|
317 |
+
if self.predict_motion:
|
318 |
+
|
319 |
+
next_token_idx = pred['next_token_idx']
|
320 |
+
next_token_prob = pred['next_token_prob'] # (a, t, token_size)
|
321 |
+
next_token_idx_gt = pred['next_token_idx_gt'] # (a, t)
|
322 |
+
next_token_eval_mask = pred['next_token_eval_mask'] # (a, t)
|
323 |
+
next_token_eval_mask &= train_mask[:, None]
|
324 |
+
|
325 |
+
token_cls_loss = self.token_cls_loss(next_token_prob[next_token_eval_mask],
|
326 |
+
next_token_idx_gt[next_token_eval_mask]) * self.loss_weight['token_cls_loss']
|
327 |
+
self.log('token_cls_loss', token_cls_loss, **log_params)
|
328 |
+
|
329 |
+
loss = loss + token_cls_loss
|
330 |
+
|
331 |
+
# record motion predict precision of certain timesteps of centain type of agents
|
332 |
+
with torch.no_grad():
|
333 |
+
agent_state_idx_gt = data['agent']['state_idx']
|
334 |
+
index = torch.nonzero(agent_state_idx_gt == self.enter_state)
|
335 |
+
for i in range(10):
|
336 |
+
index[:, 1] += 1
|
337 |
+
index = index[index[:, 1] < agent_state_idx_gt.shape[1] - 1]
|
338 |
+
prob = next_token_prob[index[:, 0], index[:, 1]]
|
339 |
+
gt = next_token_idx_gt[index[:, 0], index[:, 1]]
|
340 |
+
mask = next_token_eval_mask[index[:, 0], index[:, 1]]
|
341 |
+
step_token_cls_loss = self.token_cls_loss(prob[mask], gt[mask])
|
342 |
+
self.log(f's{i}', step_token_cls_loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True)
|
343 |
+
|
344 |
+
# state token loss
|
345 |
+
if self.predict_state:
|
346 |
+
|
347 |
+
next_state_idx = pred['next_state_idx']
|
348 |
+
next_state_prob = pred['next_state_prob']
|
349 |
+
next_state_idx_gt = pred['next_state_idx_gt']
|
350 |
+
next_state_eval_mask = pred['next_state_eval_mask'] # (num_agent, num_timestep)
|
351 |
+
|
352 |
+
state_cls_loss = self.state_cls_loss(next_state_prob[next_state_eval_mask],
|
353 |
+
next_state_idx_gt[next_state_eval_mask]) * self.loss_weight['state_cls_loss']
|
354 |
+
if torch.isnan(state_cls_loss):
|
355 |
+
print("Found NaN in state_cls_loss!!!")
|
356 |
+
print(next_state_prob.shape)
|
357 |
+
print(next_state_idx_gt.shape)
|
358 |
+
print(next_state_eval_mask.shape)
|
359 |
+
print(next_state_idx_gt[next_state_eval_mask].shape)
|
360 |
+
self.log('state_cls_loss', state_cls_loss, **log_params)
|
361 |
+
|
362 |
+
loss = loss + state_cls_loss
|
363 |
+
|
364 |
+
next_state_idx_seed = pred['next_state_idx_seed']
|
365 |
+
next_state_prob_seed = pred['next_state_prob_seed']
|
366 |
+
next_state_idx_gt_seed = pred['next_state_idx_gt_seed']
|
367 |
+
next_type_prob_seed = pred['next_type_prob_seed']
|
368 |
+
next_type_idx_gt_seed = pred['next_type_idx_gt_seed']
|
369 |
+
next_shape_seed = pred['next_shape_seed']
|
370 |
+
next_shape_gt_seed = pred['next_shape_gt_seed']
|
371 |
+
next_state_eval_mask_seed = pred['next_state_eval_mask_seed']
|
372 |
+
next_attr_eval_mask_seed = pred['next_attr_eval_mask_seed']
|
373 |
+
|
374 |
+
# when num_seed_gt=0 loss term will be NaN
|
375 |
+
state_cls_loss_seed = self.state_cls_loss_seed(next_state_prob_seed[next_state_eval_mask_seed],
|
376 |
+
next_state_idx_gt_seed[next_state_eval_mask_seed]) * self.loss_weight['state_cls_loss']
|
377 |
+
state_cls_loss_seed = torch.nan_to_num(state_cls_loss_seed)
|
378 |
+
self.log('seed_state_cls_loss', state_cls_loss_seed, **log_params)
|
379 |
+
|
380 |
+
type_cls_loss_seed = self.type_cls_loss_seed(next_type_prob_seed[next_attr_eval_mask_seed],
|
381 |
+
next_type_idx_gt_seed[next_attr_eval_mask_seed]) * self.loss_weight['type_cls_loss']
|
382 |
+
shape_reg_loss_seed = self.shape_reg_loss_seed(next_shape_seed[next_attr_eval_mask_seed],
|
383 |
+
next_shape_gt_seed[next_attr_eval_mask_seed]) * self.loss_weight['shape_reg_loss']
|
384 |
+
type_cls_loss_seed = torch.nan_to_num(type_cls_loss_seed)
|
385 |
+
shape_reg_loss_seed = torch.nan_to_num(shape_reg_loss_seed)
|
386 |
+
self.log('seed_type_cls_loss', type_cls_loss_seed, **log_params)
|
387 |
+
self.log('seed_shape_reg_loss', shape_reg_loss_seed, **log_params)
|
388 |
+
|
389 |
+
loss = loss + state_cls_loss_seed + type_cls_loss_seed + shape_reg_loss_seed
|
390 |
+
|
391 |
+
next_head_rel_prob_seed = pred['next_head_rel_prob_seed']
|
392 |
+
next_head_rel_index_gt_seed = pred['next_head_rel_index_gt_seed']
|
393 |
+
next_offset_xy_seed = pred['next_offset_xy_seed']
|
394 |
+
next_offset_xy_gt_seed = pred['next_offset_xy_gt_seed']
|
395 |
+
next_head_eval_mask_seed = pred['next_head_eval_mask_seed']
|
396 |
+
|
397 |
+
if self.use_grid_token:
|
398 |
+
next_pos_rel_prob_seed = pred['next_pos_rel_prob_seed']
|
399 |
+
next_pos_rel_index_gt_seed = pred['next_pos_rel_index_gt_seed']
|
400 |
+
|
401 |
+
pos_cls_loss_seed = self.pos_cls_loss_seed(next_pos_rel_prob_seed[next_attr_eval_mask_seed],
|
402 |
+
next_pos_rel_index_gt_seed[next_attr_eval_mask_seed]) * self.loss_weight['pos_cls_loss']
|
403 |
+
offset_reg_loss_seed = self.offset_reg_loss_seed(next_offset_xy_seed[next_head_eval_mask_seed],
|
404 |
+
next_offset_xy_gt_seed[next_head_eval_mask_seed]) * self.loss_weight['offset_reg_loss']
|
405 |
+
pos_cls_loss_seed = torch.nan_to_num(pos_cls_loss_seed)
|
406 |
+
self.log('seed_pos_cls_loss', pos_cls_loss_seed, **log_params)
|
407 |
+
self.log('seed_offset_reg_loss', offset_reg_loss_seed, **log_params)
|
408 |
+
|
409 |
+
loss = loss + pos_cls_loss_seed + offset_reg_loss_seed
|
410 |
+
|
411 |
+
else:
|
412 |
+
next_pos_rel_xy_seed = pred['next_pos_rel_xy_seed']
|
413 |
+
next_pos_rel_xy_gt_seed = pred['next_pos_rel_xy_gt_seed']
|
414 |
+
pos_reg_loss_seed = self.pos_reg_loss_seed(next_pos_rel_xy_seed[next_attr_eval_mask_seed],
|
415 |
+
next_pos_rel_xy_gt_seed[next_attr_eval_mask_seed]) * self.loss_weight['pos_reg_loss']
|
416 |
+
pos_reg_loss_seed = torch.nan_to_num(pos_reg_loss_seed)
|
417 |
+
self.log('seed_pos_reg_loss', pos_reg_loss_seed, **log_params)
|
418 |
+
loss = loss + pos_reg_loss_seed
|
419 |
+
|
420 |
+
head_cls_loss_seed = self.head_cls_loss_seed(next_head_rel_prob_seed[next_head_eval_mask_seed],
|
421 |
+
next_head_rel_index_gt_seed[next_head_eval_mask_seed]) * self.loss_weight['head_cls_loss']
|
422 |
+
self.log('seed_head_cls_loss', head_cls_loss_seed, **log_params)
|
423 |
+
|
424 |
+
loss = loss + head_cls_loss_seed
|
425 |
+
|
426 |
+
# plot_scenario_ids = ['74ad7b76d5906d39', '1351ea8b8333ddcb', '1352066cc3c0508d', '135436833ce5b9e7', '13570a32432d449', '13577c32a81336fb']
|
427 |
+
if random.random() < 4e-5 or int(os.getenv('DEBUG', 0)):
|
428 |
+
with torch.no_grad():
|
429 |
+
# plot probability of inserting new agent (agent-timestep)
|
430 |
+
raw_next_state_prob_seed = pred['raw_next_state_prob_seed']
|
431 |
+
plot_prob_seed(pred['scenario_id'][0],
|
432 |
+
torch.softmax(raw_next_state_prob_seed, dim=-1
|
433 |
+
)[..., -1].detach().cpu().numpy(),
|
434 |
+
self.save_path,
|
435 |
+
prefix=f'training_{self.global_step:06d}_',
|
436 |
+
indices=pred['target_indices'].cpu().numpy())
|
437 |
+
|
438 |
+
# plot heatmap of inserting new agent
|
439 |
+
if self.use_grid_token:
|
440 |
+
next_pos_rel_prob_seed = pred['next_pos_rel_prob_seed']
|
441 |
+
if next_pos_rel_prob_seed.shape[0] > 0:
|
442 |
+
next_pos_rel_prob_seed = torch.softmax(next_pos_rel_prob_seed, dim=-1).detach().cpu().numpy()
|
443 |
+
indices = next_pos_rel_index_gt_seed.detach().cpu().numpy()
|
444 |
+
mask = next_attr_eval_mask_seed.detach().cpu().numpy().astype(np.bool_)
|
445 |
+
indices[~mask] = -1
|
446 |
+
prob, indices = self.attr_tokenizer.pad_square(next_pos_rel_prob_seed, indices)
|
447 |
+
plot_insert_grid(pred['scenario_id'][0],
|
448 |
+
prob,
|
449 |
+
indices=indices,
|
450 |
+
save_path=self.save_path,
|
451 |
+
prefix=f'training_{self.global_step:06d}_')
|
452 |
+
|
453 |
+
if self.predict_occ:
|
454 |
+
|
455 |
+
neighbor_agent_grid_idx = pred['neighbor_agent_grid_idx']
|
456 |
+
neighbor_agent_grid_index_gt = pred['neighbor_agent_grid_index_gt']
|
457 |
+
neighbor_agent_grid_index_eval_mask = pred['neighbor_agent_grid_index_eval_mask']
|
458 |
+
neighbor_pt_grid_idx = pred['neighbor_pt_grid_idx']
|
459 |
+
neighbor_pt_grid_index_gt = pred['neighbor_pt_grid_index_gt']
|
460 |
+
neighbor_pt_grid_index_eval_mask = pred['neighbor_pt_grid_index_eval_mask']
|
461 |
+
|
462 |
+
neighbor_agent_grid_cls_loss = self.occ_cls_loss(neighbor_agent_grid_idx[neighbor_agent_grid_index_eval_mask],
|
463 |
+
neighbor_agent_grid_index_gt[neighbor_agent_grid_index_eval_mask])
|
464 |
+
neighbor_pt_grid_cls_loss = self.occ_cls_loss(neighbor_pt_grid_idx[neighbor_pt_grid_index_eval_mask],
|
465 |
+
neighbor_pt_grid_index_gt[neighbor_pt_grid_index_eval_mask])
|
466 |
+
# self.log('neighbor_agent_grid_cls_loss', neighbor_agent_grid_cls_loss, **log_params)
|
467 |
+
# self.log('neighbor_pt_grid_cls_loss', neighbor_pt_grid_cls_loss, **log_params)
|
468 |
+
# loss = loss + neighbor_agent_grid_cls_loss + neighbor_pt_grid_cls_loss
|
469 |
+
|
470 |
+
grid_agent_occ_seed = pred['grid_agent_occ_seed']
|
471 |
+
grid_pt_occ_seed = pred['grid_pt_occ_seed']
|
472 |
+
grid_agent_occ_gt_seed = pred['grid_agent_occ_gt_seed'].float()
|
473 |
+
grid_pt_occ_gt_seed = pred['grid_pt_occ_gt_seed'].float()
|
474 |
+
grid_agent_occ_eval_mask_seed = pred['grid_agent_occ_eval_mask_seed']
|
475 |
+
grid_pt_occ_eval_mask_seed = pred['grid_pt_occ_eval_mask_seed']
|
476 |
+
|
477 |
+
# plot_scenario_ids = ['74ad7b76d5906d39', '1351ea8b8333ddcb', '1352066cc3c0508d', '135436833ce5b9e7', '13570a32432d449', '13577c32a81336fb']
|
478 |
+
if random.random() < 4e-5 or os.getenv('DEBUG'):
|
479 |
+
with torch.no_grad():
|
480 |
+
agent_occ = self.attr_tokenizer.pad_square(grid_agent_occ_seed.sigmoid().detach().cpu().numpy())[0]
|
481 |
+
agent_occ_gt = self.attr_tokenizer.pad_square(grid_agent_occ_gt_seed.detach().cpu().numpy())[0]
|
482 |
+
plot_occ_grid(pred['scenario_id'][0],
|
483 |
+
agent_occ,
|
484 |
+
gt_occ=agent_occ_gt,
|
485 |
+
mode='agent',
|
486 |
+
save_path=self.save_path,
|
487 |
+
prefix=f'training_{self.global_step:06d}_')
|
488 |
+
pt_occ = self.attr_tokenizer.pad_square(grid_pt_occ_seed.sigmoid().detach().cpu().numpy())[0]
|
489 |
+
pt_occ_gt = self.attr_tokenizer.pad_square(grid_pt_occ_gt_seed.detach().cpu().numpy())[0]
|
490 |
+
plot_occ_grid(pred['scenario_id'][0],
|
491 |
+
pt_occ,
|
492 |
+
gt_occ=pt_occ_gt,
|
493 |
+
mode='pt',
|
494 |
+
save_path=self.save_path,
|
495 |
+
prefix=f'training_{self.global_step:06d}_')
|
496 |
+
|
497 |
+
grid_agent_occ_gt_seed[grid_agent_occ_gt_seed == -1] = 0
|
498 |
+
if grid_agent_occ_gt_seed.min() < 0 or grid_agent_occ_gt_seed.max() > 1 or \
|
499 |
+
grid_pt_occ_gt_seed.min() < 0 or grid_pt_occ_gt_seed.max() > 1:
|
500 |
+
raise RuntimeError("Occurred invalid values in occ gt")
|
501 |
+
|
502 |
+
agent_occ_loss = self.agent_occ_loss_seed(grid_agent_occ_seed[grid_agent_occ_eval_mask_seed],
|
503 |
+
grid_agent_occ_gt_seed[grid_agent_occ_eval_mask_seed]) * self.loss_weight['agent_occ_loss']
|
504 |
+
pt_occ_loss = self.pt_occ_loss_seed(grid_pt_occ_seed[grid_pt_occ_eval_mask_seed],
|
505 |
+
grid_pt_occ_gt_seed[grid_pt_occ_eval_mask_seed]) * self.loss_weight['pt_occ_loss']
|
506 |
+
|
507 |
+
self.log('agent_occ_loss', agent_occ_loss, **log_params)
|
508 |
+
self.log('pt_occ_loss', pt_occ_loss, **log_params)
|
509 |
+
loss = loss + agent_occ_loss + pt_occ_loss
|
510 |
+
|
511 |
+
if os.getenv('LOG_TRAIN', False) and (self.predict_motion or self.predict_state):
|
512 |
+
for a in range(next_token_idx.shape[0]):
|
513 |
+
print(f"agent: {a}")
|
514 |
+
if self.predict_motion:
|
515 |
+
print(f"pred motion: {next_token_idx[a, :, 0].tolist()}, \ngt motion: {next_token_idx_gt[a, :].tolist()}")
|
516 |
+
print(f"train mask: {next_token_eval_mask[a].long().tolist()}")
|
517 |
+
if self.predict_state:
|
518 |
+
print(f"pred state: {next_state_idx[a, :, 0].tolist()}, \ngt state: {next_state_idx_gt[a, :].tolist()}")
|
519 |
+
print(f"train mask: {next_state_eval_mask[a].long().tolist()}")
|
520 |
+
num_sa = next_state_idx_seed[..., 0].sum(dim=-1).bool().sum()
|
521 |
+
for sa in range(num_sa):
|
522 |
+
print(f"seed agent: {sa}")
|
523 |
+
print(f"seed pred state: {next_state_idx_seed[sa, :, 0].tolist()}, \ngt seed state: {next_state_idx_gt_seed[sa, :].tolist()}")
|
524 |
+
# if sa < next_pos_rel_seed.shape[0]:
|
525 |
+
# print(f"pred pos: {next_pos_rel_seed[sa, :, 0].tolist()}, \ngt pos: {next_pos_rel_gt_seed[sa, :, 0].tolist()}")
|
526 |
+
# print(f"pred head: {next_head_rel_seed[sa].tolist()}, \ngt head: {next_head_rel_gt_seed[sa].tolist()}")
|
527 |
+
# print(f"seed train mask: {next_state_eval_mask_seed[sa].long().tolist()}")
|
528 |
+
|
529 |
+
# map token loss
|
530 |
+
if self.predict_map:
|
531 |
+
|
532 |
+
map_next_token_prob = pred['map_next_token_prob']
|
533 |
+
map_next_token_idx_gt = pred['map_next_token_idx_gt']
|
534 |
+
map_next_token_eval_mask = pred['map_next_token_eval_mask']
|
535 |
+
|
536 |
+
map_token_loss = self.map_token_loss(map_next_token_prob[map_next_token_eval_mask], map_next_token_idx_gt[map_next_token_eval_mask]) * self.loss_weight['map_token_loss']
|
537 |
+
self.log('map_token_loss', map_token_loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=1)
|
538 |
+
loss = loss + map_token_loss
|
539 |
+
|
540 |
+
allocated = torch.cuda.memory_allocated(device='cuda:0') / (1024 ** 3)
|
541 |
+
reserved = torch.cuda.memory_reserved(device='cuda:0') / (1024 ** 3)
|
542 |
+
self.log('allocated', allocated, **log_params)
|
543 |
+
self.log('reserved', reserved, **log_params)
|
544 |
+
|
545 |
+
return loss
|
546 |
+
|
547 |
+
def validation_step(self,
|
548 |
+
data,
|
549 |
+
batch_idx):
|
550 |
+
|
551 |
+
self.debug = int(os.getenv('DEBUG', 0))
|
552 |
+
|
553 |
+
# ! validation in training process
|
554 |
+
if (
|
555 |
+
self._mode == 'training' and (
|
556 |
+
self.current_epoch not in [5, 10, 20, 25, self.max_epochs] or random.random() > 5e-4) and
|
557 |
+
not self.debug
|
558 |
+
):
|
559 |
+
self.val_open_loop = False
|
560 |
+
self.val_close_loop = False
|
561 |
+
return
|
562 |
+
|
563 |
+
if int(os.getenv('NO_VAL', 0)) or int(os.getenv("CHECK_INPUTS", 0)):
|
564 |
+
return
|
565 |
+
|
566 |
+
# ! check if save exists
|
567 |
+
if not self._plot_rollouts:
|
568 |
+
rollouts_path = os.path.join(self.save_path, f'idx_{self.trainer.global_rank}_{batch_idx}_rollouts.pkl')
|
569 |
+
if os.path.exists(rollouts_path):
|
570 |
+
tqdm.write(f'Skipped batch {batch_idx}')
|
571 |
+
return
|
572 |
+
else:
|
573 |
+
rollouts_path = os.path.join(self.save_path, f'{data["scenario_id"][0]}.gif')
|
574 |
+
if os.path.exists(rollouts_path):
|
575 |
+
tqdm.write(f'Skipped scenario {data["scenario_id"][0]}')
|
576 |
+
return
|
577 |
+
|
578 |
+
# ! data preparation
|
579 |
+
data = self.token_processer(data)
|
580 |
+
|
581 |
+
data = self.match_token_map(data)
|
582 |
+
data = self.sample_pt_pred(data)
|
583 |
+
|
584 |
+
# find map tokens for entering agents
|
585 |
+
data = self._fetch_enterings(data)
|
586 |
+
|
587 |
+
data['batch_size_a'] = data['agent']['ptr'][1:] - data['agent']['ptr'][:-1]
|
588 |
+
data['batch_size_pl'] = data['pt_token']['ptr'][1:] - data['pt_token']['ptr'][:-1]
|
589 |
+
if isinstance(data, Batch):
|
590 |
+
data['agent']['av_index'] += data['agent']['ptr'][:-1]
|
591 |
+
|
592 |
+
if int(os.getenv('NEAREST_POS', 0)):
|
593 |
+
pred = self.encoder.predict_nearest_pos(data, rank=self.local_rank)
|
594 |
+
return
|
595 |
+
|
596 |
+
# if self.insert_agent:
|
597 |
+
# pred = self.encoder.insert_agent(data)
|
598 |
+
# return
|
599 |
+
|
600 |
+
# ! open-loop validation
|
601 |
+
if self.val_open_loop or int(os.getenv('OPEN_LOOP', 0)):
|
602 |
+
|
603 |
+
pred = self(data)
|
604 |
+
|
605 |
+
# pred['next_state_prob_seed'] = torch.softmax(pred['next_state_prob_seed'], dim=-1)[..., -1]
|
606 |
+
# plot_prob_seed(pred, self.save_path, suffix=f'_training')
|
607 |
+
|
608 |
+
loss = 0
|
609 |
+
|
610 |
+
if self.predict_motion:
|
611 |
+
|
612 |
+
# motion token
|
613 |
+
next_token_idx = pred['next_token_idx']
|
614 |
+
next_token_idx_gt = pred['next_token_idx_gt'] # (num_agent, num_step, 10)
|
615 |
+
next_token_prob = pred['next_token_prob']
|
616 |
+
next_token_eval_mask = pred['next_token_eval_mask']
|
617 |
+
|
618 |
+
token_cls_loss = self.token_cls_loss(next_token_prob[next_token_eval_mask], next_token_idx_gt[next_token_eval_mask])
|
619 |
+
loss = loss + token_cls_loss
|
620 |
+
|
621 |
+
if self.predict_state:
|
622 |
+
|
623 |
+
# state token
|
624 |
+
next_state_idx = pred['next_state_idx']
|
625 |
+
next_state_idx_gt = pred['next_state_idx_gt']
|
626 |
+
next_state_prob = pred['next_state_prob']
|
627 |
+
next_state_eval_mask = pred['next_state_eval_mask']
|
628 |
+
|
629 |
+
state_cls_loss = self.state_cls_loss(next_state_prob[next_state_eval_mask], next_state_idx_gt[next_state_eval_mask])
|
630 |
+
loss = loss + state_cls_loss
|
631 |
+
|
632 |
+
# seed state token
|
633 |
+
next_state_idx_seed = pred['next_state_idx_seed']
|
634 |
+
next_state_idx_gt_seed = pred['next_state_idx_gt_seed']
|
635 |
+
|
636 |
+
if self.predict_occ:
|
637 |
+
|
638 |
+
grid_agent_occ_seed = pred['grid_agent_occ_seed']
|
639 |
+
grid_pt_occ_seed = pred['grid_pt_occ_seed']
|
640 |
+
grid_agent_occ_gt_seed = pred['grid_agent_occ_gt_seed'].float()
|
641 |
+
grid_pt_occ_gt_seed = pred['grid_pt_occ_gt_seed'].float()
|
642 |
+
|
643 |
+
agent_occ = self.attr_tokenizer.pad_square(grid_agent_occ_seed.sigmoid().detach().cpu().numpy())[0]
|
644 |
+
agent_occ_gt = self.attr_tokenizer.pad_square(grid_agent_occ_gt_seed.detach().cpu().numpy())[0]
|
645 |
+
plot_occ_grid(pred['scenario_id'][0],
|
646 |
+
agent_occ,
|
647 |
+
gt_occ=agent_occ_gt,
|
648 |
+
mode='agent',
|
649 |
+
save_path=self.save_path,
|
650 |
+
prefix=f'eval_')
|
651 |
+
pt_occ = self.attr_tokenizer.pad_square(grid_pt_occ_seed.sigmoid().detach().cpu().numpy())[0]
|
652 |
+
pt_occ_gt = self.attr_tokenizer.pad_square(grid_pt_occ_gt_seed.detach().cpu().numpy())[0]
|
653 |
+
plot_occ_grid(pred['scenario_id'][0],
|
654 |
+
pt_occ,
|
655 |
+
gt_occ=pt_occ_gt,
|
656 |
+
mode='pt',
|
657 |
+
save_path=self.save_path,
|
658 |
+
prefix=f'eval_')
|
659 |
+
|
660 |
+
self.log('val_loss', loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=1, sync_dist=True)
|
661 |
+
|
662 |
+
if self.val_insert:
|
663 |
+
|
664 |
+
pred = self(data)
|
665 |
+
|
666 |
+
next_state_idx_seed = pred['next_state_idx_seed']
|
667 |
+
next_state_idx_gt_seed = pred['next_state_idx_gt_seed']
|
668 |
+
|
669 |
+
self.NumInsertAccuracy.update(next_state_idx_seed=next_state_idx_seed,
|
670 |
+
next_state_idx_gt_seed=next_state_idx_gt_seed)
|
671 |
+
|
672 |
+
return
|
673 |
+
|
674 |
+
# ! close-loop validation
|
675 |
+
if self.val_close_loop and (self.predict_motion or self.predict_state):
|
676 |
+
|
677 |
+
rollouts = []
|
678 |
+
for _ in tqdm(range(self.n_rollout_close_val), leave=False, desc='Rollout ...'):
|
679 |
+
rollout = self.encoder.inference(data.clone())
|
680 |
+
rollouts.append(rollout)
|
681 |
+
|
682 |
+
av_index = int(rollout['ego_index'])
|
683 |
+
scenario_id = rollout['scenario_id'][0]
|
684 |
+
|
685 |
+
# motion tokens
|
686 |
+
if self.predict_motion:
|
687 |
+
|
688 |
+
if self._plot_rollouts: # only plot gifs for last 2 epochs for efficiency
|
689 |
+
plot_val(data, rollout, av_index, self.save_path, pl2seed_radius=self.pl2seed_radius, attr_tokenizer=self.attr_tokenizer)
|
690 |
+
|
691 |
+
# next_token_idx = pred['next_token_idx'][..., None]
|
692 |
+
# next_token_idx_gt = pred['next_token_idx_gt'][:, 2:] # hard code 2=11//5
|
693 |
+
# next_token_eval_mask = pred['next_token_eval_mask'][:, 2:]
|
694 |
+
|
695 |
+
# gt_traj = pred['gt_traj']
|
696 |
+
# pred_traj = pred['pred_traj']
|
697 |
+
# pred_head = pred['pred_head']
|
698 |
+
|
699 |
+
# self.TokenCls.update(pred=next_token_idx[next_token_eval_mask], target=next_token_idx_gt[next_token_eval_mask],
|
700 |
+
# valid_mask=next_token_eval_mask[next_token_eval_mask])
|
701 |
+
# self.log('val_token_cls_acc', self.TokenCls, prog_bar=True, on_step=True, on_epoch=True, batch_size=1, sync_dist=True)
|
702 |
+
|
703 |
+
# remove the agents which are unseen at current step
|
704 |
+
# eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps - 1]
|
705 |
+
|
706 |
+
# self.minADE.update(pred=pred_traj[eval_mask], target=gt_traj[eval_mask], valid_mask=valid_mask[eval_mask])
|
707 |
+
# self.minFDE.update(pred=pred_traj[eval_mask], target=gt_traj[eval_mask], valid_mask=valid_mask[eval_mask])
|
708 |
+
# print('ade: ', self.minADE.compute(), 'fde: ', self.minFDE.compute())
|
709 |
+
|
710 |
+
# self.log('val_minADE', self.minADE, prog_bar=True, on_step=True, on_epoch=True, batch_size=1)
|
711 |
+
# self.log('val_minFDE', self.minFDE, prog_bar=True, on_step=True, on_epoch=True, batch_size=1)
|
712 |
+
|
713 |
+
# state tokens
|
714 |
+
if self.predict_state:
|
715 |
+
|
716 |
+
if self.use_grid_token:
|
717 |
+
next_pos_rel_prob_seed = rollout['next_pos_rel_prob_seed'].cpu().numpy() # (s, t, grid_size)
|
718 |
+
prob, _ = self.attr_tokenizer.pad_square(next_pos_rel_prob_seed)
|
719 |
+
|
720 |
+
if self._plot_rollouts:
|
721 |
+
if self.use_grid_token:
|
722 |
+
plot_insert_grid(scenario_id,
|
723 |
+
prob,
|
724 |
+
save_path=self.save_path,
|
725 |
+
prefix=f'inference_')
|
726 |
+
plot_prob_seed(scenario_id,
|
727 |
+
rollout['next_state_prob_seed'].cpu().numpy(),
|
728 |
+
self.save_path,
|
729 |
+
prefix=f'inference_')
|
730 |
+
|
731 |
+
next_state_idx = rollout['next_state_idx'][..., None]
|
732 |
+
# next_state_idx_gt = rollout['next_state_idx_gt'][:, 2:]
|
733 |
+
# next_state_eval_mask = rollout['next_state_eval_mask'][:, 2:]
|
734 |
+
|
735 |
+
# self.StateCls.update(pred=next_state_idx[next_token_eval_mask], target=next_state_idx_gt[next_token_eval_mask],
|
736 |
+
# valid_mask=next_token_eval_mask[next_token_eval_mask])
|
737 |
+
# self.log('val_state_cls_acc', self.TokenCls, prog_bar=True, on_step=True, on_epoch=True, batch_size=1, sync_dist=True)
|
738 |
+
|
739 |
+
self.StateAccuracy.update(state_idx=next_state_idx[..., 0])
|
740 |
+
self.log('valid_accuracy', self.StateAccuracy.compute()['valid'], prog_bar=True, on_step=True, on_epoch=True, batch_size=1)
|
741 |
+
self.log('invalid_accuracy', self.StateAccuracy.compute()['invalid'], prog_bar=True, on_step=True, on_epoch=True, batch_size=1)
|
742 |
+
self.local_logger.info(rollout['log_message'])
|
743 |
+
# print(rollout['log_message'])
|
744 |
+
# print(self.StateAccuracy)
|
745 |
+
|
746 |
+
if self.predict_occ:
|
747 |
+
|
748 |
+
grid_agent_occ_seed = rollout['grid_agent_occ_seed']
|
749 |
+
grid_pt_occ_seed = rollout['grid_pt_occ_seed']
|
750 |
+
grid_agent_occ_gt_seed = rollout['grid_agent_occ_gt_seed']
|
751 |
+
|
752 |
+
agent_occ = self.attr_tokenizer.pad_square(grid_agent_occ_seed.sigmoid().cpu().numpy())[0]
|
753 |
+
agent_occ_gt = self.attr_tokenizer.pad_square(grid_agent_occ_gt_seed.sigmoid().cpu().numpy())[0]
|
754 |
+
if self._plot_rollouts:
|
755 |
+
plot_occ_grid(scenario_id,
|
756 |
+
agent_occ,
|
757 |
+
gt_occ=agent_occ_gt,
|
758 |
+
mode='agent',
|
759 |
+
save_path=self.save_path,
|
760 |
+
prefix=f'inference_')
|
761 |
+
|
762 |
+
if self._online_metric or self._save_validate_reuslts:
|
763 |
+
|
764 |
+
# ! format results
|
765 |
+
pred_valid, token_pos, token_head = [], [], []
|
766 |
+
pred_traj, pred_head, pred_z = [], [], []
|
767 |
+
pred_shape, pred_type, pred_state = [], [], []
|
768 |
+
agent_id = []
|
769 |
+
for rollout in rollouts:
|
770 |
+
pred_valid.append(rollout['pred_valid'])
|
771 |
+
token_pos.append(rollout['pos_a'])
|
772 |
+
token_head.append(rollout['head_a'])
|
773 |
+
pred_traj.append(rollout['pred_traj'])
|
774 |
+
pred_head.append(rollout['pred_head'])
|
775 |
+
pred_z.append(rollout['pred_z'])
|
776 |
+
pred_shape.append(rollout['eval_shape'])
|
777 |
+
pred_type.append(rollout['pred_type'])
|
778 |
+
pred_state.append(rollout['next_state_idx'])
|
779 |
+
agent_id.append(rollout['agent_id'])
|
780 |
+
|
781 |
+
pred_valid = torch.stack(pred_valid, dim=1)
|
782 |
+
token_pos = torch.stack(token_pos, dim=1)
|
783 |
+
token_head = torch.stack(token_head, dim=1)
|
784 |
+
pred_traj = torch.stack(pred_traj, dim=1) # (n_agent, n_rollout, n_step, 2)
|
785 |
+
pred_head = torch.stack(pred_head, dim=1)
|
786 |
+
pred_z = torch.stack(pred_z, dim=1)
|
787 |
+
pred_shape = torch.stack(pred_shape, dim=1) # [n_agent, n_rollout, 3]
|
788 |
+
pred_type = torch.stack(pred_type, dim=1) # [n_agent, n_rollout]
|
789 |
+
pred_state = torch.stack(pred_state, dim=1) # [n_agent, n_rollout, n_step // shift]
|
790 |
+
agent_id = torch.stack(agent_id, dim=1) # [n_agent, n_rollout]
|
791 |
+
|
792 |
+
agent_batch = torch.zeros((pred_traj.shape[0]), dtype=torch.long)
|
793 |
+
rollouts = dict(
|
794 |
+
_scenario_id=data['scenario_id'],
|
795 |
+
scenario_id=get_scenario_id_int_tensor(data['scenario_id']),
|
796 |
+
av_id=int(rollouts[0]['agent_id'][rollouts[0]['ego_index']]), # NOTE: hard code!!!
|
797 |
+
agent_id=agent_id.cpu(),
|
798 |
+
agent_batch=agent_batch.cpu(),
|
799 |
+
pred_traj=pred_traj.cpu(),
|
800 |
+
pred_z=pred_z.cpu(),
|
801 |
+
pred_head=pred_head.cpu(),
|
802 |
+
pred_shape=pred_shape.cpu(),
|
803 |
+
pred_type=pred_type.cpu(),
|
804 |
+
pred_state=pred_state.cpu(),
|
805 |
+
pred_valid=pred_valid.cpu(),
|
806 |
+
token_pos=token_pos.cpu(),
|
807 |
+
token_head=token_head.cpu(),
|
808 |
+
tfrecord_path=data['tfrecord_path'],
|
809 |
+
)
|
810 |
+
|
811 |
+
if self._save_validate_reuslts:
|
812 |
+
with open(rollouts_path, 'wb') as f:
|
813 |
+
pickle.dump(rollouts, f)
|
814 |
+
|
815 |
+
if self._online_metric:
|
816 |
+
self._long_metrics.update(rollouts)
|
817 |
+
|
818 |
+
def on_validation_start(self):
|
819 |
+
self.scenario_rollouts = []
|
820 |
+
self.batch_metric = defaultdict(list)
|
821 |
+
|
822 |
+
def on_validation_epoch_end(self):
|
823 |
+
if self.val_close_loop:
|
824 |
+
|
825 |
+
if self._long_metrics is not None:
|
826 |
+
epoch_long_metrics = self._long_metrics.compute()
|
827 |
+
if self.global_rank == 0:
|
828 |
+
epoch_long_metrics['epoch'] = self.current_epoch
|
829 |
+
self.logger.log_metrics(epoch_long_metrics)
|
830 |
+
|
831 |
+
self._long_metrics.reset()
|
832 |
+
|
833 |
+
self.minADE.reset()
|
834 |
+
self.minFDE.reset()
|
835 |
+
self.StateAccuracy.reset()
|
836 |
+
|
837 |
+
def configure_optimizers(self):
|
838 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
|
839 |
+
|
840 |
+
def lr_lambda(current_step):
|
841 |
+
if current_step + 1 < self.warmup_steps:
|
842 |
+
return float(current_step + 1) / float(max(1, self.warmup_steps))
|
843 |
+
return max(
|
844 |
+
0.0, 0.5 * (1.0 + math.cos(math.pi * (current_step - self.warmup_steps) / float(max(1, self.total_steps - self.warmup_steps))))
|
845 |
+
)
|
846 |
+
|
847 |
+
lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
|
848 |
+
return [optimizer], [lr_scheduler]
|
849 |
+
|
850 |
+
def load_state_from_file(self, filename, to_cpu=False):
|
851 |
+
logger = self.local_logger
|
852 |
+
|
853 |
+
if not os.path.isfile(filename):
|
854 |
+
raise FileNotFoundError
|
855 |
+
|
856 |
+
logger.info('==> Loading parameters from checkpoint %s to %s' % (filename, 'CPU' if to_cpu else 'GPU'))
|
857 |
+
loc_type = torch.device('cpu') if to_cpu else None
|
858 |
+
checkpoint = torch.load(filename, map_location=loc_type)
|
859 |
+
|
860 |
+
version = checkpoint.get("version", None)
|
861 |
+
if version is not None:
|
862 |
+
logger.info('==> Checkpoint trained from version: %s' % version)
|
863 |
+
|
864 |
+
|
865 |
+
model_state_disk = checkpoint['state_dict']
|
866 |
+
logger.info(f'The number of disk ckpt keys: {len(model_state_disk)}')
|
867 |
+
|
868 |
+
model_state = self.state_dict()
|
869 |
+
model_state_disk_filter = {}
|
870 |
+
for key, val in model_state_disk.items():
|
871 |
+
if key in model_state and model_state_disk[key].shape == model_state[key].shape:
|
872 |
+
model_state_disk_filter[key] = val
|
873 |
+
else:
|
874 |
+
if key not in model_state:
|
875 |
+
print(f'Ignore key in disk (not found in model): {key}, shape={val.shape}')
|
876 |
+
else:
|
877 |
+
print(f'Ignore key in disk (shape does not match): {key}, load_shape={val.shape}, model_shape={model_state[key].shape}')
|
878 |
+
|
879 |
+
model_state_disk = model_state_disk_filter
|
880 |
+
missing_keys, unexpected_keys = self.load_state_dict(model_state_disk, strict=False)
|
881 |
+
|
882 |
+
logger.info(f'Missing keys: {missing_keys}')
|
883 |
+
logger.info(f'The number of missing keys: {len(missing_keys)}')
|
884 |
+
logger.info(f'The number of unexpected keys: {len(unexpected_keys)}')
|
885 |
+
logger.info('==> Done (total keys %d)' % (len(model_state)))
|
886 |
+
|
887 |
+
epoch = checkpoint.get('epoch', -1)
|
888 |
+
it = checkpoint.get('it', 0.0)
|
889 |
+
|
890 |
+
return it, epoch
|
891 |
+
|
892 |
+
def match_token_map(self, data):
|
893 |
+
traj_pos = data['map_save']['traj_pos'].to(torch.float)
|
894 |
+
traj_theta = data['map_save']['traj_theta'].to(torch.float)
|
895 |
+
pl_idx_list = data['map_save']['pl_idx_list']
|
896 |
+
token_sample_pt = self.map_token['sample_pt'].to(traj_pos.device)
|
897 |
+
token_src = self.map_token['traj_src'].to(traj_pos.device)
|
898 |
+
max_traj_len = self.map_token['traj_src'].shape[1]
|
899 |
+
pl_num = traj_pos.shape[0]
|
900 |
+
|
901 |
+
pt_token_pos = traj_pos[:, 0, :].clone()
|
902 |
+
pt_token_orientation = traj_theta.clone()
|
903 |
+
cos, sin = traj_theta.cos(), traj_theta.sin()
|
904 |
+
rot_mat = traj_theta.new_zeros(pl_num, 2, 2)
|
905 |
+
rot_mat[..., 0, 0] = cos
|
906 |
+
rot_mat[..., 0, 1] = -sin
|
907 |
+
rot_mat[..., 1, 0] = sin
|
908 |
+
rot_mat[..., 1, 1] = cos
|
909 |
+
traj_pos_local = torch.bmm((traj_pos - traj_pos[:, 0:1]), rot_mat.view(-1, 2, 2))
|
910 |
+
distance = torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1)) ** 2, dim=(-2, -1))
|
911 |
+
pt_token_id = torch.argmin(distance, dim=1)
|
912 |
+
|
913 |
+
if self.noise:
|
914 |
+
topk_indices = torch.argsort(torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1))**2, dim=(-2, -1)), dim=1)[:, :8]
|
915 |
+
sample_topk = torch.randint(0, topk_indices.shape[-1], size=(topk_indices.shape[0], 1), device=topk_indices.device)
|
916 |
+
pt_token_id = torch.gather(topk_indices, 1, sample_topk).squeeze(-1)
|
917 |
+
|
918 |
+
# cos, sin = traj_theta.cos(), traj_theta.sin()
|
919 |
+
# rot_mat = traj_theta.new_zeros(pl_num, 2, 2)
|
920 |
+
# rot_mat[..., 0, 0] = cos
|
921 |
+
# rot_mat[..., 0, 1] = sin
|
922 |
+
# rot_mat[..., 1, 0] = -sin
|
923 |
+
# rot_mat[..., 1, 1] = cos
|
924 |
+
# token_src_world = torch.bmm(token_src[None, ...].repeat(pl_num, 1, 1, 1).reshape(pl_num, -1, 2),
|
925 |
+
# rot_mat.view(-1, 2, 2)).reshape(pl_num, token_src.shape[0], max_traj_len, 2) + traj_pos[:, None, [0], :]
|
926 |
+
# token_src_world_select = token_src_world.view(-1, 1024, 11, 2)[torch.arange(pt_token_id.view(-1).shape[0]), pt_token_id.view(-1)].view(pl_num, max_traj_len, 2)
|
927 |
+
|
928 |
+
pl_idx_full = pl_idx_list.clone()
|
929 |
+
token2pl = torch.stack([torch.arange(len(pl_idx_list), device=traj_pos.device), pl_idx_full.long()])
|
930 |
+
count_nums = []
|
931 |
+
for pl in pl_idx_full.unique():
|
932 |
+
pt = token2pl[0, token2pl[1, :] == pl]
|
933 |
+
left_side = (data['pt_token']['side'][pt] == 0).sum()
|
934 |
+
right_side = (data['pt_token']['side'][pt] == 1).sum()
|
935 |
+
center_side = (data['pt_token']['side'][pt] == 2).sum()
|
936 |
+
count_nums.append(torch.Tensor([left_side, right_side, center_side]))
|
937 |
+
count_nums = torch.stack(count_nums, dim=0)
|
938 |
+
num_polyline = int(count_nums.max().item())
|
939 |
+
traj_mask = torch.zeros((int(len(pl_idx_full.unique())), 3, num_polyline), dtype=bool)
|
940 |
+
idx_matrix = torch.arange(traj_mask.size(2)).unsqueeze(0).unsqueeze(0)
|
941 |
+
idx_matrix = idx_matrix.expand(traj_mask.size(0), traj_mask.size(1), -1)
|
942 |
+
counts_num_expanded = count_nums.unsqueeze(-1)
|
943 |
+
mask_update = idx_matrix < counts_num_expanded
|
944 |
+
traj_mask[mask_update] = True
|
945 |
+
|
946 |
+
data['pt_token']['traj_mask'] = traj_mask
|
947 |
+
data['pt_token']['position'] = torch.cat([pt_token_pos, torch.zeros((data['pt_token']['num_nodes'], 1),
|
948 |
+
device=traj_pos.device, dtype=torch.float)], dim=-1)
|
949 |
+
data['pt_token']['orientation'] = pt_token_orientation
|
950 |
+
data['pt_token']['height'] = data['pt_token']['position'][:, -1]
|
951 |
+
data[('pt_token', 'to', 'map_polygon')] = {}
|
952 |
+
data[('pt_token', 'to', 'map_polygon')]['edge_index'] = token2pl # (2, num_points)
|
953 |
+
data['pt_token']['token_idx'] = pt_token_id
|
954 |
+
|
955 |
+
# data['pt_token']['batch'] = torch.zeros(data['pt_token']['num_nodes'], device=traj_pos.device).long()
|
956 |
+
# data['pt_token']['ptr'] = torch.tensor([0, data['pt_token']['num_nodes']], device=traj_pos.device).long()
|
957 |
+
|
958 |
+
return data
|
959 |
+
|
960 |
+
def sample_pt_pred(self, data):
|
961 |
+
traj_mask = data['pt_token']['traj_mask']
|
962 |
+
raw_pt_index = torch.arange(1, traj_mask.shape[2]).repeat(traj_mask.shape[0], traj_mask.shape[1], 1)
|
963 |
+
masked_pt_index = raw_pt_index.view(-1)[torch.randperm(raw_pt_index.numel())[:traj_mask.shape[0] * traj_mask.shape[1] * ((traj_mask.shape[2] - 1) // 3)].reshape(traj_mask.shape[0], traj_mask.shape[1], (traj_mask.shape[2] - 1) // 3)]
|
964 |
+
masked_pt_index = torch.sort(masked_pt_index, -1)[0]
|
965 |
+
pt_valid_mask = traj_mask.clone()
|
966 |
+
pt_valid_mask.scatter_(2, masked_pt_index, False)
|
967 |
+
pt_pred_mask = traj_mask.clone()
|
968 |
+
pt_pred_mask.scatter_(2, masked_pt_index, False)
|
969 |
+
tmp_mask = pt_pred_mask.clone()
|
970 |
+
tmp_mask[:, :, :] = True
|
971 |
+
tmp_mask.scatter_(2, masked_pt_index-1, False)
|
972 |
+
pt_pred_mask.masked_fill_(tmp_mask, False)
|
973 |
+
pt_pred_mask = pt_pred_mask * torch.roll(traj_mask, shifts=-1, dims=2)
|
974 |
+
pt_target_mask = torch.roll(pt_pred_mask, shifts=1, dims=2)
|
975 |
+
|
976 |
+
data['pt_token']['pt_valid_mask'] = pt_valid_mask[traj_mask]
|
977 |
+
data['pt_token']['pt_pred_mask'] = pt_pred_mask[traj_mask]
|
978 |
+
data['pt_token']['pt_target_mask'] = pt_target_mask[traj_mask]
|
979 |
+
|
980 |
+
return data
|
981 |
+
|
982 |
+
def _fetch_enterings(self, data: HeteroData, plot: bool=False):
|
983 |
+
data['agent']['grid_token_idx'] = torch.zeros_like(data['agent']['state_idx']).long()
|
984 |
+
data['agent']['grid_offset_xy'] = torch.zeros_like(data['agent']['token_pos'])
|
985 |
+
data['agent']['heading_token_idx'] = torch.zeros_like(data['agent']['state_idx']).long()
|
986 |
+
data['agent']['sort_indices'] = torch.zeros_like(data['agent']['state_idx']).long()
|
987 |
+
data['agent']['inrange_mask'] = torch.zeros_like(data['agent']['state_idx']).bool()
|
988 |
+
data['agent']['bos_mask'] = torch.zeros_like(data['agent']['state_idx']).bool()
|
989 |
+
|
990 |
+
data['agent']['pos_xy'] = torch.zeros_like(data['agent']['token_pos'])
|
991 |
+
if self.predict_occ:
|
992 |
+
num_step = data['agent']['state_idx'].shape[1]
|
993 |
+
data['agent']['pt_grid_token_idx'] = torch.zeros_like(data['pt_token']['token_idx'])[None].repeat(num_step, 1).long()
|
994 |
+
|
995 |
+
for b in range(data.num_graphs):
|
996 |
+
av_index = int(data['agent']['av_index'][b])
|
997 |
+
agent_batch_mask = data['agent']['batch'] == b
|
998 |
+
pt_batch_mask = data['pt_token']['batch'] == b
|
999 |
+
pt_token_idx = data['pt_token']['token_idx'][pt_batch_mask]
|
1000 |
+
pt_pos = data['pt_token']['position'][pt_batch_mask]
|
1001 |
+
agent_token_pos = data['agent']['token_pos'][agent_batch_mask]
|
1002 |
+
agent_token_heading = data['agent']['token_heading'][agent_batch_mask]
|
1003 |
+
state_idx = data['agent']['state_idx'][agent_batch_mask]
|
1004 |
+
ego_pos = agent_token_pos[av_index] # NOTE: `av_index` will be added by `ptr` later
|
1005 |
+
ego_heading = agent_token_heading[av_index]
|
1006 |
+
|
1007 |
+
grid_token_idx = torch.full(state_idx.shape, -1, device=state_idx.device)
|
1008 |
+
offset_xy = torch.zeros_like(agent_token_pos)
|
1009 |
+
sort_indices = torch.zeros_like(grid_token_idx)
|
1010 |
+
pt_grid_token_idx = torch.full((state_idx.shape[1], *pt_token_idx.shape), -1, device=pt_token_idx.device)
|
1011 |
+
|
1012 |
+
pos_xy = torch.zeros((*state_idx.shape, 2), device=state_idx.device)
|
1013 |
+
|
1014 |
+
is_bos = []
|
1015 |
+
is_inrange = []
|
1016 |
+
for t in range(agent_token_pos.shape[1]): # num_step
|
1017 |
+
|
1018 |
+
# tokenize position
|
1019 |
+
is_bos_t = state_idx[:, t] == self.enter_state
|
1020 |
+
is_invalid_t = state_idx[:, t] == self.invalid_state
|
1021 |
+
is_inrange_t = ((agent_token_pos[:, t] - ego_pos[[t]]) ** 2).sum(-1).sqrt() <= self.pl2seed_radius
|
1022 |
+
grid_index_t, offset_xy_t = self.attr_tokenizer.encode_pos(x=agent_token_pos[~is_invalid_t & is_inrange_t, t],
|
1023 |
+
y=ego_pos[[t]],
|
1024 |
+
theta_y=ego_heading[[t]])
|
1025 |
+
grid_token_idx[~is_invalid_t & is_inrange_t, t] = grid_index_t
|
1026 |
+
offset_xy[~is_invalid_t & is_inrange_t, t] = offset_xy_t
|
1027 |
+
|
1028 |
+
pos_xy[~is_invalid_t & is_inrange_t, t] = agent_token_pos[~is_invalid_t & is_inrange_t, t] - ego_pos[[t]]
|
1029 |
+
|
1030 |
+
# distance = ((agent_token_pos[:, t] - ego_pos[[t]]) ** 2).sum(-1).sqrt()
|
1031 |
+
head_vector = torch.stack([ego_heading[[t]].cos(), ego_heading[[t]].sin()], dim=-1)
|
1032 |
+
distance = angle_between_2d_vectors(ctr_vector=head_vector,
|
1033 |
+
nbr_vector=agent_token_pos[:, t] - ego_pos[[t]])
|
1034 |
+
# distance = torch.rand(agent_token_pos.shape[0], device=agent_token_pos.device)
|
1035 |
+
distance[~(is_bos_t & is_inrange_t)] = torch.inf
|
1036 |
+
sort_dist, sort_indice = distance.sort()
|
1037 |
+
sort_indice[torch.isinf(sort_dist)] = av_index
|
1038 |
+
sort_indices[:, t] = sort_indice
|
1039 |
+
|
1040 |
+
is_bos.append(is_bos_t)
|
1041 |
+
is_inrange.append(is_inrange_t)
|
1042 |
+
|
1043 |
+
# tokenize pt token
|
1044 |
+
if self.predict_occ:
|
1045 |
+
is_inrange_t = ((pt_pos[:, :2] - ego_pos[None, t]) ** 2).sum(-1).sqrt() <= self.pl2seed_radius
|
1046 |
+
grid_index_t, _ = self.attr_tokenizer.encode_pos(x=pt_pos[is_inrange_t, :2],
|
1047 |
+
y=ego_pos[[t]],
|
1048 |
+
theta_y=ego_heading[[t]])
|
1049 |
+
|
1050 |
+
pt_grid_token_idx[t, is_inrange_t] = grid_index_t
|
1051 |
+
|
1052 |
+
# tokenize heading
|
1053 |
+
rel_heading = agent_token_heading - ego_heading[None, ...]
|
1054 |
+
heading_token_idx = self.attr_tokenizer.encode_heading(rel_heading)
|
1055 |
+
|
1056 |
+
data['agent']['grid_token_idx'][agent_batch_mask] = grid_token_idx
|
1057 |
+
data['agent']['grid_offset_xy'][agent_batch_mask] = offset_xy
|
1058 |
+
data['agent']['pos_xy'][agent_batch_mask] = pos_xy
|
1059 |
+
data['agent']['heading_token_idx'][agent_batch_mask] = heading_token_idx
|
1060 |
+
data['agent']['sort_indices'][agent_batch_mask] = sort_indices
|
1061 |
+
data['agent']['inrange_mask'][agent_batch_mask] = torch.stack(is_inrange, dim=1)
|
1062 |
+
data['agent']['bos_mask'][agent_batch_mask] = torch.stack(is_bos, dim=1)
|
1063 |
+
if self.predict_occ:
|
1064 |
+
data['agent']['pt_grid_token_idx'][:, pt_batch_mask] = pt_grid_token_idx
|
1065 |
+
|
1066 |
+
plot = False
|
1067 |
+
if plot:
|
1068 |
+
scenario_id = data['scenario_id'][b]
|
1069 |
+
dummy_prob = np.zeros((ego_pos.shape[0], self.attr_tokenizer.grid.shape[0])) + .5
|
1070 |
+
indices = grid_token_idx[:, 1:][state_idx[:, 1:] == self.enter_state].reshape(-1).cpu().numpy()
|
1071 |
+
dummy_prob, indices = self.attr_tokenizer.pad_square(dummy_prob, indices)
|
1072 |
+
# plot_insert_grid(scenario_id, dummy_prob,
|
1073 |
+
# self.attr_tokenizer.grid.cpu().numpy(),
|
1074 |
+
# ego_pos.cpu().numpy(),
|
1075 |
+
# None,
|
1076 |
+
# save_path=os.path.join(self.save_path, 'vis'),
|
1077 |
+
# indices=indices[np.newaxis, ...],
|
1078 |
+
# inference=True,
|
1079 |
+
# all_t_in_one=True)
|
1080 |
+
|
1081 |
+
enter_index = [grid_token_idx[:, i][state_idx[:, i] == self.enter_state].tolist()
|
1082 |
+
for i in range(agent_token_pos.shape[1])]
|
1083 |
+
agent_labels = [[f'A{i}'] * agent_token_pos.shape[1] for i in range(agent_token_pos.shape[0])]
|
1084 |
+
plot_scenario(scenario_id,
|
1085 |
+
data['map_point']['position'].cpu().numpy(),
|
1086 |
+
agent_token_pos.cpu().numpy(),
|
1087 |
+
agent_token_heading.cpu().numpy(),
|
1088 |
+
state_idx.cpu().numpy(),
|
1089 |
+
types=list(map(lambda i: self.encoder.agent_encoder.agent_type[i],
|
1090 |
+
data['agent']['type'].tolist())),
|
1091 |
+
av_index=av_index,
|
1092 |
+
pl2seed_radius=self.pl2seed_radius,
|
1093 |
+
attr_tokenizer=self.attr_tokenizer,
|
1094 |
+
enter_index=enter_index,
|
1095 |
+
save_gif=False,
|
1096 |
+
save_path=os.path.join(self.save_path, 'vis'),
|
1097 |
+
agent_labels=agent_labels,
|
1098 |
+
tokenized=True)
|
1099 |
+
|
1100 |
+
return data
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/modules/agent_decoder.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/modules/attr_tokenizer.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
from dev.utils.func import wrap_angle, angle_between_2d_vectors
|
5 |
+
|
6 |
+
|
7 |
+
class Attr_Tokenizer(nn.Module):
|
8 |
+
|
9 |
+
def __init__(self, grid_range, grid_interval, radius, angle_interval):
|
10 |
+
super().__init__()
|
11 |
+
self.grid_range = grid_range
|
12 |
+
self.grid_interval = grid_interval
|
13 |
+
self.radius = radius
|
14 |
+
self.angle_interval = angle_interval
|
15 |
+
self.heading = torch.pi / 2
|
16 |
+
self._prepare_grid()
|
17 |
+
|
18 |
+
self.grid_size = self.grid.shape[0]
|
19 |
+
self.angle_size = int(360. / self.angle_interval)
|
20 |
+
|
21 |
+
assert torch.all(self.grid[self.grid_size // 2] == 0.)
|
22 |
+
|
23 |
+
def _prepare_grid(self):
|
24 |
+
num_grid = int(self.grid_range / self.grid_interval) + 1 # Do not use '//'
|
25 |
+
|
26 |
+
x = torch.linspace(0, num_grid - 1, steps=num_grid)
|
27 |
+
y = torch.linspace(0, num_grid - 1, steps=num_grid)
|
28 |
+
grid_x, grid_y = torch.meshgrid(x, y, indexing='xy')
|
29 |
+
grid = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1) # (n^2, 2)
|
30 |
+
grid = grid.reshape(num_grid, num_grid, 2).flip(dims=[0]).reshape(-1, 2)
|
31 |
+
grid = (grid - x.shape[0] // 2) * self.grid_interval
|
32 |
+
|
33 |
+
distance = (grid ** 2).sum(-1).sqrt()
|
34 |
+
square_mask = ((distance <= self.radius) & (distance >= 0.)) | (distance == 0.)
|
35 |
+
self.register_buffer('grid', grid[square_mask])
|
36 |
+
self.register_buffer('dist', torch.norm(self.grid, p=2, dim=-1))
|
37 |
+
head_vector = torch.stack([torch.tensor(self.heading).cos(), torch.tensor(self.heading).sin()])
|
38 |
+
self.register_buffer('dir', angle_between_2d_vectors(ctr_vector=head_vector.unsqueeze(0),
|
39 |
+
nbr_vector=self.grid)) # (-pi, pi]
|
40 |
+
|
41 |
+
self.num_grid = num_grid
|
42 |
+
self.square_mask = square_mask.numpy()
|
43 |
+
|
44 |
+
def _apply_rot(self, x, theta):
|
45 |
+
# x: (b, l, 2) e.g. (num_step, num_agent, 2)
|
46 |
+
# theta: (b,) e.g. (num_step,)
|
47 |
+
cos, sin = theta.cos(), theta.sin()
|
48 |
+
rot_mat = torch.zeros((theta.shape[0], 2, 2), device=theta.device)
|
49 |
+
rot_mat[:, 0, 0] = cos
|
50 |
+
rot_mat[:, 0, 1] = sin
|
51 |
+
rot_mat[:, 1, 0] = -sin
|
52 |
+
rot_mat[:, 1, 1] = cos
|
53 |
+
x = torch.bmm(x, rot_mat)
|
54 |
+
return x
|
55 |
+
|
56 |
+
def pad_square(self, prob, indices=None):
|
57 |
+
# square_mask: bool array of shape (n^2,)
|
58 |
+
# prob: float array of shape (num_step, m)
|
59 |
+
pad_prob = np.zeros((*prob.shape[:-1], self.square_mask.shape[0]))
|
60 |
+
pad_prob[..., self.square_mask] = prob
|
61 |
+
|
62 |
+
square_indices = np.arange(self.square_mask.shape[0])
|
63 |
+
circle_indices = np.concatenate([square_indices[self.square_mask], [-1]])
|
64 |
+
if indices is not None:
|
65 |
+
indices = circle_indices[indices]
|
66 |
+
|
67 |
+
return pad_prob, indices
|
68 |
+
|
69 |
+
def get_grid(self, x, theta=None):
|
70 |
+
x = x.reshape(-1, 2)
|
71 |
+
grid = self.grid[None, ...].to(x.device)
|
72 |
+
if theta is not None:
|
73 |
+
grid = self._apply_rot(grid, (theta - self.heading).expand(x.shape[0]))
|
74 |
+
return x[:, None] + grid
|
75 |
+
|
76 |
+
def encode_pos(self, x, y, theta_y=None):
|
77 |
+
assert x.dim() == y.dim() and x.shape[-1] == 2 and y.shape[-1] == 2, \
|
78 |
+
f"Invalid input shape x: {x.shape}, y: {y.shape}."
|
79 |
+
centered_x = x - y
|
80 |
+
if theta_y is not None:
|
81 |
+
centered_x = self._apply_rot(centered_x[:, None], -(theta_y - self.heading).expand(x.shape[0]))[:, 0]
|
82 |
+
distance = ((centered_x[:, None] - self.grid.to(x.device)[None, ...]) ** 2).sum(-1).sqrt()
|
83 |
+
index = torch.argmin(distance, dim=-1)
|
84 |
+
|
85 |
+
grid_xy = self.grid[index]
|
86 |
+
offset_xy = centered_x - grid_xy
|
87 |
+
|
88 |
+
return index.long(), offset_xy
|
89 |
+
|
90 |
+
def decode_pos(self, index, y=None, theta_y=None):
|
91 |
+
assert torch.all((index >= 0) & (index < self.grid_size))
|
92 |
+
centered_x = self.grid.to(index.device)[index.long()]
|
93 |
+
if y is not None:
|
94 |
+
if theta_y is not None:
|
95 |
+
centered_x = self._apply_rot(centered_x[:, None], (theta_y - self.heading).expand(centered_x.shape[0]))[:, 0]
|
96 |
+
x = centered_x + y
|
97 |
+
return x.float()
|
98 |
+
return centered_x.float()
|
99 |
+
|
100 |
+
def encode_heading(self, heading):
|
101 |
+
heading = (wrap_angle(heading) + torch.pi) / (2 * torch.pi) * 360
|
102 |
+
index = heading // self.angle_interval
|
103 |
+
return index.long()
|
104 |
+
|
105 |
+
def decode_heading(self, index):
|
106 |
+
assert torch.all(index >= 0) and torch.all(index < (360 / self.angle_interval))
|
107 |
+
angles = index * self.angle_interval - 180
|
108 |
+
angles = angles / 360 * (2 * torch.pi)
|
109 |
+
return angles.float()
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/modules/debug.py
ADDED
@@ -0,0 +1,1439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Mapping, Optional
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch_cluster import radius, radius_graph
|
6 |
+
from torch_geometric.data import HeteroData, Batch
|
7 |
+
from torch_geometric.utils import dense_to_sparse, subgraph
|
8 |
+
|
9 |
+
from dev.modules.layers import *
|
10 |
+
from dev.modules.map_decoder import discretize_neighboring
|
11 |
+
from dev.utils.geometry import angle_between_2d_vectors, wrap_angle
|
12 |
+
from dev.utils.weight_init import weight_init
|
13 |
+
|
14 |
+
|
15 |
+
def cal_polygon_contour(x, y, theta, width, length):
|
16 |
+
left_front_x = x + 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)
|
17 |
+
left_front_y = y + 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)
|
18 |
+
left_front = (left_front_x, left_front_y)
|
19 |
+
|
20 |
+
right_front_x = x + 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)
|
21 |
+
right_front_y = y + 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)
|
22 |
+
right_front = (right_front_x, right_front_y)
|
23 |
+
|
24 |
+
right_back_x = x - 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)
|
25 |
+
right_back_y = y - 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)
|
26 |
+
right_back = (right_back_x, right_back_y)
|
27 |
+
|
28 |
+
left_back_x = x - 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)
|
29 |
+
left_back_y = y - 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)
|
30 |
+
left_back = (left_back_x, left_back_y)
|
31 |
+
polygon_contour = [left_front, right_front, right_back, left_back]
|
32 |
+
|
33 |
+
return polygon_contour
|
34 |
+
|
35 |
+
|
36 |
+
class SMARTAgentDecoder(nn.Module):
|
37 |
+
|
38 |
+
def __init__(self,
|
39 |
+
dataset: str,
|
40 |
+
input_dim: int,
|
41 |
+
hidden_dim: int,
|
42 |
+
num_historical_steps: int,
|
43 |
+
num_interaction_steps: int,
|
44 |
+
time_span: Optional[int],
|
45 |
+
pl2a_radius: float,
|
46 |
+
pl2seed_radius: float,
|
47 |
+
a2a_radius: float,
|
48 |
+
num_freq_bands: int,
|
49 |
+
num_layers: int,
|
50 |
+
num_heads: int,
|
51 |
+
head_dim: int,
|
52 |
+
dropout: float,
|
53 |
+
token_data: Dict,
|
54 |
+
token_size: int,
|
55 |
+
special_token_index: list=[],
|
56 |
+
predict_motion: bool=False,
|
57 |
+
predict_state: bool=False,
|
58 |
+
predict_map: bool=False,
|
59 |
+
state_token: Dict[str, int]=None,
|
60 |
+
seed_size: int=5) -> None:
|
61 |
+
|
62 |
+
super(SMARTAgentDecoder, self).__init__()
|
63 |
+
self.dataset = dataset
|
64 |
+
self.input_dim = input_dim
|
65 |
+
self.hidden_dim = hidden_dim
|
66 |
+
self.num_historical_steps = num_historical_steps
|
67 |
+
self.num_interaction_steps = num_interaction_steps
|
68 |
+
self.time_span = time_span if time_span is not None else num_historical_steps
|
69 |
+
self.pl2a_radius = pl2a_radius
|
70 |
+
self.pl2seed_radius = pl2seed_radius
|
71 |
+
self.a2a_radius = a2a_radius
|
72 |
+
self.num_freq_bands = num_freq_bands
|
73 |
+
self.num_layers = num_layers
|
74 |
+
self.num_heads = num_heads
|
75 |
+
self.head_dim = head_dim
|
76 |
+
self.dropout = dropout
|
77 |
+
self.special_token_index = special_token_index
|
78 |
+
self.predict_motion = predict_motion
|
79 |
+
self.predict_state = predict_state
|
80 |
+
self.predict_map = predict_map
|
81 |
+
|
82 |
+
# state tokens
|
83 |
+
self.state_type = list(state_token.keys())
|
84 |
+
self.state_token = state_token
|
85 |
+
self.invalid_state = int(state_token['invalid'])
|
86 |
+
self.valid_state = int(state_token['valid'])
|
87 |
+
self.enter_state = int(state_token['enter'])
|
88 |
+
self.exit_state = int(state_token['exit'])
|
89 |
+
|
90 |
+
self.seed_state_type = ['invalid', 'enter']
|
91 |
+
self.valid_state_type = ['invalid', 'valid', 'exit']
|
92 |
+
|
93 |
+
input_dim_x_a = 2
|
94 |
+
input_dim_r_t = 4
|
95 |
+
input_dim_r_pt2a = 3
|
96 |
+
input_dim_r_a2a = 3
|
97 |
+
input_dim_token = 8 # tokens: (token_size, 4, 2)
|
98 |
+
|
99 |
+
self.seed_size = seed_size
|
100 |
+
|
101 |
+
self.all_agent_type = ['veh', 'ped', 'cyc', 'background', 'invalid', 'seed']
|
102 |
+
self.seed_agent_type = ['veh', 'ped', 'cyc', 'seed']
|
103 |
+
self.type_a_emb = nn.Embedding(len(self.all_agent_type), hidden_dim)
|
104 |
+
self.shape_emb = MLPLayer(3, hidden_dim, hidden_dim)
|
105 |
+
if self.predict_state:
|
106 |
+
self.state_a_emb = nn.Embedding(len(self.state_type), hidden_dim)
|
107 |
+
self.invalid_shape_value = .1
|
108 |
+
self.motion_gap = 1.
|
109 |
+
self.heading_gap = 1.
|
110 |
+
|
111 |
+
self.x_a_emb = FourierEmbedding(input_dim=input_dim_x_a, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands)
|
112 |
+
self.r_t_emb = FourierEmbedding(input_dim=input_dim_r_t, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands)
|
113 |
+
self.r_pt2a_emb = FourierEmbedding(input_dim=input_dim_r_pt2a, hidden_dim=hidden_dim,
|
114 |
+
num_freq_bands=num_freq_bands)
|
115 |
+
self.r_a2a_emb = FourierEmbedding(input_dim=input_dim_r_a2a, hidden_dim=hidden_dim,
|
116 |
+
num_freq_bands=num_freq_bands)
|
117 |
+
self.token_emb_veh = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
|
118 |
+
self.token_emb_ped = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
|
119 |
+
self.token_emb_cyc = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
|
120 |
+
self.no_token_emb = nn.Embedding(1, hidden_dim)
|
121 |
+
self.bos_token_emb = nn.Embedding(1, hidden_dim)
|
122 |
+
# FIXME: do we need this???
|
123 |
+
self.token_emb_offset = MLPEmbedding(input_dim=2, hidden_dim=hidden_dim)
|
124 |
+
|
125 |
+
num_inputs = 2
|
126 |
+
if self.predict_state:
|
127 |
+
num_inputs = 3
|
128 |
+
self.fusion_emb = MLPEmbedding(input_dim=self.hidden_dim * num_inputs, hidden_dim=self.hidden_dim)
|
129 |
+
|
130 |
+
self.t_attn_layers = nn.ModuleList(
|
131 |
+
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
|
132 |
+
bipartite=False, has_pos_emb=True) for _ in range(num_layers)]
|
133 |
+
)
|
134 |
+
self.pt2a_attn_layers = nn.ModuleList(
|
135 |
+
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
|
136 |
+
bipartite=True, has_pos_emb=True) for _ in range(num_layers)]
|
137 |
+
)
|
138 |
+
self.a2a_attn_layers = nn.ModuleList(
|
139 |
+
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
|
140 |
+
bipartite=False, has_pos_emb=True) for _ in range(num_layers)]
|
141 |
+
)
|
142 |
+
self.token_size = token_size # 2048
|
143 |
+
# agent motion prediction head
|
144 |
+
self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
|
145 |
+
output_dim=self.token_size)
|
146 |
+
# agent state prediction head
|
147 |
+
if self.predict_state:
|
148 |
+
|
149 |
+
self.seed_feature = nn.Embedding(self.seed_size, self.hidden_dim)
|
150 |
+
|
151 |
+
self.state_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
|
152 |
+
output_dim=len(self.valid_state_type))
|
153 |
+
|
154 |
+
self.seed_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
|
155 |
+
output_dim=hidden_dim)
|
156 |
+
|
157 |
+
self.seed_state_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
|
158 |
+
output_dim=len(self.seed_state_type))
|
159 |
+
self.seed_type_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
|
160 |
+
output_dim=len(self.seed_agent_type))
|
161 |
+
# entering token prediction
|
162 |
+
# FIXME: this is just under test!!!
|
163 |
+
# self.bos_pl_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
|
164 |
+
# output_dim=200)
|
165 |
+
# self.bos_offset_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
|
166 |
+
# output_dim=2601)
|
167 |
+
self.trajectory_token = token_data['token'] # dict('veh', 'ped', 'cyc') (2048, 4, 2)
|
168 |
+
self.trajectory_token_traj = token_data['traj'] # (2048, 6, 3)
|
169 |
+
self.trajectory_token_all = token_data['token_all'] # (2048, 6, 4, 2)
|
170 |
+
self.apply(weight_init)
|
171 |
+
|
172 |
+
self.shift = 5
|
173 |
+
self.beam_size = 5
|
174 |
+
self.hist_mask = True
|
175 |
+
self.temporal_attn_to_invalid = True
|
176 |
+
self.temporal_attn_seed = False
|
177 |
+
|
178 |
+
# FIXME: This is just under test!!!
|
179 |
+
# self.mapping_network = MappingNetwork(z_dim=hidden_dim, w_dim=hidden_dim, num_layers=num_layers)
|
180 |
+
|
181 |
+
def transform_rel(self, token_traj, prev_pos, prev_heading=None):
|
182 |
+
if prev_heading is None:
|
183 |
+
diff_xy = prev_pos[:, :, -1, :] - prev_pos[:, :, -2, :]
|
184 |
+
prev_heading = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])
|
185 |
+
|
186 |
+
num_agent, num_step, traj_num, traj_dim = token_traj.shape
|
187 |
+
cos, sin = prev_heading.cos(), prev_heading.sin()
|
188 |
+
rot_mat = torch.zeros((num_agent, num_step, 2, 2), device=prev_heading.device)
|
189 |
+
rot_mat[:, :, 0, 0] = cos
|
190 |
+
rot_mat[:, :, 0, 1] = -sin
|
191 |
+
rot_mat[:, :, 1, 0] = sin
|
192 |
+
rot_mat[:, :, 1, 1] = cos
|
193 |
+
agent_diff_rel = torch.bmm(token_traj.view(-1, traj_num, 2), rot_mat.view(-1, 2, 2)).view(num_agent, num_step, traj_num, traj_dim)
|
194 |
+
agent_pred_rel = agent_diff_rel + prev_pos[:, :, -1:, :]
|
195 |
+
return agent_pred_rel
|
196 |
+
|
197 |
+
def agent_token_embedding(self, data, agent_token_index, agent_state, pos_a, head_a, inference=False,
|
198 |
+
filter_mask=None, av_index=None):
|
199 |
+
|
200 |
+
if filter_mask is None:
|
201 |
+
filter_mask = torch.ones_like(agent_state[:, 2], dtype=torch.bool)
|
202 |
+
|
203 |
+
num_agent, num_step, traj_dim = pos_a.shape # traj_dim=2
|
204 |
+
agent_type = data['agent']['type'][filter_mask]
|
205 |
+
veh_mask = (agent_type == 0)
|
206 |
+
ped_mask = (agent_type == 1)
|
207 |
+
cyc_mask = (agent_type == 2)
|
208 |
+
|
209 |
+
# set the position of invalid agents to the position of ego agent
|
210 |
+
# note here we only set invalid steps BEFORE the bos token!
|
211 |
+
# is_invalid = agent_state == self.invalid_state
|
212 |
+
# is_bos = agent_state == self.enter_state
|
213 |
+
# bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
|
214 |
+
# bos_mask = torch.arange(num_step).expand(num_agent, -1).to(agent_state.device) < bos_index[:, None]
|
215 |
+
# is_invalid[~bos_mask] = False
|
216 |
+
|
217 |
+
# ego_pos_a = pos_a[av_index].clone()
|
218 |
+
# ego_head_vector_a = head_vector_a[av_index].clone()
|
219 |
+
# pos_a[is_invalid] = ego_pos_a[None, :].repeat(pos_a.shape[0], 1, 1)[is_invalid]
|
220 |
+
# head_vector_a[is_invalid] = ego_head_vector_a[None, :].repeat(head_vector_a.shape[0], 1, 1)[is_invalid]
|
221 |
+
|
222 |
+
motion_vector_a, head_vector_a = self.build_vector_a(pos_a, head_a, agent_state)
|
223 |
+
|
224 |
+
trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(pos_a.device).to(torch.float)
|
225 |
+
trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(pos_a.device).to(torch.float)
|
226 |
+
trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(pos_a.device).to(torch.float)
|
227 |
+
self.agent_token_emb_veh = self.token_emb_veh(trajectory_token_veh.view(trajectory_token_veh.shape[0], -1)) # (token_size, 8)
|
228 |
+
self.agent_token_emb_ped = self.token_emb_ped(trajectory_token_ped.view(trajectory_token_ped.shape[0], -1))
|
229 |
+
self.agent_token_emb_cyc = self.token_emb_cyc(trajectory_token_cyc.view(trajectory_token_cyc.shape[0], -1))
|
230 |
+
|
231 |
+
# add bos token embedding
|
232 |
+
self.agent_token_emb_veh = torch.cat([self.agent_token_emb_veh, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())])
|
233 |
+
self.agent_token_emb_ped = torch.cat([self.agent_token_emb_ped, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())])
|
234 |
+
self.agent_token_emb_cyc = torch.cat([self.agent_token_emb_cyc, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())])
|
235 |
+
|
236 |
+
# add invalid token embedding
|
237 |
+
self.agent_token_emb_veh = torch.cat([self.agent_token_emb_veh, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())])
|
238 |
+
self.agent_token_emb_ped = torch.cat([self.agent_token_emb_ped, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())])
|
239 |
+
self.agent_token_emb_cyc = torch.cat([self.agent_token_emb_cyc, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())])
|
240 |
+
|
241 |
+
if inference:
|
242 |
+
agent_token_traj_all = torch.zeros((num_agent, self.token_size, self.shift + 1, 4, 2), device=pos_a.device)
|
243 |
+
trajectory_token_all_veh = torch.from_numpy(self.trajectory_token_all['veh']).clone().to(pos_a.device).to(torch.float)
|
244 |
+
trajectory_token_all_ped = torch.from_numpy(self.trajectory_token_all['ped']).clone().to(pos_a.device).to(torch.float)
|
245 |
+
trajectory_token_all_cyc = torch.from_numpy(self.trajectory_token_all['cyc']).clone().to(pos_a.device).to(torch.float)
|
246 |
+
agent_token_traj_all[veh_mask] = torch.cat(
|
247 |
+
[trajectory_token_all_veh[:, :self.shift], trajectory_token_veh[:, None, ...]], dim=1)
|
248 |
+
agent_token_traj_all[ped_mask] = torch.cat(
|
249 |
+
[trajectory_token_all_ped[:, :self.shift], trajectory_token_ped[:, None, ...]], dim=1)
|
250 |
+
agent_token_traj_all[cyc_mask] = torch.cat(
|
251 |
+
[trajectory_token_all_cyc[:, :self.shift], trajectory_token_cyc[:, None, ...]], dim=1)
|
252 |
+
|
253 |
+
# additional token embeddings are already added -> -1: invalid, -2: bos
|
254 |
+
agent_token_emb = torch.zeros((num_agent, num_step, self.hidden_dim), device=pos_a.device)
|
255 |
+
agent_token_emb[veh_mask] = self.agent_token_emb_veh[agent_token_index[veh_mask]]
|
256 |
+
agent_token_emb[ped_mask] = self.agent_token_emb_ped[agent_token_index[ped_mask]]
|
257 |
+
agent_token_emb[cyc_mask] = self.agent_token_emb_cyc[agent_token_index[cyc_mask]]
|
258 |
+
|
259 |
+
# 'vehicle', 'pedestrian', 'cyclist', 'background'
|
260 |
+
is_invalid = (agent_state == self.invalid_state) & (agent_state != self.enter_state)
|
261 |
+
agent_types = data['agent']['type'][filter_mask].long().repeat_interleave(repeats=num_step, dim=0)
|
262 |
+
agent_types[is_invalid.reshape(-1)] = self.all_agent_type.index('invalid')
|
263 |
+
agent_shapes = data['agent']['shape'][filter_mask, self.num_historical_steps - 1, :].repeat_interleave(repeats=num_step, dim=0)
|
264 |
+
agent_shapes[is_invalid.reshape(-1)] = self.invalid_shape_value
|
265 |
+
|
266 |
+
categorical_embs = [self.type_a_emb(agent_types), self.shape_emb(agent_shapes)]
|
267 |
+
feature_a = torch.stack(
|
268 |
+
[torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1),
|
269 |
+
angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2]),
|
270 |
+
], dim=-1) # (num_agent, num_shifted_step, 2)
|
271 |
+
|
272 |
+
x_a = self.x_a_emb(continuous_inputs=feature_a.view(-1, feature_a.size(-1)),
|
273 |
+
categorical_embs=categorical_embs)
|
274 |
+
x_a = x_a.view(-1, num_step, self.hidden_dim) # (num_agent, num_step, hidden_dim)
|
275 |
+
|
276 |
+
s_a = self.state_a_emb(agent_state.reshape(-1).long()).reshape(num_agent, num_step, self.hidden_dim)
|
277 |
+
feat_a = torch.cat((agent_token_emb, x_a, s_a), dim=-1) # (num_agent, num_step, hidden_dim * 3)
|
278 |
+
feat_a = self.fusion_emb(feat_a) # (num_agent, num_step, hidden_dim)
|
279 |
+
|
280 |
+
# seed agent feature
|
281 |
+
motion_vector_seed = motion_vector_a[av_index : av_index + 1]
|
282 |
+
head_vector_seed = head_vector_a[av_index : av_index + 1]
|
283 |
+
feat_seed = self.build_invalid_agent_feature(num_step, pos_a.device, type_index=self.all_agent_type.index('seed'),
|
284 |
+
motion_vector=motion_vector_seed, head_vector=head_vector_seed)
|
285 |
+
|
286 |
+
# replace the features of steps before bos of valid agents with the corresponding invalid agent features
|
287 |
+
# is_bos = agent_state == self.enter_state
|
288 |
+
# is_eos = agent_state == self.exit_state
|
289 |
+
# bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
|
290 |
+
# eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1))
|
291 |
+
# is_before_bos = torch.arange(num_step).expand(num_agent, -1).to(agent_state.device) < bos_index[:, None]
|
292 |
+
# is_after_eos = torch.arange(num_step).expand(num_agent, -1).to(agent_state.device) > eos_index[:, None] + 1
|
293 |
+
# feat_ina = self.build_invalid_agent_feature(num_step, pos_a.device)
|
294 |
+
# feat_a[is_before_bos | is_after_eos] = feat_ina.repeat(num_agent, 1, 1)[is_before_bos | is_after_eos]
|
295 |
+
|
296 |
+
# print("train")
|
297 |
+
# is_bos = agent_state == self.enter_state
|
298 |
+
# is_eos = agent_state == self.exit_state
|
299 |
+
# bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
|
300 |
+
# eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1))
|
301 |
+
# mask = torch.arange(num_step).expand(num_agent, -1).to(agent_state.device)
|
302 |
+
# mask = (mask >= bos_index[:, None]) & (mask <= eos_index[:, None] + 1)
|
303 |
+
# is_invalid[mask] = False
|
304 |
+
# print(feat_a.sum(dim=-1)[is_invalid])
|
305 |
+
|
306 |
+
feat_a = torch.cat([feat_a, feat_seed], dim=0) # (num_agent + 1, num_step, hidden_dim)
|
307 |
+
|
308 |
+
# feat_a_sum = feat_a.sum(dim=-1)
|
309 |
+
# for a in range(num_agent):
|
310 |
+
# print(f"agent {a}:")
|
311 |
+
# print(f"state: {agent_state[a, :]}")
|
312 |
+
# print(f"feat_a_sum: {feat_a_sum[a, :]}")
|
313 |
+
# exit(1)
|
314 |
+
|
315 |
+
if inference:
|
316 |
+
return feat_a, head_vector_a, agent_token_traj_all, agent_token_emb, categorical_embs
|
317 |
+
else:
|
318 |
+
return feat_a, head_vector_a
|
319 |
+
|
320 |
+
def build_vector_a(self, pos_a, head_a, state_a):
|
321 |
+
num_agent = pos_a.shape[0]
|
322 |
+
|
323 |
+
motion_vector_a = torch.cat([pos_a.new_zeros(num_agent, 1, self.input_dim),
|
324 |
+
pos_a[:, 1:] - pos_a[:, :-1]], dim=1)
|
325 |
+
|
326 |
+
# update the relative motion/head vectors
|
327 |
+
is_bos = state_a == self.enter_state
|
328 |
+
motion_vector_a[is_bos] = self.motion_gap
|
329 |
+
|
330 |
+
is_last_eos = state_a.roll(shifts=1, dims=1) == self.exit_state
|
331 |
+
is_last_eos[:, 0] = False
|
332 |
+
motion_vector_a[is_last_eos] = -self.motion_gap
|
333 |
+
|
334 |
+
head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1)
|
335 |
+
|
336 |
+
return motion_vector_a, head_vector_a
|
337 |
+
|
338 |
+
def build_invalid_agent_feature(self, num_step, device, motion_vector=None, head_vector=None, type_index=None, shape_value=None):
|
339 |
+
invalid_agent_token_emb = self.no_token_emb(torch.zeros(1, device=device).long())[:, None].repeat(1, num_step, 1)
|
340 |
+
|
341 |
+
if motion_vector is None or head_vector is None:
|
342 |
+
motion_vector = torch.zeros((1, num_step, 2), device=device)
|
343 |
+
head_vector = torch.stack([torch.cos(torch.zeros(1, device=device)), torch.sin(torch.zeros(1, device=device))], dim=-1)[:, None, :].repeat(1, num_step, 1)
|
344 |
+
|
345 |
+
feature_ina = torch.stack(
|
346 |
+
[torch.norm(motion_vector[:, :, :2], p=2, dim=-1),
|
347 |
+
angle_between_2d_vectors(ctr_vector=head_vector, nbr_vector=motion_vector[:, :, :2]),
|
348 |
+
], dim=-1)
|
349 |
+
|
350 |
+
if type_index is None:
|
351 |
+
type_index = self.all_agent_type.index('invalid')
|
352 |
+
if shape_value is None:
|
353 |
+
shape_value = torch.full((1, 3), self.invalid_shape_value, device=device)
|
354 |
+
|
355 |
+
categorical_embs_ina = [self.type_a_emb(torch.tensor([type_index], device=device)),
|
356 |
+
self.shape_emb(shape_value)]
|
357 |
+
x_ina = self.x_a_emb(continuous_inputs=feature_ina.view(-1, feature_ina.size(-1)),
|
358 |
+
categorical_embs=categorical_embs_ina)
|
359 |
+
x_ina = x_ina.view(-1, num_step, self.hidden_dim) # (1, num_step, hidden_dim)
|
360 |
+
|
361 |
+
s_ina = self.state_a_emb(torch.tensor([self.invalid_state], device=device))[:, None].repeat(1, num_step, 1) # NOTE: do not use `expand`
|
362 |
+
|
363 |
+
feat_ina = torch.cat((invalid_agent_token_emb, x_ina, s_ina), dim=-1)
|
364 |
+
feat_ina = self.fusion_emb(feat_ina) # (1, num_step, hidden_dim)
|
365 |
+
|
366 |
+
return feat_ina
|
367 |
+
|
368 |
+
def build_temporal_edge(self, pos_a, head_a, head_vector_a, state_a, mask, inference_mask=None, av_index=None):
|
369 |
+
|
370 |
+
num_agent = pos_a.shape[0]
|
371 |
+
hist_mask = mask.clone()
|
372 |
+
|
373 |
+
if not self.temporal_attn_to_invalid:
|
374 |
+
hist_mask[state_a == self.invalid_state] = False
|
375 |
+
|
376 |
+
# set the position of invalid agents to the position of ego agent
|
377 |
+
ego_pos_a = pos_a[av_index].clone() # (num_step, 2)
|
378 |
+
ego_head_a = head_a[av_index].clone()
|
379 |
+
ego_head_vector_a = head_vector_a[av_index].clone()
|
380 |
+
ego_state_a = state_a[av_index].clone()
|
381 |
+
# is_invalid = state_a == self.invalid_state
|
382 |
+
# pos_a[is_invalid] = ego_pos_a[None, :].repeat(pos_a.shape[0], 1, 1)[is_invalid]
|
383 |
+
# head_a[is_invalid] = ego_head_a[None, :].repeat(head_a.shape[0], 1)[is_invalid]
|
384 |
+
|
385 |
+
# add seed agent
|
386 |
+
pos_a = torch.cat([pos_a, ego_pos_a[None]], dim=0)
|
387 |
+
head_a = torch.cat([head_a, ego_head_a[None]], dim=0)
|
388 |
+
state_a = torch.cat([state_a, ego_state_a[None]], dim=0)
|
389 |
+
head_vector_a = torch.cat([head_vector_a, ego_head_vector_a[None]], dim=0)
|
390 |
+
hist_mask = torch.cat([hist_mask, torch.ones_like(hist_mask[0:1])], dim=0).bool()
|
391 |
+
if not self.temporal_attn_seed:
|
392 |
+
hist_mask[-1:] = False
|
393 |
+
if inference_mask is not None:
|
394 |
+
inference_mask[-1:] = False
|
395 |
+
|
396 |
+
pos_t = pos_a.reshape(-1, self.input_dim) # (num_agent * num_step, ...)
|
397 |
+
head_t = head_a.reshape(-1)
|
398 |
+
head_vector_t = head_vector_a.reshape(-1, 2)
|
399 |
+
|
400 |
+
# for those invalid agents won't predict any motion token, we don't attend to them
|
401 |
+
is_bos = state_a == self.enter_state
|
402 |
+
is_bos[-1] = False
|
403 |
+
bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
|
404 |
+
motion_predict_start_index = torch.clamp(bos_index - self.time_span / self.shift + 1, min=0)
|
405 |
+
motion_predict_mask = torch.arange(hist_mask.shape[1]).expand(hist_mask.shape[0], -1).to(hist_mask.device)
|
406 |
+
motion_predict_mask = motion_predict_mask >= motion_predict_start_index[:, None]
|
407 |
+
hist_mask[~motion_predict_mask] = False
|
408 |
+
|
409 |
+
if self.hist_mask and self.training:
|
410 |
+
hist_mask[
|
411 |
+
torch.arange(mask.shape[0]).unsqueeze(1), torch.randint(0, mask.shape[1], (num_agent, 10))] = False
|
412 |
+
mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1)
|
413 |
+
elif inference_mask is not None:
|
414 |
+
mask_t = hist_mask.unsqueeze(2) & inference_mask.unsqueeze(1)
|
415 |
+
else:
|
416 |
+
mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1)
|
417 |
+
|
418 |
+
# mask_t: (num_agent, 18, 18), edge_index_t: (2, num_edge)
|
419 |
+
edge_index_t = dense_to_sparse(mask_t)[0]
|
420 |
+
edge_index_t = edge_index_t[:, (edge_index_t[1] - edge_index_t[0] > 0) &
|
421 |
+
(edge_index_t[1] - edge_index_t[0] <= self.time_span / self.shift)]
|
422 |
+
rel_pos_t = pos_t[edge_index_t[0]] - pos_t[edge_index_t[1]]
|
423 |
+
rel_head_t = wrap_angle(head_t[edge_index_t[0]] - head_t[edge_index_t[1]])
|
424 |
+
|
425 |
+
# FIXME relative motion/head for bos/eos token
|
426 |
+
# is_next_bos = state_a.roll(shifts=-1, dims=1) == self.enter_state
|
427 |
+
# is_next_bos[:, -1] = False # the last step
|
428 |
+
# is_next_bos_t = is_next_bos.reshape(-1)
|
429 |
+
# rel_pos_t[is_next_bos_t[edge_index_t[0]]] = -self.bos_motion
|
430 |
+
# rel_pos_t[is_next_bos_t[edge_index_t[1]]] = self.bos_motion
|
431 |
+
# rel_head_t[is_next_bos_t[edge_index_t[0]]] = -torch.pi
|
432 |
+
# rel_head_t[is_next_bos_t[edge_index_t[1]]] = torch.pi
|
433 |
+
|
434 |
+
# is_last_eos = state_a.roll(shifts=1, dims=1) == self.exit_state
|
435 |
+
# is_last_eos[:, 0] = False # the first step
|
436 |
+
# is_last_eos_t = is_last_eos.reshape(-1)
|
437 |
+
# rel_pos_t[is_last_eos_t[edge_index_t[0]]] = -self.bos_motion
|
438 |
+
# rel_pos_t[is_last_eos_t[edge_index_t[1]]] = self.bos_motion
|
439 |
+
# rel_head_t[is_last_eos_t[edge_index_t[0]]] = -torch.pi
|
440 |
+
# rel_head_t[is_last_eos_t[edge_index_t[1]]] = torch.pi
|
441 |
+
|
442 |
+
# handle the bos token of ego agent
|
443 |
+
# is_invalid = state_a == self.invalid_state
|
444 |
+
# is_invalid_t = is_invalid.reshape(-1)
|
445 |
+
# is_ego_bos = (ego_state_a == self.enter_state)[None, :].expand(num_agent + 1, -1)
|
446 |
+
# is_ego_bos_t = is_ego_bos.reshape(-1)
|
447 |
+
# rel_pos_t[is_invalid_t[edge_index_t[0]] & is_ego_bos_t[edge_index_t[0]]] = 0.
|
448 |
+
# rel_pos_t[is_invalid_t[edge_index_t[1]] & is_ego_bos_t[edge_index_t[1]]] = 0.
|
449 |
+
# rel_head_t[is_invalid_t[edge_index_t[0]] & is_ego_bos_t[edge_index_t[0]]] = 0.
|
450 |
+
# rel_head_t[is_invalid_t[edge_index_t[1]] & is_ego_bos_t[edge_index_t[1]]] = 0.
|
451 |
+
|
452 |
+
# handle the invalid steps
|
453 |
+
is_invalid = state_a == self.invalid_state
|
454 |
+
is_invalid_t = is_invalid.reshape(-1)
|
455 |
+
rel_pos_t[is_invalid_t[edge_index_t[0]]] = -self.motion_gap
|
456 |
+
rel_pos_t[is_invalid_t[edge_index_t[1]]] = self.motion_gap
|
457 |
+
rel_head_t[is_invalid_t[edge_index_t[0]]] = -self.heading_gap
|
458 |
+
rel_head_t[is_invalid_t[edge_index_t[1]]] = self.heading_gap
|
459 |
+
|
460 |
+
r_t = torch.stack(
|
461 |
+
[torch.norm(rel_pos_t[:, :2], p=2, dim=-1),
|
462 |
+
angle_between_2d_vectors(ctr_vector=head_vector_t[edge_index_t[1]], nbr_vector=rel_pos_t[:, :2]),
|
463 |
+
rel_head_t,
|
464 |
+
edge_index_t[0] - edge_index_t[1]], dim=-1)
|
465 |
+
r_t = self.r_t_emb(continuous_inputs=r_t, categorical_embs=None)
|
466 |
+
|
467 |
+
return edge_index_t, r_t
|
468 |
+
|
469 |
+
def build_interaction_edge(self, pos_a, head_a, head_vector_a, state_a, batch_s, mask_a, inference_mask=None, av_index=None):
|
470 |
+
num_agent, num_step, _ = pos_a.shape
|
471 |
+
|
472 |
+
pos_a = torch.cat([pos_a, pos_a[av_index][None]], dim=0)
|
473 |
+
head_a = torch.cat([head_a, head_a[av_index][None]], dim=0)
|
474 |
+
state_a = torch.cat([state_a, state_a[av_index][None]], dim=0)
|
475 |
+
head_vector_a = torch.cat([head_vector_a, head_vector_a[av_index][None]], dim=0)
|
476 |
+
|
477 |
+
pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
|
478 |
+
head_s = head_a.transpose(0, 1).reshape(-1)
|
479 |
+
head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
|
480 |
+
if inference_mask is not None:
|
481 |
+
mask_a = mask_a & inference_mask
|
482 |
+
mask_s = mask_a.transpose(0, 1).reshape(-1)
|
483 |
+
|
484 |
+
# seed agent
|
485 |
+
mask_seed = state_a[av_index] != self.invalid_state
|
486 |
+
pos_seed = pos_a[av_index]
|
487 |
+
edge_index_seed2a = radius(x=pos_seed[:, :2], y=pos_s[:, :2], r=self.pl2seed_radius,
|
488 |
+
batch_x=torch.arange(num_step).to(pos_s.device), batch_y=batch_s, max_num_neighbors=300)
|
489 |
+
edge_index_seed2a = edge_index_seed2a[:, mask_s[edge_index_seed2a[0]] & mask_seed[edge_index_seed2a[1]]]
|
490 |
+
|
491 |
+
# convert to global index (must be unilateral connection)
|
492 |
+
edge_index_seed2a[1, :] = (edge_index_seed2a[1, :] + 1) * (num_agent + 1) - 1
|
493 |
+
|
494 |
+
# build agent2agent bilateral connection
|
495 |
+
edge_index_a2a = radius_graph(x=pos_s[:, :2], r=self.a2a_radius, batch=batch_s, loop=False,
|
496 |
+
max_num_neighbors=300)
|
497 |
+
edge_index_a2a = subgraph(subset=mask_s, edge_index=edge_index_a2a)[0]
|
498 |
+
|
499 |
+
# add the edges which connect seed agents
|
500 |
+
edge_index_a2a = torch.cat([edge_index_a2a, edge_index_seed2a], dim=-1)
|
501 |
+
|
502 |
+
# set the position of invalid agents to the position of ego agent
|
503 |
+
# ego_pos_a = pos_a[av_index].clone()
|
504 |
+
# ego_head_a = head_a[av_index].clone()
|
505 |
+
# is_invalid = state_a == self.invalid_state
|
506 |
+
# pos_a[is_invalid] = ego_pos_a[None, :].repeat(pos_a.shape[0], 1, 1)[is_invalid]
|
507 |
+
# head_a[is_invalid] = ego_head_a[None, :].repeat(head_a.shape[0], 1)[is_invalid]
|
508 |
+
|
509 |
+
rel_pos_a2a = pos_s[edge_index_a2a[0]] - pos_s[edge_index_a2a[1]]
|
510 |
+
rel_head_a2a = wrap_angle(head_s[edge_index_a2a[0]] - head_s[edge_index_a2a[1]])
|
511 |
+
|
512 |
+
# relative motion/head for bos/eos token
|
513 |
+
# is_bos = state_a == self.enter_state
|
514 |
+
# is_bos_s = is_bos.transpose(0, 1).reshape(-1)
|
515 |
+
# rel_pos_a2a[is_bos_s[edge_index_a2a[0]]] = -self.bos_motion
|
516 |
+
# rel_pos_a2a[is_bos_s[edge_index_a2a[1]]] = self.bos_motion
|
517 |
+
# rel_head_a2a[is_bos_s[edge_index_a2a[0]]] = -torch.pi
|
518 |
+
# rel_head_a2a[is_bos_s[edge_index_a2a[1]]] = torch.pi
|
519 |
+
|
520 |
+
# is_last_eos = state_a.roll(shifts=-1, dims=1) == self.exit_state
|
521 |
+
# is_last_eos[:, 0] = False # first step
|
522 |
+
# is_last_eos_s = is_last_eos.transpose(0, 1).reshape(-1)
|
523 |
+
# rel_pos_a2a[is_last_eos_s[edge_index_a2a[0]]] = -self.bos_motion
|
524 |
+
# rel_pos_a2a[is_last_eos_s[edge_index_a2a[1]]] = self.bos_motion
|
525 |
+
# rel_head_a2a[is_last_eos_s[edge_index_a2a[0]]] = -torch.pi
|
526 |
+
# rel_head_a2a[is_last_eos_s[edge_index_a2a[1]]] = torch.pi
|
527 |
+
|
528 |
+
# handle the bos token of ego agent
|
529 |
+
# is_invalid = state_a == self.invalid_state
|
530 |
+
# is_invalid_s = is_invalid.transpose(0, 1).reshape(-1)
|
531 |
+
# is_ego_bos = (state_a[av_index] == self.enter_state)[None, :].expand(num_agent + 1, -1)
|
532 |
+
# is_ego_bos_s = is_ego_bos.transpose(0, 1).reshape(-1)
|
533 |
+
# rel_pos_a2a[is_invalid_s[edge_index_a2a[0]] & is_ego_bos_s[edge_index_a2a[0]]] = 0.
|
534 |
+
# rel_pos_a2a[is_invalid_s[edge_index_a2a[1]] & is_ego_bos_s[edge_index_a2a[1]]] = 0.
|
535 |
+
# rel_head_a2a[is_invalid_s[edge_index_a2a[0]] & is_ego_bos_s[edge_index_a2a[0]]] = 0.
|
536 |
+
# rel_head_a2a[is_invalid_s[edge_index_a2a[1]] & is_ego_bos_s[edge_index_a2a[1]]] = 0.
|
537 |
+
|
538 |
+
# handle the invalid steps
|
539 |
+
is_invalid = state_a == self.invalid_state
|
540 |
+
is_invalid_s = is_invalid.transpose(0, 1).reshape(-1)
|
541 |
+
rel_pos_a2a[is_invalid_s[edge_index_a2a[0]]] = -self.motion_gap
|
542 |
+
rel_pos_a2a[is_invalid_s[edge_index_a2a[1]]] = self.motion_gap
|
543 |
+
rel_head_a2a[is_invalid_s[edge_index_a2a[0]]] = -self.heading_gap
|
544 |
+
rel_head_a2a[is_invalid_s[edge_index_a2a[1]]] = self.heading_gap
|
545 |
+
|
546 |
+
r_a2a = torch.stack(
|
547 |
+
[torch.norm(rel_pos_a2a[:, :2], p=2, dim=-1),
|
548 |
+
angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2a[1]], nbr_vector=rel_pos_a2a[:, :2]),
|
549 |
+
rel_head_a2a], dim=-1)
|
550 |
+
r_a2a = self.r_a2a_emb(continuous_inputs=r_a2a, categorical_embs=None)
|
551 |
+
|
552 |
+
return edge_index_a2a, r_a2a
|
553 |
+
|
554 |
+
def build_map2agent_edge(self, data, num_step, pos_a, head_a, head_vector_a, state_a, batch_s, batch_pl,
|
555 |
+
mask, inference_mask=None, av_index=None):
|
556 |
+
|
557 |
+
num_agent, num_step, _ = pos_a.shape
|
558 |
+
|
559 |
+
mask_pl2a = mask.clone()
|
560 |
+
if inference_mask is not None:
|
561 |
+
mask_pl2a = mask_pl2a & inference_mask
|
562 |
+
mask_pl2a = mask_pl2a.transpose(0, 1).reshape(-1)
|
563 |
+
|
564 |
+
pos_a = torch.cat([pos_a, pos_a[av_index][None]], dim=0)
|
565 |
+
state_a = torch.cat([state_a, state_a[av_index][None]], dim=0)
|
566 |
+
head_a = torch.cat([head_a, head_a[av_index][None]], dim=0)
|
567 |
+
head_vector_a = torch.cat([head_vector_a, head_vector_a[av_index][None]], dim=0)
|
568 |
+
|
569 |
+
pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
|
570 |
+
head_s = head_a.transpose(0, 1).reshape(-1)
|
571 |
+
head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
|
572 |
+
|
573 |
+
ori_pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous()
|
574 |
+
ori_orient_pl = data['pt_token']['orientation'].contiguous()
|
575 |
+
pos_pl = ori_pos_pl.repeat(num_step, 1) # not `repeat_interleave`
|
576 |
+
orient_pl = ori_orient_pl.repeat(num_step)
|
577 |
+
|
578 |
+
# seed agent
|
579 |
+
mask_seed = state_a[av_index] != self.invalid_state
|
580 |
+
pos_seed = pos_a[av_index]
|
581 |
+
edge_index_pl2seed = radius(x=pos_seed[:, :2], y=pos_pl[:, :2], r=self.pl2seed_radius,
|
582 |
+
batch_x=torch.arange(num_step).to(pos_s.device), batch_y=batch_pl, max_num_neighbors=600)
|
583 |
+
edge_index_pl2seed = edge_index_pl2seed[:, mask_seed[edge_index_pl2seed[1]]]
|
584 |
+
|
585 |
+
# convert to global index
|
586 |
+
edge_index_pl2seed[1, :] = (edge_index_pl2seed[1, :] + 1) * (num_agent + 1) - 1
|
587 |
+
|
588 |
+
# build map2agent directed graph
|
589 |
+
# edge_index_pl2a[0]: pl token; edge_index_pl2a[1]: agent token
|
590 |
+
edge_index_pl2a = radius(x=pos_s[:, :2], y=pos_pl[:, :2], r=self.pl2a_radius,
|
591 |
+
batch_x=batch_s, batch_y=batch_pl, max_num_neighbors=300)
|
592 |
+
# We force invalid agents to interact with **all** (visible in current window) map tokens
|
593 |
+
# invalid_node_index_a = torch.where(bos_state_s.bool())[0]
|
594 |
+
# sampled_node_index_m = torch.arange(ori_pos_pl.shape[0]).to(pos_pl.device)
|
595 |
+
# if kwargs.get('sample_pt_indices', None) is not None:
|
596 |
+
# sampled_node_index_m = sampled_node_index_m[kwargs['sample_pt_indices'].long()]
|
597 |
+
# grid_a, grid_b = torch.meshgrid(sampled_node_index_m, invalid_node_index_a, indexing='ij')
|
598 |
+
# invalid_edge_index_pl2a = torch.stack([grid_a.reshape(-1), grid_b.reshape(-1)], dim=0)
|
599 |
+
# edge_index_pl2a = torch.concat([edge_index_pl2a, invalid_edge_index_pl2a], dim=-1)
|
600 |
+
# remove the edges which connect with motion-invalid agents
|
601 |
+
edge_index_pl2a = edge_index_pl2a[:, mask_pl2a[edge_index_pl2a[1]]]
|
602 |
+
|
603 |
+
# add the edges which connect seed agents with map tokens
|
604 |
+
edge_index_pl2a = torch.cat([edge_index_pl2a, edge_index_pl2seed], dim=-1)
|
605 |
+
|
606 |
+
# set the position of invalid agents to the position of ego agent
|
607 |
+
# ego_pos_a = pos_a[av_index].clone()
|
608 |
+
# ego_head_a = head_a[av_index].clone()
|
609 |
+
# is_invalid = state_a == self.invalid_state
|
610 |
+
# pos_a[is_invalid] = ego_pos_a[None, :].repeat(pos_a.shape[0], 1, 1)[is_invalid]
|
611 |
+
# head_a[is_invalid] = ego_head_a[None, :].repeat(head_a.shape[0], 1)[is_invalid]
|
612 |
+
|
613 |
+
rel_pos_pl2a = pos_pl[edge_index_pl2a[0]] - pos_s[edge_index_pl2a[1]]
|
614 |
+
rel_orient_pl2a = wrap_angle(orient_pl[edge_index_pl2a[0]] - head_s[edge_index_pl2a[1]])
|
615 |
+
|
616 |
+
# handle the invalid steps
|
617 |
+
is_invalid = state_a == self.invalid_state
|
618 |
+
is_invalid_s = is_invalid.transpose(0, 1).reshape(-1)
|
619 |
+
rel_pos_pl2a[is_invalid_s[edge_index_pl2a[1]]] = self.motion_gap
|
620 |
+
rel_orient_pl2a[is_invalid_s[edge_index_pl2a[1]]] = self.heading_gap
|
621 |
+
|
622 |
+
r_pl2a = torch.stack(
|
623 |
+
[torch.norm(rel_pos_pl2a[:, :2], p=2, dim=-1),
|
624 |
+
angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_pl2a[1]], nbr_vector=rel_pos_pl2a[:, :2]),
|
625 |
+
rel_orient_pl2a], dim=-1)
|
626 |
+
r_pl2a = self.r_pt2a_emb(continuous_inputs=r_pl2a, categorical_embs=None)
|
627 |
+
|
628 |
+
return edge_index_pl2a, r_pl2a
|
629 |
+
|
630 |
+
def get_inputs(self, data: HeteroData) -> Dict[str, torch.Tensor]:
|
631 |
+
|
632 |
+
pos_a = data['agent']['token_pos']
|
633 |
+
head_a = data['agent']['token_heading']
|
634 |
+
agent_category = data['agent']['category']
|
635 |
+
agent_token_index = data['agent']['token_idx']
|
636 |
+
agent_state_index = data['agent']['state_idx']
|
637 |
+
mask = data['agent']['raw_agent_valid_mask'].clone()
|
638 |
+
# mask[agent_category != 3] = False
|
639 |
+
|
640 |
+
if not self.predict_state:
|
641 |
+
agent_state_index = None
|
642 |
+
|
643 |
+
next_token_index_gt = agent_token_index.roll(shifts=-1, dims=1)
|
644 |
+
next_state_index_gt = agent_state_index.roll(shifts=-1, dims=1)
|
645 |
+
|
646 |
+
if self.predict_state:
|
647 |
+
next_token_eval_mask = mask.clone()
|
648 |
+
next_token_eval_mask = next_token_eval_mask * next_token_eval_mask.roll(shifts=1, dims=1)
|
649 |
+
bos_token_index = torch.nonzero(agent_state_index == 2)
|
650 |
+
eos_token_index = torch.nonzero(agent_state_index == 3)
|
651 |
+
next_token_eval_mask[bos_token_index[:, 0], bos_token_index[:, 1]] = 1
|
652 |
+
for eos_token_index_ in eos_token_index:
|
653 |
+
if not next_token_eval_mask[eos_token_index_[0], eos_token_index_[1]]:
|
654 |
+
next_token_eval_mask[eos_token_index_[0], eos_token_index_[1]:] = 0
|
655 |
+
next_token_eval_mask = next_token_eval_mask.roll(shifts=-1, dims=1)
|
656 |
+
# TODO: next_state_eval_mask !!!
|
657 |
+
|
658 |
+
if next_token_index_gt[next_token_eval_mask].min() < 0:
|
659 |
+
raise RuntimeError()
|
660 |
+
|
661 |
+
next_token_eval_mask[:, -1] = False
|
662 |
+
|
663 |
+
return {'token_pos': pos_a,
|
664 |
+
'token_heading': head_a,
|
665 |
+
'agent_category': agent_category,
|
666 |
+
'next_token_idx_gt': next_token_index_gt,
|
667 |
+
'next_state_idx_gt': next_state_index_gt,
|
668 |
+
'next_token_eval_mask': next_token_eval_mask,
|
669 |
+
'raw_agent_valid_mask': data['agent']['raw_agent_valid_mask'],
|
670 |
+
}
|
671 |
+
|
672 |
+
def forward(self,
|
673 |
+
data: HeteroData,
|
674 |
+
map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
675 |
+
|
676 |
+
pos_a = data['agent']['token_pos'].clone() # (num_agent, num_shifted_step, 2)
|
677 |
+
head_a = data['agent']['token_heading'].clone() # (num_agent, num_shifted_step)
|
678 |
+
num_agent, num_step, traj_dim = pos_a.shape # e.g. (50, 18, 2)
|
679 |
+
agent_category = data['agent']['category'].clone() # (num_agent,)
|
680 |
+
agent_token_index = data['agent']['token_idx'].clone() # (num_agent, num_step)
|
681 |
+
agent_state_index = data['agent']['state_idx'].clone() # (num_agent, num_step)
|
682 |
+
agent_type_index = data['agent']['type'].clone() # (num_agent, num_step)
|
683 |
+
agent_enter_pl_token_idx = None
|
684 |
+
agent_enter_offset_token_idx = None
|
685 |
+
|
686 |
+
device = pos_a.device
|
687 |
+
|
688 |
+
seed_step_mask = agent_state_index[:, 1:] == self.enter_state
|
689 |
+
if torch.any(seed_step_mask.sum(dim=0) > self.seed_size):
|
690 |
+
print(agent_state_index)
|
691 |
+
print(agent_state_index.shape)
|
692 |
+
print(seed_step_mask.long())
|
693 |
+
print(seed_step_mask.sum(dim=0))
|
694 |
+
raise RuntimeError(f"Seed size {self.seed_size} is too small.")
|
695 |
+
|
696 |
+
# fix pos and head of invalid agents
|
697 |
+
av_index = int(data['agent']['av_index'])
|
698 |
+
# ego_pos_a = pos_a[av_index].clone() # (num_shifted_step, 2)
|
699 |
+
# ego_head_vector_a = head_vector_a[av_index] # (num_shifted_step, 2)
|
700 |
+
# is_invalid = agent_state_index == self.invalid_state
|
701 |
+
# pos_a[is_invalid] = ego_pos_a[None, :].expand(pos_a.shape[0], -1, -1)[is_invalid]
|
702 |
+
# head_vector_a[is_invalid] = ego_head_vector_a[None, :].expand(head_vector_a.shape[0], -1, -1)[is_invalid]
|
703 |
+
|
704 |
+
if not self.predict_state:
|
705 |
+
agent_state_index = None
|
706 |
+
|
707 |
+
feat_a, head_vector_a = self.agent_token_embedding(data, agent_token_index, agent_state_index, pos_a, head_a, av_index=av_index)
|
708 |
+
|
709 |
+
# build masks
|
710 |
+
mask = data['agent']['raw_agent_valid_mask'].clone()
|
711 |
+
temporal_mask = mask.clone()
|
712 |
+
interact_mask = mask.clone()
|
713 |
+
if self.predict_state:
|
714 |
+
|
715 |
+
agent_enter_offset_token_idx = data['agent']['neighbor_token_idx']
|
716 |
+
agent_enter_pl_token_idx = data['agent']['map_bos_token_idx']
|
717 |
+
agent_enter_pl_token_id = data['agent']['map_bos_token_id']
|
718 |
+
|
719 |
+
is_bos = agent_state_index == self.enter_state
|
720 |
+
is_eos = agent_state_index == self.exit_state
|
721 |
+
bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
|
722 |
+
eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1)) # not `-1`
|
723 |
+
|
724 |
+
temporal_mask = torch.ones_like(mask)
|
725 |
+
motion_mask = torch.arange(mask.shape[1]).expand(mask.shape[0], -1).to(device)
|
726 |
+
motion_mask = (motion_mask > bos_index[:, None]) & (motion_mask <= eos_index[:, None])
|
727 |
+
temporal_mask[motion_mask] = mask[motion_mask]
|
728 |
+
|
729 |
+
interact_mask[agent_state_index == self.enter_state] = True
|
730 |
+
interact_mask = torch.cat([interact_mask, torch.ones_like(interact_mask[:1])]).bool() # placeholder
|
731 |
+
|
732 |
+
edge_index_t, r_t = self.build_temporal_edge(pos_a, head_a, head_vector_a, agent_state_index, temporal_mask,
|
733 |
+
av_index=av_index)
|
734 |
+
|
735 |
+
# +1: placeholder for seed agent
|
736 |
+
# if isinstance(data, Batch):
|
737 |
+
# print(data['agent']['batch'], data.num_graphs)
|
738 |
+
# batch_s = torch.cat([data['agent']['batch'] + data.num_graphs * t for t in range(num_step)], dim=0)
|
739 |
+
# batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t for t in range(num_step)], dim=0)
|
740 |
+
# else:
|
741 |
+
batch_s = torch.arange(num_step, device=device).repeat_interleave(data['agent']['num_nodes'] + 1)
|
742 |
+
batch_pl = torch.arange(num_step, device=device).repeat_interleave(data['pt_token']['num_nodes'])
|
743 |
+
|
744 |
+
edge_index_a2a, r_a2a = self.build_interaction_edge(pos_a, head_a, head_vector_a, agent_state_index, batch_s,
|
745 |
+
interact_mask, av_index=av_index)
|
746 |
+
|
747 |
+
agent_category = torch.cat([agent_category, torch.full(agent_category[-1:].shape, 3, device=device)])
|
748 |
+
interact_mask[agent_category != 3] = False
|
749 |
+
edge_index_pl2a, r_pl2a = self.build_map2agent_edge(data, num_step, pos_a, head_a, head_vector_a,
|
750 |
+
agent_state_index, batch_s, batch_pl, interact_mask, av_index=av_index)
|
751 |
+
|
752 |
+
# mapping network
|
753 |
+
# z = torch.randn(num_agent, self.hidden_dim).to(feat_a.device)
|
754 |
+
# w = self.mapping_network(z)
|
755 |
+
|
756 |
+
for i in range(self.num_layers):
|
757 |
+
|
758 |
+
# feat_a = feat_a + w[:, None]
|
759 |
+
|
760 |
+
feat_a = feat_a.reshape(-1, self.hidden_dim) # (num_agent, num_step, hidden_dim) -> (seq_len, hidden_dim)
|
761 |
+
feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t)
|
762 |
+
|
763 |
+
feat_a = feat_a.reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim)
|
764 |
+
feat_a = self.pt2a_attn_layers[i]((
|
765 |
+
map_enc['x_pt'].repeat_interleave(repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(
|
766 |
+
-1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a)
|
767 |
+
|
768 |
+
feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a)
|
769 |
+
feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1)
|
770 |
+
|
771 |
+
# next motion token
|
772 |
+
next_token_prob = self.token_predict_head(feat_a[:-1]) # (num_agent, num_step, token_size)
|
773 |
+
next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)
|
774 |
+
_, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1) # (num_agent, num_step, 10)
|
775 |
+
|
776 |
+
next_token_index_gt = agent_token_index.roll(shifts=-1, dims=1)
|
777 |
+
|
778 |
+
# next state token
|
779 |
+
next_state_prob = self.state_predict_head(feat_a[:-1])
|
780 |
+
next_state_idx = next_state_prob.softmax(dim=-1).argmax(dim=-1, keepdim=True) # (num_agent, num_step, 1)
|
781 |
+
|
782 |
+
next_state_index_gt = agent_state_index.roll(shifts=-1, dims=1) # (invalid, valid, exit)
|
783 |
+
|
784 |
+
# seed agent
|
785 |
+
feat_seed = self.seed_head(feat_a[-1:]) + self.seed_feature.weight[:, None]
|
786 |
+
next_state_prob_seed = self.seed_state_predict_head(feat_seed)
|
787 |
+
next_state_idx_seed = next_state_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True) # (self.seed_size, num_step, 1)
|
788 |
+
|
789 |
+
next_type_prob_seed = self.seed_type_predict_head(feat_seed)
|
790 |
+
next_type_idx_seed = next_type_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True)
|
791 |
+
|
792 |
+
next_type_index_gt = agent_type_index[:, None].expand(-1, num_step).roll(shifts=-1, dims=1)
|
793 |
+
|
794 |
+
# polygon token for bos token
|
795 |
+
# next_bos_pl_prob = self.bos_pl_predict_head(feat_a)
|
796 |
+
# next_bos_pl_prob_softmax = torch.softmax(next_bos_pl_prob, dim=-1)
|
797 |
+
# _, next_bos_pl_idx = torch.topk(next_bos_pl_prob_softmax, k=1, dim=-1) # (num_agent, num_step, 1)
|
798 |
+
|
799 |
+
# next_bos_pl_index_gt = agent_enter_pl_token_id.roll(shifts=-1, dims=-1)
|
800 |
+
|
801 |
+
# offset token for bos token
|
802 |
+
# next_bos_offset_prob = self.bos_offset_predict_head(feat_a)
|
803 |
+
# next_bos_offset_prob_softmax = torch.softmax(next_bos_offset_prob, dim=-1)
|
804 |
+
# _, next_bos_offset_idx = torch.topk(next_bos_offset_prob_softmax, k=1, dim=-1)
|
805 |
+
|
806 |
+
# next_bos_offset_index_gt = agent_enter_offset_token_idx.roll(shifts=-1, dims=-1)
|
807 |
+
|
808 |
+
# next token prediction mask
|
809 |
+
bos_token_index = torch.nonzero(agent_state_index == self.enter_state)
|
810 |
+
eos_token_index = torch.nonzero(agent_state_index == self.exit_state)
|
811 |
+
|
812 |
+
# mask for motion tokens
|
813 |
+
next_token_eval_mask = mask.clone()
|
814 |
+
next_token_eval_mask = next_token_eval_mask * next_token_eval_mask.roll(shifts=-1, dims=1) * next_token_eval_mask.roll(shifts=1, dims=1)
|
815 |
+
for bos_token_index_ in bos_token_index:
|
816 |
+
next_token_eval_mask[bos_token_index_[0], bos_token_index_[1] : bos_token_index_[1] + 1] = 1
|
817 |
+
next_token_eval_mask[bos_token_index_[0], bos_token_index_[1] + 1 : bos_token_index_[1] + 2] = \
|
818 |
+
mask[bos_token_index_[0], bos_token_index_[1] + 2 : bos_token_index_[1] + 3]
|
819 |
+
next_token_eval_mask[eos_token_index[:, 0], eos_token_index[:, 1]] = 0
|
820 |
+
|
821 |
+
# mask for state tokens
|
822 |
+
next_state_eval_mask = mask.clone()
|
823 |
+
next_state_eval_mask = next_state_eval_mask * next_state_eval_mask.roll(shifts=-1, dims=1) * next_state_eval_mask.roll(shifts=1, dims=1)
|
824 |
+
for bos_token_index_ in bos_token_index:
|
825 |
+
next_state_eval_mask[bos_token_index_[0], :bos_token_index_[1]] = 0
|
826 |
+
next_state_eval_mask[bos_token_index_[0], bos_token_index_[1] : bos_token_index_[1] + 1] = 1
|
827 |
+
next_state_eval_mask[bos_token_index_[0], bos_token_index_[1] + 1 : bos_token_index_[1] + 2] = \
|
828 |
+
mask[bos_token_index_[0], bos_token_index_[1] + 2 : bos_token_index_[1] + 3]
|
829 |
+
for eos_token_index_ in eos_token_index:
|
830 |
+
next_state_eval_mask[eos_token_index_[0], eos_token_index_[1] + 1:] = 1
|
831 |
+
next_state_eval_mask[eos_token_index_[0], eos_token_index_[1] : eos_token_index_[1] + 1] = \
|
832 |
+
mask[eos_token_index_[0], eos_token_index_[1] - 1 : eos_token_index_[1]]
|
833 |
+
|
834 |
+
# seed agents
|
835 |
+
next_bos_token_index = torch.nonzero(next_state_index_gt == self.enter_state)
|
836 |
+
next_bos_token_index = next_bos_token_index[next_bos_token_index[:, 1] < num_step - 1]
|
837 |
+
|
838 |
+
next_state_index_gt_seed = torch.full((self.seed_size, num_step), self.seed_state_type.index('invalid'), device=next_state_index_gt.device)
|
839 |
+
next_type_index_gt_seed = torch.full((self.seed_size, num_step), self.seed_agent_type.index('seed'), device=next_state_index_gt.device)
|
840 |
+
next_eval_mask_seed = torch.ones_like(next_state_index_gt_seed)
|
841 |
+
|
842 |
+
num_seed = torch.zeros(num_step, device=next_state_index_gt.device).long()
|
843 |
+
for next_bos_token_index_ in next_bos_token_index:
|
844 |
+
if num_seed[next_bos_token_index_[1]] < self.seed_size:
|
845 |
+
next_state_index_gt_seed[num_seed[next_bos_token_index_[1]], next_bos_token_index_[1]] = self.seed_state_type.index('enter')
|
846 |
+
next_type_index_gt_seed[num_seed[next_bos_token_index_[1]], next_bos_token_index_[1]] = next_type_index_gt[next_bos_token_index_[0], next_bos_token_index_[1]]
|
847 |
+
num_seed[next_bos_token_index_[1]] += 1
|
848 |
+
|
849 |
+
# the last timestep is the beginning of the sequence (also the input)
|
850 |
+
next_token_eval_mask[:, -1] = 0
|
851 |
+
next_state_eval_mask[:, -1] = 0
|
852 |
+
next_eval_mask_seed[:, -1] = 0
|
853 |
+
# next_bos_token_eval_mask[:, -1] = False
|
854 |
+
|
855 |
+
# no invalid motion token will be supervised
|
856 |
+
if (next_token_index_gt[next_token_eval_mask] < 0).any():
|
857 |
+
raise RuntimeError()
|
858 |
+
|
859 |
+
next_state_index_gt[next_state_index_gt == self.exit_state] = self.valid_state_type.index('exit')
|
860 |
+
|
861 |
+
return {'x_a': feat_a,
|
862 |
+
# motion token
|
863 |
+
'next_token_idx': next_token_idx,
|
864 |
+
'next_token_prob': next_token_prob,
|
865 |
+
'next_token_idx_gt': next_token_index_gt,
|
866 |
+
'next_token_eval_mask': next_token_eval_mask.bool(),
|
867 |
+
# state token
|
868 |
+
'next_state_idx': next_state_idx,
|
869 |
+
'next_state_prob': next_state_prob,
|
870 |
+
'next_state_idx_gt': next_state_index_gt,
|
871 |
+
'next_state_eval_mask': next_state_eval_mask.bool(),
|
872 |
+
# seed agent
|
873 |
+
'next_state_idx_seed': next_state_idx_seed,
|
874 |
+
'next_state_prob_seed': next_state_prob_seed,
|
875 |
+
'next_state_idx_gt_seed': next_state_index_gt_seed,
|
876 |
+
'next_type_idx_seed': next_type_idx_seed,
|
877 |
+
'next_type_prob_seed': next_type_prob_seed,
|
878 |
+
'next_type_idx_gt_seed': next_type_index_gt_seed,
|
879 |
+
'next_eval_mask_seed': next_eval_mask_seed.bool(),
|
880 |
+
# pl token for bos
|
881 |
+
# 'next_bos_pl_idx': next_bos_pl_idx,
|
882 |
+
# 'next_bos_pl_prob': next_bos_pl_prob,
|
883 |
+
# 'next_bos_pl_index_gt': next_bos_pl_index_gt,
|
884 |
+
# offset token for bos
|
885 |
+
# 'next_bos_offset_idx': next_bos_offset_idx,
|
886 |
+
# 'next_bos_offset_prob': next_bos_offset_prob,
|
887 |
+
# 'next_bos_offset_index_gt': next_bos_offset_index_gt,
|
888 |
+
# 'next_bos_token_eval_mask': next_bos_token_eval_mask,
|
889 |
+
}
|
890 |
+
|
891 |
+
def inference(self,
|
892 |
+
data: HeteroData,
|
893 |
+
map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
894 |
+
|
895 |
+
start_state_idx = data['agent']['state_idx'][:, (self.num_historical_steps - 1) // self.shift]
|
896 |
+
filter_mask = (start_state_idx == self.valid_state) | (start_state_idx == self.exit_state)
|
897 |
+
seed_step_mask = data['agent']['state_idx'][:, (self.num_historical_steps - 1) // self.shift:] == self.enter_state
|
898 |
+
seed_agent_index_per_step = [torch.nonzero(seed_step_mask[:, t]).squeeze(dim=-1) for t in range(seed_step_mask.shape[1])]
|
899 |
+
if torch.any(seed_step_mask.sum(dim=0) > self.seed_size):
|
900 |
+
raise RuntimeError(f"Seed size {self.seed_size} is too small.")
|
901 |
+
|
902 |
+
# num_historical_steps=11
|
903 |
+
eval_mask = data['agent']['valid_mask'][filter_mask, self.num_historical_steps - 1]
|
904 |
+
|
905 |
+
if self.predict_state:
|
906 |
+
eval_mask = torch.ones_like(eval_mask).bool()
|
907 |
+
|
908 |
+
# agent attributes
|
909 |
+
pos_a = data['agent']['token_pos'][filter_mask].clone() # (num_agent, num_step, 2)
|
910 |
+
state_a = data['agent']['state_idx'][filter_mask].clone() # (num_agent, num_step)
|
911 |
+
head_a = data['agent']['token_heading'][filter_mask].clone() # (num_agent, num_step)
|
912 |
+
gt_traj = data['agent']['position'][filter_mask, self.num_historical_steps:, :self.input_dim].contiguous()
|
913 |
+
num_agent, num_step, traj_dim = pos_a.shape
|
914 |
+
|
915 |
+
av_index = int(data['agent']['av_index'])
|
916 |
+
av_index -= (~filter_mask[:av_index]).sum()
|
917 |
+
|
918 |
+
# map attributes
|
919 |
+
pos_pl = data['pt_token']['position'][:, :2].clone() # (num_pl, 2)
|
920 |
+
|
921 |
+
# make future steps to zero
|
922 |
+
pos_a[:, (self.num_historical_steps - 1) // self.shift:] = 0
|
923 |
+
state_a[:, (self.num_historical_steps - 1) // self.shift:] = 0
|
924 |
+
head_a[:, (self.num_historical_steps - 1) // self.shift:] = 0
|
925 |
+
|
926 |
+
agent_valid_mask = data['agent']['raw_agent_valid_mask'][filter_mask].clone() # token_valid_mask
|
927 |
+
agent_valid_mask[:, (self.num_historical_steps - 1) // self.shift:] = True
|
928 |
+
agent_valid_mask[~eval_mask] = False
|
929 |
+
agent_token_index = data['agent']['token_idx'][filter_mask]
|
930 |
+
agent_state_index = data['agent']['state_idx'][filter_mask]
|
931 |
+
agent_type = data['agent']['type'][filter_mask]
|
932 |
+
agent_category = data['agent']['category'][filter_mask]
|
933 |
+
|
934 |
+
feat_a, head_vector_a, agent_token_traj_all, agent_token_emb, categorical_embs = self.agent_token_embedding(data,
|
935 |
+
agent_token_index,
|
936 |
+
agent_state_index,
|
937 |
+
pos_a,
|
938 |
+
head_a,
|
939 |
+
inference=True,
|
940 |
+
filter_mask=filter_mask,
|
941 |
+
av_index=av_index,
|
942 |
+
)
|
943 |
+
feat_seed = feat_a[-1:]
|
944 |
+
feat_a = feat_a[:-1]
|
945 |
+
|
946 |
+
agent_type = data["agent"]["type"][filter_mask]
|
947 |
+
veh_mask = agent_type == 0
|
948 |
+
cyc_mask = agent_type == 2
|
949 |
+
ped_mask = agent_type == 1
|
950 |
+
|
951 |
+
# self.num_recurrent_steps_val = 91 - 11 = 80
|
952 |
+
self.num_recurrent_steps_val = data["agent"]['position'].shape[1] - self.num_historical_steps
|
953 |
+
pred_traj = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val, 2, device=feat_a.device) # (num_agent, 80, 2)
|
954 |
+
pred_head = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val, device=feat_a.device)
|
955 |
+
pred_type = agent_type.clone()
|
956 |
+
pred_state = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val, device=feat_a.device)
|
957 |
+
pred_prob = torch.zeros(pos_a.shape[0], self.num_recurrent_steps_val // self.shift, device=feat_a.device) # (num_agent, 80 // 5 = 16)
|
958 |
+
next_token_idx_list = []
|
959 |
+
next_state_idx_list = []
|
960 |
+
next_bos_pl_idx_list = []
|
961 |
+
next_bos_offset_idx_list = []
|
962 |
+
feat_a_t_dict = {}
|
963 |
+
feat_sa_t_dict = {}
|
964 |
+
|
965 |
+
# build masks (init)
|
966 |
+
mask = agent_valid_mask.clone()
|
967 |
+
temporal_mask = mask.clone()
|
968 |
+
interact_mask = mask.clone()
|
969 |
+
if self.predict_state:
|
970 |
+
|
971 |
+
# find bos and eos index
|
972 |
+
is_bos = agent_state_index == self.enter_state
|
973 |
+
is_eos = agent_state_index == self.exit_state
|
974 |
+
bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
|
975 |
+
eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1))
|
976 |
+
|
977 |
+
temporal_mask = torch.ones_like(mask)
|
978 |
+
motion_mask = torch.arange(mask.shape[1]).expand(mask.shape[0], mask.shape[1]).to(mask.device)
|
979 |
+
motion_mask = (motion_mask > bos_index[:, None]) & (motion_mask <= eos_index[:, None])
|
980 |
+
motion_mask[:, self.num_historical_steps // self.shift:] = False
|
981 |
+
temporal_mask[motion_mask] = mask[motion_mask]
|
982 |
+
|
983 |
+
interact_mask = torch.ones_like(mask)
|
984 |
+
non_motion_mask = ~motion_mask
|
985 |
+
non_motion_mask[:, self.num_historical_steps // self.shift:] = False
|
986 |
+
interact_mask[non_motion_mask] = False
|
987 |
+
interact_mask[agent_state_index == self.enter_state] = True
|
988 |
+
|
989 |
+
temporal_mask[:, (self.num_historical_steps - 1) // self.shift:] = True
|
990 |
+
interact_mask[:, (self.num_historical_steps - 1) // self.shift:] = True
|
991 |
+
|
992 |
+
# mapping network
|
993 |
+
# z = torch.randn(num_agent, self.hidden_dim).to(feat_a.device)
|
994 |
+
# w = self.mapping_network(z)
|
995 |
+
|
996 |
+
# we only need to predict 16 next tokens
|
997 |
+
for t in range(self.num_recurrent_steps_val // self.shift):
|
998 |
+
|
999 |
+
# feat_a = feat_a + w[:, None]
|
1000 |
+
num_agent = pos_a.shape[0]
|
1001 |
+
|
1002 |
+
if t == 0:
|
1003 |
+
inference_mask = temporal_mask.clone()
|
1004 |
+
inference_mask = torch.cat([inference_mask, torch.ones_like(inference_mask[-1:])])
|
1005 |
+
inference_mask[:, (self.num_historical_steps - 1) // self.shift + t:] = False
|
1006 |
+
else:
|
1007 |
+
inference_mask = torch.zeros_like(temporal_mask)
|
1008 |
+
inference_mask = torch.cat([inference_mask, torch.zeros_like(inference_mask[-1:])])
|
1009 |
+
inference_mask[:, max((self.num_historical_steps - 1) // self.shift + t - (self.num_interaction_steps // self.shift), 0) :
|
1010 |
+
(self.num_historical_steps - 1) // self.shift + t] = True
|
1011 |
+
|
1012 |
+
interact_mask = torch.cat([interact_mask, torch.ones_like(interact_mask[:1])]).bool() # placeholder
|
1013 |
+
|
1014 |
+
edge_index_t, r_t = self.build_temporal_edge(pos_a, head_a, head_vector_a, state_a, temporal_mask, inference_mask,
|
1015 |
+
av_index=av_index)
|
1016 |
+
|
1017 |
+
# +1: placeholder for seed agent
|
1018 |
+
batch_s = torch.arange(num_step, device=pos_a.device).repeat_interleave(num_agent + 1)
|
1019 |
+
batch_pl = torch.arange(num_step, device=pos_a.device).repeat_interleave(data['pt_token']['num_nodes'])
|
1020 |
+
|
1021 |
+
# In the inference stage, we only infer the current stage for recurrent
|
1022 |
+
edge_index_a2a, r_a2a = self.build_interaction_edge(pos_a, head_a, head_vector_a, state_a, batch_s,
|
1023 |
+
interact_mask, inference_mask, av_index=av_index)
|
1024 |
+
edge_index_pl2a, r_pl2a = self.build_map2agent_edge(data, num_step, pos_a, head_a, head_vector_a, state_a, batch_s, batch_pl,
|
1025 |
+
interact_mask, inference_mask, av_index=av_index)
|
1026 |
+
interact_mask = interact_mask[:-1]
|
1027 |
+
|
1028 |
+
# if t > 0:
|
1029 |
+
# feat_a_sum = feat_a.sum(dim=-1)
|
1030 |
+
# for a in range(pos_a.shape[0]):
|
1031 |
+
# t_1 = (self.num_historical_steps - 1) // self.shift + t - 1
|
1032 |
+
# print(f"agent {a} t_1 {t_1}")
|
1033 |
+
# print(f"token: {next_token_idx[a]}")
|
1034 |
+
# print(f"state: {next_state_idx[a]}")
|
1035 |
+
# print(f"feat_a_sum: {feat_a_sum[a, t_1]}")
|
1036 |
+
|
1037 |
+
for i in range(self.num_layers):
|
1038 |
+
|
1039 |
+
if (i in feat_a_t_dict) and (i in feat_sa_t_dict):
|
1040 |
+
feat_a = feat_a_t_dict[i]
|
1041 |
+
feat_seed = feat_sa_t_dict[i]
|
1042 |
+
|
1043 |
+
feat_a = torch.cat([feat_a, feat_seed], dim=0)
|
1044 |
+
|
1045 |
+
feat_a = feat_a.reshape(-1, self.hidden_dim)
|
1046 |
+
feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t)
|
1047 |
+
|
1048 |
+
feat_a = feat_a.reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim)
|
1049 |
+
feat_a = self.pt2a_attn_layers[i]((
|
1050 |
+
map_enc['x_pt'].repeat_interleave(repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(
|
1051 |
+
-1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a)
|
1052 |
+
|
1053 |
+
feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a)
|
1054 |
+
feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1)
|
1055 |
+
|
1056 |
+
feat_seed = feat_a[-1:] # (1, num_step, hidden_dim)
|
1057 |
+
feat_a = feat_a[:-1] # (num_agent, num_step, hidden_dim)
|
1058 |
+
|
1059 |
+
if t == 0:
|
1060 |
+
feat_a_t_dict[i + 1] = feat_a
|
1061 |
+
feat_sa_t_dict[i + 1] = feat_seed
|
1062 |
+
else:
|
1063 |
+
# update agent features at current step
|
1064 |
+
n = feat_a_t_dict[i + 1].shape[0]
|
1065 |
+
feat_a_t_dict[i + 1][:n, (self.num_historical_steps - 1) // self.shift - 1 + t] = feat_a[:n, (self.num_historical_steps - 1) // self.shift - 1 + t]
|
1066 |
+
# add newly inserted agent features (only when t changed)
|
1067 |
+
if feat_a.shape[0] > n:
|
1068 |
+
m = feat_a.shape[0] - n
|
1069 |
+
feat_a_t_dict[i + 1] = torch.cat([feat_a_t_dict[i + 1], feat_a[-m:]])
|
1070 |
+
# update seed agent features at current step
|
1071 |
+
feat_sa_t_dict[i + 1][:, (self.num_historical_steps - 1) // self.shift - 1 + t] = feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t]
|
1072 |
+
|
1073 |
+
# next motion token
|
1074 |
+
next_token_prob = self.token_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
|
1075 |
+
next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)
|
1076 |
+
topk_token_prob, next_token_idx = torch.topk(next_token_prob_softmax, k=self.beam_size, dim=-1) # both (num_agent, beam_size) e.g. (31, 5)
|
1077 |
+
|
1078 |
+
# next state token
|
1079 |
+
next_state_prob = self.state_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
|
1080 |
+
next_state_idx = next_state_prob.softmax(dim=-1).argmax(dim=-1)
|
1081 |
+
next_state_idx[next_state_idx == self.valid_state_type.index('exit')] = self.exit_state
|
1082 |
+
|
1083 |
+
# seed agent
|
1084 |
+
feat_seed = self.seed_head(feat_seed) + self.seed_feature.weight[:, None]
|
1085 |
+
next_state_prob_seed = self.seed_state_predict_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
|
1086 |
+
next_state_idx_seed = next_state_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True)
|
1087 |
+
next_state_idx_seed[next_state_idx_seed == self.seed_state_type.index('enter')] = self.enter_state
|
1088 |
+
|
1089 |
+
next_type_prob_seed = self.seed_type_predict_head(feat_seed[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
|
1090 |
+
next_type_idx_seed = next_type_prob_seed.softmax(dim=-1).argmax(dim=-1, keepdim=True)
|
1091 |
+
|
1092 |
+
# print(f"t: {t}")
|
1093 |
+
# print(next_type_idx_seed[..., 0].tolist())
|
1094 |
+
|
1095 |
+
# bos pl prediction
|
1096 |
+
# next_bos_pl_prob = self.bos_pl_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
|
1097 |
+
# next_bos_pl_prob_softmax = torch.softmax(next_bos_pl_prob, dim=-1)
|
1098 |
+
# next_bos_pl_idx = torch.argmax(next_bos_pl_prob_softmax, dim=-1)
|
1099 |
+
|
1100 |
+
# bos offset prediction
|
1101 |
+
# next_bos_offset_prob = self.bos_offset_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
|
1102 |
+
# next_bos_offset_prob_softmax = torch.softmax(next_bos_offset_prob, dim=-1)
|
1103 |
+
# next_bos_offset_idx = torch.argmax(next_bos_offset_prob_softmax, dim=-1)
|
1104 |
+
|
1105 |
+
# convert the predicted token to a 0.5s (6 timesteps) trajectory
|
1106 |
+
expanded_token_index = next_token_idx[..., None, None, None].expand(-1, -1, 6, 4, 2)
|
1107 |
+
next_token_traj = torch.gather(agent_token_traj_all, 1, expanded_token_index) # (num_agent, beam_size, 6, 4, 2)
|
1108 |
+
|
1109 |
+
# apply rotation and translation on 'next_token_traj'
|
1110 |
+
theta = head_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]
|
1111 |
+
cos, sin = theta.cos(), theta.sin()
|
1112 |
+
rot_mat = torch.zeros((num_agent, 2, 2), device=theta.device)
|
1113 |
+
rot_mat[:, 0, 0] = cos
|
1114 |
+
rot_mat[:, 0, 1] = sin
|
1115 |
+
rot_mat[:, 1, 0] = -sin
|
1116 |
+
rot_mat[:, 1, 1] = cos
|
1117 |
+
agent_diff_rel = torch.bmm(next_token_traj.view(-1, 4, 2),
|
1118 |
+
rot_mat[:, None, None, ...].repeat(1, self.beam_size, self.shift + 1, 1, 1).view(
|
1119 |
+
-1, 2, 2)).view(num_agent, self.beam_size, self.shift + 1, 4, 2)
|
1120 |
+
agent_pred_rel = agent_diff_rel + pos_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t, :][:, None, None, None, ...]
|
1121 |
+
|
1122 |
+
# sample 1 most probable index of top beam_size tokens, (num_agent, beam_size) -> (num_agent, 1)
|
1123 |
+
# then sample the agent_pred_rel, (num_agent, beam_size, 6, 4, 2) -> (num_agent, 6, 4, 2)
|
1124 |
+
sample_token_index = torch.multinomial(topk_token_prob, 1).to(agent_pred_rel.device)
|
1125 |
+
next_token_idx = next_token_idx.gather(dim=1, index=sample_token_index).squeeze(-1)
|
1126 |
+
agent_pred_rel = agent_pred_rel.gather(dim=1,
|
1127 |
+
index=sample_token_index[..., None, None, None].expand(-1, -1, 6, 4,
|
1128 |
+
2))[:, 0, ...]
|
1129 |
+
|
1130 |
+
# get predicted position and heading of current shifted timesteps
|
1131 |
+
diff_xy = agent_pred_rel[:, 1:, 0, :] - agent_pred_rel[:, 1:, 3, :]
|
1132 |
+
pred_traj[:num_agent, t * 5 : (t + 1) * 5] = agent_pred_rel[:, 1:, ...].clone().mean(dim=2)
|
1133 |
+
pred_head[:num_agent, t * 5 : (t + 1) * 5] = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])
|
1134 |
+
pred_state[:num_agent, t * 5 : (t + 1) * 5] = next_state_idx[:, None].repeat(1, 5)
|
1135 |
+
# pred_prob[:num_agent, t] = topk_token_prob.gather(dim=-1, index=sample_token_index)[:, 0] # (num_agent, beam_size) -> (num_agent,)
|
1136 |
+
|
1137 |
+
# update pos/head/state of current step
|
1138 |
+
pos_a[:, (self.num_historical_steps - 1) // self.shift + t] = agent_pred_rel[:, -1, ...].clone().mean(dim=1)
|
1139 |
+
diff_xy = agent_pred_rel[:, -1, 0, :] - agent_pred_rel[:, -1, 3, :]
|
1140 |
+
theta = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])
|
1141 |
+
head_a[:, (self.num_historical_steps - 1) // self.shift + t] = theta
|
1142 |
+
state_a[:, (self.num_historical_steps - 1) // self.shift + t] = next_state_idx
|
1143 |
+
|
1144 |
+
# the case that the current predicted state token is invalid/exit
|
1145 |
+
is_eos = next_state_idx == self.exit_state
|
1146 |
+
is_invalid = next_state_idx == self.invalid_state
|
1147 |
+
|
1148 |
+
next_token_idx[is_invalid] = -1
|
1149 |
+
pos_a[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = 0.
|
1150 |
+
head_a[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = 0.
|
1151 |
+
|
1152 |
+
mask[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = False # to handle those newly-added agents
|
1153 |
+
interact_mask[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = False
|
1154 |
+
|
1155 |
+
agent_token_emb[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = self.no_token_emb(torch.zeros(1, device=pos_a.device).long())
|
1156 |
+
|
1157 |
+
type_emb = categorical_embs[0].reshape(num_agent, num_step, -1)
|
1158 |
+
shape_emb = categorical_embs[1].reshape(num_agent, num_step, -1)
|
1159 |
+
type_emb[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = self.type_a_emb(torch.tensor(self.all_agent_type.index('invalid'), device=pos_a.device).long())
|
1160 |
+
shape_emb[is_invalid, (self.num_historical_steps - 1) // self.shift + t] = self.shape_emb(torch.full((1, 3), self.invalid_shape_value, device=pos_a.device))
|
1161 |
+
categorical_embs = [type_emb.reshape(-1, self.hidden_dim), shape_emb.reshape(-1, self.hidden_dim)]
|
1162 |
+
|
1163 |
+
# FIXME: need to discuss!!!
|
1164 |
+
# if is_eos.any():
|
1165 |
+
|
1166 |
+
# pos_a[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = 0.
|
1167 |
+
# head_a[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = 0.
|
1168 |
+
# mask[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = False # to handle those newly-added agents
|
1169 |
+
# interact_mask[torch.cat([is_eos, torch.zeros(1, device=is_eos.device).bool()]), (self.num_historical_steps - 1) // self.shift + t + 1:] = False
|
1170 |
+
|
1171 |
+
# agent_token_emb[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = self.no_token_emb(torch.zeros(1, device=pos_a.device).long())
|
1172 |
+
|
1173 |
+
# type_emb = categorical_embs[0].reshape(num_agent, num_step, -1)
|
1174 |
+
# shape_emb = categorical_embs[1].reshape(num_agent, num_step, -1)
|
1175 |
+
# type_emb[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = self.type_a_emb(torch.tensor(self.all_agent_type.index('invalid'), device=pos_a.device).long())
|
1176 |
+
# shape_emb[is_eos, (self.num_historical_steps - 1) // self.shift + t + 1:] = self.shape_emb(torch.full((1, 3), self.invalid_shape_value, device=pos_a.device))
|
1177 |
+
# categorical_embs = [type_emb.reshape(-1, self.hidden_dim), shape_emb.reshape(-1, self.hidden_dim)]
|
1178 |
+
|
1179 |
+
# for sa in range(next_state_idx_seed.shape[0]):
|
1180 |
+
# if next_state_idx_seed[sa] == self.enter_state:
|
1181 |
+
# print(f"agent {sa} is entering at step {t}")
|
1182 |
+
|
1183 |
+
# insert new agents (from seed agent)
|
1184 |
+
seed_agent_index_cur_step = seed_agent_index_per_step[t]
|
1185 |
+
num_new_agent = min(len(seed_agent_index_cur_step), next_state_idx_seed.bool().sum())
|
1186 |
+
new_agent_mask = next_state_idx_seed.bool()
|
1187 |
+
next_state_idx_seed = next_state_idx_seed[new_agent_mask]
|
1188 |
+
next_state_idx_seed = next_state_idx_seed[:num_new_agent]
|
1189 |
+
next_type_idx_seed = next_type_idx_seed[new_agent_mask]
|
1190 |
+
next_type_idx_seed = next_type_idx_seed[:num_new_agent]
|
1191 |
+
selected_agent_index_cur_step = seed_agent_index_cur_step[:num_new_agent]
|
1192 |
+
agent_token_index = torch.cat([agent_token_index, data['agent']['token_idx'][selected_agent_index_cur_step]])
|
1193 |
+
agent_state_index = torch.cat([agent_state_index, data['agent']['state_idx'][selected_agent_index_cur_step]])
|
1194 |
+
agent_category = torch.cat([agent_category, data['agent']['category'][selected_agent_index_cur_step]])
|
1195 |
+
agent_valid_mask = torch.cat([agent_valid_mask, data['agent']['raw_agent_valid_mask'][selected_agent_index_cur_step]])
|
1196 |
+
gt_traj = torch.cat([gt_traj, data['agent']['position'][selected_agent_index_cur_step, self.num_historical_steps:, :self.input_dim]])
|
1197 |
+
|
1198 |
+
# FIXME: under test!!! bos token index is -2
|
1199 |
+
next_state_idx = torch.cat([next_state_idx, next_state_idx_seed], dim=0).long()
|
1200 |
+
next_token_idx = torch.cat([next_token_idx, torch.zeros(num_new_agent, device=next_token_idx.device) - 2], dim=0).long()
|
1201 |
+
mask = torch.cat([mask, torch.ones(num_new_agent, num_step, device=mask.device)], dim=0).bool()
|
1202 |
+
temporal_mask = torch.cat([temporal_mask, torch.ones(num_new_agent, num_step, device=temporal_mask.device)], dim=0).bool()
|
1203 |
+
interact_mask = torch.cat([interact_mask, torch.ones(num_new_agent, num_step, device=interact_mask.device)], dim=0).bool()
|
1204 |
+
|
1205 |
+
# new_pos_a = ego_pos_a[None].repeat(num_new_agent, 1, 1)
|
1206 |
+
# new_head_a = ego_head_a[None].repeat(num_new_agent, 1)
|
1207 |
+
new_pos_a = torch.zeros(num_new_agent, num_step, 2, device=pos_a.device)
|
1208 |
+
new_head_a = torch.zeros(num_new_agent, num_step, device=pos_a.device)
|
1209 |
+
new_state_a = torch.zeros(num_new_agent, num_step, device=state_a.device)
|
1210 |
+
new_shape_a = torch.full((num_new_agent, num_step, 3), self.invalid_shape_value, device=pos_a.device)
|
1211 |
+
new_type_a = torch.full((num_new_agent, num_step), self.all_agent_type.index('invalid'), device=pos_a.device)
|
1212 |
+
|
1213 |
+
if num_new_agent > 0:
|
1214 |
+
gt_bos_pos_a = data['agent']['position'][seed_agent_index_cur_step[:num_new_agent], (self.num_historical_steps - 1) // self.shift + t]
|
1215 |
+
new_pos_a[:, (self.num_historical_steps - 1) // self.shift + t] = gt_bos_pos_a[:, :2].clone()
|
1216 |
+
pos_a = torch.cat([pos_a, new_pos_a], dim=0)
|
1217 |
+
|
1218 |
+
gt_bos_head_a = data['agent']['heading'][seed_agent_index_cur_step[:num_new_agent], (self.num_historical_steps - 1) // self.shift + t]
|
1219 |
+
new_head_a[:, (self.num_historical_steps - 1) // self.shift + t] = gt_bos_head_a.clone()
|
1220 |
+
head_a = torch.cat([head_a, new_head_a], dim=0)
|
1221 |
+
|
1222 |
+
gt_bos_shape_a = data['agent']['shape'][seed_agent_index_cur_step[:num_new_agent], self.num_historical_steps - 1]
|
1223 |
+
gt_bos_type_a = data['agent']['type'][seed_agent_index_cur_step[:num_new_agent]]
|
1224 |
+
new_shape_a[:, (self.num_historical_steps - 1) // self.shift + t:] = gt_bos_shape_a.clone()[:, None]
|
1225 |
+
new_type_a[:, (self.num_historical_steps - 1) // self.shift + t:] = gt_bos_type_a.clone()[:, None]
|
1226 |
+
# new_type_a[:, (self.num_historical_steps - 1) // self.shift + t] = next_type_idx_seed
|
1227 |
+
pred_type = torch.cat([pred_type, new_type_a[:, (self.num_historical_steps - 1) // self.shift + t]])
|
1228 |
+
|
1229 |
+
new_state_a[:, (self.num_historical_steps - 1) // self.shift + t] = self.enter_state
|
1230 |
+
state_a = torch.cat([state_a, new_state_a], dim=0)
|
1231 |
+
|
1232 |
+
mask[-num_new_agent:, :(self.num_historical_steps - 1) // self.shift + t + 1] = 0
|
1233 |
+
interact_mask[-num_new_agent:, :(self.num_historical_steps - 1) // self.shift + t] = 0
|
1234 |
+
|
1235 |
+
# update all steps
|
1236 |
+
new_pred_traj = torch.zeros(num_new_agent, self.num_recurrent_steps_val, 2, device=pos_a.device)
|
1237 |
+
new_pred_traj[:, t * 5 : (t + 1) * 5] = new_pos_a[:, (self.num_historical_steps - 1) // self.shift + t][:, None].repeat(1, 5, 1)
|
1238 |
+
pred_traj = torch.cat([pred_traj, new_pred_traj], dim=0)
|
1239 |
+
|
1240 |
+
new_pred_head = torch.zeros(num_new_agent, self.num_recurrent_steps_val, device=pos_a.device)
|
1241 |
+
new_pred_head[:, t * 5 : (t + 1) * 5] = new_head_a[:, (self.num_historical_steps - 1) // self.shift + t][:, None].repeat(1, 5)
|
1242 |
+
pred_head = torch.cat([pred_head, new_pred_head], dim=0)
|
1243 |
+
|
1244 |
+
new_pred_state = torch.zeros(num_new_agent, self.num_recurrent_steps_val, device=pos_a.device)
|
1245 |
+
new_pred_state[:, t * 5 : (t + 1) * 5] = next_state_idx_seed[:, None].repeat(1, 5)
|
1246 |
+
pred_state = torch.cat([pred_state, new_pred_state], dim=0)
|
1247 |
+
|
1248 |
+
# handle the position/heading of bos token
|
1249 |
+
# bos_pl_pos = pos_pl[next_bos_pl_idx[is_bos].long()]
|
1250 |
+
# bos_offset_pos = discretize_neighboring(neighbor_index=next_bos_offset_idx[is_bos])
|
1251 |
+
# pos_a[is_bos, (self.num_historical_steps - 1) // self.shift + t] += (bos_pl_pos + bos_offset_pos)
|
1252 |
+
# # headings before bos token remains 0 which align with training process
|
1253 |
+
# head_a[is_bos, (self.num_historical_steps - 1) // self.shift + t] += 0.
|
1254 |
+
|
1255 |
+
# add new agents token embeddings
|
1256 |
+
agent_token_emb = torch.cat([agent_token_emb, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())[None, :].repeat(num_new_agent, num_step, 1)])
|
1257 |
+
veh_mask = torch.cat([veh_mask, next_type_idx_seed == self.seed_agent_type.index('veh')])
|
1258 |
+
ped_mask = torch.cat([ped_mask, next_type_idx_seed == self.seed_agent_type.index('ped')])
|
1259 |
+
cyc_mask = torch.cat([cyc_mask, next_type_idx_seed == self.seed_agent_type.index('cyc')])
|
1260 |
+
|
1261 |
+
# add new agents trajectory embeddings
|
1262 |
+
trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(pos_a.device).to(torch.float)
|
1263 |
+
trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(pos_a.device).to(torch.float)
|
1264 |
+
trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(pos_a.device).to(torch.float)
|
1265 |
+
|
1266 |
+
new_agent_token_traj_all = torch.zeros((num_new_agent, self.token_size, self.shift + 1, 4, 2), device=pos_a.device)
|
1267 |
+
trajectory_token_all_veh = torch.from_numpy(self.trajectory_token_all['veh']).clone().to(pos_a.device).to(torch.float)
|
1268 |
+
trajectory_token_all_ped = torch.from_numpy(self.trajectory_token_all['ped']).clone().to(pos_a.device).to(torch.float)
|
1269 |
+
trajectory_token_all_cyc = torch.from_numpy(self.trajectory_token_all['cyc']).clone().to(pos_a.device).to(torch.float)
|
1270 |
+
new_agent_token_traj_all[next_type_idx_seed == 0] = torch.cat(
|
1271 |
+
[trajectory_token_all_veh[:, :self.shift], trajectory_token_veh[:, None, ...]], dim=1)
|
1272 |
+
new_agent_token_traj_all[next_type_idx_seed == 1] = torch.cat(
|
1273 |
+
[trajectory_token_all_ped[:, :self.shift], trajectory_token_ped[:, None, ...]], dim=1)
|
1274 |
+
new_agent_token_traj_all[next_type_idx_seed == 2] = torch.cat(
|
1275 |
+
[trajectory_token_all_cyc[:, :self.shift], trajectory_token_cyc[:, None, ...]], dim=1)
|
1276 |
+
|
1277 |
+
agent_token_traj_all = torch.cat([agent_token_traj_all, new_agent_token_traj_all], dim=0)
|
1278 |
+
|
1279 |
+
# add new agents categorical embeddings
|
1280 |
+
new_categorical_embs = [self.type_a_emb(new_type_a.reshape(-1).long()), self.shape_emb(new_shape_a.reshape(-1, 3))]
|
1281 |
+
categorical_embs = [torch.cat([categorical_embs[0], new_categorical_embs[0]], dim=0),
|
1282 |
+
torch.cat([categorical_embs[1], new_categorical_embs[1]], dim=0)]
|
1283 |
+
|
1284 |
+
# update token embeddings of current step
|
1285 |
+
agent_token_emb[veh_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_veh[
|
1286 |
+
next_token_idx[veh_mask]]
|
1287 |
+
agent_token_emb[ped_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_ped[
|
1288 |
+
next_token_idx[ped_mask]]
|
1289 |
+
agent_token_emb[cyc_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_cyc[
|
1290 |
+
next_token_idx[cyc_mask]]
|
1291 |
+
|
1292 |
+
motion_vector_a, head_vector_a = self.build_vector_a(pos_a, head_a, state_a)
|
1293 |
+
|
1294 |
+
motion_vector_a[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0.
|
1295 |
+
head_vector_a[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0.
|
1296 |
+
x_a = torch.stack(
|
1297 |
+
[torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1),
|
1298 |
+
angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2])], dim=-1)
|
1299 |
+
|
1300 |
+
x_b = x_a.clone()
|
1301 |
+
x_a = self.x_a_emb(continuous_inputs=x_a.view(-1, x_a.size(-1)),
|
1302 |
+
categorical_embs=categorical_embs)
|
1303 |
+
x_a = x_a.view(-1, num_step, self.hidden_dim)
|
1304 |
+
|
1305 |
+
s_a = self.state_a_emb(state_a.reshape(-1).long()).reshape(num_agent + num_new_agent, num_step, self.hidden_dim)
|
1306 |
+
feat_a = torch.cat((agent_token_emb, x_a, s_a), dim=-1)
|
1307 |
+
feat_a = self.fusion_emb(feat_a)
|
1308 |
+
|
1309 |
+
# if t >= 15:
|
1310 |
+
# print(f"inference {t}")
|
1311 |
+
# is_invalid = state_a == self.invalid_state
|
1312 |
+
# is_bos = state_a == self.enter_state
|
1313 |
+
# is_eos = state_a == self.exit_state
|
1314 |
+
# bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
|
1315 |
+
# eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1))
|
1316 |
+
# mask = torch.arange(num_step).expand(num_agent + num_new_agent, -1).to(state_a.device)
|
1317 |
+
# mask = (mask >= bos_index[:, None]) & (mask <= eos_index[:, None] + 1)
|
1318 |
+
# is_invalid[mask] = False
|
1319 |
+
# is_invalid[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = False
|
1320 |
+
# print(pos_a[:, :((self.num_historical_steps - 1) // self.shift + 1 + t)])
|
1321 |
+
# print(state_a[:, :((self.num_historical_steps - 1) // self.shift + 1 + t)])
|
1322 |
+
# print(pos_a[is_invalid][:, 0])
|
1323 |
+
# print(head_a[is_invalid])
|
1324 |
+
# print(categorical_embs[0].sum(dim=-1)[is_invalid.reshape(-1)])
|
1325 |
+
# print(categorical_embs[1].sum(dim=-1)[is_invalid.reshape(-1)])
|
1326 |
+
# print(motion_vector_a[is_invalid][:, 0])
|
1327 |
+
# print(head_vector_a[is_invalid][:, 0])
|
1328 |
+
# print(x_b.sum(dim=-1)[is_invalid])
|
1329 |
+
# print(x_a.sum(dim=-1)[is_invalid])
|
1330 |
+
# for a in range(state_a.shape[0]):
|
1331 |
+
# print(f"agent: {a}")
|
1332 |
+
# print(state_a[a])
|
1333 |
+
# print(is_invalid[a].long())
|
1334 |
+
# print(pos_a[a, :, 0])
|
1335 |
+
# print(motion_vector_a[a, :, 0])
|
1336 |
+
# print(s_a.sum(dim=-1)[is_invalid])
|
1337 |
+
# print(feat_a.sum(dim=-1)[is_invalid])
|
1338 |
+
|
1339 |
+
# replace the features of steps before bos of valid agents with the corresponding seed agent features
|
1340 |
+
# is_bos = state_a == self.enter_state
|
1341 |
+
# bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(num_step))
|
1342 |
+
# before_bos_mask = torch.arange(num_step).expand(num_agent + num_new_agent, -1).to(state_a.device) < bos_index[:, None]
|
1343 |
+
# feat_a[before_bos_mask] = feat_seed.repeat(num_agent + num_new_agent, 1, 1)[before_bos_mask]
|
1344 |
+
|
1345 |
+
# build seed agent features
|
1346 |
+
motion_vector_seed = motion_vector_a[av_index : av_index + 1]
|
1347 |
+
head_vector_seed = head_vector_a[av_index : av_index + 1]
|
1348 |
+
feat_seed = self.build_invalid_agent_feature(num_step, pos_a.device, type_index=self.all_agent_type.index('seed'),
|
1349 |
+
motion_vector=motion_vector_seed, head_vector=head_vector_seed)
|
1350 |
+
# print(f"inference {t}")
|
1351 |
+
# print(feat_seed.sum(dim=-1))
|
1352 |
+
|
1353 |
+
next_token_idx_list.append(next_token_idx[:, None])
|
1354 |
+
next_state_idx_list.append(next_state_idx[:, None])
|
1355 |
+
# next_bos_pl_idx_list.append(next_bos_pl_idx[:, None])
|
1356 |
+
# next_bos_offset_idx_list.append(next_bos_offset_idx[:, None])
|
1357 |
+
|
1358 |
+
# TODO: check this
|
1359 |
+
# agent_valid_mask[agent_category != 3] = False
|
1360 |
+
|
1361 |
+
# print("inference")
|
1362 |
+
# is_invalid = state_a == self.invalid_state
|
1363 |
+
# is_bos = state_a == self.enter_state
|
1364 |
+
# is_eos = state_a == self.exit_state
|
1365 |
+
# bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
|
1366 |
+
# eos_index = torch.where(is_eos.any(dim=1), torch.argmax(is_eos.long(), dim=1), torch.tensor(num_step - 1))
|
1367 |
+
# mask = torch.arange(num_step).expand(num_agent, -1).to(state_a.device)
|
1368 |
+
# mask = (mask >= bos_index[:, None]) & (mask <= eos_index[:, None] + 1)
|
1369 |
+
# is_invalid[mask] = False
|
1370 |
+
# print(feat_a.sum(dim=-1)[is_invalid])
|
1371 |
+
# print(pos_a[is_invalid][: 0])
|
1372 |
+
# print(head_a[is_invalid])
|
1373 |
+
# exit(1)
|
1374 |
+
|
1375 |
+
num_agent = pos_a.shape[0]
|
1376 |
+
for i in range(len(next_token_idx_list)):
|
1377 |
+
next_token_idx_list[i] = torch.cat([next_token_idx_list[i], torch.zeros(num_agent - next_token_idx_list[i].shape[0], 1, device=next_token_idx_list[i].device) - 1], dim=0).long()
|
1378 |
+
next_state_idx_list[i] = torch.cat([next_state_idx_list[i], torch.zeros(num_agent - next_state_idx_list[i].shape[0], 1, device=next_state_idx_list[i].device)], dim=0).long()
|
1379 |
+
|
1380 |
+
# eval mask
|
1381 |
+
next_token_eval_mask = agent_valid_mask.clone()
|
1382 |
+
next_state_eval_mask = agent_valid_mask.clone()
|
1383 |
+
bos_token_index = torch.nonzero(agent_state_index == self.enter_state)
|
1384 |
+
eos_token_index = torch.nonzero(agent_state_index == self.exit_state)
|
1385 |
+
|
1386 |
+
next_token_eval_mask[bos_token_index[:, 0], bos_token_index[:, 1]] = 1
|
1387 |
+
|
1388 |
+
for bos_token_index_i in bos_token_index:
|
1389 |
+
next_state_eval_mask[bos_token_index_i[0], :bos_token_index_i[1] + 2] = 1
|
1390 |
+
for eos_token_index_i in eos_token_index:
|
1391 |
+
next_state_eval_mask[eos_token_index_i[0], eos_token_index_i[1]:] = 1
|
1392 |
+
|
1393 |
+
# add history attributes
|
1394 |
+
num_agent = pred_traj.shape[0]
|
1395 |
+
num_init_agent = filter_mask.sum()
|
1396 |
+
|
1397 |
+
pred_traj = torch.cat([pred_traj, torch.zeros(num_agent, self.num_historical_steps - 1, *(pred_traj.shape[2:]), device=pred_traj.device)], dim=1)
|
1398 |
+
pred_head = torch.cat([pred_head, torch.zeros(num_agent, self.num_historical_steps - 1, *(pred_head.shape[2:]), device=pred_head.device)], dim=1)
|
1399 |
+
pred_state = torch.cat([pred_state, torch.zeros(num_agent, self.num_historical_steps - 1, *(pred_state.shape[2:]), device=pred_state.device)], dim=1)
|
1400 |
+
|
1401 |
+
pred_state[:num_init_agent, :self.num_historical_steps - 1] = data['agent']['state_idx'][filter_mask, :(self.num_historical_steps - 1) // self.shift].repeat_interleave(repeats=self.shift, dim=1)
|
1402 |
+
|
1403 |
+
historical_token_idx = data['agent']['token_idx'][filter_mask, :(self.num_historical_steps - 1) // self.shift]
|
1404 |
+
historical_token_idx[historical_token_idx < 0] = 0
|
1405 |
+
historical_token_traj_all = torch.gather(agent_token_traj_all, 1, historical_token_idx[..., None, None, None].expand(-1, -1, 6, 4, 2))
|
1406 |
+
init_theta = head_a[:num_init_agent, 0]
|
1407 |
+
cos, sin = init_theta.cos(), init_theta.sin()
|
1408 |
+
rot_mat = torch.zeros((num_init_agent, 2, 2), device=init_theta.device)
|
1409 |
+
rot_mat[:, 0, 0] = cos
|
1410 |
+
rot_mat[:, 0, 1] = sin
|
1411 |
+
rot_mat[:, 1, 0] = -sin
|
1412 |
+
rot_mat[:, 1, 1] = cos
|
1413 |
+
historical_token_traj_all = torch.bmm(historical_token_traj_all.view(-1, 4, 2),
|
1414 |
+
rot_mat[:, None, None, ...].repeat(1, (self.num_historical_steps - 1) // self.shift, self.shift + 1, 1, 1).view(
|
1415 |
+
-1, 2, 2)).view(num_init_agent, (self.num_historical_steps - 1) // self.shift, self.shift + 1, 4, 2)
|
1416 |
+
historical_token_traj_all = historical_token_traj_all + pos_a[:num_init_agent, 0, :][:, None, None, None, ...]
|
1417 |
+
pred_traj[:num_init_agent, :self.num_historical_steps - 1] = historical_token_traj_all[:, :, 1:, ...].clone().mean(dim=3).reshape(num_init_agent, -1, 2)
|
1418 |
+
diff_xy = historical_token_traj_all[..., 1:, 0, :] - historical_token_traj_all[..., 1:, 3, :]
|
1419 |
+
pred_head[:num_init_agent, :self.num_historical_steps - 1] = torch.arctan2(diff_xy[..., 1], diff_xy[..., 0]).reshape(num_init_agent, -1)
|
1420 |
+
|
1421 |
+
return {
|
1422 |
+
'av_index': av_index,
|
1423 |
+
'valid_mask': agent_valid_mask[:, self.num_historical_steps:],
|
1424 |
+
'pos_a': pos_a[:, (self.num_historical_steps - 1) // self.shift:],
|
1425 |
+
'head_a': head_a[:, (self.num_historical_steps - 1) // self.shift:],
|
1426 |
+
'gt_traj': gt_traj,
|
1427 |
+
'pred_traj': pred_traj,
|
1428 |
+
'pred_head': pred_head,
|
1429 |
+
'pred_type': list(map(lambda i: self.seed_agent_type[i], pred_type.tolist())),
|
1430 |
+
'pred_state': pred_state,
|
1431 |
+
'next_token_idx': torch.cat(next_token_idx_list, dim=-1), # (num_agent, num_step)
|
1432 |
+
'next_token_idx_gt': agent_token_index,
|
1433 |
+
'next_state_idx': torch.cat(next_state_idx_list, dim=-1) if len(next_state_idx_list) > 0 else None,
|
1434 |
+
'next_state_idx_gt': agent_state_index,
|
1435 |
+
'next_token_eval_mask': next_token_eval_mask,
|
1436 |
+
'next_state_eval_mask': next_state_eval_mask,
|
1437 |
+
# 'next_bos_pl_idx': torch.cat(next_bos_pl_idx_list, dim=-1),
|
1438 |
+
# 'next_bos_offset_idx': torch.cat(next_bos_offset_idx_list, dim=-1),
|
1439 |
+
}
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/modules/layers.py
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from typing import List, Optional, Tuple, Union
|
7 |
+
from torch_geometric.nn.conv import MessagePassing
|
8 |
+
from torch_geometric.utils import softmax
|
9 |
+
|
10 |
+
from dev.utils.func import weight_init
|
11 |
+
|
12 |
+
|
13 |
+
__all__ = ['AttentionLayer', 'FourierEmbedding', 'MLPEmbedding', 'MLPLayer', 'MappingNetwork']
|
14 |
+
|
15 |
+
|
16 |
+
class AttentionLayer(MessagePassing):
|
17 |
+
|
18 |
+
def __init__(self,
|
19 |
+
hidden_dim: int,
|
20 |
+
num_heads: int,
|
21 |
+
head_dim: int,
|
22 |
+
dropout: float,
|
23 |
+
bipartite: bool,
|
24 |
+
has_pos_emb: bool,
|
25 |
+
**kwargs) -> None:
|
26 |
+
super(AttentionLayer, self).__init__(aggr='add', node_dim=0, **kwargs)
|
27 |
+
self.num_heads = num_heads
|
28 |
+
self.head_dim = head_dim
|
29 |
+
self.has_pos_emb = has_pos_emb
|
30 |
+
self.scale = head_dim ** -0.5
|
31 |
+
|
32 |
+
self.to_q = nn.Linear(hidden_dim, head_dim * num_heads)
|
33 |
+
self.to_k = nn.Linear(hidden_dim, head_dim * num_heads, bias=False)
|
34 |
+
self.to_v = nn.Linear(hidden_dim, head_dim * num_heads)
|
35 |
+
if has_pos_emb:
|
36 |
+
self.to_k_r = nn.Linear(hidden_dim, head_dim * num_heads, bias=False)
|
37 |
+
self.to_v_r = nn.Linear(hidden_dim, head_dim * num_heads)
|
38 |
+
self.to_s = nn.Linear(hidden_dim, head_dim * num_heads)
|
39 |
+
self.to_g = nn.Linear(head_dim * num_heads + hidden_dim, head_dim * num_heads)
|
40 |
+
self.to_out = nn.Linear(head_dim * num_heads, hidden_dim)
|
41 |
+
self.attn_drop = nn.Dropout(dropout)
|
42 |
+
self.ff_mlp = nn.Sequential(
|
43 |
+
nn.Linear(hidden_dim, hidden_dim * 4),
|
44 |
+
nn.ReLU(inplace=True),
|
45 |
+
nn.Dropout(dropout),
|
46 |
+
nn.Linear(hidden_dim * 4, hidden_dim),
|
47 |
+
)
|
48 |
+
if bipartite:
|
49 |
+
self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim)
|
50 |
+
self.attn_prenorm_x_dst = nn.LayerNorm(hidden_dim)
|
51 |
+
else:
|
52 |
+
self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim)
|
53 |
+
self.attn_prenorm_x_dst = self.attn_prenorm_x_src
|
54 |
+
if has_pos_emb:
|
55 |
+
self.attn_prenorm_r = nn.LayerNorm(hidden_dim)
|
56 |
+
self.attn_postnorm = nn.LayerNorm(hidden_dim)
|
57 |
+
self.ff_prenorm = nn.LayerNorm(hidden_dim)
|
58 |
+
self.ff_postnorm = nn.LayerNorm(hidden_dim)
|
59 |
+
self.apply(weight_init)
|
60 |
+
|
61 |
+
def forward(self,
|
62 |
+
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
63 |
+
r: Optional[torch.Tensor],
|
64 |
+
edge_index: torch.Tensor) -> torch.Tensor:
|
65 |
+
if isinstance(x, torch.Tensor):
|
66 |
+
x_src = x_dst = self.attn_prenorm_x_src(x)
|
67 |
+
else:
|
68 |
+
x_src, x_dst = x
|
69 |
+
x_src = self.attn_prenorm_x_src(x_src)
|
70 |
+
x_dst = self.attn_prenorm_x_dst(x_dst)
|
71 |
+
x = x[1]
|
72 |
+
if self.has_pos_emb and r is not None:
|
73 |
+
r = self.attn_prenorm_r(r)
|
74 |
+
x = x + self.attn_postnorm(self._attn_block(x_src, x_dst, r, edge_index))
|
75 |
+
x = x + self.ff_postnorm(self._ff_block(self.ff_prenorm(x)))
|
76 |
+
return x
|
77 |
+
|
78 |
+
def message(self,
|
79 |
+
q_i: torch.Tensor,
|
80 |
+
k_j: torch.Tensor,
|
81 |
+
v_j: torch.Tensor,
|
82 |
+
r: Optional[torch.Tensor],
|
83 |
+
index: torch.Tensor,
|
84 |
+
ptr: Optional[torch.Tensor]) -> torch.Tensor:
|
85 |
+
if self.has_pos_emb and r is not None:
|
86 |
+
k_j = k_j + self.to_k_r(r).view(-1, self.num_heads, self.head_dim)
|
87 |
+
v_j = v_j + self.to_v_r(r).view(-1, self.num_heads, self.head_dim)
|
88 |
+
sim = (q_i * k_j).sum(dim=-1) * self.scale
|
89 |
+
attn = softmax(sim, index, ptr)
|
90 |
+
self.attention_weight = attn.sum(-1).detach()
|
91 |
+
attn = self.attn_drop(attn)
|
92 |
+
return v_j * attn.unsqueeze(-1)
|
93 |
+
|
94 |
+
def update(self,
|
95 |
+
inputs: torch.Tensor,
|
96 |
+
x_dst: torch.Tensor) -> torch.Tensor:
|
97 |
+
inputs = inputs.view(-1, self.num_heads * self.head_dim)
|
98 |
+
g = torch.sigmoid(self.to_g(torch.cat([inputs, x_dst], dim=-1)))
|
99 |
+
return inputs + g * (self.to_s(x_dst) - inputs)
|
100 |
+
|
101 |
+
def _attn_block(self,
|
102 |
+
x_src: torch.Tensor,
|
103 |
+
x_dst: torch.Tensor,
|
104 |
+
r: Optional[torch.Tensor],
|
105 |
+
edge_index: torch.Tensor) -> torch.Tensor:
|
106 |
+
q = self.to_q(x_dst).view(-1, self.num_heads, self.head_dim)
|
107 |
+
k = self.to_k(x_src).view(-1, self.num_heads, self.head_dim)
|
108 |
+
v = self.to_v(x_src).view(-1, self.num_heads, self.head_dim)
|
109 |
+
agg = self.propagate(edge_index=edge_index, x_dst=x_dst, q=q, k=k, v=v, r=r)
|
110 |
+
return self.to_out(agg)
|
111 |
+
|
112 |
+
def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
|
113 |
+
return self.ff_mlp(x)
|
114 |
+
|
115 |
+
|
116 |
+
class FourierEmbedding(nn.Module):
|
117 |
+
|
118 |
+
def __init__(self,
|
119 |
+
input_dim: int,
|
120 |
+
hidden_dim: int,
|
121 |
+
num_freq_bands: int) -> None:
|
122 |
+
super(FourierEmbedding, self).__init__()
|
123 |
+
self.input_dim = input_dim
|
124 |
+
self.hidden_dim = hidden_dim
|
125 |
+
|
126 |
+
self.freqs = nn.Embedding(input_dim, num_freq_bands) if input_dim != 0 else None
|
127 |
+
self.mlps = nn.ModuleList(
|
128 |
+
[nn.Sequential(
|
129 |
+
nn.Linear(num_freq_bands * 2 + 1, hidden_dim),
|
130 |
+
nn.LayerNorm(hidden_dim),
|
131 |
+
nn.ReLU(inplace=True),
|
132 |
+
nn.Linear(hidden_dim, hidden_dim),
|
133 |
+
)
|
134 |
+
for _ in range(input_dim)])
|
135 |
+
self.to_out = nn.Sequential(
|
136 |
+
nn.LayerNorm(hidden_dim),
|
137 |
+
nn.ReLU(inplace=True),
|
138 |
+
nn.Linear(hidden_dim, hidden_dim),
|
139 |
+
)
|
140 |
+
self.apply(weight_init)
|
141 |
+
|
142 |
+
def forward(self,
|
143 |
+
continuous_inputs: Optional[torch.Tensor] = None,
|
144 |
+
categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
|
145 |
+
if continuous_inputs is None:
|
146 |
+
if categorical_embs is not None:
|
147 |
+
x = torch.stack(categorical_embs).sum(dim=0)
|
148 |
+
else:
|
149 |
+
raise ValueError('Both continuous_inputs and categorical_embs are None')
|
150 |
+
else:
|
151 |
+
x = continuous_inputs.unsqueeze(-1) * self.freqs.weight * 2 * math.pi
|
152 |
+
# Warning: if your data are noisy, don't use learnable sinusoidal embedding
|
153 |
+
x = torch.cat([x.cos(), x.sin(), continuous_inputs.unsqueeze(-1)], dim=-1)
|
154 |
+
continuous_embs: List[Optional[torch.Tensor]] = [None] * self.input_dim
|
155 |
+
for i in range(self.input_dim):
|
156 |
+
continuous_embs[i] = self.mlps[i](x[:, i])
|
157 |
+
x = torch.stack(continuous_embs).sum(dim=0)
|
158 |
+
if categorical_embs is not None:
|
159 |
+
x = x + torch.stack(categorical_embs).sum(dim=0)
|
160 |
+
return self.to_out(x)
|
161 |
+
|
162 |
+
|
163 |
+
class MLPEmbedding(nn.Module):
|
164 |
+
def __init__(self,
|
165 |
+
input_dim: int,
|
166 |
+
hidden_dim: int) -> None:
|
167 |
+
super(MLPEmbedding, self).__init__()
|
168 |
+
self.input_dim = input_dim
|
169 |
+
self.hidden_dim = hidden_dim
|
170 |
+
self.mlp = nn.Sequential(
|
171 |
+
nn.Linear(input_dim, 128),
|
172 |
+
nn.LayerNorm(128),
|
173 |
+
nn.ReLU(inplace=True),
|
174 |
+
nn.Linear(128, hidden_dim),
|
175 |
+
nn.LayerNorm(hidden_dim),
|
176 |
+
nn.ReLU(inplace=True),
|
177 |
+
nn.Linear(hidden_dim, hidden_dim))
|
178 |
+
self.apply(weight_init)
|
179 |
+
|
180 |
+
def forward(self,
|
181 |
+
continuous_inputs: Optional[torch.Tensor] = None,
|
182 |
+
categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
|
183 |
+
if continuous_inputs is None:
|
184 |
+
if categorical_embs is not None:
|
185 |
+
x = torch.stack(categorical_embs).sum(dim=0)
|
186 |
+
else:
|
187 |
+
raise ValueError('Both continuous_inputs and categorical_embs are None')
|
188 |
+
else:
|
189 |
+
x = self.mlp(continuous_inputs)
|
190 |
+
if categorical_embs is not None:
|
191 |
+
x = x + torch.stack(categorical_embs).sum(dim=0)
|
192 |
+
return x
|
193 |
+
|
194 |
+
|
195 |
+
class MLPLayer(nn.Module):
|
196 |
+
|
197 |
+
def __init__(self,
|
198 |
+
input_dim: int,
|
199 |
+
hidden_dim: int=None,
|
200 |
+
output_dim: int=None) -> None:
|
201 |
+
super(MLPLayer, self).__init__()
|
202 |
+
|
203 |
+
if hidden_dim is None:
|
204 |
+
hidden_dim = output_dim
|
205 |
+
|
206 |
+
self.mlp = nn.Sequential(
|
207 |
+
nn.Linear(input_dim, hidden_dim),
|
208 |
+
nn.LayerNorm(hidden_dim),
|
209 |
+
nn.ReLU(inplace=True),
|
210 |
+
nn.Linear(hidden_dim, output_dim),
|
211 |
+
)
|
212 |
+
self.apply(weight_init)
|
213 |
+
|
214 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
215 |
+
return self.mlp(x)
|
216 |
+
|
217 |
+
|
218 |
+
class MappingNetwork(nn.Module):
|
219 |
+
def __init__(self, z_dim, w_dim, layer_dim=None, num_layers=8):
|
220 |
+
super().__init__()
|
221 |
+
|
222 |
+
if not layer_dim:
|
223 |
+
layer_dim = w_dim
|
224 |
+
layer_dims = [z_dim] + [layer_dim] * (num_layers - 1) + [w_dim]
|
225 |
+
|
226 |
+
layers = []
|
227 |
+
for i in range(num_layers):
|
228 |
+
layers.extend([
|
229 |
+
nn.Linear(layer_dims[i], layer_dims[i + 1]),
|
230 |
+
nn.LeakyReLU(),
|
231 |
+
])
|
232 |
+
self.layers = nn.Sequential(*layers)
|
233 |
+
|
234 |
+
def forward(self, z):
|
235 |
+
w = self.layers(z)
|
236 |
+
return w
|
237 |
+
|
238 |
+
|
239 |
+
# class FocalLoss:
|
240 |
+
# def __init__(self, alpha: float=.25, gamma: float=2):
|
241 |
+
# self.alpha = alpha
|
242 |
+
# self.gamma = gamma
|
243 |
+
|
244 |
+
# def __call__(self, inputs, targets):
|
245 |
+
# prob = inputs.sigmoid()
|
246 |
+
# ce_loss = F.binary_cross_entropy_with_logits(inputs, targets.float(), reduction='none')
|
247 |
+
# p_t = prob * targets + (1 - prob) * (1 - targets)
|
248 |
+
# loss = ce_loss * ((1 - p_t) ** self.gamma)
|
249 |
+
|
250 |
+
# if self.alpha >= 0:
|
251 |
+
# alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
|
252 |
+
# loss = alpha_t * loss
|
253 |
+
|
254 |
+
# return loss.mean()
|
255 |
+
|
256 |
+
|
257 |
+
class FocalLoss(nn.Module):
|
258 |
+
"""Focal Loss, as described in https://arxiv.org/abs/1708.02002.
|
259 |
+
It is essentially an enhancement to cross entropy loss and is
|
260 |
+
useful for classification tasks when there is a large class imbalance.
|
261 |
+
x is expected to contain raw, unnormalized scores for each class.
|
262 |
+
y is expected to contain class labels.
|
263 |
+
Shape:
|
264 |
+
- x: (batch_size, C) or (batch_size, C, d1, d2, ..., dK), K > 0.
|
265 |
+
- y: (batch_size,) or (batch_size, d1, d2, ..., dK), K > 0.
|
266 |
+
"""
|
267 |
+
|
268 |
+
def __init__(
|
269 |
+
self,
|
270 |
+
alpha: Optional[torch.Tensor] = None,
|
271 |
+
gamma: float = 0.0,
|
272 |
+
reduction: str = "mean",
|
273 |
+
ignore_index: int = -100,
|
274 |
+
):
|
275 |
+
"""Constructor.
|
276 |
+
Args:
|
277 |
+
alpha (Tensor, optional): Weights for each class. Defaults to None.
|
278 |
+
gamma (float, optional): A constant, as described in the paper.
|
279 |
+
Defaults to 0.
|
280 |
+
reduction (str, optional): 'mean', 'sum' or 'none'.
|
281 |
+
Defaults to 'mean'.
|
282 |
+
ignore_index (int, optional): class label to ignore.
|
283 |
+
Defaults to -100.
|
284 |
+
"""
|
285 |
+
if reduction not in ("mean", "sum", "none"):
|
286 |
+
raise ValueError('Reduction must be one of: "mean", "sum", "none".')
|
287 |
+
|
288 |
+
super().__init__()
|
289 |
+
self.alpha = alpha
|
290 |
+
self.gamma = gamma
|
291 |
+
self.ignore_index = ignore_index
|
292 |
+
self.reduction = reduction
|
293 |
+
|
294 |
+
self.nll_loss = nn.NLLLoss(
|
295 |
+
weight=alpha, reduction="none", ignore_index=ignore_index
|
296 |
+
)
|
297 |
+
|
298 |
+
def __repr__(self):
|
299 |
+
arg_keys = ["alpha", "gamma", "ignore_index", "reduction"]
|
300 |
+
arg_vals = [self.__dict__[k] for k in arg_keys]
|
301 |
+
arg_strs = [f"{k}={v}" for k, v in zip(arg_keys, arg_vals)]
|
302 |
+
arg_str = ", ".join(arg_strs)
|
303 |
+
return f"{type(self).__name__}({arg_str})"
|
304 |
+
|
305 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
306 |
+
if x.ndim > 2:
|
307 |
+
# (N, C, d1, d2, ..., dK) --> (N * d1 * ... * dK, C)
|
308 |
+
c = x.shape[1]
|
309 |
+
x = x.permute(0, *range(2, x.ndim), 1).reshape(-1, c)
|
310 |
+
# (N, d1, d2, ..., dK) --> (N * d1 * ... * dK,)
|
311 |
+
y = y.view(-1)
|
312 |
+
|
313 |
+
unignored_mask = y != self.ignore_index
|
314 |
+
y = y[unignored_mask]
|
315 |
+
if len(y) == 0:
|
316 |
+
return 0.0
|
317 |
+
x = x[unignored_mask]
|
318 |
+
|
319 |
+
# compute weighted cross entropy term: -alpha * log(pt)
|
320 |
+
# (alpha is already part of self.nll_loss)
|
321 |
+
log_p = F.log_softmax(x, dim=-1)
|
322 |
+
ce = self.nll_loss(log_p, y)
|
323 |
+
|
324 |
+
# get true class column from each row
|
325 |
+
all_rows = torch.arange(len(x))
|
326 |
+
log_pt = log_p[all_rows, y]
|
327 |
+
|
328 |
+
# compute focal term: (1 - pt)^gamma
|
329 |
+
pt = log_pt.exp()
|
330 |
+
focal_term = (1 - pt) ** self.gamma
|
331 |
+
|
332 |
+
# the full loss: -alpha * ((1 - pt)^gamma) * log(pt)
|
333 |
+
loss = focal_term * ce
|
334 |
+
|
335 |
+
if self.reduction == "mean":
|
336 |
+
loss = loss.mean()
|
337 |
+
elif self.reduction == "sum":
|
338 |
+
loss = loss.sum()
|
339 |
+
|
340 |
+
return loss
|
341 |
+
|
342 |
+
|
343 |
+
class OccLoss(nn.Module):
|
344 |
+
|
345 |
+
# geo_scal_loss
|
346 |
+
def __init__(self):
|
347 |
+
super().__init__()
|
348 |
+
|
349 |
+
def forward(self, pred, target, mask=None):
|
350 |
+
|
351 |
+
nonempty_probs = torch.sigmoid(pred)
|
352 |
+
empty_probs = 1 - nonempty_probs
|
353 |
+
|
354 |
+
if mask is None:
|
355 |
+
mask = torch.ones_like(target).bool()
|
356 |
+
|
357 |
+
nonempty_target = target == 1
|
358 |
+
nonempty_target = nonempty_target[mask].float()
|
359 |
+
nonempty_probs = nonempty_probs[mask]
|
360 |
+
empty_probs = empty_probs[mask]
|
361 |
+
|
362 |
+
intersection = (nonempty_target * nonempty_probs).sum()
|
363 |
+
precision = intersection / nonempty_probs.sum()
|
364 |
+
recall = intersection / nonempty_target.sum()
|
365 |
+
spec = ((1 - nonempty_target) * (empty_probs)).sum() / (1 - nonempty_target).sum()
|
366 |
+
|
367 |
+
return (
|
368 |
+
F.binary_cross_entropy(precision, torch.ones_like(precision))
|
369 |
+
+ F.binary_cross_entropy(recall, torch.ones_like(recall))
|
370 |
+
+ F.binary_cross_entropy(spec, torch.ones_like(spec))
|
371 |
+
)
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/modules/map_decoder.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch_cluster import radius_graph
|
5 |
+
from torch_geometric.data import Batch
|
6 |
+
from torch_geometric.data import HeteroData
|
7 |
+
from torch_geometric.utils import subgraph
|
8 |
+
|
9 |
+
from dev.modules.layers import MLPLayer, AttentionLayer, FourierEmbedding, MLPEmbedding
|
10 |
+
from dev.utils.func import weight_init, wrap_angle, angle_between_2d_vectors
|
11 |
+
|
12 |
+
|
13 |
+
class SMARTMapDecoder(nn.Module):
|
14 |
+
|
15 |
+
def __init__(self,
|
16 |
+
dataset: str,
|
17 |
+
input_dim: int,
|
18 |
+
hidden_dim: int,
|
19 |
+
num_historical_steps: int,
|
20 |
+
pl2pl_radius: float,
|
21 |
+
num_freq_bands: int,
|
22 |
+
num_layers: int,
|
23 |
+
num_heads: int,
|
24 |
+
head_dim: int,
|
25 |
+
dropout: float,
|
26 |
+
map_token) -> None:
|
27 |
+
|
28 |
+
super(SMARTMapDecoder, self).__init__()
|
29 |
+
self.dataset = dataset
|
30 |
+
self.input_dim = input_dim
|
31 |
+
self.hidden_dim = hidden_dim
|
32 |
+
self.num_historical_steps = num_historical_steps
|
33 |
+
self.pl2pl_radius = pl2pl_radius
|
34 |
+
self.num_freq_bands = num_freq_bands
|
35 |
+
self.num_layers = num_layers
|
36 |
+
self.num_heads = num_heads
|
37 |
+
self.head_dim = head_dim
|
38 |
+
self.dropout = dropout
|
39 |
+
|
40 |
+
if input_dim == 2:
|
41 |
+
input_dim_r_pt2pt = 3
|
42 |
+
elif input_dim == 3:
|
43 |
+
input_dim_r_pt2pt = 4
|
44 |
+
else:
|
45 |
+
raise ValueError('{} is not a valid dimension'.format(input_dim))
|
46 |
+
|
47 |
+
self.type_pt_emb = nn.Embedding(17, hidden_dim)
|
48 |
+
self.side_pt_emb = nn.Embedding(4, hidden_dim)
|
49 |
+
self.polygon_type_emb = nn.Embedding(4, hidden_dim)
|
50 |
+
self.light_pl_emb = nn.Embedding(4, hidden_dim)
|
51 |
+
|
52 |
+
self.r_pt2pt_emb = FourierEmbedding(input_dim=input_dim_r_pt2pt, hidden_dim=hidden_dim,
|
53 |
+
num_freq_bands=num_freq_bands)
|
54 |
+
self.pt2pt_layers = nn.ModuleList(
|
55 |
+
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
|
56 |
+
bipartite=False, has_pos_emb=True) for _ in range(num_layers)]
|
57 |
+
)
|
58 |
+
self.token_size = 1024
|
59 |
+
self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
|
60 |
+
output_dim=self.token_size)
|
61 |
+
input_dim_token = 22
|
62 |
+
self.token_emb = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
|
63 |
+
self.map_token = map_token
|
64 |
+
self.apply(weight_init)
|
65 |
+
self.mask_pt = False
|
66 |
+
|
67 |
+
def maybe_autocast(self, dtype=torch.float32):
|
68 |
+
return torch.cuda.amp.autocast(dtype=dtype)
|
69 |
+
|
70 |
+
def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]:
|
71 |
+
pt_valid_mask = data['pt_token']['pt_valid_mask']
|
72 |
+
pt_pred_mask = data['pt_token']['pt_pred_mask']
|
73 |
+
pt_target_mask = data['pt_token']['pt_target_mask']
|
74 |
+
mask_s = pt_valid_mask
|
75 |
+
|
76 |
+
pos_pt = data['pt_token']['position'][:, :self.input_dim].contiguous()
|
77 |
+
orient_pt = data['pt_token']['orientation'].contiguous()
|
78 |
+
orient_vector_pt = torch.stack([orient_pt.cos(), orient_pt.sin()], dim=-1)
|
79 |
+
token_sample_pt = self.map_token['traj_src'].to(pos_pt.device).to(torch.float)
|
80 |
+
pt_token_emb_src = self.token_emb(token_sample_pt.view(token_sample_pt.shape[0], -1))
|
81 |
+
pt_token_emb = pt_token_emb_src[data['pt_token']['token_idx']]
|
82 |
+
|
83 |
+
x_pt = pt_token_emb
|
84 |
+
|
85 |
+
token2pl = data[('pt_token', 'to', 'map_polygon')]['edge_index']
|
86 |
+
token_light_type = data['map_polygon']['light_type'][token2pl[1]]
|
87 |
+
x_pt_categorical_embs = [self.type_pt_emb(data['pt_token']['type'].long()),
|
88 |
+
self.polygon_type_emb(data['pt_token']['pl_type'].long()),
|
89 |
+
self.light_pl_emb(token_light_type.long()),]
|
90 |
+
x_pt = x_pt + torch.stack(x_pt_categorical_embs).sum(dim=0)
|
91 |
+
edge_index_pt2pt = radius_graph(x=pos_pt[:, :2], r=self.pl2pl_radius,
|
92 |
+
batch=data['pt_token']['batch'] if isinstance(data, Batch) else None,
|
93 |
+
loop=False, max_num_neighbors=100)
|
94 |
+
if self.mask_pt:
|
95 |
+
edge_index_pt2pt = subgraph(subset=mask_s, edge_index=edge_index_pt2pt)[0]
|
96 |
+
rel_pos_pt2pt = pos_pt[edge_index_pt2pt[0]] - pos_pt[edge_index_pt2pt[1]]
|
97 |
+
rel_orient_pt2pt = wrap_angle(orient_pt[edge_index_pt2pt[0]] - orient_pt[edge_index_pt2pt[1]])
|
98 |
+
if self.input_dim == 2:
|
99 |
+
r_pt2pt = torch.stack(
|
100 |
+
[torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1),
|
101 |
+
angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]],
|
102 |
+
nbr_vector=rel_pos_pt2pt[:, :2]),
|
103 |
+
rel_orient_pt2pt], dim=-1)
|
104 |
+
elif self.input_dim == 3:
|
105 |
+
r_pt2pt = torch.stack(
|
106 |
+
[torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1),
|
107 |
+
angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]],
|
108 |
+
nbr_vector=rel_pos_pt2pt[:, :2]),
|
109 |
+
rel_pos_pt2pt[:, -1],
|
110 |
+
rel_orient_pt2pt], dim=-1)
|
111 |
+
else:
|
112 |
+
raise ValueError('{} is not a valid dimension'.format(self.input_dim))
|
113 |
+
|
114 |
+
# layers
|
115 |
+
r_pt2pt = self.r_pt2pt_emb(continuous_inputs=r_pt2pt, categorical_embs=None)
|
116 |
+
for i in range(self.num_layers):
|
117 |
+
x_pt = self.pt2pt_layers[i](x_pt, r_pt2pt, edge_index_pt2pt)
|
118 |
+
|
119 |
+
next_token_prob = self.token_predict_head(x_pt[pt_pred_mask])
|
120 |
+
next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)
|
121 |
+
_, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1)
|
122 |
+
next_token_index_gt = data['pt_token']['token_idx'][pt_target_mask]
|
123 |
+
|
124 |
+
return {
|
125 |
+
'x_pt': x_pt,
|
126 |
+
'map_next_token_idx': next_token_idx,
|
127 |
+
'map_next_token_prob': next_token_prob,
|
128 |
+
'map_next_token_idx_gt': next_token_index_gt,
|
129 |
+
'map_next_token_eval_mask': pt_pred_mask[pt_pred_mask]
|
130 |
+
}
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/modules/occ_decoder.py
ADDED
@@ -0,0 +1,927 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from typing import Dict, Mapping, Optional, Literal
|
7 |
+
from torch_cluster import radius, radius_graph
|
8 |
+
from torch_geometric.data import HeteroData, Batch
|
9 |
+
from torch_geometric.utils import dense_to_sparse, subgraph
|
10 |
+
from scipy.optimize import linear_sum_assignment
|
11 |
+
|
12 |
+
from dev.modules.attr_tokenizer import Attr_Tokenizer
|
13 |
+
from dev.modules.layers import *
|
14 |
+
from dev.utils.visualization import *
|
15 |
+
from dev.utils.func import angle_between_2d_vectors, wrap_angle, weight_init
|
16 |
+
|
17 |
+
|
18 |
+
class SMARTOccDecoder(nn.Module):
|
19 |
+
|
20 |
+
def __init__(self,
|
21 |
+
dataset: str,
|
22 |
+
input_dim: int,
|
23 |
+
hidden_dim: int,
|
24 |
+
num_historical_steps: int,
|
25 |
+
time_span: Optional[int],
|
26 |
+
pl2a_radius: float,
|
27 |
+
pl2seed_radius: float,
|
28 |
+
a2a_radius: float,
|
29 |
+
a2sa_radius: float,
|
30 |
+
pl2sa_radius: float,
|
31 |
+
num_freq_bands: int,
|
32 |
+
num_layers: int,
|
33 |
+
num_heads: int,
|
34 |
+
head_dim: int,
|
35 |
+
dropout: float,
|
36 |
+
token_data: Dict,
|
37 |
+
token_size: int,
|
38 |
+
special_token_index: list=[],
|
39 |
+
attr_tokenizer: Attr_Tokenizer=None,
|
40 |
+
predict_motion: bool=False,
|
41 |
+
predict_state: bool=False,
|
42 |
+
predict_map: bool=False,
|
43 |
+
predict_occ: bool=False,
|
44 |
+
state_token: Dict[str, int]=None,
|
45 |
+
seed_size: int=5,
|
46 |
+
buffer_size: int=32,
|
47 |
+
loss_weight: dict=None,
|
48 |
+
logger=None) -> None:
|
49 |
+
|
50 |
+
super(SMARTOccDecoder, self).__init__()
|
51 |
+
self.dataset = dataset
|
52 |
+
self.input_dim = input_dim
|
53 |
+
self.hidden_dim = hidden_dim
|
54 |
+
self.num_historical_steps = num_historical_steps
|
55 |
+
self.time_span = time_span if time_span is not None else num_historical_steps
|
56 |
+
self.pl2a_radius = pl2a_radius
|
57 |
+
self.pl2seed_radius = pl2seed_radius
|
58 |
+
self.a2a_radius = a2a_radius
|
59 |
+
self.a2sa_radius = a2sa_radius
|
60 |
+
self.pl2sa_radius = pl2sa_radius
|
61 |
+
self.num_freq_bands = num_freq_bands
|
62 |
+
self.num_layers = num_layers
|
63 |
+
self.num_heads = num_heads
|
64 |
+
self.head_dim = head_dim
|
65 |
+
self.dropout = dropout
|
66 |
+
self.special_token_index = special_token_index
|
67 |
+
self.predict_motion = predict_motion
|
68 |
+
self.predict_state = predict_state
|
69 |
+
self.predict_map = predict_map
|
70 |
+
self.predict_occ = predict_occ
|
71 |
+
self.loss_weight = loss_weight
|
72 |
+
self.logger = logger
|
73 |
+
|
74 |
+
self.attr_tokenizer = attr_tokenizer
|
75 |
+
|
76 |
+
# state tokens
|
77 |
+
self.state_type = list(state_token.keys())
|
78 |
+
self.state_token = state_token
|
79 |
+
self.invalid_state = int(state_token['invalid'])
|
80 |
+
self.valid_state = int(state_token['valid'])
|
81 |
+
self.enter_state = int(state_token['enter'])
|
82 |
+
self.exit_state = int(state_token['exit'])
|
83 |
+
|
84 |
+
self.seed_state_type = ['invalid', 'enter']
|
85 |
+
self.valid_state_type = ['invalid', 'valid', 'exit']
|
86 |
+
|
87 |
+
input_dim_r_pt2a = 3
|
88 |
+
input_dim_r_a2a = 3
|
89 |
+
|
90 |
+
self.seed_size = seed_size
|
91 |
+
self.buffer_size = buffer_size
|
92 |
+
|
93 |
+
self.agent_type = ['veh', 'ped', 'cyc', 'seed']
|
94 |
+
self.type_a_emb = nn.Embedding(len(self.agent_type), hidden_dim)
|
95 |
+
self.shape_emb = MLPEmbedding(input_dim=3, hidden_dim=hidden_dim)
|
96 |
+
self.state_a_emb = nn.Embedding(len(self.state_type), hidden_dim)
|
97 |
+
self.motion_gap = 1.
|
98 |
+
self.heading_gap = 1.
|
99 |
+
self.invalid_shape_value = .1
|
100 |
+
self.invalid_motion_value = -2.
|
101 |
+
self.invalid_head_value = -2.
|
102 |
+
|
103 |
+
self.r_pt2a_emb = FourierEmbedding(input_dim=input_dim_r_pt2a, hidden_dim=hidden_dim,
|
104 |
+
num_freq_bands=num_freq_bands)
|
105 |
+
self.r_a2a_emb = FourierEmbedding(input_dim=input_dim_r_a2a, hidden_dim=hidden_dim,
|
106 |
+
num_freq_bands=num_freq_bands)
|
107 |
+
|
108 |
+
self.token_size = token_size # 2048
|
109 |
+
self.grid_size = self.attr_tokenizer.grid_size
|
110 |
+
self.angle_size = self.attr_tokenizer.angle_size
|
111 |
+
self.agent_limit = 3
|
112 |
+
self.pt_limit = 10
|
113 |
+
self.grid_agent_occ_head = MLPLayer(input_dim=hidden_dim, hidden_dim=self.grid_size,
|
114 |
+
output_dim=self.agent_limit * self.grid_size)
|
115 |
+
self.grid_pt_occ_head = MLPLayer(input_dim=hidden_dim, hidden_dim=self.grid_size,
|
116 |
+
output_dim=self.pt_limit * self.grid_size)
|
117 |
+
|
118 |
+
# self.num_seed_feature = 1
|
119 |
+
# self.num_seed_feature = self.seed_size
|
120 |
+
self.num_seed_feature = 10
|
121 |
+
|
122 |
+
self.trajectory_token = token_data['token'] # dict('veh', 'ped', 'cyc') (2048, 4, 2)
|
123 |
+
self.trajectory_token_traj = token_data['traj'] # (2048, 6, 3)
|
124 |
+
self.trajectory_token_all = token_data['token_all'] # (2048, 6, 4, 2)
|
125 |
+
self.apply(weight_init)
|
126 |
+
|
127 |
+
self.shift = 5
|
128 |
+
self.beam_size = 5
|
129 |
+
self.hist_mask = True
|
130 |
+
self.temporal_attn_to_invalid = False
|
131 |
+
self.use_rel = False
|
132 |
+
|
133 |
+
# seed agent
|
134 |
+
self.temporal_attn_seed = False
|
135 |
+
self.seed_attn_to_av = True
|
136 |
+
self.seed_use_ego_motion = False
|
137 |
+
|
138 |
+
def transform_rel(self, token_traj, prev_pos, prev_heading=None):
|
139 |
+
if prev_heading is None:
|
140 |
+
diff_xy = prev_pos[:, :, -1, :] - prev_pos[:, :, -2, :]
|
141 |
+
prev_heading = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])
|
142 |
+
|
143 |
+
num_agent, num_step, traj_num, traj_dim = token_traj.shape
|
144 |
+
cos, sin = prev_heading.cos(), prev_heading.sin()
|
145 |
+
rot_mat = torch.zeros((num_agent, num_step, 2, 2), device=prev_heading.device)
|
146 |
+
rot_mat[:, :, 0, 0] = cos
|
147 |
+
rot_mat[:, :, 0, 1] = -sin
|
148 |
+
rot_mat[:, :, 1, 0] = sin
|
149 |
+
rot_mat[:, :, 1, 1] = cos
|
150 |
+
agent_diff_rel = torch.bmm(token_traj.view(-1, traj_num, 2), rot_mat.view(-1, 2, 2)).view(num_agent, num_step, traj_num, traj_dim)
|
151 |
+
agent_pred_rel = agent_diff_rel + prev_pos[:, :, -1:, :]
|
152 |
+
return agent_pred_rel
|
153 |
+
|
154 |
+
def _agent_token_embedding(self, data, agent_token_index, agent_state, agent_offset_token_idx, pos_a, head_a,
|
155 |
+
inference=False, filter_mask=None, av_index=None):
|
156 |
+
|
157 |
+
if filter_mask is None:
|
158 |
+
filter_mask = torch.ones_like(agent_state[:, 2], dtype=torch.bool)
|
159 |
+
|
160 |
+
num_agent, num_step, traj_dim = pos_a.shape # traj_dim=2
|
161 |
+
agent_type = data['agent']['type'][filter_mask]
|
162 |
+
veh_mask = (agent_type == 0)
|
163 |
+
ped_mask = (agent_type == 1)
|
164 |
+
cyc_mask = (agent_type == 2)
|
165 |
+
|
166 |
+
motion_vector_a, head_vector_a = self._build_vector_a(pos_a, head_a, agent_state)
|
167 |
+
|
168 |
+
trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(pos_a.device).to(torch.float)
|
169 |
+
trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(pos_a.device).to(torch.float)
|
170 |
+
trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(pos_a.device).to(torch.float)
|
171 |
+
self.agent_token_emb_veh = self.token_emb_veh(trajectory_token_veh.view(trajectory_token_veh.shape[0], -1)) # (token_size, 8)
|
172 |
+
self.agent_token_emb_ped = self.token_emb_ped(trajectory_token_ped.view(trajectory_token_ped.shape[0], -1))
|
173 |
+
self.agent_token_emb_cyc = self.token_emb_cyc(trajectory_token_cyc.view(trajectory_token_cyc.shape[0], -1))
|
174 |
+
|
175 |
+
# add bos token embedding
|
176 |
+
self.agent_token_emb_veh = torch.cat([self.agent_token_emb_veh, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())])
|
177 |
+
self.agent_token_emb_ped = torch.cat([self.agent_token_emb_ped, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())])
|
178 |
+
self.agent_token_emb_cyc = torch.cat([self.agent_token_emb_cyc, self.bos_token_emb(torch.zeros(1, device=pos_a.device).long())])
|
179 |
+
|
180 |
+
# add invalid token embedding
|
181 |
+
self.agent_token_emb_veh = torch.cat([self.agent_token_emb_veh, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())])
|
182 |
+
self.agent_token_emb_ped = torch.cat([self.agent_token_emb_ped, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())])
|
183 |
+
self.agent_token_emb_cyc = torch.cat([self.agent_token_emb_cyc, self.no_token_emb(torch.zeros(1, device=pos_a.device).long())])
|
184 |
+
|
185 |
+
# self.grid_token_emb = self.token_emb_grid(torch.stack([self.attr_tokenizer.dist,
|
186 |
+
# self.attr_tokenizer.dir], dim=-1).to(pos_a.device))
|
187 |
+
self.grid_token_emb = self.token_emb_grid(self.attr_tokenizer.grid)
|
188 |
+
self.grid_token_emb = torch.cat([self.grid_token_emb, self.invalid_offset_token_emb(torch.zeros(1, device=pos_a.device).long())])
|
189 |
+
|
190 |
+
if inference:
|
191 |
+
agent_token_traj_all = torch.zeros((num_agent, self.token_size, self.shift + 1, 4, 2), device=pos_a.device)
|
192 |
+
trajectory_token_all_veh = torch.from_numpy(self.trajectory_token_all['veh']).clone().to(pos_a.device).to(torch.float)
|
193 |
+
trajectory_token_all_ped = torch.from_numpy(self.trajectory_token_all['ped']).clone().to(pos_a.device).to(torch.float)
|
194 |
+
trajectory_token_all_cyc = torch.from_numpy(self.trajectory_token_all['cyc']).clone().to(pos_a.device).to(torch.float)
|
195 |
+
agent_token_traj_all[veh_mask] = torch.cat(
|
196 |
+
[trajectory_token_all_veh[:, :self.shift], trajectory_token_veh[:, None, ...]], dim=1)
|
197 |
+
agent_token_traj_all[ped_mask] = torch.cat(
|
198 |
+
[trajectory_token_all_ped[:, :self.shift], trajectory_token_ped[:, None, ...]], dim=1)
|
199 |
+
agent_token_traj_all[cyc_mask] = torch.cat(
|
200 |
+
[trajectory_token_all_cyc[:, :self.shift], trajectory_token_cyc[:, None, ...]], dim=1)
|
201 |
+
|
202 |
+
# additional token embeddings are already added -> -1: invalid, -2: bos
|
203 |
+
agent_token_emb = torch.zeros((num_agent, num_step, self.hidden_dim), device=pos_a.device)
|
204 |
+
agent_token_emb[veh_mask] = self.agent_token_emb_veh[agent_token_index[veh_mask]]
|
205 |
+
agent_token_emb[ped_mask] = self.agent_token_emb_ped[agent_token_index[ped_mask]]
|
206 |
+
agent_token_emb[cyc_mask] = self.agent_token_emb_cyc[agent_token_index[cyc_mask]]
|
207 |
+
|
208 |
+
offset_token_emb = self.grid_token_emb[agent_offset_token_idx]
|
209 |
+
|
210 |
+
# 'vehicle', 'pedestrian', 'cyclist', 'background'
|
211 |
+
is_invalid = agent_state == self.invalid_state
|
212 |
+
agent_types = data['agent']['type'].clone()[filter_mask].long().repeat_interleave(repeats=num_step, dim=0)
|
213 |
+
agent_types[is_invalid.reshape(-1)] = self.agent_type.index('seed')
|
214 |
+
agent_shapes = data['agent']['shape'].clone()[filter_mask, self.num_historical_steps - 1, :].repeat_interleave(repeats=num_step, dim=0)
|
215 |
+
agent_shapes[is_invalid.reshape(-1)] = self.invalid_shape_value
|
216 |
+
|
217 |
+
# TODO: fix ego_pos in inference mode
|
218 |
+
offset_pos = pos_a - pos_a[av_index].repeat_interleave(repeats=data['batch_size_a'], dim=0)
|
219 |
+
feat_a, categorical_embs = self._build_agent_feature(num_step, pos_a.device,
|
220 |
+
motion_vector_a,
|
221 |
+
head_vector_a,
|
222 |
+
agent_token_emb,
|
223 |
+
offset_token_emb,
|
224 |
+
offset_pos=offset_pos,
|
225 |
+
type=agent_types,
|
226 |
+
shape=agent_shapes,
|
227 |
+
state=agent_state,
|
228 |
+
n=num_agent)
|
229 |
+
|
230 |
+
if inference:
|
231 |
+
return feat_a, agent_token_traj_all, agent_token_emb, categorical_embs
|
232 |
+
else:
|
233 |
+
# seed agent feature
|
234 |
+
if self.seed_use_ego_motion:
|
235 |
+
motion_vector_seed = motion_vector_a[av_index].repeat_interleave(repeats=self.num_seed_feature, dim=0)
|
236 |
+
head_vector_seed = head_vector_a[av_index].repeat_interleave(repeats=self.num_seed_feature, dim=0)
|
237 |
+
else:
|
238 |
+
motion_vector_seed = head_vector_seed = None
|
239 |
+
feat_seed, _ = self._build_agent_feature(num_step, pos_a.device,
|
240 |
+
motion_vector_seed,
|
241 |
+
head_vector_seed,
|
242 |
+
state_index=self.invalid_state,
|
243 |
+
n=data.num_graphs * self.num_seed_feature)
|
244 |
+
|
245 |
+
feat_a = torch.cat([feat_a, feat_seed], dim=0) # (a + n, t, d)
|
246 |
+
|
247 |
+
return feat_a
|
248 |
+
|
249 |
+
def _build_vector_a(self, pos_a, head_a, state_a):
|
250 |
+
num_agent = pos_a.shape[0]
|
251 |
+
|
252 |
+
motion_vector_a = torch.cat([pos_a.new_zeros(num_agent, 1, self.input_dim),
|
253 |
+
pos_a[:, 1:] - pos_a[:, :-1]], dim=1)
|
254 |
+
|
255 |
+
motion_vector_a[state_a == self.invalid_state] = self.invalid_motion_value
|
256 |
+
|
257 |
+
# invalid -> valid
|
258 |
+
is_last_invalid = (state_a.roll(shifts=1, dims=1) == self.invalid_state) & (state_a != self.invalid_state)
|
259 |
+
is_last_invalid[:, 0] = state_a[:, 0] == self.enter_state
|
260 |
+
motion_vector_a[is_last_invalid] = self.motion_gap
|
261 |
+
|
262 |
+
# valid -> invalid
|
263 |
+
is_last_valid = (state_a.roll(shifts=1, dims=1) != self.invalid_state) & (state_a == self.invalid_state)
|
264 |
+
is_last_valid[:, 0] = False
|
265 |
+
motion_vector_a[is_last_valid] = -self.motion_gap
|
266 |
+
|
267 |
+
head_a[state_a == self.invalid_state] == self.invalid_head_value
|
268 |
+
head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1)
|
269 |
+
|
270 |
+
return motion_vector_a, head_vector_a
|
271 |
+
|
272 |
+
def _build_agent_feature(self, num_step, device,
|
273 |
+
motion_vector=None,
|
274 |
+
head_vector=None,
|
275 |
+
agent_token_emb=None,
|
276 |
+
agent_grid_emb=None,
|
277 |
+
offset_pos=None,
|
278 |
+
type=None,
|
279 |
+
shape=None,
|
280 |
+
categorical_embs_a=None,
|
281 |
+
state=None,
|
282 |
+
state_index=None,
|
283 |
+
n=1):
|
284 |
+
|
285 |
+
if agent_token_emb is None:
|
286 |
+
agent_token_emb = self.no_token_emb(torch.zeros(1, device=device).long())[:, None].repeat(n, num_step, 1)
|
287 |
+
if state is not None:
|
288 |
+
agent_token_emb[state == self.enter_state] = self.bos_token_emb(torch.zeros(1, device=device).long())
|
289 |
+
|
290 |
+
if agent_grid_emb is None:
|
291 |
+
agent_grid_emb = self.grid_token_emb[None, None, self.grid_size // 2].repeat(n, num_step, 1)
|
292 |
+
|
293 |
+
if motion_vector is None or head_vector is None:
|
294 |
+
pos_a = torch.zeros((n, num_step, 2), device=device)
|
295 |
+
head_a = torch.zeros((n, num_step), device=device)
|
296 |
+
if state is None:
|
297 |
+
state = torch.full((n, num_step), self.invalid_state, device=device)
|
298 |
+
motion_vector, head_vector = self._build_vector_a(pos_a, head_a, state)
|
299 |
+
|
300 |
+
if offset_pos is None:
|
301 |
+
offset_pos = torch.zeros_like(motion_vector)
|
302 |
+
|
303 |
+
feature_a = torch.stack(
|
304 |
+
[torch.norm(motion_vector[:, :, :2], p=2, dim=-1),
|
305 |
+
angle_between_2d_vectors(ctr_vector=head_vector, nbr_vector=motion_vector[:, :, :2]),
|
306 |
+
# torch.norm(offset_pos[:, :, :2], p=2, dim=-1),
|
307 |
+
], dim=-1)
|
308 |
+
|
309 |
+
if categorical_embs_a is None:
|
310 |
+
if type is None:
|
311 |
+
type = torch.tensor([self.agent_type.index('seed')], device=device)
|
312 |
+
if shape is None:
|
313 |
+
shape = torch.full((1, 3), self.invalid_shape_value, device=device)
|
314 |
+
|
315 |
+
categorical_embs_a = [self.type_a_emb(type.reshape(-1)), self.shape_emb(shape.reshape(-1, shape.shape[-1]))]
|
316 |
+
|
317 |
+
x_a = self.x_a_emb(continuous_inputs=feature_a.view(-1, feature_a.size(-1)),
|
318 |
+
categorical_embs=categorical_embs_a)
|
319 |
+
x_a = x_a.view(-1, num_step, self.hidden_dim) # (a, t, d)
|
320 |
+
|
321 |
+
if state is None:
|
322 |
+
assert state_index is not None, f"state index need to be set when state tensor is None!"
|
323 |
+
state = torch.tensor([state_index], device=device)[:, None].repeat(n, num_step, 1) # do not use `expand`
|
324 |
+
s_a = self.state_a_emb(state.reshape(-1).long()).reshape(n, num_step, self.hidden_dim)
|
325 |
+
|
326 |
+
feat_a = torch.cat((agent_token_emb, x_a, s_a, agent_grid_emb), dim=-1)
|
327 |
+
feat_a = self.fusion_emb(feat_a) # (a, t, d)
|
328 |
+
|
329 |
+
return feat_a, categorical_embs_a
|
330 |
+
|
331 |
+
def _pad_feat(self, num_graph, av_index, *feats, num_seed_feature=None):
|
332 |
+
|
333 |
+
if num_seed_feature is None:
|
334 |
+
num_seed_feature = self.num_seed_feature
|
335 |
+
|
336 |
+
padded_feats = tuple()
|
337 |
+
for i in range(len(feats)):
|
338 |
+
padded_feats += (torch.cat([feats[i], feats[i][av_index].repeat_interleave(
|
339 |
+
repeats=num_seed_feature, dim=0)],
|
340 |
+
dim=0
|
341 |
+
),)
|
342 |
+
|
343 |
+
pad_mask = torch.ones(*padded_feats[0].shape[:2], device=feats[0].device).bool() # (a, t)
|
344 |
+
pad_mask[-num_graph * num_seed_feature:] = False
|
345 |
+
|
346 |
+
return padded_feats + (pad_mask,)
|
347 |
+
|
348 |
+
def _build_seed_feat(self, data, pos_a, head_a, state_a, head_vector_a, mask, sort_indices, av_index):
|
349 |
+
seed_mask = sort_indices != av_index.repeat_interleave(repeats=data['batch_size_a'], dim=0)[:, None]
|
350 |
+
# TODO: fix batch_size!!!
|
351 |
+
print(mask.shape, sort_indices.shape, seed_mask.shape)
|
352 |
+
mask[-data.num_graphs * self.num_seed_feature:] = seed_mask[:self.num_seed_feature]
|
353 |
+
|
354 |
+
insert_pos_a = torch.gather(pos_a, dim=0, index=sort_indices[:self.num_seed_feature, :, None].expand(-1, -1, pos_a.shape[-1]))
|
355 |
+
pos_a[mask] = insert_pos_a[mask[-self.num_seed_feature:]]
|
356 |
+
|
357 |
+
state_a[-data.num_graphs * self.num_seed_feature:] = self.enter_state
|
358 |
+
|
359 |
+
return pos_a, head_a, state_a, head_vector_a, mask
|
360 |
+
|
361 |
+
def _build_temporal_edge(self, data, pos_a, head_a, state_a, head_vector_a, mask, inference_mask=None):
|
362 |
+
|
363 |
+
num_graph = data.num_graphs
|
364 |
+
num_agent = pos_a.shape[0]
|
365 |
+
hist_mask = mask.clone()
|
366 |
+
|
367 |
+
if not self.temporal_attn_to_invalid:
|
368 |
+
is_bos = state_a == self.enter_state
|
369 |
+
bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
|
370 |
+
history_invalid_mask = torch.arange(mask.shape[1]).expand(mask.shape[0], mask.shape[1]).to(mask.device)
|
371 |
+
history_invalid_mask = (history_invalid_mask < bos_index[:, None])
|
372 |
+
hist_mask[history_invalid_mask] = False
|
373 |
+
|
374 |
+
if not self.temporal_attn_seed:
|
375 |
+
hist_mask[-num_graph * self.num_seed_feature:] = False
|
376 |
+
if inference_mask is not None:
|
377 |
+
inference_mask[-num_graph * self.num_seed_feature:] = False
|
378 |
+
else:
|
379 |
+
# WARNING: if use temporal attn to seed
|
380 |
+
# we need to fix the pos/head of seed!!!
|
381 |
+
raise RuntimeError("Wrong settings!")
|
382 |
+
|
383 |
+
pos_t = pos_a.reshape(-1, self.input_dim) # (num_agent * num_step, ...)
|
384 |
+
head_t = head_a.reshape(-1)
|
385 |
+
head_vector_t = head_vector_a.reshape(-1, 2)
|
386 |
+
|
387 |
+
# for those invalid agents won't predict any motion token, we don't attend to them
|
388 |
+
is_bos = state_a == self.enter_state
|
389 |
+
is_bos[-num_graph * self.num_seed_feature:] = False
|
390 |
+
bos_index = torch.where(is_bos.any(dim=1), torch.argmax(is_bos.long(), dim=1), torch.tensor(0))
|
391 |
+
motion_predict_start_index = torch.clamp(bos_index - self.time_span / self.shift + 1, min=0)
|
392 |
+
motion_predict_mask = torch.arange(hist_mask.shape[1]).expand(hist_mask.shape[0], -1).to(hist_mask.device)
|
393 |
+
motion_predict_mask = motion_predict_mask >= motion_predict_start_index[:, None]
|
394 |
+
hist_mask[~motion_predict_mask] = False
|
395 |
+
|
396 |
+
if self.hist_mask and self.training:
|
397 |
+
hist_mask[
|
398 |
+
torch.arange(mask.shape[0]).unsqueeze(1), torch.randint(0, mask.shape[1], (num_agent, 10))] = False
|
399 |
+
mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1)
|
400 |
+
elif inference_mask is not None:
|
401 |
+
mask_t = hist_mask.unsqueeze(2) & inference_mask.unsqueeze(1)
|
402 |
+
else:
|
403 |
+
mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1)
|
404 |
+
|
405 |
+
# mask_t: (num_agent, 18, 18), edge_index_t: (2, num_edge)
|
406 |
+
edge_index_t = dense_to_sparse(mask_t)[0]
|
407 |
+
edge_index_t = edge_index_t[:, (edge_index_t[1] - edge_index_t[0] > 0) &
|
408 |
+
(edge_index_t[1] - edge_index_t[0] <= self.time_span / self.shift)]
|
409 |
+
rel_pos_t = pos_t[edge_index_t[0]] - pos_t[edge_index_t[1]]
|
410 |
+
rel_head_t = wrap_angle(head_t[edge_index_t[0]] - head_t[edge_index_t[1]])
|
411 |
+
|
412 |
+
# handle the invalid steps
|
413 |
+
is_invalid = state_a == self.invalid_state
|
414 |
+
is_invalid_t = is_invalid.reshape(-1)
|
415 |
+
|
416 |
+
rel_pos_t[is_invalid_t[edge_index_t[0]] & ~is_invalid_t[edge_index_t[1]]] = -self.motion_gap
|
417 |
+
rel_pos_t[~is_invalid_t[edge_index_t[0]] & is_invalid_t[edge_index_t[1]]] = self.motion_gap
|
418 |
+
rel_head_t[is_invalid_t[edge_index_t[0]] & ~is_invalid_t[edge_index_t[1]]] = -self.heading_gap
|
419 |
+
rel_head_t[~is_invalid_t[edge_index_t[1]] & is_invalid_t[edge_index_t[1]]] = self.heading_gap
|
420 |
+
|
421 |
+
rel_pos_t[is_invalid_t[edge_index_t[0]] & is_invalid_t[edge_index_t[1]]] = self.invalid_motion_value
|
422 |
+
rel_head_t[is_invalid_t[edge_index_t[0]] & is_invalid_t[edge_index_t[1]]] = self.invalid_head_value
|
423 |
+
|
424 |
+
r_t = torch.stack(
|
425 |
+
[torch.norm(rel_pos_t[:, :2], p=2, dim=-1),
|
426 |
+
angle_between_2d_vectors(ctr_vector=head_vector_t[edge_index_t[1]], nbr_vector=rel_pos_t[:, :2]),
|
427 |
+
rel_head_t,
|
428 |
+
edge_index_t[0] - edge_index_t[1]], dim=-1)
|
429 |
+
r_t = self.r_t_emb(continuous_inputs=r_t, categorical_embs=None)
|
430 |
+
|
431 |
+
return edge_index_t, r_t
|
432 |
+
|
433 |
+
def _build_interaction_edge(self, data, pos_a, head_a, state_a, head_vector_a, batch_s, mask, pad_mask=None, inference_mask=None,
|
434 |
+
av_index=None, seq_mask=None, seq_index=None, grid_index_a=None, **plot_kwargs):
|
435 |
+
num_graph = data.num_graphs
|
436 |
+
num_agent, num_step, _ = pos_a.shape
|
437 |
+
is_training = inference_mask is None
|
438 |
+
|
439 |
+
mask_a = mask.clone()
|
440 |
+
|
441 |
+
if pad_mask is None:
|
442 |
+
pad_mask = torch.ones_like(state_a).bool()
|
443 |
+
|
444 |
+
pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
|
445 |
+
head_s = head_a.transpose(0, 1).reshape(-1)
|
446 |
+
head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
|
447 |
+
pad_mask_s = pad_mask.transpose(0, 1).reshape(-1)
|
448 |
+
if inference_mask is not None:
|
449 |
+
mask_a = mask_a & inference_mask
|
450 |
+
mask_s = mask_a.transpose(0, 1).reshape(-1)
|
451 |
+
|
452 |
+
# build agent2agent bilateral connection
|
453 |
+
edge_index_a2a = radius_graph(x=pos_s[:, :2], r=self.a2a_radius, batch=batch_s, loop=False,
|
454 |
+
max_num_neighbors=300)
|
455 |
+
edge_index_a2a = subgraph(subset=mask_s & pad_mask_s, edge_index=edge_index_a2a)[0]
|
456 |
+
|
457 |
+
if os.getenv('PLOT_EDGE', False):
|
458 |
+
plot_interact_edge(edge_index_a2a, data['scenario_id'], data['batch_size_a'].cpu(), self.num_seed_feature, num_step,
|
459 |
+
av_index=av_index, **plot_kwargs)
|
460 |
+
|
461 |
+
rel_pos_a2a = pos_s[edge_index_a2a[0]] - pos_s[edge_index_a2a[1]]
|
462 |
+
rel_head_a2a = wrap_angle(head_s[edge_index_a2a[0]] - head_s[edge_index_a2a[1]])
|
463 |
+
|
464 |
+
# handle the invalid steps
|
465 |
+
is_invalid = state_a == self.invalid_state
|
466 |
+
is_invalid_s = is_invalid.transpose(0, 1).reshape(-1)
|
467 |
+
|
468 |
+
rel_pos_a2a[is_invalid_s[edge_index_a2a[0]] & ~is_invalid_s[edge_index_a2a[1]]] = -self.motion_gap
|
469 |
+
rel_pos_a2a[~is_invalid_s[edge_index_a2a[0]] & is_invalid_s[edge_index_a2a[1]]] = self.motion_gap
|
470 |
+
rel_head_a2a[is_invalid_s[edge_index_a2a[0]] & ~is_invalid_s[edge_index_a2a[1]]] = -self.heading_gap
|
471 |
+
rel_head_a2a[~is_invalid_s[edge_index_a2a[1]] & is_invalid_s[edge_index_a2a[1]]] = self.heading_gap
|
472 |
+
|
473 |
+
rel_pos_a2a[is_invalid_s[edge_index_a2a[0]] & is_invalid_s[edge_index_a2a[1]]] = self.invalid_motion_value
|
474 |
+
rel_head_a2a[is_invalid_s[edge_index_a2a[0]] & is_invalid_s[edge_index_a2a[1]]] = self.invalid_head_value
|
475 |
+
|
476 |
+
r_a2a = torch.stack(
|
477 |
+
[torch.norm(rel_pos_a2a[:, :2], p=2, dim=-1),
|
478 |
+
angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2a[1]], nbr_vector=rel_pos_a2a[:, :2]),
|
479 |
+
rel_head_a2a,
|
480 |
+
torch.zeros_like(edge_index_a2a[0])], dim=-1)
|
481 |
+
r_a2a = self.r_a2a_emb(continuous_inputs=r_a2a, categorical_embs=None)
|
482 |
+
|
483 |
+
# add the edges which connect seed agents
|
484 |
+
if is_training:
|
485 |
+
mask_av = torch.ones_like(mask_a).bool()
|
486 |
+
if not self.seed_attn_to_av:
|
487 |
+
mask_av[av_index] = False
|
488 |
+
mask_a &= mask_av
|
489 |
+
edge_index_seed2a, r_seed2a = self._build_a2sa_edge(data, pos_a, head_a, head_vector_a, batch_s,
|
490 |
+
mask_a.clone(), ~pad_mask.clone(), inference_mask=inference_mask,
|
491 |
+
r=self.pl2seed_radius, max_num_neighbors=300,
|
492 |
+
seq_mask=seq_mask, seq_index=seq_index, grid_index_a=grid_index_a, mode='grid')
|
493 |
+
|
494 |
+
if os.getenv('PLOT_EDGE', False):
|
495 |
+
plot_interact_edge(edge_index_seed2a, data['scenario_id'], data['batch_size_a'].cpu(), self.num_seed_feature, num_step,
|
496 |
+
'interact_edge_map_seed', av_index=av_index, **plot_kwargs)
|
497 |
+
|
498 |
+
edge_index_a2a = torch.cat([edge_index_a2a, edge_index_seed2a], dim=-1)
|
499 |
+
r_a2a = torch.cat([r_a2a, r_seed2a])
|
500 |
+
|
501 |
+
return edge_index_a2a, r_a2a, (edge_index_a2a.shape[1], edge_index_seed2a.shape[1]) #, nearest_dict
|
502 |
+
|
503 |
+
return edge_index_a2a, r_a2a
|
504 |
+
|
505 |
+
def _build_map2agent_edge(self, data, pos_a, head_a, state_a, head_vector_a, batch_s, batch_pl,
|
506 |
+
mask, pad_mask=None, inference_mask=None, av_index=None, **kwargs):
|
507 |
+
num_graph = data.num_graphs
|
508 |
+
num_agent, num_step, _ = pos_a.shape
|
509 |
+
is_training = inference_mask is None
|
510 |
+
|
511 |
+
mask_pl2a = mask.clone()
|
512 |
+
|
513 |
+
if pad_mask is None:
|
514 |
+
pad_mask = torch.ones_like(state_a).bool()
|
515 |
+
|
516 |
+
pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
|
517 |
+
head_s = head_a.transpose(0, 1).reshape(-1)
|
518 |
+
head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
|
519 |
+
pad_mask_s = pad_mask.transpose(0, 1).reshape(-1)
|
520 |
+
if inference_mask is not None:
|
521 |
+
mask_pl2a = mask_pl2a & inference_mask
|
522 |
+
mask_pl2a = mask_pl2a.transpose(0, 1).reshape(-1)
|
523 |
+
|
524 |
+
ori_pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous()
|
525 |
+
ori_orient_pl = data['pt_token']['orientation'].contiguous()
|
526 |
+
pos_pl = ori_pos_pl.repeat(num_step, 1) # not `repeat_interleave`
|
527 |
+
orient_pl = ori_orient_pl.repeat(num_step)
|
528 |
+
|
529 |
+
# build map2agent directed graph
|
530 |
+
# edge_index_pl2a = radius(x=pos_s[:, :2], y=pos_pl[:, :2], r=self.pl2a_radius,
|
531 |
+
# batch_x=batch_s, batch_y=batch_pl, max_num_neighbors=300)
|
532 |
+
edge_index_pl2a = radius(x=pos_pl[:, :2], y=pos_s[:, :2], r=self.pl2a_radius,
|
533 |
+
batch_x=batch_pl, batch_y=batch_s, max_num_neighbors=5)
|
534 |
+
edge_index_pl2a = edge_index_pl2a[[1, 0]]
|
535 |
+
edge_index_pl2a = edge_index_pl2a[:, mask_pl2a[edge_index_pl2a[1]] &
|
536 |
+
pad_mask_s[edge_index_pl2a[1]]]
|
537 |
+
|
538 |
+
rel_pos_pl2a = pos_pl[edge_index_pl2a[0]] - pos_s[edge_index_pl2a[1]]
|
539 |
+
rel_orient_pl2a = wrap_angle(orient_pl[edge_index_pl2a[0]] - head_s[edge_index_pl2a[1]])
|
540 |
+
|
541 |
+
# handle the invalid steps
|
542 |
+
is_invalid = state_a == self.invalid_state
|
543 |
+
is_invalid_s = is_invalid.transpose(0, 1).reshape(-1)
|
544 |
+
rel_pos_pl2a[is_invalid_s[edge_index_pl2a[1]]] = self.motion_gap
|
545 |
+
rel_orient_pl2a[is_invalid_s[edge_index_pl2a[1]]] = self.heading_gap
|
546 |
+
|
547 |
+
r_pl2a = torch.stack(
|
548 |
+
[torch.norm(rel_pos_pl2a[:, :2], p=2, dim=-1),
|
549 |
+
angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_pl2a[1]], nbr_vector=rel_pos_pl2a[:, :2]),
|
550 |
+
rel_orient_pl2a], dim=-1)
|
551 |
+
r_pl2a = self.r_pt2a_emb(continuous_inputs=r_pl2a, categorical_embs=None)
|
552 |
+
|
553 |
+
# add the edges which connect seed agents
|
554 |
+
if is_training:
|
555 |
+
edge_index_pl2seed, r_pl2seed = self._build_map2sa_edge(data, pos_a, head_a, head_vector_a, batch_s, batch_pl,
|
556 |
+
~pad_mask.clone(), inference_mask=inference_mask,
|
557 |
+
r=self.pl2seed_radius, max_num_neighbors=2048, mode='grid')
|
558 |
+
|
559 |
+
# sanity check
|
560 |
+
# pl2a_index = torch.zeros(pos_a.shape[0], num_step)
|
561 |
+
# pl2a_r = torch.zeros(pos_a.shape[0], num_step)
|
562 |
+
# for src_index in torch.unique(edge_index_pl2seed[1]):
|
563 |
+
# src_row = src_index % pos_a.shape[0]
|
564 |
+
# src_col = src_index // pos_a.shape[0]
|
565 |
+
# pl2a_index[src_row, src_col] = edge_index_pl2seed[0, edge_index_pl2seed[1] == src_index].sum()
|
566 |
+
# pl2a_r[src_row, src_col] = r_pl2seed[edge_index_pl2seed[1] == src_index].sum()
|
567 |
+
# print(pl2a_index)
|
568 |
+
# print(pl2a_r)
|
569 |
+
# exit(1)
|
570 |
+
|
571 |
+
if os.getenv('PLOT_EDGE', False):
|
572 |
+
plot_interact_edge(edge_index_pl2seed, data['scenario_id'], data['batch_size_a'].cpu(), self.num_seed_feature, num_step,
|
573 |
+
'interact_edge_map_seed', av_index=av_index)
|
574 |
+
|
575 |
+
edge_index_pl2a = torch.cat([edge_index_pl2a, edge_index_pl2seed], dim=-1)
|
576 |
+
r_pl2a = torch.cat([r_pl2a, r_pl2seed])
|
577 |
+
|
578 |
+
return edge_index_pl2a, r_pl2a, (edge_index_pl2a.shape[1], edge_index_pl2seed.shape[1])
|
579 |
+
|
580 |
+
return edge_index_pl2a, r_pl2a
|
581 |
+
|
582 |
+
def _build_a2sa_edge(self, data, pos_a, head_a, head_vector_a, batch_s, mask_a, mask_sa,
|
583 |
+
inference_mask=None, r=None, max_num_neighbors=8, seq_mask=None, seq_index=None,
|
584 |
+
grid_index_a=None, mode: Literal['grid', 'heading']='heading', **plot_kwargs):
|
585 |
+
|
586 |
+
num_agent, num_step, _ = pos_a.shape
|
587 |
+
is_training = inference_mask is None
|
588 |
+
|
589 |
+
pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
|
590 |
+
head_s = head_a.transpose(0, 1).reshape(-1)
|
591 |
+
head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
|
592 |
+
if inference_mask is not None:
|
593 |
+
mask_a = mask_a & inference_mask
|
594 |
+
mask_sa = mask_sa & inference_mask
|
595 |
+
mask_s = mask_a.transpose(0, 1).reshape(-1)
|
596 |
+
mask_s_sa = mask_sa.transpose(0, 1).reshape(-1)
|
597 |
+
|
598 |
+
# build seed_agent2agent unilateral connection
|
599 |
+
assert r is not None, "r needs to be specified!"
|
600 |
+
# edge_index_a2sa = radius(x=pos_s[mask_s_sa, :2], y=pos_s[:, :2], r=r,
|
601 |
+
# batch_x=batch_s[mask_s_sa], batch_y=batch_s, max_num_neighbors=max_num_neighbors)
|
602 |
+
edge_index_a2sa = radius(x=pos_s[:, :2], y=pos_s[mask_s_sa, :2], r=r,
|
603 |
+
batch_x=batch_s, batch_y=batch_s[mask_s_sa], max_num_neighbors=max_num_neighbors)
|
604 |
+
edge_index_a2sa = edge_index_a2sa[[1, 0]]
|
605 |
+
edge_index_a2sa = edge_index_a2sa[:, ~mask_s_sa[edge_index_a2sa[0]] & mask_s[edge_index_a2sa[0]]]
|
606 |
+
|
607 |
+
# only for seed agent sequence training
|
608 |
+
if seq_mask is not None:
|
609 |
+
edge_mask = seq_mask[edge_index_a2sa[1]]
|
610 |
+
edge_mask = torch.gather(edge_mask, dim=1, index=edge_index_a2sa[0, :, None] % num_agent)[:, 0]
|
611 |
+
edge_index_a2sa = edge_index_a2sa[:, edge_mask]
|
612 |
+
|
613 |
+
if seq_index is None:
|
614 |
+
seq_index = torch.zeros(num_agent, device=pos_a.device).long()
|
615 |
+
if seq_index.dim() == 1:
|
616 |
+
seq_index = seq_index[:, None].repeat(1, num_step)
|
617 |
+
seq_index = seq_index.transpose(0, 1).reshape(-1)
|
618 |
+
assert seq_index.shape[0] == pos_s.shape[0], f"Inconsistent lenght {seq_index.shape[0]} and {pos_s.shape[0]}!"
|
619 |
+
|
620 |
+
# convert to global index
|
621 |
+
all_index = torch.arange(pos_s.shape[0], device=pos_a.device).long()
|
622 |
+
sa_index = all_index[mask_s_sa]
|
623 |
+
edge_index_a2sa[1] = sa_index[edge_index_a2sa[1]]
|
624 |
+
|
625 |
+
# plot edge index TODO: now only support bs=1
|
626 |
+
if os.getenv('PLOT_EDGE_INFERENCE', False) and not is_training:
|
627 |
+
num_agent, num_step, _ = pos_a.shape
|
628 |
+
# plot_interact_edge(edge_index_a2sa, data['scenario_id'], data['batch_size_a'].cpu(), 1, num_step,
|
629 |
+
# 'interact_a2sa_edge_map', **plot_kwargs)
|
630 |
+
plot_interact_edge(edge_index_a2sa, data['scenario_id'], torch.tensor([num_agent - 1]), 1, num_step,
|
631 |
+
f"interact_a2sa_edge_map_infer_{plot_kwargs['tag']}", **plot_kwargs)
|
632 |
+
|
633 |
+
rel_pos_a2sa = pos_s[edge_index_a2sa[0]] - pos_s[edge_index_a2sa[1]]
|
634 |
+
rel_head_a2sa = wrap_angle(head_s[edge_index_a2sa[0]] - head_s[edge_index_a2sa[1]])
|
635 |
+
|
636 |
+
r_a2sa = torch.stack(
|
637 |
+
[torch.norm(rel_pos_a2sa[:, :2], p=2, dim=-1),
|
638 |
+
angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2sa[1]], nbr_vector=rel_pos_a2sa[:, :2]),
|
639 |
+
rel_head_a2sa,
|
640 |
+
seq_index[edge_index_a2sa[0]] - seq_index[edge_index_a2sa[1]]], dim=-1)
|
641 |
+
r_a2sa = self.r_a2sa_emb(continuous_inputs=r_a2sa, categorical_embs=None)
|
642 |
+
|
643 |
+
return edge_index_a2sa, r_a2sa
|
644 |
+
|
645 |
+
def _build_map2sa_edge(self, data, pos_a, head_a, head_vector_a, batch_s, batch_pl,
|
646 |
+
mask_sa, inference_mask=None, r=None, max_num_neighbors=32, mode: Literal['grid', 'heading']='heading'):
|
647 |
+
|
648 |
+
_, num_step, _ = pos_a.shape
|
649 |
+
|
650 |
+
mask_pl2sa = torch.ones_like(mask_sa).bool()
|
651 |
+
if inference_mask is not None:
|
652 |
+
mask_pl2sa = mask_pl2sa & inference_mask
|
653 |
+
mask_pl2sa = mask_pl2sa.transpose(0, 1).reshape(-1)
|
654 |
+
mask_s_sa = mask_sa.transpose(0, 1).reshape(-1)
|
655 |
+
|
656 |
+
pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
|
657 |
+
head_s = head_a.transpose(0, 1).reshape(-1)
|
658 |
+
head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
|
659 |
+
|
660 |
+
ori_pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous()
|
661 |
+
ori_orient_pl = data['pt_token']['orientation'].contiguous()
|
662 |
+
pos_pl = ori_pos_pl.repeat(num_step, 1) # not `repeat_interleave`
|
663 |
+
orient_pl = ori_orient_pl.repeat(num_step)
|
664 |
+
|
665 |
+
# build map2agent directed graph
|
666 |
+
assert r is not None, "r needs to be specified!"
|
667 |
+
# edge_index_pl2sa = radius(x=pos_s[mask_s_sa, :2], y=pos_pl[:, :2], r=r,
|
668 |
+
# batch_x=batch_s[mask_s_sa], batch_y=batch_pl, max_num_neighbors=max_num_neighbors)
|
669 |
+
edge_index_pl2sa = radius(x=pos_pl[:, :2], y=pos_s[mask_s_sa, :2], r=r,
|
670 |
+
batch_x=batch_pl, batch_y=batch_s[mask_s_sa], max_num_neighbors=max_num_neighbors)
|
671 |
+
edge_index_pl2sa = edge_index_pl2sa[[1, 0]]
|
672 |
+
edge_index_pl2sa = edge_index_pl2sa[:, mask_pl2sa[mask_s_sa][edge_index_pl2sa[1]]]
|
673 |
+
|
674 |
+
# convert to global index
|
675 |
+
all_index = torch.arange(pos_s.shape[0], device=pos_a.device).long()
|
676 |
+
sa_index = all_index[mask_s_sa]
|
677 |
+
edge_index_pl2sa[1] = sa_index[edge_index_pl2sa[1]]
|
678 |
+
|
679 |
+
# plot edge map
|
680 |
+
# if os.getenv('PLOT_EDGE', False):
|
681 |
+
# plot_map_edge(edge_index_pl2sa, pos_s[:, :2], data, save_path='map2sa_edge_map')
|
682 |
+
|
683 |
+
rel_pos_pl2sa = pos_pl[edge_index_pl2sa[0]] - pos_s[edge_index_pl2sa[1]]
|
684 |
+
rel_orient_pl2sa = wrap_angle(orient_pl[edge_index_pl2sa[0]] - head_s[edge_index_pl2sa[1]])
|
685 |
+
|
686 |
+
r_pl2sa = torch.stack(
|
687 |
+
[torch.norm(rel_pos_pl2sa[:, :2], p=2, dim=-1),
|
688 |
+
angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_pl2sa[1]], nbr_vector=rel_pos_pl2sa[:, :2]),
|
689 |
+
rel_orient_pl2sa], dim=-1)
|
690 |
+
r_pl2sa = self.r_pt2sa_emb(continuous_inputs=r_pl2sa, categorical_embs=None)
|
691 |
+
|
692 |
+
return edge_index_pl2sa, r_pl2sa
|
693 |
+
|
694 |
+
def _build_sa2sa_edge(self, data, pos_a, head_a, state_a, head_vector_a, batch_s, mask, inference_mask=None, **plot_kwargs):
|
695 |
+
|
696 |
+
num_agent = pos_a.shape[0]
|
697 |
+
|
698 |
+
pos_t = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
|
699 |
+
head_t = head_a.reshape(-1)
|
700 |
+
head_vector_t = head_vector_a.reshape(-1, 2)
|
701 |
+
|
702 |
+
if inference_mask is not None:
|
703 |
+
mask_t = mask.unsqueeze(2) & inference_mask.unsqueeze(1)
|
704 |
+
else:
|
705 |
+
mask_t = mask.unsqueeze(2) & mask.unsqueeze(1)
|
706 |
+
|
707 |
+
edge_index_sa2sa = dense_to_sparse(mask_t)[0]
|
708 |
+
edge_index_sa2sa = edge_index_sa2sa[:, edge_index_sa2sa[1] - edge_index_sa2sa[0] > 0]
|
709 |
+
rel_pos_t = pos_t[edge_index_sa2sa[0]] - pos_t[edge_index_sa2sa[1]]
|
710 |
+
rel_head_t = wrap_angle(head_t[edge_index_sa2sa[0]] - head_t[edge_index_sa2sa[1]])
|
711 |
+
|
712 |
+
r_t = torch.stack(
|
713 |
+
[torch.norm(rel_pos_t[:, :2], p=2, dim=-1),
|
714 |
+
angle_between_2d_vectors(ctr_vector=head_vector_t[edge_index_sa2sa[1]], nbr_vector=rel_pos_t[:, :2]),
|
715 |
+
rel_head_t,
|
716 |
+
edge_index_sa2sa[0] - edge_index_sa2sa[1]], dim=-1)
|
717 |
+
r_sa2sa = self.r_sa2sa_emb(continuous_inputs=r_t, categorical_embs=None)
|
718 |
+
|
719 |
+
return edge_index_sa2sa, r_sa2sa
|
720 |
+
|
721 |
+
def _build_seq(self, device, num_agent, num_step, av_index, sort_indices):
|
722 |
+
"""
|
723 |
+
Args:
|
724 |
+
sort_indices (torch.Tensor): shape (num_agent, num_atep)
|
725 |
+
"""
|
726 |
+
# sort_indices = sort_indices[:self.num_seed_feature]
|
727 |
+
seq_mask = torch.ones(self.num_seed_feature, num_step, num_agent + self.num_seed_feature, device=device).bool()
|
728 |
+
seq_mask[..., -self.num_seed_feature:] = False
|
729 |
+
for t in range(num_step):
|
730 |
+
for s in range(self.num_seed_feature):
|
731 |
+
seq_mask[s, t, sort_indices[s:, t].flatten().long()] = False
|
732 |
+
if self.seed_attn_to_av:
|
733 |
+
seq_mask[..., av_index] = True
|
734 |
+
seq_mask = seq_mask.transpose(0, 1).reshape(-1, num_agent + self.num_seed_feature)
|
735 |
+
|
736 |
+
seq_index = torch.cat([torch.zeros(num_agent), torch.arange(self.num_seed_feature) + 1]).to(device)
|
737 |
+
seq_index = seq_index[:, None].repeat(1, num_step)
|
738 |
+
for t in range(num_step):
|
739 |
+
for s in range(self.num_seed_feature):
|
740 |
+
seq_index[sort_indices[s : s + 1, t].flatten().long(), t] = s + 1
|
741 |
+
seq_index[av_index] = 0
|
742 |
+
|
743 |
+
return seq_mask, seq_index
|
744 |
+
|
745 |
+
def _build_occ_gt(self, data, seq_mask, pos_rel_index_gt, pos_rel_index_gt_seed, mask_seed,
|
746 |
+
edge_index=None, mode='edge_index'):
|
747 |
+
"""
|
748 |
+
Args:
|
749 |
+
seq_mask (torch.Tensor): shape (num_step * num_seed_feature, num_agent + self.num_seed_feature)
|
750 |
+
pos_rel_index_gt (torch.Tensor): shape (num_agent, num_step)
|
751 |
+
pos_rel_index_gt_seed (torch.Tensor): shape (num_seed, num_step)
|
752 |
+
"""
|
753 |
+
num_agent = data['agent']['state_idx'].shape[0] + self.num_seed_feature
|
754 |
+
num_step = data['agent']['state_idx'].shape[1]
|
755 |
+
data['agent']['agent_occ'] = torch.zeros(data.num_graphs * self.num_seed_feature, num_step, self.attr_tokenizer.grid_size,
|
756 |
+
device=data['agent']['state_idx'].device).long()
|
757 |
+
data['agent']['map_occ'] = torch.zeros(data.num_graphs, num_step, self.attr_tokenizer.grid_size,
|
758 |
+
device=data['agent']['state_idx'].device).long()
|
759 |
+
|
760 |
+
if mode == 'edge_index':
|
761 |
+
|
762 |
+
assert edge_index is not None, f"Need edge_index input!"
|
763 |
+
for src_index in torch.unique(edge_index[1]):
|
764 |
+
# decode src
|
765 |
+
src_row = src_index % num_agent - (num_agent - self.num_seed_feature)
|
766 |
+
src_col = src_index // num_agent
|
767 |
+
# decode tgt
|
768 |
+
tgt_indexes = edge_index[0, edge_index[1] == src_index]
|
769 |
+
tgt_rows = tgt_indexes % num_agent
|
770 |
+
tgt_cols = tgt_indexes // num_agent
|
771 |
+
assert tgt_rows.max() < num_agent - self.num_seed_feature, f"Invalid {tgt_rows}"
|
772 |
+
assert torch.unique(tgt_cols).shape[0] == 1 and torch.unique(tgt_cols)[0] == src_col
|
773 |
+
data['agent']['agent_occ'][src_row, src_col, pos_rel_index_gt[tgt_rows, tgt_cols]] = 1
|
774 |
+
|
775 |
+
else:
|
776 |
+
|
777 |
+
seq_mask = seq_mask.reshape(num_step, self.num_seed_feature, -1).transpose(0, 1)[..., :-self.num_seed_feature]
|
778 |
+
for s in range(self.num_seed_feature):
|
779 |
+
for t in range(num_step):
|
780 |
+
index = pos_rel_index_gt[seq_mask[s, t], t]
|
781 |
+
data['agent']['agent_occ'][s, t, index[index != -1]] = 1
|
782 |
+
if t > 0 and s < pos_rel_index_gt_seed.shape[0] and mask_seed[s, t - 1]: # insert agents
|
783 |
+
data['agent']['agent_occ'][s, t, pos_rel_index_gt_seed[s, t - 1]] = -1
|
784 |
+
|
785 |
+
# TODO: fix batch_size!!!
|
786 |
+
pt_grid_token_idx = data['agent']['pt_grid_token_idx'] # (t, num_pt)
|
787 |
+
for t in range(num_step):
|
788 |
+
data['agent']['map_occ'][:, t, pt_grid_token_idx[t][pt_grid_token_idx[t] != -1]] = 1
|
789 |
+
data['agent']['map_occ'] = data['agent']['map_occ'].repeat_interleave(repeats=self.num_seed_feature, dim=0)
|
790 |
+
|
791 |
+
def forward(self,
|
792 |
+
data: HeteroData,
|
793 |
+
map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
794 |
+
|
795 |
+
pos_a = data['agent']['token_pos'].clone() # (a, t, 2)
|
796 |
+
head_a = data['agent']['token_heading'].clone() # (a, t)
|
797 |
+
num_agent, num_step, traj_dim = pos_a.shape # e.g. (50, 18, 2)
|
798 |
+
num_pt = data['pt_token']['position'].shape[0]
|
799 |
+
agent_category = data['agent']['category'].clone() # (a,)
|
800 |
+
agent_shape = data['agent']['shape'][:, self.num_historical_steps - 1].clone() # (a, 3)
|
801 |
+
agent_token_index = data['agent']['token_idx'].clone() # (a, t)
|
802 |
+
agent_state_index = data['agent']['state_idx'].clone()
|
803 |
+
agent_type_index = data['agent']['type'].clone()
|
804 |
+
|
805 |
+
av_index = data['agent']['av_index'].long()
|
806 |
+
ego_pos = pos_a[av_index]
|
807 |
+
ego_head = head_a[av_index]
|
808 |
+
|
809 |
+
_, head_vector_a = self._build_vector_a(pos_a, head_a, agent_state_index)
|
810 |
+
|
811 |
+
agent_grid_token_idx = data['agent']['grid_token_idx']
|
812 |
+
agent_grid_offset_xy = data['agent']['grid_offset_xy']
|
813 |
+
agent_head_token_idx = data['agent']['heading_token_idx']
|
814 |
+
sort_indices = data['agent']['sort_indices']
|
815 |
+
pt_grid_token_idx = data['agent']['pt_grid_token_idx']
|
816 |
+
|
817 |
+
ori_pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous()
|
818 |
+
ori_orient_pl = data['pt_token']['orientation'].contiguous()
|
819 |
+
pos_pl = ori_pos_pl.repeat(num_step, 1)
|
820 |
+
orient_pl = ori_orient_pl.repeat(num_step)
|
821 |
+
|
822 |
+
# build relative 3d descriptors
|
823 |
+
pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
|
824 |
+
head_s = head_a.transpose(0, 1).reshape(-1)
|
825 |
+
|
826 |
+
ego_pos_a = ego_pos.repeat_interleave(repeats=data['batch_size_a'], dim=0)
|
827 |
+
ego_head_a = ego_head.repeat_interleave(repeats=data['batch_size_a'], dim=0)
|
828 |
+
ego_pos_s = ego_pos_a.transpose(0, 1).reshape(-1, self.input_dim)
|
829 |
+
ego_head_s = ego_head_a.transpose(0, 1).reshape(-1)
|
830 |
+
rel_pos_a2a = pos_s - ego_pos_s
|
831 |
+
rel_head_a2a = head_s - ego_head_s
|
832 |
+
|
833 |
+
ego_pos_pl = ego_pos.repeat_interleave(repeats=data['batch_size_pl'], dim=0)
|
834 |
+
ego_head_pl = ego_head.repeat_interleave(repeats=data['batch_size_pl'], dim=0)
|
835 |
+
ego_pos_s = ego_pos_pl.transpose(0, 1).reshape(-1, self.input_dim)
|
836 |
+
ego_head_s = ego_head_pl.transpose(0, 1).reshape(-1)
|
837 |
+
rel_pos_pl2a = pos_pl - ego_pos_s
|
838 |
+
rel_head_pl2a = orient_pl - ego_head_s
|
839 |
+
|
840 |
+
# releative encodings
|
841 |
+
ego_head_vector_a = head_vector_a[av_index].repeat_interleave(repeats=data['batch_size_a'], dim=0)
|
842 |
+
ego_head_vector_s = ego_head_vector_a.transpose(0, 1).reshape(-1, 2)
|
843 |
+
r_a2a = torch.stack(
|
844 |
+
[torch.norm(rel_pos_a2a[:, :2], p=2, dim=-1),
|
845 |
+
angle_between_2d_vectors(ctr_vector=ego_head_vector_s, nbr_vector=rel_pos_a2a[:, :2]),
|
846 |
+
rel_head_a2a], dim=-1)
|
847 |
+
r_a2a = self.r_a2a_emb(continuous_inputs=r_a2a, categorical_embs=None) # [N, hidden_dim]
|
848 |
+
|
849 |
+
ego_head_vector_a = head_vector_a[av_index].repeat_interleave(repeats=data['batch_size_pl'], dim=0)
|
850 |
+
ego_head_vector_s = ego_head_vector_a.transpose(0, 1).reshape(-1, 2)
|
851 |
+
r_pl2a = torch.stack(
|
852 |
+
[torch.norm(rel_pos_pl2a[:, :2], p=2, dim=-1),
|
853 |
+
angle_between_2d_vectors(ctr_vector=ego_head_vector_s, nbr_vector=rel_pos_pl2a[:, :2]),
|
854 |
+
rel_head_pl2a], dim=-1)
|
855 |
+
r_pl2a = self.r_pt2a_emb(continuous_inputs=r_pl2a, categorical_embs=None) # [M, d]
|
856 |
+
|
857 |
+
r_a2a = r_a2a.reshape(num_step, num_agent, -1).transpose(0, 1)
|
858 |
+
r_pl2a = r_pl2a.reshape(num_step, num_pt, -1).transpose(0, 1)
|
859 |
+
select_agent = torch.randperm(num_agent)[:self.agent_limit]
|
860 |
+
select_pt = torch.randperm(num_pt)[:self.pt_limit]
|
861 |
+
r_a2a = r_a2a[select_agent]
|
862 |
+
r_pl2a = r_pl2a[select_pt]
|
863 |
+
|
864 |
+
# aggregate to global feature
|
865 |
+
r_a2a = r_a2a.mean(dim=0) # [t, d]
|
866 |
+
r_pl2a = r_pl2a.mean(dim=0)
|
867 |
+
|
868 |
+
# decode grid index of neighbor agents
|
869 |
+
agent_occ = self.grid_agent_occ_head(r_a2a) # [t, grid_size]
|
870 |
+
pt_occ = self.grid_pt_occ_head(r_pl2a)
|
871 |
+
|
872 |
+
# 1.
|
873 |
+
# agent_occ_gt = torch.zeros_like(agent_occ).long()
|
874 |
+
# pt_occ_gt = torch.zeros_like(pt_occ).long()
|
875 |
+
|
876 |
+
# for t in range(num_step):
|
877 |
+
# agent_occ_gt[t, agent_grid_token_idx[:, t][agent_grid_token_idx[:, t] != -1]] = 1
|
878 |
+
# pt_occ_gt[t, pt_grid_token_idx[t][pt_grid_token_idx[t] != -1]] = 1
|
879 |
+
|
880 |
+
# agent_occ_gt[:, self.grid_size // 2] = 0
|
881 |
+
# pt_occ_gt[:, self.grid_size // 2] = 0
|
882 |
+
|
883 |
+
# agent_occ_eval_mask = torch.ones_like(agent_occ_gt)
|
884 |
+
# agent_occ_eval_mask[0] = 0
|
885 |
+
# agent_occ_eval_mask[:, self.grid_size // 2] = 0
|
886 |
+
# pt_occ_eval_mask = torch.ones_like(pt_occ_gt)
|
887 |
+
# pt_occ_eval_mask[0] = 0
|
888 |
+
# pt_occ_eval_mask[:, self.grid_size // 2] = 0
|
889 |
+
|
890 |
+
# 2.
|
891 |
+
# agent_occ_gt = agent_grid_token_idx.transpose(0, 1).reshape(-1)
|
892 |
+
# pt_occ_gt = pt_grid_token_idx.reshape(-1)
|
893 |
+
|
894 |
+
# agent_occ_eval_mask = torch.zeros_like(agent_occ_gt)
|
895 |
+
# agent_occ_eval_mask[torch.randperm(agent_occ_gt.shape[0])[:(num_step * 10)]] = 1
|
896 |
+
# agent_occ_eval_mask[agent_occ_gt == -1] = 0
|
897 |
+
|
898 |
+
# pt_occ_eval_mask = torch.zeros_like(pt_occ_gt)
|
899 |
+
# pt_occ_eval_mask[torch.randperm(pt_occ_gt.shape[0])[:(num_step * 300)]] = 1
|
900 |
+
# pt_occ_eval_mask[pt_occ_gt == -1] = 0
|
901 |
+
|
902 |
+
# 3.
|
903 |
+
agent_occ = agent_occ.reshape(num_step, self.agent_limit, -1)
|
904 |
+
pt_occ = pt_occ.reshape(num_step, self.pt_limit, -1)
|
905 |
+
agent_occ_gt = agent_grid_token_idx[select_agent].transpose(0, 1)
|
906 |
+
pt_occ_gt = pt_grid_token_idx[:, select_pt]
|
907 |
+
agent_occ_eval_mask = agent_occ_gt != -1
|
908 |
+
pt_occ_eval_mask = pt_occ_gt != -1
|
909 |
+
|
910 |
+
agent_occ = agent_occ[:, :agent_occ_gt.shape[1]]
|
911 |
+
pt_occ = pt_occ[:, :pt_occ_gt.shape[1]]
|
912 |
+
|
913 |
+
return {'occ_decoder': True,
|
914 |
+
'num_step': num_step,
|
915 |
+
'num_agent': self.agent_limit, # num_agent
|
916 |
+
'num_pt': self.pt_limit, # num_pt
|
917 |
+
'agent_occ': agent_occ,
|
918 |
+
'agent_occ_gt': agent_occ_gt,
|
919 |
+
'agent_occ_eval_mask': agent_occ_eval_mask.bool(),
|
920 |
+
'pt_occ': pt_occ,
|
921 |
+
'pt_occ_gt': pt_occ_gt,
|
922 |
+
'pt_occ_eval_mask': pt_occ_eval_mask.bool(),
|
923 |
+
}
|
924 |
+
|
925 |
+
def inference(self, *args, **kwargs):
|
926 |
+
return self(*args, **kwargs)
|
927 |
+
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/modules/smart_decoder.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch_geometric.data import HeteroData
|
5 |
+
from dev.modules.attr_tokenizer import Attr_Tokenizer
|
6 |
+
from dev.modules.agent_decoder import SMARTAgentDecoder
|
7 |
+
from dev.modules.occ_decoder import SMARTOccDecoder
|
8 |
+
from dev.modules.map_decoder import SMARTMapDecoder
|
9 |
+
|
10 |
+
|
11 |
+
DECODER = {'agent_decoder': SMARTAgentDecoder,
|
12 |
+
'occ_decoder': SMARTOccDecoder}
|
13 |
+
|
14 |
+
|
15 |
+
class SMARTDecoder(nn.Module):
|
16 |
+
|
17 |
+
def __init__(self,
|
18 |
+
decoder_type: str,
|
19 |
+
dataset: str,
|
20 |
+
input_dim: int,
|
21 |
+
hidden_dim: int,
|
22 |
+
num_historical_steps: int,
|
23 |
+
pl2pl_radius: float,
|
24 |
+
time_span: Optional[int],
|
25 |
+
pl2a_radius: float,
|
26 |
+
pl2seed_radius: float,
|
27 |
+
a2a_radius: float,
|
28 |
+
a2sa_radius: float,
|
29 |
+
pl2sa_radius: float,
|
30 |
+
num_freq_bands: int,
|
31 |
+
num_map_layers: int,
|
32 |
+
num_agent_layers: int,
|
33 |
+
num_heads: int,
|
34 |
+
head_dim: int,
|
35 |
+
dropout: float,
|
36 |
+
map_token: Dict,
|
37 |
+
token_size=512,
|
38 |
+
attr_tokenizer: Attr_Tokenizer=None,
|
39 |
+
predict_motion: bool=False,
|
40 |
+
predict_state: bool=False,
|
41 |
+
predict_map: bool=False,
|
42 |
+
predict_occ: bool=False,
|
43 |
+
use_grid_token: bool=False,
|
44 |
+
state_token: Dict[str, int]=None,
|
45 |
+
seed_size: int=5,
|
46 |
+
buffer_size: int=32,
|
47 |
+
num_recurrent_steps_val: int=-1,
|
48 |
+
loss_weight: dict=None,
|
49 |
+
logger=None) -> None:
|
50 |
+
|
51 |
+
super(SMARTDecoder, self).__init__()
|
52 |
+
|
53 |
+
self.map_encoder = SMARTMapDecoder(
|
54 |
+
dataset=dataset,
|
55 |
+
input_dim=input_dim,
|
56 |
+
hidden_dim=hidden_dim,
|
57 |
+
num_historical_steps=num_historical_steps,
|
58 |
+
pl2pl_radius=pl2pl_radius,
|
59 |
+
num_freq_bands=num_freq_bands,
|
60 |
+
num_layers=num_map_layers,
|
61 |
+
num_heads=num_heads,
|
62 |
+
head_dim=head_dim,
|
63 |
+
dropout=dropout,
|
64 |
+
map_token=map_token,
|
65 |
+
)
|
66 |
+
|
67 |
+
assert decoder_type in list(DECODER.keys()), f"Unsupport decoder type: {decoder_type}"
|
68 |
+
self.agent_encoder = DECODER[decoder_type](
|
69 |
+
dataset=dataset,
|
70 |
+
input_dim=input_dim,
|
71 |
+
hidden_dim=hidden_dim,
|
72 |
+
num_historical_steps=num_historical_steps,
|
73 |
+
time_span=time_span,
|
74 |
+
pl2a_radius=pl2a_radius,
|
75 |
+
pl2seed_radius=pl2seed_radius,
|
76 |
+
a2a_radius=a2a_radius,
|
77 |
+
a2sa_radius=a2sa_radius,
|
78 |
+
pl2sa_radius=pl2sa_radius,
|
79 |
+
num_freq_bands=num_freq_bands,
|
80 |
+
num_layers=num_agent_layers,
|
81 |
+
num_heads=num_heads,
|
82 |
+
head_dim=head_dim,
|
83 |
+
dropout=dropout,
|
84 |
+
token_size=token_size,
|
85 |
+
attr_tokenizer=attr_tokenizer,
|
86 |
+
predict_motion=predict_motion,
|
87 |
+
predict_state=predict_state,
|
88 |
+
predict_map=predict_map,
|
89 |
+
predict_occ=predict_occ,
|
90 |
+
state_token=state_token,
|
91 |
+
use_grid_token=use_grid_token,
|
92 |
+
seed_size=seed_size,
|
93 |
+
buffer_size=buffer_size,
|
94 |
+
num_recurrent_steps_val=num_recurrent_steps_val,
|
95 |
+
loss_weight=loss_weight,
|
96 |
+
logger=logger,
|
97 |
+
)
|
98 |
+
self.map_enc = None
|
99 |
+
self.predict_motion = predict_motion
|
100 |
+
self.predict_state = predict_state
|
101 |
+
self.predict_map = predict_map
|
102 |
+
self.predict_occ = predict_occ
|
103 |
+
self.data_keys = ["agent_valid_mask", "category", "valid_mask", "av_index", "scenario_id", "shape"]
|
104 |
+
|
105 |
+
def get_agent_inputs(self, data: HeteroData) -> Dict[str, torch.Tensor]:
|
106 |
+
return self.agent_encoder.get_inputs(data)
|
107 |
+
|
108 |
+
def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]:
|
109 |
+
map_enc = self.map_encoder(data)
|
110 |
+
|
111 |
+
agent_enc = {}
|
112 |
+
if self.predict_motion or self.predict_state or self.predict_occ:
|
113 |
+
agent_enc = self.agent_encoder(data, map_enc)
|
114 |
+
|
115 |
+
return {**map_enc, **agent_enc, **{k: data[k] for k in self.data_keys}}
|
116 |
+
|
117 |
+
def inference(self, data: HeteroData) -> Dict[str, torch.Tensor]:
|
118 |
+
map_enc = self.map_encoder(data)
|
119 |
+
|
120 |
+
agent_enc = {}
|
121 |
+
if self.predict_motion or self.predict_state or self.predict_occ:
|
122 |
+
agent_enc = self.agent_encoder.inference(data, map_enc)
|
123 |
+
|
124 |
+
return {**map_enc, **agent_enc, **{k: data[k] for k in self.data_keys}}
|
125 |
+
|
126 |
+
def inference_no_map(self, data: HeteroData, map_enc) -> Dict[str, torch.Tensor]:
|
127 |
+
agent_enc = self.agent_encoder.inference(data, map_enc)
|
128 |
+
return {**map_enc, **agent_enc}
|
129 |
+
|
130 |
+
def insert_agent(self, data: HeteroData) -> Dict[str, torch.Tensor]:
|
131 |
+
map_enc = self.map_encoder(data)
|
132 |
+
agent_enc = self.agent_encoder.insert(data, map_enc)
|
133 |
+
return {**map_enc, **agent_enc, **{k: data[k] for k in self.data_keys}}
|
134 |
+
|
135 |
+
def predict_nearest_pos(self, data: HeteroData, rank) -> Dict[str, torch.Tensor]:
|
136 |
+
map_enc = self.map_encoder(data)
|
137 |
+
self.agent_encoder.predict_nearest_pos(data, map_enc, rank)
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/utils/cluster_reader.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import pickle
|
3 |
+
import pandas as pd
|
4 |
+
import json
|
5 |
+
|
6 |
+
|
7 |
+
class LoadScenarioFromCeph:
|
8 |
+
def __init__(self):
|
9 |
+
from petrel_client.client import Client
|
10 |
+
self.file_client = Client('~/petreloss.conf')
|
11 |
+
|
12 |
+
def list(self, dir_path):
|
13 |
+
return list(self.file_client.list(dir_path))
|
14 |
+
|
15 |
+
def save(self, data, url):
|
16 |
+
self.file_client.put(url, pickle.dumps(data))
|
17 |
+
|
18 |
+
def read_correct_csv(self, scenario_path):
|
19 |
+
output = pd.read_csv(io.StringIO(self.file_client.get(scenario_path).decode('utf-8')), engine="python")
|
20 |
+
return output
|
21 |
+
|
22 |
+
def contains(self, url):
|
23 |
+
return self.file_client.contains(url)
|
24 |
+
|
25 |
+
def read_string(self, csv_url):
|
26 |
+
from io import StringIO
|
27 |
+
df = pd.read_csv(StringIO(str(self.file_client.get(csv_url), 'utf-8')), sep='\s+', low_memory=False)
|
28 |
+
return df
|
29 |
+
|
30 |
+
def read(self, scenario_path):
|
31 |
+
with io.BytesIO(self.file_client.get(scenario_path)) as f:
|
32 |
+
datas = pickle.load(f)
|
33 |
+
return datas
|
34 |
+
|
35 |
+
def read_json(self, path):
|
36 |
+
with io.BytesIO(self.file_client.get(path)) as f:
|
37 |
+
data = json.load(f)
|
38 |
+
return data
|
39 |
+
|
40 |
+
def read_csv(self, scenario_path):
|
41 |
+
return pickle.loads(self.file_client.get(scenario_path))
|
42 |
+
|
43 |
+
def read_model(self, model_path):
|
44 |
+
with io.BytesIO(self.file_client.get(model_path)) as f:
|
45 |
+
pass
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/utils/func.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import time
|
3 |
+
import os
|
4 |
+
import yaml
|
5 |
+
import easydict
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from rich.console import Console
|
10 |
+
from typing import Any, List, Optional, Mapping
|
11 |
+
|
12 |
+
from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
|
13 |
+
|
14 |
+
|
15 |
+
CONSOLE = Console(width=128)
|
16 |
+
|
17 |
+
|
18 |
+
def check_nan_inf(t, s):
|
19 |
+
assert not torch.isinf(t).any(), f"{s} is inf, {t}"
|
20 |
+
assert not torch.isnan(t).any(), f"{s} is nan, {t}"
|
21 |
+
|
22 |
+
|
23 |
+
def safe_list_index(ls: List[Any], elem: Any) -> Optional[int]:
|
24 |
+
try:
|
25 |
+
return ls.index(elem)
|
26 |
+
except ValueError:
|
27 |
+
return None
|
28 |
+
|
29 |
+
|
30 |
+
def angle_between_2d_vectors(
|
31 |
+
ctr_vector: torch.Tensor,
|
32 |
+
nbr_vector: torch.Tensor) -> torch.Tensor:
|
33 |
+
return torch.atan2(ctr_vector[..., 0] * nbr_vector[..., 1] - ctr_vector[..., 1] * nbr_vector[..., 0],
|
34 |
+
(ctr_vector[..., :2] * nbr_vector[..., :2]).sum(dim=-1))
|
35 |
+
|
36 |
+
|
37 |
+
def angle_between_3d_vectors(
|
38 |
+
ctr_vector: torch.Tensor,
|
39 |
+
nbr_vector: torch.Tensor) -> torch.Tensor:
|
40 |
+
return torch.atan2(torch.cross(ctr_vector, nbr_vector, dim=-1).norm(p=2, dim=-1),
|
41 |
+
(ctr_vector * nbr_vector).sum(dim=-1))
|
42 |
+
|
43 |
+
|
44 |
+
def side_to_directed_lineseg(
|
45 |
+
query_point: torch.Tensor,
|
46 |
+
start_point: torch.Tensor,
|
47 |
+
end_point: torch.Tensor) -> str:
|
48 |
+
cond = ((end_point[0] - start_point[0]) * (query_point[1] - start_point[1]) -
|
49 |
+
(end_point[1] - start_point[1]) * (query_point[0] - start_point[0]))
|
50 |
+
if cond > 0:
|
51 |
+
return 'LEFT'
|
52 |
+
elif cond < 0:
|
53 |
+
return 'RIGHT'
|
54 |
+
else:
|
55 |
+
return 'CENTER'
|
56 |
+
|
57 |
+
|
58 |
+
def wrap_angle(
|
59 |
+
angle: torch.Tensor,
|
60 |
+
min_val: float = -math.pi,
|
61 |
+
max_val: float = math.pi) -> torch.Tensor:
|
62 |
+
return min_val + (angle + max_val) % (max_val - min_val)
|
63 |
+
|
64 |
+
|
65 |
+
def load_config_act(path):
|
66 |
+
""" load config file"""
|
67 |
+
with open(path, 'r') as f:
|
68 |
+
cfg = yaml.load(f, Loader=yaml.FullLoader)
|
69 |
+
return easydict.EasyDict(cfg)
|
70 |
+
|
71 |
+
|
72 |
+
def load_config_init(path):
|
73 |
+
""" load config file"""
|
74 |
+
path = os.path.join('init/configs', f'{path}.yaml')
|
75 |
+
with open(path, 'r') as f:
|
76 |
+
cfg = yaml.load(f, Loader=yaml.FullLoader)
|
77 |
+
return cfg
|
78 |
+
|
79 |
+
|
80 |
+
class Logging:
|
81 |
+
|
82 |
+
def make_log_dir(self, dirname='logs'):
|
83 |
+
now_dir = os.path.dirname(__file__)
|
84 |
+
path = os.path.join(now_dir, dirname)
|
85 |
+
path = os.path.normpath(path)
|
86 |
+
if not os.path.exists(path):
|
87 |
+
os.mkdir(path)
|
88 |
+
return path
|
89 |
+
|
90 |
+
def get_log_filename(self):
|
91 |
+
filename = "{}.log".format(time.strftime("%Y-%m-%d-%H%M%S", time.localtime()))
|
92 |
+
filename = os.path.join(self.make_log_dir(), filename)
|
93 |
+
filename = os.path.normpath(filename)
|
94 |
+
return filename
|
95 |
+
|
96 |
+
def log(self, level='DEBUG', name="simagent"):
|
97 |
+
logger = logging.getLogger(name)
|
98 |
+
level = getattr(logging, level)
|
99 |
+
logger.setLevel(level)
|
100 |
+
if not logger.handlers:
|
101 |
+
sh = logging.StreamHandler()
|
102 |
+
fh = logging.FileHandler(filename=self.get_log_filename(), mode='a',encoding="utf-8")
|
103 |
+
fmt = logging.Formatter("%(asctime)s-%(levelname)s-%(filename)s-Line:%(lineno)d-Message:%(message)s")
|
104 |
+
sh.setFormatter(fmt=fmt)
|
105 |
+
fh.setFormatter(fmt=fmt)
|
106 |
+
logger.addHandler(sh)
|
107 |
+
logger.addHandler(fh)
|
108 |
+
return logger
|
109 |
+
|
110 |
+
def add_log(self, logger, level='DEBUG'):
|
111 |
+
level = getattr(logging, level)
|
112 |
+
logger.setLevel(level)
|
113 |
+
if not logger.handlers:
|
114 |
+
sh = logging.StreamHandler()
|
115 |
+
fh = logging.FileHandler(filename=self.get_log_filename(), mode='a',encoding="utf-8")
|
116 |
+
fmt = logging.Formatter("%(asctime)s-%(levelname)s-%(filename)s-Line:%(lineno)d-Message:%(message)s")
|
117 |
+
sh.setFormatter(fmt=fmt)
|
118 |
+
fh.setFormatter(fmt=fmt)
|
119 |
+
logger.addHandler(sh)
|
120 |
+
logger.addHandler(fh)
|
121 |
+
return logger
|
122 |
+
|
123 |
+
|
124 |
+
# Adapted from 'CatK'
|
125 |
+
class RankedLogger(logging.LoggerAdapter):
|
126 |
+
"""A multi-GPU-friendly python command line logger."""
|
127 |
+
|
128 |
+
def __init__(
|
129 |
+
self,
|
130 |
+
name: str = __name__,
|
131 |
+
rank_zero_only: bool = False,
|
132 |
+
extra: Optional[Mapping[str, object]] = None,
|
133 |
+
) -> None:
|
134 |
+
"""Initializes a multi-GPU-friendly python command line logger that logs on all processes
|
135 |
+
with their rank prefixed in the log message.
|
136 |
+
|
137 |
+
:param name: The name of the logger. Default is ``__name__``.
|
138 |
+
:param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
|
139 |
+
:param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
|
140 |
+
"""
|
141 |
+
logger = logging.getLogger(name)
|
142 |
+
super().__init__(logger=logger, extra=extra)
|
143 |
+
self.rank_zero_only = rank_zero_only
|
144 |
+
|
145 |
+
def log(
|
146 |
+
self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
|
147 |
+
) -> None:
|
148 |
+
"""Delegate a log call to the underlying logger, after prefixing its message with the rank
|
149 |
+
of the process it's being logged from. If `'rank'` is provided, then the log will only
|
150 |
+
occur on that rank/process.
|
151 |
+
|
152 |
+
:param level: The level to log at. Look at `logging.__init__.py` for more information.
|
153 |
+
:param msg: The message to log.
|
154 |
+
:param rank: The rank to log at.
|
155 |
+
:param args: Additional args to pass to the underlying logging function.
|
156 |
+
:param kwargs: Any additional keyword args to pass to the underlying logging function.
|
157 |
+
"""
|
158 |
+
if self.isEnabledFor(level):
|
159 |
+
msg, kwargs = self.process(msg, kwargs)
|
160 |
+
current_rank = getattr(rank_zero_only, "rank", None)
|
161 |
+
if current_rank is None:
|
162 |
+
raise RuntimeError(
|
163 |
+
"The `rank_zero_only.rank` needs to be set before use"
|
164 |
+
)
|
165 |
+
msg = rank_prefixed_message(msg, current_rank)
|
166 |
+
if self.rank_zero_only:
|
167 |
+
if current_rank == 0:
|
168 |
+
self.logger.log(level, msg, *args, **kwargs)
|
169 |
+
else:
|
170 |
+
if rank is None:
|
171 |
+
self.logger.log(level, msg, *args, **kwargs)
|
172 |
+
elif current_rank == rank:
|
173 |
+
self.logger.log(level, msg, *args, **kwargs)
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
def weight_init(m: nn.Module) -> None:
|
178 |
+
if isinstance(m, nn.Linear):
|
179 |
+
nn.init.xavier_uniform_(m.weight)
|
180 |
+
if m.bias is not None:
|
181 |
+
nn.init.zeros_(m.bias)
|
182 |
+
elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
183 |
+
fan_in = m.in_channels / m.groups
|
184 |
+
fan_out = m.out_channels / m.groups
|
185 |
+
bound = (6.0 / (fan_in + fan_out)) ** 0.5
|
186 |
+
nn.init.uniform_(m.weight, -bound, bound)
|
187 |
+
if m.bias is not None:
|
188 |
+
nn.init.zeros_(m.bias)
|
189 |
+
elif isinstance(m, nn.Embedding):
|
190 |
+
nn.init.normal_(m.weight, mean=0.0, std=0.02)
|
191 |
+
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
|
192 |
+
nn.init.ones_(m.weight)
|
193 |
+
nn.init.zeros_(m.bias)
|
194 |
+
elif isinstance(m, nn.LayerNorm):
|
195 |
+
nn.init.ones_(m.weight)
|
196 |
+
nn.init.zeros_(m.bias)
|
197 |
+
elif isinstance(m, nn.MultiheadAttention):
|
198 |
+
if m.in_proj_weight is not None:
|
199 |
+
fan_in = m.embed_dim
|
200 |
+
fan_out = m.embed_dim
|
201 |
+
bound = (6.0 / (fan_in + fan_out)) ** 0.5
|
202 |
+
nn.init.uniform_(m.in_proj_weight, -bound, bound)
|
203 |
+
else:
|
204 |
+
nn.init.xavier_uniform_(m.q_proj_weight)
|
205 |
+
nn.init.xavier_uniform_(m.k_proj_weight)
|
206 |
+
nn.init.xavier_uniform_(m.v_proj_weight)
|
207 |
+
if m.in_proj_bias is not None:
|
208 |
+
nn.init.zeros_(m.in_proj_bias)
|
209 |
+
nn.init.xavier_uniform_(m.out_proj.weight)
|
210 |
+
if m.out_proj.bias is not None:
|
211 |
+
nn.init.zeros_(m.out_proj.bias)
|
212 |
+
if m.bias_k is not None:
|
213 |
+
nn.init.normal_(m.bias_k, mean=0.0, std=0.02)
|
214 |
+
if m.bias_v is not None:
|
215 |
+
nn.init.normal_(m.bias_v, mean=0.0, std=0.02)
|
216 |
+
elif isinstance(m, (nn.LSTM, nn.LSTMCell)):
|
217 |
+
for name, param in m.named_parameters():
|
218 |
+
if 'weight_ih' in name:
|
219 |
+
for ih in param.chunk(4, 0):
|
220 |
+
nn.init.xavier_uniform_(ih)
|
221 |
+
elif 'weight_hh' in name:
|
222 |
+
for hh in param.chunk(4, 0):
|
223 |
+
nn.init.orthogonal_(hh)
|
224 |
+
elif 'weight_hr' in name:
|
225 |
+
nn.init.xavier_uniform_(param)
|
226 |
+
elif 'bias_ih' in name:
|
227 |
+
nn.init.zeros_(param)
|
228 |
+
elif 'bias_hh' in name:
|
229 |
+
nn.init.zeros_(param)
|
230 |
+
nn.init.ones_(param.chunk(4, 0)[1])
|
231 |
+
elif isinstance(m, (nn.GRU, nn.GRUCell)):
|
232 |
+
for name, param in m.named_parameters():
|
233 |
+
if 'weight_ih' in name:
|
234 |
+
for ih in param.chunk(3, 0):
|
235 |
+
nn.init.xavier_uniform_(ih)
|
236 |
+
elif 'weight_hh' in name:
|
237 |
+
for hh in param.chunk(3, 0):
|
238 |
+
nn.init.orthogonal_(hh)
|
239 |
+
elif 'bias_ih' in name:
|
240 |
+
nn.init.zeros_(param)
|
241 |
+
elif 'bias_hh' in name:
|
242 |
+
nn.init.zeros_(param)
|
243 |
+
|
244 |
+
|
245 |
+
def pos2posemb(pos, num_pos_feats=128, temperature=10000):
|
246 |
+
|
247 |
+
scale = 2 * math.pi
|
248 |
+
pos = pos * scale
|
249 |
+
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)
|
250 |
+
dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
|
251 |
+
|
252 |
+
D = pos.shape[-1]
|
253 |
+
pos_dims = []
|
254 |
+
for i in range(D):
|
255 |
+
pos_dim_i = pos[..., i, None] / dim_t
|
256 |
+
pos_dim_i = torch.stack((pos_dim_i[..., 0::2].sin(), pos_dim_i[..., 1::2].cos()), dim=-1).flatten(-2)
|
257 |
+
pos_dims.append(pos_dim_i)
|
258 |
+
posemb = torch.cat(pos_dims, dim=-1)
|
259 |
+
|
260 |
+
return posemb
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/utils/graph.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
from typing import List, Optional, Tuple, Union
|
4 |
+
from torch_geometric.utils import coalesce
|
5 |
+
from torch_geometric.utils import degree
|
6 |
+
|
7 |
+
|
8 |
+
def add_edges(
|
9 |
+
from_edge_index: torch.Tensor,
|
10 |
+
to_edge_index: torch.Tensor,
|
11 |
+
from_edge_attr: Optional[torch.Tensor] = None,
|
12 |
+
to_edge_attr: Optional[torch.Tensor] = None,
|
13 |
+
replace: bool = True) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
14 |
+
from_edge_index = from_edge_index.to(device=to_edge_index.device, dtype=to_edge_index.dtype)
|
15 |
+
mask = ((to_edge_index[0].unsqueeze(-1) == from_edge_index[0].unsqueeze(0)) &
|
16 |
+
(to_edge_index[1].unsqueeze(-1) == from_edge_index[1].unsqueeze(0)))
|
17 |
+
if replace:
|
18 |
+
to_mask = mask.any(dim=1)
|
19 |
+
if from_edge_attr is not None and to_edge_attr is not None:
|
20 |
+
from_edge_attr = from_edge_attr.to(device=to_edge_attr.device, dtype=to_edge_attr.dtype)
|
21 |
+
to_edge_attr = torch.cat([to_edge_attr[~to_mask], from_edge_attr], dim=0)
|
22 |
+
to_edge_index = torch.cat([to_edge_index[:, ~to_mask], from_edge_index], dim=1)
|
23 |
+
else:
|
24 |
+
from_mask = mask.any(dim=0)
|
25 |
+
if from_edge_attr is not None and to_edge_attr is not None:
|
26 |
+
from_edge_attr = from_edge_attr.to(device=to_edge_attr.device, dtype=to_edge_attr.dtype)
|
27 |
+
to_edge_attr = torch.cat([to_edge_attr, from_edge_attr[~from_mask]], dim=0)
|
28 |
+
to_edge_index = torch.cat([to_edge_index, from_edge_index[:, ~from_mask]], dim=1)
|
29 |
+
return to_edge_index, to_edge_attr
|
30 |
+
|
31 |
+
|
32 |
+
def merge_edges(
|
33 |
+
edge_indices: List[torch.Tensor],
|
34 |
+
edge_attrs: Optional[List[torch.Tensor]] = None,
|
35 |
+
reduce: str = 'add') -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
36 |
+
edge_index = torch.cat(edge_indices, dim=1)
|
37 |
+
if edge_attrs is not None:
|
38 |
+
edge_attr = torch.cat(edge_attrs, dim=0)
|
39 |
+
else:
|
40 |
+
edge_attr = None
|
41 |
+
return coalesce(edge_index=edge_index, edge_attr=edge_attr, reduce=reduce)
|
42 |
+
|
43 |
+
|
44 |
+
def complete_graph(
|
45 |
+
num_nodes: Union[int, Tuple[int, int]],
|
46 |
+
ptr: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
|
47 |
+
loop: bool = False,
|
48 |
+
device: Optional[Union[torch.device, str]] = None) -> torch.Tensor:
|
49 |
+
if ptr is None:
|
50 |
+
if isinstance(num_nodes, int):
|
51 |
+
num_src, num_dst = num_nodes, num_nodes
|
52 |
+
else:
|
53 |
+
num_src, num_dst = num_nodes
|
54 |
+
edge_index = torch.cartesian_prod(torch.arange(num_src, dtype=torch.long, device=device),
|
55 |
+
torch.arange(num_dst, dtype=torch.long, device=device)).t()
|
56 |
+
else:
|
57 |
+
if isinstance(ptr, torch.Tensor):
|
58 |
+
ptr_src, ptr_dst = ptr, ptr
|
59 |
+
num_src_batch = num_dst_batch = ptr[1:] - ptr[:-1]
|
60 |
+
else:
|
61 |
+
ptr_src, ptr_dst = ptr
|
62 |
+
num_src_batch = ptr_src[1:] - ptr_src[:-1]
|
63 |
+
num_dst_batch = ptr_dst[1:] - ptr_dst[:-1]
|
64 |
+
edge_index = torch.cat(
|
65 |
+
[torch.cartesian_prod(torch.arange(num_src, dtype=torch.long, device=device),
|
66 |
+
torch.arange(num_dst, dtype=torch.long, device=device)) + p
|
67 |
+
for num_src, num_dst, p in zip(num_src_batch, num_dst_batch, torch.stack([ptr_src, ptr_dst], dim=1))],
|
68 |
+
dim=0)
|
69 |
+
edge_index = edge_index.t()
|
70 |
+
if isinstance(num_nodes, int) and not loop:
|
71 |
+
edge_index = edge_index[:, edge_index[0] != edge_index[1]]
|
72 |
+
return edge_index.contiguous()
|
73 |
+
|
74 |
+
|
75 |
+
def bipartite_dense_to_sparse(adj: torch.Tensor) -> torch.Tensor:
|
76 |
+
index = adj.nonzero(as_tuple=True)
|
77 |
+
if len(index) == 3:
|
78 |
+
batch_src = index[0] * adj.size(1)
|
79 |
+
batch_dst = index[0] * adj.size(2)
|
80 |
+
index = (batch_src + index[1], batch_dst + index[2])
|
81 |
+
return torch.stack(index, dim=0)
|
82 |
+
|
83 |
+
|
84 |
+
def unbatch(
|
85 |
+
src: torch.Tensor,
|
86 |
+
batch: torch.Tensor,
|
87 |
+
dim: int = 0) -> List[torch.Tensor]:
|
88 |
+
sizes = degree(batch, dtype=torch.long).tolist()
|
89 |
+
return src.split(sizes, dim)
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/utils/metrics.py
ADDED
@@ -0,0 +1,692 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import itertools
|
4 |
+
import multiprocessing as mp
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from pathlib import Path
|
7 |
+
from torch.nn import CrossEntropyLoss
|
8 |
+
from torch_scatter import gather_csr
|
9 |
+
from torch_scatter import segment_csr
|
10 |
+
from torchmetrics import Metric
|
11 |
+
from typing import Optional, Tuple, Dict, List
|
12 |
+
|
13 |
+
|
14 |
+
__all__ = ['minADE', 'minFDE', 'TokenCls', 'StateAccuracy', 'GridOverlapRate']
|
15 |
+
|
16 |
+
|
17 |
+
class CustomCrossEntropyLoss(CrossEntropyLoss):
|
18 |
+
|
19 |
+
def __init__(self, label_smoothing=0.0, reduction='mean'):
|
20 |
+
super(CustomCrossEntropyLoss, self).__init__()
|
21 |
+
self.label_smoothing = label_smoothing
|
22 |
+
self.reduction = reduction
|
23 |
+
|
24 |
+
def forward(self, input, target):
|
25 |
+
num_classes = input.size(1)
|
26 |
+
|
27 |
+
log_probs = F.log_softmax(input, dim=1)
|
28 |
+
|
29 |
+
with torch.no_grad():
|
30 |
+
smooth_target = torch.zeros_like(log_probs).scatter_(1, target.unsqueeze(1), 1)
|
31 |
+
smooth_target = smooth_target * (1 - self.label_smoothing) + self.label_smoothing / num_classes
|
32 |
+
|
33 |
+
loss = -torch.sum(log_probs * smooth_target, dim=1)
|
34 |
+
|
35 |
+
if self.reduction == 'mean':
|
36 |
+
return loss.mean()
|
37 |
+
elif self.reduction == 'sum':
|
38 |
+
return loss.sum()
|
39 |
+
else:
|
40 |
+
return loss
|
41 |
+
|
42 |
+
|
43 |
+
def topk(
|
44 |
+
max_guesses: int,
|
45 |
+
pred: torch.Tensor,
|
46 |
+
prob: Optional[torch.Tensor] = None,
|
47 |
+
ptr: Optional[torch.Tensor] = None,
|
48 |
+
joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
|
49 |
+
max_guesses = min(max_guesses, pred.size(1))
|
50 |
+
if max_guesses == pred.size(1):
|
51 |
+
if prob is not None:
|
52 |
+
prob = prob / prob.sum(dim=-1, keepdim=True)
|
53 |
+
else:
|
54 |
+
prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
|
55 |
+
return pred, prob
|
56 |
+
else:
|
57 |
+
if prob is not None:
|
58 |
+
if joint:
|
59 |
+
if ptr is None:
|
60 |
+
inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True),
|
61 |
+
k=max_guesses, dim=-1, largest=True, sorted=True)[1]
|
62 |
+
inds_topk = inds_topk.repeat(pred.size(0), 1)
|
63 |
+
else:
|
64 |
+
inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr,
|
65 |
+
reduce='mean'),
|
66 |
+
k=max_guesses, dim=-1, largest=True, sorted=True)[1]
|
67 |
+
inds_topk = gather_csr(src=inds_topk, indptr=ptr)
|
68 |
+
else:
|
69 |
+
inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1]
|
70 |
+
pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
|
71 |
+
prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
|
72 |
+
prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True)
|
73 |
+
else:
|
74 |
+
pred_topk = pred[:, :max_guesses]
|
75 |
+
prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
|
76 |
+
return pred_topk, prob_topk
|
77 |
+
|
78 |
+
|
79 |
+
def topkind(
|
80 |
+
max_guesses: int,
|
81 |
+
pred: torch.Tensor,
|
82 |
+
prob: Optional[torch.Tensor] = None,
|
83 |
+
ptr: Optional[torch.Tensor] = None,
|
84 |
+
joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
85 |
+
max_guesses = min(max_guesses, pred.size(1))
|
86 |
+
if max_guesses == pred.size(1):
|
87 |
+
if prob is not None:
|
88 |
+
prob = prob / prob.sum(dim=-1, keepdim=True)
|
89 |
+
else:
|
90 |
+
prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
|
91 |
+
return pred, prob, None
|
92 |
+
else:
|
93 |
+
if prob is not None:
|
94 |
+
if joint:
|
95 |
+
if ptr is None:
|
96 |
+
inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True),
|
97 |
+
k=max_guesses, dim=-1, largest=True, sorted=True)[1]
|
98 |
+
inds_topk = inds_topk.repeat(pred.size(0), 1)
|
99 |
+
else:
|
100 |
+
inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr,
|
101 |
+
reduce='mean'),
|
102 |
+
k=max_guesses, dim=-1, largest=True, sorted=True)[1]
|
103 |
+
inds_topk = gather_csr(src=inds_topk, indptr=ptr)
|
104 |
+
else:
|
105 |
+
inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1]
|
106 |
+
pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
|
107 |
+
prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
|
108 |
+
prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True)
|
109 |
+
else:
|
110 |
+
pred_topk = pred[:, :max_guesses]
|
111 |
+
prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
|
112 |
+
return pred_topk, prob_topk, inds_topk
|
113 |
+
|
114 |
+
|
115 |
+
def valid_filter(
|
116 |
+
pred: torch.Tensor,
|
117 |
+
target: torch.Tensor,
|
118 |
+
prob: Optional[torch.Tensor] = None,
|
119 |
+
valid_mask: Optional[torch.Tensor] = None,
|
120 |
+
ptr: Optional[torch.Tensor] = None,
|
121 |
+
keep_invalid_final_step: bool = True) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
122 |
+
torch.Tensor, torch.Tensor]:
|
123 |
+
if valid_mask is None:
|
124 |
+
valid_mask = target.new_ones(target.size()[:-1], dtype=torch.bool)
|
125 |
+
if keep_invalid_final_step:
|
126 |
+
filter_mask = valid_mask.any(dim=-1)
|
127 |
+
else:
|
128 |
+
filter_mask = valid_mask[:, -1]
|
129 |
+
pred = pred[filter_mask]
|
130 |
+
target = target[filter_mask]
|
131 |
+
if prob is not None:
|
132 |
+
prob = prob[filter_mask]
|
133 |
+
valid_mask = valid_mask[filter_mask]
|
134 |
+
if ptr is not None:
|
135 |
+
num_nodes_batch = segment_csr(src=filter_mask.long(), indptr=ptr, reduce='sum')
|
136 |
+
ptr = num_nodes_batch.new_zeros((num_nodes_batch.size(0) + 1,))
|
137 |
+
torch.cumsum(num_nodes_batch, dim=0, out=ptr[1:])
|
138 |
+
else:
|
139 |
+
ptr = target.new_tensor([0, target.size(0)])
|
140 |
+
return pred, target, prob, valid_mask, ptr
|
141 |
+
|
142 |
+
|
143 |
+
def new_batch_nms(pred_trajs, dist_thresh, num_ret_modes=6):
|
144 |
+
"""
|
145 |
+
|
146 |
+
Args:
|
147 |
+
pred_trajs (batch_size, num_modes, num_timestamps, 7)
|
148 |
+
pred_scores (batch_size, num_modes):
|
149 |
+
dist_thresh (float):
|
150 |
+
num_ret_modes (int, optional): Defaults to 6.
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)
|
154 |
+
ret_scores (batch_size, num_ret_modes)
|
155 |
+
ret_idxs (batch_size, num_ret_modes)
|
156 |
+
"""
|
157 |
+
batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape
|
158 |
+
pred_goals = pred_trajs[:, :, -1, :]
|
159 |
+
dist = (pred_goals[:, :, None, 0:2] - pred_goals[:, None, :, 0:2]).norm(dim=-1)
|
160 |
+
nearby_neighbor = dist < dist_thresh
|
161 |
+
pred_scores = nearby_neighbor.sum(dim=-1) / num_modes
|
162 |
+
|
163 |
+
sorted_idxs = pred_scores.argsort(dim=-1, descending=True)
|
164 |
+
bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)
|
165 |
+
sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]
|
166 |
+
sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7)
|
167 |
+
sorted_pred_goals = sorted_pred_trajs[:, :, -1, :] # (batch_size, num_modes, 7)
|
168 |
+
|
169 |
+
dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
|
170 |
+
point_cover_mask = (dist < dist_thresh)
|
171 |
+
|
172 |
+
point_val = sorted_pred_scores.clone() # (batch_size, N)
|
173 |
+
point_val_selected = torch.zeros_like(point_val) # (batch_size, N)
|
174 |
+
|
175 |
+
ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()
|
176 |
+
ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim)
|
177 |
+
ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes)
|
178 |
+
bs_idxs = torch.arange(batch_size).type_as(ret_idxs)
|
179 |
+
|
180 |
+
for k in range(num_ret_modes):
|
181 |
+
cur_idx = point_val.argmax(dim=-1) # (batch_size)
|
182 |
+
ret_idxs[:, k] = cur_idx
|
183 |
+
|
184 |
+
new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N)
|
185 |
+
point_val = point_val * (~new_cover_mask).float() # (batch_size, N)
|
186 |
+
point_val_selected[bs_idxs, cur_idx] = -1
|
187 |
+
point_val += point_val_selected
|
188 |
+
|
189 |
+
ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx]
|
190 |
+
ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]
|
191 |
+
|
192 |
+
bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)
|
193 |
+
|
194 |
+
ret_idxs = sorted_idxs[bs_idxs, ret_idxs]
|
195 |
+
return ret_trajs, ret_scores, ret_idxs
|
196 |
+
|
197 |
+
|
198 |
+
def batch_nms(pred_trajs, pred_scores,
|
199 |
+
dist_thresh, num_ret_modes=6,
|
200 |
+
mode='static', speed=None):
|
201 |
+
"""
|
202 |
+
|
203 |
+
Args:
|
204 |
+
pred_trajs (batch_size, num_modes, num_timestamps, 7)
|
205 |
+
pred_scores (batch_size, num_modes):
|
206 |
+
dist_thresh (float):
|
207 |
+
num_ret_modes (int, optional): Defaults to 6.
|
208 |
+
|
209 |
+
Returns:
|
210 |
+
ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)
|
211 |
+
ret_scores (batch_size, num_ret_modes)
|
212 |
+
ret_idxs (batch_size, num_ret_modes)
|
213 |
+
"""
|
214 |
+
batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape
|
215 |
+
|
216 |
+
sorted_idxs = pred_scores.argsort(dim=-1, descending=True)
|
217 |
+
bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)
|
218 |
+
sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]
|
219 |
+
sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7)
|
220 |
+
sorted_pred_goals = sorted_pred_trajs[:, :, -1, :] # (batch_size, num_modes, 7)
|
221 |
+
|
222 |
+
if mode == "speed":
|
223 |
+
scale = torch.ones(batch_size).to(sorted_pred_goals.device)
|
224 |
+
lon_dist_thresh = 4 * scale
|
225 |
+
lat_dist_thresh = 0.5 * scale
|
226 |
+
lon_dist = (sorted_pred_goals[:, :, None, [0]] - sorted_pred_goals[:, None, :, [0]]).norm(dim=-1)
|
227 |
+
lat_dist = (sorted_pred_goals[:, :, None, [1]] - sorted_pred_goals[:, None, :, [1]]).norm(dim=-1)
|
228 |
+
point_cover_mask = (lon_dist < lon_dist_thresh[:, None, None]) & (lat_dist < lat_dist_thresh[:, None, None])
|
229 |
+
else:
|
230 |
+
dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
|
231 |
+
point_cover_mask = (dist < dist_thresh)
|
232 |
+
|
233 |
+
point_val = sorted_pred_scores.clone() # (batch_size, N)
|
234 |
+
point_val_selected = torch.zeros_like(point_val) # (batch_size, N)
|
235 |
+
|
236 |
+
ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()
|
237 |
+
ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim)
|
238 |
+
ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes)
|
239 |
+
bs_idxs = torch.arange(batch_size).type_as(ret_idxs)
|
240 |
+
|
241 |
+
for k in range(num_ret_modes):
|
242 |
+
cur_idx = point_val.argmax(dim=-1) # (batch_size)
|
243 |
+
ret_idxs[:, k] = cur_idx
|
244 |
+
|
245 |
+
new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N)
|
246 |
+
point_val = point_val * (~new_cover_mask).float() # (batch_size, N)
|
247 |
+
point_val_selected[bs_idxs, cur_idx] = -1
|
248 |
+
point_val += point_val_selected
|
249 |
+
|
250 |
+
ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx]
|
251 |
+
ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]
|
252 |
+
|
253 |
+
bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)
|
254 |
+
|
255 |
+
ret_idxs = sorted_idxs[bs_idxs, ret_idxs]
|
256 |
+
return ret_trajs, ret_scores, ret_idxs
|
257 |
+
|
258 |
+
|
259 |
+
def batch_nms_token(pred_trajs, pred_scores,
|
260 |
+
dist_thresh, num_ret_modes=6,
|
261 |
+
mode='static', speed=None):
|
262 |
+
"""
|
263 |
+
Args:
|
264 |
+
pred_trajs (batch_size, num_modes, num_timestamps, 7)
|
265 |
+
pred_scores (batch_size, num_modes):
|
266 |
+
dist_thresh (float):
|
267 |
+
num_ret_modes (int, optional): Defaults to 6.
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)
|
271 |
+
ret_scores (batch_size, num_ret_modes)
|
272 |
+
ret_idxs (batch_size, num_ret_modes)
|
273 |
+
"""
|
274 |
+
batch_size, num_modes, num_feat_dim = pred_trajs.shape
|
275 |
+
|
276 |
+
sorted_idxs = pred_scores.argsort(dim=-1, descending=True)
|
277 |
+
bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)
|
278 |
+
sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]
|
279 |
+
sorted_pred_goals = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7)
|
280 |
+
|
281 |
+
if mode == "nearby":
|
282 |
+
dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
|
283 |
+
values, indices = torch.topk(dist, 5, dim=-1, largest=False)
|
284 |
+
thresh_hold = values[..., -1]
|
285 |
+
point_cover_mask = dist < thresh_hold[..., None]
|
286 |
+
else:
|
287 |
+
dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
|
288 |
+
point_cover_mask = (dist < dist_thresh)
|
289 |
+
|
290 |
+
point_val = sorted_pred_scores.clone() # (batch_size, N)
|
291 |
+
point_val_selected = torch.zeros_like(point_val) # (batch_size, N)
|
292 |
+
|
293 |
+
ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()
|
294 |
+
ret_goals = sorted_pred_goals.new_zeros(batch_size, num_ret_modes, num_feat_dim)
|
295 |
+
ret_scores = sorted_pred_goals.new_zeros(batch_size, num_ret_modes)
|
296 |
+
bs_idxs = torch.arange(batch_size).type_as(ret_idxs)
|
297 |
+
|
298 |
+
for k in range(num_ret_modes):
|
299 |
+
cur_idx = point_val.argmax(dim=-1) # (batch_size)
|
300 |
+
ret_idxs[:, k] = cur_idx
|
301 |
+
|
302 |
+
new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N)
|
303 |
+
point_val = point_val * (~new_cover_mask).float() # (batch_size, N)
|
304 |
+
point_val_selected[bs_idxs, cur_idx] = -1
|
305 |
+
point_val += point_val_selected
|
306 |
+
|
307 |
+
ret_goals[:, k] = sorted_pred_goals[bs_idxs, cur_idx]
|
308 |
+
ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]
|
309 |
+
|
310 |
+
bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)
|
311 |
+
|
312 |
+
ret_idxs = sorted_idxs[bs_idxs, ret_idxs]
|
313 |
+
return ret_goals, ret_scores, ret_idxs
|
314 |
+
|
315 |
+
|
316 |
+
class TokenCls(Metric):
|
317 |
+
|
318 |
+
def __init__(self,
|
319 |
+
max_guesses: int = 6,
|
320 |
+
**kwargs) -> None:
|
321 |
+
super(TokenCls, self).__init__(**kwargs)
|
322 |
+
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
|
323 |
+
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
|
324 |
+
self.max_guesses = max_guesses
|
325 |
+
|
326 |
+
def update(self,
|
327 |
+
pred: torch.Tensor,
|
328 |
+
target: torch.Tensor,
|
329 |
+
valid_mask: Optional[torch.Tensor] = None) -> None:
|
330 |
+
target = target[..., None]
|
331 |
+
acc = (pred[:, :self.max_guesses] == target).any(dim=1) * valid_mask
|
332 |
+
self.sum += acc.sum()
|
333 |
+
self.count += valid_mask.sum()
|
334 |
+
|
335 |
+
def compute(self) -> torch.Tensor:
|
336 |
+
return self.sum / self.count
|
337 |
+
|
338 |
+
|
339 |
+
class minMultiFDE(Metric):
|
340 |
+
|
341 |
+
def __init__(self,
|
342 |
+
max_guesses: int = 6,
|
343 |
+
**kwargs) -> None:
|
344 |
+
super(minMultiFDE, self).__init__(**kwargs)
|
345 |
+
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
|
346 |
+
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
|
347 |
+
self.max_guesses = max_guesses
|
348 |
+
|
349 |
+
def update(self,
|
350 |
+
pred: torch.Tensor,
|
351 |
+
target: torch.Tensor,
|
352 |
+
prob: Optional[torch.Tensor] = None,
|
353 |
+
valid_mask: Optional[torch.Tensor] = None,
|
354 |
+
keep_invalid_final_step: bool = True) -> None:
|
355 |
+
pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step)
|
356 |
+
pred_topk, _ = topk(self.max_guesses, pred, prob)
|
357 |
+
inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1)
|
358 |
+
self.sum += torch.norm(pred_topk[torch.arange(pred.size(0)), :, inds_last] -
|
359 |
+
target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2),
|
360 |
+
p=2, dim=-1).min(dim=-1)[0].sum()
|
361 |
+
self.count += pred.size(0)
|
362 |
+
|
363 |
+
def compute(self) -> torch.Tensor:
|
364 |
+
return self.sum / self.count
|
365 |
+
|
366 |
+
|
367 |
+
class minFDE(Metric):
|
368 |
+
|
369 |
+
def __init__(self,
|
370 |
+
max_guesses: int = 6,
|
371 |
+
**kwargs) -> None:
|
372 |
+
super(minFDE, self).__init__(**kwargs)
|
373 |
+
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
|
374 |
+
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
|
375 |
+
self.max_guesses = max_guesses
|
376 |
+
self.eval_timestep = 70
|
377 |
+
|
378 |
+
def update(self,
|
379 |
+
pred: torch.Tensor,
|
380 |
+
target: torch.Tensor,
|
381 |
+
prob: Optional[torch.Tensor] = None,
|
382 |
+
valid_mask: Optional[torch.Tensor] = None,
|
383 |
+
keep_invalid_final_step: bool = True) -> None:
|
384 |
+
eval_timestep = min(self.eval_timestep, pred.shape[1]) - 1
|
385 |
+
self.sum += ((torch.norm(pred[:, eval_timestep-1:eval_timestep] - target[:, eval_timestep-1:eval_timestep], p=2, dim=-1) *
|
386 |
+
valid_mask[:, eval_timestep-1].unsqueeze(1)).sum(dim=-1)).sum()
|
387 |
+
self.count += valid_mask[:, eval_timestep-1].sum()
|
388 |
+
|
389 |
+
def compute(self) -> torch.Tensor:
|
390 |
+
return self.sum / self.count
|
391 |
+
|
392 |
+
|
393 |
+
class minMultiADE(Metric):
|
394 |
+
|
395 |
+
def __init__(self,
|
396 |
+
max_guesses: int = 6,
|
397 |
+
**kwargs) -> None:
|
398 |
+
super(minMultiADE, self).__init__(**kwargs)
|
399 |
+
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
|
400 |
+
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
|
401 |
+
self.max_guesses = max_guesses
|
402 |
+
|
403 |
+
def update(self,
|
404 |
+
pred: torch.Tensor,
|
405 |
+
target: torch.Tensor,
|
406 |
+
prob: Optional[torch.Tensor] = None,
|
407 |
+
valid_mask: Optional[torch.Tensor] = None,
|
408 |
+
keep_invalid_final_step: bool = True,
|
409 |
+
min_criterion: str = 'FDE') -> None:
|
410 |
+
pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step)
|
411 |
+
pred_topk, _ = topk(self.max_guesses, pred, prob)
|
412 |
+
if min_criterion == 'FDE':
|
413 |
+
inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1)
|
414 |
+
inds_best = torch.norm(
|
415 |
+
pred_topk[torch.arange(pred.size(0)), :, inds_last] -
|
416 |
+
target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).argmin(dim=-1)
|
417 |
+
self.sum += ((torch.norm(pred_topk[torch.arange(pred.size(0)), inds_best] - target, p=2, dim=-1) *
|
418 |
+
valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum()
|
419 |
+
elif min_criterion == 'ADE':
|
420 |
+
self.sum += ((torch.norm(pred_topk - target.unsqueeze(1), p=2, dim=-1) *
|
421 |
+
valid_mask.unsqueeze(1)).sum(dim=-1).min(dim=-1)[0] / valid_mask.sum(dim=-1)).sum()
|
422 |
+
else:
|
423 |
+
raise ValueError('{} is not a valid criterion'.format(min_criterion))
|
424 |
+
self.count += pred.size(0)
|
425 |
+
|
426 |
+
def compute(self) -> torch.Tensor:
|
427 |
+
return self.sum / self.count
|
428 |
+
|
429 |
+
|
430 |
+
class minADE(Metric):
|
431 |
+
|
432 |
+
def __init__(self,
|
433 |
+
max_guesses: int = 6,
|
434 |
+
**kwargs) -> None:
|
435 |
+
super(minADE, self).__init__(**kwargs)
|
436 |
+
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
|
437 |
+
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
|
438 |
+
self.max_guesses = max_guesses
|
439 |
+
self.eval_timestep = 70
|
440 |
+
|
441 |
+
def update(self,
|
442 |
+
pred: torch.Tensor,
|
443 |
+
target: torch.Tensor,
|
444 |
+
prob: Optional[torch.Tensor] = None,
|
445 |
+
valid_mask: Optional[torch.Tensor] = None,
|
446 |
+
keep_invalid_final_step: bool = True,
|
447 |
+
min_criterion: str = 'ADE') -> None:
|
448 |
+
# pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step)
|
449 |
+
# pred_topk, _ = topk(self.max_guesses, pred, prob)
|
450 |
+
# if min_criterion == 'FDE':
|
451 |
+
# inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1)
|
452 |
+
# inds_best = torch.norm(
|
453 |
+
# pred[torch.arange(pred.size(0)), :, inds_last] -
|
454 |
+
# target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).argmin(dim=-1)
|
455 |
+
# self.sum += ((torch.norm(pred[torch.arange(pred.size(0)), inds_best] - target, p=2, dim=-1) *
|
456 |
+
# valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum()
|
457 |
+
# elif min_criterion == 'ADE':
|
458 |
+
# self.sum += ((torch.norm(pred - target.unsqueeze(1), p=2, dim=-1) *
|
459 |
+
# valid_mask.unsqueeze(1)).sum(dim=-1).min(dim=-1)[0] / valid_mask.sum(dim=-1)).sum()
|
460 |
+
# else:
|
461 |
+
# raise ValueError('{} is not a valid criterion'.format(min_criterion))
|
462 |
+
eval_timestep = min(self.eval_timestep, pred.shape[1])
|
463 |
+
self.sum += ((torch.norm(pred[:, :eval_timestep] - target[:, :eval_timestep], p=2, dim=-1) * valid_mask[:, :eval_timestep]).sum(dim=-1) / pred.shape[1]).sum()
|
464 |
+
self.count += valid_mask[:, :eval_timestep].any(dim=-1).sum()
|
465 |
+
|
466 |
+
def compute(self) -> torch.Tensor:
|
467 |
+
return self.sum / self.count
|
468 |
+
|
469 |
+
|
470 |
+
class AverageMeter(Metric):
|
471 |
+
|
472 |
+
def __init__(self, **kwargs) -> None:
|
473 |
+
super(AverageMeter, self).__init__(**kwargs)
|
474 |
+
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
|
475 |
+
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
|
476 |
+
|
477 |
+
def update(self, val: torch.Tensor) -> None:
|
478 |
+
self.sum += val.sum()
|
479 |
+
self.count += val.numel()
|
480 |
+
|
481 |
+
def compute(self) -> torch.Tensor:
|
482 |
+
return self.sum / self.count
|
483 |
+
|
484 |
+
|
485 |
+
class StateAccuracy(Metric):
|
486 |
+
|
487 |
+
def __init__(self, state_token: Dict[str, int], **kwargs) -> None:
|
488 |
+
super().__init__(**kwargs)
|
489 |
+
self.invalid_state = int(state_token['invalid'])
|
490 |
+
self.valid_state = int(state_token['valid'])
|
491 |
+
self.enter_state = int(state_token['enter'])
|
492 |
+
self.exit_state = int(state_token['exit'])
|
493 |
+
|
494 |
+
self.add_state('valid', default=torch.tensor(0), dist_reduce_fx='sum')
|
495 |
+
self.add_state('valid_count', default=torch.tensor(0), dist_reduce_fx='sum')
|
496 |
+
self.add_state('invalid', default=torch.tensor(0), dist_reduce_fx='sum')
|
497 |
+
self.add_state('invalid_count', default=torch.tensor(0), dist_reduce_fx='sum')
|
498 |
+
|
499 |
+
def update(self,
|
500 |
+
state_idx: torch.Tensor,
|
501 |
+
valid_mask: Optional[torch.Tensor] = None) -> None:
|
502 |
+
|
503 |
+
num_agent, num_step = state_idx.shape
|
504 |
+
|
505 |
+
# check the evaluation outputs
|
506 |
+
for a in range(num_agent):
|
507 |
+
bos_idx = torch.where(state_idx[a] == self.enter_state)[0]
|
508 |
+
eos_idx = torch.where(state_idx[a] == self.exit_state)[0]
|
509 |
+
bos = 0
|
510 |
+
eos = num_step - 1
|
511 |
+
if len(bos_idx) > 0:
|
512 |
+
bos = bos_idx[0]
|
513 |
+
self.invalid += (state_idx[a, :bos] == self.invalid_state).sum()
|
514 |
+
self.invalid_count += len(state_idx[a, :bos])
|
515 |
+
if len(eos_idx) > 0:
|
516 |
+
eos = eos_idx[0]
|
517 |
+
self.invalid += (state_idx[a, eos + 1:] == self.invalid_state).sum()
|
518 |
+
self.invalid_count += len(state_idx[a, eos + 1:])
|
519 |
+
self.valid += (state_idx[a, bos + 1 : eos] == self.valid_state).sum()
|
520 |
+
self.valid_count += len(state_idx[a, bos + 1 : eos])
|
521 |
+
|
522 |
+
# check the tokenization
|
523 |
+
if valid_mask is not None:
|
524 |
+
|
525 |
+
state_idx = state_idx.roll(shifts=1, dims=1)
|
526 |
+
|
527 |
+
for a in range(num_agent):
|
528 |
+
bos_idx = torch.where(state_idx[a] == self.enter_state)[0]
|
529 |
+
eos_idx = torch.where(state_idx[a] == self.exit_state)[0]
|
530 |
+
bos = 0
|
531 |
+
eos = num_step - 1
|
532 |
+
if len(bos_idx) > 0:
|
533 |
+
bos = bos_idx[0]
|
534 |
+
self.invalid += (valid_mask[a, :bos] == 0).sum()
|
535 |
+
self.invalid_count += len(valid_mask[a, :bos])
|
536 |
+
if len(eos_idx) > 0:
|
537 |
+
eos = eos_idx[-1]
|
538 |
+
self.invalid += (valid_mask[a, eos + 1:] != 0).sum()
|
539 |
+
self.invalid_count += len(valid_mask[a, eos + 1:])
|
540 |
+
self.invalid += (((state_idx[a, bos : eos + 1] > 0) != valid_mask[a, bos : eos + 1])[valid_mask[a, bos : eos + 1] == 0]).sum()
|
541 |
+
self.invalid_count += (valid_mask[a, bos : eos + 1] == 0).sum()
|
542 |
+
self.valid += (((state_idx[a, bos : eos + 1] > 0) != valid_mask[a, bos : eos + 1])[valid_mask[a, bos : eos + 1] == 1]).sum()
|
543 |
+
self.valid_count += (valid_mask[a, bos : eos + 1] == 1).sum()
|
544 |
+
|
545 |
+
def compute(self) -> Dict[str, torch.Tensor]:
|
546 |
+
return {'valid': self.valid / self.valid_count,
|
547 |
+
'invalid': self.invalid / self.invalid_count,
|
548 |
+
}
|
549 |
+
|
550 |
+
def __repr__(self):
|
551 |
+
head = "Results of " + self.__class__.__name__
|
552 |
+
results = self.compute()
|
553 |
+
body = [
|
554 |
+
"valid: {}".format(results['valid']),
|
555 |
+
"invalid: {}".format(results['invalid']),
|
556 |
+
]
|
557 |
+
_repr_indent = 4
|
558 |
+
lines = [head] + [" " * _repr_indent + line for line in body]
|
559 |
+
return "\n".join(lines)
|
560 |
+
|
561 |
+
|
562 |
+
class GridOverlapRate(Metric):
|
563 |
+
|
564 |
+
def __init__(self, num_step, state_token, seed_size, **kwargs) -> None:
|
565 |
+
super().__init__(**kwargs)
|
566 |
+
self.num_step = num_step
|
567 |
+
self.enter_state = int(state_token['enter'])
|
568 |
+
self.seed_size = seed_size
|
569 |
+
self.add_state('num_overlap_t', default=torch.zeros(num_step).long(), dist_reduce_fx='sum')
|
570 |
+
self.add_state('num_insert_agent_t', default=torch.zeros(num_step).long(), dist_reduce_fx='sum')
|
571 |
+
self.add_state('num_total_agent_t', default=torch.zeros(num_step).long(), dist_reduce_fx='sum')
|
572 |
+
self.add_state('num_exceed_seed_t', default=torch.zeros(num_step).long(), dist_reduce_fx='sum')
|
573 |
+
|
574 |
+
def update(self,
|
575 |
+
state_token: torch.Tensor,
|
576 |
+
grid_index: torch.Tensor) -> None:
|
577 |
+
|
578 |
+
for t in range(self.num_step):
|
579 |
+
inrange_mask_t = grid_index[:, t] != -1
|
580 |
+
insert_mask_t = (state_token[:, t] == self.enter_state) & inrange_mask_t
|
581 |
+
self.num_total_agent_t[t] += inrange_mask_t.sum()
|
582 |
+
self.num_insert_agent_t[t] += insert_mask_t.sum()
|
583 |
+
self.num_exceed_seed_t[t] += int(insert_mask_t.sum() >= self.seed_size)
|
584 |
+
|
585 |
+
occupied_grids = set(grid_index[:, t][(grid_index[:, t] != -1) & (state_token[:, t] != self.enter_state)].tolist())
|
586 |
+
to_inserted_grids = grid_index[:, t][(grid_index[:, t] != -1) & (state_token[:, t] == self.enter_state)].tolist()
|
587 |
+
while to_inserted_grids:
|
588 |
+
grid_index_t_i = to_inserted_grids.pop()
|
589 |
+
if grid_index_t_i in occupied_grids:
|
590 |
+
self.num_overlap_t[t] += 1
|
591 |
+
occupied_grids.add(grid_index_t_i)
|
592 |
+
|
593 |
+
def compute(self) -> Dict[str, torch.Tensor]:
|
594 |
+
overlap_rate_t = self.num_overlap_t / self.num_insert_agent_t
|
595 |
+
overlap_rate_t.nan_to_num_()
|
596 |
+
return {'num_overlap_t': self.num_overlap_t,
|
597 |
+
'num_insert_agent_t': self.num_insert_agent_t,
|
598 |
+
'num_total_agent_t': self.num_total_agent_t,
|
599 |
+
'overlap_rate_t': overlap_rate_t,
|
600 |
+
'num_exceed_seed_t': self.num_exceed_seed_t,
|
601 |
+
}
|
602 |
+
|
603 |
+
def __repr__(self):
|
604 |
+
head = "Results of " + self.__class__.__name__
|
605 |
+
results = self.compute()
|
606 |
+
body = [
|
607 |
+
"num_overlap_t: {}".format(results['num_overlap_t'].tolist()),
|
608 |
+
"num_insert_agent_t: {}".format(results['num_insert_agent_t'].tolist()),
|
609 |
+
"num_total_agent_t: {}".format(results['num_total_agent_t'].tolist()),
|
610 |
+
"overlap_rate_t: {}".format(results['overlap_rate_t'].tolist()),
|
611 |
+
"num_exceed_seed_t: {}".format(results['num_exceed_seed_t'].tolist()),
|
612 |
+
]
|
613 |
+
_repr_indent = 4
|
614 |
+
lines = [head] + [" " * _repr_indent + line for line in body]
|
615 |
+
return "\n".join(lines)
|
616 |
+
|
617 |
+
|
618 |
+
class NumInsertAccuracy(Metric):
|
619 |
+
|
620 |
+
def __init__(self, state_token: Dict[str, int], **kwargs) -> None:
|
621 |
+
super().__init__(**kwargs)
|
622 |
+
self.invalid_state = int(state_token['invalid'])
|
623 |
+
self.valid_state = int(state_token['valid'])
|
624 |
+
self.enter_state = int(state_token['enter'])
|
625 |
+
self.exit_state = int(state_token['exit'])
|
626 |
+
|
627 |
+
self.add_state('valid', default=torch.tensor(0), dist_reduce_fx='sum')
|
628 |
+
self.add_state('valid_count', default=torch.tensor(0), dist_reduce_fx='sum')
|
629 |
+
self.add_state('invalid', default=torch.tensor(0), dist_reduce_fx='sum')
|
630 |
+
self.add_state('invalid_count', default=torch.tensor(0), dist_reduce_fx='sum')
|
631 |
+
|
632 |
+
def update(self,
|
633 |
+
state_idx: torch.Tensor,
|
634 |
+
valid_mask: Optional[torch.Tensor] = None) -> None:
|
635 |
+
|
636 |
+
num_agent, num_step = state_idx.shape
|
637 |
+
|
638 |
+
# check the evaluation outputs
|
639 |
+
for a in range(num_agent):
|
640 |
+
bos_idx = torch.where(state_idx[a] == self.enter_state)[0]
|
641 |
+
eos_idx = torch.where(state_idx[a] == self.exit_state)[0]
|
642 |
+
bos = 0
|
643 |
+
eos = num_step - 1
|
644 |
+
if len(bos_idx) > 0:
|
645 |
+
bos = bos_idx[0]
|
646 |
+
self.invalid += (state_idx[a, :bos] == self.invalid_state).sum()
|
647 |
+
self.invalid_count += len(state_idx[a, :bos])
|
648 |
+
if len(eos_idx) > 0:
|
649 |
+
eos = eos_idx[0]
|
650 |
+
self.invalid += (state_idx[a, eos + 1:] == self.invalid_state).sum()
|
651 |
+
self.invalid_count += len(state_idx[a, eos + 1:])
|
652 |
+
self.valid += (state_idx[a, bos + 1 : eos] == self.valid_state).sum()
|
653 |
+
self.valid_count += len(state_idx[a, bos + 1 : eos])
|
654 |
+
|
655 |
+
# check the tokenization
|
656 |
+
if valid_mask is not None:
|
657 |
+
|
658 |
+
state_idx = state_idx.roll(shifts=1, dims=1)
|
659 |
+
|
660 |
+
for a in range(num_agent):
|
661 |
+
bos_idx = torch.where(state_idx[a] == self.enter_state)[0]
|
662 |
+
eos_idx = torch.where(state_idx[a] == self.exit_state)[0]
|
663 |
+
bos = 0
|
664 |
+
eos = num_step - 1
|
665 |
+
if len(bos_idx) > 0:
|
666 |
+
bos = bos_idx[0]
|
667 |
+
self.invalid += (valid_mask[a, :bos] == 0).sum()
|
668 |
+
self.invalid_count += len(valid_mask[a, :bos])
|
669 |
+
if len(eos_idx) > 0:
|
670 |
+
eos = eos_idx[-1]
|
671 |
+
self.invalid += (valid_mask[a, eos + 1:] != 0).sum()
|
672 |
+
self.invalid_count += len(valid_mask[a, eos + 1:])
|
673 |
+
self.invalid += (((state_idx[a, bos : eos + 1] > 0) != valid_mask[a, bos : eos + 1])[valid_mask[a, bos : eos + 1] == 0]).sum()
|
674 |
+
self.invalid_count += (valid_mask[a, bos : eos + 1] == 0).sum()
|
675 |
+
self.valid += (((state_idx[a, bos : eos + 1] > 0) != valid_mask[a, bos : eos + 1])[valid_mask[a, bos : eos + 1] == 1]).sum()
|
676 |
+
self.valid_count += (valid_mask[a, bos : eos + 1] == 1).sum()
|
677 |
+
|
678 |
+
def compute(self) -> Dict[str, torch.Tensor]:
|
679 |
+
return {'valid': self.valid / self.valid_count,
|
680 |
+
'invalid': self.invalid / self.invalid_count,
|
681 |
+
}
|
682 |
+
|
683 |
+
def __repr__(self):
|
684 |
+
head = "Results of " + self.__class__.__name__
|
685 |
+
results = self.compute()
|
686 |
+
body = [
|
687 |
+
"valid: {}".format(results['valid']),
|
688 |
+
"invalid: {}".format(results['invalid']),
|
689 |
+
]
|
690 |
+
_repr_indent = 4
|
691 |
+
lines = [head] + [" " * _repr_indent + line for line in body]
|
692 |
+
return "\n".join(lines)
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/dev/utils/visualization.py
ADDED
@@ -0,0 +1,1145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import pickle
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import tensorflow as tf
|
7 |
+
import numpy as np
|
8 |
+
import numpy.typing as npt
|
9 |
+
import fnmatch
|
10 |
+
import seaborn as sns
|
11 |
+
import matplotlib.axes as Axes
|
12 |
+
import matplotlib.transforms as mtransforms
|
13 |
+
from PIL import Image
|
14 |
+
from functools import wraps
|
15 |
+
from typing import Sequence, Union, Optional
|
16 |
+
from tqdm import tqdm
|
17 |
+
from typing import List, Literal
|
18 |
+
from argparse import ArgumentParser
|
19 |
+
from scipy.ndimage.filters import gaussian_filter
|
20 |
+
from matplotlib.patches import FancyBboxPatch, Polygon, Rectangle, Circle
|
21 |
+
from matplotlib.collections import LineCollection
|
22 |
+
from torch_geometric.data import HeteroData, Dataset
|
23 |
+
from waymo_open_dataset.protos import scenario_pb2
|
24 |
+
|
25 |
+
from dev.utils.func import CONSOLE
|
26 |
+
from dev.modules.attr_tokenizer import Attr_Tokenizer
|
27 |
+
from dev.datasets.preprocess import TokenProcessor, cal_polygon_contour, AGENT_TYPE
|
28 |
+
from dev.datasets.scalable_dataset import WaymoTargetBuilder
|
29 |
+
|
30 |
+
|
31 |
+
__all__ = ['plot_occ_grid', 'plot_interact_edge', 'plot_map_edge', 'plot_insert_grid', 'plot_binary_map',
|
32 |
+
'plot_map_token', 'plot_prob_seed', 'plot_scenario', 'get_heatmap', 'draw_heatmap', 'plot_val', 'plot_tokenize']
|
33 |
+
|
34 |
+
|
35 |
+
def safe_run(func):
|
36 |
+
|
37 |
+
@wraps(func)
|
38 |
+
def wrapper1(*args, **kwargs):
|
39 |
+
try:
|
40 |
+
return func(*args, **kwargs)
|
41 |
+
except Exception as e:
|
42 |
+
print(e)
|
43 |
+
return
|
44 |
+
|
45 |
+
@wraps(func)
|
46 |
+
def wrapper2(*args, **kwargs):
|
47 |
+
return func(*args, **kwargs)
|
48 |
+
|
49 |
+
if int(os.getenv('DEBUG', 0)):
|
50 |
+
return wrapper2
|
51 |
+
else:
|
52 |
+
return wrapper1
|
53 |
+
|
54 |
+
|
55 |
+
@safe_run
|
56 |
+
def plot_occ_grid(scenario_id, occ, gt_occ=None, save_path='', mode='agent', prefix=''):
|
57 |
+
|
58 |
+
def generate_box_edges(matrix, find_value=1):
|
59 |
+
y, x = np.where(matrix == find_value)
|
60 |
+
edges = []
|
61 |
+
|
62 |
+
for xi, yi in zip(x, y):
|
63 |
+
edges.append([(xi - 0.5, yi - 0.5), (xi + 0.5, yi - 0.5)])
|
64 |
+
edges.append([(xi + 0.5, yi - 0.5), (xi + 0.5, yi + 0.5)])
|
65 |
+
edges.append([(xi + 0.5, yi + 0.5), (xi - 0.5, yi + 0.5)])
|
66 |
+
edges.append([(xi - 0.5, yi + 0.5), (xi - 0.5, yi - 0.5)])
|
67 |
+
|
68 |
+
return edges
|
69 |
+
|
70 |
+
os.makedirs(save_path, exist_ok=True)
|
71 |
+
n = int(math.sqrt(occ.shape[-1]))
|
72 |
+
|
73 |
+
plot_n = 3
|
74 |
+
plot_t = 5
|
75 |
+
|
76 |
+
occ_list = []
|
77 |
+
for i in range(plot_n):
|
78 |
+
for j in range(plot_t):
|
79 |
+
occ_list.append(occ[i, j].reshape(n, n))
|
80 |
+
|
81 |
+
occ_gt_list = []
|
82 |
+
if gt_occ is not None:
|
83 |
+
for i in range(plot_n):
|
84 |
+
for j in range(plot_t):
|
85 |
+
occ_gt_list.append(gt_occ[i, j].reshape(n, n))
|
86 |
+
|
87 |
+
row_labels = [f'n={n}' for n in range(plot_n)]
|
88 |
+
col_labels = [f't={t}' for t in range(plot_t)]
|
89 |
+
|
90 |
+
fig, axes = plt.subplots(plot_n, plot_t, figsize=(9, 6))
|
91 |
+
plt.subplots_adjust(wspace=0.1, hspace=0.1)
|
92 |
+
|
93 |
+
for i, ax in enumerate(axes.flat):
|
94 |
+
# NOTE: do not set vmin and vamx!
|
95 |
+
ax.imshow(occ_list[i], cmap='viridis', interpolation='nearest')
|
96 |
+
ax.axis('off')
|
97 |
+
|
98 |
+
if occ_gt_list:
|
99 |
+
gt_edges = generate_box_edges(occ_gt_list[i])
|
100 |
+
gts = LineCollection(gt_edges, colors='blue', linewidths=0.5)
|
101 |
+
ax.add_collection(gts)
|
102 |
+
insert_edges = generate_box_edges(occ_gt_list[i], find_value=-1)
|
103 |
+
inserts = LineCollection(insert_edges, colors='red', linewidths=0.5)
|
104 |
+
ax.add_collection(inserts)
|
105 |
+
|
106 |
+
ax.add_patch(plt.Rectangle((-0.5, -0.5), occ_list[i].shape[1], occ_list[i].shape[0],
|
107 |
+
linewidth=2, edgecolor='black', facecolor='none'))
|
108 |
+
|
109 |
+
for i, ax in enumerate(axes[:, 0]):
|
110 |
+
ax.annotate(row_labels[i], xy=(-0.1, 0.5), xycoords="axes fraction",
|
111 |
+
fontsize=12, ha="right", va="center", rotation=0)
|
112 |
+
|
113 |
+
for j, ax in enumerate(axes[0, :]):
|
114 |
+
ax.annotate(col_labels[j], xy=(0.5, 1.05), xycoords="axes fraction",
|
115 |
+
fontsize=12, ha="center", va="bottom")
|
116 |
+
|
117 |
+
plt.savefig(os.path.join(save_path, f'{prefix}{scenario_id}_occ_{mode}.png'), dpi=500, bbox_inches='tight')
|
118 |
+
plt.close()
|
119 |
+
|
120 |
+
|
121 |
+
@safe_run
|
122 |
+
def plot_interact_edge(edge_index, scenario_ids, batch_sizes, num_seed, num_step, save_path='interact_edge_map',
|
123 |
+
**kwargs):
|
124 |
+
|
125 |
+
num_batch = len(scenario_ids)
|
126 |
+
batches = torch.cat([
|
127 |
+
torch.arange(num_batch).repeat_interleave(repeats=batch_sizes, dim=0),
|
128 |
+
torch.arange(num_batch).repeat_interleave(repeats=num_seed, dim=0),
|
129 |
+
], dim=0).repeat(num_step).numpy()
|
130 |
+
|
131 |
+
num_agent = batch_sizes.sum() + num_seed * num_batch
|
132 |
+
batch_sizes = torch.nn.functional.pad(batch_sizes, (1, 0), mode='constant', value=0)
|
133 |
+
ptr = torch.cumsum(batch_sizes, dim=0)
|
134 |
+
# assume difference scenarios and different timestep have the same number of seed agents
|
135 |
+
ptr_seed = torch.tensor(np.array([0] + [num_seed] * num_batch), device=ptr.device)
|
136 |
+
|
137 |
+
all_av_index = None
|
138 |
+
if 'av_index' in kwargs:
|
139 |
+
all_av_index = kwargs.pop('av_index').cpu() - ptr[:-1]
|
140 |
+
|
141 |
+
is_bos = np.zeros((batch_sizes.sum(), num_step)).astype(np.bool_)
|
142 |
+
if 'is_bos' in kwargs:
|
143 |
+
is_bos = kwargs.pop('is_bos').cpu().numpy()
|
144 |
+
|
145 |
+
src_index = torch.unique(edge_index[1])
|
146 |
+
for idx, src in enumerate(tqdm(src_index)):
|
147 |
+
|
148 |
+
src_batch = batches[src]
|
149 |
+
|
150 |
+
src_row = src % num_agent
|
151 |
+
if src_row // batch_sizes.sum() > 0:
|
152 |
+
seed_row = src_row % batch_sizes.sum() - ptr_seed[src_batch]
|
153 |
+
src_row = batch_sizes[src_batch + 1] + seed_row
|
154 |
+
else:
|
155 |
+
src_row = src_row - ptr[src_batch]
|
156 |
+
|
157 |
+
src_col = src // (num_agent)
|
158 |
+
src_mask = np.zeros((batch_sizes[src_batch + 1] + num_seed, num_step))
|
159 |
+
src_mask[src_row, src_col] = 1
|
160 |
+
|
161 |
+
tgt_mask = np.zeros((src_mask.shape[0], num_step))
|
162 |
+
tgt_index = edge_index[0, edge_index[1] == src]
|
163 |
+
for tgt in tgt_index:
|
164 |
+
|
165 |
+
tgt_batch = batches[tgt]
|
166 |
+
|
167 |
+
tgt_row = tgt % num_agent
|
168 |
+
if tgt_row // batch_sizes.sum() > 0:
|
169 |
+
seed_row = tgt_row % batch_sizes.sum() - ptr_seed[tgt_batch]
|
170 |
+
tgt_row = batch_sizes[tgt_batch + 1] + seed_row
|
171 |
+
else:
|
172 |
+
tgt_row = tgt_row - ptr[tgt_batch]
|
173 |
+
|
174 |
+
tgt_col = tgt // num_agent
|
175 |
+
tgt_mask[tgt_row, tgt_col] = 1
|
176 |
+
assert tgt_batch == src_batch
|
177 |
+
|
178 |
+
selected_step = tgt_mask.sum(axis=0) > 0
|
179 |
+
if selected_step.sum() > 1:
|
180 |
+
print(f"\nidx={idx}", src.item(), src_row.item(), src_col.item())
|
181 |
+
print(selected_step)
|
182 |
+
print(edge_index[:, edge_index[1] == src].tolist())
|
183 |
+
|
184 |
+
if all_av_index is not None:
|
185 |
+
kwargs['av_index'] = int(all_av_index[src_batch])
|
186 |
+
|
187 |
+
t = kwargs.get('t', src_col)
|
188 |
+
n = kwargs.get('n', 0)
|
189 |
+
is_bos_batch = is_bos[ptr[src_batch] : ptr[src_batch + 1]]
|
190 |
+
plot_binary_map(src_mask, tgt_mask, save_path, suffix=f'_{scenario_ids[src_batch]}_{t:02d}_{n:02d}_{idx:04d}',
|
191 |
+
is_bos=is_bos_batch, **kwargs)
|
192 |
+
|
193 |
+
|
194 |
+
@safe_run
|
195 |
+
def plot_map_edge(edge_index, pos_a, data, save_path='map_edge_map'):
|
196 |
+
|
197 |
+
map_points = data['map_point']['position'][:, :2].cpu().numpy()
|
198 |
+
token_pos = data['pt_token']['position'][:, :2].cpu().numpy()
|
199 |
+
token_heading = data['pt_token']['orientation'].cpu().numpy()
|
200 |
+
num_pt = token_pos.shape[0]
|
201 |
+
|
202 |
+
agent_index = torch.unique(edge_index[1])
|
203 |
+
for i in tqdm(agent_index):
|
204 |
+
xy = pos_a[i].cpu().numpy()
|
205 |
+
pt_index = edge_index[0, edge_index[1] == i].cpu().numpy()
|
206 |
+
pt_index = pt_index % num_pt
|
207 |
+
|
208 |
+
plt.subplots_adjust(left=0.3, right=0.7, top=0.7, bottom=0.3)
|
209 |
+
_, ax = plt.subplots()
|
210 |
+
ax.set_axis_off()
|
211 |
+
|
212 |
+
plot_map_token(ax, map_points, token_pos[pt_index], token_heading[pt_index], colors='blue')
|
213 |
+
|
214 |
+
ax.scatter(xy[0], xy[1], s=0.5, c='red', edgecolors='none')
|
215 |
+
|
216 |
+
os.makedirs(save_path, exist_ok=True)
|
217 |
+
plt.savefig(os.path.join(save_path, f'map_{i}.png'), dpi=600, bbox_inches='tight')
|
218 |
+
plt.close()
|
219 |
+
|
220 |
+
|
221 |
+
def get_heatmap(x, y, prob, s=3, bins=1000):
|
222 |
+
heatmap, xedges, yedges = np.histogram2d(x, y, bins=bins, weights=prob, density=True)
|
223 |
+
|
224 |
+
heatmap = gaussian_filter(heatmap, sigma=s)
|
225 |
+
|
226 |
+
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
|
227 |
+
return heatmap.T, extent
|
228 |
+
|
229 |
+
|
230 |
+
@safe_run
|
231 |
+
def draw_heatmap(vector, vector_prob, gt_idx):
|
232 |
+
fig, ax = plt.subplots(figsize=(10, 10))
|
233 |
+
vector_prob = vector_prob.cpu().numpy()
|
234 |
+
|
235 |
+
for j in range(vector.shape[0]):
|
236 |
+
if j in gt_idx:
|
237 |
+
color = (0, 0, 1)
|
238 |
+
else:
|
239 |
+
grey_scale = max(0, 0.9 - vector_prob[j])
|
240 |
+
color = (0.9, grey_scale, grey_scale)
|
241 |
+
|
242 |
+
# if lane[j, k, -1] == 0: continue
|
243 |
+
x0, y0, x1, y1, = vector[j, :4]
|
244 |
+
ax.plot((x0, x1), (y0, y1), color=color, linewidth=2)
|
245 |
+
|
246 |
+
return plt
|
247 |
+
|
248 |
+
|
249 |
+
@safe_run
|
250 |
+
def plot_insert_grid(scenario_id, prob, grid, ego_pos, map, save_path='', prefix='', inference=False, indices=None,
|
251 |
+
all_t_in_one=False):
|
252 |
+
|
253 |
+
"""
|
254 |
+
prob: float array of shape (num_step, num_grid)
|
255 |
+
grid: float array of shape (num_grid, 2)
|
256 |
+
"""
|
257 |
+
|
258 |
+
os.makedirs(save_path, exist_ok=True)
|
259 |
+
|
260 |
+
n = int(math.sqrt(prob.shape[1]))
|
261 |
+
|
262 |
+
# grid = grid[:, np.newaxis] + ego_pos[np.newaxis, ...]
|
263 |
+
for t in range(ego_pos.shape[0]):
|
264 |
+
|
265 |
+
plt.subplots_adjust(left=0.3, right=0.7, top=0.7, bottom=0.3)
|
266 |
+
_, ax = plt.subplots()
|
267 |
+
|
268 |
+
# plot probability
|
269 |
+
prob_t = prob[t].reshape(n, n)
|
270 |
+
plt.imshow(prob_t, cmap='viridis', interpolation='nearest')
|
271 |
+
|
272 |
+
if indices is not None:
|
273 |
+
indice = indices[t]
|
274 |
+
|
275 |
+
if isinstance(indice, (int, float, np.int_)):
|
276 |
+
indice = [indice]
|
277 |
+
|
278 |
+
for _indice in indice:
|
279 |
+
if _indice == -1: continue
|
280 |
+
|
281 |
+
row = _indice // n
|
282 |
+
col = _indice % n
|
283 |
+
|
284 |
+
rect = Rectangle((col - 0.5, row - 0.5), 1, 1, edgecolor='red', facecolor='none', lw=2)
|
285 |
+
ax.add_patch(rect)
|
286 |
+
|
287 |
+
ax.grid(False)
|
288 |
+
ax.set_aspect('equal', adjustable='box')
|
289 |
+
|
290 |
+
plt.title('Prob of Rel Position Grid')
|
291 |
+
plt.savefig(os.path.join(save_path, f'{prefix}{scenario_id}_heat_map_{t}.png'), dpi=300, bbox_inches='tight')
|
292 |
+
plt.close()
|
293 |
+
|
294 |
+
if all_t_in_one:
|
295 |
+
break
|
296 |
+
|
297 |
+
|
298 |
+
@safe_run
|
299 |
+
def plot_insert_grid(scenario_id, prob, indices=None, save_path='', prefix='', inference=False):
|
300 |
+
|
301 |
+
"""
|
302 |
+
prob: float array of shape (num_seed, num_step, num_grid)
|
303 |
+
grid: float array of shape (num_grid, 2)
|
304 |
+
"""
|
305 |
+
|
306 |
+
os.makedirs(save_path, exist_ok=True)
|
307 |
+
|
308 |
+
n = int(math.sqrt(prob.shape[-1]))
|
309 |
+
|
310 |
+
plot_n = 3
|
311 |
+
plot_t = 5
|
312 |
+
|
313 |
+
prob_list = []
|
314 |
+
for i in range(plot_n):
|
315 |
+
for j in range(plot_t):
|
316 |
+
prob_list.append(prob[i, j].reshape(n, n))
|
317 |
+
|
318 |
+
indice_list = []
|
319 |
+
if indices is not None:
|
320 |
+
for i in range(plot_n):
|
321 |
+
for j in range(plot_t):
|
322 |
+
indice_list.append(indices[i, j])
|
323 |
+
|
324 |
+
row_labels = [f'n={n}' for n in range(plot_n)]
|
325 |
+
col_labels = [f't={t}' for t in range(plot_t)]
|
326 |
+
|
327 |
+
fig, axes = plt.subplots(plot_n, plot_t, figsize=(9, 6))
|
328 |
+
fig.suptitle('Prob of Insert Position Grid')
|
329 |
+
plt.subplots_adjust(wspace=0.1, hspace=0.1)
|
330 |
+
|
331 |
+
for i, ax in enumerate(axes.flat):
|
332 |
+
ax.imshow(prob_list[i], cmap='viridis', interpolation='nearest')
|
333 |
+
ax.axis('off')
|
334 |
+
|
335 |
+
if indice_list:
|
336 |
+
row = indice_list[i] // n
|
337 |
+
col = indice_list[i] % n
|
338 |
+
rect = Rectangle((col - .5, row - .5), 1, 1, edgecolor='red', facecolor='none', lw=2)
|
339 |
+
ax.add_patch(rect)
|
340 |
+
|
341 |
+
ax.add_patch(plt.Rectangle((-0.5, -0.5), prob_list[i].shape[1], prob_list[i].shape[0],
|
342 |
+
linewidth=2, edgecolor='black', facecolor='none'))
|
343 |
+
|
344 |
+
for i, ax in enumerate(axes[:, 0]):
|
345 |
+
ax.annotate(row_labels[i], xy=(-0.1, 0.5), xycoords="axes fraction",
|
346 |
+
fontsize=12, ha="right", va="center", rotation=0)
|
347 |
+
|
348 |
+
for j, ax in enumerate(axes[0, :]):
|
349 |
+
ax.annotate(col_labels[j], xy=(0.5, 1.05), xycoords="axes fraction",
|
350 |
+
fontsize=12, ha="center", va="bottom")
|
351 |
+
|
352 |
+
ax.grid(False)
|
353 |
+
ax.set_aspect('equal', adjustable='box')
|
354 |
+
|
355 |
+
plt.savefig(os.path.join(save_path, f'{prefix}{scenario_id}_insert_map.png'), dpi=500, bbox_inches='tight')
|
356 |
+
plt.close()
|
357 |
+
|
358 |
+
|
359 |
+
@safe_run
|
360 |
+
def plot_binary_map(src_mask, tgt_mask, save_path='', suffix='', av_index=None, is_bos=None, **kwargs):
|
361 |
+
|
362 |
+
from matplotlib.colors import ListedColormap
|
363 |
+
os.makedirs(save_path, exist_ok=True)
|
364 |
+
|
365 |
+
fig, axes = plt.subplots(1, 2, figsize=(10, 8))
|
366 |
+
|
367 |
+
title = []
|
368 |
+
if kwargs.get('t', None) is not None:
|
369 |
+
t = kwargs['t']
|
370 |
+
title.append(f't={t}')
|
371 |
+
|
372 |
+
if kwargs.get('n', None) is not None:
|
373 |
+
n = kwargs['n']
|
374 |
+
title.append(f'n={n}')
|
375 |
+
|
376 |
+
plt.title(' '.join(title))
|
377 |
+
|
378 |
+
cmap = ListedColormap(['white', 'green'])
|
379 |
+
axes[0].imshow(src_mask, cmap=cmap, interpolation='nearest')
|
380 |
+
|
381 |
+
cmap = ListedColormap(['white', 'orange'])
|
382 |
+
axes[1].imshow(tgt_mask, cmap=cmap, interpolation='nearest')
|
383 |
+
|
384 |
+
if av_index is not None:
|
385 |
+
rect = Rectangle((-0.5, av_index - 0.5), src_mask.shape[1], 1, edgecolor='red', facecolor='none', lw=2)
|
386 |
+
axes[0].add_patch(rect)
|
387 |
+
rect = Rectangle((-0.5, av_index - 0.5), tgt_mask.shape[1], 1, edgecolor='red', facecolor='none', lw=2)
|
388 |
+
axes[1].add_patch(rect)
|
389 |
+
|
390 |
+
if is_bos is not None:
|
391 |
+
rows, cols = np.where(is_bos)
|
392 |
+
for row, col in zip(rows, cols):
|
393 |
+
rect = Rectangle((col - 0.5, row - 0.5), 1, 1, edgecolor='blue', facecolor='none', lw=1)
|
394 |
+
axes[0].add_patch(rect)
|
395 |
+
rect = Rectangle((col - 0.5, row - 0.5), 1, 1, edgecolor='blue', facecolor='none', lw=1)
|
396 |
+
axes[1].add_patch(rect)
|
397 |
+
|
398 |
+
for ax in axes:
|
399 |
+
ax.set_xticks(range(src_mask.shape[1] + 1), minor=False)
|
400 |
+
ax.set_yticks(range(src_mask.shape[0] + 1), minor=False)
|
401 |
+
ax.grid(which='major', color='gray', linestyle='--', linewidth=0.5)
|
402 |
+
|
403 |
+
plt.savefig(os.path.join(save_path, f'map{suffix}.png'), dpi=300, bbox_inches='tight')
|
404 |
+
plt.close()
|
405 |
+
|
406 |
+
|
407 |
+
@safe_run
|
408 |
+
def plot_prob_seed(scenario_id, prob, save_path, prefix='', indices=None):
|
409 |
+
|
410 |
+
os.makedirs(save_path, exist_ok=True)
|
411 |
+
|
412 |
+
plt.figure(figsize=(8, 5))
|
413 |
+
plt.imshow(prob, cmap='viridis', aspect='auto')
|
414 |
+
plt.colorbar()
|
415 |
+
|
416 |
+
plt.title('Seed Probability')
|
417 |
+
|
418 |
+
if indices is not None:
|
419 |
+
|
420 |
+
for col in range(indices.shape[1]):
|
421 |
+
for row in indices[:, col]:
|
422 |
+
|
423 |
+
if row == -1: continue
|
424 |
+
|
425 |
+
rect = Rectangle((col - 0.5, row - 0.5), 1, 1, edgecolor='red', facecolor='none', lw=2)
|
426 |
+
plt.gca().add_patch(rect)
|
427 |
+
|
428 |
+
plt.tight_layout()
|
429 |
+
plt.savefig(os.path.join(save_path, f'{prefix}{scenario_id}_prob_seed.png'), dpi=300, bbox_inches='tight')
|
430 |
+
plt.close()
|
431 |
+
|
432 |
+
|
433 |
+
@safe_run
|
434 |
+
def plot_raw():
|
435 |
+
plt.figure(figsize=(30, 30))
|
436 |
+
plt.rcParams['axes.facecolor']='white'
|
437 |
+
|
438 |
+
data_path = '/u/xiuyu/work/dev4/data/waymo/scenario/training'
|
439 |
+
os.makedirs("data/vis/raw/0/", exist_ok=True)
|
440 |
+
file_list = os.listdir(data_path)
|
441 |
+
|
442 |
+
for cnt_file, file in enumerate(file_list):
|
443 |
+
file_path = os.path.join(data_path, file)
|
444 |
+
dataset = tf.data.TFRecordDataset(file_path, compression_type='')
|
445 |
+
for scenario_idx, data in enumerate(dataset):
|
446 |
+
scenario = scenario_pb2.Scenario()
|
447 |
+
scenario.ParseFromString(bytearray(data.numpy()))
|
448 |
+
tqdm.write(f"scenario id: {scenario.scenario_id}")
|
449 |
+
|
450 |
+
# draw maps
|
451 |
+
for i in range(len(scenario.map_features)):
|
452 |
+
|
453 |
+
# draw lanes
|
454 |
+
if str(scenario.map_features[i].lane) != '':
|
455 |
+
line_x = [z.x for z in scenario.map_features[i].lane.polyline]
|
456 |
+
line_y = [z.y for z in scenario.map_features[i].lane.polyline]
|
457 |
+
plt.scatter(line_x, line_y, c='g', s=5)
|
458 |
+
plt.text(line_x[0], line_y[0], str(scenario.map_features[i].id), fontdict={'family': 'serif', 'size': 20, 'color': 'green'})
|
459 |
+
|
460 |
+
# draw road_edge
|
461 |
+
if str(scenario.map_features[i].road_edge) != '':
|
462 |
+
road_edge_x = [polyline.x for polyline in scenario.map_features[i].road_edge.polyline]
|
463 |
+
road_edge_y = [polyline.y for polyline in scenario.map_features[i].road_edge.polyline]
|
464 |
+
plt.scatter(road_edge_x, road_edge_y)
|
465 |
+
plt.text(road_edge_x[0], road_edge_y[0], scenario.map_features[i].road_edge.type, fontdict={'family': 'serif', 'size': 20, 'color': 'black'})
|
466 |
+
if scenario.map_features[i].road_edge.type == 2:
|
467 |
+
plt.scatter(road_edge_x, road_edge_y, c='k')
|
468 |
+
elif scenario.map_features[i].road_edge.type == 3:
|
469 |
+
plt.scatter(road_edge_x, road_edge_y, c='purple')
|
470 |
+
print(scenario.map_features[i].road_edge)
|
471 |
+
else:
|
472 |
+
plt.scatter(road_edge_x, road_edge_y, c='k')
|
473 |
+
|
474 |
+
# draw road_line
|
475 |
+
if str(scenario.map_features[i].road_line) != '':
|
476 |
+
road_line_x = [j.x for j in scenario.map_features[i].road_line.polyline]
|
477 |
+
road_line_y = [j.y for j in scenario.map_features[i].road_line.polyline]
|
478 |
+
if scenario.map_features[i].road_line.type == 7:
|
479 |
+
plt.plot(road_line_x, road_line_y, c='y')
|
480 |
+
elif scenario.map_features[i].road_line.type == 8:
|
481 |
+
plt.plot(road_line_x, road_line_y, c='y')
|
482 |
+
elif scenario.map_features[i].road_line.type == 6:
|
483 |
+
plt.plot(road_line_x, road_line_y, c='y')
|
484 |
+
elif scenario.map_features[i].road_line.type == 1:
|
485 |
+
for i in range(int(len(road_line_x) / 7)):
|
486 |
+
plt.plot(road_line_x[i * 7 : 5 + i * 7], road_line_y[i * 7 : 5 + i * 7], color='w')
|
487 |
+
elif scenario.map_features[i].road_line.type == 2:
|
488 |
+
plt.plot(road_line_x, road_line_y, c='w')
|
489 |
+
else:
|
490 |
+
plt.plot(road_line_x, road_line_y, c='w')
|
491 |
+
|
492 |
+
# draw tracks
|
493 |
+
scenario_has_invalid_tracks = False
|
494 |
+
for i in range(len(scenario.tracks)):
|
495 |
+
traj_x = [center.center_x for center in scenario.tracks[i].states]
|
496 |
+
traj_y = [center.center_y for center in scenario.tracks[i].states]
|
497 |
+
head = [center.heading for center in scenario.tracks[i].states]
|
498 |
+
valid = [center.valid for center in scenario.tracks[i].states]
|
499 |
+
print(valid)
|
500 |
+
if i == scenario.sdc_track_index:
|
501 |
+
plt.scatter(traj_x[0], traj_y[0], s=140, c='r', marker='s')
|
502 |
+
plt.scatter([x for x, v in zip(traj_x, valid) if v],
|
503 |
+
[y for y, v in zip(traj_y, valid) if v], s=14, c='r')
|
504 |
+
plt.scatter([x for x, v in zip(traj_x, valid) if not v],
|
505 |
+
[y for y, v in zip(traj_y, valid) if not v], s=14, c='m')
|
506 |
+
else:
|
507 |
+
plt.scatter(traj_x[0], traj_y[0], s=140, c='k', marker='s')
|
508 |
+
plt.scatter([x for x, v in zip(traj_x, valid) if v],
|
509 |
+
[y for y, v in zip(traj_y, valid) if v], s=14, c='b')
|
510 |
+
plt.scatter([x for x, v in zip(traj_x, valid) if not v],
|
511 |
+
[y for y, v in zip(traj_y, valid) if not v], s=14, c='m')
|
512 |
+
if valid.count(False) > 0:
|
513 |
+
scenario_has_invalid_tracks = True
|
514 |
+
if scenario_has_invalid_tracks:
|
515 |
+
plt.savefig(f"scenario_{scenario_idx}_{scenario.scenario_id}.png")
|
516 |
+
plt.clf()
|
517 |
+
breakpoint()
|
518 |
+
break
|
519 |
+
|
520 |
+
|
521 |
+
colors = [
|
522 |
+
('#1f77b4', '#1a5a8a'), # blue
|
523 |
+
('#2ca02c', '#217721'), # green
|
524 |
+
('#ff7f0e', '#cc660b'), # orange
|
525 |
+
('#9467bd', '#6f4a91'), # purple
|
526 |
+
('#d62728', '#a31d1d'), # red
|
527 |
+
('#000000', '#000000'), # black
|
528 |
+
]
|
529 |
+
|
530 |
+
@safe_run
|
531 |
+
def plot_gif():
|
532 |
+
data_path = "/u/xiuyu/work/dev4/data/waymo_processed/training"
|
533 |
+
os.makedirs("data/vis/processed/0/gif", exist_ok=True)
|
534 |
+
file_list = os.listdir(data_path)
|
535 |
+
|
536 |
+
for scenario_idx, file in tqdm(enumerate(file_list), leave=False, desc="Scenario"):
|
537 |
+
|
538 |
+
fig, ax = plt.subplots()
|
539 |
+
ax.set_axis_off()
|
540 |
+
|
541 |
+
file_path = os.path.join(data_path, file)
|
542 |
+
data = pickle.load(open(file_path, "rb"))
|
543 |
+
scenario_id = data['scenario_id']
|
544 |
+
|
545 |
+
save_path = os.path.join("data/vis/processed/0/gif",
|
546 |
+
f"scenario_{scenario_idx}_{scenario_id}.gif")
|
547 |
+
if os.path.exists(save_path):
|
548 |
+
tqdm.write(f"Skipped {save_path}.")
|
549 |
+
continue
|
550 |
+
|
551 |
+
# draw maps
|
552 |
+
ax.scatter(data['map_point']['position'][:, 0],
|
553 |
+
data['map_point']['position'][:, 1], s=0.2, c='black', edgecolors='none')
|
554 |
+
|
555 |
+
# draw agents
|
556 |
+
agent_data = data['agent']
|
557 |
+
av_index = agent_data['av_index']
|
558 |
+
position = agent_data['position'] # (num_agent, 91, 3)
|
559 |
+
heading = agent_data['heading'] # (num_agent, 91)
|
560 |
+
shape = agent_data['shape'] # (num_agent, 91, 3)
|
561 |
+
category = agent_data['category'] # (num_agent,)
|
562 |
+
valid_mask = (position[..., 0] != 0) & (position[..., 1] != 0) # (num_agent, 91)
|
563 |
+
|
564 |
+
num_agent = valid_mask.shape[0]
|
565 |
+
num_timestep = position.shape[1]
|
566 |
+
is_av = np.arange(num_agent) == int(av_index)
|
567 |
+
|
568 |
+
is_blue = valid_mask.sum(axis=1) == num_timestep
|
569 |
+
is_green = ~valid_mask[:, 0] & valid_mask[:, -1]
|
570 |
+
is_orange = valid_mask[:, 0] & ~valid_mask[:, -1]
|
571 |
+
is_purple = (valid_mask.sum(axis=1) != num_timestep
|
572 |
+
) & (~is_green) & (~is_orange)
|
573 |
+
agent_colors = np.zeros((num_agent,))
|
574 |
+
agent_colors[is_blue] = 1
|
575 |
+
agent_colors[is_green] = 2
|
576 |
+
agent_colors[is_orange] = 3
|
577 |
+
agent_colors[is_purple] = 4
|
578 |
+
agent_colors[is_av] = 5
|
579 |
+
|
580 |
+
veh_mask = category == 1
|
581 |
+
ped_mask = category == 2
|
582 |
+
cyc_mask = category == 3
|
583 |
+
shape[veh_mask, :, 1] = 1.8
|
584 |
+
shape[veh_mask, :, 0] = 1.8
|
585 |
+
shape[ped_mask, :, 1] = 0.5
|
586 |
+
shape[ped_mask, :, 0] = 0.5
|
587 |
+
shape[cyc_mask, :, 1] = 1.0
|
588 |
+
shape[cyc_mask, :, 0] = 1.0
|
589 |
+
|
590 |
+
fig_paths = []
|
591 |
+
for tid in tqdm(range(num_timestep), leave=False, desc="Timestep"):
|
592 |
+
current_valid_mask = valid_mask[:, tid]
|
593 |
+
xs = position[current_valid_mask, tid, 0]
|
594 |
+
ys = position[current_valid_mask, tid, 1]
|
595 |
+
widths = shape[current_valid_mask, tid, 1]
|
596 |
+
lengths = shape[current_valid_mask, tid, 0]
|
597 |
+
angles = heading[current_valid_mask, tid]
|
598 |
+
current_agent_colors = agent_colors[current_valid_mask]
|
599 |
+
|
600 |
+
drawn_agents = []
|
601 |
+
contours = cal_polygon_contour(xs, ys, angles, widths, lengths) # (num_agent, 4, 2)
|
602 |
+
contours = np.concatenate([contours, contours[:, 0:1]], axis=1) # (num_agent, 5, 2)
|
603 |
+
for x, y, width, length, angle, color_type in zip(
|
604 |
+
xs, ys, widths, lengths, angles, current_agent_colors):
|
605 |
+
agent = plt.Rectangle((x, y), width, length, angle=((angle + np.pi / 2) / np.pi * 360) % 360,
|
606 |
+
linewidth=0.2,
|
607 |
+
facecolor=colors[int(color_type) - 1][0],
|
608 |
+
edgecolor=colors[int(color_type) - 1][1])
|
609 |
+
ax.add_patch(agent)
|
610 |
+
drawn_agents.append(agent)
|
611 |
+
plt.gca().set_aspect('equal', adjustable='box')
|
612 |
+
# for contour, color_type in zip(contours, agent_colors):
|
613 |
+
# drawn_agent = ax.plot(contour[:, 0], contour[:, 1])
|
614 |
+
# drawn_agents.append(drawn_agent)
|
615 |
+
|
616 |
+
fig_path = os.path.join("data/vis/processed/0/",
|
617 |
+
f"scenario_{scenario_idx}_{scenario_id}_{tid}.png")
|
618 |
+
plt.savefig(fig_path, dpi=600)
|
619 |
+
fig_paths.append(fig_path)
|
620 |
+
|
621 |
+
for drawn_agent in drawn_agents:
|
622 |
+
drawn_agent.remove()
|
623 |
+
|
624 |
+
plt.close()
|
625 |
+
|
626 |
+
# generate gif
|
627 |
+
import imageio.v2 as imageio
|
628 |
+
images = []
|
629 |
+
for fig_path in tqdm(fig_paths, leave=False, desc="Generate gif ..."):
|
630 |
+
images.append(imageio.imread(fig_path))
|
631 |
+
imageio.mimsave(save_path, images, duration=0.1)
|
632 |
+
|
633 |
+
|
634 |
+
@safe_run
|
635 |
+
def plot_map_token(ax: Axes, map_points: npt.NDArray, token_pos: npt.NDArray, token_heading: npt.NDArray, colors: Union[str, npt.NDArray]=None):
|
636 |
+
|
637 |
+
plot_map(ax, map_points)
|
638 |
+
|
639 |
+
x, y = token_pos[:, 0], token_pos[:, 1]
|
640 |
+
u = np.cos(token_heading)
|
641 |
+
v = np.sin(token_heading)
|
642 |
+
|
643 |
+
if colors is None:
|
644 |
+
colors = np.random.rand(x.shape[0], 3)
|
645 |
+
ax.quiver(x, y, u, v, angles='xy', scale_units='xy', scale=0.2, color=colors, width=0.005,
|
646 |
+
headwidth=0.2, headlength=2)
|
647 |
+
ax.scatter(x, y, color='blue', s=0.2, edgecolors='none')
|
648 |
+
ax.axis("equal")
|
649 |
+
|
650 |
+
|
651 |
+
@safe_run
|
652 |
+
def plot_map(ax: Axes, map_points: npt.NDArray, color='black'):
|
653 |
+
ax.scatter(map_points[:, 0], map_points[:, 1], s=0.2, c=color, edgecolors='none')
|
654 |
+
|
655 |
+
xmin = np.min(map_points[:, 0])
|
656 |
+
xmax = np.max(map_points[:, 0])
|
657 |
+
ymin = np.min(map_points[:, 1])
|
658 |
+
ymax = np.max(map_points[:, 1])
|
659 |
+
ax.set_xlim(xmin, xmax)
|
660 |
+
ax.set_ylim(ymin, ymax)
|
661 |
+
|
662 |
+
|
663 |
+
@safe_run
|
664 |
+
def plot_agent(ax: Axes, xy: Sequence[float], heading: float, type: str, state, is_av: bool=False,
|
665 |
+
pl2seed_radius: float=25., attr_tokenizer: Attr_Tokenizer=None, enter_index: list=[], **kwargs):
|
666 |
+
|
667 |
+
if type == 'veh':
|
668 |
+
length = 4.3
|
669 |
+
width = 1.8
|
670 |
+
size = 1.0
|
671 |
+
elif type == 'ped':
|
672 |
+
length = 0.5
|
673 |
+
width = 0.5
|
674 |
+
size = 0.1
|
675 |
+
elif type == 'cyc':
|
676 |
+
length = 1.9
|
677 |
+
width = 0.5
|
678 |
+
size = 0.3
|
679 |
+
else:
|
680 |
+
raise ValueError(f"Unsupported agent type {type}")
|
681 |
+
|
682 |
+
if kwargs.get('label', None) is not None:
|
683 |
+
ax.text(
|
684 |
+
xy[0] + 1.5, xy[1] + 1.5,
|
685 |
+
kwargs['label'], fontsize=2, color="darkred", ha="center", va="center"
|
686 |
+
)
|
687 |
+
|
688 |
+
patch = FancyBboxPatch([-length / 2, -width / 2], length, width, linewidth=.2, **kwargs)
|
689 |
+
transform = (
|
690 |
+
mtransforms.Affine2D().rotate(heading).translate(xy[0], xy[1])
|
691 |
+
+ ax.transData
|
692 |
+
)
|
693 |
+
patch.set_transform(transform)
|
694 |
+
|
695 |
+
kwargs['label'] = None
|
696 |
+
angles = [0, 2 * np.pi / 3, np.pi, 4 * np.pi / 3]
|
697 |
+
pts = np.stack([size * np.cos(angles), size * np.sin(angles)], axis=-1)
|
698 |
+
center_patch = Polygon(pts, zorder=10., linewidth=.2, **kwargs)
|
699 |
+
center_patch.set_transform(transform)
|
700 |
+
|
701 |
+
ax.add_patch(patch)
|
702 |
+
ax.add_patch(center_patch)
|
703 |
+
|
704 |
+
if is_av:
|
705 |
+
|
706 |
+
if attr_tokenizer is not None:
|
707 |
+
|
708 |
+
circle_patch = Circle(
|
709 |
+
(xy[0], xy[1]), pl2seed_radius, linewidth=0.5, edgecolor='gray', linestyle='--', facecolor='none'
|
710 |
+
)
|
711 |
+
ax.add_patch(circle_patch)
|
712 |
+
|
713 |
+
grid = attr_tokenizer.get_grid(torch.tensor(np.array(xy)).float(),
|
714 |
+
torch.tensor(np.array([heading])).float()).numpy()[0] # (num_grid, 2)
|
715 |
+
ax.scatter(grid[:, 0], grid[:, 1], s=0.3, c='blue', edgecolors='none')
|
716 |
+
ax.text(grid[0, 0], grid[0, 1], 'Front', fontsize=2, color='darkred', ha='center', va='center')
|
717 |
+
ax.text(grid[-1, 0], grid[-1, 1], 'Back', fontsize=2, color='darkred', ha='center', va='center')
|
718 |
+
|
719 |
+
if enter_index:
|
720 |
+
for i in enter_index:
|
721 |
+
ax.plot(grid[int(i), 0], grid[int(i), 1], marker='x', color='red', markersize=1)
|
722 |
+
|
723 |
+
return patch, center_patch
|
724 |
+
|
725 |
+
|
726 |
+
@safe_run
|
727 |
+
def plot_all(map, xs, ys, angles, types, colors, is_avs, pl2seed_radius: float=25.,
|
728 |
+
attr_tokenizer: Attr_Tokenizer=None, enter_index: list=[], labels: list=[], **kwargs):
|
729 |
+
|
730 |
+
plt.subplots_adjust(left=0.3, right=0.7, top=0.7, bottom=0.3)
|
731 |
+
_, ax = plt.subplots()
|
732 |
+
ax.set_axis_off()
|
733 |
+
|
734 |
+
plot_map(ax, map)
|
735 |
+
|
736 |
+
if not labels:
|
737 |
+
labels = [None] * xs.shape[0]
|
738 |
+
|
739 |
+
for x, y, angle, type, color, label, is_av in zip(xs, ys, angles, types, colors, labels, is_avs):
|
740 |
+
assert type in ('veh', 'ped', 'cyc'), f"Unsupported type {type}."
|
741 |
+
plot_agent(ax, [x, y], angle.item(), type, None, is_av, facecolor=color, edgecolor='k', label=label,
|
742 |
+
pl2seed_radius=pl2seed_radius, attr_tokenizer=attr_tokenizer, enter_index=enter_index)
|
743 |
+
|
744 |
+
ax.grid(False)
|
745 |
+
ax.set_aspect('equal', adjustable='box')
|
746 |
+
|
747 |
+
# ax.legend(loc='best', frameon=True)
|
748 |
+
|
749 |
+
if kwargs.get('save_path', None):
|
750 |
+
plt.savefig(kwargs['save_path'], dpi=600, bbox_inches="tight")
|
751 |
+
|
752 |
+
plt.close()
|
753 |
+
|
754 |
+
return ax
|
755 |
+
|
756 |
+
|
757 |
+
@safe_run
|
758 |
+
def plot_file(gt_folder: str,
|
759 |
+
folder: Optional[str] = None,
|
760 |
+
files: Optional[str] = None):
|
761 |
+
from dev.metrics.compute_metrics import _unbatch
|
762 |
+
|
763 |
+
if files is None:
|
764 |
+
assert os.path.exists(folder), f'Path {folder} does not exist.'
|
765 |
+
files = list(fnmatch.filter(os.listdir(folder), 'idx_*_rollouts.pkl'))
|
766 |
+
CONSOLE.log(f'Found {len(files)} rollouts files from {folder}.')
|
767 |
+
|
768 |
+
|
769 |
+
if folder is None:
|
770 |
+
assert os.path.exists(files), f'Path {files} does not exist.'
|
771 |
+
folder = os.path.dirname(files)
|
772 |
+
files = [files]
|
773 |
+
|
774 |
+
parent, folder_name = os.path.split(folder.rstrip(os.sep))
|
775 |
+
save_path = os.path.join(parent, f'{folder_name}_plots')
|
776 |
+
|
777 |
+
for file in (pbar := tqdm(files, leave=False, desc='Plotting files ...')):
|
778 |
+
pbar.set_postfix(file=file)
|
779 |
+
|
780 |
+
with open(os.path.join(folder, file), 'rb') as f:
|
781 |
+
preds = pickle.load(f)
|
782 |
+
|
783 |
+
scenario_ids = preds['_scenario_id']
|
784 |
+
agent_batch = preds['agent_batch']
|
785 |
+
agent_id = _unbatch(preds['agent_id'], agent_batch)
|
786 |
+
preds_traj = _unbatch(preds['pred_traj'], agent_batch)
|
787 |
+
preds_head = _unbatch(preds['pred_head'], agent_batch)
|
788 |
+
preds_type = _unbatch(preds['pred_type'], agent_batch)
|
789 |
+
preds_state = _unbatch(preds['pred_state'], agent_batch)
|
790 |
+
preds_valid = _unbatch(preds['pred_valid'], agent_batch)
|
791 |
+
|
792 |
+
for i, scenario_id in enumerate(scenario_ids):
|
793 |
+
n_rollouts = preds_traj[0].shape[1]
|
794 |
+
|
795 |
+
for j in range(n_rollouts): # 1
|
796 |
+
pred = dict(scenario_id=[scenario_id],
|
797 |
+
pred_traj=preds_traj[i][:, j],
|
798 |
+
pred_head=preds_head[i][:, j],
|
799 |
+
pred_state=preds_state[i][:, j],
|
800 |
+
pred_type=preds_type[i][:, j],
|
801 |
+
)
|
802 |
+
av_index = agent_id[i][:, 0].tolist().index(preds['av_id']) # NOTE: hard code!!!
|
803 |
+
|
804 |
+
data_path = os.path.join(gt_folder, 'validation', f'{scenario_id}.pkl')
|
805 |
+
with open(data_path, 'rb') as f:
|
806 |
+
data = pickle.load(f)
|
807 |
+
plot_val(data, pred, av_index=av_index, save_path=save_path)
|
808 |
+
|
809 |
+
|
810 |
+
@safe_run
|
811 |
+
def plot_val(data: Union[dict, str], pred: dict, av_index: int, save_path: str, suffix: str='',
|
812 |
+
pl2seed_radius: float=75., attr_tokenizer=None, **kwargs):
|
813 |
+
|
814 |
+
if isinstance(data, str):
|
815 |
+
assert data.endswith('.pkl'), f'Got invalid data path {data}.'
|
816 |
+
assert os.path.exists(data), f'Path {data} does not exist.'
|
817 |
+
with open(data, 'rb') as f:
|
818 |
+
data = pickle.load(f)
|
819 |
+
|
820 |
+
map_point = data['map_point']['position'].cpu().numpy()
|
821 |
+
|
822 |
+
scenario_id = pred['scenario_id'][0]
|
823 |
+
pred_traj = pred['pred_traj'].cpu().numpy() # (num_agent, num_future_step, 2)
|
824 |
+
pred_type = list(map(lambda i: AGENT_TYPE[i], pred['pred_type'].tolist()))
|
825 |
+
pred_state = pred['pred_state'].cpu().numpy()
|
826 |
+
pred_head = pred['pred_head'].cpu().numpy()
|
827 |
+
ids = np.arange(pred_traj.shape[0])
|
828 |
+
|
829 |
+
if 'agent_labels' in pred:
|
830 |
+
kwargs.update(agent_labels=pred['agent_labels'])
|
831 |
+
|
832 |
+
plot_scenario(scenario_id, map_point, pred_traj, pred_head, pred_state, pred_type,
|
833 |
+
av_index=av_index, ids=ids, save_path=save_path, suffix=suffix,
|
834 |
+
pl2seed_radius=pl2seed_radius, attr_tokenizer=attr_tokenizer, **kwargs)
|
835 |
+
|
836 |
+
|
837 |
+
@safe_run
|
838 |
+
def plot_scenario(scenario_id: str,
|
839 |
+
map_data: npt.NDArray,
|
840 |
+
traj: npt.NDArray,
|
841 |
+
heading: npt.NDArray,
|
842 |
+
state: npt.NDArray,
|
843 |
+
types: List[str],
|
844 |
+
av_index: int,
|
845 |
+
color_type: Literal['state', 'type', 'seed']='seed',
|
846 |
+
state_type: List[str]=['invalid', 'valid', 'enter', 'exit'],
|
847 |
+
plot_enter: bool=False,
|
848 |
+
suffix: str='',
|
849 |
+
pl2seed_radius: float=25.,
|
850 |
+
attr_tokenizer: Attr_Tokenizer=None,
|
851 |
+
enter_index: List[list] = [],
|
852 |
+
save_gif: bool=True,
|
853 |
+
tokenized: bool=False,
|
854 |
+
agent_labels: List[List[Optional[str]]] = [],
|
855 |
+
**kwargs):
|
856 |
+
|
857 |
+
num_historical_steps = 11
|
858 |
+
shift = 5
|
859 |
+
num_agent, num_timestep = traj.shape[:2]
|
860 |
+
|
861 |
+
if tokenized:
|
862 |
+
num_historical_steps = 2
|
863 |
+
shift = 1
|
864 |
+
|
865 |
+
if 'save_path' in kwargs and kwargs['save_path'] != '':
|
866 |
+
os.makedirs(kwargs['save_path'], exist_ok=True)
|
867 |
+
save_id = int(max([0] + list(map(lambda fname: int(fname.split("_")[-1]),
|
868 |
+
filter(lambda fname: fname.startswith(scenario_id)
|
869 |
+
and os.path.isdir(os.path.join(kwargs['save_path'], fname)),
|
870 |
+
os.listdir(kwargs['save_path'])))))) + 1
|
871 |
+
os.makedirs(f"{kwargs['save_path']}/{scenario_id}_{str(save_id).zfill(3)}", exist_ok=True)
|
872 |
+
|
873 |
+
if save_id > 1:
|
874 |
+
try:
|
875 |
+
import shutil
|
876 |
+
shutil.rmtree(f"{kwargs['save_path']}/{scenario_id}_{str(save_id - 1).zfill(3)}")
|
877 |
+
except:
|
878 |
+
pass
|
879 |
+
|
880 |
+
visible_mask = state != state_type.index('invalid')
|
881 |
+
if not plot_enter:
|
882 |
+
visible_mask &= (state != state_type.index('enter'))
|
883 |
+
|
884 |
+
last_valid_step = visible_mask.shape[1] - 1 - torch.argmax(torch.Tensor(visible_mask).flip(dims=[1]).long(), dim=1)
|
885 |
+
ids = None
|
886 |
+
if 'ids' in kwargs:
|
887 |
+
ids = kwargs['ids']
|
888 |
+
last_valid_step = {int(ids[i]): int(last_valid_step[i]) for i in range(len(ids))}
|
889 |
+
|
890 |
+
# agent colors
|
891 |
+
agent_colors = np.zeros((num_agent, num_timestep, 3))
|
892 |
+
|
893 |
+
agent_palette = sns.color_palette('husl', n_colors=7)
|
894 |
+
state_colors = {state: np.array(agent_palette[i]) for i, state in enumerate(state_type)}
|
895 |
+
seed_colors = {seed: np.array(agent_palette[i]) for i, seed in enumerate(['existing', 'entered', 'exited'])}
|
896 |
+
|
897 |
+
if color_type == 'state':
|
898 |
+
for t in range(state.shape[1]):
|
899 |
+
agent_colors[state[:, t] == state_type.index('invalid'), t * shift : (t + 1) * shift] = state_colors['invalid']
|
900 |
+
agent_colors[state[:, t] == state_type.index('valid'), t * shift : (t + 1) * shift] = state_colors['valid']
|
901 |
+
agent_colors[state[:, t] == state_type.index('enter'), t * shift : (t + 1) * shift] = state_colors['enter']
|
902 |
+
agent_colors[state[:, t] == state_type.index('exit'), t * shift : (t + 1) * shift] = state_colors['exit']
|
903 |
+
|
904 |
+
if color_type == 'seed':
|
905 |
+
agent_colors[:, :] = seed_colors['existing']
|
906 |
+
is_exited = np.any(state[:, num_historical_steps - 1:] == state_type.index('exit'), axis=-1)
|
907 |
+
is_entered = np.any(state[:, num_historical_steps - 1:] == state_type.index('enter'), axis=-1)
|
908 |
+
is_entered[av_index + 1:] = True # NOTE: hard code, need improvment
|
909 |
+
agent_colors[is_exited, :] = seed_colors['exited']
|
910 |
+
agent_colors[is_entered, :] = seed_colors['entered']
|
911 |
+
|
912 |
+
agent_colors[av_index, :] = np.array(agent_palette[-1])
|
913 |
+
is_av = np.zeros_like(state[:, 0]).astype(np.bool_)
|
914 |
+
is_av[av_index] = True
|
915 |
+
|
916 |
+
# draw agents
|
917 |
+
fig_paths = []
|
918 |
+
for tid in tqdm(range(num_timestep), leave=False, desc="Plot ..."):
|
919 |
+
mask_t = visible_mask[:, tid]
|
920 |
+
xs = traj[mask_t, tid, 0]
|
921 |
+
ys = traj[mask_t, tid, 1]
|
922 |
+
angles = heading[mask_t, tid]
|
923 |
+
colors = agent_colors[mask_t, tid]
|
924 |
+
types_t = [types[i] for i, mask in enumerate(mask_t) if mask]
|
925 |
+
if ids is not None:
|
926 |
+
ids_t = ids[mask_t]
|
927 |
+
is_av_t = is_av[mask_t]
|
928 |
+
enter_index_t = enter_index[tid] if enter_index else None
|
929 |
+
labels = []
|
930 |
+
if agent_labels:
|
931 |
+
labels = [agent_labels[i][tid // shift] for i in range(len(agent_labels)) if mask_t[i]]
|
932 |
+
|
933 |
+
fig_path = None
|
934 |
+
if 'save_path' in kwargs:
|
935 |
+
save_path = kwargs['save_path']
|
936 |
+
fig_path = os.path.join(f"{save_path}/{scenario_id}_{str(save_id).zfill(3)}", f"{tid}.png")
|
937 |
+
fig_paths.append(fig_path)
|
938 |
+
|
939 |
+
plot_all(map_data, xs, ys, angles, types_t, colors=colors, save_path=fig_path, is_avs=is_av_t,
|
940 |
+
pl2seed_radius=pl2seed_radius, attr_tokenizer=attr_tokenizer, enter_index=enter_index_t, labels=labels)
|
941 |
+
|
942 |
+
# generate gif
|
943 |
+
if fig_paths and save_gif:
|
944 |
+
os.makedirs(os.path.join(save_path, 'gifs'), exist_ok=True)
|
945 |
+
images = []
|
946 |
+
gif_path = f"{save_path}/gifs/{scenario_id}_{str(save_id).zfill(3)}.gif"
|
947 |
+
for fig_path in tqdm(fig_paths, leave=False, desc="Generate gif ..."):
|
948 |
+
images.append(Image.open(fig_path))
|
949 |
+
try:
|
950 |
+
images[0].save(gif_path, save_all=True, append_images=images[1:], duration=100, loop=0)
|
951 |
+
tqdm.write(f"Saved gif at {gif_path}")
|
952 |
+
try:
|
953 |
+
import shutil
|
954 |
+
shutil.rmtree(f"{save_path}/{scenario_id}_{str(save_id).zfill(3)}")
|
955 |
+
os.remove(f"{save_path}/gifs/{scenario_id}_{str(save_id - 1).zfill(3)}.gif")
|
956 |
+
except:
|
957 |
+
pass
|
958 |
+
except Exception as e:
|
959 |
+
tqdm.write(f"{e}! Failed to save gif at {gif_path}")
|
960 |
+
|
961 |
+
|
962 |
+
def match_token_map(data):
|
963 |
+
|
964 |
+
# init map token
|
965 |
+
argmin_sample_len = 3
|
966 |
+
map_token_traj_path = '/u/xiuyu/work/dev4/dev/tokens/map_traj_token5.pkl'
|
967 |
+
|
968 |
+
map_token_traj = pickle.load(open(map_token_traj_path, 'rb'))
|
969 |
+
map_token = {'traj_src': map_token_traj['traj_src'], }
|
970 |
+
traj_end_theta = np.arctan2(map_token['traj_src'][:, -1, 1] - map_token['traj_src'][:, -2, 1],
|
971 |
+
map_token['traj_src'][:, -1, 0] - map_token['traj_src'][:, -2, 0])
|
972 |
+
indices = torch.linspace(0, map_token['traj_src'].shape[1]-1, steps=argmin_sample_len).long()
|
973 |
+
map_token['sample_pt'] = torch.from_numpy(map_token['traj_src'][:, indices]).to(torch.float)
|
974 |
+
map_token['traj_end_theta'] = torch.from_numpy(traj_end_theta).to(torch.float)
|
975 |
+
map_token['traj_src'] = torch.from_numpy(map_token['traj_src']).to(torch.float)
|
976 |
+
|
977 |
+
traj_pos = data['map_save']['traj_pos'].to(torch.float)
|
978 |
+
traj_theta = data['map_save']['traj_theta'].to(torch.float)
|
979 |
+
pl_idx_list = data['map_save']['pl_idx_list']
|
980 |
+
token_sample_pt = map_token['sample_pt'].to(traj_pos.device)
|
981 |
+
token_src = map_token['traj_src'].to(traj_pos.device)
|
982 |
+
max_traj_len = map_token['traj_src'].shape[1]
|
983 |
+
pl_num = traj_pos.shape[0]
|
984 |
+
|
985 |
+
pt_token_pos = traj_pos[:, 0, :].clone()
|
986 |
+
pt_token_orientation = traj_theta.clone()
|
987 |
+
cos, sin = traj_theta.cos(), traj_theta.sin()
|
988 |
+
rot_mat = traj_theta.new_zeros(pl_num, 2, 2)
|
989 |
+
rot_mat[..., 0, 0] = cos
|
990 |
+
rot_mat[..., 0, 1] = -sin
|
991 |
+
rot_mat[..., 1, 0] = sin
|
992 |
+
rot_mat[..., 1, 1] = cos
|
993 |
+
traj_pos_local = torch.bmm((traj_pos - traj_pos[:, 0:1]), rot_mat.view(-1, 2, 2))
|
994 |
+
distance = torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1)) ** 2, dim=(-2, -1))
|
995 |
+
pt_token_id = torch.argmin(distance, dim=1)
|
996 |
+
|
997 |
+
noise = False
|
998 |
+
if noise:
|
999 |
+
topk_indices = torch.argsort(torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1)) ** 2, dim=(-2, -1)), dim=1)[:, :8]
|
1000 |
+
sample_topk = torch.randint(0, topk_indices.shape[-1], size=(topk_indices.shape[0], 1), device=topk_indices.device)
|
1001 |
+
pt_token_id = torch.gather(topk_indices, 1, sample_topk).squeeze(-1)
|
1002 |
+
|
1003 |
+
# cos, sin = traj_theta.cos(), traj_theta.sin()
|
1004 |
+
# rot_mat = traj_theta.new_zeros(pl_num, 2, 2)
|
1005 |
+
# rot_mat[..., 0, 0] = cos
|
1006 |
+
# rot_mat[..., 0, 1] = sin
|
1007 |
+
# rot_mat[..., 1, 0] = -sin
|
1008 |
+
# rot_mat[..., 1, 1] = cos
|
1009 |
+
# token_src_world = torch.bmm(token_src[None, ...].repeat(pl_num, 1, 1, 1).reshape(pl_num, -1, 2),
|
1010 |
+
# rot_mat.view(-1, 2, 2)).reshape(pl_num, token_src.shape[0], max_traj_len, 2) + traj_pos[:, None, [0], :]
|
1011 |
+
# token_src_world_select = token_src_world.view(-1, 1024, 11, 2)[torch.arange(pt_token_id.view(-1).shape[0]), pt_token_id.view(-1)].view(pl_num, max_traj_len, 2)
|
1012 |
+
|
1013 |
+
pl_idx_full = pl_idx_list.clone()
|
1014 |
+
token2pl = torch.stack([torch.arange(len(pl_idx_list), device=traj_pos.device), pl_idx_full.long()])
|
1015 |
+
count_nums = []
|
1016 |
+
for pl in pl_idx_full.unique():
|
1017 |
+
pt = token2pl[0, token2pl[1, :] == pl]
|
1018 |
+
left_side = (data['pt_token']['side'][pt] == 0).sum()
|
1019 |
+
right_side = (data['pt_token']['side'][pt] == 1).sum()
|
1020 |
+
center_side = (data['pt_token']['side'][pt] == 2).sum()
|
1021 |
+
count_nums.append(torch.Tensor([left_side, right_side, center_side]))
|
1022 |
+
count_nums = torch.stack(count_nums, dim=0)
|
1023 |
+
num_polyline = int(count_nums.max().item())
|
1024 |
+
traj_mask = torch.zeros((int(len(pl_idx_full.unique())), 3, num_polyline), dtype=bool)
|
1025 |
+
idx_matrix = torch.arange(traj_mask.size(2)).unsqueeze(0).unsqueeze(0)
|
1026 |
+
idx_matrix = idx_matrix.expand(traj_mask.size(0), traj_mask.size(1), -1)
|
1027 |
+
counts_num_expanded = count_nums.unsqueeze(-1)
|
1028 |
+
mask_update = idx_matrix < counts_num_expanded
|
1029 |
+
traj_mask[mask_update] = True
|
1030 |
+
|
1031 |
+
data['pt_token']['traj_mask'] = traj_mask
|
1032 |
+
data['pt_token']['position'] = torch.cat([pt_token_pos, torch.zeros((data['pt_token']['num_nodes'], 1),
|
1033 |
+
device=traj_pos.device, dtype=torch.float)], dim=-1)
|
1034 |
+
data['pt_token']['orientation'] = pt_token_orientation
|
1035 |
+
data['pt_token']['height'] = data['pt_token']['position'][:, -1]
|
1036 |
+
data[('pt_token', 'to', 'map_polygon')] = {}
|
1037 |
+
data[('pt_token', 'to', 'map_polygon')]['edge_index'] = token2pl # (2, num_points)
|
1038 |
+
data['pt_token']['token_idx'] = pt_token_id
|
1039 |
+
return data
|
1040 |
+
|
1041 |
+
|
1042 |
+
@safe_run
|
1043 |
+
def plot_tokenize(data, save_path: str):
|
1044 |
+
|
1045 |
+
shift = 5
|
1046 |
+
token_size = 2048
|
1047 |
+
pl2seed_radius = 75
|
1048 |
+
|
1049 |
+
# transformation
|
1050 |
+
transform = WaymoTargetBuilder(num_historical_steps=11,
|
1051 |
+
num_future_steps=80,
|
1052 |
+
max_num=32,
|
1053 |
+
training=False)
|
1054 |
+
|
1055 |
+
grid_range = 150.
|
1056 |
+
grid_interval = 3.
|
1057 |
+
angle_interval = 3.
|
1058 |
+
attr_tokenizer = Attr_Tokenizer(grid_range=grid_range,
|
1059 |
+
grid_interval=grid_interval,
|
1060 |
+
radius=pl2seed_radius,
|
1061 |
+
angle_interval=angle_interval)
|
1062 |
+
|
1063 |
+
# tokenization
|
1064 |
+
token_processor = TokenProcessor(token_size,
|
1065 |
+
training=False,
|
1066 |
+
predict_motion=True,
|
1067 |
+
predict_state=True,
|
1068 |
+
predict_map=True,
|
1069 |
+
state_token={'invalid': 0, 'valid': 1, 'enter': 2, 'exit': 3},
|
1070 |
+
pl2seed_radius=pl2seed_radius)
|
1071 |
+
CONSOLE.log(f"Loaded token processor with token_size: {token_size}")
|
1072 |
+
|
1073 |
+
# preprocess
|
1074 |
+
data: HeteroData = transform(data)
|
1075 |
+
tokenized_data = token_processor(data)
|
1076 |
+
CONSOLE.log(f"Keys in tokenized data:\n{tokenized_data.keys()}")
|
1077 |
+
|
1078 |
+
# plot
|
1079 |
+
agent_data = tokenized_data['agent']
|
1080 |
+
map_data = tokenized_data['map_point']
|
1081 |
+
# CONSOLE.log(f"Keys in agent data:\n{agent_data.keys()}")
|
1082 |
+
|
1083 |
+
av_index = agent_data['av_index']
|
1084 |
+
raw_traj = agent_data['position'][..., :2].contiguous() # [n_agent, n_step, 2]
|
1085 |
+
raw_heading = agent_data['heading'] # [n_agent, n_step]
|
1086 |
+
|
1087 |
+
traj = agent_data['traj_pos'][..., :2].contiguous() # [n_agent, n_step, 6, 2]
|
1088 |
+
traj = traj[:, :, 1:, :].flatten(1, 2)
|
1089 |
+
traj = torch.cat([raw_traj[:, :1], traj], dim=1)
|
1090 |
+
heading = agent_data['traj_heading'] # [n_agent, n_step, 6]
|
1091 |
+
heading = heading[:, :, 1:].flatten(1, 2)
|
1092 |
+
heading = torch.cat([raw_heading[:, :1], heading], dim=1)
|
1093 |
+
|
1094 |
+
agent_state = agent_data['state_idx'].repeat_interleave(repeats=shift, dim=-1)
|
1095 |
+
agent_state = torch.cat([torch.zeros_like(agent_state[:, :1]), agent_state], dim=1)
|
1096 |
+
agent_type = agent_data['type']
|
1097 |
+
ids = np.arange(raw_traj.shape[0])
|
1098 |
+
|
1099 |
+
plot_scenario(scenario_id=tokenized_data['scenario_id'],
|
1100 |
+
map_data=tokenized_data['map_point']['position'].numpy(),
|
1101 |
+
traj=raw_traj.numpy(),
|
1102 |
+
heading=raw_heading.numpy(),
|
1103 |
+
state=agent_state.numpy(),
|
1104 |
+
types=list(map(lambda i: AGENT_TYPE[i], agent_type.tolist())),
|
1105 |
+
av_index=av_index,
|
1106 |
+
ids=ids,
|
1107 |
+
save_path=save_path,
|
1108 |
+
pl2seed_radius=pl2seed_radius,
|
1109 |
+
attr_tokenizer=attr_tokenizer,
|
1110 |
+
color_type='state',
|
1111 |
+
)
|
1112 |
+
|
1113 |
+
|
1114 |
+
if __name__ == "__main__":
|
1115 |
+
parser = ArgumentParser()
|
1116 |
+
parser.add_argument('--data_path', type=str, default='/u/xiuyu/work/dev4/data/waymo_processed')
|
1117 |
+
parser.add_argument('--tfrecord_dir', type=str, default='validation_tfrecords_splitted')
|
1118 |
+
# plot tokenized data
|
1119 |
+
parser.add_argument('--save_folder', type=str, default='plot_gt')
|
1120 |
+
parser.add_argument('--split', type=str, default='validation')
|
1121 |
+
parser.add_argument('--scenario_id', type=str, default=None)
|
1122 |
+
parser.add_argument('--plot_tokenize', action='store_true')
|
1123 |
+
# plot generated rollouts
|
1124 |
+
parser.add_argument('--plot_file', action='store_true')
|
1125 |
+
parser.add_argument('--folder_path', type=str, default=None)
|
1126 |
+
parser.add_argument('--file_path', type=str, default=None)
|
1127 |
+
args = parser.parse_args()
|
1128 |
+
|
1129 |
+
if args.plot_tokenize:
|
1130 |
+
|
1131 |
+
scenario_id = "74ad7b76d5906d39"
|
1132 |
+
# scenario_id = "1d60300bc06f4801"
|
1133 |
+
data_path = os.path.join(args.data_path, args.split, f"{scenario_id}.pkl")
|
1134 |
+
data = pickle.load(open(data_path, "rb"))
|
1135 |
+
data['tfrecord_path'] = os.path.join(args.tfrecord_dir, f'{scenario_id}.tfrecords')
|
1136 |
+
CONSOLE.log(f"Loaded scenario {scenario_id}")
|
1137 |
+
|
1138 |
+
save_path = os.path.join(args.data_path, args.save_folder, args.split)
|
1139 |
+
os.makedirs(save_path, exist_ok=True)
|
1140 |
+
|
1141 |
+
plot_tokenize(data, save_path)
|
1142 |
+
|
1143 |
+
if args.plot_file:
|
1144 |
+
|
1145 |
+
plot_file(args.data_path, folder=args.folder_path, files=args.file_path)
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/environment.yml
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: traj
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- _libgcc_mutex=0.1=conda_forge
|
7 |
+
- _openmp_mutex=4.5=2_gnu
|
8 |
+
- ca-certificates=2025.1.31=hbcca054_0
|
9 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
10 |
+
- libffi=3.4.4=h6a678d5_1
|
11 |
+
- libgcc=14.2.0=h77fa898_1
|
12 |
+
- libgcc-ng=14.2.0=h69a702a_1
|
13 |
+
- libgomp=14.2.0=h77fa898_1
|
14 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
15 |
+
- ncdu=1.16=h0f457ee_0
|
16 |
+
- ncurses=6.4=h6a678d5_0
|
17 |
+
- openssl=3.4.0=h7b32b05_1
|
18 |
+
- pip=24.2=py39h06a4308_0
|
19 |
+
- python=3.9.19=h955ad1f_1
|
20 |
+
- readline=8.2=h5eee18b_0
|
21 |
+
- sqlite=3.45.3=h5eee18b_0
|
22 |
+
- tk=8.6.14=h39e8969_0
|
23 |
+
- wheel=0.43.0=py39h06a4308_0
|
24 |
+
- xz=5.4.6=h5eee18b_1
|
25 |
+
- zlib=1.2.13=h5eee18b_1
|
26 |
+
- pip:
|
27 |
+
- absl-py==1.4.0
|
28 |
+
- addict==2.4.0
|
29 |
+
- aiohappyeyeballs==2.4.0
|
30 |
+
- aiohttp==3.10.5
|
31 |
+
- aiosignal==1.3.1
|
32 |
+
- anyio==4.4.0
|
33 |
+
- appdirs==1.4.4
|
34 |
+
- argon2-cffi==23.1.0
|
35 |
+
- argon2-cffi-bindings==21.2.0
|
36 |
+
- array-record==0.5.1
|
37 |
+
- arrow==1.3.0
|
38 |
+
- asttokens==2.4.1
|
39 |
+
- astunparse==1.6.3
|
40 |
+
- async-lru==2.0.4
|
41 |
+
- async-timeout==4.0.3
|
42 |
+
- attrs==24.2.0
|
43 |
+
- av==12.3.0
|
44 |
+
- babel==2.16.0
|
45 |
+
- beautifulsoup4==4.12.3
|
46 |
+
- bidict==0.23.1
|
47 |
+
- bleach==6.1.0
|
48 |
+
- blinker==1.8.2
|
49 |
+
- cachetools==5.5.0
|
50 |
+
- certifi==2024.7.4
|
51 |
+
- cffi==1.17.0
|
52 |
+
- chardet==5.2.0
|
53 |
+
- charset-normalizer==3.3.2
|
54 |
+
- click==8.1.7
|
55 |
+
- cloudpickle==3.0.0
|
56 |
+
- colorlog==6.8.2
|
57 |
+
- comet-ml==3.45.0
|
58 |
+
- comm==0.2.2
|
59 |
+
- configargparse==1.7
|
60 |
+
- configobj==5.0.8
|
61 |
+
- contourpy==1.3.0
|
62 |
+
- cryptography==43.0.0
|
63 |
+
- cycler==0.12.1
|
64 |
+
- dacite==1.8.1
|
65 |
+
- dash==2.17.1
|
66 |
+
- dash-core-components==2.0.0
|
67 |
+
- dash-html-components==2.0.0
|
68 |
+
- dash-table==5.0.0
|
69 |
+
- dask==2023.3.1
|
70 |
+
- dataclass-array==1.5.1
|
71 |
+
- debugpy==1.8.5
|
72 |
+
- decorator==5.1.1
|
73 |
+
- defusedxml==0.7.1
|
74 |
+
- descartes==1.1.0
|
75 |
+
- dm-tree==0.1.8
|
76 |
+
- docker-pycreds==0.4.0
|
77 |
+
- docstring-parser==0.16
|
78 |
+
- dulwich==0.22.1
|
79 |
+
- easydict==1.13
|
80 |
+
- einops==0.8.0
|
81 |
+
- einsum==0.3.0
|
82 |
+
- embreex==2.17.7.post5
|
83 |
+
- etils==1.5.2
|
84 |
+
- eval-type-backport==0.2.0
|
85 |
+
- everett==3.1.0
|
86 |
+
- exceptiongroup==1.2.2
|
87 |
+
- executing==2.0.1
|
88 |
+
- fastjsonschema==2.20.0
|
89 |
+
- filelock==3.15.4
|
90 |
+
- fire==0.6.0
|
91 |
+
- flask==3.0.3
|
92 |
+
- flatbuffers==24.3.25
|
93 |
+
- fonttools==4.53.1
|
94 |
+
- fqdn==1.5.1
|
95 |
+
- frozenlist==1.4.1
|
96 |
+
- fsspec==2024.6.1
|
97 |
+
- gast==0.4.0
|
98 |
+
- gdown==5.2.0
|
99 |
+
- gitdb==4.0.11
|
100 |
+
- gitpython==3.1.43
|
101 |
+
- google-auth==2.16.2
|
102 |
+
- google-auth-oauthlib==1.0.0
|
103 |
+
- google-pasta==0.2.0
|
104 |
+
- grpcio==1.66.1
|
105 |
+
- h11==0.14.0
|
106 |
+
- h5py==3.11.0
|
107 |
+
- httpcore==1.0.5
|
108 |
+
- httpx==0.27.2
|
109 |
+
- idna==3.8
|
110 |
+
- imageio==2.35.1
|
111 |
+
- immutabledict==2.2.0
|
112 |
+
- importlib-metadata==8.4.0
|
113 |
+
- importlib-resources==6.4.4
|
114 |
+
- ipykernel==6.29.5
|
115 |
+
- ipython==8.18.1
|
116 |
+
- ipywidgets==8.1.5
|
117 |
+
- isoduration==20.11.0
|
118 |
+
- itsdangerous==2.2.0
|
119 |
+
- jax==0.4.30
|
120 |
+
- jaxlib==0.4.30
|
121 |
+
- jaxtyping==0.2.33
|
122 |
+
- jedi==0.19.1
|
123 |
+
- jinja2==3.1.4
|
124 |
+
- joblib==1.4.2
|
125 |
+
- json5==0.9.25
|
126 |
+
- jsonpointer==3.0.0
|
127 |
+
- jsonschema==4.23.0
|
128 |
+
- jsonschema-specifications==2023.12.1
|
129 |
+
- jupyter-client==8.6.2
|
130 |
+
- jupyter-core==5.7.2
|
131 |
+
- jupyter-events==0.10.0
|
132 |
+
- jupyter-lsp==2.2.5
|
133 |
+
- jupyter-server==2.14.2
|
134 |
+
- jupyter-server-terminals==0.5.3
|
135 |
+
- jupyterlab==4.2.5
|
136 |
+
- jupyterlab-pygments==0.3.0
|
137 |
+
- jupyterlab-server==2.27.3
|
138 |
+
- jupyterlab-widgets==3.0.13
|
139 |
+
- keras==2.12.0
|
140 |
+
- kiwisolver==1.4.5
|
141 |
+
- lark==1.2.2
|
142 |
+
- lazy-loader==0.4
|
143 |
+
- libclang==18.1.1
|
144 |
+
- lightning-utilities==0.11.6
|
145 |
+
- locket==1.0.0
|
146 |
+
- lxml==5.3.0
|
147 |
+
- manifold3d==2.5.1
|
148 |
+
- mapbox-earcut==1.0.2
|
149 |
+
- markdown==3.7
|
150 |
+
- markdown-it-py==3.0.0
|
151 |
+
- markupsafe==2.1.5
|
152 |
+
- matplotlib==3.9.2
|
153 |
+
- matplotlib-inline==0.1.7
|
154 |
+
- mdurl==0.1.2
|
155 |
+
- mediapy==1.2.2
|
156 |
+
- mistune==3.0.2
|
157 |
+
- ml-dtypes==0.4.0
|
158 |
+
- mpmath==1.3.0
|
159 |
+
- msgpack==1.0.8
|
160 |
+
- msgpack-numpy==0.4.8
|
161 |
+
- multidict==6.0.5
|
162 |
+
- namex==0.0.8
|
163 |
+
- nbclient==0.10.0
|
164 |
+
- nbconvert==7.16.4
|
165 |
+
- nbformat==5.10.4
|
166 |
+
- nerfacc==0.5.2
|
167 |
+
- nerfstudio==0.3.4
|
168 |
+
- nest-asyncio==1.6.0
|
169 |
+
- networkx==3.2.1
|
170 |
+
- ninja==1.11.1.1
|
171 |
+
- nodeenv==1.9.1
|
172 |
+
- notebook-shim==0.2.4
|
173 |
+
- numpy==1.23.0
|
174 |
+
- nuscenes-devkit==1.1.11
|
175 |
+
- nvidia-cublas-cu12==12.1.3.1
|
176 |
+
- nvidia-cuda-cupti-cu12==12.1.105
|
177 |
+
- nvidia-cuda-nvrtc-cu12==12.1.105
|
178 |
+
- nvidia-cuda-runtime-cu12==12.1.105
|
179 |
+
- nvidia-cudnn-cu12==9.1.0.70
|
180 |
+
- nvidia-cufft-cu12==11.0.2.54
|
181 |
+
- nvidia-curand-cu12==10.3.2.106
|
182 |
+
- nvidia-cusolver-cu12==11.4.5.107
|
183 |
+
- nvidia-cusparse-cu12==12.1.0.106
|
184 |
+
- nvidia-nccl-cu12==2.20.5
|
185 |
+
- nvidia-nvjitlink-cu12==12.6.20
|
186 |
+
- nvidia-nvtx-cu12==12.1.105
|
187 |
+
- oauthlib==3.2.2
|
188 |
+
- open3d==0.18.0
|
189 |
+
- opencv-python==4.6.0.66
|
190 |
+
- openexr==1.3.9
|
191 |
+
- opt-einsum==3.3.0
|
192 |
+
- optree==0.12.1
|
193 |
+
- overrides==7.7.0
|
194 |
+
- packaging==24.1
|
195 |
+
- pandas==1.5.3
|
196 |
+
- pandocfilters==1.5.1
|
197 |
+
- parso==0.8.4
|
198 |
+
- partd==1.4.2
|
199 |
+
- pexpect==4.9.0
|
200 |
+
- pillow==9.2.0
|
201 |
+
- platformdirs==4.2.2
|
202 |
+
- plotly==5.13.1
|
203 |
+
- prometheus-client==0.20.0
|
204 |
+
- promise==2.3
|
205 |
+
- prompt-toolkit==3.0.47
|
206 |
+
- protobuf==3.20.3
|
207 |
+
- psutil==6.0.0
|
208 |
+
- ptyprocess==0.7.0
|
209 |
+
- pure-eval==0.2.3
|
210 |
+
- pyarrow==10.0.0
|
211 |
+
- pyasn1==0.6.0
|
212 |
+
- pyasn1-modules==0.4.0
|
213 |
+
- pycocotools==2.0.8
|
214 |
+
- pycollada==0.8
|
215 |
+
- pycparser==2.22
|
216 |
+
- pygments==2.18.0
|
217 |
+
- pyliblzfse==0.4.1
|
218 |
+
- pymeshlab==2023.12.post1
|
219 |
+
- pyngrok==7.2.0
|
220 |
+
- pyparsing==3.1.4
|
221 |
+
- pyquaternion==0.9.9
|
222 |
+
- pysocks==1.7.1
|
223 |
+
- python-box==6.1.0
|
224 |
+
- python-dateutil==2.9.0.post0
|
225 |
+
- python-engineio==4.9.1
|
226 |
+
- python-json-logger==2.0.7
|
227 |
+
- python-socketio==5.11.3
|
228 |
+
- pytorch-lightning==2.4.0
|
229 |
+
- pytz==2024.1
|
230 |
+
- pywavelets==1.6.0
|
231 |
+
- pyyaml==6.0.2
|
232 |
+
- pyzmq==26.2.0
|
233 |
+
- rawpy==0.22.0
|
234 |
+
- referencing==0.35.1
|
235 |
+
- requests==2.32.3
|
236 |
+
- requests-oauthlib==2.0.0
|
237 |
+
- requests-toolbelt==1.0.0
|
238 |
+
- retrying==1.3.4
|
239 |
+
- rfc3339-validator==0.1.4
|
240 |
+
- rfc3986-validator==0.1.1
|
241 |
+
- rich==13.8.0
|
242 |
+
- rpds-py==0.20.0
|
243 |
+
- rsa==4.9
|
244 |
+
- rtree==1.3.0
|
245 |
+
- scikit-image==0.20.0
|
246 |
+
- scikit-learn==1.2.2
|
247 |
+
- scipy==1.9.1
|
248 |
+
- seaborn==0.13.2
|
249 |
+
- semantic-version==2.10.0
|
250 |
+
- send2trash==1.8.3
|
251 |
+
- sentry-sdk==2.13.0
|
252 |
+
- setproctitle==1.3.3
|
253 |
+
- setuptools==67.6.0
|
254 |
+
- shapely==1.8.5.post1
|
255 |
+
- shtab==1.7.1
|
256 |
+
- simple-websocket==1.0.0
|
257 |
+
- simplejson==3.19.3
|
258 |
+
- six==1.16.0
|
259 |
+
- smmap==5.0.1
|
260 |
+
- sniffio==1.3.1
|
261 |
+
- soupsieve==2.6
|
262 |
+
- splines==0.3.0
|
263 |
+
- stack-data==0.6.3
|
264 |
+
- svg-path==6.3
|
265 |
+
- sympy==1.13.2
|
266 |
+
- tenacity==9.0.0
|
267 |
+
- tensorboard==2.12.3
|
268 |
+
- tensorboard-data-server==0.7.2
|
269 |
+
- tensorflow==2.12.0
|
270 |
+
- tensorflow-addons==0.23.0
|
271 |
+
- tensorflow-datasets==4.9.3
|
272 |
+
- tensorflow-estimator==2.12.0
|
273 |
+
- tensorflow-graphics==2021.12.3
|
274 |
+
- tensorflow-io-gcs-filesystem==0.37.1
|
275 |
+
- tensorflow-metadata==1.15.0
|
276 |
+
- tensorflow-probability==0.19.0
|
277 |
+
- termcolor==2.4.0
|
278 |
+
- terminado==0.18.1
|
279 |
+
- threadpoolctl==3.5.0
|
280 |
+
- tifffile==2024.8.28
|
281 |
+
- timm==0.6.7
|
282 |
+
- tinycss2==1.3.0
|
283 |
+
- toml==0.10.2
|
284 |
+
- tomli==2.0.1
|
285 |
+
- toolz==0.12.1
|
286 |
+
- torch==2.4.0
|
287 |
+
- torch-cluster==1.6.3+pt24cu121
|
288 |
+
- torch-fidelity==0.3.0
|
289 |
+
- torch-geometric==2.5.3
|
290 |
+
- torch-scatter==2.1.2+pt24cu121
|
291 |
+
- torch-sparse==0.6.18+pt24cu121
|
292 |
+
- torchmetrics==1.4.1
|
293 |
+
- torchvision==0.19.0
|
294 |
+
- tornado==6.4.1
|
295 |
+
- tqdm==4.66.5
|
296 |
+
- traitlets==5.14.3
|
297 |
+
- trimesh==4.4.7
|
298 |
+
- triton==3.0.0
|
299 |
+
- typeguard==2.13.3
|
300 |
+
- types-python-dateutil==2.9.0.20240821
|
301 |
+
- typing-extensions==4.12.2
|
302 |
+
- tyro==0.8.10
|
303 |
+
- tzdata==2024.1
|
304 |
+
- uri-template==1.3.0
|
305 |
+
- urllib3==2.2.2
|
306 |
+
- vhacdx==0.0.8.post1
|
307 |
+
- viser==0.1.3
|
308 |
+
- visu3d==1.5.1
|
309 |
+
- wandb==0.17.8
|
310 |
+
- waymo-open-dataset-tf-2-12-0==1.6.4
|
311 |
+
- wcwidth==0.2.13
|
312 |
+
- webcolors==24.8.0
|
313 |
+
- webencodings==0.5.1
|
314 |
+
- websocket-client==1.8.0
|
315 |
+
- websockets==13.0.1
|
316 |
+
- werkzeug==3.0.4
|
317 |
+
- widgetsnbextension==4.0.13
|
318 |
+
- wrapt==1.14.1
|
319 |
+
- wsproto==1.2.0
|
320 |
+
- wurlitzer==3.1.1
|
321 |
+
- xatlas==0.0.9
|
322 |
+
- xxhash==3.5.0
|
323 |
+
- yarl==1.9.11
|
324 |
+
- yourdfpy==0.0.56
|
325 |
+
- zipp==3.20.1
|
326 |
+
prefix: /u/xiuyu/anaconda3/envs/traffic
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/run.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
import fnmatch
|
5 |
+
import torch
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
from pytorch_lightning.callbacks import LearningRateMonitor
|
8 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
9 |
+
from pytorch_lightning.strategies import DDPStrategy
|
10 |
+
from pytorch_lightning.loggers import WandbLogger
|
11 |
+
|
12 |
+
from dev.utils.func import RankedLogger, load_config_act, CONSOLE
|
13 |
+
from dev.datasets.scalable_dataset import MultiDataModule
|
14 |
+
from dev.model.smart import SMART
|
15 |
+
|
16 |
+
|
17 |
+
def backup(source_dir, backup_dir):
|
18 |
+
"""
|
19 |
+
Back up the source directory (code and configs) to a backup directory.
|
20 |
+
"""
|
21 |
+
|
22 |
+
if os.path.exists(backup_dir):
|
23 |
+
return
|
24 |
+
os.makedirs(backup_dir, exist_ok=False)
|
25 |
+
|
26 |
+
# Helper function to check if a path matches exclude patterns
|
27 |
+
def should_exclude(path):
|
28 |
+
for pattern in exclude_patterns:
|
29 |
+
if fnmatch.fnmatch(os.path.basename(path), pattern):
|
30 |
+
return True
|
31 |
+
return False
|
32 |
+
|
33 |
+
# Iterate through the files and directories in source_dir
|
34 |
+
for root, dirs, files in os.walk(source_dir):
|
35 |
+
# Skip excluded directories
|
36 |
+
dirs[:] = [d for d in dirs if not should_exclude(d)]
|
37 |
+
|
38 |
+
# Determine the relative path and destination path
|
39 |
+
rel_path = os.path.relpath(root, source_dir)
|
40 |
+
dest_dir = os.path.join(backup_dir, rel_path)
|
41 |
+
os.makedirs(dest_dir, exist_ok=True)
|
42 |
+
|
43 |
+
# Copy all relevant files
|
44 |
+
for file in files:
|
45 |
+
if any(fnmatch.fnmatch(file, pattern) for pattern in include_patterns):
|
46 |
+
shutil.copy2(os.path.join(root, file), os.path.join(dest_dir, file))
|
47 |
+
|
48 |
+
logger.info(f"Backup completed. Files saved to: {backup_dir}")
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == '__main__':
|
52 |
+
pl.seed_everything(2024, workers=True)
|
53 |
+
torch.set_printoptions(precision=3)
|
54 |
+
|
55 |
+
parser = ArgumentParser()
|
56 |
+
parser.add_argument('--config', type=str, default='configs/ours_long_term.yaml')
|
57 |
+
parser.add_argument('--pretrain_ckpt', type=str, default=None,
|
58 |
+
help='Path to any pretrained model, will only load its parameters.'
|
59 |
+
)
|
60 |
+
parser.add_argument('--ckpt_path', type=str, default=None,
|
61 |
+
help='Path to any trained model, will load all the states.'
|
62 |
+
)
|
63 |
+
parser.add_argument('--save_ckpt_path', type=str, default='output/debug',
|
64 |
+
help='Path to save the checkpoints in training mode'
|
65 |
+
)
|
66 |
+
parser.add_argument('--save_path', type=str, default=None,
|
67 |
+
help='Path to save the inference results in validation and test mode.'
|
68 |
+
)
|
69 |
+
parser.add_argument('--wandb', action='store_true',
|
70 |
+
help='Whether to use wandb logger in training.'
|
71 |
+
)
|
72 |
+
parser.add_argument('--devices', type=int, default=1)
|
73 |
+
parser.add_argument('--train', action='store_true')
|
74 |
+
parser.add_argument('--validate', action='store_true')
|
75 |
+
parser.add_argument('--test', action='store_true')
|
76 |
+
parser.add_argument('--plot_rollouts', action='store_true')
|
77 |
+
args = parser.parse_args()
|
78 |
+
|
79 |
+
if not (args.train or args.validate or args.test or args.plot_rollouts):
|
80 |
+
raise RuntimeError(f"Got invalid action, should be one of ['train', 'validate', 'test', 'plot_rollouts']")
|
81 |
+
|
82 |
+
# ! setup logger
|
83 |
+
logger = RankedLogger(__name__, rank_zero_only=True)
|
84 |
+
|
85 |
+
# ! backup codes
|
86 |
+
exclude_patterns = ['*output*', '*logs', 'wandb', 'data', '*debug*', '*backup*', 'interact_*', '*edge_map*', '__pycache__']
|
87 |
+
include_patterns = ['*.py', '*.json', '*.yaml', '*.yml', '*.sh']
|
88 |
+
backup(os.getcwd(), os.path.join(args.save_ckpt_path, 'backups'))
|
89 |
+
|
90 |
+
config = load_config_act(args.config)
|
91 |
+
|
92 |
+
wandb_logger = None
|
93 |
+
if args.wandb and not int(os.getenv('DEBUG', 0)):
|
94 |
+
# squeue -O username,state,nodelist,gres,minmemory,numcpus,name
|
95 |
+
wandb_logger = WandbLogger(project='simagent')
|
96 |
+
|
97 |
+
trainer_config = config.Trainer
|
98 |
+
max_epochs = trainer_config.max_epochs
|
99 |
+
|
100 |
+
# ! setup datamodule and model
|
101 |
+
datamodule = MultiDataModule(**vars(config.Dataset), logger=logger)
|
102 |
+
model = SMART(config.Model, save_path=args.save_ckpt_path, logger=logger, max_epochs=max_epochs)
|
103 |
+
if args.pretrain_ckpt:
|
104 |
+
model.load_state_from_file(filename=args.pretrain_ckpt)
|
105 |
+
strategy = DDPStrategy(find_unused_parameters=True, gradient_as_bucket_view=True)
|
106 |
+
logger.info(f'Build model: {model.__class__.__name__} datamodule: {datamodule.__class__.__name__}')
|
107 |
+
|
108 |
+
# ! checkpoint configuration
|
109 |
+
every_n_epochs = 1
|
110 |
+
if int(os.getenv('OVERFIT', 0)):
|
111 |
+
max_epochs = trainer_config.overfit_epochs
|
112 |
+
every_n_epochs = 100
|
113 |
+
|
114 |
+
if int(os.getenv('CHECK_INPUTS', 0)):
|
115 |
+
max_epochs = 1
|
116 |
+
|
117 |
+
check_val_every_n_epoch = 1 # save checkpoints for each epoch
|
118 |
+
model_checkpoint = ModelCheckpoint(dirpath=args.save_ckpt_path,
|
119 |
+
filename='{epoch:02d}',
|
120 |
+
save_top_k=5,
|
121 |
+
monitor='epoch',
|
122 |
+
mode='max',
|
123 |
+
save_last=True,
|
124 |
+
every_n_train_steps=1000,
|
125 |
+
save_on_train_epoch_end=True)
|
126 |
+
|
127 |
+
# ! setup trainer
|
128 |
+
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
129 |
+
trainer = pl.Trainer(accelerator=trainer_config.accelerator, devices=args.devices if args.devices is not None else trainer_config.devices,
|
130 |
+
strategy=strategy, logger=wandb_logger,
|
131 |
+
accumulate_grad_batches=trainer_config.accumulate_grad_batches,
|
132 |
+
num_nodes=trainer_config.num_nodes,
|
133 |
+
callbacks=[model_checkpoint, lr_monitor],
|
134 |
+
max_epochs=max_epochs,
|
135 |
+
num_sanity_val_steps=0,
|
136 |
+
check_val_every_n_epoch=check_val_every_n_epoch,
|
137 |
+
log_every_n_steps=1,
|
138 |
+
gradient_clip_val=0.5)
|
139 |
+
logger.info(f'Build trainer: {trainer.__class__.__name__}')
|
140 |
+
|
141 |
+
# ! run
|
142 |
+
if args.train:
|
143 |
+
|
144 |
+
logger.info(f'Start training ...')
|
145 |
+
trainer.fit(model, datamodule, ckpt_path=args.ckpt_path)
|
146 |
+
|
147 |
+
# NOTE: here both validation and test process use validation split data
|
148 |
+
# for validation, we enable the online metric calculation with results dumping
|
149 |
+
# for test, we disable it and only dump the inference results.
|
150 |
+
else:
|
151 |
+
|
152 |
+
if args.save_path is not None:
|
153 |
+
save_path = args.save_path
|
154 |
+
else:
|
155 |
+
assert args.ckpt_path is not None and os.path.exists(args.ckpt_path), \
|
156 |
+
f'Path {args.ckpt_path} not exists!'
|
157 |
+
save_path = os.path.join(os.path.dirname(args.ckpt_path), 'validation')
|
158 |
+
os.makedirs(save_path, exist_ok=True)
|
159 |
+
CONSOLE.log(f'Results will be saved to [yellow]{save_path}[/]')
|
160 |
+
|
161 |
+
model.save_path = save_path
|
162 |
+
|
163 |
+
if not args.ckpt_path:
|
164 |
+
CONSOLE.log(f'[yellow] Warning: no checkpoint will be loaded in validation! [/]')
|
165 |
+
|
166 |
+
if args.validate:
|
167 |
+
|
168 |
+
CONSOLE.log('[on blue] Start validating ... [/]')
|
169 |
+
model.set(mode='validation')
|
170 |
+
|
171 |
+
elif args.test:
|
172 |
+
|
173 |
+
CONSOLE.log('[on blue] Sart testing ... [/]')
|
174 |
+
model.set(mode='test')
|
175 |
+
|
176 |
+
elif args.plot_rollouts:
|
177 |
+
|
178 |
+
CONSOLE.log('[on blue] Sart generating ... [/]')
|
179 |
+
model.set(mode='plot_rollouts')
|
180 |
+
|
181 |
+
trainer.validate(model, datamodule, ckpt_path=args.ckpt_path)
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/aggregate_log_metric_features.sh
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
export TF_CPP_MIN_LOG_LEVEL='2'
|
4 |
+
export PYTHONPATH='.'
|
5 |
+
|
6 |
+
# dump all features
|
7 |
+
echo 'Start dump all log features ...'
|
8 |
+
python dev/metrics/compute_metrics.py --dump_log --no_batch
|
9 |
+
|
10 |
+
sleep 20
|
11 |
+
|
12 |
+
# aggregate features
|
13 |
+
echo 'Start aggregate log features ...'
|
14 |
+
python dev/metrics/compute_metrics.py --aggregate_log
|
15 |
+
|
16 |
+
echo 'Done!
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/c128.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
#SBATCH --job-name c128 # Job name
|
4 |
+
### Logging
|
5 |
+
#SBATCH --output=%j.out # Stdout (%j expands to jobId)
|
6 |
+
#SBATCH --error=%j.err # Stderr (%j expands to jobId)
|
7 |
+
### Node info
|
8 |
+
#SBATCH --nodes=1 # Single node or multi node
|
9 |
+
#SBATCH --time 100:00:00 # Max time (hh:mm:ss)
|
10 |
+
#SBATCH --gres=gpu:0 # GPUs per node
|
11 |
+
#SBATCH --mem=128G # Recommend 32G per GPU
|
12 |
+
#SBATCH --ntasks-per-node=1 # Tasks per node
|
13 |
+
#SBATCH --cpus-per-task=64 # Recommend 8 per GPU
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/c64.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
#SBATCH --job-name c128 # Job name
|
4 |
+
### Logging
|
5 |
+
#SBATCH --output=%j.out # Stdout (%j expands to jobId)
|
6 |
+
#SBATCH --error=%j.err # Stderr (%j expands to jobId)
|
7 |
+
### Node info
|
8 |
+
#SBATCH --nodes=1 # Single node or multi node
|
9 |
+
#SBATCH --time 100:00:00 # Max time (hh:mm:ss)
|
10 |
+
#SBATCH --gres=gpu:0 # GPUs per node
|
11 |
+
#SBATCH --mem=128G # Recommend 32G per GPU
|
12 |
+
#SBATCH --ntasks-per-node=1 # Tasks per node
|
13 |
+
#SBATCH --cpus-per-task=64 # Recommend 8 per GPU
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/compute_metrics.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
export TORCH_LOGS='0'
|
4 |
+
export TF_CPP_MIN_LOG_LEVEL='2'
|
5 |
+
export PYTHONPATH='.'
|
6 |
+
|
7 |
+
NUM_WORKERS=$1
|
8 |
+
SIM_DIR=$2
|
9 |
+
|
10 |
+
echo 'Start running ...'
|
11 |
+
python dev/metrics/compute_metrics.py --compute_metric --num_workers "$NUM_WORKERS" --sim_dir "$SIM_DIR" ${@:3}
|
12 |
+
|
13 |
+
echo 'Done!
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/data_preprocess.sh
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# env
|
4 |
+
source ~/anaconda3/etc/profile.d/conda.sh
|
5 |
+
conda config --append envs_dirs ~/.conda/envs
|
6 |
+
conda activate traj
|
7 |
+
|
8 |
+
echo "Starting running..."
|
9 |
+
|
10 |
+
# multi-GPU training
|
11 |
+
cd ~/work/dev6/thirdparty/dev4
|
12 |
+
PYTHONPATH='..':$PYTHONPATH python3 data_preprocess.py --split validation
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/data_preprocess_loop.sh
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
RED='\033[0;31m'
|
4 |
+
NC='\033[0m'
|
5 |
+
|
6 |
+
cd /u/xiuyu/work/dev4/
|
7 |
+
|
8 |
+
trap "echo -e \"${RED}Stopping script...${NC}\"; kill -- -$$" SIGINT
|
9 |
+
|
10 |
+
while true; do
|
11 |
+
echo -e "${RED}Start running ...${NC}"
|
12 |
+
PYTHONPATH='.':$PYTHONPATH setsid python data_preprocess.py --split training &
|
13 |
+
PID=$!
|
14 |
+
|
15 |
+
sleep 1200
|
16 |
+
|
17 |
+
echo -e "${RED}Sending SIGINT to process group $PID...${NC}"
|
18 |
+
PGID=$(ps -o pgid= -p $PID | tail -n 1 | tr -d ' ')
|
19 |
+
kill -- -$PGID
|
20 |
+
wait $PID
|
21 |
+
|
22 |
+
sleep 10
|
23 |
+
done
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/debug.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.optim as optim
|
4 |
+
|
5 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
6 |
+
|
7 |
+
embedding = nn.Embedding(180, 128).to(device)
|
8 |
+
gt = torch.randint(0, 2, (180, 2048)).to(device)
|
9 |
+
head = nn.Linear(128, 2048).to(device)
|
10 |
+
optimizer = optim.Adam([embedding.weight, head.weight])
|
11 |
+
|
12 |
+
while True:
|
13 |
+
pred = head(embedding.weight).sigmoid()
|
14 |
+
loss = nn.MSELoss()(pred, gt.float())
|
15 |
+
optimizer.zero_grad()
|
16 |
+
loss.backward()
|
17 |
+
optimizer.step()
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/debug_map.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
import torch
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
from tqdm import tqdm
|
7 |
+
from argparse import ArgumentParser
|
8 |
+
from dev.datasets.preprocess import TokenProcessor
|
9 |
+
from dev.transforms.target_builder import WaymoTargetBuilder
|
10 |
+
|
11 |
+
|
12 |
+
colors = [
|
13 |
+
('#1f77b4', '#1a5a8a'), # blue
|
14 |
+
('#2ca02c', '#217721'), # green
|
15 |
+
('#ff7f0e', '#cc660b'), # orange
|
16 |
+
('#9467bd', '#6f4a91'), # purple
|
17 |
+
('#d62728', '#a31d1d'), # red
|
18 |
+
('#000000', '#000000'), # black
|
19 |
+
]
|
20 |
+
|
21 |
+
|
22 |
+
def draw_map(tokenize_data, token_processor: TokenProcessor, index, posfix):
|
23 |
+
print("Drawing raw data ...")
|
24 |
+
shift = 5
|
25 |
+
token_size = 2048
|
26 |
+
|
27 |
+
traj_token = token_processor.trajectory_token["veh"]
|
28 |
+
traj_token_all = token_processor.trajectory_token_all["veh"]
|
29 |
+
|
30 |
+
plt.subplots_adjust(left=0.3, right=0.7, top=0.7, bottom=0.3)
|
31 |
+
fig, ax = plt.subplots()
|
32 |
+
ax.set_axis_off()
|
33 |
+
|
34 |
+
scenario_id = data['scenario_id']
|
35 |
+
ax.scatter(tokenize_data["map_point"]["position"][:, 0],
|
36 |
+
tokenize_data["map_point"]["position"][:, 1], s=0.2, c='black', edgecolors='none')
|
37 |
+
|
38 |
+
index = np.array(index).astype(np.int32)
|
39 |
+
agent_data = tokenize_data["agent"]
|
40 |
+
token_index = agent_data["token_idx"][index]
|
41 |
+
token_valid_mask = agent_data["agent_valid_mask"][index]
|
42 |
+
|
43 |
+
num_agent, num_token = token_index.shape
|
44 |
+
tokens = traj_token[token_index.view(-1)].reshape(num_agent, num_token, 4, 2)
|
45 |
+
tokens_all = traj_token_all[token_index.view(-1)].reshape(num_agent, num_token, 6, 4, 2)
|
46 |
+
|
47 |
+
position = agent_data['position'][index, :, :2] # (num_agent, 91, 2)
|
48 |
+
heading = agent_data['heading'][index] # (num_agent, 91)
|
49 |
+
valid_mask = (position[..., 0] != 0) & (position[..., 1] != 0) # (num_agent, 91)
|
50 |
+
# TODO: fix this
|
51 |
+
if args.smart:
|
52 |
+
for shifted_tid in range(token_valid_mask.shape[1]):
|
53 |
+
valid_mask[:, shifted_tid * shift : (shifted_tid + 1) * shift] = token_valid_mask[:, shifted_tid : shifted_tid + 1].repeat(1, shift)
|
54 |
+
else:
|
55 |
+
for shifted_tid in range(token_index.shape[1]):
|
56 |
+
valid_mask[:, shifted_tid * shift : (shifted_tid + 1) * shift] = token_index[:, shifted_tid : shifted_tid + 1] != token_size + 2
|
57 |
+
last_valid_step = valid_mask.shape[1] - 1 - torch.argmax(valid_mask.flip(dims=[1]).long(), dim=1)
|
58 |
+
last_valid_step = {int(index[i]): int(last_valid_step[i]) for i in range(len(index))}
|
59 |
+
|
60 |
+
_, token_num, token_contour_dim, feat_dim = tokens.shape
|
61 |
+
tokens_src = tokens.reshape(num_agent, token_num * token_contour_dim, feat_dim)
|
62 |
+
tokens_all_src = tokens_all.reshape(num_agent, token_num * 6 * token_contour_dim, feat_dim)
|
63 |
+
prev_heading = heading[:, 0]
|
64 |
+
prev_pos = position[:, 0]
|
65 |
+
|
66 |
+
fig_paths = []
|
67 |
+
agent_colors = np.zeros((num_agent, position.shape[1]))
|
68 |
+
shape = np.zeros((num_agent, position.shape[1], 2)) + 3.
|
69 |
+
for tid in tqdm(range(shift, position.shape[1], shift), leave=False, desc="Token ..."):
|
70 |
+
cos, sin = prev_heading.cos(), prev_heading.sin()
|
71 |
+
rot_mat = prev_heading.new_zeros(num_agent, 2, 2)
|
72 |
+
rot_mat[:, 0, 0] = cos
|
73 |
+
rot_mat[:, 0, 1] = sin
|
74 |
+
rot_mat[:, 1, 0] = -sin
|
75 |
+
rot_mat[:, 1, 1] = cos
|
76 |
+
tokens_world = torch.bmm(torch.from_numpy(tokens_src).float(), rot_mat).reshape(num_agent,
|
77 |
+
token_num,
|
78 |
+
token_contour_dim,
|
79 |
+
feat_dim)
|
80 |
+
tokens_all_world = torch.bmm(torch.from_numpy(tokens_all_src).float(), rot_mat).reshape(num_agent,
|
81 |
+
token_num,
|
82 |
+
6,
|
83 |
+
token_contour_dim,
|
84 |
+
feat_dim)
|
85 |
+
tokens_world += prev_pos[:, None, None, :2]
|
86 |
+
tokens_all_world += prev_pos[:, None, None, None, :2]
|
87 |
+
tokens_select = tokens_world[:, tid // shift - 1] # (num_agent, token_contour_dim, feat_dim)
|
88 |
+
tokens_all_select = tokens_all_world[:, tid // shift - 1] # (num_agent, 6, token_contour_dim, feat_dim)
|
89 |
+
|
90 |
+
diff_xy = tokens_select[:, 0, :] - tokens_select[:, 3, :]
|
91 |
+
prev_heading = heading[:, tid].clone()
|
92 |
+
# prev_heading[valid_mask[:, tid - shift]] = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])[
|
93 |
+
# valid_mask[:, tid - shift]]
|
94 |
+
prev_pos = position[:, tid].clone()
|
95 |
+
# prev_pos[valid_mask[:, tid - shift]] = tokens_select.mean(dim=1)[valid_mask[:, tid - shift]]
|
96 |
+
|
97 |
+
# NOTE tokens_pos equals to tokens_all_pos[:, -1]
|
98 |
+
tokens_pos = tokens_select.mean(dim=1) # (num_agent, 2)
|
99 |
+
tokens_all_pos = tokens_all_select.mean(dim=2) # (num_agent, 6, 2)
|
100 |
+
|
101 |
+
# colors
|
102 |
+
cur_token_index = token_index[:, tid // shift - 1]
|
103 |
+
is_bos = cur_token_index == token_size
|
104 |
+
is_eos = cur_token_index == token_size + 1
|
105 |
+
is_invalid = cur_token_index == token_size + 2
|
106 |
+
is_valid = ~is_bos & ~is_eos & ~is_invalid
|
107 |
+
agent_colors[is_valid, tid - shift : tid] = 1
|
108 |
+
agent_colors[is_bos, tid - shift : tid] = 2
|
109 |
+
agent_colors[is_eos, tid - shift : tid] = 3
|
110 |
+
agent_colors[is_invalid, tid - shift : tid] = 4
|
111 |
+
|
112 |
+
for i in tqdm(range(shift), leave=False, desc="Timestep ..."):
|
113 |
+
global_tid = tid - shift + i
|
114 |
+
cur_valid_mask = valid_mask[:, tid - shift] # only when the last tokenized timestep is valid the current shifts trajectory is valid
|
115 |
+
xs = tokens_all_pos[cur_valid_mask, i, 0]
|
116 |
+
ys = tokens_all_pos[cur_valid_mask, i, 1]
|
117 |
+
widths = shape[cur_valid_mask, global_tid, 1]
|
118 |
+
lengths = shape[cur_valid_mask, global_tid, 0]
|
119 |
+
angles = heading[cur_valid_mask, global_tid]
|
120 |
+
cur_agent_colors = agent_colors[cur_valid_mask, global_tid]
|
121 |
+
current_index = index[cur_valid_mask]
|
122 |
+
|
123 |
+
drawn_agents = []
|
124 |
+
drawn_texts = []
|
125 |
+
for x, y, width, length, angle, color_type, id in zip(
|
126 |
+
xs, ys, widths, lengths, angles, cur_agent_colors, current_index):
|
127 |
+
if x < 3000: continue
|
128 |
+
agent = plt.Rectangle((x, y), width, length, # angle=((angle + np.pi / 2) / np.pi * 360) % 360,
|
129 |
+
linewidth=0.2,
|
130 |
+
facecolor=colors[int(color_type) - 1][0],
|
131 |
+
edgecolor=colors[int(color_type) - 1][1])
|
132 |
+
ax.add_patch(agent)
|
133 |
+
text = plt.text(x-4, y-4, f"{str(id)}:{str(global_tid)}", fontdict={'family': 'serif', 'size': 3, 'color': 'red'})
|
134 |
+
|
135 |
+
if global_tid != last_valid_step[id]:
|
136 |
+
drawn_agents.append(agent)
|
137 |
+
drawn_texts.append(text)
|
138 |
+
|
139 |
+
# draw timestep to be tokenized
|
140 |
+
if global_tid % shift == 0:
|
141 |
+
tokenize_agent = plt.Rectangle((x, y), width, length, # angle=((angle + np.pi / 2) / np.pi * 360) % 360,
|
142 |
+
linewidth=0.2, fill=False,
|
143 |
+
edgecolor=colors[int(color_type) - 1][1])
|
144 |
+
ax.add_patch(tokenize_agent)
|
145 |
+
|
146 |
+
plt.gca().set_aspect('equal', adjustable='box')
|
147 |
+
|
148 |
+
fig_path = f"debug/tokenize/steps/{scenario_id}_{global_tid}.png"
|
149 |
+
plt.savefig(fig_path, dpi=600, bbox_inches="tight")
|
150 |
+
fig_paths.append(fig_path)
|
151 |
+
|
152 |
+
for drawn_agent, drawn_text in zip(drawn_agents, drawn_texts):
|
153 |
+
drawn_agent.remove()
|
154 |
+
drawn_text.remove()
|
155 |
+
|
156 |
+
plt.close()
|
157 |
+
|
158 |
+
# generate gif
|
159 |
+
import imageio.v2 as imageio
|
160 |
+
images = []
|
161 |
+
for fig_path in tqdm(fig_paths, leave=False, desc="Generate gif ..."):
|
162 |
+
images.append(imageio.imread(fig_path))
|
163 |
+
imageio.mimsave(f"debug/tokenize/{scenario_id}_tokenize_{posfix}.gif", images, duration=0.1)
|
164 |
+
|
165 |
+
|
166 |
+
def main(data):
|
167 |
+
|
168 |
+
token_size = 2048
|
169 |
+
|
170 |
+
os.makedirs("debug/tokenize/steps/", exist_ok=True)
|
171 |
+
scenario_id = data["scenario_id"]
|
172 |
+
|
173 |
+
selected_agents_index = [1, 21, 35, 36, 46]
|
174 |
+
|
175 |
+
# raw data
|
176 |
+
if not os.path.exists(f"debug/tokenize/{scenario_id}_raw.gif"):
|
177 |
+
draw_raw(data, selected_agents_index)
|
178 |
+
|
179 |
+
# tokenization
|
180 |
+
token_processor = TokenProcessor(token_size, disable_invalid=args.smart)
|
181 |
+
print(f"Loaded token processor with token_size: {token_size}")
|
182 |
+
data = token_processor.preprocess(data)
|
183 |
+
|
184 |
+
# tokenzied data
|
185 |
+
posfix = "smart" if args.smart else "ours"
|
186 |
+
# if not os.path.exists(f"debug/tokenize/{scenario_id}_tokenize_{posfix}.gif"):
|
187 |
+
draw_tokenize(data, token_processor, selected_agents_index, posfix)
|
188 |
+
|
189 |
+
target_builder = WaymoTargetBuilder(num_historical_steps=11, num_future_steps=80)
|
190 |
+
data = target_builder(data)
|
191 |
+
|
192 |
+
|
193 |
+
if __name__ == "__main__":
|
194 |
+
parser = ArgumentParser(description="Testing script parameters")
|
195 |
+
parser.add_argument("--smart", action="store_true")
|
196 |
+
parser.add_argument("--data_path", type=str, default="/u/xiuyu/work/dev4/data/waymo_processed/training")
|
197 |
+
args = parser.parse_args()
|
198 |
+
|
199 |
+
scenario_id = "74ad7b76d5906d39"
|
200 |
+
data_path = os.path.join(args.data_path, f"{scenario_id}.pkl")
|
201 |
+
data = pickle.load(open(data_path, "rb"))
|
202 |
+
print(f"Loaded scenario {scenario_id}")
|
203 |
+
|
204 |
+
main(data)
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/g2.sh
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
#SBATCH --job-name g2 # Job name
|
4 |
+
### Logging
|
5 |
+
#SBATCH --output=%j.out # Stdout (%j expands to jobId)
|
6 |
+
#SBATCH --error=%j.err # Stderr (%j expands to jobId)
|
7 |
+
### Node info
|
8 |
+
#SBATCH --nodes=1 # Single node or multi node
|
9 |
+
#SBATCH --nodelist=sota-2
|
10 |
+
#SBATCH --time 24:00:00 # Max time (hh:mm:ss)
|
11 |
+
#SBATCH --gres=gpu:2 # GPUs per node
|
12 |
+
#SBATCH --mem=96G # Recommend 32G per GPU
|
13 |
+
#SBATCH --ntasks-per-node=1 # Tasks per node
|
14 |
+
#SBATCH --cpus-per-task=16 # Recommend 8 per GPU
|
15 |
+
|
16 |
+
export NCCL_DEBUG=INFO
|
17 |
+
export REQUESTS_CA_BUNDLE="/etc/ssl/certs/ca-certificates.crt"
|
18 |
+
export HTTPS_PROXY="https://192.168.0.10:443/"
|
19 |
+
export https_proxy="https://192.168.0.10:443/"
|
20 |
+
|
21 |
+
export TEST_VAL_TRAIN=False
|
22 |
+
export TEST_VAL_PRED=True
|
23 |
+
export WANDB=True
|
24 |
+
|
25 |
+
sleep 86400
|
26 |
+
|
27 |
+
cd /u/xiuyu/work/dev4
|
28 |
+
PYTHONPATH=".":$PYTHONPATH python3 train.py \
|
29 |
+
--devices 2 \
|
30 |
+
--config configs/train/train_scalable_with_state.yaml \
|
31 |
+
--save_ckpt_path output/seed_1k_pure_seed_150_3_emb_head_3_debug \
|
32 |
+
--pretrain_ckpt output/ours_map_pretrain/epoch=31.ckpt
|
33 |
+
|
34 |
+
PYTHONPATH=".":$PYTHONPATH python val.py \
|
35 |
+
--config configs/validation/val_scalable_with_state.yaml \
|
36 |
+
--save_path output/seed_debug \
|
37 |
+
--pretrain_ckpt output/seed_1k_pure_seed_150_3_emb_head_3/last.ckpt
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/g4.sh
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
#SBATCH --job-name g4 # Job name
|
4 |
+
### Logging
|
5 |
+
#SBATCH --output=%j.out # Stdout (%j expands to jobId)
|
6 |
+
#SBATCH --error=%j.err # Stderr (%j expands to jobId)
|
7 |
+
### Node info
|
8 |
+
#SBATCH --nodes=1 # Single node or multi node
|
9 |
+
#SBATCH --nodelist=sota-1
|
10 |
+
#SBATCH --time 72:00:00 # Max time (hh:mm:ss)
|
11 |
+
#SBATCH --gres=gpu:4 # GPUs per node
|
12 |
+
#SBATCH --mem=128G # Recommend 32G per GPU
|
13 |
+
#SBATCH --ntasks-per-node=1 # Tasks per node
|
14 |
+
#SBATCH --cpus-per-task=32 # Recommend 8 per GPU
|
15 |
+
|
16 |
+
export REQUESTS_CA_BUNDLE="/etc/ssl/certs/ca-certificates.crt"
|
17 |
+
export HTTPS_PROXY="https://192.168.0.10:443/"
|
18 |
+
export https_proxy="https://192.168.0.10:443/"
|
19 |
+
|
20 |
+
export TEST_VAL_TRAIN=0
|
21 |
+
export TEST_VAL_PRED=1
|
22 |
+
export WANDB=1
|
23 |
+
|
24 |
+
sleep 604800
|
25 |
+
|
26 |
+
cd /u/xiuyu/work/dev4
|
27 |
+
PYTHONPATH=".":$PYTHONPATH python3 train.py \
|
28 |
+
--devices 4 \
|
29 |
+
--config configs/train/train_scalable_with_state.yaml \
|
30 |
+
--save_ckpt_path output/seq_1k_10_150_3_3_encode_occ_separate_offsets \
|
31 |
+
--pretrain_ckpt output/pretrain_scalable_map/epoch=31.ckpt
|
32 |
+
|
33 |
+
PYTHONPATH=".":$PYTHONPATH python val.py \
|
34 |
+
--config configs/ours_long_term.yaml \
|
35 |
+
--ckpt_path output/seq_5k_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long/epoch=31.ckpt
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/g8.sh
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
#SBATCH --job-name g8 # Job name
|
4 |
+
### Logging
|
5 |
+
#SBATCH --output=%j.out # Stdout (%j expands to jobId)
|
6 |
+
#SBATCH --error=%j.err # Stderr (%j expands to jobId)
|
7 |
+
### Node info
|
8 |
+
#SBATCH --nodes=1 # Single node or multi node
|
9 |
+
#SBATCH --nodelist=sota-6
|
10 |
+
#SBATCH --time 120:00:00 # Max time (hh:mm:ss)
|
11 |
+
#SBATCH --gres=gpu:8 # GPUs per node
|
12 |
+
#SBATCH --mem=256G # Recommend 32G per GPU
|
13 |
+
#SBATCH --ntasks-per-node=1 # Tasks per node
|
14 |
+
#SBATCH --cpus-per-task=32 # Recommend 8 per GPU
|
15 |
+
|
16 |
+
export REQUESTS_CA_BUNDLE="/etc/ssl/certs/ca-certificates.crt"
|
17 |
+
export HTTPS_PROXY="https://192.168.0.10:443/"
|
18 |
+
export https_proxy="https://192.168.0.10:443/"
|
19 |
+
|
20 |
+
export TEST_VAL_TRAIN=0
|
21 |
+
export TEST_VAL_PRED=1
|
22 |
+
export WANDB=1
|
23 |
+
|
24 |
+
sleep 864000
|
25 |
+
|
26 |
+
cd /u/xiuyu/work/dev4
|
27 |
+
PYTHONPATH=".":$PYTHONPATH python3 train.py \
|
28 |
+
--devices 8 \
|
29 |
+
--config configs/train/train_scalable_long_term.yaml \
|
30 |
+
--save_ckpt_path output/seq_5k_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long \
|
31 |
+
--pretrain_ckpt output/pretrain_scalable_map/epoch=31.ckpt
|
32 |
+
|
33 |
+
PYTHONPATH=".":$PYTHONPATH python3 train.py \
|
34 |
+
--devices 8 \
|
35 |
+
--config configs/ours_long_term.yaml \
|
36 |
+
--save_ckpt_path output2/seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid
|
37 |
+
|
38 |
+
PYTHONPATH=".":$PYTHONPATH python3 train.py \
|
39 |
+
--config configs/ours_long_term.yaml \
|
40 |
+
--save_ckpt_path output2/debug
|
41 |
+
|
42 |
+
PYTHONPATH=".":$PYTHONPATH python val.py \
|
43 |
+
--config configs/ours_long_term.yaml \
|
44 |
+
--ckpt_path output2/bug_seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long/last.ckpt
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/hf_model.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from huggingface_hub import upload_folder, upload_file, hf_hub_download
|
4 |
+
from rich.console import Console
|
5 |
+
from rich.panel import Panel
|
6 |
+
from rich import box, style
|
7 |
+
from rich.table import Table
|
8 |
+
|
9 |
+
CONSOLE = Console(width=120)
|
10 |
+
|
11 |
+
|
12 |
+
def upload():
|
13 |
+
|
14 |
+
if args.folder_path:
|
15 |
+
|
16 |
+
try:
|
17 |
+
if token is not None:
|
18 |
+
upload_folder(repo_id=args.repo_id, folder_path=args.folder_path, ignore_patterns=ignore_patterns, token=token)
|
19 |
+
else:
|
20 |
+
upload_folder(repo_id=args.repo_id, folder_path=args.folder_path, ignore_patterns=ignore_patterns)
|
21 |
+
table = Table(title=None, show_header=False, box=box.MINIMAL, title_style=style.Style(bold=True))
|
22 |
+
table.add_row(f"Model id {args.repo_id}", str(args.folder_path))
|
23 |
+
CONSOLE.print(Panel(table, title="[bold][green]:tada: Upload completed DO NOT forget specify the model id in methods! :tada:[/bold]", expand=False))
|
24 |
+
|
25 |
+
except Exception as e:
|
26 |
+
CONSOLE.print(f"[bold][yellow]:tada: Upload failed due to {e}.")
|
27 |
+
raise e
|
28 |
+
|
29 |
+
if args.file_path:
|
30 |
+
|
31 |
+
try:
|
32 |
+
if token is not None:
|
33 |
+
upload_file(
|
34 |
+
path_or_fileobj=args.file_path,
|
35 |
+
path_in_repo=os.path.basename(args.file_path),
|
36 |
+
repo_id=args.repo_id,
|
37 |
+
repo_type='model',
|
38 |
+
token=token
|
39 |
+
)
|
40 |
+
else:
|
41 |
+
upload_file(
|
42 |
+
path_or_fileobj=args.file_path,
|
43 |
+
path_in_repo=os.path.basename(args.file_path),
|
44 |
+
repo_id=args.repo_id,
|
45 |
+
repo_type='model',
|
46 |
+
)
|
47 |
+
table = Table(title=None, show_header=False, box=box.MINIMAL, title_style=style.Style(bold=True))
|
48 |
+
table.add_row(f"Model id {args.repo_id}", str(args.file_path))
|
49 |
+
CONSOLE.print(Panel(table, title="[bold][green]:tada: Upload completed! :tada:[/bold]", expand=False))
|
50 |
+
|
51 |
+
except Exception as e:
|
52 |
+
CONSOLE.print(f"[bold][yellow]:tada: Upload failed due to {e}.")
|
53 |
+
raise e
|
54 |
+
|
55 |
+
|
56 |
+
def download():
|
57 |
+
|
58 |
+
try:
|
59 |
+
if token is not None:
|
60 |
+
ckpt_path = hf_hub_download(
|
61 |
+
repo_id=args.repo_id,
|
62 |
+
filename=args.file_path,
|
63 |
+
token=token
|
64 |
+
)
|
65 |
+
else:
|
66 |
+
ckpt_path = hf_hub_download(
|
67 |
+
repo_id=args.repo_id,
|
68 |
+
filename=args.file_path,
|
69 |
+
)
|
70 |
+
table = Table(title=None, show_header=False, box=box.MINIMAL, title_style=style.Style(bold=True))
|
71 |
+
table.add_row(f"Model id {args.repo_id}", str(args.file_path))
|
72 |
+
CONSOLE.print(Panel(table, title=f"[bold][green]:tada: Download completed to {ckpt_path}! :tada:[/bold]", expand=False))
|
73 |
+
|
74 |
+
if args.save_path is not None:
|
75 |
+
os.makedirs(args.save_path, exist_ok=True)
|
76 |
+
import shutil
|
77 |
+
shutil.copy(ckpt_path, os.path.join(args.save_path, args.file_path))
|
78 |
+
|
79 |
+
except Exception as e:
|
80 |
+
CONSOLE.print(f"[bold][yellow]:tada: Download failed due to {e}.")
|
81 |
+
raise e
|
82 |
+
|
83 |
+
return ckpt_path
|
84 |
+
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
parser = argparse.ArgumentParser()
|
88 |
+
parser.add_argument("--repo_id", type=str, default=None, required=True)
|
89 |
+
parser.add_argument("--upload", action="store_true")
|
90 |
+
parser.add_argument("--download", action="store_true")
|
91 |
+
parser.add_argument("--folder_path", type=str, default=None, required=False)
|
92 |
+
parser.add_argument("--file_path", type=str, default=None, required=False)
|
93 |
+
parser.add_argument("--save_path", type=str, default=None, required=False)
|
94 |
+
parser.add_argument("--token", type=str, default=None, required=False)
|
95 |
+
args = parser.parse_args()
|
96 |
+
|
97 |
+
token = args.token or os.getenv("hf_token", None)
|
98 |
+
ignore_patterns = ["**/optimizer.bin", "**/random_states*", "**/scaler.pt", "**/scheduler.bin"]
|
99 |
+
|
100 |
+
if not (args.folder_path or args.file_path):
|
101 |
+
raise RuntimeError(f'Choose either folder path or file path please!')
|
102 |
+
|
103 |
+
if len(args.repo_id.split('/')) != 2:
|
104 |
+
raise RuntimeError(f'Invalid repo_id: {args.repo_id}, please use in [use-id]/[repo-name] format')
|
105 |
+
CONSOLE.log(f"Use repo: [bold][yellow] {args.repo_id}")
|
106 |
+
|
107 |
+
if args.upload:
|
108 |
+
upload()
|
109 |
+
|
110 |
+
if args.download:
|
111 |
+
download()
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/pretrain_map.sh
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
mkdir -p job_out
|
4 |
+
|
5 |
+
#SBATCH --job-name YOUR_JOB_NAME # Job name
|
6 |
+
### Logging
|
7 |
+
#SBATCH --output=job_out/%j.out # Stdout (%j expands to jobId)
|
8 |
+
#SBATCH --error=job_out/%j.err # Stderr (%j expands to jobId)
|
9 |
+
### Node info
|
10 |
+
#SBATCH --nodes=1 # Single node or multi node
|
11 |
+
#SBATCH --nodelist=sota-6
|
12 |
+
#SBATCH --time 20:00:00 # Max time (hh:mm:ss)
|
13 |
+
#SBATCH --gres=gpu:4 # GPUs per node
|
14 |
+
#SBATCH --mem=256G # Recommend 32G per GPU
|
15 |
+
#SBATCH --ntasks-per-node=4 # Tasks per node
|
16 |
+
#SBATCH --cpus-per-task=256 # Recommend 8 per GPU
|
17 |
+
### Whatever your job needs to do
|
18 |
+
|
19 |
+
export REQUESTS_CA_BUNDLE="/etc/ssl/certs/ca-certificates.crt"
|
20 |
+
export HTTPS_PROXY="https://192.168.0.10:443/"
|
21 |
+
export https_proxy="https://192.168.0.10:443/"
|
22 |
+
|
23 |
+
export TEST_VAL_PRED=True
|
24 |
+
export WANDB=True
|
25 |
+
|
26 |
+
cd /u/xiuyu/work/dev4
|
27 |
+
PYTHONPATH=".":$PYTHONPATH python3 train.py --config configs/train/pretrain_scalable_map.yaml --save_ckpt_path output/ours_map_pretrain
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/run_eval.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /bin/bash
|
2 |
+
|
3 |
+
# env
|
4 |
+
export REQUESTS_CA_BUNDLE="/etc/ssl/certs/ca-certificates.crt"
|
5 |
+
export HTTPS_PROXY="https://192.168.0.10:443/"
|
6 |
+
export https_proxy="https://192.168.0.10:443/"
|
7 |
+
|
8 |
+
export WANDB=1
|
9 |
+
|
10 |
+
# args
|
11 |
+
DEVICES=$1
|
12 |
+
CONFIG='configs/ours_long_term.yaml'
|
13 |
+
# CKPT_PATH='output/scalable_smart_long/last.ckpt'
|
14 |
+
CKPT_PATH='output2/seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long/last.ckpt'
|
15 |
+
|
16 |
+
# run
|
17 |
+
PYTHONPATH=".":$PYTHONPATH python3 run.py \
|
18 |
+
--devices $DEVICES \
|
19 |
+
--config $CONFIG \
|
20 |
+
--ckpt_path $CKPT_PATH ${@:2}
|
seq_10_150_3_3_encode_occ_separate_offsets_bs8_128_no_seqindex_long_lastvalid_ablation_grid/backups/scripts/run_train.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /bin/bash
|
2 |
+
|
3 |
+
# env
|
4 |
+
export REQUESTS_CA_BUNDLE="/etc/ssl/certs/ca-certificates.crt"
|
5 |
+
export HTTPS_PROXY="https://192.168.0.10:443/"
|
6 |
+
export https_proxy="https://192.168.0.10:443/"
|
7 |
+
|
8 |
+
export WANDB=1
|
9 |
+
|
10 |
+
# args
|
11 |
+
DEVICES=$1
|
12 |
+
CONFIG='configs/ours_long_term.yaml'
|
13 |
+
SAVE_CKPT_PATH='output/scalable_smart_long'
|
14 |
+
|
15 |
+
# run
|
16 |
+
PYTHONPATH=".":$PYTHONPATH python3 run.py \
|
17 |
+
--train \
|
18 |
+
--devices $DEVICES \
|
19 |
+
--config $CONFIG \
|
20 |
+
--save_ckpt_path $SAVE_CKPT_PATH
|