Add files using upload-large-folder tool
Browse files- configs/base.yaml +71 -0
- configs/calvin_rel_traj_location_bounds_task_ABC_D.json +50 -0
- configs/dataset_control_freq.json +73 -0
- configs/dataset_img_keys.json +674 -0
- configs/dataset_stat.json +0 -0
- configs/finetune_datasets.json +5 -0
- configs/finetune_sample_weights.json +5 -0
- configs/pretrain_datasets.json +3 -0
- configs/pretrain_sample_weights.json +3 -0
- configs/state_vec.py +114 -0
- configs/zero2.json +14 -0
- data/aloha/hdf5totfrecords.py +98 -0
- data/aloha/unzip_data.sh +3 -0
- data/bridgev2/bridgedata_numpy_to_tfrecord.py +174 -0
- data/bridgev2/bridgedata_raw_to_numpy.py +316 -0
- data/bridgev2/download.sh +13 -0
- data/calvin/download.sh +19 -0
- data/calvin/hdf5totfrecords.py +92 -0
- data/rh20t/hdf5totfrecords.py +200 -0
- data/roboset/download.py +42 -0
- data/roboset/download.sh +21 -0
- data/roboset/h5totfrecords.py +82 -0
- data/roboset/links.txt +197 -0
- docs/pretrain.md +270 -0
- docs/test_6drot.py +99 -0
- eval_sim/eval_dp.py +166 -0
- eval_sim/eval_octo.py +182 -0
- eval_sim/eval_openvla.py +175 -0
- eval_sim/eval_rdt_maniskill.py +137 -0
- lang_embed/aloha_dish_drainer.pt +3 -0
- lang_embed/aloha_handover_box.pt +3 -0
- lang_embed/aloha_lift_box.pt +3 -0
- lang_embed/aloha_shoes_table.pt +3 -0
- lang_embed/anubis_brush_to_pan.pt +3 -0
- lang_embed/anubis_carrot_to_bag.pt +3 -0
- lang_embed/anubis_towel_kirby.pt +3 -0
- scripts/agilex_inference.py +658 -0
- scripts/agilex_model.py +313 -0
- scripts/encode_lang_batch.py +76 -0
- scripts/maniskill_model.py +277 -0
- train/dataset.py +467 -0
- train/image_corrupt.py +44 -0
- train/sample.py +99 -0
- train/train.py +509 -0
configs/base.yaml
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
common:
|
2 |
+
# The number of historical images
|
3 |
+
img_history_size: 2
|
4 |
+
# The number of future actions to predict
|
5 |
+
action_chunk_size: 64
|
6 |
+
# The number of cameras to be used in the model
|
7 |
+
num_cameras: 3
|
8 |
+
# Dimension for state/action, we use the same space for both state and action
|
9 |
+
# This MUST be equal to configs/state_vec.py
|
10 |
+
state_dim: 128
|
11 |
+
|
12 |
+
|
13 |
+
dataset:
|
14 |
+
# We will extract the data from raw dataset
|
15 |
+
# and store them in the disk buffer by producer
|
16 |
+
# When training, we will read the data
|
17 |
+
# randomly from the buffer by consumer
|
18 |
+
# The producer will replace the data which has been
|
19 |
+
# read by the consumer with new data
|
20 |
+
|
21 |
+
# The path to the buffer (at least 400GB)
|
22 |
+
buf_path: /home/jellyho/RDTBuffer
|
23 |
+
# The number of chunks in the buffer
|
24 |
+
buf_num_chunks: 128
|
25 |
+
# The number of samples (step rather than episode) in each chunk
|
26 |
+
buf_chunk_size: 128
|
27 |
+
|
28 |
+
# We will filter the episodes with length less than `epsd_len_thresh_low`
|
29 |
+
epsd_len_thresh_low: 32
|
30 |
+
# For those more than `epsd_len_thresh_high`,
|
31 |
+
# we will randomly sample `epsd_len_thresh_high` steps each time we load the episode
|
32 |
+
# to better balance the training datasets
|
33 |
+
epsd_len_thresh_high: 2048
|
34 |
+
# How to fit the image size
|
35 |
+
image_aspect_ratio: pad
|
36 |
+
# Maximum number of language tokens
|
37 |
+
tokenizer_max_length: 1024
|
38 |
+
|
39 |
+
model:
|
40 |
+
# Config for condition adpators
|
41 |
+
lang_adaptor: mlp2x_gelu
|
42 |
+
img_adaptor: mlp2x_gelu
|
43 |
+
state_adaptor: mlp3x_gelu
|
44 |
+
lang_token_dim: 4096
|
45 |
+
img_token_dim: 1152
|
46 |
+
# Dim of action or proprioception vector
|
47 |
+
# A `state` refers to an action or a proprioception vector
|
48 |
+
state_token_dim: 128
|
49 |
+
# Config for RDT structure
|
50 |
+
rdt:
|
51 |
+
# 1B: num_head 32 hidden_size 2048
|
52 |
+
hidden_size: 2048
|
53 |
+
depth: 28
|
54 |
+
num_heads: 32
|
55 |
+
cond_pos_embed_type: multimodal
|
56 |
+
# For noise scheduler
|
57 |
+
noise_scheduler:
|
58 |
+
type: ddpm
|
59 |
+
num_train_timesteps: 1000
|
60 |
+
num_inference_timesteps: 5
|
61 |
+
beta_schedule: squaredcos_cap_v2 # Critical choice
|
62 |
+
prediction_type: sample
|
63 |
+
clip_sample: False
|
64 |
+
# For EMA (params averaging)
|
65 |
+
# We do not use EMA currently
|
66 |
+
ema:
|
67 |
+
update_after_step: 0
|
68 |
+
inv_gamma: 1.0
|
69 |
+
power: 0.75
|
70 |
+
min_value: 0.0
|
71 |
+
max_value: 0.9999
|
configs/calvin_rel_traj_location_bounds_task_ABC_D.json
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"A": [
|
3 |
+
[
|
4 |
+
-0.2691913843154907,
|
5 |
+
-0.21995729207992554,
|
6 |
+
-0.182277649641037
|
7 |
+
],
|
8 |
+
[
|
9 |
+
0.35127854347229004,
|
10 |
+
0.2769763469696045,
|
11 |
+
0.17159393429756165
|
12 |
+
]
|
13 |
+
],
|
14 |
+
"B": [
|
15 |
+
[
|
16 |
+
-0.2576896846294403,
|
17 |
+
-0.22244493663311005,
|
18 |
+
-0.20557966828346252
|
19 |
+
],
|
20 |
+
[
|
21 |
+
0.32854634523391724,
|
22 |
+
0.2922680974006653,
|
23 |
+
0.17373555898666382
|
24 |
+
]
|
25 |
+
],
|
26 |
+
"C": [
|
27 |
+
[
|
28 |
+
-0.29205888509750366,
|
29 |
+
-0.24688798189163208,
|
30 |
+
-0.17577645182609558
|
31 |
+
],
|
32 |
+
[
|
33 |
+
0.25053921341896057,
|
34 |
+
0.3277084231376648,
|
35 |
+
0.16431939601898193
|
36 |
+
]
|
37 |
+
],
|
38 |
+
"D": [
|
39 |
+
[
|
40 |
+
-0.25131964683532715,
|
41 |
+
-0.15233077108860016,
|
42 |
+
-0.13294968008995056
|
43 |
+
],
|
44 |
+
[
|
45 |
+
0.19209328293800354,
|
46 |
+
0.19344553351402283,
|
47 |
+
0.1370421051979065
|
48 |
+
]
|
49 |
+
]
|
50 |
+
}
|
configs/dataset_control_freq.json
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"fractal20220817_data": 3,
|
3 |
+
"taco_play": 15,
|
4 |
+
"jaco_play": 10,
|
5 |
+
"berkeley_cable_routing": 10,
|
6 |
+
"nyu_door_opening_surprising_effectiveness": 3,
|
7 |
+
"viola": 20,
|
8 |
+
"berkeley_autolab_ur5": 5,
|
9 |
+
"toto": 30,
|
10 |
+
"kuka": 10,
|
11 |
+
"language_table": 10,
|
12 |
+
"columbia_cairlab_pusht_real": 10,
|
13 |
+
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": 20,
|
14 |
+
"nyu_rot_dataset_converted_externally_to_rlds":3,
|
15 |
+
"stanford_hydra_dataset_converted_externally_to_rlds": 10,
|
16 |
+
"austin_buds_dataset_converted_externally_to_rlds": 20,
|
17 |
+
"nyu_franka_play_dataset_converted_externally_to_rlds": 3,
|
18 |
+
"maniskill_dataset_converted_externally_to_rlds": 20,
|
19 |
+
"furniture_bench_dataset_converted_externally_to_rlds": 10,
|
20 |
+
"ucsd_kitchen_dataset_converted_externally_to_rlds": 2,
|
21 |
+
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": 3,
|
22 |
+
"austin_sailor_dataset_converted_externally_to_rlds": 20,
|
23 |
+
"austin_sirius_dataset_converted_externally_to_rlds": 20,
|
24 |
+
"bc_z": 10,
|
25 |
+
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": 10,
|
26 |
+
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": 10,
|
27 |
+
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": 10,
|
28 |
+
"utokyo_xarm_bimanual_converted_externally_to_rlds": 10,
|
29 |
+
"berkeley_mvp_converted_externally_to_rlds": 5,
|
30 |
+
"berkeley_rpt_converted_externally_to_rlds": 30,
|
31 |
+
"kaist_nonprehensile_converted_externally_to_rlds": 10,
|
32 |
+
"stanford_mask_vit_converted_externally_to_rlds": 0,
|
33 |
+
"tokyo_u_lsmo_converted_externally_to_rlds": 10,
|
34 |
+
"dlr_sara_pour_converted_externally_to_rlds": 10,
|
35 |
+
"dlr_sara_grid_clamp_converted_externally_to_rlds": 10,
|
36 |
+
"dlr_edan_shared_control_converted_externally_to_rlds": 5,
|
37 |
+
"asu_table_top_converted_externally_to_rlds": 12.5,
|
38 |
+
"stanford_robocook_converted_externally_to_rlds": 5,
|
39 |
+
"eth_agent_affordances": 66.6,
|
40 |
+
"imperialcollege_sawyer_wrist_cam": 10,
|
41 |
+
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": 20,
|
42 |
+
"uiuc_d3field": 1,
|
43 |
+
"utaustin_mutex": 20,
|
44 |
+
"berkeley_fanuc_manipulation": 10,
|
45 |
+
"cmu_play_fusion": 5,
|
46 |
+
"cmu_stretch": 10,
|
47 |
+
"berkeley_gnm_recon": 3,
|
48 |
+
"berkeley_gnm_cory_hall": 5,
|
49 |
+
"berkeley_gnm_sac_son": 10,
|
50 |
+
"robo_net": 1,
|
51 |
+
"roboturk_real_towercreation": 10,
|
52 |
+
"roboturk_real_laundrylayout": 10,
|
53 |
+
"roboturk_real_objectsearch": 10,
|
54 |
+
"aloha_mobile": 50,
|
55 |
+
"aloha_static": 50,
|
56 |
+
"roboset": 5,
|
57 |
+
"droid": 15,
|
58 |
+
"fmb": 10,
|
59 |
+
"dobbe": 30,
|
60 |
+
"qut_dexterous_manpulation": 30,
|
61 |
+
"agilex": 25,
|
62 |
+
"rh20t": 10,
|
63 |
+
"calvin": 30,
|
64 |
+
"bridgev2": 5,
|
65 |
+
"aloha_dish_drainer" : 20,
|
66 |
+
"aloha_handover_box" : 20,
|
67 |
+
"aloha_shoes_table" : 20,
|
68 |
+
"aloha_lift_box" : 20,
|
69 |
+
"aloha_box_into_pot" : 20,
|
70 |
+
"anubis_towel_kirby" : 20,
|
71 |
+
"anubis_carrot_to_bag" : 20,
|
72 |
+
"anubis_brush_to_pan" : 20
|
73 |
+
}
|
configs/dataset_img_keys.json
ADDED
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"anubis_towel_kirby": {
|
3 |
+
"image_keys": [
|
4 |
+
"agentview_image",
|
5 |
+
"right_wrist_image",
|
6 |
+
"left_wrist_image",
|
7 |
+
"agentview_image"
|
8 |
+
],
|
9 |
+
"image_mask":[
|
10 |
+
1,1,1,0
|
11 |
+
]
|
12 |
+
},
|
13 |
+
"anubis_carrot_to_bag": {
|
14 |
+
"image_keys": [
|
15 |
+
"agentview_image",
|
16 |
+
"right_wrist_image",
|
17 |
+
"left_wrist_image",
|
18 |
+
"agentview_image"
|
19 |
+
],
|
20 |
+
"image_mask":[
|
21 |
+
1,1,1,0
|
22 |
+
]
|
23 |
+
},
|
24 |
+
"anubis_brush_to_pan": {
|
25 |
+
"image_keys": [
|
26 |
+
"agentview_image",
|
27 |
+
"right_wrist_image",
|
28 |
+
"left_wrist_image",
|
29 |
+
"agentview_image"
|
30 |
+
],
|
31 |
+
"image_mask":[
|
32 |
+
1,1,1,0
|
33 |
+
]
|
34 |
+
},
|
35 |
+
"aloha_box_into_pot": {
|
36 |
+
"image_keys": [
|
37 |
+
"agentview_image",
|
38 |
+
"right_wrist_image",
|
39 |
+
"left_wrist_image",
|
40 |
+
"agentview_image"
|
41 |
+
],
|
42 |
+
"image_mask":[
|
43 |
+
1,1,1,0
|
44 |
+
]
|
45 |
+
},
|
46 |
+
"aloha_box_into_pot_easy": {
|
47 |
+
"image_keys": [
|
48 |
+
"agentview_image",
|
49 |
+
"right_wrist_image",
|
50 |
+
"left_wrist_image",
|
51 |
+
"agentview_image"
|
52 |
+
],
|
53 |
+
"image_mask":[
|
54 |
+
1,1,1,0
|
55 |
+
]
|
56 |
+
},
|
57 |
+
"aloha_dish_drainer": {
|
58 |
+
"image_keys": [
|
59 |
+
"agentview_image",
|
60 |
+
"right_wrist_image",
|
61 |
+
"left_wrist_image",
|
62 |
+
"agentview_image"
|
63 |
+
],
|
64 |
+
"image_mask":[
|
65 |
+
1,1,1,0
|
66 |
+
]
|
67 |
+
},
|
68 |
+
"aloha_handover_box": {
|
69 |
+
"image_keys": [
|
70 |
+
"agentview_image",
|
71 |
+
"right_wrist_image",
|
72 |
+
"left_wrist_image",
|
73 |
+
"agentview_image"
|
74 |
+
],
|
75 |
+
"image_mask":[
|
76 |
+
1,1,1,0
|
77 |
+
]
|
78 |
+
},
|
79 |
+
"aloha_shoes_table": {
|
80 |
+
"image_keys": [
|
81 |
+
"agentview_image",
|
82 |
+
"right_wrist_image",
|
83 |
+
"left_wrist_image",
|
84 |
+
"agentview_image"
|
85 |
+
],
|
86 |
+
"image_mask":[
|
87 |
+
1,1,1,0
|
88 |
+
]
|
89 |
+
},
|
90 |
+
"aloha_lift_box": {
|
91 |
+
"image_keys": [
|
92 |
+
"agentview_image",
|
93 |
+
"right_wrist_image",
|
94 |
+
"left_wrist_image",
|
95 |
+
"agentview_image"
|
96 |
+
],
|
97 |
+
"image_mask":[
|
98 |
+
1,1,1,0
|
99 |
+
]
|
100 |
+
},
|
101 |
+
"fractal20220817_data": {
|
102 |
+
"image_keys": [
|
103 |
+
"image",
|
104 |
+
"image",
|
105 |
+
"image",
|
106 |
+
"image"
|
107 |
+
],
|
108 |
+
"image_mask":[
|
109 |
+
1,0,0,0
|
110 |
+
]
|
111 |
+
},
|
112 |
+
"taco_play": {
|
113 |
+
"image_keys": [
|
114 |
+
"rgb_static",
|
115 |
+
"rgb_gripper",
|
116 |
+
"rgb_static",
|
117 |
+
"rgb_static"
|
118 |
+
],
|
119 |
+
"image_mask":[
|
120 |
+
1,1,0,0
|
121 |
+
]
|
122 |
+
},
|
123 |
+
"jaco_play": {
|
124 |
+
"image_keys": [
|
125 |
+
"image",
|
126 |
+
"image_wrist",
|
127 |
+
"image_wrist",
|
128 |
+
"image_wrist"
|
129 |
+
],
|
130 |
+
"image_mask":[
|
131 |
+
1,1,0,0
|
132 |
+
]
|
133 |
+
},
|
134 |
+
"berkeley_cable_routing": {
|
135 |
+
"image_keys": [
|
136 |
+
"image",
|
137 |
+
"wrist45_image",
|
138 |
+
"wrist225_image",
|
139 |
+
"top_image"
|
140 |
+
],
|
141 |
+
"image_mask":[1,1,0,1]
|
142 |
+
},
|
143 |
+
"nyu_door_opening_surprising_effectiveness": {
|
144 |
+
"image_keys": [
|
145 |
+
"image",
|
146 |
+
"image",
|
147 |
+
"image",
|
148 |
+
"image"
|
149 |
+
],
|
150 |
+
"image_mask":[1,0,0,0]
|
151 |
+
},
|
152 |
+
"viola": {
|
153 |
+
"image_keys": [
|
154 |
+
"agentview_rgb",
|
155 |
+
"eye_in_hand_rgb",
|
156 |
+
"eye_in_hand_rgb",
|
157 |
+
"eye_in_hand_rgb"
|
158 |
+
],
|
159 |
+
"image_mask":[1,1,0,0]
|
160 |
+
},
|
161 |
+
"berkeley_autolab_ur5": {
|
162 |
+
"image_keys": [
|
163 |
+
"image",
|
164 |
+
"hand_image",
|
165 |
+
"hand_image",
|
166 |
+
"hand_image"
|
167 |
+
],
|
168 |
+
"image_mask":[1,1,0,0]
|
169 |
+
},
|
170 |
+
"toto": {
|
171 |
+
"image_keys": [
|
172 |
+
"image",
|
173 |
+
"image",
|
174 |
+
"image",
|
175 |
+
"image"
|
176 |
+
],
|
177 |
+
"image_mask":[1,0,0,0]
|
178 |
+
},
|
179 |
+
"kuka": {
|
180 |
+
"image_keys": [
|
181 |
+
"image",
|
182 |
+
"image",
|
183 |
+
"image",
|
184 |
+
"image"
|
185 |
+
],
|
186 |
+
"image_mask":[1,0,0,0]
|
187 |
+
},
|
188 |
+
"language_table": {
|
189 |
+
"image_keys": [
|
190 |
+
"rgb",
|
191 |
+
"rgb",
|
192 |
+
"rgb",
|
193 |
+
"rgb"
|
194 |
+
],
|
195 |
+
"image_mask":[1,0,0,0]
|
196 |
+
},
|
197 |
+
"columbia_cairlab_pusht_real": {
|
198 |
+
"image_keys": [
|
199 |
+
"image",
|
200 |
+
"wrist_image",
|
201 |
+
"wrist_image",
|
202 |
+
"wrist_image"
|
203 |
+
],
|
204 |
+
"image_mask":[1,1,0,0]
|
205 |
+
},
|
206 |
+
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": {
|
207 |
+
"image_keys": [
|
208 |
+
"image",
|
209 |
+
"image",
|
210 |
+
"image",
|
211 |
+
"image"
|
212 |
+
],
|
213 |
+
"image_mask":[1,0,0,0]
|
214 |
+
},
|
215 |
+
"nyu_rot_dataset_converted_externally_to_rlds": {
|
216 |
+
"image_keys": [
|
217 |
+
"image",
|
218 |
+
"image",
|
219 |
+
"image",
|
220 |
+
"image"
|
221 |
+
],
|
222 |
+
"image_mask":[1,0,0,0]
|
223 |
+
},
|
224 |
+
"stanford_hydra_dataset_converted_externally_to_rlds": {
|
225 |
+
"image_keys": [
|
226 |
+
"image",
|
227 |
+
"wrist_image",
|
228 |
+
"wrist_image",
|
229 |
+
"wrist_image"
|
230 |
+
],
|
231 |
+
"image_mask":[1,1,0,0]
|
232 |
+
},
|
233 |
+
"austin_buds_dataset_converted_externally_to_rlds": {
|
234 |
+
"image_keys": [
|
235 |
+
"image",
|
236 |
+
"wrist_image",
|
237 |
+
"wrist_image",
|
238 |
+
"wrist_image"
|
239 |
+
],
|
240 |
+
"image_mask":[1,1,0,0]
|
241 |
+
},
|
242 |
+
"nyu_franka_play_dataset_converted_externally_to_rlds": {
|
243 |
+
"image_keys": [
|
244 |
+
"image",
|
245 |
+
"image_additional_view",
|
246 |
+
"image_additional_view",
|
247 |
+
"image_additional_view"
|
248 |
+
],
|
249 |
+
"image_mask":[1,0,0,1]
|
250 |
+
},
|
251 |
+
"maniskill_dataset_converted_externally_to_rlds": {
|
252 |
+
"image_keys": [
|
253 |
+
"image",
|
254 |
+
"wrist_image",
|
255 |
+
"wrist_image",
|
256 |
+
"wrist_image"
|
257 |
+
],
|
258 |
+
"image_mask":[1,1,0,0]
|
259 |
+
},
|
260 |
+
"furniture_bench_dataset_converted_externally_to_rlds": {
|
261 |
+
"image_keys": [
|
262 |
+
"image",
|
263 |
+
"wrist_image",
|
264 |
+
"wrist_image",
|
265 |
+
"wrist_image"
|
266 |
+
],
|
267 |
+
"image_mask":[1,1,0,0]
|
268 |
+
},
|
269 |
+
"ucsd_kitchen_dataset_converted_externally_to_rlds": {
|
270 |
+
"image_keys": [
|
271 |
+
"image",
|
272 |
+
"image",
|
273 |
+
"image",
|
274 |
+
"image"
|
275 |
+
],
|
276 |
+
"image_mask":[1,0,0,0]
|
277 |
+
},
|
278 |
+
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": {
|
279 |
+
"image_keys": [
|
280 |
+
"image",
|
281 |
+
"image",
|
282 |
+
"image",
|
283 |
+
"image"
|
284 |
+
],
|
285 |
+
"image_mask":[1,0,0,0]
|
286 |
+
},
|
287 |
+
"austin_sailor_dataset_converted_externally_to_rlds": {
|
288 |
+
"image_keys": [
|
289 |
+
"image",
|
290 |
+
"wrist_image",
|
291 |
+
"wrist_image",
|
292 |
+
"wrist_image"
|
293 |
+
],
|
294 |
+
"image_mask":[1,1,0,0]
|
295 |
+
},
|
296 |
+
"austin_sirius_dataset_converted_externally_to_rlds": {
|
297 |
+
"image_keys": [
|
298 |
+
"image",
|
299 |
+
"wrist_image",
|
300 |
+
"wrist_image",
|
301 |
+
"wrist_image"
|
302 |
+
],
|
303 |
+
"image_mask":[1,1,0,0]
|
304 |
+
},
|
305 |
+
"bc_z": {
|
306 |
+
"image_keys": [
|
307 |
+
"image",
|
308 |
+
"image",
|
309 |
+
"image",
|
310 |
+
"image"
|
311 |
+
],
|
312 |
+
"image_mask":[1,0,0,0]
|
313 |
+
},
|
314 |
+
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": {
|
315 |
+
"image_keys": [
|
316 |
+
"image",
|
317 |
+
"image",
|
318 |
+
"image",
|
319 |
+
"image"
|
320 |
+
],
|
321 |
+
"image_mask":[1,0,0,0]
|
322 |
+
},
|
323 |
+
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": {
|
324 |
+
"image_keys": [
|
325 |
+
"image",
|
326 |
+
"image",
|
327 |
+
"image",
|
328 |
+
"image"
|
329 |
+
],
|
330 |
+
"image_mask":[1,0,0,0]
|
331 |
+
},
|
332 |
+
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": {
|
333 |
+
"image_keys": [
|
334 |
+
"image",
|
335 |
+
"hand_image",
|
336 |
+
"hand_image",
|
337 |
+
"image2"
|
338 |
+
],
|
339 |
+
"image_mask":[1,1,0,1]
|
340 |
+
},
|
341 |
+
"utokyo_xarm_bimanual_converted_externally_to_rlds": {
|
342 |
+
"image_keys": [
|
343 |
+
"image",
|
344 |
+
"image",
|
345 |
+
"image",
|
346 |
+
"image"
|
347 |
+
],
|
348 |
+
"image_mask":[1,0,0,0]
|
349 |
+
},
|
350 |
+
"berkeley_mvp_converted_externally_to_rlds": {
|
351 |
+
"image_keys": [
|
352 |
+
"hand_image",
|
353 |
+
"hand_image",
|
354 |
+
"hand_image",
|
355 |
+
"hand_image"
|
356 |
+
],
|
357 |
+
"image_mask":[0,1,0,0]
|
358 |
+
},
|
359 |
+
"berkeley_rpt_converted_externally_to_rlds": {
|
360 |
+
"image_keys": [
|
361 |
+
"hand_image",
|
362 |
+
"hand_image",
|
363 |
+
"hand_image",
|
364 |
+
"hand_image"
|
365 |
+
],
|
366 |
+
"image_mask":[0,1,0,0]
|
367 |
+
},
|
368 |
+
"kaist_nonprehensile_converted_externally_to_rlds": {
|
369 |
+
"image_keys": [
|
370 |
+
"image",
|
371 |
+
"image",
|
372 |
+
"image",
|
373 |
+
"image"
|
374 |
+
],
|
375 |
+
"image_mask":[1,0,0,0]
|
376 |
+
},
|
377 |
+
"stanford_mask_vit_converted_externally_to_rlds": {
|
378 |
+
"image_keys": [
|
379 |
+
"image",
|
380 |
+
"image",
|
381 |
+
"image",
|
382 |
+
"image"
|
383 |
+
],
|
384 |
+
"image_mask":[1,0,0,0]
|
385 |
+
},
|
386 |
+
"tokyo_u_lsmo_converted_externally_to_rlds": {
|
387 |
+
"image_keys": [
|
388 |
+
"image",
|
389 |
+
"image",
|
390 |
+
"image",
|
391 |
+
"image"
|
392 |
+
],
|
393 |
+
"image_mask":[1,0,0,0]
|
394 |
+
},
|
395 |
+
"dlr_sara_pour_converted_externally_to_rlds": {
|
396 |
+
"image_keys": [
|
397 |
+
"image",
|
398 |
+
"image",
|
399 |
+
"image",
|
400 |
+
"image"
|
401 |
+
],
|
402 |
+
"image_mask":[1,0,0,0]
|
403 |
+
},
|
404 |
+
"dlr_sara_grid_clamp_converted_externally_to_rlds": {
|
405 |
+
"image_keys": [
|
406 |
+
"image",
|
407 |
+
"image",
|
408 |
+
"image",
|
409 |
+
"image"
|
410 |
+
],
|
411 |
+
"image_mask":[1,0,0,0]
|
412 |
+
},
|
413 |
+
"dlr_edan_shared_control_converted_externally_to_rlds": {
|
414 |
+
"image_keys": [
|
415 |
+
"image",
|
416 |
+
"image",
|
417 |
+
"image",
|
418 |
+
"image"
|
419 |
+
],
|
420 |
+
"image_mask":[1,0,0,0]
|
421 |
+
},
|
422 |
+
"asu_table_top_converted_externally_to_rlds": {
|
423 |
+
"image_keys": [
|
424 |
+
"image",
|
425 |
+
"image",
|
426 |
+
"image",
|
427 |
+
"image"
|
428 |
+
],
|
429 |
+
"image_mask":[1,0,0,0]
|
430 |
+
},
|
431 |
+
"stanford_robocook_converted_externally_to_rlds": {
|
432 |
+
"image_keys": [
|
433 |
+
"image_2",
|
434 |
+
"image_1",
|
435 |
+
"image_3",
|
436 |
+
"image_4"
|
437 |
+
],
|
438 |
+
"image_mask":[1,0,0,1]
|
439 |
+
},
|
440 |
+
"eth_agent_affordances": {
|
441 |
+
"image_keys": [
|
442 |
+
"image",
|
443 |
+
"image",
|
444 |
+
"image",
|
445 |
+
"image"
|
446 |
+
],
|
447 |
+
"image_mask":[1,0,0,0]
|
448 |
+
},
|
449 |
+
"imperialcollege_sawyer_wrist_cam": {
|
450 |
+
"image_keys": [
|
451 |
+
"image",
|
452 |
+
"wrist_image",
|
453 |
+
"wrist_image",
|
454 |
+
"wrist_image"
|
455 |
+
],
|
456 |
+
"image_mask":[0,1,0,0]
|
457 |
+
},
|
458 |
+
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": {
|
459 |
+
"image_keys": [
|
460 |
+
"image",
|
461 |
+
"wrist_image",
|
462 |
+
"wrist_image",
|
463 |
+
"wrist_image"
|
464 |
+
],
|
465 |
+
"image_mask":[1,1,0,0]
|
466 |
+
},
|
467 |
+
"uiuc_d3field": {
|
468 |
+
"image_keys": [
|
469 |
+
"image_1",
|
470 |
+
"image_2",
|
471 |
+
"image_3",
|
472 |
+
"image_4"
|
473 |
+
],
|
474 |
+
"image_mask":[1,0,0,1]
|
475 |
+
},
|
476 |
+
"utaustin_mutex": {
|
477 |
+
"image_keys": [
|
478 |
+
"image",
|
479 |
+
"wrist_image",
|
480 |
+
"wrist_image",
|
481 |
+
"wrist_image"
|
482 |
+
],
|
483 |
+
"image_mask":[1,1,0,0]
|
484 |
+
},
|
485 |
+
"berkeley_fanuc_manipulation": {
|
486 |
+
"image_keys": [
|
487 |
+
"image",
|
488 |
+
"wrist_image",
|
489 |
+
"wrist_image",
|
490 |
+
"wrist_image"
|
491 |
+
],
|
492 |
+
"image_mask":[1,1,0,0]
|
493 |
+
},
|
494 |
+
"cmu_play_fusion": {
|
495 |
+
"image_keys": [
|
496 |
+
"image",
|
497 |
+
"image",
|
498 |
+
"image",
|
499 |
+
"image"
|
500 |
+
],
|
501 |
+
"image_mask":[1,0,0,0]
|
502 |
+
},
|
503 |
+
"cmu_stretch": {
|
504 |
+
"image_keys": [
|
505 |
+
"image",
|
506 |
+
"image",
|
507 |
+
"image",
|
508 |
+
"image"
|
509 |
+
],
|
510 |
+
"image_mask":[1,0,0,0]
|
511 |
+
},
|
512 |
+
"berkeley_gnm_recon": {
|
513 |
+
"image_keys": [
|
514 |
+
"image",
|
515 |
+
"image",
|
516 |
+
"image",
|
517 |
+
"image"
|
518 |
+
],
|
519 |
+
"image_mask":[1,0,0,0]
|
520 |
+
},
|
521 |
+
"berkeley_gnm_cory_hall": {
|
522 |
+
"image_keys": [
|
523 |
+
"image",
|
524 |
+
"image",
|
525 |
+
"image",
|
526 |
+
"image"
|
527 |
+
],
|
528 |
+
"image_mask":[1,0,0,0]
|
529 |
+
},
|
530 |
+
"berkeley_gnm_sac_son": {
|
531 |
+
"image_keys": [
|
532 |
+
"image",
|
533 |
+
"image",
|
534 |
+
"image",
|
535 |
+
"image"
|
536 |
+
],
|
537 |
+
"image_mask":[1,0,0,0]
|
538 |
+
},
|
539 |
+
"robo_net": {
|
540 |
+
"image_keys": [
|
541 |
+
"image",
|
542 |
+
"image1",
|
543 |
+
"image2",
|
544 |
+
"image2"
|
545 |
+
],
|
546 |
+
"image_mask":[1,0,0,1]
|
547 |
+
},
|
548 |
+
"roboturk_real_towercreation": {
|
549 |
+
"image_keys": [
|
550 |
+
"top_rgb_frame",
|
551 |
+
"front_rgb_frame",
|
552 |
+
"front_rgb_frame",
|
553 |
+
"front_rgb_frame"
|
554 |
+
],
|
555 |
+
"image_mask":[1,0,0,1]
|
556 |
+
},
|
557 |
+
"roboturk_real_laundrylayout": {
|
558 |
+
"image_keys": [
|
559 |
+
"top_rgb_frame",
|
560 |
+
"front_rgb_frame",
|
561 |
+
"front_rgb_frame",
|
562 |
+
"front_rgb_frame"
|
563 |
+
],
|
564 |
+
"image_mask":[1,0,0,1]
|
565 |
+
},
|
566 |
+
"roboturk_real_objectsearch": {
|
567 |
+
"image_keys": [
|
568 |
+
"top_rgb_frame",
|
569 |
+
"front_rgb_frame",
|
570 |
+
"front_rgb_frame",
|
571 |
+
"front_rgb_frame"
|
572 |
+
],
|
573 |
+
"image_mask":[1,0,0,1]
|
574 |
+
},
|
575 |
+
"aloha_mobile": {
|
576 |
+
"image_keys": [
|
577 |
+
"cam_high",
|
578 |
+
"cam_right_wrist",
|
579 |
+
"cam_left_wrist",
|
580 |
+
"cam_right_wrist"
|
581 |
+
],
|
582 |
+
"image_mask":[1,1,1,0]
|
583 |
+
},
|
584 |
+
"aloha_static": {
|
585 |
+
"image_keys": [
|
586 |
+
"cam_high",
|
587 |
+
"cam_right_wrist",
|
588 |
+
"cam_left_wrist",
|
589 |
+
"cam_low"
|
590 |
+
],
|
591 |
+
"image_mask":[1,1,1,1]
|
592 |
+
},
|
593 |
+
"roboset": {
|
594 |
+
"image_keys": [
|
595 |
+
"rgb_top",
|
596 |
+
"rgb_right",
|
597 |
+
"rgb_left",
|
598 |
+
"rgb_right"
|
599 |
+
],
|
600 |
+
"image_mask":[1,1,1,0]
|
601 |
+
},
|
602 |
+
"droid": {
|
603 |
+
"image_keys": [
|
604 |
+
"exterior_image_1_left",
|
605 |
+
"wrist_image_left",
|
606 |
+
"wrist_image_left",
|
607 |
+
"exterior_image_2_left"
|
608 |
+
],
|
609 |
+
"image_mask":[1,1,0,1]
|
610 |
+
},
|
611 |
+
"fmb": {
|
612 |
+
"image_keys": [
|
613 |
+
"image_side_1",
|
614 |
+
"image_wrist_1",
|
615 |
+
"image_wrist_1",
|
616 |
+
"image_side_2"
|
617 |
+
],
|
618 |
+
"image_mask":[1,1,0,1]
|
619 |
+
},
|
620 |
+
"dobbe": {
|
621 |
+
"image_keys": [
|
622 |
+
"wrist_image",
|
623 |
+
"wrist_image",
|
624 |
+
"wrist_image",
|
625 |
+
"wrist_image"
|
626 |
+
],
|
627 |
+
"image_mask":[0,1,0,0]
|
628 |
+
},
|
629 |
+
"qut_dexterous_manpulation": {
|
630 |
+
"image_keys": [
|
631 |
+
"image",
|
632 |
+
"wrist_image",
|
633 |
+
"wrist_image",
|
634 |
+
"wrist_image"
|
635 |
+
],
|
636 |
+
"image_mask":[1,1,0,0]
|
637 |
+
},
|
638 |
+
"agilex": {
|
639 |
+
"image_keys": [
|
640 |
+
"cam_high",
|
641 |
+
"cam_right_wrist",
|
642 |
+
"cam_left_wrist",
|
643 |
+
"cam_right_wrist"
|
644 |
+
],
|
645 |
+
"image_mask":[1,1,1,0]
|
646 |
+
},
|
647 |
+
"rh20t": {
|
648 |
+
"image_keys": [
|
649 |
+
"image",
|
650 |
+
"image",
|
651 |
+
"image",
|
652 |
+
"image"
|
653 |
+
],
|
654 |
+
"image_mask":[1,0,0,0]
|
655 |
+
},
|
656 |
+
"calvin": {
|
657 |
+
"image_keys": [
|
658 |
+
"rgb_static",
|
659 |
+
"rgb_gripper",
|
660 |
+
"rgb_gripper",
|
661 |
+
"rgb_gripper"
|
662 |
+
],
|
663 |
+
"image_mask":[1,1,0,0]
|
664 |
+
},
|
665 |
+
"bridgev2": {
|
666 |
+
"image_keys": [
|
667 |
+
"images0",
|
668 |
+
"images0",
|
669 |
+
"images0",
|
670 |
+
"images0"
|
671 |
+
],
|
672 |
+
"image_mask":[1,0,0,0]
|
673 |
+
}
|
674 |
+
}
|
configs/dataset_stat.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
configs/finetune_datasets.json
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
"anubis_brush_to_pan",
|
3 |
+
"anubis_carrot_to_bag",
|
4 |
+
"anubis_towel_kirby"
|
5 |
+
]
|
configs/finetune_sample_weights.json
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"anubis_towel_kirby" : 100,
|
3 |
+
"anubis_carrot_to_bag" :100,
|
4 |
+
"anubis_brush_to_pan" : 100
|
5 |
+
}
|
configs/pretrain_datasets.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
"aloha_box_into_pot_easy"
|
3 |
+
]
|
configs/pretrain_sample_weights.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"aloha_box_into_pot_easy" : 100
|
3 |
+
}
|
configs/state_vec.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
STATE_VEC_IDX_MAPPING = {
|
2 |
+
# [0, 10): right arm joint positions
|
3 |
+
**{
|
4 |
+
'arm_joint_{}_pos'.format(i): i for i in range(10)
|
5 |
+
},
|
6 |
+
**{
|
7 |
+
'right_arm_joint_{}_pos'.format(i): i for i in range(10)
|
8 |
+
},
|
9 |
+
# [10, 15): right gripper joint positions
|
10 |
+
**{
|
11 |
+
'gripper_joint_{}_pos'.format(i): i + 10 for i in range(5)
|
12 |
+
},
|
13 |
+
**{
|
14 |
+
'right_gripper_joint_{}_pos'.format(i): i + 10 for i in range(5)
|
15 |
+
},
|
16 |
+
'gripper_open': 10, # alias of right_gripper_joint_0_pos
|
17 |
+
'right_gripper_open': 10,
|
18 |
+
# [15, 25): right arm joint velocities
|
19 |
+
**{
|
20 |
+
'arm_joint_{}_vel'.format(i): i + 15 for i in range(10)
|
21 |
+
},
|
22 |
+
**{
|
23 |
+
'right_arm_joint_{}_vel'.format(i): i + 15 for i in range(10)
|
24 |
+
},
|
25 |
+
# [25, 30): right gripper joint velocities
|
26 |
+
**{
|
27 |
+
'gripper_joint_{}_vel'.format(i): i + 25 for i in range(5)
|
28 |
+
},
|
29 |
+
**{
|
30 |
+
'right_gripper_joint_{}_vel'.format(i): i + 25 for i in range(5)
|
31 |
+
},
|
32 |
+
'gripper_open_vel': 25, # alias of right_gripper_joint_0_vel
|
33 |
+
'right_gripper_open_vel': 25,
|
34 |
+
# [30, 33): right end effector positions
|
35 |
+
'eef_pos_x': 30,
|
36 |
+
'right_eef_pos_x': 30,
|
37 |
+
'eef_pos_y': 31,
|
38 |
+
'right_eef_pos_y': 31,
|
39 |
+
'eef_pos_z': 32,
|
40 |
+
'right_eef_pos_z': 32,
|
41 |
+
# [33, 39): right end effector 6D pose
|
42 |
+
'eef_angle_0': 33,
|
43 |
+
'right_eef_angle_0': 33,
|
44 |
+
'eef_angle_1': 34,
|
45 |
+
'right_eef_angle_1': 34,
|
46 |
+
'eef_angle_2': 35,
|
47 |
+
'right_eef_angle_2': 35,
|
48 |
+
'eef_angle_3': 36,
|
49 |
+
'right_eef_angle_3': 36,
|
50 |
+
'eef_angle_4': 37,
|
51 |
+
'right_eef_angle_4': 37,
|
52 |
+
'eef_angle_5': 38,
|
53 |
+
'right_eef_angle_5': 38,
|
54 |
+
# [39, 42): right end effector velocities
|
55 |
+
'eef_vel_x': 39,
|
56 |
+
'right_eef_vel_x': 39,
|
57 |
+
'eef_vel_y': 40,
|
58 |
+
'right_eef_vel_y': 40,
|
59 |
+
'eef_vel_z': 41,
|
60 |
+
'right_eef_vel_z': 41,
|
61 |
+
# [42, 45): right end effector angular velocities
|
62 |
+
'eef_angular_vel_roll': 42,
|
63 |
+
'right_eef_angular_vel_roll': 42,
|
64 |
+
'eef_angular_vel_pitch': 43,
|
65 |
+
'right_eef_angular_vel_pitch': 43,
|
66 |
+
'eef_angular_vel_yaw': 44,
|
67 |
+
'right_eef_angular_vel_yaw': 44,
|
68 |
+
# [45, 50): reserved
|
69 |
+
# [50, 60): left arm joint positions
|
70 |
+
**{
|
71 |
+
'left_arm_joint_{}_pos'.format(i): i + 50 for i in range(10)
|
72 |
+
},
|
73 |
+
# [60, 65): left gripper joint positions
|
74 |
+
**{
|
75 |
+
'left_gripper_joint_{}_pos'.format(i): i + 60 for i in range(5)
|
76 |
+
},
|
77 |
+
'left_gripper_open': 60, # alias of left_gripper_joint_0_pos
|
78 |
+
# [65, 75): left arm joint velocities
|
79 |
+
**{
|
80 |
+
'left_arm_joint_{}_vel'.format(i): i + 65 for i in range(10)
|
81 |
+
},
|
82 |
+
# [75, 80): left gripper joint velocities
|
83 |
+
**{
|
84 |
+
'left_gripper_joint_{}_vel'.format(i): i + 75 for i in range(5)
|
85 |
+
},
|
86 |
+
'left_gripper_open_vel': 75, # alias of left_gripper_joint_0_vel
|
87 |
+
# [80, 83): left end effector positions
|
88 |
+
'left_eef_pos_x': 80,
|
89 |
+
'left_eef_pos_y': 81,
|
90 |
+
'left_eef_pos_z': 82,
|
91 |
+
# [83, 89): left end effector 6D pose
|
92 |
+
'left_eef_angle_0': 83,
|
93 |
+
'left_eef_angle_1': 84,
|
94 |
+
'left_eef_angle_2': 85,
|
95 |
+
'left_eef_angle_3': 86,
|
96 |
+
'left_eef_angle_4': 87,
|
97 |
+
'left_eef_angle_5': 88,
|
98 |
+
# [89, 92): left end effector velocities
|
99 |
+
'left_eef_vel_x': 89,
|
100 |
+
'left_eef_vel_y': 90,
|
101 |
+
'left_eef_vel_z': 91,
|
102 |
+
# [92, 95): left end effector angular velocities
|
103 |
+
'left_eef_angular_vel_roll': 92,
|
104 |
+
'left_eef_angular_vel_pitch': 93,
|
105 |
+
'left_eef_angular_vel_yaw': 94,
|
106 |
+
# [95, 100): reserved
|
107 |
+
# [100, 102): base linear velocities
|
108 |
+
'base_vel_x': 100,
|
109 |
+
'base_vel_y': 101,
|
110 |
+
# [102, 103): base angular velocities
|
111 |
+
'base_angular_vel': 102,
|
112 |
+
# [103, 128): reserved
|
113 |
+
}
|
114 |
+
STATE_VEC_LEN = 128
|
configs/zero2.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bf16": {
|
3 |
+
"enabled": "auto"
|
4 |
+
},
|
5 |
+
"train_micro_batch_size_per_gpu": "auto",
|
6 |
+
"train_batch_size": "auto",
|
7 |
+
"gradient_accumulation_steps": "auto",
|
8 |
+
"zero_optimization": {
|
9 |
+
"stage": 2,
|
10 |
+
"overlap_comm": true,
|
11 |
+
"contiguous_gradients": true,
|
12 |
+
"sub_group_size": 1e9
|
13 |
+
}
|
14 |
+
}
|
data/aloha/hdf5totfrecords.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import h5py
|
3 |
+
import os
|
4 |
+
import fnmatch
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
def decode_img(img):
|
10 |
+
return cv2.cvtColor(cv2.imdecode(np.frombuffer(img, np.uint8), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
|
11 |
+
|
12 |
+
def decode_all_imgs(imgs):
|
13 |
+
return [decode_img(img) for img in imgs]
|
14 |
+
|
15 |
+
def _bytes_feature(value):
|
16 |
+
"""Returns a bytes_list from a string / byte."""
|
17 |
+
if isinstance(value, type(tf.constant(0))):
|
18 |
+
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
|
19 |
+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
|
20 |
+
def _bool_feature(value):
|
21 |
+
"""Returns a bool_list from a boolean."""
|
22 |
+
return tf.train.Feature(int64_list=tf.train.Int64List(value=[int(value)]))
|
23 |
+
|
24 |
+
def serialize_example(action, base_action, qpos, qvel, cam_high, cam_left_wrist, cam_right_wrist, cam_low, instruction, terminate_episode):
|
25 |
+
if base_action is not None:
|
26 |
+
feature = {
|
27 |
+
'action': _bytes_feature(tf.io.serialize_tensor(action)),
|
28 |
+
'base_action': _bytes_feature(tf.io.serialize_tensor(base_action)),
|
29 |
+
'qpos': _bytes_feature(tf.io.serialize_tensor(qpos)),
|
30 |
+
'qvel': _bytes_feature(tf.io.serialize_tensor(qvel)),
|
31 |
+
'cam_high': _bytes_feature(tf.io.serialize_tensor(cam_high)),
|
32 |
+
'cam_left_wrist': _bytes_feature(tf.io.serialize_tensor(cam_left_wrist)),
|
33 |
+
'cam_right_wrist': _bytes_feature(tf.io.serialize_tensor(cam_right_wrist)),
|
34 |
+
'instruction': _bytes_feature(instruction),
|
35 |
+
'terminate_episode': _bool_feature(terminate_episode)
|
36 |
+
}
|
37 |
+
else:
|
38 |
+
feature = {
|
39 |
+
'action': _bytes_feature(tf.io.serialize_tensor(action)),
|
40 |
+
'qpos': _bytes_feature(tf.io.serialize_tensor(qpos)),
|
41 |
+
'qvel': _bytes_feature(tf.io.serialize_tensor(qvel)),
|
42 |
+
'cam_high': _bytes_feature(tf.io.serialize_tensor(cam_high)),
|
43 |
+
'cam_left_wrist': _bytes_feature(tf.io.serialize_tensor(cam_left_wrist)),
|
44 |
+
'cam_right_wrist': _bytes_feature(tf.io.serialize_tensor(cam_right_wrist)),
|
45 |
+
'cam_low': _bytes_feature(tf.io.serialize_tensor(cam_low)),
|
46 |
+
'instruction': _bytes_feature(instruction),
|
47 |
+
'terminate_episode': _bool_feature(terminate_episode)
|
48 |
+
}
|
49 |
+
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
|
50 |
+
return example_proto.SerializeToString()
|
51 |
+
|
52 |
+
def write_tfrecords(root_dir, out_dir):
|
53 |
+
if not os.path.exists(out_dir):
|
54 |
+
os.makedirs(out_dir)
|
55 |
+
num_files = 0
|
56 |
+
for root, dirs, files in os.walk(root_dir):
|
57 |
+
num_files += len(fnmatch.filter(files, '*.hdf5'))
|
58 |
+
with tqdm(total=num_files) as pbar:
|
59 |
+
for root, dirs, files in os.walk(root_dir):
|
60 |
+
for filename in fnmatch.filter(files, '*.hdf5'):
|
61 |
+
filepath = os.path.join(root, filename)
|
62 |
+
with h5py.File(filepath, 'r') as f:
|
63 |
+
if not 'instruction' in f:
|
64 |
+
continue
|
65 |
+
pbar.update(1)
|
66 |
+
output_dir = os.path.join(out_dir, os.path.relpath(root, root_dir))
|
67 |
+
if not os.path.exists(output_dir):
|
68 |
+
os.makedirs(output_dir)
|
69 |
+
print(f"Writing TFRecords to {output_dir}")
|
70 |
+
tfrecord_path = os.path.join(output_dir, filename.replace('.hdf5', '.tfrecord'))
|
71 |
+
with tf.io.TFRecordWriter(tfrecord_path) as writer:
|
72 |
+
num_episodes = f['action'].shape[0]
|
73 |
+
for i in range(num_episodes):
|
74 |
+
action = f['action'][i]
|
75 |
+
if 'base_action' in f:
|
76 |
+
base_action = f['base_action'][i]
|
77 |
+
else:
|
78 |
+
base_action = None
|
79 |
+
qpos = f['observations']['qpos'][i]
|
80 |
+
qvel = f['observations']['qvel'][i]
|
81 |
+
cam_high = decode_img(f['observations']['images']['cam_high'][i])
|
82 |
+
cam_left_wrist = decode_img(f['observations']['images']['cam_left_wrist'][i])
|
83 |
+
cam_right_wrist = decode_img(f['observations']['images']['cam_right_wrist'][i])
|
84 |
+
if 'cam_low' in f['observations']['images']:
|
85 |
+
cam_low = decode_img(f['observations']['images']['cam_low'][i])
|
86 |
+
else:
|
87 |
+
cam_low = None
|
88 |
+
instruction = f['instruction'][()]
|
89 |
+
terminate_episode = i == num_episodes - 1
|
90 |
+
serialized_example = serialize_example(action, base_action, qpos, qvel, cam_high, cam_left_wrist, cam_right_wrist, cam_low, instruction, terminate_episode)
|
91 |
+
writer.write(serialized_example)
|
92 |
+
print(f"TFRecords written to {tfrecord_path}")
|
93 |
+
print(f"TFRecords written to {out_dir}")
|
94 |
+
|
95 |
+
root_dir = '../datasets/aloha/'
|
96 |
+
output_dir = '../datasets/aloha/tfrecords/'
|
97 |
+
|
98 |
+
write_tfrecords(root_dir, output_dir)
|
data/aloha/unzip_data.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
cd ../datasets/aloha/
|
2 |
+
|
3 |
+
unzip aloha_mobile.zip
|
data/bridgev2/bridgedata_numpy_to_tfrecord.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Converts data from the BridgeData numpy format to TFRecord format.
|
3 |
+
|
4 |
+
Consider the following directory structure for the input data:
|
5 |
+
|
6 |
+
bridgedata_numpy/
|
7 |
+
rss/
|
8 |
+
toykitchen2/
|
9 |
+
set_table/
|
10 |
+
00/
|
11 |
+
train/
|
12 |
+
out.npy
|
13 |
+
val/
|
14 |
+
out.npy
|
15 |
+
icra/
|
16 |
+
...
|
17 |
+
|
18 |
+
The --depth parameter controls how much of the data to process at the
|
19 |
+
--input_path; for example, if --depth=5, then --input_path should be
|
20 |
+
"bridgedata_numpy", and all data will be processed. If --depth=3, then
|
21 |
+
--input_path should be "bridgedata_numpy/rss/toykitchen2", and only data
|
22 |
+
under "toykitchen2" will be processed.
|
23 |
+
|
24 |
+
The same directory structure will be replicated under --output_path. For
|
25 |
+
example, in the second case, the output will be written to
|
26 |
+
"{output_path}/set_table/00/...".
|
27 |
+
|
28 |
+
Can read/write directly from/to Google Cloud Storage.
|
29 |
+
|
30 |
+
Written by Kevin Black ([email protected]).
|
31 |
+
"""
|
32 |
+
import os
|
33 |
+
from multiprocessing import Pool
|
34 |
+
|
35 |
+
import numpy as np
|
36 |
+
import tensorflow as tf
|
37 |
+
import tqdm
|
38 |
+
from absl import app, flags, logging
|
39 |
+
import pickle
|
40 |
+
from multiprocessing import cpu_count
|
41 |
+
|
42 |
+
FLAGS = flags.FLAGS
|
43 |
+
|
44 |
+
flags.DEFINE_string("input_path", None, "Input path", required=True)
|
45 |
+
flags.DEFINE_string("output_path", None, "Output path", required=True)
|
46 |
+
flags.DEFINE_integer(
|
47 |
+
"depth",
|
48 |
+
5,
|
49 |
+
"Number of directories deep to traverse. Looks for {input_path}/dir_1/dir_2/.../dir_{depth-1}/train/out.npy",
|
50 |
+
)
|
51 |
+
flags.DEFINE_bool("overwrite", False, "Overwrite existing files")
|
52 |
+
num_workers = 8
|
53 |
+
flags.DEFINE_integer("num_workers", num_workers, "Number of threads to use")
|
54 |
+
|
55 |
+
print(f"using {num_workers} workers")
|
56 |
+
|
57 |
+
def tensor_feature(value):
|
58 |
+
return tf.train.Feature(
|
59 |
+
bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(value).numpy()])
|
60 |
+
)
|
61 |
+
|
62 |
+
def _bytes_feature(value):
|
63 |
+
"""Returns a bytes_list from a string / byte."""
|
64 |
+
if isinstance(value, type(tf.constant(0))):
|
65 |
+
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
|
66 |
+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.encode('utf-8')]))
|
67 |
+
|
68 |
+
def _strings_feature(string_list):
|
69 |
+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=s.encode('utf-8')))
|
70 |
+
|
71 |
+
def _bool_feature(value):
|
72 |
+
"""Returns a bool_list from a boolean."""
|
73 |
+
return tf.train.Feature(int64_list=tf.train.Int64List(value=[int(value)]))
|
74 |
+
|
75 |
+
|
76 |
+
def process(path):
|
77 |
+
# with tf.io.gfile.GFile(path, "rb") as f:
|
78 |
+
# arr = np.load(f, allow_pickle=True)
|
79 |
+
try:
|
80 |
+
with tf.io.gfile.GFile(path, "rb") as f:
|
81 |
+
arr = np.load(path, allow_pickle=True)
|
82 |
+
except Exception as e:
|
83 |
+
print(f"Error loading {path}: {e}")
|
84 |
+
return
|
85 |
+
|
86 |
+
dirname = os.path.dirname(os.path.abspath(path))
|
87 |
+
outpath = os.path.join(FLAGS.output_path, *dirname.split(os.sep)[-FLAGS.depth :])
|
88 |
+
|
89 |
+
if tf.io.gfile.exists(outpath):
|
90 |
+
if FLAGS.overwrite:
|
91 |
+
logging.info(f"Deleting {outpath}")
|
92 |
+
tf.io.gfile.rmtree(outpath)
|
93 |
+
else:
|
94 |
+
logging.info(f"Skipping {outpath}")
|
95 |
+
return
|
96 |
+
|
97 |
+
if len(arr) == 0:
|
98 |
+
logging.info(f"Skipping {path}, empty")
|
99 |
+
return
|
100 |
+
|
101 |
+
tf.io.gfile.makedirs(outpath)
|
102 |
+
|
103 |
+
for i,traj in enumerate(arr):
|
104 |
+
write_path = f"{outpath}/out_{i}.tfrecord"
|
105 |
+
with tf.io.TFRecordWriter(write_path) as writer:
|
106 |
+
truncates = np.zeros(len(traj["actions"]), dtype=np.bool_)
|
107 |
+
truncates[-1] = True
|
108 |
+
frames_num = len(traj["observations"])
|
109 |
+
# remove empty string
|
110 |
+
traj["language"] = [x for x in traj["language"] if x != ""]
|
111 |
+
if len(traj["language"]) == 0:
|
112 |
+
traj["language"] = [""]
|
113 |
+
instr = traj["language"][0]
|
114 |
+
if(len(traj["language"]) > 2):
|
115 |
+
print(len(traj["language"]))
|
116 |
+
for i in range(frames_num):
|
117 |
+
tf_features = {
|
118 |
+
"observations/images0": tensor_feature(
|
119 |
+
np.array(
|
120 |
+
[traj["observations"][i]["images0"]],
|
121 |
+
dtype=np.uint8,
|
122 |
+
)
|
123 |
+
),
|
124 |
+
"observations/state": tensor_feature(
|
125 |
+
np.array(
|
126 |
+
[traj["observations"][i]["state"]],
|
127 |
+
dtype=np.float32,
|
128 |
+
)
|
129 |
+
),
|
130 |
+
"observations/qpos": tensor_feature(
|
131 |
+
np.array(
|
132 |
+
[traj["observations"][i]["qpos"]],
|
133 |
+
dtype=np.float32,
|
134 |
+
)
|
135 |
+
),
|
136 |
+
"observations/eef_transform": tensor_feature(
|
137 |
+
np.array(
|
138 |
+
[traj["observations"][i]["eef_transform"]],
|
139 |
+
dtype=np.float32,
|
140 |
+
)
|
141 |
+
),
|
142 |
+
"language": _bytes_feature(instr),
|
143 |
+
"actions": tensor_feature(
|
144 |
+
np.array(traj["actions"][i], dtype=np.float32)
|
145 |
+
),
|
146 |
+
"truncates": _bool_feature(i == frames_num - 1),
|
147 |
+
}
|
148 |
+
example = tf.train.Example(
|
149 |
+
features=tf.train.Features(
|
150 |
+
feature = tf_features
|
151 |
+
)
|
152 |
+
)
|
153 |
+
writer.write(example.SerializeToString())
|
154 |
+
|
155 |
+
|
156 |
+
def main(_):
|
157 |
+
assert FLAGS.depth >= 1
|
158 |
+
|
159 |
+
paths = tf.io.gfile.glob(
|
160 |
+
tf.io.gfile.join(FLAGS.input_path, *("*" * (FLAGS.depth - 1)))
|
161 |
+
)
|
162 |
+
paths = [f"{p}/train/out.npy" for p in paths] + [f"{p}/val/out.npy" for p in paths]
|
163 |
+
# num_episodes = 0
|
164 |
+
# for dirpath in paths:
|
165 |
+
# with tf.io.gfile.GFile(dirpath, "rb") as f:
|
166 |
+
# arr = np.load(dirpath, allow_pickle=True)
|
167 |
+
# num_episodes += len(arr)
|
168 |
+
# print(num_episodes)
|
169 |
+
with Pool(FLAGS.num_workers) as p:
|
170 |
+
list(tqdm.tqdm(p.imap(process, paths), total=len(paths)))
|
171 |
+
|
172 |
+
|
173 |
+
if __name__ == "__main__":
|
174 |
+
app.run(main)
|
data/bridgev2/bridgedata_raw_to_numpy.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Converts data from the BridgeData raw format to numpy format.
|
3 |
+
|
4 |
+
Consider the following directory structure for the input data:
|
5 |
+
|
6 |
+
bridgedata_raw/
|
7 |
+
rss/
|
8 |
+
toykitchen2/
|
9 |
+
set_table/
|
10 |
+
00/
|
11 |
+
2022-01-01_00-00-00/
|
12 |
+
collection_metadata.json
|
13 |
+
config.json
|
14 |
+
diagnostics.png
|
15 |
+
raw/
|
16 |
+
traj_group0/
|
17 |
+
traj0/
|
18 |
+
obs_dict.pkl
|
19 |
+
policy_out.pkl
|
20 |
+
agent_data.pkl
|
21 |
+
images0/
|
22 |
+
im_0.jpg
|
23 |
+
im_1.jpg
|
24 |
+
...
|
25 |
+
...
|
26 |
+
...
|
27 |
+
01/
|
28 |
+
...
|
29 |
+
|
30 |
+
The --depth parameter controls how much of the data to process at the
|
31 |
+
--input_path; for example, if --depth=5, then --input_path should be
|
32 |
+
"bridgedata_raw", and all data will be processed. If --depth=3, then
|
33 |
+
--input_path should be "bridgedata_raw/rss/toykitchen2", and only data
|
34 |
+
under "toykitchen2" will be processed.
|
35 |
+
|
36 |
+
The same directory structure will be replicated under --output_path. For
|
37 |
+
example, in the second case, the output will be written to
|
38 |
+
"{output_path}/set_table/00/...".
|
39 |
+
|
40 |
+
Squashes images to 128x128.
|
41 |
+
|
42 |
+
Can write directly to Google Cloud Storage, but not read from it.
|
43 |
+
|
44 |
+
Written by Kevin Black ([email protected]).
|
45 |
+
"""
|
46 |
+
import copy
|
47 |
+
import glob
|
48 |
+
import os
|
49 |
+
import pickle
|
50 |
+
import random
|
51 |
+
from collections import defaultdict
|
52 |
+
from datetime import datetime
|
53 |
+
from functools import partial
|
54 |
+
from multiprocessing import Pool
|
55 |
+
|
56 |
+
import numpy as np
|
57 |
+
import tensorflow as tf
|
58 |
+
import tqdm
|
59 |
+
from absl import app, flags, logging
|
60 |
+
from PIL import Image
|
61 |
+
|
62 |
+
FLAGS = flags.FLAGS
|
63 |
+
|
64 |
+
flags.DEFINE_string("input_path", None, "Input path", required=True)
|
65 |
+
flags.DEFINE_string("output_path", None, "Output path", required=True)
|
66 |
+
flags.DEFINE_integer(
|
67 |
+
"depth",
|
68 |
+
5,
|
69 |
+
"Number of directories deep to traverse to the dated directory. Looks for"
|
70 |
+
"{input_path}/dir_1/dir_2/.../dir_{depth-1}/2022-01-01_00-00-00/...",
|
71 |
+
)
|
72 |
+
flags.DEFINE_bool("overwrite", False, "Overwrite existing files")
|
73 |
+
flags.DEFINE_float(
|
74 |
+
"train_proportion", 0.9, "Proportion of data to use for training (rather than val)"
|
75 |
+
)
|
76 |
+
flags.DEFINE_integer("num_workers", 8, "Number of threads to use")
|
77 |
+
flags.DEFINE_integer("im_size", 128, "Image size")
|
78 |
+
|
79 |
+
|
80 |
+
def squash(path):
|
81 |
+
im = Image.open(path)
|
82 |
+
# im = im.resize((FLAGS.im_size, FLAGS.im_size), Image.Resampling.LANCZOS)
|
83 |
+
out = np.asarray(im).astype(np.uint8)
|
84 |
+
return out
|
85 |
+
|
86 |
+
|
87 |
+
def process_images(path): # processes images at a trajectory level
|
88 |
+
names = sorted(
|
89 |
+
[x for x in os.listdir(path) if "images" in x and not "depth" in x],
|
90 |
+
key=lambda x: int(x.split("images")[1]),
|
91 |
+
)
|
92 |
+
image_path = [
|
93 |
+
os.path.join(path, x)
|
94 |
+
for x in os.listdir(path)
|
95 |
+
if "images" in x and not "depth" in x
|
96 |
+
]
|
97 |
+
image_path = sorted(image_path, key=lambda x: int(x.split("images")[1]))
|
98 |
+
|
99 |
+
images_out = defaultdict(list)
|
100 |
+
if len(image_path) == 0:
|
101 |
+
return None, None
|
102 |
+
|
103 |
+
tlen = len(glob.glob(image_path[0] + "/im_*.jpg"))
|
104 |
+
|
105 |
+
for i, name in enumerate(names):
|
106 |
+
for t in range(tlen):
|
107 |
+
images_out[name].append(squash(image_path[i] + "/im_{}.jpg".format(t)))
|
108 |
+
|
109 |
+
images_out = dict(images_out)
|
110 |
+
|
111 |
+
obs, next_obs = dict(), dict()
|
112 |
+
|
113 |
+
for n in names:
|
114 |
+
obs[n] = images_out[n][:-1]
|
115 |
+
next_obs[n] = images_out[n][1:]
|
116 |
+
return obs, next_obs
|
117 |
+
|
118 |
+
|
119 |
+
def process_state(path):
|
120 |
+
fp = os.path.join(path, "obs_dict.pkl")
|
121 |
+
with open(fp, "rb") as f:
|
122 |
+
x = pickle.load(f)
|
123 |
+
qpos = None if "qpos" not in x.keys() else x["qpos"]
|
124 |
+
qvel = None if "qvel" not in x.keys() else x["qvel"]
|
125 |
+
eef_transform = None if "eef_transform" not in x.keys() else x["eef_transform"]
|
126 |
+
return x["full_state"][:-1], x["full_state"][1:], qpos, qvel, eef_transform
|
127 |
+
|
128 |
+
def process_time(path):
|
129 |
+
fp = os.path.join(path, "obs_dict.pkl")
|
130 |
+
with open(fp, "rb") as f:
|
131 |
+
x = pickle.load(f)
|
132 |
+
return x["time_stamp"][:-1], x["time_stamp"][1:]
|
133 |
+
|
134 |
+
|
135 |
+
def process_actions(path): # gets actions
|
136 |
+
fp = os.path.join(path, "policy_out.pkl")
|
137 |
+
with open(fp, "rb") as f:
|
138 |
+
act_list = pickle.load(f)
|
139 |
+
if isinstance(act_list[0], dict):
|
140 |
+
act_list = [x["actions"] for x in act_list]
|
141 |
+
return act_list
|
142 |
+
|
143 |
+
|
144 |
+
# processes each data collection attempt
|
145 |
+
def process_dc(path, train_ratio=0.9):
|
146 |
+
# a mystery left by the greats of the past
|
147 |
+
if "lmdb" in path:
|
148 |
+
logging.warning(f"Skipping {path} because uhhhh lmdb?")
|
149 |
+
return [], [], [], []
|
150 |
+
|
151 |
+
all_dicts_train = list()
|
152 |
+
all_dicts_test = list()
|
153 |
+
all_rews_train = list()
|
154 |
+
all_rews_test = list()
|
155 |
+
|
156 |
+
# Data collected prior to 7-23 has a delay of 1, otherwise a delay of 0
|
157 |
+
date_time = datetime.strptime(path.split("/")[-1], "%Y-%m-%d_%H-%M-%S")
|
158 |
+
latency_shift = date_time < datetime(2021, 7, 23)
|
159 |
+
|
160 |
+
search_path = os.path.join(path, "raw", "traj_group*", "traj*")
|
161 |
+
all_traj = glob.glob(search_path)
|
162 |
+
if all_traj == []:
|
163 |
+
logging.info(f"no trajs found in {search_path}")
|
164 |
+
return [], [], [], []
|
165 |
+
|
166 |
+
random.shuffle(all_traj)
|
167 |
+
|
168 |
+
num_traj = len(all_traj)
|
169 |
+
for itraj, tp in tqdm.tqdm(enumerate(all_traj)):
|
170 |
+
try:
|
171 |
+
out = dict()
|
172 |
+
|
173 |
+
ld = os.listdir(tp)
|
174 |
+
|
175 |
+
assert "obs_dict.pkl" in ld, tp + ":" + str(ld)
|
176 |
+
assert "policy_out.pkl" in ld, tp + ":" + str(ld)
|
177 |
+
# assert "agent_data.pkl" in ld, tp + ":" + str(ld) # not used
|
178 |
+
|
179 |
+
obs, next_obs = process_images(tp)
|
180 |
+
if obs is None:
|
181 |
+
return
|
182 |
+
acts = process_actions(tp)
|
183 |
+
state, next_state, qpos, qvel, eef_transform = process_state(tp)
|
184 |
+
time_stamp, next_time_stamp = process_time(tp)
|
185 |
+
term = [0] * len(acts)
|
186 |
+
if "lang.txt" in ld:
|
187 |
+
with open(os.path.join(tp, "lang.txt")) as f:
|
188 |
+
lang = list(f)
|
189 |
+
lang = [l.strip() for l in lang if "confidence" not in l]
|
190 |
+
else:
|
191 |
+
# empty string is a placeholder for data with no language label
|
192 |
+
lang = [""]
|
193 |
+
|
194 |
+
out["observations"] = obs
|
195 |
+
out["observations"]["state"] = state
|
196 |
+
out["observations"]["time_stamp"] = time_stamp
|
197 |
+
if qpos is not None:
|
198 |
+
out["observations"]["qpos"] = qpos
|
199 |
+
else:
|
200 |
+
return None, None, None, None
|
201 |
+
if qvel is not None:
|
202 |
+
out["observations"]["qvel"] = qvel
|
203 |
+
if eef_transform is not None:
|
204 |
+
out["observations"]["eef_transform"] = eef_transform
|
205 |
+
out["next_observations"] = next_obs
|
206 |
+
out["next_observations"]["state"] = next_state
|
207 |
+
out["next_observations"]["time_stamp"] = next_time_stamp
|
208 |
+
|
209 |
+
|
210 |
+
out["observations"] = [
|
211 |
+
dict(zip(out["observations"], t))
|
212 |
+
for t in zip(*out["observations"].values())
|
213 |
+
]
|
214 |
+
out["next_observations"] = [
|
215 |
+
dict(zip(out["next_observations"], t))
|
216 |
+
for t in zip(*out["next_observations"].values())
|
217 |
+
]
|
218 |
+
|
219 |
+
out["actions"] = acts
|
220 |
+
out["terminals"] = term
|
221 |
+
out["language"] = lang
|
222 |
+
|
223 |
+
# shift the actions according to camera latency
|
224 |
+
if latency_shift:
|
225 |
+
out["observations"] = out["observations"][1:]
|
226 |
+
out["next_observations"] = out["next_observations"][1:]
|
227 |
+
out["actions"] = out["actions"][:-1]
|
228 |
+
out["terminals"] = term[:-1]
|
229 |
+
|
230 |
+
labeled_rew = copy.deepcopy(out["terminals"])[:]
|
231 |
+
labeled_rew[-2:] = [1, 1]
|
232 |
+
|
233 |
+
traj_len = len(out["observations"])
|
234 |
+
assert len(out["next_observations"]) == traj_len
|
235 |
+
assert len(out["actions"]) == traj_len
|
236 |
+
assert len(out["terminals"]) == traj_len
|
237 |
+
assert len(labeled_rew) == traj_len
|
238 |
+
|
239 |
+
if itraj < int(num_traj * train_ratio):
|
240 |
+
all_dicts_train.append(out)
|
241 |
+
all_rews_train.append(labeled_rew)
|
242 |
+
else:
|
243 |
+
all_dicts_test.append(out)
|
244 |
+
all_rews_test.append(labeled_rew)
|
245 |
+
except FileNotFoundError as e:
|
246 |
+
logging.error(e)
|
247 |
+
continue
|
248 |
+
except AssertionError as e:
|
249 |
+
logging.error(e)
|
250 |
+
continue
|
251 |
+
|
252 |
+
return all_dicts_train, all_dicts_test, all_rews_train, all_rews_test
|
253 |
+
|
254 |
+
|
255 |
+
def make_numpy(path, train_proportion):
|
256 |
+
dirname = os.path.abspath(path)
|
257 |
+
outpath = os.path.join(
|
258 |
+
FLAGS.output_path, *dirname.split(os.sep)[-(max(FLAGS.depth - 1, 1)) :]
|
259 |
+
)
|
260 |
+
|
261 |
+
if os.path.exists(outpath):
|
262 |
+
if FLAGS.overwrite:
|
263 |
+
logging.info(f"Deleting {outpath}")
|
264 |
+
tf.io.gfile.rmtree(outpath)
|
265 |
+
else:
|
266 |
+
logging.info(f"Skipping {outpath}")
|
267 |
+
return
|
268 |
+
|
269 |
+
outpath_train = tf.io.gfile.join(outpath, "train")
|
270 |
+
outpath_val = tf.io.gfile.join(outpath, "val")
|
271 |
+
tf.io.gfile.makedirs(outpath_train)
|
272 |
+
tf.io.gfile.makedirs(outpath_val)
|
273 |
+
|
274 |
+
lst_train = []
|
275 |
+
lst_val = []
|
276 |
+
rew_train_l = []
|
277 |
+
rew_val_l = []
|
278 |
+
|
279 |
+
for dated_folder in os.listdir(path):
|
280 |
+
curr_train, curr_val, rew_train, rew_val = process_dc(
|
281 |
+
os.path.join(path, dated_folder), train_ratio=train_proportion
|
282 |
+
)
|
283 |
+
if curr_train is None:
|
284 |
+
continue
|
285 |
+
lst_train.extend(curr_train)
|
286 |
+
lst_val.extend(curr_val)
|
287 |
+
rew_train_l.extend(rew_train)
|
288 |
+
rew_val_l.extend(rew_val)
|
289 |
+
|
290 |
+
if len(lst_train) == 0 or len(lst_val) == 0:
|
291 |
+
return
|
292 |
+
|
293 |
+
with tf.io.gfile.GFile(tf.io.gfile.join(outpath_train, "out.npy"), "wb") as f:
|
294 |
+
np.save(f, lst_train)
|
295 |
+
with tf.io.gfile.GFile(tf.io.gfile.join(outpath_val, "out.npy"), "wb") as f:
|
296 |
+
np.save(f, lst_val)
|
297 |
+
|
298 |
+
# doesn't seem like these are ever used anymore
|
299 |
+
# np.save(os.path.join(outpath_train, "out_rew.npy"), rew_train_l)
|
300 |
+
# np.save(os.path.join(outpath_val, "out_rew.npy"), rew_val_l)
|
301 |
+
|
302 |
+
|
303 |
+
def main(_):
|
304 |
+
assert FLAGS.depth >= 1
|
305 |
+
|
306 |
+
# each path is a directory that contains dated directories
|
307 |
+
paths = glob.glob(os.path.join(FLAGS.input_path, *("*" * (FLAGS.depth - 1))))
|
308 |
+
|
309 |
+
worker_fn = partial(make_numpy, train_proportion=FLAGS.train_proportion)
|
310 |
+
|
311 |
+
with Pool(FLAGS.num_workers) as p:
|
312 |
+
list(tqdm.tqdm(p.imap(worker_fn, paths), total=len(paths)))
|
313 |
+
|
314 |
+
|
315 |
+
if __name__ == "__main__":
|
316 |
+
app.run(main)
|
data/bridgev2/download.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Download the dataset to ../datasets/bridgev2
|
2 |
+
mkdir -p ../datasets/bridgev2
|
3 |
+
wget -O ../datasets/bridgev2/demos_8_17.zip https://rail.eecs.berkeley.edu/datasets/bridge_release/data/demos_8_17.zip
|
4 |
+
mkdir -p ../datasets/bridgev2/raw
|
5 |
+
# Unzip the dataset
|
6 |
+
unzip '../datasets/bridgev2/*.zip' -d ../datasets/bridgev2/raw
|
7 |
+
# Convert the dataset to numpy
|
8 |
+
python bridgedata_raw_to_numpy.py --input ../datasets/bridgev2/raw --output ../datasets/bridgev2/npy
|
9 |
+
# Convert the dataset to tfrecords
|
10 |
+
python bridgedata_numpy_to_tfrecord.py --input ../datasets/bridgev2/npy --output ../datasets/bridgev2/tfrecords
|
11 |
+
# Remove the raw data and numpy data
|
12 |
+
rm -rf ../datasets/bridgev2/raw
|
13 |
+
rm -rf ../datasets/bridgev2/npy
|
data/calvin/download.sh
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
echo "Downloading CALVIN dataset..."
|
4 |
+
|
5 |
+
# Create calvin folder in ../datasets/calvin/
|
6 |
+
mkdir -p ../datasets/calvin/
|
7 |
+
|
8 |
+
cd ../datasets/calvin/
|
9 |
+
|
10 |
+
# You can use this for faster downloading
|
11 |
+
# aria2c -x 16 -s 16 http://calvin.cs.uni-freiburg.de/dataset/task_ABC_D.zip
|
12 |
+
|
13 |
+
wget http://calvin.cs.uni-freiburg.de/dataset/task_ABC_D.zip
|
14 |
+
|
15 |
+
echo "Unzipping CALVIN dataset..."
|
16 |
+
|
17 |
+
unzip task_ABC_D.zip
|
18 |
+
|
19 |
+
echo "Done downloading and unzipping CALVIN dataset."
|
data/calvin/hdf5totfrecords.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
|
7 |
+
def _bytes_feature(value):
|
8 |
+
"""Returns a bytes_list from a string / byte."""
|
9 |
+
if isinstance(value, type(tf.constant(0))):
|
10 |
+
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
|
11 |
+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
|
12 |
+
|
13 |
+
|
14 |
+
def _bool_feature(value):
|
15 |
+
"""Returns a bool_list from a boolean."""
|
16 |
+
return tf.train.Feature(int64_list=tf.train.Int64List(value=[int(value)]))
|
17 |
+
|
18 |
+
|
19 |
+
def serialize_example(action, robot_obs, rgb_static, rgb_gripper, instruction, terminate_episode):
|
20 |
+
# Feature for fixed-length fields
|
21 |
+
feature = {
|
22 |
+
'action': _bytes_feature(tf.io.serialize_tensor(action)),
|
23 |
+
'robot_obs': _bytes_feature(tf.io.serialize_tensor(robot_obs)),
|
24 |
+
'rgb_static': _bytes_feature(tf.io.serialize_tensor(rgb_static)),
|
25 |
+
'rgb_gripper': _bytes_feature(tf.io.serialize_tensor(rgb_gripper)),
|
26 |
+
'terminate_episode': _bool_feature(terminate_episode),
|
27 |
+
'instruction': _bytes_feature(instruction),
|
28 |
+
}
|
29 |
+
|
30 |
+
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
|
31 |
+
return example_proto.SerializeToString()
|
32 |
+
|
33 |
+
|
34 |
+
def write_tfrecords(root_dir, out_dir):
|
35 |
+
if not os.path.exists(out_dir):
|
36 |
+
os.makedirs(out_dir)
|
37 |
+
|
38 |
+
# Get the language annotation and corresponding indices
|
39 |
+
f = np.load(os.path.join(root_dir, "lang_annotations/auto_lang_ann.npy"), allow_pickle=True)
|
40 |
+
lang = f.item()['language']['ann']
|
41 |
+
lang = np.array([x.encode('utf-8') for x in lang])
|
42 |
+
lang_start_end_idx = f.item()['info']['indx']
|
43 |
+
num_ep = len(lang_start_end_idx)
|
44 |
+
|
45 |
+
with tqdm(total=num_ep) as pbar:
|
46 |
+
for episode_idx, (start_idx, end_idx) in enumerate(lang_start_end_idx):
|
47 |
+
pbar.update(1)
|
48 |
+
|
49 |
+
step_files = [
|
50 |
+
f"episode_{str(i).zfill(7)}.npz"
|
51 |
+
for i in range(start_idx, end_idx + 1)
|
52 |
+
]
|
53 |
+
action = []
|
54 |
+
robot_obs = []
|
55 |
+
rgb_static = []
|
56 |
+
rgb_gripper = []
|
57 |
+
instr = lang[episode_idx]
|
58 |
+
for step_file in step_files:
|
59 |
+
filepath = os.path.join(root_dir, step_file)
|
60 |
+
f = np.load(filepath)
|
61 |
+
# Get relevent things
|
62 |
+
action.append(f['actions'])
|
63 |
+
robot_obs.append(f['robot_obs'])
|
64 |
+
rgb_static.append(f['rgb_static'])
|
65 |
+
rgb_gripper.append(f['rgb_gripper'])
|
66 |
+
|
67 |
+
tfrecord_path = os.path.join(out_dir, f'{episode_idx:07d}.tfrecord')
|
68 |
+
print(f"Writing TFRecords to {tfrecord_path}")
|
69 |
+
with tf.io.TFRecordWriter(tfrecord_path) as writer:
|
70 |
+
for i in range(len(step_files)):
|
71 |
+
serialized_example = serialize_example(
|
72 |
+
action[i], robot_obs[i], rgb_static[i], rgb_gripper[i], instr, i == len(step_files) - 1
|
73 |
+
)
|
74 |
+
writer.write(serialized_example)
|
75 |
+
|
76 |
+
output_dirs = [
|
77 |
+
'../datasets/calvin/tfrecords/training',
|
78 |
+
'../datasets/calvin/tfrecords/validation'
|
79 |
+
]
|
80 |
+
|
81 |
+
for output_dir in output_dirs:
|
82 |
+
if not os.path.exists(output_dir):
|
83 |
+
os.makedirs(output_dir)
|
84 |
+
|
85 |
+
root_dirs = [
|
86 |
+
'../datasets/calvin/task_ABC_D/training',
|
87 |
+
'../datasets/calvin/task_ABC_D/validation'
|
88 |
+
]
|
89 |
+
|
90 |
+
for root_dir, output_dir in zip(root_dirs, output_dirs):
|
91 |
+
print(f"Writing TFRecords to {output_dir}")
|
92 |
+
write_tfrecords(root_dir, output_dir)
|
data/rh20t/hdf5totfrecords.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
import cv2
|
4 |
+
from multiprocessing import Pool, cpu_count, current_process
|
5 |
+
import tensorflow as tf
|
6 |
+
from tqdm import tqdm
|
7 |
+
import json
|
8 |
+
|
9 |
+
def _parse_function(proto):
|
10 |
+
# Define how to parse the data here.
|
11 |
+
feature_description = {
|
12 |
+
'joint': tf.io.FixedLenFeature([], tf.string),
|
13 |
+
'image': tf.io.FixedLenFeature([], tf.string),
|
14 |
+
'instruction': tf.io.FixedLenFeature([], tf.string),
|
15 |
+
'terminate_episode': tf.io.FixedLenFeature([], tf.int64),
|
16 |
+
'gripper': tf.io.FixedLenFeature([], tf.string, default_value=""),
|
17 |
+
'tcp': tf.io.FixedLenFeature([], tf.string, default_value=""),
|
18 |
+
'tcp_base': tf.io.FixedLenFeature([], tf.string, default_value="")
|
19 |
+
}
|
20 |
+
parsed_features = tf.io.parse_single_example(proto, feature_description)
|
21 |
+
# Parse tensors
|
22 |
+
parsed_features['joint'] = tf.io.parse_tensor(parsed_features['joint'], out_type=tf.float64)
|
23 |
+
parsed_features['image'] = tf.io.parse_tensor(parsed_features['image'], out_type=tf.uint8)
|
24 |
+
parsed_features['instruction'] = tf.io.parse_tensor(parsed_features['instruction'], out_type=tf.string)
|
25 |
+
parsed_features['gripper'] = tf.cond(
|
26 |
+
tf.math.equal(parsed_features['gripper'], ""),
|
27 |
+
lambda: tf.constant([], dtype=tf.float64),
|
28 |
+
lambda: tf.io.parse_tensor(parsed_features['gripper'], out_type=tf.float64)
|
29 |
+
)
|
30 |
+
parsed_features['tcp'] = tf.cond(
|
31 |
+
tf.math.equal(parsed_features['tcp'], ""),
|
32 |
+
lambda: tf.constant([], dtype=tf.float64),
|
33 |
+
lambda: tf.io.parse_tensor(parsed_features['tcp'], out_type=tf.float64)
|
34 |
+
)
|
35 |
+
parsed_features['tcp_base'] = tf.cond(
|
36 |
+
tf.math.equal(parsed_features['tcp_base'], ""),
|
37 |
+
lambda: tf.constant([], dtype=tf.float64),
|
38 |
+
lambda: tf.io.parse_tensor(parsed_features['tcp_base'], out_type=tf.float64)
|
39 |
+
)
|
40 |
+
return parsed_features
|
41 |
+
|
42 |
+
def convert_color(color_file, color_timestamps):
|
43 |
+
"""
|
44 |
+
Args:
|
45 |
+
- color_file: the color video file;
|
46 |
+
- color_timestamps: the color timestamps;
|
47 |
+
- dest_color_dir: the destination color directory.
|
48 |
+
"""
|
49 |
+
cap = cv2.VideoCapture(color_file)
|
50 |
+
cnt = 0
|
51 |
+
frames = []
|
52 |
+
while True:
|
53 |
+
ret, frame = cap.read()
|
54 |
+
if ret:
|
55 |
+
resized_frame = cv2.resize(frame, (640, 360))
|
56 |
+
frames.append(resized_frame)
|
57 |
+
cnt += 1
|
58 |
+
else:
|
59 |
+
break
|
60 |
+
cap.release()
|
61 |
+
return frames
|
62 |
+
|
63 |
+
def _bytes_feature(value):
|
64 |
+
if isinstance(value, type(tf.constant(0))):
|
65 |
+
value = value.numpy()
|
66 |
+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
|
67 |
+
|
68 |
+
def _bool_feature(value):
|
69 |
+
return tf.train.Feature(int64_list=tf.train.Int64List(value=[int(value)]))
|
70 |
+
|
71 |
+
def serialize_example(joint,gripper,tcp,tcp_base,image,instruction,terminate_episode):
|
72 |
+
feature = {
|
73 |
+
'joint': _bytes_feature(tf.io.serialize_tensor(joint)),
|
74 |
+
'image': _bytes_feature(tf.io.serialize_tensor(image)),
|
75 |
+
'instruction': _bytes_feature(tf.io.serialize_tensor(instruction)),
|
76 |
+
'terminate_episode': _bool_feature(terminate_episode),
|
77 |
+
}
|
78 |
+
if gripper is not None:
|
79 |
+
feature['gripper'] = _bytes_feature(tf.io.serialize_tensor(gripper))
|
80 |
+
if tcp is not None:
|
81 |
+
feature['tcp'] = _bytes_feature(tf.io.serialize_tensor(tcp))
|
82 |
+
if tcp_base is not None:
|
83 |
+
feature['tcp_base'] = _bytes_feature(tf.io.serialize_tensor(tcp_base))
|
84 |
+
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
|
85 |
+
return example_proto.SerializeToString()
|
86 |
+
|
87 |
+
def compress_tfrecord(tfrecord_path):
|
88 |
+
raw_dataset = tf.data.TFRecordDataset(tfrecord_path)
|
89 |
+
parsed_dataset = raw_dataset.map(_parse_function)
|
90 |
+
|
91 |
+
# Serialize and write to a new TFRecord file
|
92 |
+
with tf.io.TFRecordWriter(tfrecord_path) as writer:
|
93 |
+
for features in parsed_dataset:
|
94 |
+
image_tensor = features['image']
|
95 |
+
image_np = image_tensor.numpy()
|
96 |
+
if len(image_np.shape) <= 1: # already compressed
|
97 |
+
return
|
98 |
+
_, compressed_image = cv2.imencode('.jpg', image_np)
|
99 |
+
features['image'] = tf.io.serialize_tensor(tf.convert_to_tensor(compressed_image.tobytes(), dtype=tf.string))
|
100 |
+
|
101 |
+
def _bytes_feature(value):
|
102 |
+
"""Returns a bytes_list from a string / byte."""
|
103 |
+
if isinstance(value, type(tf.constant(0))):
|
104 |
+
value = value.numpy()
|
105 |
+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
|
106 |
+
|
107 |
+
feature_dict = {
|
108 |
+
'joint': _bytes_feature(features['joint']),
|
109 |
+
'image': _bytes_feature(features['image']),
|
110 |
+
'instruction': _bytes_feature(features['instruction']),
|
111 |
+
'terminate_episode': tf.train.Feature(int64_list=tf.train.Int64List(value=[features['terminate_episode']])),
|
112 |
+
'gripper': _bytes_feature(features['gripper']),
|
113 |
+
'tcp': _bytes_feature(features['tcp']),
|
114 |
+
'tcp_base': _bytes_feature(features['tcp_base'])
|
115 |
+
}
|
116 |
+
example_proto = tf.train.Example(features=tf.train.Features(feature=feature_dict))
|
117 |
+
serialized_example = example_proto.SerializeToString()
|
118 |
+
writer.write(serialized_example)
|
119 |
+
print(f"compressed {tfrecord_path}")
|
120 |
+
|
121 |
+
def write_task(args):
|
122 |
+
task_dir,output_dir = args
|
123 |
+
|
124 |
+
all_instructions = json.load(open('./instruction.json'))
|
125 |
+
instruction = None
|
126 |
+
for taskid in list(all_instructions.keys()):
|
127 |
+
if taskid in task_dir:
|
128 |
+
instruction = all_instructions[taskid]['task_description_english']
|
129 |
+
if instruction is None:
|
130 |
+
return
|
131 |
+
|
132 |
+
if not os.path.exists(output_dir):
|
133 |
+
os.makedirs(output_dir)
|
134 |
+
joints = np.load(os.path.join(task_dir,"transformed/joint.npy"),allow_pickle=True).item()
|
135 |
+
if not os.path.exists(os.path.join(task_dir,"transformed/gripper.npy")):
|
136 |
+
return
|
137 |
+
grippers = np.load(os.path.join(task_dir,"transformed/gripper.npy"),allow_pickle=True).item()
|
138 |
+
tcps = np.load(os.path.join(task_dir,"transformed/tcp.npy"),allow_pickle=True).item()
|
139 |
+
tcp_bases = np.load(os.path.join(task_dir,"transformed/tcp_base.npy"),allow_pickle=True).item()
|
140 |
+
|
141 |
+
for camid in joints.keys():
|
142 |
+
timesteps = joints[camid]
|
143 |
+
if len(timesteps) == 0:
|
144 |
+
continue
|
145 |
+
tfrecord_path = os.path.join(output_dir,f'cam_{camid}.tfrecord')
|
146 |
+
timesteps_file = os.path.join(task_dir,f'cam_{camid}/timestamps.npy')
|
147 |
+
|
148 |
+
if not os.path.exists(timesteps_file):
|
149 |
+
continue
|
150 |
+
if os.path.exists(tfrecord_path) and os.path.getsize(tfrecord_path) > 0:
|
151 |
+
continue
|
152 |
+
|
153 |
+
timesteps_file = np.load(timesteps_file,allow_pickle=True).item()
|
154 |
+
images = convert_color(os.path.join(task_dir,f'cam_{camid}/color.mp4'),timesteps_file['color'])
|
155 |
+
if len(timesteps) != len(images): ## BUG FROM RH20T
|
156 |
+
continue
|
157 |
+
with tf.io.TFRecordWriter(tfrecord_path) as writer:
|
158 |
+
for i,timestep in enumerate(timesteps):
|
159 |
+
# image = cv2.imread(os.path.join(img_dir,f"{timestep}.jpg"))
|
160 |
+
image = cv2.imencode('.jpg', images[i])[1].tobytes()
|
161 |
+
joint_pos = joints[camid][timestep]
|
162 |
+
tcp = next((item for item in tcps[camid] if item['timestamp'] == timestep), None)['tcp']
|
163 |
+
tcp_base = next((item for item in tcp_bases[camid] if item['timestamp'] == timestep), None)['tcp']
|
164 |
+
if timestep not in grippers[camid]:
|
165 |
+
gripper_pos = None
|
166 |
+
else:
|
167 |
+
gripper_pos = grippers[camid][timestep]['gripper_info']
|
168 |
+
terminate_episode = i == len(timesteps) - 1
|
169 |
+
# read from instruction.json
|
170 |
+
serialized_example = serialize_example(joint_pos,gripper_pos,tcp,tcp_base,image,instruction,terminate_episode)
|
171 |
+
writer.write(serialized_example)
|
172 |
+
|
173 |
+
|
174 |
+
def write_tfrecords(root_dir,output_dir,num_processes = None):
|
175 |
+
if not os.path.exists(output_dir):
|
176 |
+
os.makedirs(output_dir)
|
177 |
+
if num_processes is None:
|
178 |
+
num_processes = cpu_count()
|
179 |
+
|
180 |
+
num_files = 0
|
181 |
+
args = []
|
182 |
+
for dirs in os.listdir(root_dir):
|
183 |
+
for task in os.listdir(os.path.join(root_dir,dirs)):
|
184 |
+
if 'human' in task:
|
185 |
+
continue
|
186 |
+
task_dir = os.path.join(root_dir,dirs,task)
|
187 |
+
joint_path = os.path.join(task_dir,"transformed/joint.npy")
|
188 |
+
if not os.path.exists(joint_path):
|
189 |
+
continue
|
190 |
+
num_files += 1
|
191 |
+
task_out = os.path.join(output_dir,dirs,task)
|
192 |
+
os.makedirs(task_out,exist_ok=True)
|
193 |
+
args.append((task_dir,task_out))
|
194 |
+
|
195 |
+
with tqdm(total=num_files, desc="Processing files") as pbar:
|
196 |
+
with Pool(num_processes) as pool:
|
197 |
+
for _ in pool.imap_unordered(write_task, args):
|
198 |
+
pbar.update(1)
|
199 |
+
|
200 |
+
write_tfrecords('../datasets/rh20t/raw_data/','../datasets/rh20t/tfrecords/')
|
data/roboset/download.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import os
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
links = []
|
6 |
+
with open('links.txt', 'r', encoding='utf-8') as file:
|
7 |
+
for line in file:
|
8 |
+
links.append(line.strip())
|
9 |
+
|
10 |
+
download_dir = "../datasets/roboset"
|
11 |
+
os.makedirs(download_dir, exist_ok=True)
|
12 |
+
|
13 |
+
for link in links:
|
14 |
+
filename = os.path.basename(link)
|
15 |
+
filepath = os.path.join(download_dir, filename)
|
16 |
+
print(f"Downloading {filename} from {link}")
|
17 |
+
|
18 |
+
response = requests.get(link, stream=True)
|
19 |
+
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
20 |
+
block_size = 1024
|
21 |
+
|
22 |
+
if os.path.exists(filepath):
|
23 |
+
local_size = os.path.getsize(filepath)
|
24 |
+
if local_size == total_size_in_bytes:
|
25 |
+
print(f"{filename} already exists")
|
26 |
+
continue
|
27 |
+
|
28 |
+
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
29 |
+
|
30 |
+
with open(filepath, 'wb') as f:
|
31 |
+
for data in response.iter_content(block_size):
|
32 |
+
progress_bar.update(len(data))
|
33 |
+
f.write(data)
|
34 |
+
|
35 |
+
progress_bar.close()
|
36 |
+
|
37 |
+
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
|
38 |
+
print("ERROR, something went wrong")
|
39 |
+
|
40 |
+
print(f"Downloaded {filename}")
|
41 |
+
|
42 |
+
print("All files processed.")
|
data/roboset/download.sh
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
while true; do
|
4 |
+
python download.py
|
5 |
+
EXIT_CODE=$?
|
6 |
+
if [ $EXIT_CODE -ne 0 ]; then
|
7 |
+
echo "Download exited with code $EXIT_CODE. Restarting..."
|
8 |
+
else
|
9 |
+
echo "Download exited with code 0. Not restarting."
|
10 |
+
break
|
11 |
+
fi
|
12 |
+
done
|
13 |
+
|
14 |
+
# Unzip all the files in the ../datasets/roboset/ directory
|
15 |
+
cd ../datasets/roboset/
|
16 |
+
for file in *.tar.gz; do
|
17 |
+
tar -xzvf "$file"
|
18 |
+
done
|
19 |
+
|
20 |
+
## Convert the dataset to tfrecords
|
21 |
+
python hdf5totfrecords.py
|
data/roboset/h5totfrecords.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import h5py
|
3 |
+
import os
|
4 |
+
import fnmatch
|
5 |
+
from tqdm import tqdm
|
6 |
+
from multiprocessing import Pool, cpu_count, current_process
|
7 |
+
|
8 |
+
def _bytes_feature(value):
|
9 |
+
if isinstance(value, type(tf.constant(0))):
|
10 |
+
value = value.numpy()
|
11 |
+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
|
12 |
+
|
13 |
+
def _bool_feature(value):
|
14 |
+
return tf.train.Feature(int64_list=tf.train.Int64List(value=[int(value)]))
|
15 |
+
|
16 |
+
def serialize_example(action, action_gripper, qpos, qvel, qpos_gripper, qvel_gripper, rgb_left, rgb_right, rgb_top, terminate_episode):
|
17 |
+
feature = {
|
18 |
+
'action': _bytes_feature(tf.io.serialize_tensor(action)),
|
19 |
+
'action_gripper': _bytes_feature(tf.io.serialize_tensor(action_gripper)),
|
20 |
+
'qpos': _bytes_feature(tf.io.serialize_tensor(qpos)),
|
21 |
+
'qvel': _bytes_feature(tf.io.serialize_tensor(qvel)),
|
22 |
+
'qpos_gripper': _bytes_feature(tf.io.serialize_tensor(qpos_gripper)),
|
23 |
+
'qvel_gripper': _bytes_feature(tf.io.serialize_tensor(qvel_gripper)),
|
24 |
+
'rgb_left': _bytes_feature(tf.io.serialize_tensor(rgb_left)),
|
25 |
+
'rgb_right': _bytes_feature(tf.io.serialize_tensor(rgb_right)),
|
26 |
+
'rgb_top': _bytes_feature(tf.io.serialize_tensor(rgb_top)),
|
27 |
+
'terminate_episode': _bool_feature(terminate_episode),
|
28 |
+
}
|
29 |
+
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
|
30 |
+
return example_proto.SerializeToString()
|
31 |
+
|
32 |
+
def process_file(params):
|
33 |
+
filepath, output_dir = params
|
34 |
+
with h5py.File(filepath, 'r') as f:
|
35 |
+
for Trial in f.keys():
|
36 |
+
data = f[Trial]['data']
|
37 |
+
tfrecord_path = os.path.join(output_dir, os.path.basename(filepath).replace('.h5', f'_{Trial}.tfrecord'))
|
38 |
+
if os.path.exists(tfrecord_path) and os.path.getsize(tfrecord_path) > 0:
|
39 |
+
continue
|
40 |
+
with tf.io.TFRecordWriter(tfrecord_path) as writer:
|
41 |
+
num_episodes = data['ctrl_arm'].shape[0]
|
42 |
+
for i in range(num_episodes):
|
43 |
+
action = data['ctrl_arm'][i]
|
44 |
+
action_gripper = data['ctrl_ee'][i]
|
45 |
+
qpos = data['qp_arm'][i]
|
46 |
+
qvel = data['qv_arm'][i]
|
47 |
+
qpos_gripper = data['qp_ee'][i]
|
48 |
+
qvel_gripper = data['qv_ee'][i]
|
49 |
+
rgb_left = data['rgb_left'][i]
|
50 |
+
rgb_right = data['rgb_right'][i]
|
51 |
+
rgb_top = data['rgb_top'][i]
|
52 |
+
terminate_episode = i == num_episodes - 1
|
53 |
+
serialized_example = serialize_example(action, action_gripper, qpos, qvel, qpos_gripper, qvel_gripper, rgb_left, rgb_right, rgb_top, terminate_episode)
|
54 |
+
writer.write(serialized_example)
|
55 |
+
|
56 |
+
def write_tfrecords(root_dir, out_dir, num_processes=None):
|
57 |
+
if not os.path.exists(out_dir):
|
58 |
+
os.makedirs(out_dir)
|
59 |
+
|
60 |
+
if num_processes is None:
|
61 |
+
num_processes = cpu_count()
|
62 |
+
|
63 |
+
file_list = []
|
64 |
+
num_files = 0
|
65 |
+
for root, dirs, files in os.walk(root_dir):
|
66 |
+
for filename in fnmatch.filter(files, '*.h5'):
|
67 |
+
filepath = os.path.join(root, filename)
|
68 |
+
output_dir = os.path.join(out_dir, os.path.relpath(os.path.dirname(filepath), root_dir))
|
69 |
+
if not os.path.exists(output_dir):
|
70 |
+
os.makedirs(output_dir)
|
71 |
+
num_files += 1
|
72 |
+
file_list.append((filepath, output_dir))
|
73 |
+
|
74 |
+
with tqdm(total=num_files, desc="Processing files") as pbar:
|
75 |
+
with Pool(num_processes) as pool:
|
76 |
+
for _ in pool.imap_unordered(process_file, file_list):
|
77 |
+
pbar.update(1)
|
78 |
+
|
79 |
+
root_dir = '../datasets/roboset/'
|
80 |
+
output_dir = '../datasets/roboset/tfrecords/'
|
81 |
+
|
82 |
+
write_tfrecords(root_dir, output_dir)
|
data/roboset/links.txt
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_1_Blocks_895/AutonomousRoboSet_Set_1_Blocks_895.tar.gz
|
2 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_0.tar.gz
|
3 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_1.tar.gz
|
4 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_2.tar.gz
|
5 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_3.tar.gz
|
6 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_4.tar.gz
|
7 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_5.tar.gz
|
8 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_6.tar.gz
|
9 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_7.tar.gz
|
10 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_8.tar.gz
|
11 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_9.tar.gz
|
12 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_10.tar.gz
|
13 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_11.tar.gz
|
14 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_12.tar.gz
|
15 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_13.tar.gz
|
16 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_3_Blocks_and_Toys_980/Autonomous_RoboSet_Set_3_Blocks_and_Toys_980.tar.gz
|
17 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_4_Medium_Block_7/Autonomous_RoboSet_Set_4_Medium_Block_7.tar.gz
|
18 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_5_Bottle_Cube_14/Autonomous_RoboSet_Set_5_Bottle_Cube_14.tar.gz
|
19 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_6_Planar_Push_120/Autonomous_RoboSet_Set_6_Planar_Push_120.tar.gz
|
20 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_7_Pick_Orange_Block_607/Autonomous_RoboSet_Set_7_Pick_Orange_Block_607_0.tar.gz
|
21 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_7_Pick_Orange_Block_607/Autonomous_RoboSet_Set_7_Pick_Orange_Block_607_1.tar.gz
|
22 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_7_Pick_Orange_Block_607/Autonomous_RoboSet_Set_7_Pick_Orange_Block_607_2.tar.gz
|
23 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_7_Pick_Orange_Block_607/Autonomous_RoboSet_Set_7_Pick_Orange_Block_607_3.tar.gz
|
24 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_7_Pick_Orange_Block_607/Autonomous_RoboSet_Set_7_Pick_Orange_Block_607_4.tar.gz
|
25 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_8_Pick_Bottle_10/Autonomous_RoboSet_Set_8_Pick_Bottle_10.tar.gz
|
26 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_5420/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_542_0.tar.gz
|
27 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_5420/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_542_1.tar.gz
|
28 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_5420/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_542_2.tar.gz
|
29 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_5420/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_542_3.tar.gz
|
30 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_5420/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_542_4.tar.gz
|
31 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_5420/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_542_5.tar.gz
|
32 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_5420/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_542_6.tar.gz
|
33 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_5420/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_542_7.tar.gz
|
34 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_5420/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_542_8.tar.gz
|
35 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_5420/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_542_9.tar.gz
|
36 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_10_Pick_Block_Eval_1837/Autonomous_RoboSet_Set_10_Pick_Block_Eval_167_0.tar.gz
|
37 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_10_Pick_Block_Eval_1837/Autonomous_RoboSet_Set_10_Pick_Block_Eval_167_1.tar.gz
|
38 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_10_Pick_Block_Eval_1837/Autonomous_RoboSet_Set_10_Pick_Block_Eval_167_2.tar.gz
|
39 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_10_Pick_Block_Eval_1837/Autonomous_RoboSet_Set_10_Pick_Block_Eval_167_3.tar.gz
|
40 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_10_Pick_Block_Eval_1837/Autonomous_RoboSet_Set_10_Pick_Block_Eval_167_4.tar.gz
|
41 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_10_Pick_Block_Eval_1837/Autonomous_RoboSet_Set_10_Pick_Block_Eval_167_5.tar.gz
|
42 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_10_Pick_Block_Eval_1837/Autonomous_RoboSet_Set_10_Pick_Block_Eval_167_6.tar.gz
|
43 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_10_Pick_Block_Eval_1837/Autonomous_RoboSet_Set_10_Pick_Block_Eval_167_7.tar.gz
|
44 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_10_Pick_Block_Eval_1837/Autonomous_RoboSet_Set_10_Pick_Block_Eval_167_8.tar.gz
|
45 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_10_Pick_Block_Eval_1837/Autonomous_RoboSet_Set_10_Pick_Block_Eval_167_9.tar.gz
|
46 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_10_Pick_Block_Eval_1837/Autonomous_RoboSet_Set_10_Pick_Block_Eval_167_10.tar.gz
|
47 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_2070/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_207_0.tar.gz
|
48 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_2070/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_207_1.tar.gz
|
49 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_2070/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_207_2.tar.gz
|
50 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_2070/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_207_3.tar.gz
|
51 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_2070/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_207_4.tar.gz
|
52 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_2070/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_207_5.tar.gz
|
53 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_2070/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_207_6.tar.gz
|
54 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_2070/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_207_7.tar.gz
|
55 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_2070/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_207_8.tar.gz
|
56 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_2070/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_207_9.tar.gz
|
57 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_12_Pick_Block_Eval_2000/Autonomous_RoboSet_Set_12_Pick_Block_Eval_200_0.tar.gz
|
58 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_12_Pick_Block_Eval_2000/Autonomous_RoboSet_Set_12_Pick_Block_Eval_200_1.tar.gz
|
59 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_12_Pick_Block_Eval_2000/Autonomous_RoboSet_Set_12_Pick_Block_Eval_200_2.tar.gz
|
60 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_12_Pick_Block_Eval_2000/Autonomous_RoboSet_Set_12_Pick_Block_Eval_200_3.tar.gz
|
61 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_12_Pick_Block_Eval_2000/Autonomous_RoboSet_Set_12_Pick_Block_Eval_200_4.tar.gz
|
62 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_12_Pick_Block_Eval_2000/Autonomous_RoboSet_Set_12_Pick_Block_Eval_200_5.tar.gz
|
63 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_12_Pick_Block_Eval_2000/Autonomous_RoboSet_Set_12_Pick_Block_Eval_200_6.tar.gz
|
64 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_12_Pick_Block_Eval_2000/Autonomous_RoboSet_Set_12_Pick_Block_Eval_200_7.tar.gz
|
65 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_12_Pick_Block_Eval_2000/Autonomous_RoboSet_Set_12_Pick_Block_Eval_200_8.tar.gz
|
66 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_12_Pick_Block_Eval_2000/Autonomous_RoboSet_Set_12_Pick_Block_Eval_200_9.tar.gz
|
67 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_0.tar.gz
|
68 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_1.tar.gz
|
69 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_2.tar.gz
|
70 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_3.tar.gz
|
71 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_4.tar.gz
|
72 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_5.tar.gz
|
73 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_6.tar.gz
|
74 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_7.tar.gz
|
75 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_8.tar.gz
|
76 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_9.tar.gz
|
77 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_10.tar.gz
|
78 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_11.tar.gz
|
79 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_12.tar.gz
|
80 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_13.tar.gz
|
81 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_14.tar.gz
|
82 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_0.tar.gz
|
83 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_1.tar.gz
|
84 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_2.tar.gz
|
85 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_3.tar.gz
|
86 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_4.tar.gz
|
87 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_5.tar.gz
|
88 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_6.tar.gz
|
89 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_7.tar.gz
|
90 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_8.tar.gz
|
91 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_9.tar.gz
|
92 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_10.tar.gz
|
93 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_11.tar.gz
|
94 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_12.tar.gz
|
95 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_13.tar.gz
|
96 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_14.tar.gz
|
97 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_15.tar.gz
|
98 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_16.tar.gz
|
99 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_17.tar.gz
|
100 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_18.tar.gz
|
101 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_19.tar.gz
|
102 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_1470/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_147_0.tar.gz
|
103 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_1470/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_147_1.tar.gz
|
104 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_1470/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_147_2.tar.gz
|
105 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_1470/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_147_3.tar.gz
|
106 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_1470/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_147_4.tar.gz
|
107 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_1470/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_147_5.tar.gz
|
108 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_1470/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_147_6.tar.gz
|
109 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_1470/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_147_7.tar.gz
|
110 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_1470/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_147_8.tar.gz
|
111 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_1470/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_147_9.tar.gz
|
112 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_0.tar.gz
|
113 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_1.tar.gz
|
114 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_2.tar.gz
|
115 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_3.tar.gz
|
116 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_4.tar.gz
|
117 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_5.tar.gz
|
118 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_6.tar.gz
|
119 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_7.tar.gz
|
120 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_8.tar.gz
|
121 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_9.tar.gz
|
122 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_10.tar.gz
|
123 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_11.tar.gz
|
124 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_12.tar.gz
|
125 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_13.tar.gz
|
126 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_14.tar.gz
|
127 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_0.tar.gz
|
128 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_1.tar.gz
|
129 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_2.tar.gz
|
130 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_3.tar.gz
|
131 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_4.tar.gz
|
132 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_5.tar.gz
|
133 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_6.tar.gz
|
134 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_7.tar.gz
|
135 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_8.tar.gz
|
136 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_9.tar.gz
|
137 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_10.tar.gz
|
138 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_11.tar.gz
|
139 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_12.tar.gz
|
140 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_13.tar.gz
|
141 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_14.tar.gz
|
142 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_15.tar.gz
|
143 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_16.tar.gz
|
144 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_17.tar.gz
|
145 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_18.tar.gz
|
146 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_19.tar.gz
|
147 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_20.tar.gz
|
148 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_21.tar.gz
|
149 |
+
https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_22.tar.gz
|
150 |
+
https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_banana_place_in_mug.tar.gz
|
151 |
+
https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_banana_place_in_strainer.tar.gz
|
152 |
+
https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_banana_from_plate_place_on_table.tar.gz
|
153 |
+
https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_banana_from_toaster_place_on_table.tar.gz
|
154 |
+
https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_banana_place_on_plate.tar.gz
|
155 |
+
https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_banana_place_on_toaster.tar.gz
|
156 |
+
https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_ketchup_place_in_strainer.tar.gz
|
157 |
+
https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_ketchup_place_in_toaster.tar.gz
|
158 |
+
https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_ketchup_from_strainer_place_on_table.tar.gz
|
159 |
+
https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_ketchup_from_plate_place_on_table.tar.gz
|
160 |
+
https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_ketchup_from_toaster_place_on_table.tar.gz
|
161 |
+
https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_ketchup_place_on_plate.tar.gz
|
162 |
+
https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_ketchup_place_on_toaster.tar.gz
|
163 |
+
https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/drag_mug_backward.tar.gz
|
164 |
+
https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/drag_mug_forward.tar.gz
|
165 |
+
https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/drag_mug_from_left_to_right.tar.gz
|
166 |
+
https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/drag_mug_from_right_to_left.tar.gz
|
167 |
+
https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/drag_strainer_backward.tar.gz
|
168 |
+
https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/drag_strainer_forward.tar.gz
|
169 |
+
https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/drag_strainer_left_to_right.tar.gz
|
170 |
+
https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/drag_strainer_right_to_left.tar.gz
|
171 |
+
https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/flap_open_toaster_oven.tar.gz
|
172 |
+
https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/flap_close_toaster_oven.tar.gz
|
173 |
+
http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/baking_prep/baking_prep_slide_open_drawer_scene_1.tar.gz
|
174 |
+
http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/baking_prep/baking_prep_slide_open_drawer_scene_4.tar.gz
|
175 |
+
http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/baking_prep/baking_prep_pick_butter_scene_1.tar.gz
|
176 |
+
http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/baking_prep/baking_prep_pick_butter_scene_4.tar.gz
|
177 |
+
http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/baking_prep/baking_prep_place_butter_scene_1.tar.gz
|
178 |
+
http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/baking_prep/baking_prep_place_butter_scene_4.tar.gz
|
179 |
+
http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/baking_prep/baking_prep_slide_close_drawer_scene_1.tar.gz
|
180 |
+
http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/baking_prep/baking_prep_slide_close_drawer_scene_4.tar.gz
|
181 |
+
http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/clean_kitchen/clean_kitchen_pick_lid_scene_3.tar.gz
|
182 |
+
http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/clean_kitchen/clean_kitchen_cap_lid_scene_3.tar.gz
|
183 |
+
http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/clean_kitchen/clean_kitchen_slide_close_drawer_scene_3.tar.gz
|
184 |
+
http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/clean_kitchen/clean_kitchen_flap_close_oven_Scene_3.tar.gz
|
185 |
+
http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/clean_kitchen/clean_kitchen_pick_towel_scene_3.tar.gz
|
186 |
+
http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/clean_kitchen/clean_kitchen_Wipe_Counter_Scene_3.tar.gz
|
187 |
+
http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/heat_soup/heat_soup_flap_open_oven_Scene_2.tar.gz
|
188 |
+
http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/heat_soup/heat_soup_flap_open_oven_Scene_4.tar.gz
|
189 |
+
http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/heat_soup/heat_soup_pick_bowl_scene_2.tar.gz
|
190 |
+
http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/heat_soup/heat_soup_pick_bowl_scene_4.tar.gz
|
191 |
+
http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/heat_soup/heat_soup_slide_in_bowl_scene_2.tar.gz
|
192 |
+
http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/heat_soup/heat_soup_slide_in_bowl_scene_4.tar.gz
|
193 |
+
http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/heat_soup/heat_soup_flap_close_oven_scene_2.tar.gz
|
194 |
+
http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/make_tea/make_tea_Uncap_Lid_Scene_2.tar.gz
|
195 |
+
http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/make_tea/make_tea_place_lid_scene_2.tar.gz
|
196 |
+
http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/make_tea/make_tea_pick_tea_scene_2.tar.gz
|
197 |
+
http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/make_tea/make_tea_place_tea_scene_2.tar.gz
|
docs/pretrain.md
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Pipeline of Pre-Training RDT
|
2 |
+
|
3 |
+
Firstly, you need to install the prerequisites for RDT (see [README](../README.md#installation)). Then, you can install the prerequisites for TensorFlow Dataset (in another Conda environment).
|
4 |
+
|
5 |
+
## Installation for TensorFlow Dataset
|
6 |
+
|
7 |
+
```bash
|
8 |
+
# Under the root directory of this repo
|
9 |
+
conda create -n rdt-data python=3.10
|
10 |
+
conda activate rdt-data
|
11 |
+
|
12 |
+
# Install all the prequisites
|
13 |
+
pip install -r requirements_data.txt
|
14 |
+
# Or you can manually install each package (please refer to requirements_data.txt for specific versions)
|
15 |
+
pip install tfds-nightly gsutil tensorflow Pillow pyyaml opencv-python tensorflow-graphics imageio[ffmpeg]
|
16 |
+
# If the speed is too slow, you can specify alternative sources (tfds-nightly is not available in Tsinghua mirror)
|
17 |
+
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple gsutil tensorflow Pillow pyyaml opencv-python tensorflow-graphics imageio[ffmpeg]
|
18 |
+
```
|
19 |
+
|
20 |
+
## Download and Prepare Datasets
|
21 |
+
|
22 |
+
We introduce how to download each of our pre-training datasets. If you plan to pre-train on a subset of them, just download the ones you need. You can also fine-tune RDT through this pipeline only if your target dataset is included below or in the Google Cloud Storage.
|
23 |
+
|
24 |
+
| Dataset | Sample Percentage (%) |
|
25 |
+
| ---- | ---- |
|
26 |
+
| RT-1 Dataset | 9.00 |
|
27 |
+
| TACO Dataset | 1.99 |
|
28 |
+
| JACO Play Dataset | 1.10 |
|
29 |
+
| Cable Routing Dataset | 0.27 |
|
30 |
+
| NYU Door Opening | 0.33 |
|
31 |
+
| Viola | 0.40 |
|
32 |
+
| Berkeley UR5 | 1.06 |
|
33 |
+
| TOTO | 1.06 |
|
34 |
+
| Kuka | 1.66 |
|
35 |
+
| Language Table | 3.32 |
|
36 |
+
| Columbia Cairlab Pusht Real | 0.40 |
|
37 |
+
| Stanford Kuka Multimodal Dataset | 1.83 |
|
38 |
+
| Stanford Hydra Dataset | 0.80 |
|
39 |
+
| Austin Buds Dataset | 0.23 |
|
40 |
+
| Maniskill Dataset | 5.78 |
|
41 |
+
| Furniture Bench Dataset | 2.36 |
|
42 |
+
| UCSD Kitchen Dataset | 0.40 |
|
43 |
+
| UCSD Pick And Place Dataset | 1.23 |
|
44 |
+
| Austin Sailor Dataset | 0.50 |
|
45 |
+
| Austin Sirius Dataset | 0.80 |
|
46 |
+
| BC Z | 6.91 |
|
47 |
+
| UTokyo PR2 Opening Fridge | 0.30 |
|
48 |
+
| UTokyo PR2 Tabletop Manipulation | 0.50 |
|
49 |
+
| UTokyo Xarm Pick And Place | 0.33 |
|
50 |
+
| UTokyo Xarm Bimanual | 0.03 |
|
51 |
+
| Berkeley MVP | 0.73 |
|
52 |
+
| Berkeley RPT | 1.00 |
|
53 |
+
| KAIST Nonprehensile | 0.46 |
|
54 |
+
| Tokyo U LSMO | 0.23 |
|
55 |
+
| DLR Sara Grid Clamp | 0.03 |
|
56 |
+
| Robocook | 1.66 |
|
57 |
+
| Imperialcollege Sawyer Wrist Cam | 0.43 |
|
58 |
+
| Iamlab CMU Pickup Insert | 0.83 |
|
59 |
+
| UTAustin Mutex | 1.29 |
|
60 |
+
| Fanuc Manipulation | 0.66 |
|
61 |
+
| Play Fusion | 0.80 |
|
62 |
+
| Droid | 10.06 |
|
63 |
+
| FMB| 1.39 |
|
64 |
+
| Dobb·E | 1.20 |
|
65 |
+
| QUT Dexterous Manipulation | 0.46 |
|
66 |
+
| Aloha Dataset | 4.98 |
|
67 |
+
| Mobile Aloha Dataset | 4.98 |
|
68 |
+
| Roboset | 4.48 |
|
69 |
+
| RH20T | 10.99 |
|
70 |
+
| Calvin Dataset | 3.32 |
|
71 |
+
| Bridgev2 | 7.44 |
|
72 |
+
|
73 |
+
Before everything, let's link the dataset directory on your disk to a subfolder of this repo:
|
74 |
+
|
75 |
+
```bash
|
76 |
+
ln -s /path/to/dataset /path/to/repo/RoboticsDiffusionTransformer/data/datasets
|
77 |
+
```
|
78 |
+
|
79 |
+
### Open X-Embodiment
|
80 |
+
|
81 |
+
Specify the correct path to the `gsutil` in your Conda in [this file](../data/openx_embod/download.sh#L72).
|
82 |
+
|
83 |
+
Run the following commands to download our selected datasets for the Open X-Embodiment:
|
84 |
+
|
85 |
+
```bash
|
86 |
+
# Under the root directory of this repo
|
87 |
+
cd data/openx_embod
|
88 |
+
# Download all datasets
|
89 |
+
bash download_openx_embod.sh
|
90 |
+
```
|
91 |
+
|
92 |
+
Note: By modifying `download_openx_embod.sh`, you can download any dataset on the Google Cloud (as long as it can be downloaded with `gsutil` and is stored in `TFRecord` format), not just the ones we have listed.
|
93 |
+
|
94 |
+
### Mobile ALOHA Dataset
|
95 |
+
|
96 |
+
Download the Mobile ALOHA Dataset from the [official website](https://mobile-aloha.github.io) to `data/datasets/aloha`, then run:
|
97 |
+
|
98 |
+
```bash
|
99 |
+
cd data/aloha
|
100 |
+
# Convert the dataset to TFRecord
|
101 |
+
python hdf5totfrecords.py
|
102 |
+
```
|
103 |
+
|
104 |
+
### Bridgev2
|
105 |
+
|
106 |
+
Run:
|
107 |
+
|
108 |
+
```bash
|
109 |
+
cd data/bridgev2
|
110 |
+
# Download and preprocess the dataset
|
111 |
+
sh download.sh
|
112 |
+
```
|
113 |
+
|
114 |
+
### Calvin
|
115 |
+
|
116 |
+
Run:
|
117 |
+
|
118 |
+
```bash
|
119 |
+
cd data/calvin
|
120 |
+
# Download and preprocess the dataset
|
121 |
+
sh download.sh
|
122 |
+
# Convert the dataset to TFRecord format
|
123 |
+
python hdf5totfrecords.py
|
124 |
+
```
|
125 |
+
|
126 |
+
### RH20T
|
127 |
+
|
128 |
+
Download the RH20T Dataset from there [official website](https://rh20t.github.io/#download) to `data/datasets/rh20t`, then run
|
129 |
+
|
130 |
+
```bash
|
131 |
+
cd data/rh20t
|
132 |
+
# Convert the dataset to TFRecord
|
133 |
+
python hdf5totfrecords.py
|
134 |
+
```
|
135 |
+
|
136 |
+
### RoboSet
|
137 |
+
|
138 |
+
Run:
|
139 |
+
|
140 |
+
```bash
|
141 |
+
cd data/roboset
|
142 |
+
# Download and preprocess the dataset
|
143 |
+
sh download.sh
|
144 |
+
```
|
145 |
+
|
146 |
+
## If Want to Train on a New Dataset
|
147 |
+
|
148 |
+
|
149 |
+
If you want to train on a new dataset (e.g., `my_pretrain_dataset`) through this pre-training pipeline, you need to modify several files as follows:
|
150 |
+
|
151 |
+
##### 1. `configs/dataset_control_freq.json`
|
152 |
+
|
153 |
+
Add the control frequency of your dataset.
|
154 |
+
|
155 |
+
##### 2. `data/preprocess_scripts/my_pretrain_dataset.py`
|
156 |
+
|
157 |
+
If your dataset can be loaded by `tfds.builder_from_directory()`, then you only need to download it into the folder of Open X-Embodiment `data/datasets/openx_embod` and implement the function of `process_step()`. You may need to specify the tfds loading path in L78 (see [this file](../data/vla_dataset.py#L78)). We refer to `data/preprocess_scripts/droid.py` for an example.
|
158 |
+
|
159 |
+
If not, you need to first convert it into TFRecords and then implement both `load_dataset()` and `process_step()`. We refer to `data/agilex/hdf5totfrecords.py` and `data/preprocess_scripts/agilex.py` for examples.
|
160 |
+
|
161 |
+
Here some descriptions:
|
162 |
+
|
163 |
+
##### `load_dataset(seed: int)`
|
164 |
+
|
165 |
+
- Returns a dataset that supports iterator and `repeat` method with a random seed.
|
166 |
+
- Suggested implementation: Use `tf.data.Dataset.from_generator` and `tf.data.TFRecordDataset`.
|
167 |
+
- The iterator should return a subdataset that supports iterator representing one episode with the following structure:
|
168 |
+
- `step`: A dataset object that supports iterator containing multiple frames per episode.
|
169 |
+
- `observation`: A dictionary containing your images.
|
170 |
+
- `your_first_image_key`: Your observation RGB image keys.
|
171 |
+
- ...
|
172 |
+
- `other_attribute`: Any other relevant attributes.
|
173 |
+
|
174 |
+
##### `process_step(step: dict) -> dict`
|
175 |
+
|
176 |
+
Processes a single frame and returns a dictionary with the following keys:
|
177 |
+
|
178 |
+
- `observation`:
|
179 |
+
- `your_first_view_image: tf.Tensor`: Your first view image.
|
180 |
+
- `arm_concat: tf.Tensor`: Concatenation of physical states.
|
181 |
+
- `format: tf.constant(string)`: Format of `arm_concat` (e.g., `arm_joint_pos_0,arm_joint_pos_1,arm_joint_pos_2`).
|
182 |
+
- `action`: Frame action (leave empty if there's none).
|
183 |
+
- `arm_concat`: Same as in `observation`.
|
184 |
+
- `format`: Same as in `observation`.
|
185 |
+
- `terminate: tf.Tensor`: Boolean Tensor indicates if the episode ends.
|
186 |
+
|
187 |
+
**IMPORTANT**: You should only use TensorFlow functions for any branch or loop operations. For example, use `tf.cond` instead of `if`.
|
188 |
+
|
189 |
+
##### 3. `configs/dataset_img_keys.json`
|
190 |
+
|
191 |
+
Add the image keys of your dataset. For example:
|
192 |
+
|
193 |
+
```json
|
194 |
+
"my_pretrain_dataset": {
|
195 |
+
"image_keys": [
|
196 |
+
"exterior-cam",
|
197 |
+
"right-wrist-cam",
|
198 |
+
"left-wrist-cam",
|
199 |
+
"left-wrist-cam"
|
200 |
+
],
|
201 |
+
"image_mask": [1, 1, 1, 0]
|
202 |
+
}
|
203 |
+
```
|
204 |
+
|
205 |
+
- To make TensorFlow happy, you have to specify four images in this order: `exterior-cam, right-wrist-cam, left-wrist-cam, any-cam`. Each key should correspond to your `step` attribute key of observation images.
|
206 |
+
|
207 |
+
- If you only have a single wrist, just make it a *right* wrist.
|
208 |
+
|
209 |
+
- The `image_mask` indicates whether each image is valid (1) or not (0).
|
210 |
+
|
211 |
+
- What if you don’t have four images? Simply repeat the images in the following positions and set their masks to 0 (invalid).
|
212 |
+
|
213 |
+
- The key order is *strict*. If you don't have the exterior camera but have both wrists, leave the exterior position blank (or pad) and use the following:
|
214 |
+
|
215 |
+
```json
|
216 |
+
"my_pretrain_dataset": {
|
217 |
+
"image_keys": [
|
218 |
+
"right-wrist-cam",
|
219 |
+
"right-wrist-cam",
|
220 |
+
"left-wrist-cam",
|
221 |
+
"left-wrist-cam"
|
222 |
+
],
|
223 |
+
"image_mask": [0, 1, 1, 0]
|
224 |
+
}
|
225 |
+
```
|
226 |
+
|
227 |
+
- During training, only the first *three* cameras will be used.
|
228 |
+
##### 4. `configs/dataset_stat.json`
|
229 |
+
|
230 |
+
Compute the statistics (min, max, mean, and std) for your dataset:
|
231 |
+
|
232 |
+
```bash
|
233 |
+
# Use -h to see the full usage
|
234 |
+
python -m data.compute_dataset_stat --skip_exist
|
235 |
+
```
|
236 |
+
This will update the `dataset_stat.json` file with your dataset's statistics.
|
237 |
+
|
238 |
+
##### 5. `data/vla_dataset.py`
|
239 |
+
|
240 |
+
- Add your dataset to `DATASET_NAMES_NOOPENX` if it cannot be loaded by `tfds.builder_from_directory()`.
|
241 |
+
- If your dataset only contains action but no proprioception (i.e., robot state), add your dataset to `DATASET_NAMES_NO_STATE` in [this file](../data/preprocess.py).
|
242 |
+
- Normally, we consider the future state as the action of current timestep. If you want to use different actions, you should implement more functions. We refer to `flatten_episode_agilex()` in [this file](../data/episode_transform.py) and `_generate_json_state_agilex()` in [this file](../data/preprocess.py) for examples. You may also refer to L318 in [this file](../data/preprocess.py) and L128 in [this file](../data/vla_dataset.py) for how to select your dataset and preprocess it differently.
|
243 |
+
|
244 |
+
## Start Pre-Training
|
245 |
+
|
246 |
+
We employ a producer-consumer framework with TensorFlow Dataset for fast data loading. Since most of the datasets in the Open X-Embodiment are stored in the form of `TFRecord`, we convert all pre-training datasets into `TFRecord` for storage. In pre-training, we use the producer process to decompress the data from `TFRecord` and store it in a buffer on the hard disk. At the same time, we use the consumer process to read data from the buffer in a disorderly order and feed it to the model training. This not only decouples the `TensorFlow` and `PyTorch` environments but also alleviates the training performance loss caused by the small size of the shuffling buffer in the memory.
|
247 |
+
|
248 |
+
[This file](../configs/base.yaml) includes configurations relevant to model architecture (including number of heads, hidden dimension, and so on) and data processing. You may need to modify `buf_path` (L22) to your real buffer path. This buffer is used as disk shuffling buffer for data loading.
|
249 |
+
|
250 |
+
Configurations relevant to training are passed through *Command Line Arguments*. Use `python main.py -h ` to see the descriptions. We provide an example pre-training script in [this file](../pretrain.sh) (`pretrain.sh`). You may need to modify some of the parameters in this file, such as `CUTLASS_PATH` and `WANDB_PROJECT`.
|
251 |
+
|
252 |
+
You may need to modify the list of pre-training datasets in [this file](../configs/pretrain_datasets.json) and their corresponding sampling weights in [this file](../configs/pretrain_sample_weights.json). If you want to fine-tune RDT through this pipeline, you may need to remove abundant datasets in the list.
|
253 |
+
|
254 |
+
Before start pre-training, we first start the data producer process (if you use multiple nodes, you should run this command in each node):
|
255 |
+
|
256 |
+
```bash
|
257 |
+
# Under the root directory of this repo
|
258 |
+
conda activate rdt-data
|
259 |
+
# Use -h to see the full usage
|
260 |
+
python -m data.producer --fill_up
|
261 |
+
# Please proceed to the next step AFTER finishing the filling up process
|
262 |
+
```
|
263 |
+
|
264 |
+
Then, we run the pre-training script:
|
265 |
+
|
266 |
+
```bash
|
267 |
+
source pretrain.sh
|
268 |
+
```
|
269 |
+
|
270 |
+
Note: You can monitor the training process by observing `loss` (through a long window moving average), `overall_avg_sample_mse`, and the sampling MSE of each dataset in [Wandb](https://wandb.ai/site) or [TensorBoard](https://www.tensorflow.org/tensorboard). We empirically found that the lower the `overall_avg_sample_mse`, the better the model performs.
|
docs/test_6drot.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from scipy.spatial.transform import Rotation as R
|
3 |
+
|
4 |
+
|
5 |
+
def convert_quaternion_to_euler(quat):
|
6 |
+
"""
|
7 |
+
Convert Quarternion (xyzw) to Euler angles (rpy)
|
8 |
+
"""
|
9 |
+
# Normalize
|
10 |
+
quat = quat / np.linalg.norm(quat)
|
11 |
+
euler = R.from_quat(quat).as_euler('xyz')
|
12 |
+
|
13 |
+
return euler
|
14 |
+
|
15 |
+
|
16 |
+
def convert_euler_to_quaternion(euler):
|
17 |
+
"""
|
18 |
+
Convert Euler angles (rpy) to Quarternion (xyzw)
|
19 |
+
"""
|
20 |
+
quat = R.from_euler('xyz', euler).as_quat()
|
21 |
+
|
22 |
+
return quat
|
23 |
+
|
24 |
+
|
25 |
+
def convert_euler_to_rotation_matrix(euler):
|
26 |
+
"""
|
27 |
+
Convert Euler angles (rpy) to rotation matrix (3x3).
|
28 |
+
"""
|
29 |
+
quat = R.from_euler('xyz', euler).as_matrix()
|
30 |
+
|
31 |
+
return quat
|
32 |
+
|
33 |
+
|
34 |
+
def convert_rotation_matrix_to_euler(rotmat):
|
35 |
+
"""
|
36 |
+
Convert rotation matrix (3x3) to Euler angles (rpy).
|
37 |
+
"""
|
38 |
+
r = R.from_matrix(rotmat)
|
39 |
+
euler = r.as_euler('xyz', degrees=False)
|
40 |
+
|
41 |
+
return euler
|
42 |
+
|
43 |
+
|
44 |
+
def normalize_vector(v):
|
45 |
+
v_mag = np.linalg.norm(v, axis=-1, keepdims=True)
|
46 |
+
v_mag = np.maximum(v_mag, 1e-8)
|
47 |
+
return v / v_mag
|
48 |
+
|
49 |
+
|
50 |
+
def cross_product(u, v):
|
51 |
+
i = u[:,1]*v[:,2] - u[:,2]*v[:,1]
|
52 |
+
j = u[:,2]*v[:,0] - u[:,0]*v[:,2]
|
53 |
+
k = u[:,0]*v[:,1] - u[:,1]*v[:,0]
|
54 |
+
|
55 |
+
out = np.stack((i, j, k), axis=1)
|
56 |
+
return out
|
57 |
+
|
58 |
+
|
59 |
+
def compute_rotation_matrix_from_ortho6d(ortho6d):
|
60 |
+
x_raw = ortho6d[:, 0:3]
|
61 |
+
y_raw = ortho6d[:, 3:6]
|
62 |
+
|
63 |
+
x = normalize_vector(x_raw)
|
64 |
+
z = cross_product(x, y_raw)
|
65 |
+
z = normalize_vector(z)
|
66 |
+
y = cross_product(z, x)
|
67 |
+
|
68 |
+
x = x.reshape(-1, 3, 1)
|
69 |
+
y = y.reshape(-1, 3, 1)
|
70 |
+
z = z.reshape(-1, 3, 1)
|
71 |
+
matrix = np.concatenate((x, y, z), axis=2)
|
72 |
+
return matrix
|
73 |
+
|
74 |
+
|
75 |
+
def compute_ortho6d_from_rotation_matrix(matrix):
|
76 |
+
# The ortho6d represents the first two column vectors a1 and a2 of the
|
77 |
+
# rotation matrix: [ | , |, | ]
|
78 |
+
# [ a1, a2, a3]
|
79 |
+
# [ | , |, | ]
|
80 |
+
ortho6d = matrix[:, :, :2].transpose(0, 2, 1).reshape(matrix.shape[0], -1)
|
81 |
+
return ortho6d
|
82 |
+
|
83 |
+
|
84 |
+
# Test
|
85 |
+
if __name__ == "__main__":
|
86 |
+
# Randomly generate a euler ange
|
87 |
+
euler = np.random.rand(3) * 2 * np.pi - np.pi
|
88 |
+
euler = euler[None, :] # Add batch dimension
|
89 |
+
print(f"Input Euler angles: {euler}")
|
90 |
+
|
91 |
+
# Convert to 6D Rotation
|
92 |
+
rotmat = convert_euler_to_rotation_matrix(euler)
|
93 |
+
ortho6d = compute_ortho6d_from_rotation_matrix(rotmat)
|
94 |
+
print(f"6D Rotation: {ortho6d}")
|
95 |
+
|
96 |
+
# Convert back to Euler angles
|
97 |
+
rotmat_recovered = compute_rotation_matrix_from_ortho6d(ortho6d)
|
98 |
+
euler_recovered = convert_rotation_matrix_to_euler(rotmat_recovered)
|
99 |
+
print(f"Recovered Euler angles: {euler_recovered}")
|
eval_sim/eval_dp.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, List, Type
|
2 |
+
import gymnasium as gym
|
3 |
+
import numpy as np
|
4 |
+
from mani_skill.envs.sapien_env import BaseEnv
|
5 |
+
from mani_skill.utils import common, gym_utils
|
6 |
+
import argparse
|
7 |
+
import yaml
|
8 |
+
import torch
|
9 |
+
from collections import deque
|
10 |
+
from PIL import Image
|
11 |
+
import cv2
|
12 |
+
import imageio
|
13 |
+
from functools import partial
|
14 |
+
|
15 |
+
from diffusion_policy.workspace.robotworkspace import RobotWorkspace
|
16 |
+
|
17 |
+
def parse_args(args=None):
|
18 |
+
parser = argparse.ArgumentParser()
|
19 |
+
parser.add_argument("-e", "--env-id", type=str, default="PickCube-v1", help=f"Environment to run motion planning solver on. ")
|
20 |
+
parser.add_argument("-o", "--obs-mode", type=str, default="rgb", help="Observation mode to use. Usually this is kept as 'none' as observations are not necesary to be stored, they can be replayed later via the mani_skill.trajectory.replay_trajectory script.")
|
21 |
+
parser.add_argument("-n", "--num-traj", type=int, default=25, help="Number of trajectories to generate.")
|
22 |
+
parser.add_argument("--only-count-success", action="store_true", help="If true, generates trajectories until num_traj of them are successful and only saves the successful trajectories/videos")
|
23 |
+
parser.add_argument("--reward-mode", type=str)
|
24 |
+
parser.add_argument("-b", "--sim-backend", type=str, default="auto", help="Which simulation backend to use. Can be 'auto', 'cpu', 'gpu'")
|
25 |
+
parser.add_argument("--render-mode", type=str, default="rgb_array", help="can be 'sensors' or 'rgb_array' which only affect what is saved to videos")
|
26 |
+
parser.add_argument("--vis", action="store_true", help="whether or not to open a GUI to visualize the solution live")
|
27 |
+
parser.add_argument("--save-video", action="store_true", help="whether or not to save videos locally")
|
28 |
+
parser.add_argument("--traj-name", type=str, help="The name of the trajectory .h5 file that will be created.")
|
29 |
+
parser.add_argument("--shader", default="default", type=str, help="Change shader used for rendering. Default is 'default' which is very fast. Can also be 'rt' for ray tracing and generating photo-realistic renders. Can also be 'rt-fast' for a faster but lower quality ray-traced renderer")
|
30 |
+
parser.add_argument("--record-dir", type=str, default="demos", help="where to save the recorded trajectories")
|
31 |
+
parser.add_argument("--num-procs", type=int, default=1, help="Number of processes to use to help parallelize the trajectory replay process. This uses CPU multiprocessing and only works with the CPU simulation backend at the moment.")
|
32 |
+
parser.add_argument("--random_seed", type=int, default=0, help="Random seed for the environment.")
|
33 |
+
parser.add_argument("--pretrained_path", type=str, default=None, help="Random seed for the environment.")
|
34 |
+
|
35 |
+
return parser.parse_args()
|
36 |
+
|
37 |
+
task2lang = {
|
38 |
+
"PegInsertionSide-v1": "Pick up a orange-white peg and insert the orange end into the box with a hole in it.",
|
39 |
+
"PickCube-v1": "Grasp a red cube and move it to a target goal position.",
|
40 |
+
"StackCube-v1": "Pick up a red cube and stack it on top of a green cube and let go of the cube without it falling.",
|
41 |
+
"PlugCharger-v1": "Pick up one of the misplaced shapes on the board/kit and insert it into the correct empty slot.",
|
42 |
+
"PushCube-v1": "Push and move a cube to a goal region in front of it."
|
43 |
+
}
|
44 |
+
import random
|
45 |
+
import os
|
46 |
+
|
47 |
+
args = parse_args()
|
48 |
+
seed = args.random_seed
|
49 |
+
random.seed(seed)
|
50 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
51 |
+
np.random.seed(seed)
|
52 |
+
torch.manual_seed(seed)
|
53 |
+
torch.cuda.manual_seed(seed)
|
54 |
+
torch.backends.cudnn.deterministic = True
|
55 |
+
torch.backends.cudnn.benchmark = False
|
56 |
+
|
57 |
+
env_id = args.env_id
|
58 |
+
env = gym.make(
|
59 |
+
env_id,
|
60 |
+
obs_mode=args.obs_mode,
|
61 |
+
control_mode="pd_joint_pos",
|
62 |
+
render_mode=args.render_mode,
|
63 |
+
reward_mode="dense" if args.reward_mode is None else args.reward_mode,
|
64 |
+
sensor_configs=dict(shader_pack=args.shader),
|
65 |
+
human_render_camera_configs=dict(shader_pack=args.shader),
|
66 |
+
viewer_camera_configs=dict(shader_pack=args.shader),
|
67 |
+
sim_backend=args.sim_backend
|
68 |
+
)
|
69 |
+
|
70 |
+
from diffusion_policy.workspace.robotworkspace import RobotWorkspace
|
71 |
+
import hydra
|
72 |
+
import dill
|
73 |
+
|
74 |
+
checkpoint_path = args.pretrained_path
|
75 |
+
print(f"Loading policy from {checkpoint_path}. Task is {task2lang[env_id]}")
|
76 |
+
|
77 |
+
def get_policy(output_dir, device):
|
78 |
+
|
79 |
+
# load checkpoint
|
80 |
+
payload = torch.load(open(checkpoint_path, 'rb'), pickle_module=dill)
|
81 |
+
cfg = payload['cfg']
|
82 |
+
cls = hydra.utils.get_class(cfg._target_)
|
83 |
+
workspace = cls(cfg, output_dir=output_dir)
|
84 |
+
workspace: RobotWorkspace
|
85 |
+
workspace.load_payload(payload, exclude_keys=None, include_keys=None)
|
86 |
+
|
87 |
+
# get policy from workspace
|
88 |
+
policy = workspace.model
|
89 |
+
if cfg.training.use_ema:
|
90 |
+
policy = workspace.ema_model
|
91 |
+
|
92 |
+
device = torch.device(device)
|
93 |
+
policy.to(device)
|
94 |
+
policy.eval()
|
95 |
+
|
96 |
+
return policy
|
97 |
+
|
98 |
+
policy = get_policy('./', device = 'cuda')
|
99 |
+
MAX_EPISODE_STEPS = 400
|
100 |
+
total_episodes = args.num_traj
|
101 |
+
success_count = 0
|
102 |
+
base_seed = 20241201
|
103 |
+
instr = task2lang[env_id]
|
104 |
+
import tqdm
|
105 |
+
|
106 |
+
DATA_STAT = {'state_min': [-0.7463043928146362, -0.0801204964518547, -0.4976441562175751, -2.657780647277832, -0.5742632150650024, 1.8309762477874756, -2.2423808574676514, 0.0, 0.0], 'state_max': [0.7645499110221863, 1.4967026710510254, 0.4650936424732208, -0.3866899907588959, 0.5505855679512024, 3.2900545597076416, 2.5737812519073486, 0.03999999910593033, 0.03999999910593033], 'action_min': [-0.7472005486488342, -0.08631071448326111, -0.4995281398296356, -2.658363103866577, -0.5751323103904724, 1.8290787935256958, -2.245187997817993, -1.0], 'action_max': [0.7654682397842407, 1.4984270334243774, 0.46786263585090637, -0.38181185722351074, 0.5517147779464722, 3.291581630706787, 2.575840711593628, 1.0], 'action_std': [0.2199309915304184, 0.18780815601348877, 0.13044124841690063, 0.30669933557510376, 0.1340624988079071, 0.24968451261520386, 0.9589747190475464, 0.9827960729598999], 'action_mean': [-0.00885344110429287, 0.5523102879524231, -0.007564723491668701, -2.0108158588409424, 0.004714342765510082, 2.615924596786499, 0.08461848646402359, -0.19301606714725494]}
|
107 |
+
|
108 |
+
state_min = torch.tensor(DATA_STAT['state_min']).cuda()
|
109 |
+
state_max = torch.tensor(DATA_STAT['state_max']).cuda()
|
110 |
+
action_min = torch.tensor(DATA_STAT['action_min']).cuda()
|
111 |
+
action_max = torch.tensor(DATA_STAT['action_max']).cuda()
|
112 |
+
|
113 |
+
for episode in tqdm.trange(total_episodes):
|
114 |
+
obs_window = deque(maxlen=2)
|
115 |
+
obs, _ = env.reset(seed = episode + base_seed)
|
116 |
+
|
117 |
+
img = env.render().cuda().float()
|
118 |
+
proprio = obs['agent']['qpos'][:].cuda()
|
119 |
+
proprio = (proprio - state_min) / (state_max - state_min) * 2 - 1
|
120 |
+
obs_window.append({
|
121 |
+
'agent_pos': proprio,
|
122 |
+
"head_cam": img.permute(0, 3, 1, 2),
|
123 |
+
})
|
124 |
+
obs_window.append({
|
125 |
+
'agent_pos': proprio,
|
126 |
+
"head_cam": img.permute(0, 3, 1, 2),
|
127 |
+
})
|
128 |
+
|
129 |
+
global_steps = 0
|
130 |
+
video_frames = []
|
131 |
+
|
132 |
+
success_time = 0
|
133 |
+
done = False
|
134 |
+
|
135 |
+
while global_steps < MAX_EPISODE_STEPS and not done:
|
136 |
+
obs = obs_window[-1]
|
137 |
+
actions = policy.predict_action(obs)
|
138 |
+
actions = actions['action_pred'].squeeze(0)
|
139 |
+
actions = (actions + 1) / 2 * (action_max - action_min) + action_min
|
140 |
+
actions = actions.detach().cpu().numpy()
|
141 |
+
actions = actions[:8]
|
142 |
+
for idx in range(actions.shape[0]):
|
143 |
+
action = actions[idx]
|
144 |
+
obs, reward, terminated, truncated, info = env.step(action)
|
145 |
+
img = env.render().cuda().float()
|
146 |
+
proprio = obs['agent']['qpos'][:].cuda()
|
147 |
+
proprio = (proprio - state_min) / (state_max - state_min) * 2 - 1
|
148 |
+
obs_window.append({
|
149 |
+
'agent_pos': proprio,
|
150 |
+
"head_cam": img.permute(0, 3, 1, 2),
|
151 |
+
})
|
152 |
+
video_frames.append(env.render().squeeze(0).detach().cpu().numpy())
|
153 |
+
global_steps += 1
|
154 |
+
if terminated or truncated:
|
155 |
+
assert "success" in info, sorted(info.keys())
|
156 |
+
if info['success']:
|
157 |
+
done = True
|
158 |
+
success_count += 1
|
159 |
+
break
|
160 |
+
print(f"Trial {episode+1} finished, success: {info['success']}, steps: {global_steps}")
|
161 |
+
|
162 |
+
success_rate = success_count / total_episodes * 100
|
163 |
+
print(f"Tested {total_episodes} episodes, success rate: {success_rate:.2f}%")
|
164 |
+
log_file = f"results_dp_{checkpoint_path.split('/')[-1].split('.')[0]}.txt"
|
165 |
+
with open(log_file, 'a') as f:
|
166 |
+
f.write(f"{args.env_id}:{seed}:{success_count}\n")
|
eval_sim/eval_octo.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, List, Type
|
2 |
+
import gymnasium as gym
|
3 |
+
import numpy as np
|
4 |
+
from mani_skill.envs.sapien_env import BaseEnv
|
5 |
+
from mani_skill.utils import common, gym_utils
|
6 |
+
import argparse
|
7 |
+
import yaml
|
8 |
+
import torch
|
9 |
+
from collections import deque
|
10 |
+
from PIL import Image
|
11 |
+
import cv2
|
12 |
+
from octo.model.octo_model import OctoModel
|
13 |
+
from octo.utils.train_callbacks import supply_rng
|
14 |
+
import imageio
|
15 |
+
import jax
|
16 |
+
import jax.numpy as jnp
|
17 |
+
from octo.utils.train_callbacks import supply_rng
|
18 |
+
from functools import partial
|
19 |
+
|
20 |
+
def parse_args(args=None):
|
21 |
+
parser = argparse.ArgumentParser()
|
22 |
+
parser.add_argument("-e", "--env-id", type=str, default="PickCube-v1", help=f"Environment to run motion planning solver on. ")
|
23 |
+
parser.add_argument("-o", "--obs-mode", type=str, default="rgb", help="Observation mode to use. Usually this is kept as 'none' as observations are not necesary to be stored, they can be replayed later via the mani_skill.trajectory.replay_trajectory script.")
|
24 |
+
parser.add_argument("-n", "--num-traj", type=int, default=25, help="Number of trajectories to generate.")
|
25 |
+
parser.add_argument("--only-count-success", action="store_true", help="If true, generates trajectories until num_traj of them are successful and only saves the successful trajectories/videos")
|
26 |
+
parser.add_argument("--reward-mode", type=str)
|
27 |
+
parser.add_argument("-b", "--sim-backend", type=str, default="auto", help="Which simulation backend to use. Can be 'auto', 'cpu', 'gpu'")
|
28 |
+
parser.add_argument("--render-mode", type=str, default="rgb_array", help="can be 'sensors' or 'rgb_array' which only affect what is saved to videos")
|
29 |
+
parser.add_argument("--vis", action="store_true", help="whether or not to open a GUI to visualize the solution live")
|
30 |
+
parser.add_argument("--save-video", action="store_true", help="whether or not to save videos locally")
|
31 |
+
parser.add_argument("--traj-name", type=str, help="The name of the trajectory .h5 file that will be created.")
|
32 |
+
parser.add_argument("--shader", default="default", type=str, help="Change shader used for rendering. Default is 'default' which is very fast. Can also be 'rt' for ray tracing and generating photo-realistic renders. Can also be 'rt-fast' for a faster but lower quality ray-traced renderer")
|
33 |
+
parser.add_argument("--record-dir", type=str, default="demos", help="where to save the recorded trajectories")
|
34 |
+
parser.add_argument("--num-procs", type=int, default=1, help="Number of processes to use to help parallelize the trajectory replay process. This uses CPU multiprocessing and only works with the CPU simulation backend at the moment.")
|
35 |
+
parser.add_argument("--random_seed", type=int, default=0, help="Random seed for the environment.")
|
36 |
+
parser.add_argument("--pretrained_path", type=str, default=None, help="Path to the pretrained model")
|
37 |
+
return parser.parse_args()
|
38 |
+
|
39 |
+
task2lang = {
|
40 |
+
"PegInsertionSide-v1": "Pick up a orange-white peg and insert the orange end into the box with a hole in it.",
|
41 |
+
"PickCube-v1": "Grasp a red cube and move it to a target goal position.",
|
42 |
+
"StackCube-v1": "Pick up a red cube and stack it on top of a green cube and let go of the cube without it falling.",
|
43 |
+
"PlugCharger-v1": "Pick up one of the misplaced shapes on the board/kit and insert it into the correct empty slot.",
|
44 |
+
"PushCube-v1": "Push and move a cube to a goal region in front of it."
|
45 |
+
}
|
46 |
+
import random
|
47 |
+
import os
|
48 |
+
|
49 |
+
args = parse_args()
|
50 |
+
seed = args.random_seed
|
51 |
+
random.seed(seed)
|
52 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
53 |
+
np.random.seed(seed)
|
54 |
+
torch.manual_seed(seed)
|
55 |
+
torch.cuda.manual_seed(seed)
|
56 |
+
torch.backends.cudnn.deterministic = True
|
57 |
+
torch.backends.cudnn.benchmark = False
|
58 |
+
|
59 |
+
env_id = args.env_id
|
60 |
+
env = gym.make(
|
61 |
+
env_id,
|
62 |
+
obs_mode=args.obs_mode,
|
63 |
+
control_mode="pd_ee_delta_pose",
|
64 |
+
render_mode=args.render_mode,
|
65 |
+
reward_mode="dense" if args.reward_mode is None else args.reward_mode,
|
66 |
+
sensor_configs=dict(shader_pack=args.shader),
|
67 |
+
human_render_camera_configs=dict(shader_pack=args.shader),
|
68 |
+
viewer_camera_configs=dict(shader_pack=args.shader),
|
69 |
+
sim_backend=args.sim_backend
|
70 |
+
)
|
71 |
+
|
72 |
+
def sample_actions(
|
73 |
+
pretrained_model: OctoModel,
|
74 |
+
observations,
|
75 |
+
tasks,
|
76 |
+
rng,
|
77 |
+
):
|
78 |
+
# add batch dim to observations
|
79 |
+
observations = jax.tree_map(lambda x: x[None], observations)
|
80 |
+
actions = pretrained_model.sample_actions(
|
81 |
+
observations,
|
82 |
+
tasks,
|
83 |
+
rng=rng,
|
84 |
+
)
|
85 |
+
# remove batch dim
|
86 |
+
return actions[0]
|
87 |
+
|
88 |
+
pretrain_path = args.pretrained_path
|
89 |
+
step = 1000000
|
90 |
+
model = OctoModel.load_pretrained(
|
91 |
+
pretrain_path,
|
92 |
+
step
|
93 |
+
)
|
94 |
+
|
95 |
+
policy = supply_rng(
|
96 |
+
partial(
|
97 |
+
sample_actions,
|
98 |
+
model,
|
99 |
+
)
|
100 |
+
)
|
101 |
+
|
102 |
+
|
103 |
+
import tensorflow as tf
|
104 |
+
def resize_img(image, size=(256, 256)):
|
105 |
+
image_tf = tf.convert_to_tensor(image, dtype=tf.float32)
|
106 |
+
image_tf = tf.expand_dims(image_tf, axis=0)
|
107 |
+
resized_tf = tf.image.resize(
|
108 |
+
image_tf,
|
109 |
+
size,
|
110 |
+
method=tf.image.ResizeMethod.LANCZOS3,
|
111 |
+
antialias=True
|
112 |
+
)
|
113 |
+
resized_tf = tf.squeeze(resized_tf)
|
114 |
+
resized_img = resized_tf.numpy().astype(np.uint8)
|
115 |
+
return resized_img
|
116 |
+
|
117 |
+
MAX_EPISODE_STEPS = 400
|
118 |
+
total_episodes = args.num_traj
|
119 |
+
success_count = 0
|
120 |
+
base_seed = 20241201
|
121 |
+
import tqdm
|
122 |
+
|
123 |
+
for episode in tqdm.trange(total_episodes):
|
124 |
+
task = model.create_tasks(texts=[task2lang[env_id]])
|
125 |
+
obs_window = deque(maxlen=2)
|
126 |
+
obs, _ = env.reset(seed = base_seed)
|
127 |
+
|
128 |
+
img = env.render().squeeze(0).detach().cpu().numpy()
|
129 |
+
proprio = obs['agent']['qpos'][:]
|
130 |
+
obs_window.append({
|
131 |
+
'proprio': proprio.detach().cpu().numpy(),
|
132 |
+
"image_primary": resize_img(img)[None],
|
133 |
+
"timestep_pad_mask": np.zeros((1),dtype = bool)
|
134 |
+
})
|
135 |
+
obs_window.append({
|
136 |
+
'proprio': proprio.detach().cpu().numpy(),
|
137 |
+
"image_primary": resize_img(img)[None],
|
138 |
+
"timestep_pad_mask": np.ones((1),dtype = bool)
|
139 |
+
})
|
140 |
+
|
141 |
+
global_steps = 0
|
142 |
+
video_frames = []
|
143 |
+
|
144 |
+
success_time = 0
|
145 |
+
done = False
|
146 |
+
|
147 |
+
while global_steps < MAX_EPISODE_STEPS and not done:
|
148 |
+
obs = {
|
149 |
+
'proprio': np.concatenate([obs_window[0]['proprio'], obs_window[1]['proprio']], axis=0),
|
150 |
+
"image_primary": np.concatenate([obs_window[0]['image_primary'], obs_window[1]['image_primary']], axis=0),
|
151 |
+
"timestep_pad_mask": np.concatenate([obs_window[0]['timestep_pad_mask'], obs_window[1]['timestep_pad_mask']], axis=0)
|
152 |
+
}
|
153 |
+
actions = policy(obs, task)
|
154 |
+
actions = jax.device_put(actions, device=jax.devices('cpu')[0])
|
155 |
+
actions = jax.device_get(actions)
|
156 |
+
# actions = actions[0:4]
|
157 |
+
for idx in range(actions.shape[0]):
|
158 |
+
action = actions[idx]
|
159 |
+
obs, reward, terminated, truncated, info = env.step(action)
|
160 |
+
img = env.render().squeeze(0).detach().cpu().numpy()
|
161 |
+
proprio = obs['agent']['qpos'][:]
|
162 |
+
obs_window.append({
|
163 |
+
'proprio': proprio.detach().cpu().numpy(),
|
164 |
+
"image_primary": resize_img(img)[None],
|
165 |
+
"timestep_pad_mask": np.ones((1),dtype = bool)
|
166 |
+
})
|
167 |
+
video_frames.append(img)
|
168 |
+
global_steps += 1
|
169 |
+
if terminated or truncated:
|
170 |
+
assert "success" in info, sorted(info.keys())
|
171 |
+
if info['success']:
|
172 |
+
done = True
|
173 |
+
success_count += 1
|
174 |
+
break
|
175 |
+
print(f"Trial {episode+1} finished, success: {info['success']}, steps: {global_steps}")
|
176 |
+
|
177 |
+
success_rate = success_count / total_episodes * 100
|
178 |
+
print(f"Random seed: {seed}, Pretrained_path: {pretrain_path}")
|
179 |
+
print(f"Tested {total_episodes} episodes, success rate: {success_rate:.2f}%")
|
180 |
+
log_file = "results_octo.log"
|
181 |
+
with open(log_file, 'a') as f:
|
182 |
+
f.write(f"{seed}:{success_count}\n")
|
eval_sim/eval_openvla.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, List, Type
|
2 |
+
import gymnasium as gym
|
3 |
+
import numpy as np
|
4 |
+
from mani_skill.envs.sapien_env import BaseEnv
|
5 |
+
from mani_skill.utils import common, gym_utils
|
6 |
+
import argparse
|
7 |
+
import yaml
|
8 |
+
import torch
|
9 |
+
from collections import deque
|
10 |
+
from PIL import Image
|
11 |
+
import cv2
|
12 |
+
import imageio
|
13 |
+
from functools import partial
|
14 |
+
from torchvision.transforms.functional import center_crop
|
15 |
+
|
16 |
+
def parse_args(args=None):
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument("-e", "--env-id", type=str, default="PickCube-v1", help=f"Environment to run motion planning solver on. ")
|
19 |
+
parser.add_argument("-o", "--obs-mode", type=str, default="rgb", help="Observation mode to use. Usually this is kept as 'none' as observations are not necesary to be stored, they can be replayed later via the mani_skill.trajectory.replay_trajectory script.")
|
20 |
+
parser.add_argument("-n", "--num-traj", type=int, default=25, help="Number of trajectories to generate.")
|
21 |
+
parser.add_argument("--only-count-success", action="store_true", help="If true, generates trajectories until num_traj of them are successful and only saves the successful trajectories/videos")
|
22 |
+
parser.add_argument("--reward-mode", type=str)
|
23 |
+
parser.add_argument("-b", "--sim-backend", type=str, default="auto", help="Which simulation backend to use. Can be 'auto', 'cpu', 'gpu'")
|
24 |
+
parser.add_argument("--render-mode", type=str, default="rgb_array", help="can be 'sensors' or 'rgb_array' which only affect what is saved to videos")
|
25 |
+
parser.add_argument("--vis", action="store_true", help="whether or not to open a GUI to visualize the solution live")
|
26 |
+
parser.add_argument("--save-video", action="store_true", help="whether or not to save videos locally")
|
27 |
+
parser.add_argument("--traj-name", type=str, help="The name of the trajectory .h5 file that will be created.")
|
28 |
+
parser.add_argument("--shader", default="default", type=str, help="Change shader used for rendering. Default is 'default' which is very fast. Can also be 'rt' for ray tracing and generating photo-realistic renders. Can also be 'rt-fast' for a faster but lower quality ray-traced renderer")
|
29 |
+
parser.add_argument("--record-dir", type=str, default="demos", help="where to save the recorded trajectories")
|
30 |
+
parser.add_argument("--num-procs", type=int, default=1, help="Number of processes to use to help parallelize the trajectory replay process. This uses CPU multiprocessing and only works with the CPU simulation backend at the moment.")
|
31 |
+
parser.add_argument("--random_seed", type=int, default=0, help="Random seed for the environment.")
|
32 |
+
parser.add_argument("--pretrained_path", type=str, default=None, help="Path to the pretrained model")
|
33 |
+
return parser.parse_args()
|
34 |
+
|
35 |
+
task2lang = {
|
36 |
+
"PegInsertionSide-v1": "Pick up a orange-white peg and insert the orange end into the box with a hole in it.",
|
37 |
+
"PickCube-v1": "Grasp a red cube and move it to a target goal position.",
|
38 |
+
"StackCube-v1": "Pick up a red cube and stack it on top of a green cube and let go of the cube without it falling.",
|
39 |
+
"PlugCharger-v1": "Pick up one of the misplaced shapes on the board/kit and insert it into the correct empty slot.",
|
40 |
+
"PushCube-v1": "Push and move a cube to a goal region in front of it."
|
41 |
+
}
|
42 |
+
|
43 |
+
import random
|
44 |
+
import os
|
45 |
+
|
46 |
+
args = parse_args()
|
47 |
+
seed = args.random_seed
|
48 |
+
random.seed(seed)
|
49 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
50 |
+
np.random.seed(seed)
|
51 |
+
torch.manual_seed(seed)
|
52 |
+
torch.cuda.manual_seed(seed)
|
53 |
+
torch.backends.cudnn.deterministic = True
|
54 |
+
torch.backends.cudnn.benchmark = False
|
55 |
+
|
56 |
+
from transformers import AutoModelForVision2Seq, AutoProcessor
|
57 |
+
|
58 |
+
DATA_STAT = {'mean': [ 0.00263866, 0.01804881, -0.02151551, -0.00384866, 0.00500441,
|
59 |
+
-0.00057146, -0.26013601], 'std': [0.06639539, 0.1246438 , 0.09675793, 0.03351422, 0.04930534,
|
60 |
+
0.25787726, 0.96762997], 'max': [0.31303197, 0.77948809, 0.42906255, 0.20186238, 0.63990456,
|
61 |
+
0.99999917, 1. ], 'min': [-0.31464151, -0.64183694, -0.62718982, -0.5888508 , -0.97813392,
|
62 |
+
-0.99999928, -1. ], 'q01': [-0.18656027, -0.31995443, -0.24702898, -0.18005923, -0.2164692 ,
|
63 |
+
-0.82366071, -1. ], 'q99': [0.18384692, 0.45547636, 0.27452313, 0.03571117, 0.1188747 ,
|
64 |
+
0.85074112, 1. ]}
|
65 |
+
|
66 |
+
MODEL_PATH = args.pretrained_path
|
67 |
+
|
68 |
+
def make_policy():
|
69 |
+
device = torch.device('cuda')
|
70 |
+
|
71 |
+
processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
72 |
+
vla = AutoModelForVision2Seq.from_pretrained(
|
73 |
+
MODEL_PATH,
|
74 |
+
attn_implementation="flash_attention_2", # [Optional] Requires `flash_attn`
|
75 |
+
torch_dtype=torch.bfloat16,
|
76 |
+
low_cpu_mem_usage=True,
|
77 |
+
trust_remote_code=True
|
78 |
+
).to(device)
|
79 |
+
vla.norm_stats["maniskill"] = {
|
80 |
+
"action": {
|
81 |
+
"min": np.array(DATA_STAT["min"]),
|
82 |
+
"max": np.array(DATA_STAT["max"]),
|
83 |
+
"mean": np.array(DATA_STAT["mean"]),
|
84 |
+
"std": np.array(DATA_STAT["std"]),
|
85 |
+
"q01": np.array(DATA_STAT["q01"]),
|
86 |
+
"q99": np.array(DATA_STAT["q99"]),
|
87 |
+
}
|
88 |
+
}
|
89 |
+
|
90 |
+
vla = vla.eval()
|
91 |
+
|
92 |
+
return vla, processor
|
93 |
+
|
94 |
+
vla, processor = make_policy()
|
95 |
+
success_counts = {}
|
96 |
+
|
97 |
+
for env_id in task2lang.keys():
|
98 |
+
|
99 |
+
env = gym.make(
|
100 |
+
env_id,
|
101 |
+
obs_mode=args.obs_mode,
|
102 |
+
control_mode="pd_ee_delta_pose",
|
103 |
+
render_mode=args.render_mode,
|
104 |
+
reward_mode="dense" if args.reward_mode is None else args.reward_mode,
|
105 |
+
sensor_configs=dict(shader_pack=args.shader),
|
106 |
+
human_render_camera_configs=dict(shader_pack=args.shader),
|
107 |
+
viewer_camera_configs=dict(shader_pack=args.shader),
|
108 |
+
sim_backend=args.sim_backend
|
109 |
+
)
|
110 |
+
|
111 |
+
MAX_EPISODE_STEPS = 400
|
112 |
+
total_episodes = args.num_traj
|
113 |
+
success_count = 0
|
114 |
+
base_seed = 20241201
|
115 |
+
import tqdm
|
116 |
+
|
117 |
+
for episode in tqdm.trange(total_episodes):
|
118 |
+
obs_window = deque(maxlen=2)
|
119 |
+
obs, _ = env.reset(seed = base_seed + episode)
|
120 |
+
|
121 |
+
img = env.render().squeeze(0).detach().cpu().numpy()
|
122 |
+
obs_window.append(img)
|
123 |
+
|
124 |
+
global_steps = 0
|
125 |
+
video_frames = []
|
126 |
+
|
127 |
+
success_time = 0
|
128 |
+
done = False
|
129 |
+
|
130 |
+
while global_steps < MAX_EPISODE_STEPS and not done:
|
131 |
+
obs = obs_window[-1]
|
132 |
+
image_arrs = [
|
133 |
+
obs_window[-1]
|
134 |
+
]
|
135 |
+
images = [Image.fromarray(arr) for arr in image_arrs]
|
136 |
+
original_size = images[0].size
|
137 |
+
crop_scale = 0.9
|
138 |
+
sqrt_crop_scale = crop_scale
|
139 |
+
sqrt_crop_scale = np.sqrt(crop_scale)
|
140 |
+
images = [
|
141 |
+
center_crop(
|
142 |
+
img, output_size=(
|
143 |
+
int(sqrt_crop_scale * img.size[1]),
|
144 |
+
int(sqrt_crop_scale * img.size[0])
|
145 |
+
)
|
146 |
+
) for img in images
|
147 |
+
]
|
148 |
+
images = [img.resize(original_size, Image.Resampling.BILINEAR) for img in images]
|
149 |
+
# de-capitalize and remove trailing period
|
150 |
+
instruction = task2lang[env_id].lower()
|
151 |
+
prompt = f"In: What action should the robot take to {instruction}?\nOut:"
|
152 |
+
inputs = processor(prompt, images).to("cuda:0", dtype=torch.bfloat16)
|
153 |
+
actions = vla.predict_action(**inputs, unnorm_key="maniskill", do_sample=False)[None]
|
154 |
+
for idx in range(actions.shape[0]):
|
155 |
+
action = actions[idx]
|
156 |
+
# print(action)
|
157 |
+
# action = action * (np.array(DATA_STAT['std']) + 1e-8) + np.array(DATA_STAT['mean'])
|
158 |
+
obs, reward, terminated, truncated, info = env.step(action)
|
159 |
+
img = env.render().squeeze(0).detach().cpu().numpy()
|
160 |
+
obs_window.append(img)
|
161 |
+
video_frames.append(img)
|
162 |
+
global_steps += 1
|
163 |
+
if terminated or truncated:
|
164 |
+
assert "success" in info, sorted(info.keys())
|
165 |
+
if info['success']:
|
166 |
+
success_count += 1
|
167 |
+
done = True
|
168 |
+
break
|
169 |
+
print(f"Trial {episode+1} finished, success: {info['success']}, steps: {global_steps}")
|
170 |
+
success_counts[env_id] = success_count
|
171 |
+
print(f"Task {env_id} finished, success: {success_count}/{total_episodes}")
|
172 |
+
|
173 |
+
log_file = "results_ovla_all.log"
|
174 |
+
with open(log_file, 'a') as f:
|
175 |
+
f.write(f"{seed}:{success_counts}\n")
|
eval_sim/eval_rdt_maniskill.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, List, Type
|
2 |
+
import sys
|
3 |
+
sys.path.append('/')
|
4 |
+
import gymnasium as gym
|
5 |
+
import numpy as np
|
6 |
+
from mani_skill.envs.sapien_env import BaseEnv
|
7 |
+
from mani_skill.utils import common, gym_utils
|
8 |
+
import argparse
|
9 |
+
import yaml
|
10 |
+
from scripts.maniskill_model import create_model, RoboticDiffusionTransformerModel
|
11 |
+
import torch
|
12 |
+
from collections import deque
|
13 |
+
from PIL import Image
|
14 |
+
import cv2
|
15 |
+
|
16 |
+
def parse_args(args=None):
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument("-e", "--env-id", type=str, default="PickCube-v1", help=f"Environment to run motion planning solver on. ")
|
19 |
+
parser.add_argument("-o", "--obs-mode", type=str, default="rgb", help="Observation mode to use. Usually this is kept as 'none' as observations are not necesary to be stored, they can be replayed later via the mani_skill.trajectory.replay_trajectory script.")
|
20 |
+
parser.add_argument("-n", "--num-traj", type=int, default=25, help="Number of trajectories to test.")
|
21 |
+
parser.add_argument("--only-count-success", action="store_true", help="If true, generates trajectories until num_traj of them are successful and only saves the successful trajectories/videos")
|
22 |
+
parser.add_argument("--reward-mode", type=str)
|
23 |
+
parser.add_argument("-b", "--sim-backend", type=str, default="auto", help="Which simulation backend to use. Can be 'auto', 'cpu', 'gpu'")
|
24 |
+
parser.add_argument("--render-mode", type=str, default="rgb_array", help="can be 'sensors' or 'rgb_array' which only affect what is saved to videos")
|
25 |
+
parser.add_argument("--shader", default="default", type=str, help="Change shader used for rendering. Default is 'default' which is very fast. Can also be 'rt' for ray tracing and generating photo-realistic renders. Can also be 'rt-fast' for a faster but lower quality ray-traced renderer")
|
26 |
+
parser.add_argument("--num-procs", type=int, default=1, help="Number of processes to use to help parallelize the trajectory replay process. This uses CPU multiprocessing and only works with the CPU simulation backend at the moment.")
|
27 |
+
parser.add_argument("--pretrained_path", type=str, default=None, help="Path to the pretrained model")
|
28 |
+
parser.add_argument("--random_seed", type=int, default=0, help="Random seed for the environment.")
|
29 |
+
return parser.parse_args()
|
30 |
+
|
31 |
+
import random
|
32 |
+
import os
|
33 |
+
|
34 |
+
# set cuda
|
35 |
+
args = parse_args()
|
36 |
+
# set random seeds
|
37 |
+
seed = args.random_seed
|
38 |
+
random.seed(seed)
|
39 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
40 |
+
np.random.seed(seed)
|
41 |
+
torch.manual_seed(seed)
|
42 |
+
torch.cuda.manual_seed(seed)
|
43 |
+
torch.backends.cudnn.deterministic = True
|
44 |
+
torch.backends.cudnn.benchmark = False
|
45 |
+
|
46 |
+
task2lang = {
|
47 |
+
"PegInsertionSide-v1": "Pick up a orange-white peg and insert the orange end into the box with a hole in it.",
|
48 |
+
"PickCube-v1": "Grasp a red cube and move it to a target goal position.",
|
49 |
+
"StackCube-v1": "Pick up a red cube and stack it on top of a green cube and let go of the cube without it falling.",
|
50 |
+
"PlugCharger-v1": "Pick up one of the misplaced shapes on the board/kit and insert it into the correct empty slot.",
|
51 |
+
"PushCube-v1": "Push and move a cube to a goal region in front of it."
|
52 |
+
}
|
53 |
+
|
54 |
+
env_id = args.env_id
|
55 |
+
env = gym.make(
|
56 |
+
env_id,
|
57 |
+
obs_mode=args.obs_mode,
|
58 |
+
control_mode="pd_joint_pos",
|
59 |
+
render_mode=args.render_mode,
|
60 |
+
reward_mode="dense" if args.reward_mode is None else args.reward_mode,
|
61 |
+
sensor_configs=dict(shader_pack=args.shader),
|
62 |
+
human_render_camera_configs=dict(shader_pack=args.shader),
|
63 |
+
viewer_camera_configs=dict(shader_pack=args.shader),
|
64 |
+
sim_backend=args.sim_backend
|
65 |
+
)
|
66 |
+
|
67 |
+
config_path = 'configs/base.yaml'
|
68 |
+
with open(config_path, "r") as fp:
|
69 |
+
config = yaml.safe_load(fp)
|
70 |
+
pretrained_text_encoder_name_or_path = "google/t5-v1_1-xxl"
|
71 |
+
pretrained_vision_encoder_name_or_path = "google/siglip-so400m-patch14-384"
|
72 |
+
pretrained_path = args.pretrained_path
|
73 |
+
policy = create_model(
|
74 |
+
args=config,
|
75 |
+
dtype=torch.bfloat16,
|
76 |
+
pretrained=pretrained_path,
|
77 |
+
pretrained_text_encoder_name_or_path=pretrained_text_encoder_name_or_path,
|
78 |
+
pretrained_vision_encoder_name_or_path=pretrained_vision_encoder_name_or_path
|
79 |
+
)
|
80 |
+
|
81 |
+
if os.path.exists(f'text_embed_{env_id}.pt'):
|
82 |
+
text_embed = torch.load(f'text_embed_{env_id}.pt')
|
83 |
+
else:
|
84 |
+
text_embed = policy.encode_instruction(task2lang[env_id])
|
85 |
+
torch.save(text_embed, f'text_embed_{env_id}.pt')
|
86 |
+
|
87 |
+
MAX_EPISODE_STEPS = 400
|
88 |
+
total_episodes = args.num_traj
|
89 |
+
success_count = 0
|
90 |
+
|
91 |
+
base_seed = 20241201
|
92 |
+
import tqdm
|
93 |
+
for episode in tqdm.trange(total_episodes):
|
94 |
+
obs_window = deque(maxlen=2)
|
95 |
+
obs, _ = env.reset(seed = episode + base_seed)
|
96 |
+
policy.reset()
|
97 |
+
|
98 |
+
img = env.render().squeeze(0).detach().cpu().numpy()
|
99 |
+
obs_window.append(None)
|
100 |
+
obs_window.append(np.array(img))
|
101 |
+
proprio = obs['agent']['qpos'][:, :-1]
|
102 |
+
|
103 |
+
global_steps = 0
|
104 |
+
video_frames = []
|
105 |
+
|
106 |
+
success_time = 0
|
107 |
+
done = False
|
108 |
+
|
109 |
+
while global_steps < MAX_EPISODE_STEPS and not done:
|
110 |
+
image_arrs = []
|
111 |
+
for window_img in obs_window:
|
112 |
+
image_arrs.append(window_img)
|
113 |
+
image_arrs.append(None)
|
114 |
+
image_arrs.append(None)
|
115 |
+
images = [Image.fromarray(arr) if arr is not None else None
|
116 |
+
for arr in image_arrs]
|
117 |
+
actions = policy.step(proprio, images, text_embed).squeeze(0).cpu().numpy()
|
118 |
+
# Take 8 steps since RDT is trained to predict interpolated 64 steps(actual 14 steps)
|
119 |
+
actions = actions[::4, :]
|
120 |
+
for idx in range(actions.shape[0]):
|
121 |
+
action = actions[idx]
|
122 |
+
obs, reward, terminated, truncated, info = env.step(action)
|
123 |
+
img = env.render().squeeze(0).detach().cpu().numpy()
|
124 |
+
obs_window.append(img)
|
125 |
+
proprio = obs['agent']['qpos'][:, :-1]
|
126 |
+
video_frames.append(img)
|
127 |
+
global_steps += 1
|
128 |
+
if terminated or truncated:
|
129 |
+
assert "success" in info, sorted(info.keys())
|
130 |
+
if info['success']:
|
131 |
+
success_count += 1
|
132 |
+
done = True
|
133 |
+
break
|
134 |
+
print(f"Trial {episode+1} finished, success: {info['success']}, steps: {global_steps}")
|
135 |
+
|
136 |
+
success_rate = success_count / total_episodes * 100
|
137 |
+
print(f"Success rate: {success_rate}%")
|
lang_embed/aloha_dish_drainer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:903018f06a23f7d8b97480b5bf304442f0593a4d854dc9e0e0fd70822c52b82e
|
3 |
+
size 99667
|
lang_embed/aloha_handover_box.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4a343452e908910230df6ca0045320e121f2f59969314c6a1a09af88192b5e81
|
3 |
+
size 91475
|
lang_embed/aloha_lift_box.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7e2953aa4aad84e687f08f5819f3b3e6c3d4671cbb1d95539dacf4dcad2e9142
|
3 |
+
size 83263
|
lang_embed/aloha_shoes_table.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f606c5cbb856de1c7a362d7ab437d098abc76e78b2b731f77ded048851b87900
|
3 |
+
size 132494
|
lang_embed/anubis_brush_to_pan.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1155d325416a6d1be4a70bb49d7b15272ef0d75dc8a0437a42a9f4660c607a49
|
3 |
+
size 66904
|
lang_embed/anubis_carrot_to_bag.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:394c01aa62fbfd6fdf3d6b53684b10e377820b0a4d0c9425e08b549862beedb6
|
3 |
+
size 83293
|
lang_embed/anubis_towel_kirby.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:92e180dc03d6b6841810bafa499aeee1c43d45a9206b6401426b105bad1fd966
|
3 |
+
size 83283
|
scripts/agilex_inference.py
ADDED
@@ -0,0 +1,658 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/home/lin/software/miniconda3/envs/aloha/bin/python
|
2 |
+
# -- coding: UTF-8
|
3 |
+
"""
|
4 |
+
#!/usr/bin/python3
|
5 |
+
"""
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import sys
|
9 |
+
import threading
|
10 |
+
import time
|
11 |
+
import yaml
|
12 |
+
from collections import deque
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import rospy
|
16 |
+
import torch
|
17 |
+
from cv_bridge import CvBridge
|
18 |
+
from geometry_msgs.msg import Twist
|
19 |
+
from nav_msgs.msg import Odometry
|
20 |
+
from PIL import Image as PImage
|
21 |
+
from sensor_msgs.msg import Image, JointState
|
22 |
+
from std_msgs.msg import Header
|
23 |
+
import cv2
|
24 |
+
|
25 |
+
from scripts.agilex_model import create_model
|
26 |
+
|
27 |
+
# sys.path.append("./")
|
28 |
+
|
29 |
+
CAMERA_NAMES = ['cam_high', 'cam_right_wrist', 'cam_left_wrist']
|
30 |
+
|
31 |
+
observation_window = None
|
32 |
+
|
33 |
+
lang_embeddings = None
|
34 |
+
|
35 |
+
# debug
|
36 |
+
preload_images = None
|
37 |
+
|
38 |
+
|
39 |
+
# Initialize the model
|
40 |
+
def make_policy(args):
|
41 |
+
with open(args.config_path, "r") as fp:
|
42 |
+
config = yaml.safe_load(fp)
|
43 |
+
args.config = config
|
44 |
+
|
45 |
+
# pretrained_text_encoder_name_or_path = "google/t5-v1_1-xxl"
|
46 |
+
pretrained_vision_encoder_name_or_path = "google/siglip-so400m-patch14-384"
|
47 |
+
model = create_model(
|
48 |
+
args=args.config,
|
49 |
+
dtype=torch.bfloat16,
|
50 |
+
pretrained=args.pretrained_model_name_or_path,
|
51 |
+
# pretrained_text_encoder_name_or_path=pretrained_text_encoder_name_or_path,
|
52 |
+
pretrained_vision_encoder_name_or_path=pretrained_vision_encoder_name_or_path,
|
53 |
+
control_frequency=args.ctrl_freq,
|
54 |
+
)
|
55 |
+
|
56 |
+
return model
|
57 |
+
|
58 |
+
|
59 |
+
def set_seed(seed):
|
60 |
+
torch.manual_seed(seed)
|
61 |
+
np.random.seed(seed)
|
62 |
+
|
63 |
+
|
64 |
+
# Interpolate the actions to make the robot move smoothly
|
65 |
+
def interpolate_action(args, prev_action, cur_action):
|
66 |
+
steps = np.concatenate((np.array(args.arm_steps_length), np.array(args.arm_steps_length)), axis=0)
|
67 |
+
diff = np.abs(cur_action - prev_action)
|
68 |
+
step = np.ceil(diff / steps).astype(int)
|
69 |
+
step = np.max(step)
|
70 |
+
if step <= 1:
|
71 |
+
return cur_action[np.newaxis, :]
|
72 |
+
new_actions = np.linspace(prev_action, cur_action, step + 1)
|
73 |
+
return new_actions[1:]
|
74 |
+
|
75 |
+
|
76 |
+
def get_config(args):
|
77 |
+
config = {
|
78 |
+
'episode_len': args.max_publish_step,
|
79 |
+
'state_dim': 14,
|
80 |
+
'chunk_size': args.chunk_size,
|
81 |
+
'camera_names': CAMERA_NAMES,
|
82 |
+
}
|
83 |
+
return config
|
84 |
+
|
85 |
+
|
86 |
+
# Get the observation from the ROS topic
|
87 |
+
def get_ros_observation(args,ros_operator):
|
88 |
+
rate = rospy.Rate(args.publish_rate)
|
89 |
+
print_flag = True
|
90 |
+
|
91 |
+
while True and not rospy.is_shutdown():
|
92 |
+
result = ros_operator.get_frame()
|
93 |
+
if not result:
|
94 |
+
if print_flag:
|
95 |
+
print("syn fail when get_ros_observation")
|
96 |
+
print_flag = False
|
97 |
+
rate.sleep()
|
98 |
+
continue
|
99 |
+
print_flag = True
|
100 |
+
(img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
|
101 |
+
puppet_arm_left, puppet_arm_right, robot_base) = result
|
102 |
+
# print(f"sync success when get_ros_observation")
|
103 |
+
return (img_front, img_left, img_right,
|
104 |
+
puppet_arm_left, puppet_arm_right)
|
105 |
+
|
106 |
+
|
107 |
+
# Update the observation window buffer
|
108 |
+
def update_observation_window(args, config, ros_operator):
|
109 |
+
# JPEG transformation
|
110 |
+
# Align with training
|
111 |
+
def jpeg_mapping(img):
|
112 |
+
img = cv2.imencode('.jpg', img)[1].tobytes()
|
113 |
+
img = cv2.imdecode(np.frombuffer(img, np.uint8), cv2.IMREAD_COLOR)
|
114 |
+
return img
|
115 |
+
|
116 |
+
global observation_window
|
117 |
+
if observation_window is None:
|
118 |
+
observation_window = deque(maxlen=2)
|
119 |
+
|
120 |
+
# Append the first dummy image
|
121 |
+
observation_window.append(
|
122 |
+
{
|
123 |
+
'qpos': None,
|
124 |
+
'images':
|
125 |
+
{
|
126 |
+
config["camera_names"][0]: None,
|
127 |
+
config["camera_names"][1]: None,
|
128 |
+
config["camera_names"][2]: None,
|
129 |
+
},
|
130 |
+
}
|
131 |
+
)
|
132 |
+
|
133 |
+
img_front, img_left, img_right, puppet_arm_left, puppet_arm_right = get_ros_observation(args,ros_operator)
|
134 |
+
img_front = jpeg_mapping(img_front)
|
135 |
+
img_left = jpeg_mapping(img_left)
|
136 |
+
img_right = jpeg_mapping(img_right)
|
137 |
+
|
138 |
+
qpos = np.concatenate(
|
139 |
+
(np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)), axis=0)
|
140 |
+
qpos = torch.from_numpy(qpos).float().cuda()
|
141 |
+
observation_window.append(
|
142 |
+
{
|
143 |
+
'qpos': qpos,
|
144 |
+
'images':
|
145 |
+
{
|
146 |
+
config["camera_names"][0]: img_front,
|
147 |
+
config["camera_names"][1]: img_right,
|
148 |
+
config["camera_names"][2]: img_left,
|
149 |
+
},
|
150 |
+
}
|
151 |
+
)
|
152 |
+
|
153 |
+
|
154 |
+
# RDT inference
|
155 |
+
def inference_fn(args, config, policy, t):
|
156 |
+
global observation_window
|
157 |
+
global lang_embeddings
|
158 |
+
|
159 |
+
# print(f"Start inference_thread_fn: t={t}")
|
160 |
+
while True and not rospy.is_shutdown():
|
161 |
+
time1 = time.time()
|
162 |
+
|
163 |
+
# fetch images in sequence [front, right, left]
|
164 |
+
image_arrs = [
|
165 |
+
observation_window[-2]['images'][config['camera_names'][0]],
|
166 |
+
observation_window[-2]['images'][config['camera_names'][1]],
|
167 |
+
observation_window[-2]['images'][config['camera_names'][2]],
|
168 |
+
|
169 |
+
observation_window[-1]['images'][config['camera_names'][0]],
|
170 |
+
observation_window[-1]['images'][config['camera_names'][1]],
|
171 |
+
observation_window[-1]['images'][config['camera_names'][2]]
|
172 |
+
]
|
173 |
+
|
174 |
+
# fetch debug images in sequence [front, right, left]
|
175 |
+
# image_arrs = [
|
176 |
+
# preload_images[config['camera_names'][0]][max(t - 1, 0)],
|
177 |
+
# preload_images[config['camera_names'][2]][max(t - 1, 0)],
|
178 |
+
# preload_images[config['camera_names'][1]][max(t - 1, 0)],
|
179 |
+
# preload_images[config['camera_names'][0]][t],
|
180 |
+
# preload_images[config['camera_names'][2]][t],
|
181 |
+
# preload_images[config['camera_names'][1]][t]
|
182 |
+
# ]
|
183 |
+
# # encode the images
|
184 |
+
# for i in range(len(image_arrs)):
|
185 |
+
# image_arrs[i] = cv2.imdecode(np.frombuffer(image_arrs[i], np.uint8), cv2.IMREAD_COLOR)
|
186 |
+
# proprio = torch.from_numpy(preload_images['qpos'][t]).float().cuda()
|
187 |
+
|
188 |
+
images = [PImage.fromarray(arr) if arr is not None else None
|
189 |
+
for arr in image_arrs]
|
190 |
+
|
191 |
+
# for i, pos in enumerate(['f', 'r', 'l'] * 2):
|
192 |
+
# images[i].save(f'{t}-{i}-{pos}.png')
|
193 |
+
|
194 |
+
# get last qpos in shape [14, ]
|
195 |
+
proprio = observation_window[-1]['qpos']
|
196 |
+
# unsqueeze to [1, 14]
|
197 |
+
proprio = proprio.unsqueeze(0)
|
198 |
+
|
199 |
+
# actions shaped as [1, 64, 14] in format [left, right]
|
200 |
+
actions = policy.step(
|
201 |
+
proprio=proprio,
|
202 |
+
images=images,
|
203 |
+
text_embeds=lang_embeddings
|
204 |
+
).squeeze(0).cpu().numpy()
|
205 |
+
# print(f"inference_actions: {actions.squeeze()}")
|
206 |
+
|
207 |
+
print(f"Model inference time: {time.time() - time1} s")
|
208 |
+
|
209 |
+
# print(f"Finish inference_thread_fn: t={t}")
|
210 |
+
return actions
|
211 |
+
|
212 |
+
|
213 |
+
# Main loop for the manipulation task
|
214 |
+
def model_inference(args, config, ros_operator):
|
215 |
+
global lang_embeddings
|
216 |
+
|
217 |
+
# Load rdt model
|
218 |
+
policy = make_policy(args)
|
219 |
+
|
220 |
+
lang_dict = torch.load(args.lang_embeddings_path)
|
221 |
+
print(f"Running with instruction: \"{lang_dict['instruction']}\" from \"{lang_dict['name']}\"")
|
222 |
+
lang_embeddings = lang_dict["embeddings"]
|
223 |
+
|
224 |
+
max_publish_step = config['episode_len']
|
225 |
+
chunk_size = config['chunk_size']
|
226 |
+
|
227 |
+
# Initialize position of the puppet arm
|
228 |
+
left0 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, 3.557830810546875]
|
229 |
+
right0 = [-0.00133514404296875, 0.00438690185546875, 0.034523963928222656, -0.053597450256347656, -0.00476837158203125, -0.00209808349609375, 3.557830810546875]
|
230 |
+
left1 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3393220901489258]
|
231 |
+
right1 = [-0.00133514404296875, 0.00247955322265625, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3397035598754883]
|
232 |
+
ros_operator.puppet_arm_publish_continuous(left0, right0)
|
233 |
+
input("Press enter to continue")
|
234 |
+
ros_operator.puppet_arm_publish_continuous(left1, right1)
|
235 |
+
# Initialize the previous action to be the initial robot state
|
236 |
+
pre_action = np.zeros(config['state_dim'])
|
237 |
+
pre_action[:14] = np.array(
|
238 |
+
[-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3393220901489258] +
|
239 |
+
[-0.00133514404296875, 0.00247955322265625, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3397035598754883]
|
240 |
+
)
|
241 |
+
action = None
|
242 |
+
# Inference loop
|
243 |
+
with torch.inference_mode():
|
244 |
+
while True and not rospy.is_shutdown():
|
245 |
+
# The current time step
|
246 |
+
t = 0
|
247 |
+
rate = rospy.Rate(args.publish_rate)
|
248 |
+
|
249 |
+
action_buffer = np.zeros([chunk_size, config['state_dim']])
|
250 |
+
|
251 |
+
while t < max_publish_step and not rospy.is_shutdown():
|
252 |
+
# Update observation window
|
253 |
+
update_observation_window(args, config, ros_operator)
|
254 |
+
|
255 |
+
# When coming to the end of the action chunk
|
256 |
+
if t % chunk_size == 0:
|
257 |
+
# Start inference
|
258 |
+
action_buffer = inference_fn(args, config, policy, t).copy()
|
259 |
+
|
260 |
+
raw_action = action_buffer[t % chunk_size]
|
261 |
+
action = raw_action
|
262 |
+
# Interpolate the original action sequence
|
263 |
+
if args.use_actions_interpolation:
|
264 |
+
# print(f"Time {t}, pre {pre_action}, act {action}")
|
265 |
+
interp_actions = interpolate_action(args, pre_action, action)
|
266 |
+
else:
|
267 |
+
interp_actions = action[np.newaxis, :]
|
268 |
+
# Execute the interpolated actions one by one
|
269 |
+
for act in interp_actions:
|
270 |
+
left_action = act[:7]
|
271 |
+
right_action = act[7:14]
|
272 |
+
|
273 |
+
if not args.disable_puppet_arm:
|
274 |
+
ros_operator.puppet_arm_publish(left_action, right_action) # puppet_arm_publish_continuous_thread
|
275 |
+
|
276 |
+
if args.use_robot_base:
|
277 |
+
vel_action = act[14:16]
|
278 |
+
ros_operator.robot_base_publish(vel_action)
|
279 |
+
rate.sleep()
|
280 |
+
# print(f"doing action: {act}")
|
281 |
+
t += 1
|
282 |
+
|
283 |
+
print("Published Step", t)
|
284 |
+
pre_action = action.copy()
|
285 |
+
|
286 |
+
|
287 |
+
# ROS operator class
|
288 |
+
class RosOperator:
|
289 |
+
def __init__(self, args):
|
290 |
+
self.robot_base_deque = None
|
291 |
+
self.puppet_arm_right_deque = None
|
292 |
+
self.puppet_arm_left_deque = None
|
293 |
+
self.img_front_deque = None
|
294 |
+
self.img_right_deque = None
|
295 |
+
self.img_left_deque = None
|
296 |
+
self.img_front_depth_deque = None
|
297 |
+
self.img_right_depth_deque = None
|
298 |
+
self.img_left_depth_deque = None
|
299 |
+
self.bridge = None
|
300 |
+
self.puppet_arm_left_publisher = None
|
301 |
+
self.puppet_arm_right_publisher = None
|
302 |
+
self.robot_base_publisher = None
|
303 |
+
self.puppet_arm_publish_thread = None
|
304 |
+
self.puppet_arm_publish_lock = None
|
305 |
+
self.args = args
|
306 |
+
self.init()
|
307 |
+
self.init_ros()
|
308 |
+
|
309 |
+
def init(self):
|
310 |
+
self.bridge = CvBridge()
|
311 |
+
self.img_left_deque = deque()
|
312 |
+
self.img_right_deque = deque()
|
313 |
+
self.img_front_deque = deque()
|
314 |
+
self.img_left_depth_deque = deque()
|
315 |
+
self.img_right_depth_deque = deque()
|
316 |
+
self.img_front_depth_deque = deque()
|
317 |
+
self.puppet_arm_left_deque = deque()
|
318 |
+
self.puppet_arm_right_deque = deque()
|
319 |
+
self.robot_base_deque = deque()
|
320 |
+
self.puppet_arm_publish_lock = threading.Lock()
|
321 |
+
self.puppet_arm_publish_lock.acquire()
|
322 |
+
|
323 |
+
def puppet_arm_publish(self, left, right):
|
324 |
+
joint_state_msg = JointState()
|
325 |
+
joint_state_msg.header = Header()
|
326 |
+
joint_state_msg.header.stamp = rospy.Time.now() # Set timestep
|
327 |
+
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
|
328 |
+
joint_state_msg.position = left
|
329 |
+
self.puppet_arm_left_publisher.publish(joint_state_msg)
|
330 |
+
joint_state_msg.position = right
|
331 |
+
self.puppet_arm_right_publisher.publish(joint_state_msg)
|
332 |
+
|
333 |
+
def robot_base_publish(self, vel):
|
334 |
+
vel_msg = Twist()
|
335 |
+
vel_msg.linear.x = vel[0]
|
336 |
+
vel_msg.linear.y = 0
|
337 |
+
vel_msg.linear.z = 0
|
338 |
+
vel_msg.angular.x = 0
|
339 |
+
vel_msg.angular.y = 0
|
340 |
+
vel_msg.angular.z = vel[1]
|
341 |
+
self.robot_base_publisher.publish(vel_msg)
|
342 |
+
|
343 |
+
def puppet_arm_publish_continuous(self, left, right):
|
344 |
+
rate = rospy.Rate(self.args.publish_rate)
|
345 |
+
left_arm = None
|
346 |
+
right_arm = None
|
347 |
+
while True and not rospy.is_shutdown():
|
348 |
+
if len(self.puppet_arm_left_deque) != 0:
|
349 |
+
left_arm = list(self.puppet_arm_left_deque[-1].position)
|
350 |
+
if len(self.puppet_arm_right_deque) != 0:
|
351 |
+
right_arm = list(self.puppet_arm_right_deque[-1].position)
|
352 |
+
if left_arm is None or right_arm is None:
|
353 |
+
rate.sleep()
|
354 |
+
continue
|
355 |
+
else:
|
356 |
+
break
|
357 |
+
left_symbol = [1 if left[i] - left_arm[i] > 0 else -1 for i in range(len(left))]
|
358 |
+
right_symbol = [1 if right[i] - right_arm[i] > 0 else -1 for i in range(len(right))]
|
359 |
+
flag = True
|
360 |
+
step = 0
|
361 |
+
while flag and not rospy.is_shutdown():
|
362 |
+
if self.puppet_arm_publish_lock.acquire(False):
|
363 |
+
return
|
364 |
+
left_diff = [abs(left[i] - left_arm[i]) for i in range(len(left))]
|
365 |
+
right_diff = [abs(right[i] - right_arm[i]) for i in range(len(right))]
|
366 |
+
flag = False
|
367 |
+
for i in range(len(left)):
|
368 |
+
if left_diff[i] < self.args.arm_steps_length[i]:
|
369 |
+
left_arm[i] = left[i]
|
370 |
+
else:
|
371 |
+
left_arm[i] += left_symbol[i] * self.args.arm_steps_length[i]
|
372 |
+
flag = True
|
373 |
+
for i in range(len(right)):
|
374 |
+
if right_diff[i] < self.args.arm_steps_length[i]:
|
375 |
+
right_arm[i] = right[i]
|
376 |
+
else:
|
377 |
+
right_arm[i] += right_symbol[i] * self.args.arm_steps_length[i]
|
378 |
+
flag = True
|
379 |
+
joint_state_msg = JointState()
|
380 |
+
joint_state_msg.header = Header()
|
381 |
+
joint_state_msg.header.stamp = rospy.Time.now() # Set the timestep
|
382 |
+
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
|
383 |
+
joint_state_msg.position = left_arm
|
384 |
+
self.puppet_arm_left_publisher.publish(joint_state_msg)
|
385 |
+
joint_state_msg.position = right_arm
|
386 |
+
self.puppet_arm_right_publisher.publish(joint_state_msg)
|
387 |
+
step += 1
|
388 |
+
print("puppet_arm_publish_continuous:", step)
|
389 |
+
rate.sleep()
|
390 |
+
|
391 |
+
def puppet_arm_publish_linear(self, left, right):
|
392 |
+
num_step = 100
|
393 |
+
rate = rospy.Rate(200)
|
394 |
+
|
395 |
+
left_arm = None
|
396 |
+
right_arm = None
|
397 |
+
|
398 |
+
while True and not rospy.is_shutdown():
|
399 |
+
if len(self.puppet_arm_left_deque) != 0:
|
400 |
+
left_arm = list(self.puppet_arm_left_deque[-1].position)
|
401 |
+
if len(self.puppet_arm_right_deque) != 0:
|
402 |
+
right_arm = list(self.puppet_arm_right_deque[-1].position)
|
403 |
+
if left_arm is None or right_arm is None:
|
404 |
+
rate.sleep()
|
405 |
+
continue
|
406 |
+
else:
|
407 |
+
break
|
408 |
+
|
409 |
+
traj_left_list = np.linspace(left_arm, left, num_step)
|
410 |
+
traj_right_list = np.linspace(right_arm, right, num_step)
|
411 |
+
|
412 |
+
for i in range(len(traj_left_list)):
|
413 |
+
traj_left = traj_left_list[i]
|
414 |
+
traj_right = traj_right_list[i]
|
415 |
+
traj_left[-1] = left[-1]
|
416 |
+
traj_right[-1] = right[-1]
|
417 |
+
joint_state_msg = JointState()
|
418 |
+
joint_state_msg.header = Header()
|
419 |
+
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
|
420 |
+
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
|
421 |
+
joint_state_msg.position = traj_left
|
422 |
+
self.puppet_arm_left_publisher.publish(joint_state_msg)
|
423 |
+
joint_state_msg.position = traj_right
|
424 |
+
self.puppet_arm_right_publisher.publish(joint_state_msg)
|
425 |
+
rate.sleep()
|
426 |
+
|
427 |
+
def puppet_arm_publish_continuous_thread(self, left, right):
|
428 |
+
if self.puppet_arm_publish_thread is not None:
|
429 |
+
self.puppet_arm_publish_lock.release()
|
430 |
+
self.puppet_arm_publish_thread.join()
|
431 |
+
self.puppet_arm_publish_lock.acquire(False)
|
432 |
+
self.puppet_arm_publish_thread = None
|
433 |
+
self.puppet_arm_publish_thread = threading.Thread(target=self.puppet_arm_publish_continuous, args=(left, right))
|
434 |
+
self.puppet_arm_publish_thread.start()
|
435 |
+
|
436 |
+
def get_frame(self):
|
437 |
+
if len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or \
|
438 |
+
(self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0 or len(self.img_front_depth_deque) == 0)):
|
439 |
+
return False
|
440 |
+
if self.args.use_depth_image:
|
441 |
+
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec(),
|
442 |
+
self.img_left_depth_deque[-1].header.stamp.to_sec(), self.img_right_depth_deque[-1].header.stamp.to_sec(), self.img_front_depth_deque[-1].header.stamp.to_sec()])
|
443 |
+
else:
|
444 |
+
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec()])
|
445 |
+
|
446 |
+
if len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time:
|
447 |
+
return False
|
448 |
+
if len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time:
|
449 |
+
return False
|
450 |
+
if len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time:
|
451 |
+
return False
|
452 |
+
if len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time:
|
453 |
+
return False
|
454 |
+
if len(self.puppet_arm_right_deque) == 0 or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time:
|
455 |
+
return False
|
456 |
+
if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
457 |
+
return False
|
458 |
+
if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0 or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
459 |
+
return False
|
460 |
+
if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0 or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
461 |
+
return False
|
462 |
+
if self.args.use_robot_base and (len(self.robot_base_deque) == 0 or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time):
|
463 |
+
return False
|
464 |
+
|
465 |
+
while self.img_left_deque[0].header.stamp.to_sec() < frame_time:
|
466 |
+
self.img_left_deque.popleft()
|
467 |
+
img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), 'passthrough')
|
468 |
+
|
469 |
+
while self.img_right_deque[0].header.stamp.to_sec() < frame_time:
|
470 |
+
self.img_right_deque.popleft()
|
471 |
+
img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), 'passthrough')
|
472 |
+
|
473 |
+
while self.img_front_deque[0].header.stamp.to_sec() < frame_time:
|
474 |
+
self.img_front_deque.popleft()
|
475 |
+
img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), 'passthrough')
|
476 |
+
|
477 |
+
while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time:
|
478 |
+
self.puppet_arm_left_deque.popleft()
|
479 |
+
puppet_arm_left = self.puppet_arm_left_deque.popleft()
|
480 |
+
|
481 |
+
while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time:
|
482 |
+
self.puppet_arm_right_deque.popleft()
|
483 |
+
puppet_arm_right = self.puppet_arm_right_deque.popleft()
|
484 |
+
|
485 |
+
img_left_depth = None
|
486 |
+
if self.args.use_depth_image:
|
487 |
+
while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time:
|
488 |
+
self.img_left_depth_deque.popleft()
|
489 |
+
img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), 'passthrough')
|
490 |
+
|
491 |
+
img_right_depth = None
|
492 |
+
if self.args.use_depth_image:
|
493 |
+
while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time:
|
494 |
+
self.img_right_depth_deque.popleft()
|
495 |
+
img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), 'passthrough')
|
496 |
+
|
497 |
+
img_front_depth = None
|
498 |
+
if self.args.use_depth_image:
|
499 |
+
while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time:
|
500 |
+
self.img_front_depth_deque.popleft()
|
501 |
+
img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), 'passthrough')
|
502 |
+
|
503 |
+
robot_base = None
|
504 |
+
if self.args.use_robot_base:
|
505 |
+
while self.robot_base_deque[0].header.stamp.to_sec() < frame_time:
|
506 |
+
self.robot_base_deque.popleft()
|
507 |
+
robot_base = self.robot_base_deque.popleft()
|
508 |
+
|
509 |
+
return (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
|
510 |
+
puppet_arm_left, puppet_arm_right, robot_base)
|
511 |
+
|
512 |
+
def img_left_callback(self, msg):
|
513 |
+
if len(self.img_left_deque) >= 2000:
|
514 |
+
self.img_left_deque.popleft()
|
515 |
+
self.img_left_deque.append(msg)
|
516 |
+
|
517 |
+
def img_right_callback(self, msg):
|
518 |
+
if len(self.img_right_deque) >= 2000:
|
519 |
+
self.img_right_deque.popleft()
|
520 |
+
self.img_right_deque.append(msg)
|
521 |
+
|
522 |
+
def img_front_callback(self, msg):
|
523 |
+
if len(self.img_front_deque) >= 2000:
|
524 |
+
self.img_front_deque.popleft()
|
525 |
+
self.img_front_deque.append(msg)
|
526 |
+
|
527 |
+
def img_left_depth_callback(self, msg):
|
528 |
+
if len(self.img_left_depth_deque) >= 2000:
|
529 |
+
self.img_left_depth_deque.popleft()
|
530 |
+
self.img_left_depth_deque.append(msg)
|
531 |
+
|
532 |
+
def img_right_depth_callback(self, msg):
|
533 |
+
if len(self.img_right_depth_deque) >= 2000:
|
534 |
+
self.img_right_depth_deque.popleft()
|
535 |
+
self.img_right_depth_deque.append(msg)
|
536 |
+
|
537 |
+
def img_front_depth_callback(self, msg):
|
538 |
+
if len(self.img_front_depth_deque) >= 2000:
|
539 |
+
self.img_front_depth_deque.popleft()
|
540 |
+
self.img_front_depth_deque.append(msg)
|
541 |
+
|
542 |
+
def puppet_arm_left_callback(self, msg):
|
543 |
+
if len(self.puppet_arm_left_deque) >= 2000:
|
544 |
+
self.puppet_arm_left_deque.popleft()
|
545 |
+
self.puppet_arm_left_deque.append(msg)
|
546 |
+
|
547 |
+
def puppet_arm_right_callback(self, msg):
|
548 |
+
if len(self.puppet_arm_right_deque) >= 2000:
|
549 |
+
self.puppet_arm_right_deque.popleft()
|
550 |
+
self.puppet_arm_right_deque.append(msg)
|
551 |
+
|
552 |
+
def robot_base_callback(self, msg):
|
553 |
+
if len(self.robot_base_deque) >= 2000:
|
554 |
+
self.robot_base_deque.popleft()
|
555 |
+
self.robot_base_deque.append(msg)
|
556 |
+
|
557 |
+
def init_ros(self):
|
558 |
+
rospy.init_node('joint_state_publisher', anonymous=True)
|
559 |
+
rospy.Subscriber(self.args.img_left_topic, Image, self.img_left_callback, queue_size=1000, tcp_nodelay=True)
|
560 |
+
rospy.Subscriber(self.args.img_right_topic, Image, self.img_right_callback, queue_size=1000, tcp_nodelay=True)
|
561 |
+
rospy.Subscriber(self.args.img_front_topic, Image, self.img_front_callback, queue_size=1000, tcp_nodelay=True)
|
562 |
+
if self.args.use_depth_image:
|
563 |
+
rospy.Subscriber(self.args.img_left_depth_topic, Image, self.img_left_depth_callback, queue_size=1000, tcp_nodelay=True)
|
564 |
+
rospy.Subscriber(self.args.img_right_depth_topic, Image, self.img_right_depth_callback, queue_size=1000, tcp_nodelay=True)
|
565 |
+
rospy.Subscriber(self.args.img_front_depth_topic, Image, self.img_front_depth_callback, queue_size=1000, tcp_nodelay=True)
|
566 |
+
rospy.Subscriber(self.args.puppet_arm_left_topic, JointState, self.puppet_arm_left_callback, queue_size=1000, tcp_nodelay=True)
|
567 |
+
rospy.Subscriber(self.args.puppet_arm_right_topic, JointState, self.puppet_arm_right_callback, queue_size=1000, tcp_nodelay=True)
|
568 |
+
rospy.Subscriber(self.args.robot_base_topic, Odometry, self.robot_base_callback, queue_size=1000, tcp_nodelay=True)
|
569 |
+
self.puppet_arm_left_publisher = rospy.Publisher(self.args.puppet_arm_left_cmd_topic, JointState, queue_size=10)
|
570 |
+
self.puppet_arm_right_publisher = rospy.Publisher(self.args.puppet_arm_right_cmd_topic, JointState, queue_size=10)
|
571 |
+
self.robot_base_publisher = rospy.Publisher(self.args.robot_base_cmd_topic, Twist, queue_size=10)
|
572 |
+
|
573 |
+
|
574 |
+
def get_arguments():
|
575 |
+
parser = argparse.ArgumentParser()
|
576 |
+
parser.add_argument('--max_publish_step', action='store', type=int,
|
577 |
+
help='Maximum number of action publishing steps', default=10000, required=False)
|
578 |
+
parser.add_argument('--seed', action='store', type=int,
|
579 |
+
help='Random seed', default=None, required=False)
|
580 |
+
|
581 |
+
parser.add_argument('--img_front_topic', action='store', type=str, help='img_front_topic',
|
582 |
+
default='/camera_f/color/image_raw', required=False)
|
583 |
+
parser.add_argument('--img_left_topic', action='store', type=str, help='img_left_topic',
|
584 |
+
default='/camera_l/color/image_raw', required=False)
|
585 |
+
parser.add_argument('--img_right_topic', action='store', type=str, help='img_right_topic',
|
586 |
+
default='/camera_r/color/image_raw', required=False)
|
587 |
+
|
588 |
+
parser.add_argument('--img_front_depth_topic', action='store', type=str, help='img_front_depth_topic',
|
589 |
+
default='/camera_f/depth/image_raw', required=False)
|
590 |
+
parser.add_argument('--img_left_depth_topic', action='store', type=str, help='img_left_depth_topic',
|
591 |
+
default='/camera_l/depth/image_raw', required=False)
|
592 |
+
parser.add_argument('--img_right_depth_topic', action='store', type=str, help='img_right_depth_topic',
|
593 |
+
default='/camera_r/depth/image_raw', required=False)
|
594 |
+
|
595 |
+
parser.add_argument('--puppet_arm_left_cmd_topic', action='store', type=str, help='puppet_arm_left_cmd_topic',
|
596 |
+
default='/master/joint_left', required=False)
|
597 |
+
parser.add_argument('--puppet_arm_right_cmd_topic', action='store', type=str, help='puppet_arm_right_cmd_topic',
|
598 |
+
default='/master/joint_right', required=False)
|
599 |
+
parser.add_argument('--puppet_arm_left_topic', action='store', type=str, help='puppet_arm_left_topic',
|
600 |
+
default='/puppet/joint_left', required=False)
|
601 |
+
parser.add_argument('--puppet_arm_right_topic', action='store', type=str, help='puppet_arm_right_topic',
|
602 |
+
default='/puppet/joint_right', required=False)
|
603 |
+
|
604 |
+
parser.add_argument('--robot_base_topic', action='store', type=str, help='robot_base_topic',
|
605 |
+
default='/odom_raw', required=False)
|
606 |
+
parser.add_argument('--robot_base_cmd_topic', action='store', type=str, help='robot_base_topic',
|
607 |
+
default='/cmd_vel', required=False)
|
608 |
+
parser.add_argument('--use_robot_base', action='store_true',
|
609 |
+
help='Whether to use the robot base to move around',
|
610 |
+
default=False, required=False)
|
611 |
+
parser.add_argument('--publish_rate', action='store', type=int,
|
612 |
+
help='The rate at which to publish the actions',
|
613 |
+
default=30, required=False)
|
614 |
+
parser.add_argument('--ctrl_freq', action='store', type=int,
|
615 |
+
help='The control frequency of the robot',
|
616 |
+
default=25, required=False)
|
617 |
+
|
618 |
+
parser.add_argument('--chunk_size', action='store', type=int,
|
619 |
+
help='Action chunk size',
|
620 |
+
default=64, required=False)
|
621 |
+
parser.add_argument('--arm_steps_length', action='store', type=float,
|
622 |
+
help='The maximum change allowed for each joint per timestep',
|
623 |
+
default=[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.2], required=False)
|
624 |
+
|
625 |
+
parser.add_argument('--use_actions_interpolation', action='store_true',
|
626 |
+
help='Whether to interpolate the actions if the difference is too large',
|
627 |
+
default=False, required=False)
|
628 |
+
parser.add_argument('--use_depth_image', action='store_true',
|
629 |
+
help='Whether to use depth images',
|
630 |
+
default=False, required=False)
|
631 |
+
|
632 |
+
parser.add_argument('--disable_puppet_arm', action='store_true',
|
633 |
+
help='Whether to disable the puppet arm. This is useful for safely debugging',default=False)
|
634 |
+
|
635 |
+
parser.add_argument('--config_path', type=str, default="configs/base.yaml",
|
636 |
+
help='Path to the config file')
|
637 |
+
# parser.add_argument('--cfg_scale', type=float, default=2.0,
|
638 |
+
# help='the scaling factor used to modify the magnitude of the control features during denoising')
|
639 |
+
parser.add_argument('--pretrained_model_name_or_path', type=str, required=True, help='Name or path to the pretrained model')
|
640 |
+
|
641 |
+
parser.add_argument('--lang_embeddings_path', type=str, required=True,
|
642 |
+
help='Path to the pre-encoded language instruction embeddings')
|
643 |
+
|
644 |
+
args = parser.parse_args()
|
645 |
+
return args
|
646 |
+
|
647 |
+
|
648 |
+
def main():
|
649 |
+
args = get_arguments()
|
650 |
+
ros_operator = RosOperator(args)
|
651 |
+
if args.seed is not None:
|
652 |
+
set_seed(args.seed)
|
653 |
+
config = get_config(args)
|
654 |
+
model_inference(args, config, ros_operator)
|
655 |
+
|
656 |
+
|
657 |
+
if __name__ == '__main__':
|
658 |
+
main()
|
scripts/agilex_model.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from torchvision import transforms
|
7 |
+
|
8 |
+
from configs.state_vec import STATE_VEC_IDX_MAPPING
|
9 |
+
from models.multimodal_encoder.siglip_encoder import SiglipVisionTower
|
10 |
+
from models.multimodal_encoder.t5_encoder import T5Embedder
|
11 |
+
from models.rdt_runner import RDTRunner
|
12 |
+
|
13 |
+
|
14 |
+
# The indices that the raw vector should be mapped to in the unified action vector
|
15 |
+
AGILEX_STATE_INDICES = [
|
16 |
+
STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"] for i in range(6)
|
17 |
+
] + [
|
18 |
+
STATE_VEC_IDX_MAPPING["left_gripper_open"]
|
19 |
+
] + [
|
20 |
+
STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)
|
21 |
+
] + [
|
22 |
+
STATE_VEC_IDX_MAPPING[f"right_gripper_open"]
|
23 |
+
]
|
24 |
+
TABLETOP_6D_INDICES_NAMES = [
|
25 |
+
'left_eef_pos_x','left_eef_pos_y','left_eef_pos_z','left_eef_angle_0','left_eef_angle_1','left_eef_angle_2','left_eef_angle_3','left_eef_angle_4','left_eef_angle_5','left_gripper_open','right_eef_pos_x','right_eef_pos_y','right_eef_pos_z','right_eef_angle_0','right_eef_angle_1','right_eef_angle_2','right_eef_angle_3','right_eef_angle_4','right_eef_angle_5','right_gripper_open']
|
26 |
+
TABLETOP_6D_INDICES = [STATE_VEC_IDX_MAPPING[n] for n in TABLETOP_6D_INDICES_NAMES]
|
27 |
+
|
28 |
+
# Create the RDT model
|
29 |
+
def create_model(args, **kwargs):
|
30 |
+
model = RoboticDiffusionTransformerModel(args, **kwargs)
|
31 |
+
pretrained = kwargs.get("pretrained", None)
|
32 |
+
if (
|
33 |
+
pretrained is not None
|
34 |
+
and os.path.isfile(pretrained)
|
35 |
+
):
|
36 |
+
model.load_pretrained_weights(pretrained)
|
37 |
+
return model
|
38 |
+
|
39 |
+
|
40 |
+
class RoboticDiffusionTransformerModel(object):
|
41 |
+
"""A wrapper for the RDT model, which handles
|
42 |
+
1. Model initialization
|
43 |
+
2. Encodings of instructions
|
44 |
+
3. Model inference
|
45 |
+
"""
|
46 |
+
def __init__(
|
47 |
+
self, args,
|
48 |
+
device='cuda',
|
49 |
+
dtype=torch.bfloat16,
|
50 |
+
image_size=None,
|
51 |
+
control_frequency=25,
|
52 |
+
pretrained=None,
|
53 |
+
pretrained_vision_encoder_name_or_path=None,
|
54 |
+
pretrained_text_encoder_name_or_path=None
|
55 |
+
):
|
56 |
+
self.args = args
|
57 |
+
self.dtype = dtype
|
58 |
+
self.image_size = image_size
|
59 |
+
self.device = device
|
60 |
+
self.control_frequency = control_frequency
|
61 |
+
# We do not use the text encoder due to limited GPU memory
|
62 |
+
self.text_tokenizer, self.text_model = self.get_text_encoder(pretrained_text_encoder_name_or_path)
|
63 |
+
self.image_processor, self.vision_model = self.get_vision_encoder(pretrained_vision_encoder_name_or_path)
|
64 |
+
self.policy = self.get_policy(pretrained)
|
65 |
+
|
66 |
+
self.reset()
|
67 |
+
|
68 |
+
def get_policy(self, pretrained):
|
69 |
+
"""Initialize the model."""
|
70 |
+
# Initialize model with arguments
|
71 |
+
if (
|
72 |
+
pretrained is None
|
73 |
+
or os.path.isfile(pretrained)
|
74 |
+
):
|
75 |
+
img_cond_len = (self.args["common"]["img_history_size"]
|
76 |
+
* self.args["common"]["num_cameras"]
|
77 |
+
* self.vision_model.num_patches)
|
78 |
+
|
79 |
+
_model = RDTRunner(
|
80 |
+
action_dim=self.args["common"]["state_dim"],
|
81 |
+
pred_horizon=self.args["common"]["action_chunk_size"],
|
82 |
+
config=self.args["model"],
|
83 |
+
lang_token_dim=self.args["model"]["lang_token_dim"],
|
84 |
+
img_token_dim=self.args["model"]["img_token_dim"],
|
85 |
+
state_token_dim=self.args["model"]["state_token_dim"],
|
86 |
+
max_lang_cond_len=self.args["dataset"]["tokenizer_max_length"],
|
87 |
+
img_cond_len=img_cond_len,
|
88 |
+
img_pos_embed_config=[
|
89 |
+
# No initial pos embed in the last grid size
|
90 |
+
# since we've already done in ViT
|
91 |
+
("image", (self.args["common"]["img_history_size"],
|
92 |
+
self.args["common"]["num_cameras"],
|
93 |
+
-self.vision_model.num_patches)),
|
94 |
+
],
|
95 |
+
lang_pos_embed_config=[
|
96 |
+
# Similarly, no initial pos embed for language
|
97 |
+
("lang", -self.args["dataset"]["tokenizer_max_length"]),
|
98 |
+
],
|
99 |
+
dtype=self.dtype,
|
100 |
+
)
|
101 |
+
else:
|
102 |
+
_model = RDTRunner.from_pretrained(pretrained)
|
103 |
+
|
104 |
+
return _model
|
105 |
+
|
106 |
+
def get_text_encoder(self, pretrained_text_encoder_name_or_path):
|
107 |
+
text_embedder = T5Embedder(from_pretrained=pretrained_text_encoder_name_or_path,
|
108 |
+
model_max_length=self.args["dataset"]["tokenizer_max_length"],
|
109 |
+
device=self.device)
|
110 |
+
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
|
111 |
+
return tokenizer, text_encoder
|
112 |
+
|
113 |
+
def get_vision_encoder(self, pretrained_vision_encoder_name_or_path):
|
114 |
+
vision_encoder = SiglipVisionTower(vision_tower=pretrained_vision_encoder_name_or_path, args=None)
|
115 |
+
image_processor = vision_encoder.image_processor
|
116 |
+
return image_processor, vision_encoder
|
117 |
+
|
118 |
+
def reset(self):
|
119 |
+
"""Set model to evaluation mode.
|
120 |
+
"""
|
121 |
+
device = self.device
|
122 |
+
weight_dtype = self.dtype
|
123 |
+
self.policy.eval()
|
124 |
+
# self.text_model.eval()
|
125 |
+
self.vision_model.eval()
|
126 |
+
|
127 |
+
self.policy = self.policy.to(device, dtype=weight_dtype)
|
128 |
+
# self.text_model = self.text_model.to(device, dtype=weight_dtype)
|
129 |
+
self.vision_model = self.vision_model.to(device, dtype=weight_dtype)
|
130 |
+
|
131 |
+
def load_pretrained_weights(self, pretrained=None):
|
132 |
+
if pretrained is None:
|
133 |
+
return
|
134 |
+
print(f'Loading weights from {pretrained}')
|
135 |
+
filename = os.path.basename(pretrained)
|
136 |
+
if filename.endswith('.pt'):
|
137 |
+
checkpoint = torch.load(pretrained)
|
138 |
+
self.policy.load_state_dict(checkpoint["module"])
|
139 |
+
elif filename.endswith('.safetensors'):
|
140 |
+
from safetensors.torch import load_model
|
141 |
+
load_model(self.policy, pretrained)
|
142 |
+
else:
|
143 |
+
raise NotImplementedError(f"Unknown checkpoint format: {pretrained}")
|
144 |
+
|
145 |
+
def encode_instruction(self, instruction, device="cuda"):
|
146 |
+
"""Encode string instruction to latent embeddings.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
instruction: a string of instruction
|
150 |
+
device: a string of device
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
pred: a tensor of latent embeddings of shape (text_max_length, 512)
|
154 |
+
"""
|
155 |
+
tokens = self.text_tokenizer(
|
156 |
+
instruction, return_tensors="pt",
|
157 |
+
padding="longest",
|
158 |
+
truncation=True
|
159 |
+
)["input_ids"].to(device)
|
160 |
+
|
161 |
+
tokens = tokens.view(1, -1)
|
162 |
+
with torch.no_grad():
|
163 |
+
pred = self.text_model(tokens).last_hidden_state.detach()
|
164 |
+
|
165 |
+
return pred
|
166 |
+
|
167 |
+
def _format_joint_to_state(self, joints):
|
168 |
+
"""
|
169 |
+
Format the joint proprioception into the unified action vector.
|
170 |
+
|
171 |
+
Args:
|
172 |
+
joints (torch.Tensor): The 6D EEF proprioception to be formatted.
|
173 |
+
qpos ([B, N, 20]).
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
state (torch.Tensor): The formatted vector for RDT ([B, N, 128]).
|
177 |
+
"""
|
178 |
+
# Rescale the gripper to the range of [0, 1]
|
179 |
+
joints = joints / torch.tensor(
|
180 |
+
[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]],
|
181 |
+
device=joints.device, dtype=joints.dtype
|
182 |
+
)
|
183 |
+
|
184 |
+
B, N, _ = joints.shape
|
185 |
+
state = torch.zeros(
|
186 |
+
(B, N, self.args["model"]["state_token_dim"]),
|
187 |
+
device=joints.device, dtype=joints.dtype
|
188 |
+
)
|
189 |
+
# Fill into the unified state vector
|
190 |
+
state[:, :, TABLETOP_6D_INDICES] = joints
|
191 |
+
# Assemble the mask indicating each dimension's availability
|
192 |
+
state_elem_mask = torch.zeros(
|
193 |
+
(B, self.args["model"]["state_token_dim"]),
|
194 |
+
device=joints.device, dtype=joints.dtype
|
195 |
+
)
|
196 |
+
state_elem_mask[:,TABLETOP_6D_INDICES] = 1
|
197 |
+
return state, state_elem_mask
|
198 |
+
|
199 |
+
def _unformat_action_to_joint(self, action):
|
200 |
+
"""
|
201 |
+
Unformat the unified action vector into the joint action to be executed.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
action (torch.Tensor): The unified action vector to be unformatted.
|
205 |
+
([B, N, 128])
|
206 |
+
|
207 |
+
Returns:
|
208 |
+
joints (torch.Tensor): The unformatted robot joint action.
|
209 |
+
qpos ([B, N, 14]).
|
210 |
+
"""
|
211 |
+
action_indices = TABLETOP_6D_INDICES
|
212 |
+
joints = action[:, :, action_indices]
|
213 |
+
|
214 |
+
# Rescale the gripper back to the action range
|
215 |
+
# Note that the action range and proprioception range are different
|
216 |
+
# for Mobile ALOHA robot
|
217 |
+
joints = joints * torch.tensor(
|
218 |
+
[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]],
|
219 |
+
device=joints.device, dtype=joints.dtype
|
220 |
+
)
|
221 |
+
|
222 |
+
return joints
|
223 |
+
|
224 |
+
@torch.no_grad()
|
225 |
+
def step(self, proprio, images, instruction):
|
226 |
+
"""
|
227 |
+
Predict the next action chunk given the
|
228 |
+
proprioceptive states, images, and instruction embeddings.
|
229 |
+
|
230 |
+
Args:
|
231 |
+
proprio: proprioceptive states
|
232 |
+
images: RGB images, the order should be
|
233 |
+
[ext_{t-1}, right_wrist_{t-1}, left_wrist_{t-1},
|
234 |
+
ext_{t}, right_wrist_{t}, left_wrist_{t}]
|
235 |
+
text_embeds: instruction embeddings
|
236 |
+
|
237 |
+
Returns:
|
238 |
+
action: predicted action
|
239 |
+
"""
|
240 |
+
device = self.device
|
241 |
+
dtype = self.dtype
|
242 |
+
|
243 |
+
# The background image used for padding
|
244 |
+
background_color = np.array([
|
245 |
+
int(x*255) for x in self.image_processor.image_mean
|
246 |
+
], dtype=np.uint8).reshape(1, 1, 3)
|
247 |
+
background_image = np.ones((
|
248 |
+
self.image_processor.size["height"],
|
249 |
+
self.image_processor.size["width"], 3), dtype=np.uint8
|
250 |
+
) * background_color
|
251 |
+
|
252 |
+
# Preprocess the images by order and encode them
|
253 |
+
image_tensor_list = []
|
254 |
+
for image in images:
|
255 |
+
if image is None:
|
256 |
+
# Replace it with the background image
|
257 |
+
image = Image.fromarray(background_image)
|
258 |
+
|
259 |
+
if self.image_size is not None:
|
260 |
+
image = transforms.Resize(self.data_args.image_size)(image)
|
261 |
+
|
262 |
+
if self.args["dataset"].get("auto_adjust_image_brightness", False):
|
263 |
+
pixel_values = list(image.getdata())
|
264 |
+
average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
|
265 |
+
if average_brightness <= 0.15:
|
266 |
+
image = transforms.ColorJitter(brightness=(1.75,1.75))(image)
|
267 |
+
|
268 |
+
if self.args["dataset"].get("image_aspect_ratio", "pad") == 'pad':
|
269 |
+
def expand2square(pil_img, background_color):
|
270 |
+
width, height = pil_img.size
|
271 |
+
if width == height:
|
272 |
+
return pil_img
|
273 |
+
elif width > height:
|
274 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
275 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
276 |
+
return result
|
277 |
+
else:
|
278 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
279 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
280 |
+
return result
|
281 |
+
image = expand2square(image, tuple(int(x*255) for x in self.image_processor.image_mean))
|
282 |
+
image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
283 |
+
image_tensor_list.append(image)
|
284 |
+
|
285 |
+
image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype)
|
286 |
+
|
287 |
+
image_embeds = self.vision_model(image_tensor).detach()
|
288 |
+
image_embeds = image_embeds.reshape(-1, self.vision_model.hidden_size).unsqueeze(0)
|
289 |
+
|
290 |
+
# Prepare the proprioception states and the control frequency
|
291 |
+
joints = proprio.to(device).unsqueeze(0) # (1, 1, 14)
|
292 |
+
states, state_elem_mask = self._format_joint_to_state(joints) # (1, 1, 128), (1, 128)
|
293 |
+
states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype)
|
294 |
+
states = states[:, -1:, :] # (1, 1, 128)
|
295 |
+
ctrl_freqs = torch.tensor([self.control_frequency]).to(device)
|
296 |
+
|
297 |
+
# text_embeds = text_embeds.to(device, dtype=dtype)
|
298 |
+
text_embeds = self.encode_instruction(instruction=instruction)
|
299 |
+
|
300 |
+
# Predict the next action chunk given the inputs
|
301 |
+
trajectory = self.policy.predict_action(
|
302 |
+
lang_tokens=text_embeds,
|
303 |
+
lang_attn_mask=torch.ones(
|
304 |
+
text_embeds.shape[:2], dtype=torch.bool,
|
305 |
+
device=text_embeds.device),
|
306 |
+
img_tokens=image_embeds,
|
307 |
+
state_tokens=states,
|
308 |
+
action_mask=state_elem_mask.unsqueeze(1),
|
309 |
+
ctrl_freqs=ctrl_freqs
|
310 |
+
)
|
311 |
+
trajectory = self._unformat_action_to_joint(trajectory).to(torch.float32)
|
312 |
+
|
313 |
+
return trajectory
|
scripts/encode_lang_batch.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import yaml
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
from models.multimodal_encoder.t5_encoder import T5Embedder
|
9 |
+
|
10 |
+
|
11 |
+
GPU = 0
|
12 |
+
MODEL_PATH = "google/t5-v1_1-xxl"
|
13 |
+
CONFIG_PATH = "configs/base.yaml"
|
14 |
+
# Modify the TARGET_DIR to your dataset path
|
15 |
+
TARGET_DIR = "data/datasets/openx_embod/singlevla_benchmark_ee"
|
16 |
+
|
17 |
+
# Note: if your GPU VRAM is less than 24GB,
|
18 |
+
# it is recommended to enable offloading by specifying an offload directory.
|
19 |
+
OFFLOAD_DIR = None # Specify your offload directory here, ensuring the directory exists.
|
20 |
+
|
21 |
+
def main():
|
22 |
+
with open(CONFIG_PATH, "r") as fp:
|
23 |
+
config = yaml.safe_load(fp)
|
24 |
+
|
25 |
+
device = torch.device(f"cuda:{GPU}")
|
26 |
+
text_embedder = T5Embedder(
|
27 |
+
from_pretrained=MODEL_PATH,
|
28 |
+
model_max_length=config["dataset"]["tokenizer_max_length"],
|
29 |
+
device=device,
|
30 |
+
use_offload_folder=OFFLOAD_DIR
|
31 |
+
)
|
32 |
+
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
|
33 |
+
|
34 |
+
# Get all the task paths
|
35 |
+
task_paths = []
|
36 |
+
for sub_dir in os.listdir(TARGET_DIR):
|
37 |
+
middle_dir = os.path.join(TARGET_DIR, sub_dir)
|
38 |
+
if os.path.isdir(middle_dir):
|
39 |
+
for task_dir in os.listdir(middle_dir):
|
40 |
+
task_path = os.path.join(middle_dir, task_dir)
|
41 |
+
if os.path.isdir(task_path):
|
42 |
+
task_paths.append(task_path)
|
43 |
+
|
44 |
+
# For each task, encode the instructions
|
45 |
+
for task_path in tqdm(task_paths):
|
46 |
+
# Load the instructions corresponding to the task from the directory
|
47 |
+
with open(os.path.join(task_path, 'expanded_instruction_gpt-4-turbo.json'), 'r') as f_instr:
|
48 |
+
instruction_dict = json.load(f_instr)
|
49 |
+
instructions = [instruction_dict['instruction']] + instruction_dict['simplified_instruction'] + \
|
50 |
+
instruction_dict['expanded_instruction']
|
51 |
+
|
52 |
+
# Encode the instructions
|
53 |
+
tokenized_res = tokenizer(
|
54 |
+
instructions, return_tensors="pt",
|
55 |
+
padding="longest",
|
56 |
+
truncation=True
|
57 |
+
)
|
58 |
+
tokens = tokenized_res["input_ids"].to(device)
|
59 |
+
attn_mask = tokenized_res["attention_mask"].to(device)
|
60 |
+
|
61 |
+
with torch.no_grad():
|
62 |
+
text_embeds = text_encoder(
|
63 |
+
input_ids=tokens,
|
64 |
+
attention_mask=attn_mask
|
65 |
+
)["last_hidden_state"].detach().cpu()
|
66 |
+
|
67 |
+
attn_mask = attn_mask.cpu().bool()
|
68 |
+
|
69 |
+
# Save the embeddings for training use
|
70 |
+
for i in range(len(instructions)):
|
71 |
+
text_embed = text_embeds[i][attn_mask[i]]
|
72 |
+
save_path = os.path.join(task_path, f"lang_embed_{i}.pt")
|
73 |
+
torch.save(text_embed, save_path)
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
main()
|
scripts/maniskill_model.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from torchvision import transforms
|
7 |
+
|
8 |
+
from configs.state_vec import STATE_VEC_IDX_MAPPING
|
9 |
+
from models.multimodal_encoder.siglip_encoder import SiglipVisionTower
|
10 |
+
from models.multimodal_encoder.t5_encoder import T5Embedder
|
11 |
+
from models.rdt_runner import RDTRunner
|
12 |
+
|
13 |
+
|
14 |
+
MANISKILL_INDICES = [
|
15 |
+
STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(7)
|
16 |
+
] + [
|
17 |
+
STATE_VEC_IDX_MAPPING[f"right_gripper_open"]
|
18 |
+
]
|
19 |
+
|
20 |
+
|
21 |
+
def create_model(args, pretrained, **kwargs):
|
22 |
+
model = RoboticDiffusionTransformerModel(args, **kwargs)
|
23 |
+
if pretrained is not None:
|
24 |
+
model.load_pretrained_weights(pretrained)
|
25 |
+
return model
|
26 |
+
|
27 |
+
|
28 |
+
DATA_STAT = {'state_min': [-0.7463043928146362, -0.0801204964518547, -0.4976441562175751, -2.657780647277832, -0.5742632150650024, 1.8309762477874756, -2.2423808574676514, 0.0], 'state_max': [0.7645499110221863, 1.4967026710510254, 0.4650936424732208, -0.3866899907588959, 0.5505855679512024, 3.2900545597076416, 2.5737812519073486, 0.03999999910593033], 'action_min': [-0.7472005486488342, -0.08631071448326111, -0.4995281398296356, -2.658363103866577, -0.5751323103904724, 1.8290787935256958, -2.245187997817993, -1.0], 'action_max': [0.7654682397842407, 1.4984270334243774, 0.46786263585090637, -0.38181185722351074, 0.5517147779464722, 3.291581630706787, 2.575840711593628, 1.0]}
|
29 |
+
|
30 |
+
class RoboticDiffusionTransformerModel(object):
|
31 |
+
"""A wrapper for the RDT model, which handles
|
32 |
+
1. Model initialization
|
33 |
+
2. Encodings of instructions
|
34 |
+
3. Model inference
|
35 |
+
"""
|
36 |
+
def __init__(
|
37 |
+
self, args,
|
38 |
+
device='cuda',
|
39 |
+
dtype=torch.bfloat16,
|
40 |
+
image_size=None,
|
41 |
+
control_frequency=25,
|
42 |
+
pretrained_text_encoder_name_or_path=None,
|
43 |
+
pretrained_vision_encoder_name_or_path=None,
|
44 |
+
):
|
45 |
+
self.args = args
|
46 |
+
self.dtype = dtype
|
47 |
+
self.image_size = image_size
|
48 |
+
self.device = device
|
49 |
+
self.control_frequency = control_frequency
|
50 |
+
self.text_tokenizer, self.text_model = self.get_text_encoder(pretrained_text_encoder_name_or_path)
|
51 |
+
self.image_processor, self.vision_model = self.get_vision_encoder(pretrained_vision_encoder_name_or_path)
|
52 |
+
self.policy = self.get_policy()
|
53 |
+
|
54 |
+
self.state_min = torch.tensor(DATA_STAT['state_min']).to(device)
|
55 |
+
self.state_max = torch.tensor(DATA_STAT['state_max']).to(device)
|
56 |
+
self.action_min = torch.tensor(DATA_STAT['action_min']).to(device)
|
57 |
+
self.action_max = torch.tensor(DATA_STAT['action_max']).to(device)
|
58 |
+
|
59 |
+
self.reset()
|
60 |
+
|
61 |
+
def get_policy(self):
|
62 |
+
"""Initialize the model."""
|
63 |
+
# Initialize model with arguments
|
64 |
+
img_cond_len = (self.args["common"]["img_history_size"]
|
65 |
+
* self.args["common"]["num_cameras"]
|
66 |
+
* self.vision_model.num_patches)
|
67 |
+
|
68 |
+
_model = RDTRunner(
|
69 |
+
action_dim=self.args["common"]["state_dim"],
|
70 |
+
pred_horizon=self.args["common"]["action_chunk_size"],
|
71 |
+
config=self.args["model"],
|
72 |
+
lang_token_dim=self.args["model"]["lang_token_dim"],
|
73 |
+
img_token_dim=self.args["model"]["img_token_dim"],
|
74 |
+
state_token_dim=self.args["model"]["state_token_dim"],
|
75 |
+
max_lang_cond_len=self.args["dataset"]["tokenizer_max_length"],
|
76 |
+
img_cond_len=img_cond_len,
|
77 |
+
img_pos_embed_config=[
|
78 |
+
# No initial pos embed in the last grid size
|
79 |
+
# since we've already done in ViT
|
80 |
+
("image", (self.args["common"]["img_history_size"],
|
81 |
+
self.args["common"]["num_cameras"],
|
82 |
+
-self.vision_model.num_patches)),
|
83 |
+
],
|
84 |
+
lang_pos_embed_config=[
|
85 |
+
# Similarly, no initial pos embed for language
|
86 |
+
("lang", -self.args["dataset"]["tokenizer_max_length"]),
|
87 |
+
],
|
88 |
+
dtype=self.dtype,
|
89 |
+
)
|
90 |
+
|
91 |
+
return _model
|
92 |
+
|
93 |
+
def get_text_encoder(self, pretrained_text_encoder_name_or_path):
|
94 |
+
text_embedder = T5Embedder(from_pretrained=pretrained_text_encoder_name_or_path,
|
95 |
+
model_max_length=self.args["dataset"]["tokenizer_max_length"],
|
96 |
+
device=self.device)
|
97 |
+
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
|
98 |
+
return tokenizer, text_encoder
|
99 |
+
|
100 |
+
def get_vision_encoder(self, pretrained_vision_encoder_name_or_path):
|
101 |
+
vision_encoder = SiglipVisionTower(vision_tower=pretrained_vision_encoder_name_or_path, args=None)
|
102 |
+
image_processor = vision_encoder.image_processor
|
103 |
+
return image_processor, vision_encoder
|
104 |
+
|
105 |
+
def reset(self):
|
106 |
+
"""Set model to evaluation mode.
|
107 |
+
"""
|
108 |
+
device = self.device
|
109 |
+
weight_dtype = self.dtype
|
110 |
+
self.policy.eval()
|
111 |
+
self.text_model.eval()
|
112 |
+
self.vision_model.eval()
|
113 |
+
|
114 |
+
self.policy = self.policy.to(device, dtype=weight_dtype)
|
115 |
+
self.text_model = self.text_model.to(device, dtype=weight_dtype)
|
116 |
+
self.vision_model = self.vision_model.to(device, dtype=weight_dtype)
|
117 |
+
|
118 |
+
def load_pretrained_weights(self, pretrained=None):
|
119 |
+
if pretrained is None:
|
120 |
+
return
|
121 |
+
print(f'Loading weights from {pretrained}')
|
122 |
+
filename = os.path.basename(pretrained)
|
123 |
+
if filename.endswith('.pt'):
|
124 |
+
checkpoint = torch.load(pretrained)
|
125 |
+
self.policy.load_state_dict(checkpoint["module"])
|
126 |
+
elif filename.endswith('.safetensors'):
|
127 |
+
from safetensors.torch import load_model
|
128 |
+
load_model(self.policy, pretrained)
|
129 |
+
else:
|
130 |
+
raise NotImplementedError(f"Unknown checkpoint format: {pretrained}")
|
131 |
+
|
132 |
+
def encode_instruction(self, instruction, device="cuda"):
|
133 |
+
"""Encode string instruction to latent embeddings.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
instruction: a string of instruction
|
137 |
+
device: a string of device
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
pred: a tensor of latent embeddings of shape (text_max_length, 512)
|
141 |
+
"""
|
142 |
+
tokens = self.text_tokenizer(
|
143 |
+
instruction, return_tensors="pt",
|
144 |
+
padding="longest",
|
145 |
+
truncation=True
|
146 |
+
)["input_ids"].to(device)
|
147 |
+
|
148 |
+
tokens = tokens.view(1, -1)
|
149 |
+
with torch.no_grad():
|
150 |
+
pred = self.text_model(tokens).last_hidden_state.detach()
|
151 |
+
|
152 |
+
return pred
|
153 |
+
|
154 |
+
def _format_joint_to_state(self, joints):
|
155 |
+
"""
|
156 |
+
Format the robot joint state into the unified state vector.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
joints (torch.Tensor): The joint state to be formatted.
|
160 |
+
qpos ([B, N, 14]).
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
state (torch.Tensor): The formatted state for RDT ([B, N, 128]).
|
164 |
+
"""
|
165 |
+
# Rescale the gripper
|
166 |
+
# joints = joints / torch.tensor(
|
167 |
+
# [[[1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 4.7888]]],
|
168 |
+
# device=joints.device, dtype=joints.dtype
|
169 |
+
# )
|
170 |
+
|
171 |
+
# normalize to -1,1
|
172 |
+
joints = (joints - self.state_min) / (self.state_max - self.state_min) * 2 - 1
|
173 |
+
B, N, _ = joints.shape
|
174 |
+
state = torch.zeros(
|
175 |
+
(B, N, self.args["model"]["state_token_dim"]),
|
176 |
+
device=joints.device, dtype=joints.dtype
|
177 |
+
)
|
178 |
+
# assemble the unifed state vector
|
179 |
+
state[:, :, MANISKILL_INDICES] = joints
|
180 |
+
state_elem_mask = torch.zeros(
|
181 |
+
(B, self.args["model"]["state_token_dim"]),
|
182 |
+
device=joints.device, dtype=joints.dtype
|
183 |
+
)
|
184 |
+
state_elem_mask[:, MANISKILL_INDICES] = 1
|
185 |
+
return state, state_elem_mask
|
186 |
+
|
187 |
+
def _unformat_action_to_joint(self, action):
|
188 |
+
action_indices = MANISKILL_INDICES
|
189 |
+
joints = action[:, :, action_indices]
|
190 |
+
|
191 |
+
# denormalize to action space
|
192 |
+
|
193 |
+
joints = (joints + 1) / 2 * (self.action_max - self.action_min) + self.action_min
|
194 |
+
|
195 |
+
return joints
|
196 |
+
|
197 |
+
@torch.no_grad()
|
198 |
+
def step(self, proprio, images, text_embeds):
|
199 |
+
"""
|
200 |
+
Args:
|
201 |
+
proprio: proprioceptive states
|
202 |
+
images: RGB images
|
203 |
+
text_embeds: instruction embeddings
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
action: predicted action
|
207 |
+
"""
|
208 |
+
device = self.device
|
209 |
+
dtype = self.dtype
|
210 |
+
|
211 |
+
background_color = np.array([
|
212 |
+
int(x*255) for x in self.image_processor.image_mean
|
213 |
+
], dtype=np.uint8).reshape(1, 1, 3)
|
214 |
+
background_image = np.ones((
|
215 |
+
self.image_processor.size["height"],
|
216 |
+
self.image_processor.size["width"], 3), dtype=np.uint8
|
217 |
+
) * background_color
|
218 |
+
|
219 |
+
image_tensor_list = []
|
220 |
+
for image in images:
|
221 |
+
if image is None:
|
222 |
+
# Replace it with the background image
|
223 |
+
image = Image.fromarray(background_image)
|
224 |
+
|
225 |
+
if self.image_size is not None:
|
226 |
+
image = transforms.Resize(self.data_args.image_size)(image)
|
227 |
+
|
228 |
+
if self.args["dataset"].get("auto_adjust_image_brightness", False):
|
229 |
+
pixel_values = list(image.getdata())
|
230 |
+
average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
|
231 |
+
if average_brightness <= 0.15:
|
232 |
+
image = transforms.ColorJitter(brightness=(1.75,1.75))(image)
|
233 |
+
|
234 |
+
if self.args["dataset"].get("image_aspect_ratio", "pad") == 'pad':
|
235 |
+
def expand2square(pil_img, background_color):
|
236 |
+
width, height = pil_img.size
|
237 |
+
if width == height:
|
238 |
+
return pil_img
|
239 |
+
elif width > height:
|
240 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
241 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
242 |
+
return result
|
243 |
+
else:
|
244 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
245 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
246 |
+
return result
|
247 |
+
image = expand2square(image, tuple(int(x*255) for x in self.image_processor.image_mean))
|
248 |
+
image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
249 |
+
image_tensor_list.append(image)
|
250 |
+
|
251 |
+
image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype)
|
252 |
+
|
253 |
+
image_embeds = self.vision_model(image_tensor).detach()
|
254 |
+
image_embeds = image_embeds.reshape(-1, self.vision_model.hidden_size).unsqueeze(0)
|
255 |
+
|
256 |
+
# history of actions
|
257 |
+
joints = proprio.to(device).unsqueeze(0) # (1, 1, 14)
|
258 |
+
states, state_elem_mask = self._format_joint_to_state(joints) # (1, 1, 128), (1, 128)
|
259 |
+
states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype)
|
260 |
+
states = states[:, -1:, :] # (1, 1, 128)
|
261 |
+
ctrl_freqs = torch.tensor([self.control_frequency]).to(device)
|
262 |
+
|
263 |
+
text_embeds = text_embeds.to(device, dtype=dtype)
|
264 |
+
|
265 |
+
trajectory = self.policy.predict_action(
|
266 |
+
lang_tokens=text_embeds,
|
267 |
+
lang_attn_mask=torch.ones(
|
268 |
+
text_embeds.shape[:2], dtype=torch.bool,
|
269 |
+
device=text_embeds.device),
|
270 |
+
img_tokens=image_embeds,
|
271 |
+
state_tokens=states,
|
272 |
+
action_mask=state_elem_mask.unsqueeze(1),
|
273 |
+
ctrl_freqs=ctrl_freqs
|
274 |
+
)
|
275 |
+
trajectory = self._unformat_action_to_joint(trajectory).to(torch.float32)
|
276 |
+
|
277 |
+
return trajectory
|
train/dataset.py
ADDED
@@ -0,0 +1,467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import traceback
|
2 |
+
import time
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import math
|
6 |
+
import random
|
7 |
+
from typing import Dict, Sequence
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from torch.utils.data import Dataset
|
12 |
+
from torchvision import transforms
|
13 |
+
from PIL import Image
|
14 |
+
import transformers
|
15 |
+
|
16 |
+
from data.filelock import FileLock
|
17 |
+
from data.hdf5_vla_dataset import TabletopHDF5VLADataset, AnubisHDF5VLADataset
|
18 |
+
from train.image_corrupt import image_corrupt
|
19 |
+
|
20 |
+
|
21 |
+
def get_clean_item(chunk_dir):
|
22 |
+
"""
|
23 |
+
Get indexes of clean items in a chunk.
|
24 |
+
"""
|
25 |
+
dirty_bit = read_dirty_bit(chunk_dir)
|
26 |
+
return np.where(1 - dirty_bit)[0].tolist()
|
27 |
+
|
28 |
+
|
29 |
+
def save_dirty_bit(chunk_dir, dirty_bit):
|
30 |
+
"""
|
31 |
+
Save the dirty bit to the chunk directory.
|
32 |
+
"""
|
33 |
+
time_stmp = time.time()
|
34 |
+
while time.time() - time_stmp < 10.0:
|
35 |
+
try:
|
36 |
+
file_path = os.path.join(chunk_dir, "dirty_bit")
|
37 |
+
lock = FileLock(file_path)
|
38 |
+
lock.acquire_write_lock()
|
39 |
+
with open(file_path, 'wb') as file:
|
40 |
+
file.write(dirty_bit.tobytes())
|
41 |
+
lock.release_lock()
|
42 |
+
return
|
43 |
+
except KeyboardInterrupt:
|
44 |
+
lock.release_lock()
|
45 |
+
raise KeyboardInterrupt
|
46 |
+
except BaseException:
|
47 |
+
lock.release_lock()
|
48 |
+
continue
|
49 |
+
raise RuntimeError("Failed to save dirty bit.")
|
50 |
+
|
51 |
+
|
52 |
+
def read_dirty_bit(chunk_dir):
|
53 |
+
"""
|
54 |
+
Read the dirty bit from the chunk directory.
|
55 |
+
"""
|
56 |
+
# If error occurs, retry
|
57 |
+
time_stmp = time.time()
|
58 |
+
while time.time() - time_stmp < 10.0:
|
59 |
+
try:
|
60 |
+
file_path = os.path.join(chunk_dir, "dirty_bit")
|
61 |
+
lock = FileLock(file_path)
|
62 |
+
lock.acquire_read_lock()
|
63 |
+
with open(file_path, 'rb') as file:
|
64 |
+
dirty_bit = np.frombuffer(file.read(), dtype=np.uint8).copy()
|
65 |
+
lock.release_lock()
|
66 |
+
assert len(dirty_bit) > 0
|
67 |
+
return dirty_bit
|
68 |
+
except KeyboardInterrupt:
|
69 |
+
lock.release_lock()
|
70 |
+
raise KeyboardInterrupt
|
71 |
+
except BaseException:
|
72 |
+
lock.release_lock()
|
73 |
+
continue
|
74 |
+
raise RuntimeError("Failed to read dirty bit.")
|
75 |
+
|
76 |
+
|
77 |
+
class VLAConsumerDataset(Dataset):
|
78 |
+
"""A vision-languange-action Dataset for supervised training.
|
79 |
+
This dataset will load data from the buffer directory.
|
80 |
+
"""
|
81 |
+
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
config,
|
85 |
+
tokenizer,
|
86 |
+
image_processor,
|
87 |
+
num_cameras,
|
88 |
+
img_history_size,
|
89 |
+
image_size=None,
|
90 |
+
auto_adjust_image_brightness=False,
|
91 |
+
image_aug=False,
|
92 |
+
dataset_type='pretrain',
|
93 |
+
cond_mask_prob=0.1,
|
94 |
+
cam_ext_mask_prob=-1.0,
|
95 |
+
state_noise_snr=None,
|
96 |
+
use_hdf5=False,
|
97 |
+
use_precomp_lang_embed=False,
|
98 |
+
task_name=None
|
99 |
+
):
|
100 |
+
super(VLAConsumerDataset, self).__init__()
|
101 |
+
|
102 |
+
# Load the control frequency for each dataset
|
103 |
+
with open("configs/dataset_control_freq.json", 'r') as fp:
|
104 |
+
self.control_freq = json.load(fp)
|
105 |
+
# Load the dataset names
|
106 |
+
dataset_names_cfg = 'configs/pretrain_datasets.json' \
|
107 |
+
if dataset_type == 'pretrain' else 'configs/finetune_datasets.json'
|
108 |
+
with open(dataset_names_cfg, 'r') as file:
|
109 |
+
DATASET_NAMES = json.load(file)
|
110 |
+
# Create the mapping between dataset name and id
|
111 |
+
# self.dataset_name2id = {name: i for i, name in enumerate(DATASET_NAMES)}
|
112 |
+
# self.dataset_id2name = {i: name for i, name in enumerate(DATASET_NAMES)}
|
113 |
+
self.dataset_name2id = {task_name: 0}
|
114 |
+
self.dataset_id2name = {0: task_name}
|
115 |
+
|
116 |
+
self.image_processor = image_processor
|
117 |
+
|
118 |
+
self.buffer_dir = config["buf_path"]
|
119 |
+
self.num_chunks = config["buf_num_chunks"]
|
120 |
+
self.chunk_size = config["buf_chunk_size"]
|
121 |
+
self.tokenizer_max_length = config["tokenizer_max_length"]
|
122 |
+
self.image_aspect_ratio = config["image_aspect_ratio"]
|
123 |
+
self.state_noise_snr = state_noise_snr
|
124 |
+
self.num_cameras = num_cameras
|
125 |
+
self.img_history_size = img_history_size
|
126 |
+
self.cond_mask_prob = cond_mask_prob
|
127 |
+
self.cam_ext_mask_prob = cam_ext_mask_prob
|
128 |
+
self.use_hdf5 = use_hdf5
|
129 |
+
self.hdf5_dataset = None
|
130 |
+
if use_hdf5:
|
131 |
+
self.hdf5_dataset = AnubisHDF5VLADataset(task_name)
|
132 |
+
self.use_precomp_lang_embed = use_precomp_lang_embed
|
133 |
+
if use_precomp_lang_embed:
|
134 |
+
self.empty_lang_embed = torch.load("data/empty_lang_embed.pt")
|
135 |
+
|
136 |
+
# Load dataset stat
|
137 |
+
with open("configs/dataset_stat.json", 'r') as f:
|
138 |
+
dataset_stat = json.load(f)
|
139 |
+
self.dataset_stat = dataset_stat
|
140 |
+
|
141 |
+
self.tokenizer = tokenizer
|
142 |
+
self.image_size = image_size
|
143 |
+
self.auto_adjust_image_brightness = auto_adjust_image_brightness
|
144 |
+
self.image_aug = image_aug
|
145 |
+
|
146 |
+
self.last_content = None
|
147 |
+
self.last_meta = None
|
148 |
+
|
149 |
+
def get_dataset_name2id(self):
|
150 |
+
return self.dataset_name2id
|
151 |
+
|
152 |
+
def get_dataset_id2name(self):
|
153 |
+
return self.dataset_id2name
|
154 |
+
|
155 |
+
@staticmethod
|
156 |
+
def pairwise(iterable):
|
157 |
+
a = iter(iterable)
|
158 |
+
return zip(a, a)
|
159 |
+
|
160 |
+
@staticmethod
|
161 |
+
def _load_data_from_chunk(chunk_dir, chunk_item_idx):
|
162 |
+
# If error occurs, retry
|
163 |
+
time_stmp = time.time()
|
164 |
+
while time.time() - time_stmp < 10.0:
|
165 |
+
try:
|
166 |
+
locks = []
|
167 |
+
file_path = os.path.join(chunk_dir, f"json_content_{chunk_item_idx}.json")
|
168 |
+
lock = FileLock(file_path)
|
169 |
+
locks.append(lock)
|
170 |
+
lock.acquire_read_lock()
|
171 |
+
with open(file_path, 'r') as file:
|
172 |
+
json_content = json.load(file)
|
173 |
+
lock.release_lock()
|
174 |
+
file_path = os.path.join(chunk_dir, f"sample_{chunk_item_idx}.npz")
|
175 |
+
lock = FileLock(file_path)
|
176 |
+
locks.append(lock)
|
177 |
+
lock.acquire_read_lock()
|
178 |
+
with open(file_path, 'rb') as file:
|
179 |
+
sample_dict = np.load(file)
|
180 |
+
meta = tuple(sample_dict.values())
|
181 |
+
lock.release_lock()
|
182 |
+
return json_content, meta
|
183 |
+
except KeyboardInterrupt:
|
184 |
+
for lock in locks:
|
185 |
+
lock.release_lock()
|
186 |
+
raise KeyboardInterrupt
|
187 |
+
except BaseException:
|
188 |
+
for lock in locks:
|
189 |
+
lock.release_lock()
|
190 |
+
continue
|
191 |
+
raise RuntimeError("Failed to load sample.")
|
192 |
+
|
193 |
+
def __len__(self) -> int:
|
194 |
+
if self.use_hdf5:
|
195 |
+
return len(self.hdf5_dataset)
|
196 |
+
else:
|
197 |
+
return self.num_chunks * self.chunk_size
|
198 |
+
|
199 |
+
def _safe_load(self, index):
|
200 |
+
read_chunk_item_indices = []
|
201 |
+
# Start searching from a random chunk
|
202 |
+
read_chunk_idx = index // self.chunk_size
|
203 |
+
while len(read_chunk_item_indices) == 0:
|
204 |
+
read_chunk_dir = os.path.join(self.buffer_dir, f"chunk_{read_chunk_idx}")
|
205 |
+
try:
|
206 |
+
read_chunk_item_indices = get_clean_item(read_chunk_dir)
|
207 |
+
except BaseException as e:
|
208 |
+
# Print the error info
|
209 |
+
print("Error catched when searching a clean chunk:", e)
|
210 |
+
traceback.print_exc()
|
211 |
+
read_chunk_item_indices = []
|
212 |
+
read_chunk_idx = (read_chunk_idx + 1) % self.num_chunks
|
213 |
+
|
214 |
+
# read_chunk_item_index = random.choice(read_chunk_item_indices)
|
215 |
+
# read_chunk_item_index = read_chunk_item_indices.pop()
|
216 |
+
random_item_index = index % len(read_chunk_item_indices)
|
217 |
+
read_chunk_item_index = read_chunk_item_indices[random_item_index]
|
218 |
+
|
219 |
+
# Modify the dirty bit
|
220 |
+
try:
|
221 |
+
dirty_bit = read_dirty_bit(read_chunk_dir)
|
222 |
+
dirty_bit[read_chunk_item_index] = 1
|
223 |
+
save_dirty_bit(read_chunk_dir, dirty_bit)
|
224 |
+
except BaseException as e:
|
225 |
+
# Print the error info
|
226 |
+
print("Error catched when modifying the dirty bit:", e)
|
227 |
+
traceback.print_exc()
|
228 |
+
|
229 |
+
# load the sample
|
230 |
+
try:
|
231 |
+
content, meta = self._load_data_from_chunk(read_chunk_dir, read_chunk_item_index)
|
232 |
+
self.last_content, self.last_meta = content, meta
|
233 |
+
except BaseException as e:
|
234 |
+
# Print the error info
|
235 |
+
print("Error catched when loading sample:", e)
|
236 |
+
traceback.print_exc()
|
237 |
+
|
238 |
+
# If failed to load the data, return the last loaded data for robustness
|
239 |
+
content, meta = self.last_content, self.last_meta
|
240 |
+
|
241 |
+
return (content, *meta)
|
242 |
+
|
243 |
+
def __getitem__(self, index):
|
244 |
+
# For robustness, we will try to load the data until we succeed
|
245 |
+
while True:
|
246 |
+
data_dict = None
|
247 |
+
try:
|
248 |
+
if self.use_hdf5:
|
249 |
+
res = self.hdf5_dataset.get_item()
|
250 |
+
content = res['meta']
|
251 |
+
states = res['state']
|
252 |
+
actions = res['actions']
|
253 |
+
state_elem_mask = res['state_indicator']
|
254 |
+
image_metas = [
|
255 |
+
res['cam_high'], res['cam_high_mask'],
|
256 |
+
res['cam_right_wrist'], res['cam_right_wrist_mask'],
|
257 |
+
res['cam_left_wrist'], res['cam_left_wrist_mask'],
|
258 |
+
]
|
259 |
+
state_std = res['state_std']
|
260 |
+
state_mean = res['state_mean']
|
261 |
+
state_norm = res['state_norm']
|
262 |
+
else:
|
263 |
+
(content, _, states, _, actions, _,
|
264 |
+
state_elem_mask, *image_metas,
|
265 |
+
state_std, state_mean, state_norm) = self._safe_load(index)
|
266 |
+
|
267 |
+
data_dict = {}
|
268 |
+
data_dict['dataset_name'] = content['dataset_name']
|
269 |
+
data_dict['data_idx'] = self.dataset_name2id[data_dict['dataset_name']]
|
270 |
+
data_dict['ctrl_freq'] = self.control_freq[data_dict['dataset_name']] \
|
271 |
+
if random.random() > self.cond_mask_prob else 0
|
272 |
+
|
273 |
+
if self.state_noise_snr is not None:
|
274 |
+
states += np.random.normal(
|
275 |
+
0.0, state_std / np.sqrt(10 ** (self.state_noise_snr / 10)),
|
276 |
+
states.shape)
|
277 |
+
ds_state_mean = np.array(self.dataset_stat[data_dict['dataset_name']]['state_mean'])
|
278 |
+
ds_state_mean = np.tile(ds_state_mean[None], (states.shape[0], 1))
|
279 |
+
# Randomly mask the states by the mean state
|
280 |
+
data_dict["states"] = states \
|
281 |
+
if random.random() > self.cond_mask_prob else ds_state_mean
|
282 |
+
data_dict["actions"] = actions
|
283 |
+
data_dict["state_elem_mask"] = state_elem_mask \
|
284 |
+
if random.random() > self.cond_mask_prob else np.zeros_like(state_elem_mask)
|
285 |
+
|
286 |
+
# Stat for the episode that the step belongs to
|
287 |
+
data_dict["state_norm"] = state_norm
|
288 |
+
|
289 |
+
# We replace the invalid images with the background image
|
290 |
+
# and also randomly mask images by the background image
|
291 |
+
background_color = np.array([
|
292 |
+
int(x*255) for x in self.image_processor.image_mean
|
293 |
+
], dtype=np.uint8).reshape(1, 1, 3)
|
294 |
+
background_image = np.ones((
|
295 |
+
self.image_processor.size["height"],
|
296 |
+
self.image_processor.size["width"], 3), dtype=np.uint8
|
297 |
+
) * background_color
|
298 |
+
|
299 |
+
image_metas = list(self.pairwise(image_metas))
|
300 |
+
mask_probs = [self.cond_mask_prob] * self.num_cameras
|
301 |
+
if self.cam_ext_mask_prob >= 0.0:
|
302 |
+
mask_probs[0] = self.cam_ext_mask_prob
|
303 |
+
rearranged_images = []
|
304 |
+
for i in range(self.img_history_size):
|
305 |
+
for j in range(self.num_cameras):
|
306 |
+
images, image_mask = image_metas[j]
|
307 |
+
image, valid = images[i], image_mask[i]
|
308 |
+
if valid and (math.prod(image.shape) > 0) and \
|
309 |
+
(random.random() > mask_probs[j]):
|
310 |
+
rearranged_images.append((image, True))
|
311 |
+
else:
|
312 |
+
rearranged_images.append((background_image.copy(), False))
|
313 |
+
|
314 |
+
preprocessed_images = []
|
315 |
+
processor = self.image_processor
|
316 |
+
for image, valid in rearranged_images:
|
317 |
+
image = Image.fromarray(image)
|
318 |
+
if self.image_size is not None:
|
319 |
+
image = transforms.Resize(self.image_size)(image) # (1008, 336)
|
320 |
+
# assert image.height == 336, "We haven't prepare for training with images of different resolutions."
|
321 |
+
|
322 |
+
if valid and self.auto_adjust_image_brightness:
|
323 |
+
pixel_values = list(image.getdata())
|
324 |
+
average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
|
325 |
+
if average_brightness <= 0.15:
|
326 |
+
image = transforms.ColorJitter(brightness=(1.75,1.75))(image)
|
327 |
+
|
328 |
+
# Only apply image augmentation to 50% of the images
|
329 |
+
if valid and self.image_aug and (random.random() > 0.5):
|
330 |
+
aug_type = random.choice([
|
331 |
+
"corrput_only", "color_only", "both"])
|
332 |
+
if aug_type != "corrput_only":
|
333 |
+
image = transforms.ColorJitter(
|
334 |
+
brightness=0.3, contrast=0.4, saturation=0.5, hue=0.03)(image)
|
335 |
+
if aug_type != "color_only":
|
336 |
+
image = image_corrupt(image)
|
337 |
+
|
338 |
+
if self.image_aspect_ratio == 'pad':
|
339 |
+
def expand2square(pil_img, background_color):
|
340 |
+
width, height = pil_img.size
|
341 |
+
if width == height:
|
342 |
+
return pil_img
|
343 |
+
elif width > height:
|
344 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
345 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
346 |
+
return result
|
347 |
+
else:
|
348 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
349 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
350 |
+
return result
|
351 |
+
image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
|
352 |
+
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
353 |
+
preprocessed_images.append(image)
|
354 |
+
data_dict["images"] = preprocessed_images
|
355 |
+
|
356 |
+
if self.use_precomp_lang_embed:
|
357 |
+
if content["instruction"][-1] == ".":
|
358 |
+
content["instruction"] = content["instruction"][:-1]
|
359 |
+
data_dict["lang_embed"] = torch.load(content["instruction"])['embeddings'][0] \
|
360 |
+
if random.random() > self.cond_mask_prob else self.empty_lang_embed ##FIXED
|
361 |
+
else:
|
362 |
+
instruction = content["instruction"] \
|
363 |
+
if random.random() > self.cond_mask_prob else ""
|
364 |
+
data_dict["input_ids"] = self.tokenizer(
|
365 |
+
instruction,
|
366 |
+
return_tensors="pt",
|
367 |
+
padding="longest",
|
368 |
+
truncation=False,
|
369 |
+
).input_ids[0]
|
370 |
+
|
371 |
+
assert len(data_dict["input_ids"]) <= self.tokenizer_max_length, \
|
372 |
+
f"Instruction length {len(data_dict['input_ids'])} exceeds the maximum length {self.tokenizer_max_length}."
|
373 |
+
|
374 |
+
for k, v in data_dict.items():
|
375 |
+
if isinstance(v, np.ndarray):
|
376 |
+
data_dict[k] = torch.from_numpy(v)
|
377 |
+
|
378 |
+
for k, v in data_dict.items():
|
379 |
+
assert not isinstance(v, np.ndarray), f"key: {k}, value: {v}"
|
380 |
+
# data_dict[k] = torch.from_numpy(v)
|
381 |
+
|
382 |
+
return data_dict
|
383 |
+
except BaseException as e:
|
384 |
+
# Print the error info
|
385 |
+
if data_dict is not None:
|
386 |
+
print(f"Error catched when processing sample from {data_dict.get('dataset_name')}:", e)
|
387 |
+
else:
|
388 |
+
print(f"Error catched when processing sample:", e)
|
389 |
+
traceback.print_exc()
|
390 |
+
# Try incresing the index
|
391 |
+
index = (index + 1) % len(self)
|
392 |
+
|
393 |
+
|
394 |
+
class DataCollatorForVLAConsumerDataset(object):
|
395 |
+
"""Collate examples for supervised training."""
|
396 |
+
|
397 |
+
def __init__(self, tokenizer: transformers.PreTrainedTokenizer) -> None:
|
398 |
+
self.tokenizer = tokenizer
|
399 |
+
|
400 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
401 |
+
batch = {
|
402 |
+
"states": [],
|
403 |
+
"actions": [],
|
404 |
+
"state_elem_mask": [],
|
405 |
+
"state_norm": [],
|
406 |
+
"images": [],
|
407 |
+
"data_indices": [],
|
408 |
+
"ctrl_freqs": []
|
409 |
+
}
|
410 |
+
input_ids = []
|
411 |
+
lang_embeds = []
|
412 |
+
lang_embed_lens = []
|
413 |
+
|
414 |
+
for instance in instances:
|
415 |
+
# Convert all the numpy arrays to tensor
|
416 |
+
keys_to_check = [
|
417 |
+
'states', 'actions',
|
418 |
+
'state_elem_mask', 'state_norm',
|
419 |
+
]
|
420 |
+
for key in keys_to_check:
|
421 |
+
if isinstance(instance[key], torch.Tensor):
|
422 |
+
item = instance[key]
|
423 |
+
else:
|
424 |
+
item = torch.from_numpy(instance[key])
|
425 |
+
batch[key].append(item)
|
426 |
+
|
427 |
+
if "input_ids" in instance:
|
428 |
+
input_ids.append(instance["input_ids"])
|
429 |
+
else:
|
430 |
+
lang_embeds.append(instance["lang_embed"])
|
431 |
+
lang_embed_lens.append(instance["lang_embed"].shape[0])
|
432 |
+
|
433 |
+
batch["images"].append(torch.stack(instance["images"], dim=0))
|
434 |
+
batch["data_indices"].append(instance["data_idx"])
|
435 |
+
batch["ctrl_freqs"].append(instance["ctrl_freq"])
|
436 |
+
|
437 |
+
keys_to_stack = [
|
438 |
+
'states', 'actions',
|
439 |
+
'state_elem_mask', 'state_norm',
|
440 |
+
"images"
|
441 |
+
]
|
442 |
+
for key in keys_to_stack:
|
443 |
+
batch[key] = torch.stack(batch[key], dim=0)
|
444 |
+
|
445 |
+
batch["ctrl_freqs"] = torch.tensor(batch["ctrl_freqs"])
|
446 |
+
|
447 |
+
if len(input_ids) > 0:
|
448 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
449 |
+
input_ids,
|
450 |
+
batch_first=True,
|
451 |
+
padding_value=self.tokenizer.pad_token_id)
|
452 |
+
batch["input_ids"] = input_ids
|
453 |
+
batch["lang_attn_mask"] = input_ids.ne(self.tokenizer.pad_token_id)
|
454 |
+
else:
|
455 |
+
lang_embeds = torch.nn.utils.rnn.pad_sequence(
|
456 |
+
lang_embeds,
|
457 |
+
batch_first=True,
|
458 |
+
padding_value=0)
|
459 |
+
input_lang_attn_mask = torch.zeros(
|
460 |
+
lang_embeds.shape[0], lang_embeds.shape[1], dtype=torch.bool)
|
461 |
+
for i, l in enumerate(lang_embed_lens):
|
462 |
+
input_lang_attn_mask[i, :l] = True
|
463 |
+
batch["lang_embeds"] = lang_embeds
|
464 |
+
batch["lang_attn_mask"] = input_lang_attn_mask
|
465 |
+
|
466 |
+
|
467 |
+
return batch
|
train/image_corrupt.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
warnings.simplefilter(action='ignore', category=FutureWarning)
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
np.bool = np.bool_
|
6 |
+
import imgaug.augmenters as iaa
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
|
10 |
+
# Define our sequence of augmentation steps that will be applied to every image.
|
11 |
+
seq = iaa.Sequential(
|
12 |
+
[
|
13 |
+
# Execute one of the following noise augmentations
|
14 |
+
iaa.OneOf([
|
15 |
+
iaa.AdditiveGaussianNoise(
|
16 |
+
loc=0, scale=(0.0, 0.05*255), per_channel=0.5
|
17 |
+
),
|
18 |
+
iaa.AdditiveLaplaceNoise(scale=(0.0, 0.05*255), per_channel=0.5),
|
19 |
+
iaa.AdditivePoissonNoise(lam=(0.0, 0.05*255), per_channel=0.5)
|
20 |
+
]),
|
21 |
+
|
22 |
+
# Execute one or none of the following blur augmentations
|
23 |
+
iaa.SomeOf((0, 1), [
|
24 |
+
iaa.OneOf([
|
25 |
+
iaa.GaussianBlur((0, 3.0)),
|
26 |
+
iaa.AverageBlur(k=(2, 7)),
|
27 |
+
iaa.MedianBlur(k=(3, 11)),
|
28 |
+
]),
|
29 |
+
iaa.MotionBlur(k=(3, 36)),
|
30 |
+
]),
|
31 |
+
],
|
32 |
+
# do all of the above augmentations in random order
|
33 |
+
random_order=True
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
def image_corrupt(image: Image):
|
38 |
+
image_arr = np.array(image)
|
39 |
+
image_arr = image_arr[None, ...]
|
40 |
+
|
41 |
+
image_arr = seq(images=image_arr)
|
42 |
+
|
43 |
+
image = Image.fromarray(image_arr[0])
|
44 |
+
return image
|
train/sample.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
@torch.no_grad()
|
8 |
+
def log_sample_res(
|
9 |
+
text_encoder, vision_encoder, rdt, args,
|
10 |
+
accelerator, weight_dtype, dataset_id2name, dataloader, logger
|
11 |
+
):
|
12 |
+
logger.info(
|
13 |
+
f"Running sampling for {args.num_sample_batches} batches..."
|
14 |
+
)
|
15 |
+
|
16 |
+
rdt.eval()
|
17 |
+
|
18 |
+
loss_for_log = defaultdict(float)
|
19 |
+
loss_counter = defaultdict(int)
|
20 |
+
for step, batch in enumerate(dataloader):
|
21 |
+
if step >= args.num_sample_batches:
|
22 |
+
break
|
23 |
+
|
24 |
+
data_indices = batch["data_indices"]
|
25 |
+
ctrl_freqs = batch["ctrl_freqs"]
|
26 |
+
state_norm = batch["state_norm"].to(dtype=weight_dtype)
|
27 |
+
images = batch["images"].to(dtype=weight_dtype)
|
28 |
+
states = batch["states"].to(dtype=weight_dtype)
|
29 |
+
# We only use the last state as input
|
30 |
+
states = states[:, -1:, :]
|
31 |
+
actions = batch["actions"].to(dtype=weight_dtype)
|
32 |
+
state_elem_mask = batch["state_elem_mask"].to(dtype=weight_dtype)
|
33 |
+
|
34 |
+
batch_size, _, C, H, W = images.shape
|
35 |
+
image_embeds = vision_encoder(images.reshape(-1, C, H, W)).detach()
|
36 |
+
image_embeds = image_embeds.reshape((batch_size, -1, vision_encoder.hidden_size))
|
37 |
+
|
38 |
+
lang_attn_mask = batch["lang_attn_mask"]
|
39 |
+
text_embeds = batch["lang_embeds"].to(dtype=weight_dtype) \
|
40 |
+
if args.precomp_lang_embed \
|
41 |
+
else text_encoder(
|
42 |
+
input_ids=batch["input_ids"],
|
43 |
+
attention_mask=lang_attn_mask
|
44 |
+
)["last_hidden_state"].detach()
|
45 |
+
|
46 |
+
with torch.autocast(device_type='cuda',dtype=torch.bfloat16):
|
47 |
+
pred_actions = rdt.predict_action(
|
48 |
+
lang_tokens=text_embeds,
|
49 |
+
lang_attn_mask=lang_attn_mask,
|
50 |
+
img_tokens=image_embeds,
|
51 |
+
state_tokens=states,
|
52 |
+
action_mask=state_elem_mask.unsqueeze(1),
|
53 |
+
ctrl_freqs=ctrl_freqs
|
54 |
+
)
|
55 |
+
|
56 |
+
num_steps = pred_actions.shape[1]
|
57 |
+
expanded_state_elem_mask = state_elem_mask.unsqueeze(1).tile((1, num_steps, 1)).float()
|
58 |
+
expanded_state_norm = state_norm.unsqueeze(1).tile((1, num_steps, 1)).float()
|
59 |
+
|
60 |
+
loss = F.mse_loss(pred_actions, actions, reduction='none').float()
|
61 |
+
|
62 |
+
mse_loss_per_entry = ((loss * expanded_state_elem_mask).reshape((batch_size, -1)).sum(1)
|
63 |
+
/ expanded_state_elem_mask.reshape((batch_size, -1)).sum(1))
|
64 |
+
l2_loss_per_entry = loss.sqrt() / (expanded_state_norm + 1e-3)
|
65 |
+
l2_loss_per_entry = ((l2_loss_per_entry * expanded_state_elem_mask).reshape((batch_size, -1)).sum(1)
|
66 |
+
/ expanded_state_elem_mask.reshape((batch_size, -1)).sum(1))
|
67 |
+
|
68 |
+
dataset_indices, mse_losses, l2_losses = accelerator.gather_for_metrics(
|
69 |
+
(torch.LongTensor(data_indices).to(device=pred_actions.device),
|
70 |
+
mse_loss_per_entry, l2_loss_per_entry),
|
71 |
+
)
|
72 |
+
dataset_indices = dataset_indices.tolist()
|
73 |
+
if accelerator.is_main_process:
|
74 |
+
for loss_suffix, losses in zip(["_sample_mse", "_sample_l2err"], [mse_losses, l2_losses]):
|
75 |
+
for dataset_idx, loss_tensor in zip(dataset_indices, losses):
|
76 |
+
loss_name = dataset_id2name[dataset_idx] + loss_suffix
|
77 |
+
loss_for_log[loss_name] += loss_tensor.item()
|
78 |
+
loss_counter[loss_name] += 1
|
79 |
+
|
80 |
+
mse_loss = (loss * expanded_state_elem_mask).sum() / expanded_state_elem_mask.sum()
|
81 |
+
mse_loss_scaler = accelerator.gather(mse_loss).mean().item()
|
82 |
+
loss_for_log["overall_avg_sample_mse"] += mse_loss_scaler
|
83 |
+
|
84 |
+
l2_loss = loss.sqrt() / (expanded_state_norm + 1e-3)
|
85 |
+
l2_loss = (l2_loss * expanded_state_elem_mask).sum() / expanded_state_elem_mask.sum()
|
86 |
+
l2_loss_scaler = accelerator.gather(l2_loss).mean().item()
|
87 |
+
loss_for_log["overall_avg_sample_l2err"] += l2_loss_scaler
|
88 |
+
|
89 |
+
for name in loss_for_log:
|
90 |
+
if name in ["overall_avg_sample_mse", "overall_avg_sample_l2err"]:
|
91 |
+
loss_scaler = loss_for_log[name]
|
92 |
+
loss_for_log[name] = round(loss_scaler / (args.num_sample_batches), 4)
|
93 |
+
else:
|
94 |
+
loss_for_log[name] = round(loss_for_log[name] / loss_counter[name], 4)
|
95 |
+
|
96 |
+
rdt.train()
|
97 |
+
torch.cuda.empty_cache()
|
98 |
+
|
99 |
+
return dict(loss_for_log)
|
train/train.py
ADDED
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
|
16 |
+
import copy
|
17 |
+
import logging
|
18 |
+
import math
|
19 |
+
import os
|
20 |
+
from pathlib import Path
|
21 |
+
|
22 |
+
import diffusers
|
23 |
+
import torch
|
24 |
+
import torch.utils.checkpoint
|
25 |
+
import transformers
|
26 |
+
import yaml
|
27 |
+
from accelerate import Accelerator
|
28 |
+
from accelerate.utils import DeepSpeedPlugin, ProjectConfiguration, set_seed
|
29 |
+
from diffusers.optimization import get_scheduler
|
30 |
+
from diffusers.utils import is_wandb_available
|
31 |
+
from huggingface_hub import create_repo, upload_folder
|
32 |
+
from tqdm.auto import tqdm
|
33 |
+
from safetensors.torch import load_model
|
34 |
+
|
35 |
+
from models.ema_model import EMAModel
|
36 |
+
from models.multimodal_encoder.siglip_encoder import SiglipVisionTower
|
37 |
+
from models.multimodal_encoder.t5_encoder import T5Embedder
|
38 |
+
from models.rdt_runner import RDTRunner
|
39 |
+
from train.dataset import DataCollatorForVLAConsumerDataset, VLAConsumerDataset
|
40 |
+
from train.sample import log_sample_res
|
41 |
+
|
42 |
+
|
43 |
+
if is_wandb_available():
|
44 |
+
import wandb
|
45 |
+
|
46 |
+
|
47 |
+
def save_model_card(repo_id: str, base_model=str, repo_folder=None):
|
48 |
+
yaml = f"""
|
49 |
+
---
|
50 |
+
license: mit
|
51 |
+
base_model: {base_model}
|
52 |
+
language:
|
53 |
+
- en
|
54 |
+
pipeline_tag: robotics
|
55 |
+
library_name: transformers
|
56 |
+
tags:
|
57 |
+
- robotics
|
58 |
+
- pytorch
|
59 |
+
- multimodal
|
60 |
+
- pretraining
|
61 |
+
- vla
|
62 |
+
- diffusion
|
63 |
+
- rdt
|
64 |
+
---
|
65 |
+
"""
|
66 |
+
model_card = f"""
|
67 |
+
# RDT - {repo_id}
|
68 |
+
|
69 |
+
This is a RDT model derived from {base_model}. The weights were trained using [RDT](https://rdt-robotics.github.io/rdt-robotics/).
|
70 |
+
"""
|
71 |
+
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
72 |
+
f.write(yaml + model_card)
|
73 |
+
|
74 |
+
|
75 |
+
def train(args, logger):
|
76 |
+
# Read the config
|
77 |
+
with open(args.config_path, "r") as fp:
|
78 |
+
config = yaml.safe_load(fp)
|
79 |
+
|
80 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
81 |
+
|
82 |
+
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
|
83 |
+
accelerator = Accelerator(
|
84 |
+
deepspeed_plugin=DeepSpeedPlugin(
|
85 |
+
hf_ds_config=args.deepspeed
|
86 |
+
) if args.deepspeed is not None else None,
|
87 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
88 |
+
mixed_precision=args.mixed_precision,
|
89 |
+
log_with=args.report_to,
|
90 |
+
project_dir=logging_dir,
|
91 |
+
project_config=accelerator_project_config,
|
92 |
+
)
|
93 |
+
|
94 |
+
if args.report_to == "wandb":
|
95 |
+
if not is_wandb_available():
|
96 |
+
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
97 |
+
|
98 |
+
# Make one log on every process with the configuration for debugging.
|
99 |
+
logging.basicConfig(
|
100 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
101 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
102 |
+
level=logging.INFO,
|
103 |
+
)
|
104 |
+
logger.info(accelerator.state, main_process_only=False)
|
105 |
+
if accelerator.is_local_main_process:
|
106 |
+
transformers.utils.logging.set_verbosity_warning()
|
107 |
+
diffusers.utils.logging.set_verbosity_info()
|
108 |
+
else:
|
109 |
+
transformers.utils.logging.set_verbosity_error()
|
110 |
+
diffusers.utils.logging.set_verbosity_error()
|
111 |
+
|
112 |
+
# If passed along, set the training seed now.
|
113 |
+
if args.seed is not None:
|
114 |
+
set_seed(args.seed)
|
115 |
+
|
116 |
+
# Handle the repository creation
|
117 |
+
if accelerator.is_main_process:
|
118 |
+
if args.output_dir is not None:
|
119 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
120 |
+
|
121 |
+
if args.push_to_hub:
|
122 |
+
repo_id = create_repo(
|
123 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
124 |
+
).repo_id
|
125 |
+
|
126 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
127 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
128 |
+
weight_dtype = torch.float32
|
129 |
+
if accelerator.mixed_precision == "fp16":
|
130 |
+
weight_dtype = torch.float16
|
131 |
+
elif accelerator.mixed_precision == "bf16":
|
132 |
+
weight_dtype = torch.bfloat16
|
133 |
+
|
134 |
+
if args.precomp_lang_embed:
|
135 |
+
tokenizer, text_encoder = None, None
|
136 |
+
else:
|
137 |
+
text_embedder = T5Embedder(from_pretrained=args.pretrained_text_encoder_name_or_path,
|
138 |
+
model_max_length=config["dataset"]["tokenizer_max_length"], device=accelerator.device)
|
139 |
+
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
|
140 |
+
|
141 |
+
vision_encoder = SiglipVisionTower(vision_tower=args.pretrained_vision_encoder_name_or_path, args=None)
|
142 |
+
image_processor = vision_encoder.image_processor
|
143 |
+
|
144 |
+
# Load from a pretrained checkpoint
|
145 |
+
if (
|
146 |
+
args.pretrained_model_name_or_path is not None
|
147 |
+
and not os.path.isfile(args.pretrained_model_name_or_path)
|
148 |
+
):
|
149 |
+
logger.info("Constructing model from pretrained checkpoint.")
|
150 |
+
rdt = RDTRunner.from_pretrained(args.pretrained_model_name_or_path)
|
151 |
+
else:
|
152 |
+
logger.info("Constructing model from provided config.")
|
153 |
+
# Calculate the image condition length
|
154 |
+
img_cond_len = (config["common"]["img_history_size"]
|
155 |
+
* config["common"]["num_cameras"]
|
156 |
+
* vision_encoder.num_patches)
|
157 |
+
rdt = RDTRunner(
|
158 |
+
action_dim=config["common"]["state_dim"],
|
159 |
+
pred_horizon=config["common"]["action_chunk_size"],
|
160 |
+
config=config["model"],
|
161 |
+
lang_token_dim=config["model"]["lang_token_dim"],
|
162 |
+
img_token_dim=config["model"]["img_token_dim"],
|
163 |
+
state_token_dim=config["model"]["state_token_dim"],
|
164 |
+
max_lang_cond_len=config["dataset"]["tokenizer_max_length"],
|
165 |
+
img_cond_len=img_cond_len,
|
166 |
+
img_pos_embed_config=[
|
167 |
+
# No initial pos embed in the last grid size
|
168 |
+
# since we've already done in ViT
|
169 |
+
("image", (config["common"]["img_history_size"],
|
170 |
+
config["common"]["num_cameras"],
|
171 |
+
-vision_encoder.num_patches)),
|
172 |
+
],
|
173 |
+
lang_pos_embed_config=[
|
174 |
+
# Similarly, no initial pos embed for language
|
175 |
+
("lang", -config["dataset"]["tokenizer_max_length"]),
|
176 |
+
],
|
177 |
+
dtype=weight_dtype,
|
178 |
+
)
|
179 |
+
|
180 |
+
|
181 |
+
ema_rdt = copy.deepcopy(rdt)
|
182 |
+
ema_model = EMAModel(
|
183 |
+
ema_rdt,
|
184 |
+
update_after_step=config["model"]["ema"]["update_after_step"],
|
185 |
+
inv_gamma=config["model"]["ema"]["inv_gamma"],
|
186 |
+
power=config["model"]["ema"]["power"],
|
187 |
+
min_value=config["model"]["ema"]["min_value"],
|
188 |
+
max_value=config["model"]["ema"]["max_value"]
|
189 |
+
)
|
190 |
+
|
191 |
+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
192 |
+
# which ensure saving model in huggingface format (config.json + pytorch_model.bin)
|
193 |
+
def save_model_hook(models, weights, output_dir):
|
194 |
+
if accelerator.is_main_process:
|
195 |
+
for model in models:
|
196 |
+
model_to_save = model.module if hasattr(model, "module") else model # type: ignore
|
197 |
+
if isinstance(model_to_save, type(accelerator.unwrap_model(rdt))):
|
198 |
+
model_to_save.save_pretrained(output_dir)
|
199 |
+
|
200 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
201 |
+
|
202 |
+
if args.gradient_checkpointing:
|
203 |
+
# TODO:
|
204 |
+
raise NotImplementedError("Gradient checkpointing is not yet implemented.")
|
205 |
+
|
206 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
207 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
208 |
+
if args.allow_tf32:
|
209 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
210 |
+
|
211 |
+
if args.scale_lr:
|
212 |
+
args.learning_rate = (
|
213 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
214 |
+
)
|
215 |
+
|
216 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
217 |
+
if args.use_8bit_adam:
|
218 |
+
try:
|
219 |
+
import bitsandbytes as bnb
|
220 |
+
except ImportError:
|
221 |
+
raise ImportError(
|
222 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
223 |
+
)
|
224 |
+
|
225 |
+
optimizer_class = bnb.optim.AdamW8bit
|
226 |
+
else:
|
227 |
+
optimizer_class = torch.optim.AdamW
|
228 |
+
|
229 |
+
# Optimizer creation
|
230 |
+
params_to_optimize = rdt.parameters()
|
231 |
+
optimizer = optimizer_class(
|
232 |
+
params_to_optimize,
|
233 |
+
lr=args.learning_rate,
|
234 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
235 |
+
weight_decay=args.adam_weight_decay,
|
236 |
+
eps=args.adam_epsilon,
|
237 |
+
)
|
238 |
+
|
239 |
+
# Dataset and DataLoaders creation:
|
240 |
+
train_dataset = VLAConsumerDataset(
|
241 |
+
config=config["dataset"],
|
242 |
+
tokenizer=tokenizer,
|
243 |
+
image_processor=image_processor,
|
244 |
+
num_cameras=config["common"]["num_cameras"],
|
245 |
+
img_history_size=config["common"]["img_history_size"],
|
246 |
+
dataset_type=args.dataset_type,
|
247 |
+
image_aug=args.image_aug,
|
248 |
+
cond_mask_prob=args.cond_mask_prob,
|
249 |
+
cam_ext_mask_prob=args.cam_ext_mask_prob,
|
250 |
+
state_noise_snr=args.state_noise_snr,
|
251 |
+
use_hdf5=args.load_from_hdf5,
|
252 |
+
use_precomp_lang_embed=args.precomp_lang_embed,
|
253 |
+
task_name=args.dataset_name,
|
254 |
+
)
|
255 |
+
sample_dataset = VLAConsumerDataset(
|
256 |
+
config=config["dataset"],
|
257 |
+
tokenizer=tokenizer,
|
258 |
+
image_processor=image_processor,
|
259 |
+
num_cameras=config["common"]["num_cameras"],
|
260 |
+
img_history_size=config["common"]["img_history_size"],
|
261 |
+
dataset_type=args.dataset_type,
|
262 |
+
image_aug=False,
|
263 |
+
cond_mask_prob=0,
|
264 |
+
cam_ext_mask_prob=-1,
|
265 |
+
state_noise_snr=None,
|
266 |
+
use_hdf5=args.load_from_hdf5,
|
267 |
+
use_precomp_lang_embed=args.precomp_lang_embed,
|
268 |
+
task_name=args.dataset_name,
|
269 |
+
)
|
270 |
+
|
271 |
+
data_collator = DataCollatorForVLAConsumerDataset(tokenizer)
|
272 |
+
|
273 |
+
train_dataloader = torch.utils.data.DataLoader(
|
274 |
+
train_dataset,
|
275 |
+
batch_size=args.train_batch_size,
|
276 |
+
shuffle=True,
|
277 |
+
collate_fn=data_collator,
|
278 |
+
num_workers=args.dataloader_num_workers,
|
279 |
+
pin_memory=True,
|
280 |
+
persistent_workers=True
|
281 |
+
)
|
282 |
+
sample_dataloader = torch.utils.data.DataLoader(
|
283 |
+
sample_dataset,
|
284 |
+
batch_size=args.sample_batch_size,
|
285 |
+
shuffle=True,
|
286 |
+
collate_fn=data_collator,
|
287 |
+
num_workers=args.dataloader_num_workers,
|
288 |
+
pin_memory=True,
|
289 |
+
persistent_workers=True
|
290 |
+
)
|
291 |
+
|
292 |
+
# Scheduler and math around the number of training steps.
|
293 |
+
overrode_max_train_steps = False
|
294 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
295 |
+
if args.max_train_steps is None:
|
296 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
297 |
+
overrode_max_train_steps = True
|
298 |
+
|
299 |
+
lr_scheduler = get_scheduler(
|
300 |
+
args.lr_scheduler,
|
301 |
+
optimizer=optimizer,
|
302 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
303 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
304 |
+
num_cycles=args.lr_num_cycles,
|
305 |
+
power=args.lr_power,
|
306 |
+
)
|
307 |
+
|
308 |
+
# Prepare everything with our `accelerator`.
|
309 |
+
rdt, optimizer, train_dataloader, sample_dataloader, lr_scheduler = accelerator.prepare(
|
310 |
+
rdt, optimizer, train_dataloader, sample_dataloader, lr_scheduler
|
311 |
+
)
|
312 |
+
|
313 |
+
ema_rdt.to(accelerator.device, dtype=weight_dtype)
|
314 |
+
|
315 |
+
if text_encoder is not None:
|
316 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
317 |
+
|
318 |
+
if vision_encoder is not None:
|
319 |
+
vision_encoder.vision_tower.to(accelerator.device, dtype=weight_dtype)
|
320 |
+
|
321 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
322 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
323 |
+
if overrode_max_train_steps:
|
324 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
325 |
+
# Afterwards we recalculate our number of training epochs
|
326 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
327 |
+
|
328 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
329 |
+
# The trackers initializes automatically on the main process.
|
330 |
+
if accelerator.is_main_process:
|
331 |
+
accelerator.init_trackers("roboticDiffusionTransformer", config=vars(args))
|
332 |
+
|
333 |
+
# Train!
|
334 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
335 |
+
|
336 |
+
logger.info("***** Running training *****")
|
337 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
338 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
339 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
340 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
341 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
342 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
343 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
344 |
+
global_step = 0
|
345 |
+
first_epoch = 0
|
346 |
+
|
347 |
+
# Load from a pretrained checkpoint
|
348 |
+
if (
|
349 |
+
args.resume_from_checkpoint is None
|
350 |
+
and args.pretrained_model_name_or_path is not None
|
351 |
+
and os.path.isfile(args.pretrained_model_name_or_path)
|
352 |
+
):
|
353 |
+
# Since EMA is deprecated, we do not load EMA from the pretrained checkpoint
|
354 |
+
logger.info("Loading from a pretrained checkpoint.")
|
355 |
+
checkpoint = torch.load(args.pretrained_model_name_or_path)
|
356 |
+
rdt.module.load_state_dict(checkpoint["module"])
|
357 |
+
|
358 |
+
# Potentially load in the weights and states from a previous save
|
359 |
+
if args.resume_from_checkpoint:
|
360 |
+
if args.resume_from_checkpoint != "latest":
|
361 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
362 |
+
else:
|
363 |
+
# Get the mos recent checkpoint
|
364 |
+
dirs = os.listdir(args.output_dir)
|
365 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
366 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
367 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
368 |
+
|
369 |
+
if path is None:
|
370 |
+
accelerator.print(
|
371 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
372 |
+
)
|
373 |
+
args.resume_from_checkpoint = None
|
374 |
+
else:
|
375 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
376 |
+
try:
|
377 |
+
accelerator.load_state(os.path.join(args.output_dir, path)) # load_module_strict=False
|
378 |
+
except:
|
379 |
+
# load deepspeed's state_dict
|
380 |
+
logger.info("Resuming training state failed. Attempting to only load from model checkpoint.")
|
381 |
+
checkpoint = torch.load(os.path.join(args.output_dir, path, "pytorch_model", "mp_rank_00_model_states.pt"))
|
382 |
+
rdt.module.load_state_dict(checkpoint["module"])
|
383 |
+
|
384 |
+
load_model(ema_rdt, os.path.join(args.output_dir, path, "ema", "model.safetensors"))
|
385 |
+
global_step = int(path.split("-")[1])
|
386 |
+
|
387 |
+
resume_global_step = global_step * args.gradient_accumulation_steps
|
388 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
389 |
+
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
390 |
+
|
391 |
+
# Only show the progress bar once on each machine.
|
392 |
+
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
393 |
+
progress_bar.set_description("Steps")
|
394 |
+
|
395 |
+
loss_for_log = {}
|
396 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
397 |
+
|
398 |
+
rdt.train()
|
399 |
+
|
400 |
+
# Set the progress_bar to correct position
|
401 |
+
if args.resume_from_checkpoint and epoch == first_epoch:
|
402 |
+
progress_bar.update(resume_step // args.gradient_accumulation_steps)
|
403 |
+
|
404 |
+
# Forward and backward...
|
405 |
+
for batch in train_dataloader:
|
406 |
+
with accelerator.accumulate(rdt):
|
407 |
+
images = batch["images"].to(dtype=weight_dtype)
|
408 |
+
states = batch["states"].to(dtype=weight_dtype) # (B, T, D_a)
|
409 |
+
# We only use the last state as input
|
410 |
+
states = states[:, -1:, :]
|
411 |
+
actions = batch["actions"].to(dtype=weight_dtype)
|
412 |
+
state_elem_mask = batch["state_elem_mask"].to(dtype=weight_dtype)
|
413 |
+
ctrl_freqs = batch["ctrl_freqs"]
|
414 |
+
|
415 |
+
with torch.no_grad():
|
416 |
+
batch_size, _, C, H, W = images.shape
|
417 |
+
image_embeds = vision_encoder(images.reshape(-1, C, H, W)).detach()
|
418 |
+
image_embeds = image_embeds.reshape((batch_size, -1, vision_encoder.hidden_size))
|
419 |
+
|
420 |
+
lang_attn_mask = batch["lang_attn_mask"]
|
421 |
+
text_embeds = batch["lang_embeds"].to(dtype=weight_dtype) \
|
422 |
+
if args.precomp_lang_embed \
|
423 |
+
else text_encoder(
|
424 |
+
input_ids=batch["input_ids"],
|
425 |
+
attention_mask=lang_attn_mask
|
426 |
+
)["last_hidden_state"].detach()
|
427 |
+
|
428 |
+
state_elem_mask = state_elem_mask.unsqueeze(1)
|
429 |
+
loss = rdt(
|
430 |
+
lang_tokens=text_embeds,
|
431 |
+
lang_attn_mask=lang_attn_mask,
|
432 |
+
img_tokens=image_embeds,
|
433 |
+
state_tokens=states,
|
434 |
+
action_gt=actions,
|
435 |
+
action_mask=state_elem_mask,
|
436 |
+
ctrl_freqs=ctrl_freqs
|
437 |
+
)
|
438 |
+
|
439 |
+
accelerator.backward(loss)
|
440 |
+
if accelerator.sync_gradients:
|
441 |
+
params_to_clip = rdt.parameters()
|
442 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
443 |
+
optimizer.step()
|
444 |
+
lr_scheduler.step()
|
445 |
+
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
|
446 |
+
|
447 |
+
ema_model.step(accelerator.unwrap_model(rdt))
|
448 |
+
|
449 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
450 |
+
if accelerator.sync_gradients:
|
451 |
+
progress_bar.update(1)
|
452 |
+
global_step += 1
|
453 |
+
|
454 |
+
if global_step % args.checkpointing_period == 0:
|
455 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
456 |
+
accelerator.save_state(save_path)
|
457 |
+
ema_save_path = os.path.join(save_path, f"ema")
|
458 |
+
accelerator.save_model(ema_rdt, ema_save_path)
|
459 |
+
logger.info(f"Saved state to {save_path}")
|
460 |
+
|
461 |
+
if args.sample_period > 0 and global_step % args.sample_period == 0:
|
462 |
+
sample_loss_for_log = log_sample_res(
|
463 |
+
text_encoder,
|
464 |
+
vision_encoder,
|
465 |
+
rdt, # We do not use EMA currently
|
466 |
+
args,
|
467 |
+
accelerator,
|
468 |
+
weight_dtype,
|
469 |
+
sample_dataset.get_dataset_id2name(),
|
470 |
+
sample_dataloader,
|
471 |
+
logger,
|
472 |
+
)
|
473 |
+
logger.info(sample_loss_for_log)
|
474 |
+
accelerator.log(sample_loss_for_log, step=global_step)
|
475 |
+
|
476 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
477 |
+
progress_bar.set_postfix(**logs)
|
478 |
+
logs.update(loss_for_log)
|
479 |
+
# logger.info(logs)
|
480 |
+
accelerator.log(logs, step=global_step)
|
481 |
+
|
482 |
+
if global_step >= args.max_train_steps:
|
483 |
+
break
|
484 |
+
|
485 |
+
# Create the pipeline using using the trained modules and save it.
|
486 |
+
accelerator.wait_for_everyone()
|
487 |
+
if accelerator.is_main_process:
|
488 |
+
accelerator.unwrap_model(rdt).save_pretrained(args.output_dir)
|
489 |
+
ema_save_path = os.path.join(args.output_dir, f"ema")
|
490 |
+
accelerator.save_model(ema_rdt, ema_save_path)
|
491 |
+
|
492 |
+
logger.info(f"Saved Model to {args.output_dir}")
|
493 |
+
|
494 |
+
if args.push_to_hub:
|
495 |
+
save_model_card(
|
496 |
+
repo_id,
|
497 |
+
base_model=args.pretrained_model_name_or_path,
|
498 |
+
repo_folder=args.output_dir,
|
499 |
+
)
|
500 |
+
upload_folder(
|
501 |
+
repo_id=repo_id,
|
502 |
+
folder_path=args.output_dir,
|
503 |
+
commit_message="End of training",
|
504 |
+
token=args.hub_token,
|
505 |
+
allow_patterns=["pytorch_model.bin", "*.json", "*.md"],
|
506 |
+
# ignore_patterns=["step_*", "epoch_*"],
|
507 |
+
)
|
508 |
+
|
509 |
+
accelerator.end_training()
|