euijinrnd commited on
Commit
d899b9f
·
verified ·
1 Parent(s): eef26ad

Add files using upload-large-folder tool

Browse files
Files changed (44) hide show
  1. configs/base.yaml +71 -0
  2. configs/calvin_rel_traj_location_bounds_task_ABC_D.json +50 -0
  3. configs/dataset_control_freq.json +73 -0
  4. configs/dataset_img_keys.json +674 -0
  5. configs/dataset_stat.json +0 -0
  6. configs/finetune_datasets.json +5 -0
  7. configs/finetune_sample_weights.json +5 -0
  8. configs/pretrain_datasets.json +3 -0
  9. configs/pretrain_sample_weights.json +3 -0
  10. configs/state_vec.py +114 -0
  11. configs/zero2.json +14 -0
  12. data/aloha/hdf5totfrecords.py +98 -0
  13. data/aloha/unzip_data.sh +3 -0
  14. data/bridgev2/bridgedata_numpy_to_tfrecord.py +174 -0
  15. data/bridgev2/bridgedata_raw_to_numpy.py +316 -0
  16. data/bridgev2/download.sh +13 -0
  17. data/calvin/download.sh +19 -0
  18. data/calvin/hdf5totfrecords.py +92 -0
  19. data/rh20t/hdf5totfrecords.py +200 -0
  20. data/roboset/download.py +42 -0
  21. data/roboset/download.sh +21 -0
  22. data/roboset/h5totfrecords.py +82 -0
  23. data/roboset/links.txt +197 -0
  24. docs/pretrain.md +270 -0
  25. docs/test_6drot.py +99 -0
  26. eval_sim/eval_dp.py +166 -0
  27. eval_sim/eval_octo.py +182 -0
  28. eval_sim/eval_openvla.py +175 -0
  29. eval_sim/eval_rdt_maniskill.py +137 -0
  30. lang_embed/aloha_dish_drainer.pt +3 -0
  31. lang_embed/aloha_handover_box.pt +3 -0
  32. lang_embed/aloha_lift_box.pt +3 -0
  33. lang_embed/aloha_shoes_table.pt +3 -0
  34. lang_embed/anubis_brush_to_pan.pt +3 -0
  35. lang_embed/anubis_carrot_to_bag.pt +3 -0
  36. lang_embed/anubis_towel_kirby.pt +3 -0
  37. scripts/agilex_inference.py +658 -0
  38. scripts/agilex_model.py +313 -0
  39. scripts/encode_lang_batch.py +76 -0
  40. scripts/maniskill_model.py +277 -0
  41. train/dataset.py +467 -0
  42. train/image_corrupt.py +44 -0
  43. train/sample.py +99 -0
  44. train/train.py +509 -0
configs/base.yaml ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ common:
2
+ # The number of historical images
3
+ img_history_size: 2
4
+ # The number of future actions to predict
5
+ action_chunk_size: 64
6
+ # The number of cameras to be used in the model
7
+ num_cameras: 3
8
+ # Dimension for state/action, we use the same space for both state and action
9
+ # This MUST be equal to configs/state_vec.py
10
+ state_dim: 128
11
+
12
+
13
+ dataset:
14
+ # We will extract the data from raw dataset
15
+ # and store them in the disk buffer by producer
16
+ # When training, we will read the data
17
+ # randomly from the buffer by consumer
18
+ # The producer will replace the data which has been
19
+ # read by the consumer with new data
20
+
21
+ # The path to the buffer (at least 400GB)
22
+ buf_path: /home/jellyho/RDTBuffer
23
+ # The number of chunks in the buffer
24
+ buf_num_chunks: 128
25
+ # The number of samples (step rather than episode) in each chunk
26
+ buf_chunk_size: 128
27
+
28
+ # We will filter the episodes with length less than `epsd_len_thresh_low`
29
+ epsd_len_thresh_low: 32
30
+ # For those more than `epsd_len_thresh_high`,
31
+ # we will randomly sample `epsd_len_thresh_high` steps each time we load the episode
32
+ # to better balance the training datasets
33
+ epsd_len_thresh_high: 2048
34
+ # How to fit the image size
35
+ image_aspect_ratio: pad
36
+ # Maximum number of language tokens
37
+ tokenizer_max_length: 1024
38
+
39
+ model:
40
+ # Config for condition adpators
41
+ lang_adaptor: mlp2x_gelu
42
+ img_adaptor: mlp2x_gelu
43
+ state_adaptor: mlp3x_gelu
44
+ lang_token_dim: 4096
45
+ img_token_dim: 1152
46
+ # Dim of action or proprioception vector
47
+ # A `state` refers to an action or a proprioception vector
48
+ state_token_dim: 128
49
+ # Config for RDT structure
50
+ rdt:
51
+ # 1B: num_head 32 hidden_size 2048
52
+ hidden_size: 2048
53
+ depth: 28
54
+ num_heads: 32
55
+ cond_pos_embed_type: multimodal
56
+ # For noise scheduler
57
+ noise_scheduler:
58
+ type: ddpm
59
+ num_train_timesteps: 1000
60
+ num_inference_timesteps: 5
61
+ beta_schedule: squaredcos_cap_v2 # Critical choice
62
+ prediction_type: sample
63
+ clip_sample: False
64
+ # For EMA (params averaging)
65
+ # We do not use EMA currently
66
+ ema:
67
+ update_after_step: 0
68
+ inv_gamma: 1.0
69
+ power: 0.75
70
+ min_value: 0.0
71
+ max_value: 0.9999
configs/calvin_rel_traj_location_bounds_task_ABC_D.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "A": [
3
+ [
4
+ -0.2691913843154907,
5
+ -0.21995729207992554,
6
+ -0.182277649641037
7
+ ],
8
+ [
9
+ 0.35127854347229004,
10
+ 0.2769763469696045,
11
+ 0.17159393429756165
12
+ ]
13
+ ],
14
+ "B": [
15
+ [
16
+ -0.2576896846294403,
17
+ -0.22244493663311005,
18
+ -0.20557966828346252
19
+ ],
20
+ [
21
+ 0.32854634523391724,
22
+ 0.2922680974006653,
23
+ 0.17373555898666382
24
+ ]
25
+ ],
26
+ "C": [
27
+ [
28
+ -0.29205888509750366,
29
+ -0.24688798189163208,
30
+ -0.17577645182609558
31
+ ],
32
+ [
33
+ 0.25053921341896057,
34
+ 0.3277084231376648,
35
+ 0.16431939601898193
36
+ ]
37
+ ],
38
+ "D": [
39
+ [
40
+ -0.25131964683532715,
41
+ -0.15233077108860016,
42
+ -0.13294968008995056
43
+ ],
44
+ [
45
+ 0.19209328293800354,
46
+ 0.19344553351402283,
47
+ 0.1370421051979065
48
+ ]
49
+ ]
50
+ }
configs/dataset_control_freq.json ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fractal20220817_data": 3,
3
+ "taco_play": 15,
4
+ "jaco_play": 10,
5
+ "berkeley_cable_routing": 10,
6
+ "nyu_door_opening_surprising_effectiveness": 3,
7
+ "viola": 20,
8
+ "berkeley_autolab_ur5": 5,
9
+ "toto": 30,
10
+ "kuka": 10,
11
+ "language_table": 10,
12
+ "columbia_cairlab_pusht_real": 10,
13
+ "stanford_kuka_multimodal_dataset_converted_externally_to_rlds": 20,
14
+ "nyu_rot_dataset_converted_externally_to_rlds":3,
15
+ "stanford_hydra_dataset_converted_externally_to_rlds": 10,
16
+ "austin_buds_dataset_converted_externally_to_rlds": 20,
17
+ "nyu_franka_play_dataset_converted_externally_to_rlds": 3,
18
+ "maniskill_dataset_converted_externally_to_rlds": 20,
19
+ "furniture_bench_dataset_converted_externally_to_rlds": 10,
20
+ "ucsd_kitchen_dataset_converted_externally_to_rlds": 2,
21
+ "ucsd_pick_and_place_dataset_converted_externally_to_rlds": 3,
22
+ "austin_sailor_dataset_converted_externally_to_rlds": 20,
23
+ "austin_sirius_dataset_converted_externally_to_rlds": 20,
24
+ "bc_z": 10,
25
+ "utokyo_pr2_opening_fridge_converted_externally_to_rlds": 10,
26
+ "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": 10,
27
+ "utokyo_xarm_pick_and_place_converted_externally_to_rlds": 10,
28
+ "utokyo_xarm_bimanual_converted_externally_to_rlds": 10,
29
+ "berkeley_mvp_converted_externally_to_rlds": 5,
30
+ "berkeley_rpt_converted_externally_to_rlds": 30,
31
+ "kaist_nonprehensile_converted_externally_to_rlds": 10,
32
+ "stanford_mask_vit_converted_externally_to_rlds": 0,
33
+ "tokyo_u_lsmo_converted_externally_to_rlds": 10,
34
+ "dlr_sara_pour_converted_externally_to_rlds": 10,
35
+ "dlr_sara_grid_clamp_converted_externally_to_rlds": 10,
36
+ "dlr_edan_shared_control_converted_externally_to_rlds": 5,
37
+ "asu_table_top_converted_externally_to_rlds": 12.5,
38
+ "stanford_robocook_converted_externally_to_rlds": 5,
39
+ "eth_agent_affordances": 66.6,
40
+ "imperialcollege_sawyer_wrist_cam": 10,
41
+ "iamlab_cmu_pickup_insert_converted_externally_to_rlds": 20,
42
+ "uiuc_d3field": 1,
43
+ "utaustin_mutex": 20,
44
+ "berkeley_fanuc_manipulation": 10,
45
+ "cmu_play_fusion": 5,
46
+ "cmu_stretch": 10,
47
+ "berkeley_gnm_recon": 3,
48
+ "berkeley_gnm_cory_hall": 5,
49
+ "berkeley_gnm_sac_son": 10,
50
+ "robo_net": 1,
51
+ "roboturk_real_towercreation": 10,
52
+ "roboturk_real_laundrylayout": 10,
53
+ "roboturk_real_objectsearch": 10,
54
+ "aloha_mobile": 50,
55
+ "aloha_static": 50,
56
+ "roboset": 5,
57
+ "droid": 15,
58
+ "fmb": 10,
59
+ "dobbe": 30,
60
+ "qut_dexterous_manpulation": 30,
61
+ "agilex": 25,
62
+ "rh20t": 10,
63
+ "calvin": 30,
64
+ "bridgev2": 5,
65
+ "aloha_dish_drainer" : 20,
66
+ "aloha_handover_box" : 20,
67
+ "aloha_shoes_table" : 20,
68
+ "aloha_lift_box" : 20,
69
+ "aloha_box_into_pot" : 20,
70
+ "anubis_towel_kirby" : 20,
71
+ "anubis_carrot_to_bag" : 20,
72
+ "anubis_brush_to_pan" : 20
73
+ }
configs/dataset_img_keys.json ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "anubis_towel_kirby": {
3
+ "image_keys": [
4
+ "agentview_image",
5
+ "right_wrist_image",
6
+ "left_wrist_image",
7
+ "agentview_image"
8
+ ],
9
+ "image_mask":[
10
+ 1,1,1,0
11
+ ]
12
+ },
13
+ "anubis_carrot_to_bag": {
14
+ "image_keys": [
15
+ "agentview_image",
16
+ "right_wrist_image",
17
+ "left_wrist_image",
18
+ "agentview_image"
19
+ ],
20
+ "image_mask":[
21
+ 1,1,1,0
22
+ ]
23
+ },
24
+ "anubis_brush_to_pan": {
25
+ "image_keys": [
26
+ "agentview_image",
27
+ "right_wrist_image",
28
+ "left_wrist_image",
29
+ "agentview_image"
30
+ ],
31
+ "image_mask":[
32
+ 1,1,1,0
33
+ ]
34
+ },
35
+ "aloha_box_into_pot": {
36
+ "image_keys": [
37
+ "agentview_image",
38
+ "right_wrist_image",
39
+ "left_wrist_image",
40
+ "agentview_image"
41
+ ],
42
+ "image_mask":[
43
+ 1,1,1,0
44
+ ]
45
+ },
46
+ "aloha_box_into_pot_easy": {
47
+ "image_keys": [
48
+ "agentview_image",
49
+ "right_wrist_image",
50
+ "left_wrist_image",
51
+ "agentview_image"
52
+ ],
53
+ "image_mask":[
54
+ 1,1,1,0
55
+ ]
56
+ },
57
+ "aloha_dish_drainer": {
58
+ "image_keys": [
59
+ "agentview_image",
60
+ "right_wrist_image",
61
+ "left_wrist_image",
62
+ "agentview_image"
63
+ ],
64
+ "image_mask":[
65
+ 1,1,1,0
66
+ ]
67
+ },
68
+ "aloha_handover_box": {
69
+ "image_keys": [
70
+ "agentview_image",
71
+ "right_wrist_image",
72
+ "left_wrist_image",
73
+ "agentview_image"
74
+ ],
75
+ "image_mask":[
76
+ 1,1,1,0
77
+ ]
78
+ },
79
+ "aloha_shoes_table": {
80
+ "image_keys": [
81
+ "agentview_image",
82
+ "right_wrist_image",
83
+ "left_wrist_image",
84
+ "agentview_image"
85
+ ],
86
+ "image_mask":[
87
+ 1,1,1,0
88
+ ]
89
+ },
90
+ "aloha_lift_box": {
91
+ "image_keys": [
92
+ "agentview_image",
93
+ "right_wrist_image",
94
+ "left_wrist_image",
95
+ "agentview_image"
96
+ ],
97
+ "image_mask":[
98
+ 1,1,1,0
99
+ ]
100
+ },
101
+ "fractal20220817_data": {
102
+ "image_keys": [
103
+ "image",
104
+ "image",
105
+ "image",
106
+ "image"
107
+ ],
108
+ "image_mask":[
109
+ 1,0,0,0
110
+ ]
111
+ },
112
+ "taco_play": {
113
+ "image_keys": [
114
+ "rgb_static",
115
+ "rgb_gripper",
116
+ "rgb_static",
117
+ "rgb_static"
118
+ ],
119
+ "image_mask":[
120
+ 1,1,0,0
121
+ ]
122
+ },
123
+ "jaco_play": {
124
+ "image_keys": [
125
+ "image",
126
+ "image_wrist",
127
+ "image_wrist",
128
+ "image_wrist"
129
+ ],
130
+ "image_mask":[
131
+ 1,1,0,0
132
+ ]
133
+ },
134
+ "berkeley_cable_routing": {
135
+ "image_keys": [
136
+ "image",
137
+ "wrist45_image",
138
+ "wrist225_image",
139
+ "top_image"
140
+ ],
141
+ "image_mask":[1,1,0,1]
142
+ },
143
+ "nyu_door_opening_surprising_effectiveness": {
144
+ "image_keys": [
145
+ "image",
146
+ "image",
147
+ "image",
148
+ "image"
149
+ ],
150
+ "image_mask":[1,0,0,0]
151
+ },
152
+ "viola": {
153
+ "image_keys": [
154
+ "agentview_rgb",
155
+ "eye_in_hand_rgb",
156
+ "eye_in_hand_rgb",
157
+ "eye_in_hand_rgb"
158
+ ],
159
+ "image_mask":[1,1,0,0]
160
+ },
161
+ "berkeley_autolab_ur5": {
162
+ "image_keys": [
163
+ "image",
164
+ "hand_image",
165
+ "hand_image",
166
+ "hand_image"
167
+ ],
168
+ "image_mask":[1,1,0,0]
169
+ },
170
+ "toto": {
171
+ "image_keys": [
172
+ "image",
173
+ "image",
174
+ "image",
175
+ "image"
176
+ ],
177
+ "image_mask":[1,0,0,0]
178
+ },
179
+ "kuka": {
180
+ "image_keys": [
181
+ "image",
182
+ "image",
183
+ "image",
184
+ "image"
185
+ ],
186
+ "image_mask":[1,0,0,0]
187
+ },
188
+ "language_table": {
189
+ "image_keys": [
190
+ "rgb",
191
+ "rgb",
192
+ "rgb",
193
+ "rgb"
194
+ ],
195
+ "image_mask":[1,0,0,0]
196
+ },
197
+ "columbia_cairlab_pusht_real": {
198
+ "image_keys": [
199
+ "image",
200
+ "wrist_image",
201
+ "wrist_image",
202
+ "wrist_image"
203
+ ],
204
+ "image_mask":[1,1,0,0]
205
+ },
206
+ "stanford_kuka_multimodal_dataset_converted_externally_to_rlds": {
207
+ "image_keys": [
208
+ "image",
209
+ "image",
210
+ "image",
211
+ "image"
212
+ ],
213
+ "image_mask":[1,0,0,0]
214
+ },
215
+ "nyu_rot_dataset_converted_externally_to_rlds": {
216
+ "image_keys": [
217
+ "image",
218
+ "image",
219
+ "image",
220
+ "image"
221
+ ],
222
+ "image_mask":[1,0,0,0]
223
+ },
224
+ "stanford_hydra_dataset_converted_externally_to_rlds": {
225
+ "image_keys": [
226
+ "image",
227
+ "wrist_image",
228
+ "wrist_image",
229
+ "wrist_image"
230
+ ],
231
+ "image_mask":[1,1,0,0]
232
+ },
233
+ "austin_buds_dataset_converted_externally_to_rlds": {
234
+ "image_keys": [
235
+ "image",
236
+ "wrist_image",
237
+ "wrist_image",
238
+ "wrist_image"
239
+ ],
240
+ "image_mask":[1,1,0,0]
241
+ },
242
+ "nyu_franka_play_dataset_converted_externally_to_rlds": {
243
+ "image_keys": [
244
+ "image",
245
+ "image_additional_view",
246
+ "image_additional_view",
247
+ "image_additional_view"
248
+ ],
249
+ "image_mask":[1,0,0,1]
250
+ },
251
+ "maniskill_dataset_converted_externally_to_rlds": {
252
+ "image_keys": [
253
+ "image",
254
+ "wrist_image",
255
+ "wrist_image",
256
+ "wrist_image"
257
+ ],
258
+ "image_mask":[1,1,0,0]
259
+ },
260
+ "furniture_bench_dataset_converted_externally_to_rlds": {
261
+ "image_keys": [
262
+ "image",
263
+ "wrist_image",
264
+ "wrist_image",
265
+ "wrist_image"
266
+ ],
267
+ "image_mask":[1,1,0,0]
268
+ },
269
+ "ucsd_kitchen_dataset_converted_externally_to_rlds": {
270
+ "image_keys": [
271
+ "image",
272
+ "image",
273
+ "image",
274
+ "image"
275
+ ],
276
+ "image_mask":[1,0,0,0]
277
+ },
278
+ "ucsd_pick_and_place_dataset_converted_externally_to_rlds": {
279
+ "image_keys": [
280
+ "image",
281
+ "image",
282
+ "image",
283
+ "image"
284
+ ],
285
+ "image_mask":[1,0,0,0]
286
+ },
287
+ "austin_sailor_dataset_converted_externally_to_rlds": {
288
+ "image_keys": [
289
+ "image",
290
+ "wrist_image",
291
+ "wrist_image",
292
+ "wrist_image"
293
+ ],
294
+ "image_mask":[1,1,0,0]
295
+ },
296
+ "austin_sirius_dataset_converted_externally_to_rlds": {
297
+ "image_keys": [
298
+ "image",
299
+ "wrist_image",
300
+ "wrist_image",
301
+ "wrist_image"
302
+ ],
303
+ "image_mask":[1,1,0,0]
304
+ },
305
+ "bc_z": {
306
+ "image_keys": [
307
+ "image",
308
+ "image",
309
+ "image",
310
+ "image"
311
+ ],
312
+ "image_mask":[1,0,0,0]
313
+ },
314
+ "utokyo_pr2_opening_fridge_converted_externally_to_rlds": {
315
+ "image_keys": [
316
+ "image",
317
+ "image",
318
+ "image",
319
+ "image"
320
+ ],
321
+ "image_mask":[1,0,0,0]
322
+ },
323
+ "utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": {
324
+ "image_keys": [
325
+ "image",
326
+ "image",
327
+ "image",
328
+ "image"
329
+ ],
330
+ "image_mask":[1,0,0,0]
331
+ },
332
+ "utokyo_xarm_pick_and_place_converted_externally_to_rlds": {
333
+ "image_keys": [
334
+ "image",
335
+ "hand_image",
336
+ "hand_image",
337
+ "image2"
338
+ ],
339
+ "image_mask":[1,1,0,1]
340
+ },
341
+ "utokyo_xarm_bimanual_converted_externally_to_rlds": {
342
+ "image_keys": [
343
+ "image",
344
+ "image",
345
+ "image",
346
+ "image"
347
+ ],
348
+ "image_mask":[1,0,0,0]
349
+ },
350
+ "berkeley_mvp_converted_externally_to_rlds": {
351
+ "image_keys": [
352
+ "hand_image",
353
+ "hand_image",
354
+ "hand_image",
355
+ "hand_image"
356
+ ],
357
+ "image_mask":[0,1,0,0]
358
+ },
359
+ "berkeley_rpt_converted_externally_to_rlds": {
360
+ "image_keys": [
361
+ "hand_image",
362
+ "hand_image",
363
+ "hand_image",
364
+ "hand_image"
365
+ ],
366
+ "image_mask":[0,1,0,0]
367
+ },
368
+ "kaist_nonprehensile_converted_externally_to_rlds": {
369
+ "image_keys": [
370
+ "image",
371
+ "image",
372
+ "image",
373
+ "image"
374
+ ],
375
+ "image_mask":[1,0,0,0]
376
+ },
377
+ "stanford_mask_vit_converted_externally_to_rlds": {
378
+ "image_keys": [
379
+ "image",
380
+ "image",
381
+ "image",
382
+ "image"
383
+ ],
384
+ "image_mask":[1,0,0,0]
385
+ },
386
+ "tokyo_u_lsmo_converted_externally_to_rlds": {
387
+ "image_keys": [
388
+ "image",
389
+ "image",
390
+ "image",
391
+ "image"
392
+ ],
393
+ "image_mask":[1,0,0,0]
394
+ },
395
+ "dlr_sara_pour_converted_externally_to_rlds": {
396
+ "image_keys": [
397
+ "image",
398
+ "image",
399
+ "image",
400
+ "image"
401
+ ],
402
+ "image_mask":[1,0,0,0]
403
+ },
404
+ "dlr_sara_grid_clamp_converted_externally_to_rlds": {
405
+ "image_keys": [
406
+ "image",
407
+ "image",
408
+ "image",
409
+ "image"
410
+ ],
411
+ "image_mask":[1,0,0,0]
412
+ },
413
+ "dlr_edan_shared_control_converted_externally_to_rlds": {
414
+ "image_keys": [
415
+ "image",
416
+ "image",
417
+ "image",
418
+ "image"
419
+ ],
420
+ "image_mask":[1,0,0,0]
421
+ },
422
+ "asu_table_top_converted_externally_to_rlds": {
423
+ "image_keys": [
424
+ "image",
425
+ "image",
426
+ "image",
427
+ "image"
428
+ ],
429
+ "image_mask":[1,0,0,0]
430
+ },
431
+ "stanford_robocook_converted_externally_to_rlds": {
432
+ "image_keys": [
433
+ "image_2",
434
+ "image_1",
435
+ "image_3",
436
+ "image_4"
437
+ ],
438
+ "image_mask":[1,0,0,1]
439
+ },
440
+ "eth_agent_affordances": {
441
+ "image_keys": [
442
+ "image",
443
+ "image",
444
+ "image",
445
+ "image"
446
+ ],
447
+ "image_mask":[1,0,0,0]
448
+ },
449
+ "imperialcollege_sawyer_wrist_cam": {
450
+ "image_keys": [
451
+ "image",
452
+ "wrist_image",
453
+ "wrist_image",
454
+ "wrist_image"
455
+ ],
456
+ "image_mask":[0,1,0,0]
457
+ },
458
+ "iamlab_cmu_pickup_insert_converted_externally_to_rlds": {
459
+ "image_keys": [
460
+ "image",
461
+ "wrist_image",
462
+ "wrist_image",
463
+ "wrist_image"
464
+ ],
465
+ "image_mask":[1,1,0,0]
466
+ },
467
+ "uiuc_d3field": {
468
+ "image_keys": [
469
+ "image_1",
470
+ "image_2",
471
+ "image_3",
472
+ "image_4"
473
+ ],
474
+ "image_mask":[1,0,0,1]
475
+ },
476
+ "utaustin_mutex": {
477
+ "image_keys": [
478
+ "image",
479
+ "wrist_image",
480
+ "wrist_image",
481
+ "wrist_image"
482
+ ],
483
+ "image_mask":[1,1,0,0]
484
+ },
485
+ "berkeley_fanuc_manipulation": {
486
+ "image_keys": [
487
+ "image",
488
+ "wrist_image",
489
+ "wrist_image",
490
+ "wrist_image"
491
+ ],
492
+ "image_mask":[1,1,0,0]
493
+ },
494
+ "cmu_play_fusion": {
495
+ "image_keys": [
496
+ "image",
497
+ "image",
498
+ "image",
499
+ "image"
500
+ ],
501
+ "image_mask":[1,0,0,0]
502
+ },
503
+ "cmu_stretch": {
504
+ "image_keys": [
505
+ "image",
506
+ "image",
507
+ "image",
508
+ "image"
509
+ ],
510
+ "image_mask":[1,0,0,0]
511
+ },
512
+ "berkeley_gnm_recon": {
513
+ "image_keys": [
514
+ "image",
515
+ "image",
516
+ "image",
517
+ "image"
518
+ ],
519
+ "image_mask":[1,0,0,0]
520
+ },
521
+ "berkeley_gnm_cory_hall": {
522
+ "image_keys": [
523
+ "image",
524
+ "image",
525
+ "image",
526
+ "image"
527
+ ],
528
+ "image_mask":[1,0,0,0]
529
+ },
530
+ "berkeley_gnm_sac_son": {
531
+ "image_keys": [
532
+ "image",
533
+ "image",
534
+ "image",
535
+ "image"
536
+ ],
537
+ "image_mask":[1,0,0,0]
538
+ },
539
+ "robo_net": {
540
+ "image_keys": [
541
+ "image",
542
+ "image1",
543
+ "image2",
544
+ "image2"
545
+ ],
546
+ "image_mask":[1,0,0,1]
547
+ },
548
+ "roboturk_real_towercreation": {
549
+ "image_keys": [
550
+ "top_rgb_frame",
551
+ "front_rgb_frame",
552
+ "front_rgb_frame",
553
+ "front_rgb_frame"
554
+ ],
555
+ "image_mask":[1,0,0,1]
556
+ },
557
+ "roboturk_real_laundrylayout": {
558
+ "image_keys": [
559
+ "top_rgb_frame",
560
+ "front_rgb_frame",
561
+ "front_rgb_frame",
562
+ "front_rgb_frame"
563
+ ],
564
+ "image_mask":[1,0,0,1]
565
+ },
566
+ "roboturk_real_objectsearch": {
567
+ "image_keys": [
568
+ "top_rgb_frame",
569
+ "front_rgb_frame",
570
+ "front_rgb_frame",
571
+ "front_rgb_frame"
572
+ ],
573
+ "image_mask":[1,0,0,1]
574
+ },
575
+ "aloha_mobile": {
576
+ "image_keys": [
577
+ "cam_high",
578
+ "cam_right_wrist",
579
+ "cam_left_wrist",
580
+ "cam_right_wrist"
581
+ ],
582
+ "image_mask":[1,1,1,0]
583
+ },
584
+ "aloha_static": {
585
+ "image_keys": [
586
+ "cam_high",
587
+ "cam_right_wrist",
588
+ "cam_left_wrist",
589
+ "cam_low"
590
+ ],
591
+ "image_mask":[1,1,1,1]
592
+ },
593
+ "roboset": {
594
+ "image_keys": [
595
+ "rgb_top",
596
+ "rgb_right",
597
+ "rgb_left",
598
+ "rgb_right"
599
+ ],
600
+ "image_mask":[1,1,1,0]
601
+ },
602
+ "droid": {
603
+ "image_keys": [
604
+ "exterior_image_1_left",
605
+ "wrist_image_left",
606
+ "wrist_image_left",
607
+ "exterior_image_2_left"
608
+ ],
609
+ "image_mask":[1,1,0,1]
610
+ },
611
+ "fmb": {
612
+ "image_keys": [
613
+ "image_side_1",
614
+ "image_wrist_1",
615
+ "image_wrist_1",
616
+ "image_side_2"
617
+ ],
618
+ "image_mask":[1,1,0,1]
619
+ },
620
+ "dobbe": {
621
+ "image_keys": [
622
+ "wrist_image",
623
+ "wrist_image",
624
+ "wrist_image",
625
+ "wrist_image"
626
+ ],
627
+ "image_mask":[0,1,0,0]
628
+ },
629
+ "qut_dexterous_manpulation": {
630
+ "image_keys": [
631
+ "image",
632
+ "wrist_image",
633
+ "wrist_image",
634
+ "wrist_image"
635
+ ],
636
+ "image_mask":[1,1,0,0]
637
+ },
638
+ "agilex": {
639
+ "image_keys": [
640
+ "cam_high",
641
+ "cam_right_wrist",
642
+ "cam_left_wrist",
643
+ "cam_right_wrist"
644
+ ],
645
+ "image_mask":[1,1,1,0]
646
+ },
647
+ "rh20t": {
648
+ "image_keys": [
649
+ "image",
650
+ "image",
651
+ "image",
652
+ "image"
653
+ ],
654
+ "image_mask":[1,0,0,0]
655
+ },
656
+ "calvin": {
657
+ "image_keys": [
658
+ "rgb_static",
659
+ "rgb_gripper",
660
+ "rgb_gripper",
661
+ "rgb_gripper"
662
+ ],
663
+ "image_mask":[1,1,0,0]
664
+ },
665
+ "bridgev2": {
666
+ "image_keys": [
667
+ "images0",
668
+ "images0",
669
+ "images0",
670
+ "images0"
671
+ ],
672
+ "image_mask":[1,0,0,0]
673
+ }
674
+ }
configs/dataset_stat.json ADDED
The diff for this file is too large to render. See raw diff
 
configs/finetune_datasets.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [
2
+ "anubis_brush_to_pan",
3
+ "anubis_carrot_to_bag",
4
+ "anubis_towel_kirby"
5
+ ]
configs/finetune_sample_weights.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "anubis_towel_kirby" : 100,
3
+ "anubis_carrot_to_bag" :100,
4
+ "anubis_brush_to_pan" : 100
5
+ }
configs/pretrain_datasets.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [
2
+ "aloha_box_into_pot_easy"
3
+ ]
configs/pretrain_sample_weights.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "aloha_box_into_pot_easy" : 100
3
+ }
configs/state_vec.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ STATE_VEC_IDX_MAPPING = {
2
+ # [0, 10): right arm joint positions
3
+ **{
4
+ 'arm_joint_{}_pos'.format(i): i for i in range(10)
5
+ },
6
+ **{
7
+ 'right_arm_joint_{}_pos'.format(i): i for i in range(10)
8
+ },
9
+ # [10, 15): right gripper joint positions
10
+ **{
11
+ 'gripper_joint_{}_pos'.format(i): i + 10 for i in range(5)
12
+ },
13
+ **{
14
+ 'right_gripper_joint_{}_pos'.format(i): i + 10 for i in range(5)
15
+ },
16
+ 'gripper_open': 10, # alias of right_gripper_joint_0_pos
17
+ 'right_gripper_open': 10,
18
+ # [15, 25): right arm joint velocities
19
+ **{
20
+ 'arm_joint_{}_vel'.format(i): i + 15 for i in range(10)
21
+ },
22
+ **{
23
+ 'right_arm_joint_{}_vel'.format(i): i + 15 for i in range(10)
24
+ },
25
+ # [25, 30): right gripper joint velocities
26
+ **{
27
+ 'gripper_joint_{}_vel'.format(i): i + 25 for i in range(5)
28
+ },
29
+ **{
30
+ 'right_gripper_joint_{}_vel'.format(i): i + 25 for i in range(5)
31
+ },
32
+ 'gripper_open_vel': 25, # alias of right_gripper_joint_0_vel
33
+ 'right_gripper_open_vel': 25,
34
+ # [30, 33): right end effector positions
35
+ 'eef_pos_x': 30,
36
+ 'right_eef_pos_x': 30,
37
+ 'eef_pos_y': 31,
38
+ 'right_eef_pos_y': 31,
39
+ 'eef_pos_z': 32,
40
+ 'right_eef_pos_z': 32,
41
+ # [33, 39): right end effector 6D pose
42
+ 'eef_angle_0': 33,
43
+ 'right_eef_angle_0': 33,
44
+ 'eef_angle_1': 34,
45
+ 'right_eef_angle_1': 34,
46
+ 'eef_angle_2': 35,
47
+ 'right_eef_angle_2': 35,
48
+ 'eef_angle_3': 36,
49
+ 'right_eef_angle_3': 36,
50
+ 'eef_angle_4': 37,
51
+ 'right_eef_angle_4': 37,
52
+ 'eef_angle_5': 38,
53
+ 'right_eef_angle_5': 38,
54
+ # [39, 42): right end effector velocities
55
+ 'eef_vel_x': 39,
56
+ 'right_eef_vel_x': 39,
57
+ 'eef_vel_y': 40,
58
+ 'right_eef_vel_y': 40,
59
+ 'eef_vel_z': 41,
60
+ 'right_eef_vel_z': 41,
61
+ # [42, 45): right end effector angular velocities
62
+ 'eef_angular_vel_roll': 42,
63
+ 'right_eef_angular_vel_roll': 42,
64
+ 'eef_angular_vel_pitch': 43,
65
+ 'right_eef_angular_vel_pitch': 43,
66
+ 'eef_angular_vel_yaw': 44,
67
+ 'right_eef_angular_vel_yaw': 44,
68
+ # [45, 50): reserved
69
+ # [50, 60): left arm joint positions
70
+ **{
71
+ 'left_arm_joint_{}_pos'.format(i): i + 50 for i in range(10)
72
+ },
73
+ # [60, 65): left gripper joint positions
74
+ **{
75
+ 'left_gripper_joint_{}_pos'.format(i): i + 60 for i in range(5)
76
+ },
77
+ 'left_gripper_open': 60, # alias of left_gripper_joint_0_pos
78
+ # [65, 75): left arm joint velocities
79
+ **{
80
+ 'left_arm_joint_{}_vel'.format(i): i + 65 for i in range(10)
81
+ },
82
+ # [75, 80): left gripper joint velocities
83
+ **{
84
+ 'left_gripper_joint_{}_vel'.format(i): i + 75 for i in range(5)
85
+ },
86
+ 'left_gripper_open_vel': 75, # alias of left_gripper_joint_0_vel
87
+ # [80, 83): left end effector positions
88
+ 'left_eef_pos_x': 80,
89
+ 'left_eef_pos_y': 81,
90
+ 'left_eef_pos_z': 82,
91
+ # [83, 89): left end effector 6D pose
92
+ 'left_eef_angle_0': 83,
93
+ 'left_eef_angle_1': 84,
94
+ 'left_eef_angle_2': 85,
95
+ 'left_eef_angle_3': 86,
96
+ 'left_eef_angle_4': 87,
97
+ 'left_eef_angle_5': 88,
98
+ # [89, 92): left end effector velocities
99
+ 'left_eef_vel_x': 89,
100
+ 'left_eef_vel_y': 90,
101
+ 'left_eef_vel_z': 91,
102
+ # [92, 95): left end effector angular velocities
103
+ 'left_eef_angular_vel_roll': 92,
104
+ 'left_eef_angular_vel_pitch': 93,
105
+ 'left_eef_angular_vel_yaw': 94,
106
+ # [95, 100): reserved
107
+ # [100, 102): base linear velocities
108
+ 'base_vel_x': 100,
109
+ 'base_vel_y': 101,
110
+ # [102, 103): base angular velocities
111
+ 'base_angular_vel': 102,
112
+ # [103, 128): reserved
113
+ }
114
+ STATE_VEC_LEN = 128
configs/zero2.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": "auto"
4
+ },
5
+ "train_micro_batch_size_per_gpu": "auto",
6
+ "train_batch_size": "auto",
7
+ "gradient_accumulation_steps": "auto",
8
+ "zero_optimization": {
9
+ "stage": 2,
10
+ "overlap_comm": true,
11
+ "contiguous_gradients": true,
12
+ "sub_group_size": 1e9
13
+ }
14
+ }
data/aloha/hdf5totfrecords.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import h5py
3
+ import os
4
+ import fnmatch
5
+ import cv2
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+
9
+ def decode_img(img):
10
+ return cv2.cvtColor(cv2.imdecode(np.frombuffer(img, np.uint8), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
11
+
12
+ def decode_all_imgs(imgs):
13
+ return [decode_img(img) for img in imgs]
14
+
15
+ def _bytes_feature(value):
16
+ """Returns a bytes_list from a string / byte."""
17
+ if isinstance(value, type(tf.constant(0))):
18
+ value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
19
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
20
+ def _bool_feature(value):
21
+ """Returns a bool_list from a boolean."""
22
+ return tf.train.Feature(int64_list=tf.train.Int64List(value=[int(value)]))
23
+
24
+ def serialize_example(action, base_action, qpos, qvel, cam_high, cam_left_wrist, cam_right_wrist, cam_low, instruction, terminate_episode):
25
+ if base_action is not None:
26
+ feature = {
27
+ 'action': _bytes_feature(tf.io.serialize_tensor(action)),
28
+ 'base_action': _bytes_feature(tf.io.serialize_tensor(base_action)),
29
+ 'qpos': _bytes_feature(tf.io.serialize_tensor(qpos)),
30
+ 'qvel': _bytes_feature(tf.io.serialize_tensor(qvel)),
31
+ 'cam_high': _bytes_feature(tf.io.serialize_tensor(cam_high)),
32
+ 'cam_left_wrist': _bytes_feature(tf.io.serialize_tensor(cam_left_wrist)),
33
+ 'cam_right_wrist': _bytes_feature(tf.io.serialize_tensor(cam_right_wrist)),
34
+ 'instruction': _bytes_feature(instruction),
35
+ 'terminate_episode': _bool_feature(terminate_episode)
36
+ }
37
+ else:
38
+ feature = {
39
+ 'action': _bytes_feature(tf.io.serialize_tensor(action)),
40
+ 'qpos': _bytes_feature(tf.io.serialize_tensor(qpos)),
41
+ 'qvel': _bytes_feature(tf.io.serialize_tensor(qvel)),
42
+ 'cam_high': _bytes_feature(tf.io.serialize_tensor(cam_high)),
43
+ 'cam_left_wrist': _bytes_feature(tf.io.serialize_tensor(cam_left_wrist)),
44
+ 'cam_right_wrist': _bytes_feature(tf.io.serialize_tensor(cam_right_wrist)),
45
+ 'cam_low': _bytes_feature(tf.io.serialize_tensor(cam_low)),
46
+ 'instruction': _bytes_feature(instruction),
47
+ 'terminate_episode': _bool_feature(terminate_episode)
48
+ }
49
+ example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
50
+ return example_proto.SerializeToString()
51
+
52
+ def write_tfrecords(root_dir, out_dir):
53
+ if not os.path.exists(out_dir):
54
+ os.makedirs(out_dir)
55
+ num_files = 0
56
+ for root, dirs, files in os.walk(root_dir):
57
+ num_files += len(fnmatch.filter(files, '*.hdf5'))
58
+ with tqdm(total=num_files) as pbar:
59
+ for root, dirs, files in os.walk(root_dir):
60
+ for filename in fnmatch.filter(files, '*.hdf5'):
61
+ filepath = os.path.join(root, filename)
62
+ with h5py.File(filepath, 'r') as f:
63
+ if not 'instruction' in f:
64
+ continue
65
+ pbar.update(1)
66
+ output_dir = os.path.join(out_dir, os.path.relpath(root, root_dir))
67
+ if not os.path.exists(output_dir):
68
+ os.makedirs(output_dir)
69
+ print(f"Writing TFRecords to {output_dir}")
70
+ tfrecord_path = os.path.join(output_dir, filename.replace('.hdf5', '.tfrecord'))
71
+ with tf.io.TFRecordWriter(tfrecord_path) as writer:
72
+ num_episodes = f['action'].shape[0]
73
+ for i in range(num_episodes):
74
+ action = f['action'][i]
75
+ if 'base_action' in f:
76
+ base_action = f['base_action'][i]
77
+ else:
78
+ base_action = None
79
+ qpos = f['observations']['qpos'][i]
80
+ qvel = f['observations']['qvel'][i]
81
+ cam_high = decode_img(f['observations']['images']['cam_high'][i])
82
+ cam_left_wrist = decode_img(f['observations']['images']['cam_left_wrist'][i])
83
+ cam_right_wrist = decode_img(f['observations']['images']['cam_right_wrist'][i])
84
+ if 'cam_low' in f['observations']['images']:
85
+ cam_low = decode_img(f['observations']['images']['cam_low'][i])
86
+ else:
87
+ cam_low = None
88
+ instruction = f['instruction'][()]
89
+ terminate_episode = i == num_episodes - 1
90
+ serialized_example = serialize_example(action, base_action, qpos, qvel, cam_high, cam_left_wrist, cam_right_wrist, cam_low, instruction, terminate_episode)
91
+ writer.write(serialized_example)
92
+ print(f"TFRecords written to {tfrecord_path}")
93
+ print(f"TFRecords written to {out_dir}")
94
+
95
+ root_dir = '../datasets/aloha/'
96
+ output_dir = '../datasets/aloha/tfrecords/'
97
+
98
+ write_tfrecords(root_dir, output_dir)
data/aloha/unzip_data.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ cd ../datasets/aloha/
2
+
3
+ unzip aloha_mobile.zip
data/bridgev2/bridgedata_numpy_to_tfrecord.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Converts data from the BridgeData numpy format to TFRecord format.
3
+
4
+ Consider the following directory structure for the input data:
5
+
6
+ bridgedata_numpy/
7
+ rss/
8
+ toykitchen2/
9
+ set_table/
10
+ 00/
11
+ train/
12
+ out.npy
13
+ val/
14
+ out.npy
15
+ icra/
16
+ ...
17
+
18
+ The --depth parameter controls how much of the data to process at the
19
+ --input_path; for example, if --depth=5, then --input_path should be
20
+ "bridgedata_numpy", and all data will be processed. If --depth=3, then
21
+ --input_path should be "bridgedata_numpy/rss/toykitchen2", and only data
22
+ under "toykitchen2" will be processed.
23
+
24
+ The same directory structure will be replicated under --output_path. For
25
+ example, in the second case, the output will be written to
26
+ "{output_path}/set_table/00/...".
27
+
28
+ Can read/write directly from/to Google Cloud Storage.
29
+
30
+ Written by Kevin Black ([email protected]).
31
+ """
32
+ import os
33
+ from multiprocessing import Pool
34
+
35
+ import numpy as np
36
+ import tensorflow as tf
37
+ import tqdm
38
+ from absl import app, flags, logging
39
+ import pickle
40
+ from multiprocessing import cpu_count
41
+
42
+ FLAGS = flags.FLAGS
43
+
44
+ flags.DEFINE_string("input_path", None, "Input path", required=True)
45
+ flags.DEFINE_string("output_path", None, "Output path", required=True)
46
+ flags.DEFINE_integer(
47
+ "depth",
48
+ 5,
49
+ "Number of directories deep to traverse. Looks for {input_path}/dir_1/dir_2/.../dir_{depth-1}/train/out.npy",
50
+ )
51
+ flags.DEFINE_bool("overwrite", False, "Overwrite existing files")
52
+ num_workers = 8
53
+ flags.DEFINE_integer("num_workers", num_workers, "Number of threads to use")
54
+
55
+ print(f"using {num_workers} workers")
56
+
57
+ def tensor_feature(value):
58
+ return tf.train.Feature(
59
+ bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(value).numpy()])
60
+ )
61
+
62
+ def _bytes_feature(value):
63
+ """Returns a bytes_list from a string / byte."""
64
+ if isinstance(value, type(tf.constant(0))):
65
+ value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
66
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.encode('utf-8')]))
67
+
68
+ def _strings_feature(string_list):
69
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=s.encode('utf-8')))
70
+
71
+ def _bool_feature(value):
72
+ """Returns a bool_list from a boolean."""
73
+ return tf.train.Feature(int64_list=tf.train.Int64List(value=[int(value)]))
74
+
75
+
76
+ def process(path):
77
+ # with tf.io.gfile.GFile(path, "rb") as f:
78
+ # arr = np.load(f, allow_pickle=True)
79
+ try:
80
+ with tf.io.gfile.GFile(path, "rb") as f:
81
+ arr = np.load(path, allow_pickle=True)
82
+ except Exception as e:
83
+ print(f"Error loading {path}: {e}")
84
+ return
85
+
86
+ dirname = os.path.dirname(os.path.abspath(path))
87
+ outpath = os.path.join(FLAGS.output_path, *dirname.split(os.sep)[-FLAGS.depth :])
88
+
89
+ if tf.io.gfile.exists(outpath):
90
+ if FLAGS.overwrite:
91
+ logging.info(f"Deleting {outpath}")
92
+ tf.io.gfile.rmtree(outpath)
93
+ else:
94
+ logging.info(f"Skipping {outpath}")
95
+ return
96
+
97
+ if len(arr) == 0:
98
+ logging.info(f"Skipping {path}, empty")
99
+ return
100
+
101
+ tf.io.gfile.makedirs(outpath)
102
+
103
+ for i,traj in enumerate(arr):
104
+ write_path = f"{outpath}/out_{i}.tfrecord"
105
+ with tf.io.TFRecordWriter(write_path) as writer:
106
+ truncates = np.zeros(len(traj["actions"]), dtype=np.bool_)
107
+ truncates[-1] = True
108
+ frames_num = len(traj["observations"])
109
+ # remove empty string
110
+ traj["language"] = [x for x in traj["language"] if x != ""]
111
+ if len(traj["language"]) == 0:
112
+ traj["language"] = [""]
113
+ instr = traj["language"][0]
114
+ if(len(traj["language"]) > 2):
115
+ print(len(traj["language"]))
116
+ for i in range(frames_num):
117
+ tf_features = {
118
+ "observations/images0": tensor_feature(
119
+ np.array(
120
+ [traj["observations"][i]["images0"]],
121
+ dtype=np.uint8,
122
+ )
123
+ ),
124
+ "observations/state": tensor_feature(
125
+ np.array(
126
+ [traj["observations"][i]["state"]],
127
+ dtype=np.float32,
128
+ )
129
+ ),
130
+ "observations/qpos": tensor_feature(
131
+ np.array(
132
+ [traj["observations"][i]["qpos"]],
133
+ dtype=np.float32,
134
+ )
135
+ ),
136
+ "observations/eef_transform": tensor_feature(
137
+ np.array(
138
+ [traj["observations"][i]["eef_transform"]],
139
+ dtype=np.float32,
140
+ )
141
+ ),
142
+ "language": _bytes_feature(instr),
143
+ "actions": tensor_feature(
144
+ np.array(traj["actions"][i], dtype=np.float32)
145
+ ),
146
+ "truncates": _bool_feature(i == frames_num - 1),
147
+ }
148
+ example = tf.train.Example(
149
+ features=tf.train.Features(
150
+ feature = tf_features
151
+ )
152
+ )
153
+ writer.write(example.SerializeToString())
154
+
155
+
156
+ def main(_):
157
+ assert FLAGS.depth >= 1
158
+
159
+ paths = tf.io.gfile.glob(
160
+ tf.io.gfile.join(FLAGS.input_path, *("*" * (FLAGS.depth - 1)))
161
+ )
162
+ paths = [f"{p}/train/out.npy" for p in paths] + [f"{p}/val/out.npy" for p in paths]
163
+ # num_episodes = 0
164
+ # for dirpath in paths:
165
+ # with tf.io.gfile.GFile(dirpath, "rb") as f:
166
+ # arr = np.load(dirpath, allow_pickle=True)
167
+ # num_episodes += len(arr)
168
+ # print(num_episodes)
169
+ with Pool(FLAGS.num_workers) as p:
170
+ list(tqdm.tqdm(p.imap(process, paths), total=len(paths)))
171
+
172
+
173
+ if __name__ == "__main__":
174
+ app.run(main)
data/bridgev2/bridgedata_raw_to_numpy.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Converts data from the BridgeData raw format to numpy format.
3
+
4
+ Consider the following directory structure for the input data:
5
+
6
+ bridgedata_raw/
7
+ rss/
8
+ toykitchen2/
9
+ set_table/
10
+ 00/
11
+ 2022-01-01_00-00-00/
12
+ collection_metadata.json
13
+ config.json
14
+ diagnostics.png
15
+ raw/
16
+ traj_group0/
17
+ traj0/
18
+ obs_dict.pkl
19
+ policy_out.pkl
20
+ agent_data.pkl
21
+ images0/
22
+ im_0.jpg
23
+ im_1.jpg
24
+ ...
25
+ ...
26
+ ...
27
+ 01/
28
+ ...
29
+
30
+ The --depth parameter controls how much of the data to process at the
31
+ --input_path; for example, if --depth=5, then --input_path should be
32
+ "bridgedata_raw", and all data will be processed. If --depth=3, then
33
+ --input_path should be "bridgedata_raw/rss/toykitchen2", and only data
34
+ under "toykitchen2" will be processed.
35
+
36
+ The same directory structure will be replicated under --output_path. For
37
+ example, in the second case, the output will be written to
38
+ "{output_path}/set_table/00/...".
39
+
40
+ Squashes images to 128x128.
41
+
42
+ Can write directly to Google Cloud Storage, but not read from it.
43
+
44
+ Written by Kevin Black ([email protected]).
45
+ """
46
+ import copy
47
+ import glob
48
+ import os
49
+ import pickle
50
+ import random
51
+ from collections import defaultdict
52
+ from datetime import datetime
53
+ from functools import partial
54
+ from multiprocessing import Pool
55
+
56
+ import numpy as np
57
+ import tensorflow as tf
58
+ import tqdm
59
+ from absl import app, flags, logging
60
+ from PIL import Image
61
+
62
+ FLAGS = flags.FLAGS
63
+
64
+ flags.DEFINE_string("input_path", None, "Input path", required=True)
65
+ flags.DEFINE_string("output_path", None, "Output path", required=True)
66
+ flags.DEFINE_integer(
67
+ "depth",
68
+ 5,
69
+ "Number of directories deep to traverse to the dated directory. Looks for"
70
+ "{input_path}/dir_1/dir_2/.../dir_{depth-1}/2022-01-01_00-00-00/...",
71
+ )
72
+ flags.DEFINE_bool("overwrite", False, "Overwrite existing files")
73
+ flags.DEFINE_float(
74
+ "train_proportion", 0.9, "Proportion of data to use for training (rather than val)"
75
+ )
76
+ flags.DEFINE_integer("num_workers", 8, "Number of threads to use")
77
+ flags.DEFINE_integer("im_size", 128, "Image size")
78
+
79
+
80
+ def squash(path):
81
+ im = Image.open(path)
82
+ # im = im.resize((FLAGS.im_size, FLAGS.im_size), Image.Resampling.LANCZOS)
83
+ out = np.asarray(im).astype(np.uint8)
84
+ return out
85
+
86
+
87
+ def process_images(path): # processes images at a trajectory level
88
+ names = sorted(
89
+ [x for x in os.listdir(path) if "images" in x and not "depth" in x],
90
+ key=lambda x: int(x.split("images")[1]),
91
+ )
92
+ image_path = [
93
+ os.path.join(path, x)
94
+ for x in os.listdir(path)
95
+ if "images" in x and not "depth" in x
96
+ ]
97
+ image_path = sorted(image_path, key=lambda x: int(x.split("images")[1]))
98
+
99
+ images_out = defaultdict(list)
100
+ if len(image_path) == 0:
101
+ return None, None
102
+
103
+ tlen = len(glob.glob(image_path[0] + "/im_*.jpg"))
104
+
105
+ for i, name in enumerate(names):
106
+ for t in range(tlen):
107
+ images_out[name].append(squash(image_path[i] + "/im_{}.jpg".format(t)))
108
+
109
+ images_out = dict(images_out)
110
+
111
+ obs, next_obs = dict(), dict()
112
+
113
+ for n in names:
114
+ obs[n] = images_out[n][:-1]
115
+ next_obs[n] = images_out[n][1:]
116
+ return obs, next_obs
117
+
118
+
119
+ def process_state(path):
120
+ fp = os.path.join(path, "obs_dict.pkl")
121
+ with open(fp, "rb") as f:
122
+ x = pickle.load(f)
123
+ qpos = None if "qpos" not in x.keys() else x["qpos"]
124
+ qvel = None if "qvel" not in x.keys() else x["qvel"]
125
+ eef_transform = None if "eef_transform" not in x.keys() else x["eef_transform"]
126
+ return x["full_state"][:-1], x["full_state"][1:], qpos, qvel, eef_transform
127
+
128
+ def process_time(path):
129
+ fp = os.path.join(path, "obs_dict.pkl")
130
+ with open(fp, "rb") as f:
131
+ x = pickle.load(f)
132
+ return x["time_stamp"][:-1], x["time_stamp"][1:]
133
+
134
+
135
+ def process_actions(path): # gets actions
136
+ fp = os.path.join(path, "policy_out.pkl")
137
+ with open(fp, "rb") as f:
138
+ act_list = pickle.load(f)
139
+ if isinstance(act_list[0], dict):
140
+ act_list = [x["actions"] for x in act_list]
141
+ return act_list
142
+
143
+
144
+ # processes each data collection attempt
145
+ def process_dc(path, train_ratio=0.9):
146
+ # a mystery left by the greats of the past
147
+ if "lmdb" in path:
148
+ logging.warning(f"Skipping {path} because uhhhh lmdb?")
149
+ return [], [], [], []
150
+
151
+ all_dicts_train = list()
152
+ all_dicts_test = list()
153
+ all_rews_train = list()
154
+ all_rews_test = list()
155
+
156
+ # Data collected prior to 7-23 has a delay of 1, otherwise a delay of 0
157
+ date_time = datetime.strptime(path.split("/")[-1], "%Y-%m-%d_%H-%M-%S")
158
+ latency_shift = date_time < datetime(2021, 7, 23)
159
+
160
+ search_path = os.path.join(path, "raw", "traj_group*", "traj*")
161
+ all_traj = glob.glob(search_path)
162
+ if all_traj == []:
163
+ logging.info(f"no trajs found in {search_path}")
164
+ return [], [], [], []
165
+
166
+ random.shuffle(all_traj)
167
+
168
+ num_traj = len(all_traj)
169
+ for itraj, tp in tqdm.tqdm(enumerate(all_traj)):
170
+ try:
171
+ out = dict()
172
+
173
+ ld = os.listdir(tp)
174
+
175
+ assert "obs_dict.pkl" in ld, tp + ":" + str(ld)
176
+ assert "policy_out.pkl" in ld, tp + ":" + str(ld)
177
+ # assert "agent_data.pkl" in ld, tp + ":" + str(ld) # not used
178
+
179
+ obs, next_obs = process_images(tp)
180
+ if obs is None:
181
+ return
182
+ acts = process_actions(tp)
183
+ state, next_state, qpos, qvel, eef_transform = process_state(tp)
184
+ time_stamp, next_time_stamp = process_time(tp)
185
+ term = [0] * len(acts)
186
+ if "lang.txt" in ld:
187
+ with open(os.path.join(tp, "lang.txt")) as f:
188
+ lang = list(f)
189
+ lang = [l.strip() for l in lang if "confidence" not in l]
190
+ else:
191
+ # empty string is a placeholder for data with no language label
192
+ lang = [""]
193
+
194
+ out["observations"] = obs
195
+ out["observations"]["state"] = state
196
+ out["observations"]["time_stamp"] = time_stamp
197
+ if qpos is not None:
198
+ out["observations"]["qpos"] = qpos
199
+ else:
200
+ return None, None, None, None
201
+ if qvel is not None:
202
+ out["observations"]["qvel"] = qvel
203
+ if eef_transform is not None:
204
+ out["observations"]["eef_transform"] = eef_transform
205
+ out["next_observations"] = next_obs
206
+ out["next_observations"]["state"] = next_state
207
+ out["next_observations"]["time_stamp"] = next_time_stamp
208
+
209
+
210
+ out["observations"] = [
211
+ dict(zip(out["observations"], t))
212
+ for t in zip(*out["observations"].values())
213
+ ]
214
+ out["next_observations"] = [
215
+ dict(zip(out["next_observations"], t))
216
+ for t in zip(*out["next_observations"].values())
217
+ ]
218
+
219
+ out["actions"] = acts
220
+ out["terminals"] = term
221
+ out["language"] = lang
222
+
223
+ # shift the actions according to camera latency
224
+ if latency_shift:
225
+ out["observations"] = out["observations"][1:]
226
+ out["next_observations"] = out["next_observations"][1:]
227
+ out["actions"] = out["actions"][:-1]
228
+ out["terminals"] = term[:-1]
229
+
230
+ labeled_rew = copy.deepcopy(out["terminals"])[:]
231
+ labeled_rew[-2:] = [1, 1]
232
+
233
+ traj_len = len(out["observations"])
234
+ assert len(out["next_observations"]) == traj_len
235
+ assert len(out["actions"]) == traj_len
236
+ assert len(out["terminals"]) == traj_len
237
+ assert len(labeled_rew) == traj_len
238
+
239
+ if itraj < int(num_traj * train_ratio):
240
+ all_dicts_train.append(out)
241
+ all_rews_train.append(labeled_rew)
242
+ else:
243
+ all_dicts_test.append(out)
244
+ all_rews_test.append(labeled_rew)
245
+ except FileNotFoundError as e:
246
+ logging.error(e)
247
+ continue
248
+ except AssertionError as e:
249
+ logging.error(e)
250
+ continue
251
+
252
+ return all_dicts_train, all_dicts_test, all_rews_train, all_rews_test
253
+
254
+
255
+ def make_numpy(path, train_proportion):
256
+ dirname = os.path.abspath(path)
257
+ outpath = os.path.join(
258
+ FLAGS.output_path, *dirname.split(os.sep)[-(max(FLAGS.depth - 1, 1)) :]
259
+ )
260
+
261
+ if os.path.exists(outpath):
262
+ if FLAGS.overwrite:
263
+ logging.info(f"Deleting {outpath}")
264
+ tf.io.gfile.rmtree(outpath)
265
+ else:
266
+ logging.info(f"Skipping {outpath}")
267
+ return
268
+
269
+ outpath_train = tf.io.gfile.join(outpath, "train")
270
+ outpath_val = tf.io.gfile.join(outpath, "val")
271
+ tf.io.gfile.makedirs(outpath_train)
272
+ tf.io.gfile.makedirs(outpath_val)
273
+
274
+ lst_train = []
275
+ lst_val = []
276
+ rew_train_l = []
277
+ rew_val_l = []
278
+
279
+ for dated_folder in os.listdir(path):
280
+ curr_train, curr_val, rew_train, rew_val = process_dc(
281
+ os.path.join(path, dated_folder), train_ratio=train_proportion
282
+ )
283
+ if curr_train is None:
284
+ continue
285
+ lst_train.extend(curr_train)
286
+ lst_val.extend(curr_val)
287
+ rew_train_l.extend(rew_train)
288
+ rew_val_l.extend(rew_val)
289
+
290
+ if len(lst_train) == 0 or len(lst_val) == 0:
291
+ return
292
+
293
+ with tf.io.gfile.GFile(tf.io.gfile.join(outpath_train, "out.npy"), "wb") as f:
294
+ np.save(f, lst_train)
295
+ with tf.io.gfile.GFile(tf.io.gfile.join(outpath_val, "out.npy"), "wb") as f:
296
+ np.save(f, lst_val)
297
+
298
+ # doesn't seem like these are ever used anymore
299
+ # np.save(os.path.join(outpath_train, "out_rew.npy"), rew_train_l)
300
+ # np.save(os.path.join(outpath_val, "out_rew.npy"), rew_val_l)
301
+
302
+
303
+ def main(_):
304
+ assert FLAGS.depth >= 1
305
+
306
+ # each path is a directory that contains dated directories
307
+ paths = glob.glob(os.path.join(FLAGS.input_path, *("*" * (FLAGS.depth - 1))))
308
+
309
+ worker_fn = partial(make_numpy, train_proportion=FLAGS.train_proportion)
310
+
311
+ with Pool(FLAGS.num_workers) as p:
312
+ list(tqdm.tqdm(p.imap(worker_fn, paths), total=len(paths)))
313
+
314
+
315
+ if __name__ == "__main__":
316
+ app.run(main)
data/bridgev2/download.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Download the dataset to ../datasets/bridgev2
2
+ mkdir -p ../datasets/bridgev2
3
+ wget -O ../datasets/bridgev2/demos_8_17.zip https://rail.eecs.berkeley.edu/datasets/bridge_release/data/demos_8_17.zip
4
+ mkdir -p ../datasets/bridgev2/raw
5
+ # Unzip the dataset
6
+ unzip '../datasets/bridgev2/*.zip' -d ../datasets/bridgev2/raw
7
+ # Convert the dataset to numpy
8
+ python bridgedata_raw_to_numpy.py --input ../datasets/bridgev2/raw --output ../datasets/bridgev2/npy
9
+ # Convert the dataset to tfrecords
10
+ python bridgedata_numpy_to_tfrecord.py --input ../datasets/bridgev2/npy --output ../datasets/bridgev2/tfrecords
11
+ # Remove the raw data and numpy data
12
+ rm -rf ../datasets/bridgev2/raw
13
+ rm -rf ../datasets/bridgev2/npy
data/calvin/download.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ echo "Downloading CALVIN dataset..."
4
+
5
+ # Create calvin folder in ../datasets/calvin/
6
+ mkdir -p ../datasets/calvin/
7
+
8
+ cd ../datasets/calvin/
9
+
10
+ # You can use this for faster downloading
11
+ # aria2c -x 16 -s 16 http://calvin.cs.uni-freiburg.de/dataset/task_ABC_D.zip
12
+
13
+ wget http://calvin.cs.uni-freiburg.de/dataset/task_ABC_D.zip
14
+
15
+ echo "Unzipping CALVIN dataset..."
16
+
17
+ unzip task_ABC_D.zip
18
+
19
+ echo "Done downloading and unzipping CALVIN dataset."
data/calvin/hdf5totfrecords.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import os
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+
6
+
7
+ def _bytes_feature(value):
8
+ """Returns a bytes_list from a string / byte."""
9
+ if isinstance(value, type(tf.constant(0))):
10
+ value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
11
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
12
+
13
+
14
+ def _bool_feature(value):
15
+ """Returns a bool_list from a boolean."""
16
+ return tf.train.Feature(int64_list=tf.train.Int64List(value=[int(value)]))
17
+
18
+
19
+ def serialize_example(action, robot_obs, rgb_static, rgb_gripper, instruction, terminate_episode):
20
+ # Feature for fixed-length fields
21
+ feature = {
22
+ 'action': _bytes_feature(tf.io.serialize_tensor(action)),
23
+ 'robot_obs': _bytes_feature(tf.io.serialize_tensor(robot_obs)),
24
+ 'rgb_static': _bytes_feature(tf.io.serialize_tensor(rgb_static)),
25
+ 'rgb_gripper': _bytes_feature(tf.io.serialize_tensor(rgb_gripper)),
26
+ 'terminate_episode': _bool_feature(terminate_episode),
27
+ 'instruction': _bytes_feature(instruction),
28
+ }
29
+
30
+ example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
31
+ return example_proto.SerializeToString()
32
+
33
+
34
+ def write_tfrecords(root_dir, out_dir):
35
+ if not os.path.exists(out_dir):
36
+ os.makedirs(out_dir)
37
+
38
+ # Get the language annotation and corresponding indices
39
+ f = np.load(os.path.join(root_dir, "lang_annotations/auto_lang_ann.npy"), allow_pickle=True)
40
+ lang = f.item()['language']['ann']
41
+ lang = np.array([x.encode('utf-8') for x in lang])
42
+ lang_start_end_idx = f.item()['info']['indx']
43
+ num_ep = len(lang_start_end_idx)
44
+
45
+ with tqdm(total=num_ep) as pbar:
46
+ for episode_idx, (start_idx, end_idx) in enumerate(lang_start_end_idx):
47
+ pbar.update(1)
48
+
49
+ step_files = [
50
+ f"episode_{str(i).zfill(7)}.npz"
51
+ for i in range(start_idx, end_idx + 1)
52
+ ]
53
+ action = []
54
+ robot_obs = []
55
+ rgb_static = []
56
+ rgb_gripper = []
57
+ instr = lang[episode_idx]
58
+ for step_file in step_files:
59
+ filepath = os.path.join(root_dir, step_file)
60
+ f = np.load(filepath)
61
+ # Get relevent things
62
+ action.append(f['actions'])
63
+ robot_obs.append(f['robot_obs'])
64
+ rgb_static.append(f['rgb_static'])
65
+ rgb_gripper.append(f['rgb_gripper'])
66
+
67
+ tfrecord_path = os.path.join(out_dir, f'{episode_idx:07d}.tfrecord')
68
+ print(f"Writing TFRecords to {tfrecord_path}")
69
+ with tf.io.TFRecordWriter(tfrecord_path) as writer:
70
+ for i in range(len(step_files)):
71
+ serialized_example = serialize_example(
72
+ action[i], robot_obs[i], rgb_static[i], rgb_gripper[i], instr, i == len(step_files) - 1
73
+ )
74
+ writer.write(serialized_example)
75
+
76
+ output_dirs = [
77
+ '../datasets/calvin/tfrecords/training',
78
+ '../datasets/calvin/tfrecords/validation'
79
+ ]
80
+
81
+ for output_dir in output_dirs:
82
+ if not os.path.exists(output_dir):
83
+ os.makedirs(output_dir)
84
+
85
+ root_dirs = [
86
+ '../datasets/calvin/task_ABC_D/training',
87
+ '../datasets/calvin/task_ABC_D/validation'
88
+ ]
89
+
90
+ for root_dir, output_dir in zip(root_dirs, output_dirs):
91
+ print(f"Writing TFRecords to {output_dir}")
92
+ write_tfrecords(root_dir, output_dir)
data/rh20t/hdf5totfrecords.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import cv2
4
+ from multiprocessing import Pool, cpu_count, current_process
5
+ import tensorflow as tf
6
+ from tqdm import tqdm
7
+ import json
8
+
9
+ def _parse_function(proto):
10
+ # Define how to parse the data here.
11
+ feature_description = {
12
+ 'joint': tf.io.FixedLenFeature([], tf.string),
13
+ 'image': tf.io.FixedLenFeature([], tf.string),
14
+ 'instruction': tf.io.FixedLenFeature([], tf.string),
15
+ 'terminate_episode': tf.io.FixedLenFeature([], tf.int64),
16
+ 'gripper': tf.io.FixedLenFeature([], tf.string, default_value=""),
17
+ 'tcp': tf.io.FixedLenFeature([], tf.string, default_value=""),
18
+ 'tcp_base': tf.io.FixedLenFeature([], tf.string, default_value="")
19
+ }
20
+ parsed_features = tf.io.parse_single_example(proto, feature_description)
21
+ # Parse tensors
22
+ parsed_features['joint'] = tf.io.parse_tensor(parsed_features['joint'], out_type=tf.float64)
23
+ parsed_features['image'] = tf.io.parse_tensor(parsed_features['image'], out_type=tf.uint8)
24
+ parsed_features['instruction'] = tf.io.parse_tensor(parsed_features['instruction'], out_type=tf.string)
25
+ parsed_features['gripper'] = tf.cond(
26
+ tf.math.equal(parsed_features['gripper'], ""),
27
+ lambda: tf.constant([], dtype=tf.float64),
28
+ lambda: tf.io.parse_tensor(parsed_features['gripper'], out_type=tf.float64)
29
+ )
30
+ parsed_features['tcp'] = tf.cond(
31
+ tf.math.equal(parsed_features['tcp'], ""),
32
+ lambda: tf.constant([], dtype=tf.float64),
33
+ lambda: tf.io.parse_tensor(parsed_features['tcp'], out_type=tf.float64)
34
+ )
35
+ parsed_features['tcp_base'] = tf.cond(
36
+ tf.math.equal(parsed_features['tcp_base'], ""),
37
+ lambda: tf.constant([], dtype=tf.float64),
38
+ lambda: tf.io.parse_tensor(parsed_features['tcp_base'], out_type=tf.float64)
39
+ )
40
+ return parsed_features
41
+
42
+ def convert_color(color_file, color_timestamps):
43
+ """
44
+ Args:
45
+ - color_file: the color video file;
46
+ - color_timestamps: the color timestamps;
47
+ - dest_color_dir: the destination color directory.
48
+ """
49
+ cap = cv2.VideoCapture(color_file)
50
+ cnt = 0
51
+ frames = []
52
+ while True:
53
+ ret, frame = cap.read()
54
+ if ret:
55
+ resized_frame = cv2.resize(frame, (640, 360))
56
+ frames.append(resized_frame)
57
+ cnt += 1
58
+ else:
59
+ break
60
+ cap.release()
61
+ return frames
62
+
63
+ def _bytes_feature(value):
64
+ if isinstance(value, type(tf.constant(0))):
65
+ value = value.numpy()
66
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
67
+
68
+ def _bool_feature(value):
69
+ return tf.train.Feature(int64_list=tf.train.Int64List(value=[int(value)]))
70
+
71
+ def serialize_example(joint,gripper,tcp,tcp_base,image,instruction,terminate_episode):
72
+ feature = {
73
+ 'joint': _bytes_feature(tf.io.serialize_tensor(joint)),
74
+ 'image': _bytes_feature(tf.io.serialize_tensor(image)),
75
+ 'instruction': _bytes_feature(tf.io.serialize_tensor(instruction)),
76
+ 'terminate_episode': _bool_feature(terminate_episode),
77
+ }
78
+ if gripper is not None:
79
+ feature['gripper'] = _bytes_feature(tf.io.serialize_tensor(gripper))
80
+ if tcp is not None:
81
+ feature['tcp'] = _bytes_feature(tf.io.serialize_tensor(tcp))
82
+ if tcp_base is not None:
83
+ feature['tcp_base'] = _bytes_feature(tf.io.serialize_tensor(tcp_base))
84
+ example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
85
+ return example_proto.SerializeToString()
86
+
87
+ def compress_tfrecord(tfrecord_path):
88
+ raw_dataset = tf.data.TFRecordDataset(tfrecord_path)
89
+ parsed_dataset = raw_dataset.map(_parse_function)
90
+
91
+ # Serialize and write to a new TFRecord file
92
+ with tf.io.TFRecordWriter(tfrecord_path) as writer:
93
+ for features in parsed_dataset:
94
+ image_tensor = features['image']
95
+ image_np = image_tensor.numpy()
96
+ if len(image_np.shape) <= 1: # already compressed
97
+ return
98
+ _, compressed_image = cv2.imencode('.jpg', image_np)
99
+ features['image'] = tf.io.serialize_tensor(tf.convert_to_tensor(compressed_image.tobytes(), dtype=tf.string))
100
+
101
+ def _bytes_feature(value):
102
+ """Returns a bytes_list from a string / byte."""
103
+ if isinstance(value, type(tf.constant(0))):
104
+ value = value.numpy()
105
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
106
+
107
+ feature_dict = {
108
+ 'joint': _bytes_feature(features['joint']),
109
+ 'image': _bytes_feature(features['image']),
110
+ 'instruction': _bytes_feature(features['instruction']),
111
+ 'terminate_episode': tf.train.Feature(int64_list=tf.train.Int64List(value=[features['terminate_episode']])),
112
+ 'gripper': _bytes_feature(features['gripper']),
113
+ 'tcp': _bytes_feature(features['tcp']),
114
+ 'tcp_base': _bytes_feature(features['tcp_base'])
115
+ }
116
+ example_proto = tf.train.Example(features=tf.train.Features(feature=feature_dict))
117
+ serialized_example = example_proto.SerializeToString()
118
+ writer.write(serialized_example)
119
+ print(f"compressed {tfrecord_path}")
120
+
121
+ def write_task(args):
122
+ task_dir,output_dir = args
123
+
124
+ all_instructions = json.load(open('./instruction.json'))
125
+ instruction = None
126
+ for taskid in list(all_instructions.keys()):
127
+ if taskid in task_dir:
128
+ instruction = all_instructions[taskid]['task_description_english']
129
+ if instruction is None:
130
+ return
131
+
132
+ if not os.path.exists(output_dir):
133
+ os.makedirs(output_dir)
134
+ joints = np.load(os.path.join(task_dir,"transformed/joint.npy"),allow_pickle=True).item()
135
+ if not os.path.exists(os.path.join(task_dir,"transformed/gripper.npy")):
136
+ return
137
+ grippers = np.load(os.path.join(task_dir,"transformed/gripper.npy"),allow_pickle=True).item()
138
+ tcps = np.load(os.path.join(task_dir,"transformed/tcp.npy"),allow_pickle=True).item()
139
+ tcp_bases = np.load(os.path.join(task_dir,"transformed/tcp_base.npy"),allow_pickle=True).item()
140
+
141
+ for camid in joints.keys():
142
+ timesteps = joints[camid]
143
+ if len(timesteps) == 0:
144
+ continue
145
+ tfrecord_path = os.path.join(output_dir,f'cam_{camid}.tfrecord')
146
+ timesteps_file = os.path.join(task_dir,f'cam_{camid}/timestamps.npy')
147
+
148
+ if not os.path.exists(timesteps_file):
149
+ continue
150
+ if os.path.exists(tfrecord_path) and os.path.getsize(tfrecord_path) > 0:
151
+ continue
152
+
153
+ timesteps_file = np.load(timesteps_file,allow_pickle=True).item()
154
+ images = convert_color(os.path.join(task_dir,f'cam_{camid}/color.mp4'),timesteps_file['color'])
155
+ if len(timesteps) != len(images): ## BUG FROM RH20T
156
+ continue
157
+ with tf.io.TFRecordWriter(tfrecord_path) as writer:
158
+ for i,timestep in enumerate(timesteps):
159
+ # image = cv2.imread(os.path.join(img_dir,f"{timestep}.jpg"))
160
+ image = cv2.imencode('.jpg', images[i])[1].tobytes()
161
+ joint_pos = joints[camid][timestep]
162
+ tcp = next((item for item in tcps[camid] if item['timestamp'] == timestep), None)['tcp']
163
+ tcp_base = next((item for item in tcp_bases[camid] if item['timestamp'] == timestep), None)['tcp']
164
+ if timestep not in grippers[camid]:
165
+ gripper_pos = None
166
+ else:
167
+ gripper_pos = grippers[camid][timestep]['gripper_info']
168
+ terminate_episode = i == len(timesteps) - 1
169
+ # read from instruction.json
170
+ serialized_example = serialize_example(joint_pos,gripper_pos,tcp,tcp_base,image,instruction,terminate_episode)
171
+ writer.write(serialized_example)
172
+
173
+
174
+ def write_tfrecords(root_dir,output_dir,num_processes = None):
175
+ if not os.path.exists(output_dir):
176
+ os.makedirs(output_dir)
177
+ if num_processes is None:
178
+ num_processes = cpu_count()
179
+
180
+ num_files = 0
181
+ args = []
182
+ for dirs in os.listdir(root_dir):
183
+ for task in os.listdir(os.path.join(root_dir,dirs)):
184
+ if 'human' in task:
185
+ continue
186
+ task_dir = os.path.join(root_dir,dirs,task)
187
+ joint_path = os.path.join(task_dir,"transformed/joint.npy")
188
+ if not os.path.exists(joint_path):
189
+ continue
190
+ num_files += 1
191
+ task_out = os.path.join(output_dir,dirs,task)
192
+ os.makedirs(task_out,exist_ok=True)
193
+ args.append((task_dir,task_out))
194
+
195
+ with tqdm(total=num_files, desc="Processing files") as pbar:
196
+ with Pool(num_processes) as pool:
197
+ for _ in pool.imap_unordered(write_task, args):
198
+ pbar.update(1)
199
+
200
+ write_tfrecords('../datasets/rh20t/raw_data/','../datasets/rh20t/tfrecords/')
data/roboset/download.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import os
3
+ from tqdm import tqdm
4
+
5
+ links = []
6
+ with open('links.txt', 'r', encoding='utf-8') as file:
7
+ for line in file:
8
+ links.append(line.strip())
9
+
10
+ download_dir = "../datasets/roboset"
11
+ os.makedirs(download_dir, exist_ok=True)
12
+
13
+ for link in links:
14
+ filename = os.path.basename(link)
15
+ filepath = os.path.join(download_dir, filename)
16
+ print(f"Downloading {filename} from {link}")
17
+
18
+ response = requests.get(link, stream=True)
19
+ total_size_in_bytes = int(response.headers.get('content-length', 0))
20
+ block_size = 1024
21
+
22
+ if os.path.exists(filepath):
23
+ local_size = os.path.getsize(filepath)
24
+ if local_size == total_size_in_bytes:
25
+ print(f"{filename} already exists")
26
+ continue
27
+
28
+ progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
29
+
30
+ with open(filepath, 'wb') as f:
31
+ for data in response.iter_content(block_size):
32
+ progress_bar.update(len(data))
33
+ f.write(data)
34
+
35
+ progress_bar.close()
36
+
37
+ if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
38
+ print("ERROR, something went wrong")
39
+
40
+ print(f"Downloaded {filename}")
41
+
42
+ print("All files processed.")
data/roboset/download.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ while true; do
4
+ python download.py
5
+ EXIT_CODE=$?
6
+ if [ $EXIT_CODE -ne 0 ]; then
7
+ echo "Download exited with code $EXIT_CODE. Restarting..."
8
+ else
9
+ echo "Download exited with code 0. Not restarting."
10
+ break
11
+ fi
12
+ done
13
+
14
+ # Unzip all the files in the ../datasets/roboset/ directory
15
+ cd ../datasets/roboset/
16
+ for file in *.tar.gz; do
17
+ tar -xzvf "$file"
18
+ done
19
+
20
+ ## Convert the dataset to tfrecords
21
+ python hdf5totfrecords.py
data/roboset/h5totfrecords.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import h5py
3
+ import os
4
+ import fnmatch
5
+ from tqdm import tqdm
6
+ from multiprocessing import Pool, cpu_count, current_process
7
+
8
+ def _bytes_feature(value):
9
+ if isinstance(value, type(tf.constant(0))):
10
+ value = value.numpy()
11
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
12
+
13
+ def _bool_feature(value):
14
+ return tf.train.Feature(int64_list=tf.train.Int64List(value=[int(value)]))
15
+
16
+ def serialize_example(action, action_gripper, qpos, qvel, qpos_gripper, qvel_gripper, rgb_left, rgb_right, rgb_top, terminate_episode):
17
+ feature = {
18
+ 'action': _bytes_feature(tf.io.serialize_tensor(action)),
19
+ 'action_gripper': _bytes_feature(tf.io.serialize_tensor(action_gripper)),
20
+ 'qpos': _bytes_feature(tf.io.serialize_tensor(qpos)),
21
+ 'qvel': _bytes_feature(tf.io.serialize_tensor(qvel)),
22
+ 'qpos_gripper': _bytes_feature(tf.io.serialize_tensor(qpos_gripper)),
23
+ 'qvel_gripper': _bytes_feature(tf.io.serialize_tensor(qvel_gripper)),
24
+ 'rgb_left': _bytes_feature(tf.io.serialize_tensor(rgb_left)),
25
+ 'rgb_right': _bytes_feature(tf.io.serialize_tensor(rgb_right)),
26
+ 'rgb_top': _bytes_feature(tf.io.serialize_tensor(rgb_top)),
27
+ 'terminate_episode': _bool_feature(terminate_episode),
28
+ }
29
+ example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
30
+ return example_proto.SerializeToString()
31
+
32
+ def process_file(params):
33
+ filepath, output_dir = params
34
+ with h5py.File(filepath, 'r') as f:
35
+ for Trial in f.keys():
36
+ data = f[Trial]['data']
37
+ tfrecord_path = os.path.join(output_dir, os.path.basename(filepath).replace('.h5', f'_{Trial}.tfrecord'))
38
+ if os.path.exists(tfrecord_path) and os.path.getsize(tfrecord_path) > 0:
39
+ continue
40
+ with tf.io.TFRecordWriter(tfrecord_path) as writer:
41
+ num_episodes = data['ctrl_arm'].shape[0]
42
+ for i in range(num_episodes):
43
+ action = data['ctrl_arm'][i]
44
+ action_gripper = data['ctrl_ee'][i]
45
+ qpos = data['qp_arm'][i]
46
+ qvel = data['qv_arm'][i]
47
+ qpos_gripper = data['qp_ee'][i]
48
+ qvel_gripper = data['qv_ee'][i]
49
+ rgb_left = data['rgb_left'][i]
50
+ rgb_right = data['rgb_right'][i]
51
+ rgb_top = data['rgb_top'][i]
52
+ terminate_episode = i == num_episodes - 1
53
+ serialized_example = serialize_example(action, action_gripper, qpos, qvel, qpos_gripper, qvel_gripper, rgb_left, rgb_right, rgb_top, terminate_episode)
54
+ writer.write(serialized_example)
55
+
56
+ def write_tfrecords(root_dir, out_dir, num_processes=None):
57
+ if not os.path.exists(out_dir):
58
+ os.makedirs(out_dir)
59
+
60
+ if num_processes is None:
61
+ num_processes = cpu_count()
62
+
63
+ file_list = []
64
+ num_files = 0
65
+ for root, dirs, files in os.walk(root_dir):
66
+ for filename in fnmatch.filter(files, '*.h5'):
67
+ filepath = os.path.join(root, filename)
68
+ output_dir = os.path.join(out_dir, os.path.relpath(os.path.dirname(filepath), root_dir))
69
+ if not os.path.exists(output_dir):
70
+ os.makedirs(output_dir)
71
+ num_files += 1
72
+ file_list.append((filepath, output_dir))
73
+
74
+ with tqdm(total=num_files, desc="Processing files") as pbar:
75
+ with Pool(num_processes) as pool:
76
+ for _ in pool.imap_unordered(process_file, file_list):
77
+ pbar.update(1)
78
+
79
+ root_dir = '../datasets/roboset/'
80
+ output_dir = '../datasets/roboset/tfrecords/'
81
+
82
+ write_tfrecords(root_dir, output_dir)
data/roboset/links.txt ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_1_Blocks_895/AutonomousRoboSet_Set_1_Blocks_895.tar.gz
2
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_0.tar.gz
3
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_1.tar.gz
4
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_2.tar.gz
5
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_3.tar.gz
6
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_4.tar.gz
7
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_5.tar.gz
8
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_6.tar.gz
9
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_7.tar.gz
10
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_8.tar.gz
11
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_9.tar.gz
12
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_10.tar.gz
13
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_11.tar.gz
14
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_12.tar.gz
15
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_2_SoftToys_12585/Autonomous_RoboSet_Set_2_SoftToys_839_13.tar.gz
16
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_3_Blocks_and_Toys_980/Autonomous_RoboSet_Set_3_Blocks_and_Toys_980.tar.gz
17
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_4_Medium_Block_7/Autonomous_RoboSet_Set_4_Medium_Block_7.tar.gz
18
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_5_Bottle_Cube_14/Autonomous_RoboSet_Set_5_Bottle_Cube_14.tar.gz
19
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_6_Planar_Push_120/Autonomous_RoboSet_Set_6_Planar_Push_120.tar.gz
20
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_7_Pick_Orange_Block_607/Autonomous_RoboSet_Set_7_Pick_Orange_Block_607_0.tar.gz
21
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_7_Pick_Orange_Block_607/Autonomous_RoboSet_Set_7_Pick_Orange_Block_607_1.tar.gz
22
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_7_Pick_Orange_Block_607/Autonomous_RoboSet_Set_7_Pick_Orange_Block_607_2.tar.gz
23
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_7_Pick_Orange_Block_607/Autonomous_RoboSet_Set_7_Pick_Orange_Block_607_3.tar.gz
24
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_7_Pick_Orange_Block_607/Autonomous_RoboSet_Set_7_Pick_Orange_Block_607_4.tar.gz
25
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_8_Pick_Bottle_10/Autonomous_RoboSet_Set_8_Pick_Bottle_10.tar.gz
26
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_5420/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_542_0.tar.gz
27
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_5420/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_542_1.tar.gz
28
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_5420/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_542_2.tar.gz
29
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_5420/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_542_3.tar.gz
30
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_5420/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_542_4.tar.gz
31
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_5420/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_542_5.tar.gz
32
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_5420/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_542_6.tar.gz
33
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_5420/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_542_7.tar.gz
34
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_5420/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_542_8.tar.gz
35
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_5420/Autonomous_RoboSet_Set_9_Pick_Wooden_Block_542_9.tar.gz
36
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_10_Pick_Block_Eval_1837/Autonomous_RoboSet_Set_10_Pick_Block_Eval_167_0.tar.gz
37
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_10_Pick_Block_Eval_1837/Autonomous_RoboSet_Set_10_Pick_Block_Eval_167_1.tar.gz
38
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_10_Pick_Block_Eval_1837/Autonomous_RoboSet_Set_10_Pick_Block_Eval_167_2.tar.gz
39
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_10_Pick_Block_Eval_1837/Autonomous_RoboSet_Set_10_Pick_Block_Eval_167_3.tar.gz
40
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_10_Pick_Block_Eval_1837/Autonomous_RoboSet_Set_10_Pick_Block_Eval_167_4.tar.gz
41
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_10_Pick_Block_Eval_1837/Autonomous_RoboSet_Set_10_Pick_Block_Eval_167_5.tar.gz
42
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_10_Pick_Block_Eval_1837/Autonomous_RoboSet_Set_10_Pick_Block_Eval_167_6.tar.gz
43
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_10_Pick_Block_Eval_1837/Autonomous_RoboSet_Set_10_Pick_Block_Eval_167_7.tar.gz
44
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_10_Pick_Block_Eval_1837/Autonomous_RoboSet_Set_10_Pick_Block_Eval_167_8.tar.gz
45
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_10_Pick_Block_Eval_1837/Autonomous_RoboSet_Set_10_Pick_Block_Eval_167_9.tar.gz
46
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_10_Pick_Block_Eval_1837/Autonomous_RoboSet_Set_10_Pick_Block_Eval_167_10.tar.gz
47
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_2070/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_207_0.tar.gz
48
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_2070/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_207_1.tar.gz
49
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_2070/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_207_2.tar.gz
50
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_2070/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_207_3.tar.gz
51
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_2070/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_207_4.tar.gz
52
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_2070/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_207_5.tar.gz
53
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_2070/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_207_6.tar.gz
54
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_2070/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_207_7.tar.gz
55
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_2070/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_207_8.tar.gz
56
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_2070/Autonomous_RoboSet_Set_11_Pick_Bottle_Eval_207_9.tar.gz
57
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_12_Pick_Block_Eval_2000/Autonomous_RoboSet_Set_12_Pick_Block_Eval_200_0.tar.gz
58
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_12_Pick_Block_Eval_2000/Autonomous_RoboSet_Set_12_Pick_Block_Eval_200_1.tar.gz
59
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_12_Pick_Block_Eval_2000/Autonomous_RoboSet_Set_12_Pick_Block_Eval_200_2.tar.gz
60
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_12_Pick_Block_Eval_2000/Autonomous_RoboSet_Set_12_Pick_Block_Eval_200_3.tar.gz
61
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_12_Pick_Block_Eval_2000/Autonomous_RoboSet_Set_12_Pick_Block_Eval_200_4.tar.gz
62
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_12_Pick_Block_Eval_2000/Autonomous_RoboSet_Set_12_Pick_Block_Eval_200_5.tar.gz
63
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_12_Pick_Block_Eval_2000/Autonomous_RoboSet_Set_12_Pick_Block_Eval_200_6.tar.gz
64
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_12_Pick_Block_Eval_2000/Autonomous_RoboSet_Set_12_Pick_Block_Eval_200_7.tar.gz
65
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_12_Pick_Block_Eval_2000/Autonomous_RoboSet_Set_12_Pick_Block_Eval_200_8.tar.gz
66
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_12_Pick_Block_Eval_2000/Autonomous_RoboSet_Set_12_Pick_Block_Eval_200_9.tar.gz
67
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_0.tar.gz
68
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_1.tar.gz
69
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_2.tar.gz
70
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_3.tar.gz
71
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_4.tar.gz
72
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_5.tar.gz
73
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_6.tar.gz
74
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_7.tar.gz
75
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_8.tar.gz
76
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_9.tar.gz
77
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_10.tar.gz
78
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_11.tar.gz
79
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_12.tar.gz
80
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_13.tar.gz
81
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_4170/Autonomous_RoboSet_Set_13_Bin_Reorient_Eval_278_14.tar.gz
82
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_0.tar.gz
83
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_1.tar.gz
84
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_2.tar.gz
85
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_3.tar.gz
86
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_4.tar.gz
87
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_5.tar.gz
88
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_6.tar.gz
89
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_7.tar.gz
90
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_8.tar.gz
91
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_9.tar.gz
92
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_10.tar.gz
93
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_11.tar.gz
94
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_12.tar.gz
95
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_13.tar.gz
96
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_14.tar.gz
97
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_15.tar.gz
98
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_16.tar.gz
99
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_17.tar.gz
100
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_18.tar.gz
101
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_14_Bin_Push_Eval_11300/Autonomous_RoboSet_Set_14_Bin_Push_Eval_565_19.tar.gz
102
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_1470/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_147_0.tar.gz
103
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_1470/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_147_1.tar.gz
104
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_1470/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_147_2.tar.gz
105
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_1470/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_147_3.tar.gz
106
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_1470/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_147_4.tar.gz
107
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_1470/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_147_5.tar.gz
108
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_1470/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_147_6.tar.gz
109
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_1470/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_147_7.tar.gz
110
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_1470/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_147_8.tar.gz
111
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_1470/Autonomous_RoboSet_Set_15_Bin_Reorient_Eval_2_147_9.tar.gz
112
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_0.tar.gz
113
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_1.tar.gz
114
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_2.tar.gz
115
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_3.tar.gz
116
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_4.tar.gz
117
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_5.tar.gz
118
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_6.tar.gz
119
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_7.tar.gz
120
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_8.tar.gz
121
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_9.tar.gz
122
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_10.tar.gz
123
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_11.tar.gz
124
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_12.tar.gz
125
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_13.tar.gz
126
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_2115/Autonomous_RoboSet_Set_16_Bin_Reorient_Eval_3_141_14.tar.gz
127
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_0.tar.gz
128
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_1.tar.gz
129
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_2.tar.gz
130
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_3.tar.gz
131
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_4.tar.gz
132
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_5.tar.gz
133
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_6.tar.gz
134
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_7.tar.gz
135
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_8.tar.gz
136
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_9.tar.gz
137
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_10.tar.gz
138
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_11.tar.gz
139
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_12.tar.gz
140
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_13.tar.gz
141
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_14.tar.gz
142
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_15.tar.gz
143
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_16.tar.gz
144
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_17.tar.gz
145
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_18.tar.gz
146
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_19.tar.gz
147
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_20.tar.gz
148
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_21.tar.gz
149
+ https://dl.fbaipublicfiles.com/RoboSet/AutonomousSet/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_10465/Autonomous_RoboSet_Set_17_Plannar_Push_Eval_455_22.tar.gz
150
+ https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_banana_place_in_mug.tar.gz
151
+ https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_banana_place_in_strainer.tar.gz
152
+ https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_banana_from_plate_place_on_table.tar.gz
153
+ https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_banana_from_toaster_place_on_table.tar.gz
154
+ https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_banana_place_on_plate.tar.gz
155
+ https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_banana_place_on_toaster.tar.gz
156
+ https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_ketchup_place_in_strainer.tar.gz
157
+ https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_ketchup_place_in_toaster.tar.gz
158
+ https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_ketchup_from_strainer_place_on_table.tar.gz
159
+ https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_ketchup_from_plate_place_on_table.tar.gz
160
+ https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_ketchup_from_toaster_place_on_table.tar.gz
161
+ https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_ketchup_place_on_plate.tar.gz
162
+ https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/pick_ketchup_place_on_toaster.tar.gz
163
+ https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/drag_mug_backward.tar.gz
164
+ https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/drag_mug_forward.tar.gz
165
+ https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/drag_mug_from_left_to_right.tar.gz
166
+ https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/drag_mug_from_right_to_left.tar.gz
167
+ https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/drag_strainer_backward.tar.gz
168
+ https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/drag_strainer_forward.tar.gz
169
+ https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/drag_strainer_left_to_right.tar.gz
170
+ https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/drag_strainer_right_to_left.tar.gz
171
+ https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/flap_open_toaster_oven.tar.gz
172
+ https://dl.fbaipublicfiles.com/RoboSet/KinestheticSet/Activities/flap_close_toaster_oven.tar.gz
173
+ http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/baking_prep/baking_prep_slide_open_drawer_scene_1.tar.gz
174
+ http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/baking_prep/baking_prep_slide_open_drawer_scene_4.tar.gz
175
+ http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/baking_prep/baking_prep_pick_butter_scene_1.tar.gz
176
+ http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/baking_prep/baking_prep_pick_butter_scene_4.tar.gz
177
+ http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/baking_prep/baking_prep_place_butter_scene_1.tar.gz
178
+ http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/baking_prep/baking_prep_place_butter_scene_4.tar.gz
179
+ http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/baking_prep/baking_prep_slide_close_drawer_scene_1.tar.gz
180
+ http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/baking_prep/baking_prep_slide_close_drawer_scene_4.tar.gz
181
+ http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/clean_kitchen/clean_kitchen_pick_lid_scene_3.tar.gz
182
+ http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/clean_kitchen/clean_kitchen_cap_lid_scene_3.tar.gz
183
+ http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/clean_kitchen/clean_kitchen_slide_close_drawer_scene_3.tar.gz
184
+ http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/clean_kitchen/clean_kitchen_flap_close_oven_Scene_3.tar.gz
185
+ http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/clean_kitchen/clean_kitchen_pick_towel_scene_3.tar.gz
186
+ http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/clean_kitchen/clean_kitchen_Wipe_Counter_Scene_3.tar.gz
187
+ http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/heat_soup/heat_soup_flap_open_oven_Scene_2.tar.gz
188
+ http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/heat_soup/heat_soup_flap_open_oven_Scene_4.tar.gz
189
+ http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/heat_soup/heat_soup_pick_bowl_scene_2.tar.gz
190
+ http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/heat_soup/heat_soup_pick_bowl_scene_4.tar.gz
191
+ http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/heat_soup/heat_soup_slide_in_bowl_scene_2.tar.gz
192
+ http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/heat_soup/heat_soup_slide_in_bowl_scene_4.tar.gz
193
+ http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/heat_soup/heat_soup_flap_close_oven_scene_2.tar.gz
194
+ http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/make_tea/make_tea_Uncap_Lid_Scene_2.tar.gz
195
+ http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/make_tea/make_tea_place_lid_scene_2.tar.gz
196
+ http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/make_tea/make_tea_pick_tea_scene_2.tar.gz
197
+ http://dl.fbaipublicfiles.com/RoboSet/TeleoperationSet/Activities/make_tea/make_tea_place_tea_scene_2.tar.gz
docs/pretrain.md ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pipeline of Pre-Training RDT
2
+
3
+ Firstly, you need to install the prerequisites for RDT (see [README](../README.md#installation)). Then, you can install the prerequisites for TensorFlow Dataset (in another Conda environment).
4
+
5
+ ## Installation for TensorFlow Dataset
6
+
7
+ ```bash
8
+ # Under the root directory of this repo
9
+ conda create -n rdt-data python=3.10
10
+ conda activate rdt-data
11
+
12
+ # Install all the prequisites
13
+ pip install -r requirements_data.txt
14
+ # Or you can manually install each package (please refer to requirements_data.txt for specific versions)
15
+ pip install tfds-nightly gsutil tensorflow Pillow pyyaml opencv-python tensorflow-graphics imageio[ffmpeg]
16
+ # If the speed is too slow, you can specify alternative sources (tfds-nightly is not available in Tsinghua mirror)
17
+ pip install -i https://pypi.tuna.tsinghua.edu.cn/simple gsutil tensorflow Pillow pyyaml opencv-python tensorflow-graphics imageio[ffmpeg]
18
+ ```
19
+
20
+ ## Download and Prepare Datasets
21
+
22
+ We introduce how to download each of our pre-training datasets. If you plan to pre-train on a subset of them, just download the ones you need. You can also fine-tune RDT through this pipeline only if your target dataset is included below or in the Google Cloud Storage.
23
+
24
+ | Dataset | Sample Percentage (%) |
25
+ | ---- | ---- |
26
+ | RT-1 Dataset | 9.00 |
27
+ | TACO Dataset | 1.99 |
28
+ | JACO Play Dataset | 1.10 |
29
+ | Cable Routing Dataset | 0.27 |
30
+ | NYU Door Opening | 0.33 |
31
+ | Viola | 0.40 |
32
+ | Berkeley UR5 | 1.06 |
33
+ | TOTO | 1.06 |
34
+ | Kuka | 1.66 |
35
+ | Language Table | 3.32 |
36
+ | Columbia Cairlab Pusht Real | 0.40 |
37
+ | Stanford Kuka Multimodal Dataset | 1.83 |
38
+ | Stanford Hydra Dataset | 0.80 |
39
+ | Austin Buds Dataset | 0.23 |
40
+ | Maniskill Dataset | 5.78 |
41
+ | Furniture Bench Dataset | 2.36 |
42
+ | UCSD Kitchen Dataset | 0.40 |
43
+ | UCSD Pick And Place Dataset | 1.23 |
44
+ | Austin Sailor Dataset | 0.50 |
45
+ | Austin Sirius Dataset | 0.80 |
46
+ | BC Z | 6.91 |
47
+ | UTokyo PR2 Opening Fridge | 0.30 |
48
+ | UTokyo PR2 Tabletop Manipulation | 0.50 |
49
+ | UTokyo Xarm Pick And Place | 0.33 |
50
+ | UTokyo Xarm Bimanual | 0.03 |
51
+ | Berkeley MVP | 0.73 |
52
+ | Berkeley RPT | 1.00 |
53
+ | KAIST Nonprehensile | 0.46 |
54
+ | Tokyo U LSMO | 0.23 |
55
+ | DLR Sara Grid Clamp | 0.03 |
56
+ | Robocook | 1.66 |
57
+ | Imperialcollege Sawyer Wrist Cam | 0.43 |
58
+ | Iamlab CMU Pickup Insert | 0.83 |
59
+ | UTAustin Mutex | 1.29 |
60
+ | Fanuc Manipulation | 0.66 |
61
+ | Play Fusion | 0.80 |
62
+ | Droid | 10.06 |
63
+ | FMB| 1.39 |
64
+ | Dobb·E | 1.20 |
65
+ | QUT Dexterous Manipulation | 0.46 |
66
+ | Aloha Dataset | 4.98 |
67
+ | Mobile Aloha Dataset | 4.98 |
68
+ | Roboset | 4.48 |
69
+ | RH20T | 10.99 |
70
+ | Calvin Dataset | 3.32 |
71
+ | Bridgev2 | 7.44 |
72
+
73
+ Before everything, let's link the dataset directory on your disk to a subfolder of this repo:
74
+
75
+ ```bash
76
+ ln -s /path/to/dataset /path/to/repo/RoboticsDiffusionTransformer/data/datasets
77
+ ```
78
+
79
+ ### Open X-Embodiment
80
+
81
+ Specify the correct path to the `gsutil` in your Conda in [this file](../data/openx_embod/download.sh#L72).
82
+
83
+ Run the following commands to download our selected datasets for the Open X-Embodiment:
84
+
85
+ ```bash
86
+ # Under the root directory of this repo
87
+ cd data/openx_embod
88
+ # Download all datasets
89
+ bash download_openx_embod.sh
90
+ ```
91
+
92
+ Note: By modifying `download_openx_embod.sh`, you can download any dataset on the Google Cloud (as long as it can be downloaded with `gsutil` and is stored in `TFRecord` format), not just the ones we have listed.
93
+
94
+ ### Mobile ALOHA Dataset
95
+
96
+ Download the Mobile ALOHA Dataset from the [official website](https://mobile-aloha.github.io) to `data/datasets/aloha`, then run:
97
+
98
+ ```bash
99
+ cd data/aloha
100
+ # Convert the dataset to TFRecord
101
+ python hdf5totfrecords.py
102
+ ```
103
+
104
+ ### Bridgev2
105
+
106
+ Run:
107
+
108
+ ```bash
109
+ cd data/bridgev2
110
+ # Download and preprocess the dataset
111
+ sh download.sh
112
+ ```
113
+
114
+ ### Calvin
115
+
116
+ Run:
117
+
118
+ ```bash
119
+ cd data/calvin
120
+ # Download and preprocess the dataset
121
+ sh download.sh
122
+ # Convert the dataset to TFRecord format
123
+ python hdf5totfrecords.py
124
+ ```
125
+
126
+ ### RH20T
127
+
128
+ Download the RH20T Dataset from there [official website](https://rh20t.github.io/#download) to `data/datasets/rh20t`, then run
129
+
130
+ ```bash
131
+ cd data/rh20t
132
+ # Convert the dataset to TFRecord
133
+ python hdf5totfrecords.py
134
+ ```
135
+
136
+ ### RoboSet
137
+
138
+ Run:
139
+
140
+ ```bash
141
+ cd data/roboset
142
+ # Download and preprocess the dataset
143
+ sh download.sh
144
+ ```
145
+
146
+ ## If Want to Train on a New Dataset
147
+
148
+
149
+ If you want to train on a new dataset (e.g., `my_pretrain_dataset`) through this pre-training pipeline, you need to modify several files as follows:
150
+
151
+ ##### 1. `configs/dataset_control_freq.json`
152
+
153
+ Add the control frequency of your dataset.
154
+
155
+ ##### 2. `data/preprocess_scripts/my_pretrain_dataset.py`
156
+
157
+ If your dataset can be loaded by `tfds.builder_from_directory()`, then you only need to download it into the folder of Open X-Embodiment `data/datasets/openx_embod` and implement the function of `process_step()`. You may need to specify the tfds loading path in L78 (see [this file](../data/vla_dataset.py#L78)). We refer to `data/preprocess_scripts/droid.py` for an example.
158
+
159
+ If not, you need to first convert it into TFRecords and then implement both `load_dataset()` and `process_step()`. We refer to `data/agilex/hdf5totfrecords.py` and `data/preprocess_scripts/agilex.py` for examples.
160
+
161
+ Here some descriptions:
162
+
163
+ ##### `load_dataset(seed: int)`
164
+
165
+ - Returns a dataset that supports iterator and `repeat` method with a random seed.
166
+ - Suggested implementation: Use `tf.data.Dataset.from_generator` and `tf.data.TFRecordDataset`.
167
+ - The iterator should return a subdataset that supports iterator representing one episode with the following structure:
168
+ - `step`: A dataset object that supports iterator containing multiple frames per episode.
169
+ - `observation`: A dictionary containing your images.
170
+ - `your_first_image_key`: Your observation RGB image keys.
171
+ - ...
172
+ - `other_attribute`: Any other relevant attributes.
173
+
174
+ ##### `process_step(step: dict) -> dict`
175
+
176
+ Processes a single frame and returns a dictionary with the following keys:
177
+
178
+ - `observation`:
179
+ - `your_first_view_image: tf.Tensor`: Your first view image.
180
+ - `arm_concat: tf.Tensor`: Concatenation of physical states.
181
+ - `format: tf.constant(string)`: Format of `arm_concat` (e.g., `arm_joint_pos_0,arm_joint_pos_1,arm_joint_pos_2`).
182
+ - `action`: Frame action (leave empty if there's none).
183
+ - `arm_concat`: Same as in `observation`.
184
+ - `format`: Same as in `observation`.
185
+ - `terminate: tf.Tensor`: Boolean Tensor indicates if the episode ends.
186
+
187
+ **IMPORTANT**: You should only use TensorFlow functions for any branch or loop operations. For example, use `tf.cond` instead of `if`.
188
+
189
+ ##### 3. `configs/dataset_img_keys.json`
190
+
191
+ Add the image keys of your dataset. For example:
192
+
193
+ ```json
194
+ "my_pretrain_dataset": {
195
+ "image_keys": [
196
+ "exterior-cam",
197
+ "right-wrist-cam",
198
+ "left-wrist-cam",
199
+ "left-wrist-cam"
200
+ ],
201
+ "image_mask": [1, 1, 1, 0]
202
+ }
203
+ ```
204
+
205
+ - To make TensorFlow happy, you have to specify four images in this order: `exterior-cam, right-wrist-cam, left-wrist-cam, any-cam`. Each key should correspond to your `step` attribute key of observation images.
206
+
207
+ - If you only have a single wrist, just make it a *right* wrist.
208
+
209
+ - The `image_mask` indicates whether each image is valid (1) or not (0).
210
+
211
+ - What if you don’t have four images? Simply repeat the images in the following positions and set their masks to 0 (invalid).
212
+
213
+ - The key order is *strict*. If you don't have the exterior camera but have both wrists, leave the exterior position blank (or pad) and use the following:
214
+
215
+ ```json
216
+ "my_pretrain_dataset": {
217
+ "image_keys": [
218
+ "right-wrist-cam",
219
+ "right-wrist-cam",
220
+ "left-wrist-cam",
221
+ "left-wrist-cam"
222
+ ],
223
+ "image_mask": [0, 1, 1, 0]
224
+ }
225
+ ```
226
+
227
+ - During training, only the first *three* cameras will be used.
228
+ ##### 4. `configs/dataset_stat.json`
229
+
230
+ Compute the statistics (min, max, mean, and std) for your dataset:
231
+
232
+ ```bash
233
+ # Use -h to see the full usage
234
+ python -m data.compute_dataset_stat --skip_exist
235
+ ```
236
+ This will update the `dataset_stat.json` file with your dataset's statistics.
237
+
238
+ ##### 5. `data/vla_dataset.py`
239
+
240
+ - Add your dataset to `DATASET_NAMES_NOOPENX` if it cannot be loaded by `tfds.builder_from_directory()`.
241
+ - If your dataset only contains action but no proprioception (i.e., robot state), add your dataset to `DATASET_NAMES_NO_STATE` in [this file](../data/preprocess.py).
242
+ - Normally, we consider the future state as the action of current timestep. If you want to use different actions, you should implement more functions. We refer to `flatten_episode_agilex()` in [this file](../data/episode_transform.py) and `_generate_json_state_agilex()` in [this file](../data/preprocess.py) for examples. You may also refer to L318 in [this file](../data/preprocess.py) and L128 in [this file](../data/vla_dataset.py) for how to select your dataset and preprocess it differently.
243
+
244
+ ## Start Pre-Training
245
+
246
+ We employ a producer-consumer framework with TensorFlow Dataset for fast data loading. Since most of the datasets in the Open X-Embodiment are stored in the form of `TFRecord`, we convert all pre-training datasets into `TFRecord` for storage. In pre-training, we use the producer process to decompress the data from `TFRecord` and store it in a buffer on the hard disk. At the same time, we use the consumer process to read data from the buffer in a disorderly order and feed it to the model training. This not only decouples the `TensorFlow` and `PyTorch` environments but also alleviates the training performance loss caused by the small size of the shuffling buffer in the memory.
247
+
248
+ [This file](../configs/base.yaml) includes configurations relevant to model architecture (including number of heads, hidden dimension, and so on) and data processing. You may need to modify `buf_path` (L22) to your real buffer path. This buffer is used as disk shuffling buffer for data loading.
249
+
250
+ Configurations relevant to training are passed through *Command Line Arguments*. Use `python main.py -h ` to see the descriptions. We provide an example pre-training script in [this file](../pretrain.sh) (`pretrain.sh`). You may need to modify some of the parameters in this file, such as `CUTLASS_PATH` and `WANDB_PROJECT`.
251
+
252
+ You may need to modify the list of pre-training datasets in [this file](../configs/pretrain_datasets.json) and their corresponding sampling weights in [this file](../configs/pretrain_sample_weights.json). If you want to fine-tune RDT through this pipeline, you may need to remove abundant datasets in the list.
253
+
254
+ Before start pre-training, we first start the data producer process (if you use multiple nodes, you should run this command in each node):
255
+
256
+ ```bash
257
+ # Under the root directory of this repo
258
+ conda activate rdt-data
259
+ # Use -h to see the full usage
260
+ python -m data.producer --fill_up
261
+ # Please proceed to the next step AFTER finishing the filling up process
262
+ ```
263
+
264
+ Then, we run the pre-training script:
265
+
266
+ ```bash
267
+ source pretrain.sh
268
+ ```
269
+
270
+ Note: You can monitor the training process by observing `loss` (through a long window moving average), `overall_avg_sample_mse`, and the sampling MSE of each dataset in [Wandb](https://wandb.ai/site) or [TensorBoard](https://www.tensorflow.org/tensorboard). We empirically found that the lower the `overall_avg_sample_mse`, the better the model performs.
docs/test_6drot.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.spatial.transform import Rotation as R
3
+
4
+
5
+ def convert_quaternion_to_euler(quat):
6
+ """
7
+ Convert Quarternion (xyzw) to Euler angles (rpy)
8
+ """
9
+ # Normalize
10
+ quat = quat / np.linalg.norm(quat)
11
+ euler = R.from_quat(quat).as_euler('xyz')
12
+
13
+ return euler
14
+
15
+
16
+ def convert_euler_to_quaternion(euler):
17
+ """
18
+ Convert Euler angles (rpy) to Quarternion (xyzw)
19
+ """
20
+ quat = R.from_euler('xyz', euler).as_quat()
21
+
22
+ return quat
23
+
24
+
25
+ def convert_euler_to_rotation_matrix(euler):
26
+ """
27
+ Convert Euler angles (rpy) to rotation matrix (3x3).
28
+ """
29
+ quat = R.from_euler('xyz', euler).as_matrix()
30
+
31
+ return quat
32
+
33
+
34
+ def convert_rotation_matrix_to_euler(rotmat):
35
+ """
36
+ Convert rotation matrix (3x3) to Euler angles (rpy).
37
+ """
38
+ r = R.from_matrix(rotmat)
39
+ euler = r.as_euler('xyz', degrees=False)
40
+
41
+ return euler
42
+
43
+
44
+ def normalize_vector(v):
45
+ v_mag = np.linalg.norm(v, axis=-1, keepdims=True)
46
+ v_mag = np.maximum(v_mag, 1e-8)
47
+ return v / v_mag
48
+
49
+
50
+ def cross_product(u, v):
51
+ i = u[:,1]*v[:,2] - u[:,2]*v[:,1]
52
+ j = u[:,2]*v[:,0] - u[:,0]*v[:,2]
53
+ k = u[:,0]*v[:,1] - u[:,1]*v[:,0]
54
+
55
+ out = np.stack((i, j, k), axis=1)
56
+ return out
57
+
58
+
59
+ def compute_rotation_matrix_from_ortho6d(ortho6d):
60
+ x_raw = ortho6d[:, 0:3]
61
+ y_raw = ortho6d[:, 3:6]
62
+
63
+ x = normalize_vector(x_raw)
64
+ z = cross_product(x, y_raw)
65
+ z = normalize_vector(z)
66
+ y = cross_product(z, x)
67
+
68
+ x = x.reshape(-1, 3, 1)
69
+ y = y.reshape(-1, 3, 1)
70
+ z = z.reshape(-1, 3, 1)
71
+ matrix = np.concatenate((x, y, z), axis=2)
72
+ return matrix
73
+
74
+
75
+ def compute_ortho6d_from_rotation_matrix(matrix):
76
+ # The ortho6d represents the first two column vectors a1 and a2 of the
77
+ # rotation matrix: [ | , |, | ]
78
+ # [ a1, a2, a3]
79
+ # [ | , |, | ]
80
+ ortho6d = matrix[:, :, :2].transpose(0, 2, 1).reshape(matrix.shape[0], -1)
81
+ return ortho6d
82
+
83
+
84
+ # Test
85
+ if __name__ == "__main__":
86
+ # Randomly generate a euler ange
87
+ euler = np.random.rand(3) * 2 * np.pi - np.pi
88
+ euler = euler[None, :] # Add batch dimension
89
+ print(f"Input Euler angles: {euler}")
90
+
91
+ # Convert to 6D Rotation
92
+ rotmat = convert_euler_to_rotation_matrix(euler)
93
+ ortho6d = compute_ortho6d_from_rotation_matrix(rotmat)
94
+ print(f"6D Rotation: {ortho6d}")
95
+
96
+ # Convert back to Euler angles
97
+ rotmat_recovered = compute_rotation_matrix_from_ortho6d(ortho6d)
98
+ euler_recovered = convert_rotation_matrix_to_euler(rotmat_recovered)
99
+ print(f"Recovered Euler angles: {euler_recovered}")
eval_sim/eval_dp.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, List, Type
2
+ import gymnasium as gym
3
+ import numpy as np
4
+ from mani_skill.envs.sapien_env import BaseEnv
5
+ from mani_skill.utils import common, gym_utils
6
+ import argparse
7
+ import yaml
8
+ import torch
9
+ from collections import deque
10
+ from PIL import Image
11
+ import cv2
12
+ import imageio
13
+ from functools import partial
14
+
15
+ from diffusion_policy.workspace.robotworkspace import RobotWorkspace
16
+
17
+ def parse_args(args=None):
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument("-e", "--env-id", type=str, default="PickCube-v1", help=f"Environment to run motion planning solver on. ")
20
+ parser.add_argument("-o", "--obs-mode", type=str, default="rgb", help="Observation mode to use. Usually this is kept as 'none' as observations are not necesary to be stored, they can be replayed later via the mani_skill.trajectory.replay_trajectory script.")
21
+ parser.add_argument("-n", "--num-traj", type=int, default=25, help="Number of trajectories to generate.")
22
+ parser.add_argument("--only-count-success", action="store_true", help="If true, generates trajectories until num_traj of them are successful and only saves the successful trajectories/videos")
23
+ parser.add_argument("--reward-mode", type=str)
24
+ parser.add_argument("-b", "--sim-backend", type=str, default="auto", help="Which simulation backend to use. Can be 'auto', 'cpu', 'gpu'")
25
+ parser.add_argument("--render-mode", type=str, default="rgb_array", help="can be 'sensors' or 'rgb_array' which only affect what is saved to videos")
26
+ parser.add_argument("--vis", action="store_true", help="whether or not to open a GUI to visualize the solution live")
27
+ parser.add_argument("--save-video", action="store_true", help="whether or not to save videos locally")
28
+ parser.add_argument("--traj-name", type=str, help="The name of the trajectory .h5 file that will be created.")
29
+ parser.add_argument("--shader", default="default", type=str, help="Change shader used for rendering. Default is 'default' which is very fast. Can also be 'rt' for ray tracing and generating photo-realistic renders. Can also be 'rt-fast' for a faster but lower quality ray-traced renderer")
30
+ parser.add_argument("--record-dir", type=str, default="demos", help="where to save the recorded trajectories")
31
+ parser.add_argument("--num-procs", type=int, default=1, help="Number of processes to use to help parallelize the trajectory replay process. This uses CPU multiprocessing and only works with the CPU simulation backend at the moment.")
32
+ parser.add_argument("--random_seed", type=int, default=0, help="Random seed for the environment.")
33
+ parser.add_argument("--pretrained_path", type=str, default=None, help="Random seed for the environment.")
34
+
35
+ return parser.parse_args()
36
+
37
+ task2lang = {
38
+ "PegInsertionSide-v1": "Pick up a orange-white peg and insert the orange end into the box with a hole in it.",
39
+ "PickCube-v1": "Grasp a red cube and move it to a target goal position.",
40
+ "StackCube-v1": "Pick up a red cube and stack it on top of a green cube and let go of the cube without it falling.",
41
+ "PlugCharger-v1": "Pick up one of the misplaced shapes on the board/kit and insert it into the correct empty slot.",
42
+ "PushCube-v1": "Push and move a cube to a goal region in front of it."
43
+ }
44
+ import random
45
+ import os
46
+
47
+ args = parse_args()
48
+ seed = args.random_seed
49
+ random.seed(seed)
50
+ os.environ['PYTHONHASHSEED'] = str(seed)
51
+ np.random.seed(seed)
52
+ torch.manual_seed(seed)
53
+ torch.cuda.manual_seed(seed)
54
+ torch.backends.cudnn.deterministic = True
55
+ torch.backends.cudnn.benchmark = False
56
+
57
+ env_id = args.env_id
58
+ env = gym.make(
59
+ env_id,
60
+ obs_mode=args.obs_mode,
61
+ control_mode="pd_joint_pos",
62
+ render_mode=args.render_mode,
63
+ reward_mode="dense" if args.reward_mode is None else args.reward_mode,
64
+ sensor_configs=dict(shader_pack=args.shader),
65
+ human_render_camera_configs=dict(shader_pack=args.shader),
66
+ viewer_camera_configs=dict(shader_pack=args.shader),
67
+ sim_backend=args.sim_backend
68
+ )
69
+
70
+ from diffusion_policy.workspace.robotworkspace import RobotWorkspace
71
+ import hydra
72
+ import dill
73
+
74
+ checkpoint_path = args.pretrained_path
75
+ print(f"Loading policy from {checkpoint_path}. Task is {task2lang[env_id]}")
76
+
77
+ def get_policy(output_dir, device):
78
+
79
+ # load checkpoint
80
+ payload = torch.load(open(checkpoint_path, 'rb'), pickle_module=dill)
81
+ cfg = payload['cfg']
82
+ cls = hydra.utils.get_class(cfg._target_)
83
+ workspace = cls(cfg, output_dir=output_dir)
84
+ workspace: RobotWorkspace
85
+ workspace.load_payload(payload, exclude_keys=None, include_keys=None)
86
+
87
+ # get policy from workspace
88
+ policy = workspace.model
89
+ if cfg.training.use_ema:
90
+ policy = workspace.ema_model
91
+
92
+ device = torch.device(device)
93
+ policy.to(device)
94
+ policy.eval()
95
+
96
+ return policy
97
+
98
+ policy = get_policy('./', device = 'cuda')
99
+ MAX_EPISODE_STEPS = 400
100
+ total_episodes = args.num_traj
101
+ success_count = 0
102
+ base_seed = 20241201
103
+ instr = task2lang[env_id]
104
+ import tqdm
105
+
106
+ DATA_STAT = {'state_min': [-0.7463043928146362, -0.0801204964518547, -0.4976441562175751, -2.657780647277832, -0.5742632150650024, 1.8309762477874756, -2.2423808574676514, 0.0, 0.0], 'state_max': [0.7645499110221863, 1.4967026710510254, 0.4650936424732208, -0.3866899907588959, 0.5505855679512024, 3.2900545597076416, 2.5737812519073486, 0.03999999910593033, 0.03999999910593033], 'action_min': [-0.7472005486488342, -0.08631071448326111, -0.4995281398296356, -2.658363103866577, -0.5751323103904724, 1.8290787935256958, -2.245187997817993, -1.0], 'action_max': [0.7654682397842407, 1.4984270334243774, 0.46786263585090637, -0.38181185722351074, 0.5517147779464722, 3.291581630706787, 2.575840711593628, 1.0], 'action_std': [0.2199309915304184, 0.18780815601348877, 0.13044124841690063, 0.30669933557510376, 0.1340624988079071, 0.24968451261520386, 0.9589747190475464, 0.9827960729598999], 'action_mean': [-0.00885344110429287, 0.5523102879524231, -0.007564723491668701, -2.0108158588409424, 0.004714342765510082, 2.615924596786499, 0.08461848646402359, -0.19301606714725494]}
107
+
108
+ state_min = torch.tensor(DATA_STAT['state_min']).cuda()
109
+ state_max = torch.tensor(DATA_STAT['state_max']).cuda()
110
+ action_min = torch.tensor(DATA_STAT['action_min']).cuda()
111
+ action_max = torch.tensor(DATA_STAT['action_max']).cuda()
112
+
113
+ for episode in tqdm.trange(total_episodes):
114
+ obs_window = deque(maxlen=2)
115
+ obs, _ = env.reset(seed = episode + base_seed)
116
+
117
+ img = env.render().cuda().float()
118
+ proprio = obs['agent']['qpos'][:].cuda()
119
+ proprio = (proprio - state_min) / (state_max - state_min) * 2 - 1
120
+ obs_window.append({
121
+ 'agent_pos': proprio,
122
+ "head_cam": img.permute(0, 3, 1, 2),
123
+ })
124
+ obs_window.append({
125
+ 'agent_pos': proprio,
126
+ "head_cam": img.permute(0, 3, 1, 2),
127
+ })
128
+
129
+ global_steps = 0
130
+ video_frames = []
131
+
132
+ success_time = 0
133
+ done = False
134
+
135
+ while global_steps < MAX_EPISODE_STEPS and not done:
136
+ obs = obs_window[-1]
137
+ actions = policy.predict_action(obs)
138
+ actions = actions['action_pred'].squeeze(0)
139
+ actions = (actions + 1) / 2 * (action_max - action_min) + action_min
140
+ actions = actions.detach().cpu().numpy()
141
+ actions = actions[:8]
142
+ for idx in range(actions.shape[0]):
143
+ action = actions[idx]
144
+ obs, reward, terminated, truncated, info = env.step(action)
145
+ img = env.render().cuda().float()
146
+ proprio = obs['agent']['qpos'][:].cuda()
147
+ proprio = (proprio - state_min) / (state_max - state_min) * 2 - 1
148
+ obs_window.append({
149
+ 'agent_pos': proprio,
150
+ "head_cam": img.permute(0, 3, 1, 2),
151
+ })
152
+ video_frames.append(env.render().squeeze(0).detach().cpu().numpy())
153
+ global_steps += 1
154
+ if terminated or truncated:
155
+ assert "success" in info, sorted(info.keys())
156
+ if info['success']:
157
+ done = True
158
+ success_count += 1
159
+ break
160
+ print(f"Trial {episode+1} finished, success: {info['success']}, steps: {global_steps}")
161
+
162
+ success_rate = success_count / total_episodes * 100
163
+ print(f"Tested {total_episodes} episodes, success rate: {success_rate:.2f}%")
164
+ log_file = f"results_dp_{checkpoint_path.split('/')[-1].split('.')[0]}.txt"
165
+ with open(log_file, 'a') as f:
166
+ f.write(f"{args.env_id}:{seed}:{success_count}\n")
eval_sim/eval_octo.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, List, Type
2
+ import gymnasium as gym
3
+ import numpy as np
4
+ from mani_skill.envs.sapien_env import BaseEnv
5
+ from mani_skill.utils import common, gym_utils
6
+ import argparse
7
+ import yaml
8
+ import torch
9
+ from collections import deque
10
+ from PIL import Image
11
+ import cv2
12
+ from octo.model.octo_model import OctoModel
13
+ from octo.utils.train_callbacks import supply_rng
14
+ import imageio
15
+ import jax
16
+ import jax.numpy as jnp
17
+ from octo.utils.train_callbacks import supply_rng
18
+ from functools import partial
19
+
20
+ def parse_args(args=None):
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument("-e", "--env-id", type=str, default="PickCube-v1", help=f"Environment to run motion planning solver on. ")
23
+ parser.add_argument("-o", "--obs-mode", type=str, default="rgb", help="Observation mode to use. Usually this is kept as 'none' as observations are not necesary to be stored, they can be replayed later via the mani_skill.trajectory.replay_trajectory script.")
24
+ parser.add_argument("-n", "--num-traj", type=int, default=25, help="Number of trajectories to generate.")
25
+ parser.add_argument("--only-count-success", action="store_true", help="If true, generates trajectories until num_traj of them are successful and only saves the successful trajectories/videos")
26
+ parser.add_argument("--reward-mode", type=str)
27
+ parser.add_argument("-b", "--sim-backend", type=str, default="auto", help="Which simulation backend to use. Can be 'auto', 'cpu', 'gpu'")
28
+ parser.add_argument("--render-mode", type=str, default="rgb_array", help="can be 'sensors' or 'rgb_array' which only affect what is saved to videos")
29
+ parser.add_argument("--vis", action="store_true", help="whether or not to open a GUI to visualize the solution live")
30
+ parser.add_argument("--save-video", action="store_true", help="whether or not to save videos locally")
31
+ parser.add_argument("--traj-name", type=str, help="The name of the trajectory .h5 file that will be created.")
32
+ parser.add_argument("--shader", default="default", type=str, help="Change shader used for rendering. Default is 'default' which is very fast. Can also be 'rt' for ray tracing and generating photo-realistic renders. Can also be 'rt-fast' for a faster but lower quality ray-traced renderer")
33
+ parser.add_argument("--record-dir", type=str, default="demos", help="where to save the recorded trajectories")
34
+ parser.add_argument("--num-procs", type=int, default=1, help="Number of processes to use to help parallelize the trajectory replay process. This uses CPU multiprocessing and only works with the CPU simulation backend at the moment.")
35
+ parser.add_argument("--random_seed", type=int, default=0, help="Random seed for the environment.")
36
+ parser.add_argument("--pretrained_path", type=str, default=None, help="Path to the pretrained model")
37
+ return parser.parse_args()
38
+
39
+ task2lang = {
40
+ "PegInsertionSide-v1": "Pick up a orange-white peg and insert the orange end into the box with a hole in it.",
41
+ "PickCube-v1": "Grasp a red cube and move it to a target goal position.",
42
+ "StackCube-v1": "Pick up a red cube and stack it on top of a green cube and let go of the cube without it falling.",
43
+ "PlugCharger-v1": "Pick up one of the misplaced shapes on the board/kit and insert it into the correct empty slot.",
44
+ "PushCube-v1": "Push and move a cube to a goal region in front of it."
45
+ }
46
+ import random
47
+ import os
48
+
49
+ args = parse_args()
50
+ seed = args.random_seed
51
+ random.seed(seed)
52
+ os.environ['PYTHONHASHSEED'] = str(seed)
53
+ np.random.seed(seed)
54
+ torch.manual_seed(seed)
55
+ torch.cuda.manual_seed(seed)
56
+ torch.backends.cudnn.deterministic = True
57
+ torch.backends.cudnn.benchmark = False
58
+
59
+ env_id = args.env_id
60
+ env = gym.make(
61
+ env_id,
62
+ obs_mode=args.obs_mode,
63
+ control_mode="pd_ee_delta_pose",
64
+ render_mode=args.render_mode,
65
+ reward_mode="dense" if args.reward_mode is None else args.reward_mode,
66
+ sensor_configs=dict(shader_pack=args.shader),
67
+ human_render_camera_configs=dict(shader_pack=args.shader),
68
+ viewer_camera_configs=dict(shader_pack=args.shader),
69
+ sim_backend=args.sim_backend
70
+ )
71
+
72
+ def sample_actions(
73
+ pretrained_model: OctoModel,
74
+ observations,
75
+ tasks,
76
+ rng,
77
+ ):
78
+ # add batch dim to observations
79
+ observations = jax.tree_map(lambda x: x[None], observations)
80
+ actions = pretrained_model.sample_actions(
81
+ observations,
82
+ tasks,
83
+ rng=rng,
84
+ )
85
+ # remove batch dim
86
+ return actions[0]
87
+
88
+ pretrain_path = args.pretrained_path
89
+ step = 1000000
90
+ model = OctoModel.load_pretrained(
91
+ pretrain_path,
92
+ step
93
+ )
94
+
95
+ policy = supply_rng(
96
+ partial(
97
+ sample_actions,
98
+ model,
99
+ )
100
+ )
101
+
102
+
103
+ import tensorflow as tf
104
+ def resize_img(image, size=(256, 256)):
105
+ image_tf = tf.convert_to_tensor(image, dtype=tf.float32)
106
+ image_tf = tf.expand_dims(image_tf, axis=0)
107
+ resized_tf = tf.image.resize(
108
+ image_tf,
109
+ size,
110
+ method=tf.image.ResizeMethod.LANCZOS3,
111
+ antialias=True
112
+ )
113
+ resized_tf = tf.squeeze(resized_tf)
114
+ resized_img = resized_tf.numpy().astype(np.uint8)
115
+ return resized_img
116
+
117
+ MAX_EPISODE_STEPS = 400
118
+ total_episodes = args.num_traj
119
+ success_count = 0
120
+ base_seed = 20241201
121
+ import tqdm
122
+
123
+ for episode in tqdm.trange(total_episodes):
124
+ task = model.create_tasks(texts=[task2lang[env_id]])
125
+ obs_window = deque(maxlen=2)
126
+ obs, _ = env.reset(seed = base_seed)
127
+
128
+ img = env.render().squeeze(0).detach().cpu().numpy()
129
+ proprio = obs['agent']['qpos'][:]
130
+ obs_window.append({
131
+ 'proprio': proprio.detach().cpu().numpy(),
132
+ "image_primary": resize_img(img)[None],
133
+ "timestep_pad_mask": np.zeros((1),dtype = bool)
134
+ })
135
+ obs_window.append({
136
+ 'proprio': proprio.detach().cpu().numpy(),
137
+ "image_primary": resize_img(img)[None],
138
+ "timestep_pad_mask": np.ones((1),dtype = bool)
139
+ })
140
+
141
+ global_steps = 0
142
+ video_frames = []
143
+
144
+ success_time = 0
145
+ done = False
146
+
147
+ while global_steps < MAX_EPISODE_STEPS and not done:
148
+ obs = {
149
+ 'proprio': np.concatenate([obs_window[0]['proprio'], obs_window[1]['proprio']], axis=0),
150
+ "image_primary": np.concatenate([obs_window[0]['image_primary'], obs_window[1]['image_primary']], axis=0),
151
+ "timestep_pad_mask": np.concatenate([obs_window[0]['timestep_pad_mask'], obs_window[1]['timestep_pad_mask']], axis=0)
152
+ }
153
+ actions = policy(obs, task)
154
+ actions = jax.device_put(actions, device=jax.devices('cpu')[0])
155
+ actions = jax.device_get(actions)
156
+ # actions = actions[0:4]
157
+ for idx in range(actions.shape[0]):
158
+ action = actions[idx]
159
+ obs, reward, terminated, truncated, info = env.step(action)
160
+ img = env.render().squeeze(0).detach().cpu().numpy()
161
+ proprio = obs['agent']['qpos'][:]
162
+ obs_window.append({
163
+ 'proprio': proprio.detach().cpu().numpy(),
164
+ "image_primary": resize_img(img)[None],
165
+ "timestep_pad_mask": np.ones((1),dtype = bool)
166
+ })
167
+ video_frames.append(img)
168
+ global_steps += 1
169
+ if terminated or truncated:
170
+ assert "success" in info, sorted(info.keys())
171
+ if info['success']:
172
+ done = True
173
+ success_count += 1
174
+ break
175
+ print(f"Trial {episode+1} finished, success: {info['success']}, steps: {global_steps}")
176
+
177
+ success_rate = success_count / total_episodes * 100
178
+ print(f"Random seed: {seed}, Pretrained_path: {pretrain_path}")
179
+ print(f"Tested {total_episodes} episodes, success rate: {success_rate:.2f}%")
180
+ log_file = "results_octo.log"
181
+ with open(log_file, 'a') as f:
182
+ f.write(f"{seed}:{success_count}\n")
eval_sim/eval_openvla.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, List, Type
2
+ import gymnasium as gym
3
+ import numpy as np
4
+ from mani_skill.envs.sapien_env import BaseEnv
5
+ from mani_skill.utils import common, gym_utils
6
+ import argparse
7
+ import yaml
8
+ import torch
9
+ from collections import deque
10
+ from PIL import Image
11
+ import cv2
12
+ import imageio
13
+ from functools import partial
14
+ from torchvision.transforms.functional import center_crop
15
+
16
+ def parse_args(args=None):
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("-e", "--env-id", type=str, default="PickCube-v1", help=f"Environment to run motion planning solver on. ")
19
+ parser.add_argument("-o", "--obs-mode", type=str, default="rgb", help="Observation mode to use. Usually this is kept as 'none' as observations are not necesary to be stored, they can be replayed later via the mani_skill.trajectory.replay_trajectory script.")
20
+ parser.add_argument("-n", "--num-traj", type=int, default=25, help="Number of trajectories to generate.")
21
+ parser.add_argument("--only-count-success", action="store_true", help="If true, generates trajectories until num_traj of them are successful and only saves the successful trajectories/videos")
22
+ parser.add_argument("--reward-mode", type=str)
23
+ parser.add_argument("-b", "--sim-backend", type=str, default="auto", help="Which simulation backend to use. Can be 'auto', 'cpu', 'gpu'")
24
+ parser.add_argument("--render-mode", type=str, default="rgb_array", help="can be 'sensors' or 'rgb_array' which only affect what is saved to videos")
25
+ parser.add_argument("--vis", action="store_true", help="whether or not to open a GUI to visualize the solution live")
26
+ parser.add_argument("--save-video", action="store_true", help="whether or not to save videos locally")
27
+ parser.add_argument("--traj-name", type=str, help="The name of the trajectory .h5 file that will be created.")
28
+ parser.add_argument("--shader", default="default", type=str, help="Change shader used for rendering. Default is 'default' which is very fast. Can also be 'rt' for ray tracing and generating photo-realistic renders. Can also be 'rt-fast' for a faster but lower quality ray-traced renderer")
29
+ parser.add_argument("--record-dir", type=str, default="demos", help="where to save the recorded trajectories")
30
+ parser.add_argument("--num-procs", type=int, default=1, help="Number of processes to use to help parallelize the trajectory replay process. This uses CPU multiprocessing and only works with the CPU simulation backend at the moment.")
31
+ parser.add_argument("--random_seed", type=int, default=0, help="Random seed for the environment.")
32
+ parser.add_argument("--pretrained_path", type=str, default=None, help="Path to the pretrained model")
33
+ return parser.parse_args()
34
+
35
+ task2lang = {
36
+ "PegInsertionSide-v1": "Pick up a orange-white peg and insert the orange end into the box with a hole in it.",
37
+ "PickCube-v1": "Grasp a red cube and move it to a target goal position.",
38
+ "StackCube-v1": "Pick up a red cube and stack it on top of a green cube and let go of the cube without it falling.",
39
+ "PlugCharger-v1": "Pick up one of the misplaced shapes on the board/kit and insert it into the correct empty slot.",
40
+ "PushCube-v1": "Push and move a cube to a goal region in front of it."
41
+ }
42
+
43
+ import random
44
+ import os
45
+
46
+ args = parse_args()
47
+ seed = args.random_seed
48
+ random.seed(seed)
49
+ os.environ['PYTHONHASHSEED'] = str(seed)
50
+ np.random.seed(seed)
51
+ torch.manual_seed(seed)
52
+ torch.cuda.manual_seed(seed)
53
+ torch.backends.cudnn.deterministic = True
54
+ torch.backends.cudnn.benchmark = False
55
+
56
+ from transformers import AutoModelForVision2Seq, AutoProcessor
57
+
58
+ DATA_STAT = {'mean': [ 0.00263866, 0.01804881, -0.02151551, -0.00384866, 0.00500441,
59
+ -0.00057146, -0.26013601], 'std': [0.06639539, 0.1246438 , 0.09675793, 0.03351422, 0.04930534,
60
+ 0.25787726, 0.96762997], 'max': [0.31303197, 0.77948809, 0.42906255, 0.20186238, 0.63990456,
61
+ 0.99999917, 1. ], 'min': [-0.31464151, -0.64183694, -0.62718982, -0.5888508 , -0.97813392,
62
+ -0.99999928, -1. ], 'q01': [-0.18656027, -0.31995443, -0.24702898, -0.18005923, -0.2164692 ,
63
+ -0.82366071, -1. ], 'q99': [0.18384692, 0.45547636, 0.27452313, 0.03571117, 0.1188747 ,
64
+ 0.85074112, 1. ]}
65
+
66
+ MODEL_PATH = args.pretrained_path
67
+
68
+ def make_policy():
69
+ device = torch.device('cuda')
70
+
71
+ processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
72
+ vla = AutoModelForVision2Seq.from_pretrained(
73
+ MODEL_PATH,
74
+ attn_implementation="flash_attention_2", # [Optional] Requires `flash_attn`
75
+ torch_dtype=torch.bfloat16,
76
+ low_cpu_mem_usage=True,
77
+ trust_remote_code=True
78
+ ).to(device)
79
+ vla.norm_stats["maniskill"] = {
80
+ "action": {
81
+ "min": np.array(DATA_STAT["min"]),
82
+ "max": np.array(DATA_STAT["max"]),
83
+ "mean": np.array(DATA_STAT["mean"]),
84
+ "std": np.array(DATA_STAT["std"]),
85
+ "q01": np.array(DATA_STAT["q01"]),
86
+ "q99": np.array(DATA_STAT["q99"]),
87
+ }
88
+ }
89
+
90
+ vla = vla.eval()
91
+
92
+ return vla, processor
93
+
94
+ vla, processor = make_policy()
95
+ success_counts = {}
96
+
97
+ for env_id in task2lang.keys():
98
+
99
+ env = gym.make(
100
+ env_id,
101
+ obs_mode=args.obs_mode,
102
+ control_mode="pd_ee_delta_pose",
103
+ render_mode=args.render_mode,
104
+ reward_mode="dense" if args.reward_mode is None else args.reward_mode,
105
+ sensor_configs=dict(shader_pack=args.shader),
106
+ human_render_camera_configs=dict(shader_pack=args.shader),
107
+ viewer_camera_configs=dict(shader_pack=args.shader),
108
+ sim_backend=args.sim_backend
109
+ )
110
+
111
+ MAX_EPISODE_STEPS = 400
112
+ total_episodes = args.num_traj
113
+ success_count = 0
114
+ base_seed = 20241201
115
+ import tqdm
116
+
117
+ for episode in tqdm.trange(total_episodes):
118
+ obs_window = deque(maxlen=2)
119
+ obs, _ = env.reset(seed = base_seed + episode)
120
+
121
+ img = env.render().squeeze(0).detach().cpu().numpy()
122
+ obs_window.append(img)
123
+
124
+ global_steps = 0
125
+ video_frames = []
126
+
127
+ success_time = 0
128
+ done = False
129
+
130
+ while global_steps < MAX_EPISODE_STEPS and not done:
131
+ obs = obs_window[-1]
132
+ image_arrs = [
133
+ obs_window[-1]
134
+ ]
135
+ images = [Image.fromarray(arr) for arr in image_arrs]
136
+ original_size = images[0].size
137
+ crop_scale = 0.9
138
+ sqrt_crop_scale = crop_scale
139
+ sqrt_crop_scale = np.sqrt(crop_scale)
140
+ images = [
141
+ center_crop(
142
+ img, output_size=(
143
+ int(sqrt_crop_scale * img.size[1]),
144
+ int(sqrt_crop_scale * img.size[0])
145
+ )
146
+ ) for img in images
147
+ ]
148
+ images = [img.resize(original_size, Image.Resampling.BILINEAR) for img in images]
149
+ # de-capitalize and remove trailing period
150
+ instruction = task2lang[env_id].lower()
151
+ prompt = f"In: What action should the robot take to {instruction}?\nOut:"
152
+ inputs = processor(prompt, images).to("cuda:0", dtype=torch.bfloat16)
153
+ actions = vla.predict_action(**inputs, unnorm_key="maniskill", do_sample=False)[None]
154
+ for idx in range(actions.shape[0]):
155
+ action = actions[idx]
156
+ # print(action)
157
+ # action = action * (np.array(DATA_STAT['std']) + 1e-8) + np.array(DATA_STAT['mean'])
158
+ obs, reward, terminated, truncated, info = env.step(action)
159
+ img = env.render().squeeze(0).detach().cpu().numpy()
160
+ obs_window.append(img)
161
+ video_frames.append(img)
162
+ global_steps += 1
163
+ if terminated or truncated:
164
+ assert "success" in info, sorted(info.keys())
165
+ if info['success']:
166
+ success_count += 1
167
+ done = True
168
+ break
169
+ print(f"Trial {episode+1} finished, success: {info['success']}, steps: {global_steps}")
170
+ success_counts[env_id] = success_count
171
+ print(f"Task {env_id} finished, success: {success_count}/{total_episodes}")
172
+
173
+ log_file = "results_ovla_all.log"
174
+ with open(log_file, 'a') as f:
175
+ f.write(f"{seed}:{success_counts}\n")
eval_sim/eval_rdt_maniskill.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, List, Type
2
+ import sys
3
+ sys.path.append('/')
4
+ import gymnasium as gym
5
+ import numpy as np
6
+ from mani_skill.envs.sapien_env import BaseEnv
7
+ from mani_skill.utils import common, gym_utils
8
+ import argparse
9
+ import yaml
10
+ from scripts.maniskill_model import create_model, RoboticDiffusionTransformerModel
11
+ import torch
12
+ from collections import deque
13
+ from PIL import Image
14
+ import cv2
15
+
16
+ def parse_args(args=None):
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("-e", "--env-id", type=str, default="PickCube-v1", help=f"Environment to run motion planning solver on. ")
19
+ parser.add_argument("-o", "--obs-mode", type=str, default="rgb", help="Observation mode to use. Usually this is kept as 'none' as observations are not necesary to be stored, they can be replayed later via the mani_skill.trajectory.replay_trajectory script.")
20
+ parser.add_argument("-n", "--num-traj", type=int, default=25, help="Number of trajectories to test.")
21
+ parser.add_argument("--only-count-success", action="store_true", help="If true, generates trajectories until num_traj of them are successful and only saves the successful trajectories/videos")
22
+ parser.add_argument("--reward-mode", type=str)
23
+ parser.add_argument("-b", "--sim-backend", type=str, default="auto", help="Which simulation backend to use. Can be 'auto', 'cpu', 'gpu'")
24
+ parser.add_argument("--render-mode", type=str, default="rgb_array", help="can be 'sensors' or 'rgb_array' which only affect what is saved to videos")
25
+ parser.add_argument("--shader", default="default", type=str, help="Change shader used for rendering. Default is 'default' which is very fast. Can also be 'rt' for ray tracing and generating photo-realistic renders. Can also be 'rt-fast' for a faster but lower quality ray-traced renderer")
26
+ parser.add_argument("--num-procs", type=int, default=1, help="Number of processes to use to help parallelize the trajectory replay process. This uses CPU multiprocessing and only works with the CPU simulation backend at the moment.")
27
+ parser.add_argument("--pretrained_path", type=str, default=None, help="Path to the pretrained model")
28
+ parser.add_argument("--random_seed", type=int, default=0, help="Random seed for the environment.")
29
+ return parser.parse_args()
30
+
31
+ import random
32
+ import os
33
+
34
+ # set cuda
35
+ args = parse_args()
36
+ # set random seeds
37
+ seed = args.random_seed
38
+ random.seed(seed)
39
+ os.environ['PYTHONHASHSEED'] = str(seed)
40
+ np.random.seed(seed)
41
+ torch.manual_seed(seed)
42
+ torch.cuda.manual_seed(seed)
43
+ torch.backends.cudnn.deterministic = True
44
+ torch.backends.cudnn.benchmark = False
45
+
46
+ task2lang = {
47
+ "PegInsertionSide-v1": "Pick up a orange-white peg and insert the orange end into the box with a hole in it.",
48
+ "PickCube-v1": "Grasp a red cube and move it to a target goal position.",
49
+ "StackCube-v1": "Pick up a red cube and stack it on top of a green cube and let go of the cube without it falling.",
50
+ "PlugCharger-v1": "Pick up one of the misplaced shapes on the board/kit and insert it into the correct empty slot.",
51
+ "PushCube-v1": "Push and move a cube to a goal region in front of it."
52
+ }
53
+
54
+ env_id = args.env_id
55
+ env = gym.make(
56
+ env_id,
57
+ obs_mode=args.obs_mode,
58
+ control_mode="pd_joint_pos",
59
+ render_mode=args.render_mode,
60
+ reward_mode="dense" if args.reward_mode is None else args.reward_mode,
61
+ sensor_configs=dict(shader_pack=args.shader),
62
+ human_render_camera_configs=dict(shader_pack=args.shader),
63
+ viewer_camera_configs=dict(shader_pack=args.shader),
64
+ sim_backend=args.sim_backend
65
+ )
66
+
67
+ config_path = 'configs/base.yaml'
68
+ with open(config_path, "r") as fp:
69
+ config = yaml.safe_load(fp)
70
+ pretrained_text_encoder_name_or_path = "google/t5-v1_1-xxl"
71
+ pretrained_vision_encoder_name_or_path = "google/siglip-so400m-patch14-384"
72
+ pretrained_path = args.pretrained_path
73
+ policy = create_model(
74
+ args=config,
75
+ dtype=torch.bfloat16,
76
+ pretrained=pretrained_path,
77
+ pretrained_text_encoder_name_or_path=pretrained_text_encoder_name_or_path,
78
+ pretrained_vision_encoder_name_or_path=pretrained_vision_encoder_name_or_path
79
+ )
80
+
81
+ if os.path.exists(f'text_embed_{env_id}.pt'):
82
+ text_embed = torch.load(f'text_embed_{env_id}.pt')
83
+ else:
84
+ text_embed = policy.encode_instruction(task2lang[env_id])
85
+ torch.save(text_embed, f'text_embed_{env_id}.pt')
86
+
87
+ MAX_EPISODE_STEPS = 400
88
+ total_episodes = args.num_traj
89
+ success_count = 0
90
+
91
+ base_seed = 20241201
92
+ import tqdm
93
+ for episode in tqdm.trange(total_episodes):
94
+ obs_window = deque(maxlen=2)
95
+ obs, _ = env.reset(seed = episode + base_seed)
96
+ policy.reset()
97
+
98
+ img = env.render().squeeze(0).detach().cpu().numpy()
99
+ obs_window.append(None)
100
+ obs_window.append(np.array(img))
101
+ proprio = obs['agent']['qpos'][:, :-1]
102
+
103
+ global_steps = 0
104
+ video_frames = []
105
+
106
+ success_time = 0
107
+ done = False
108
+
109
+ while global_steps < MAX_EPISODE_STEPS and not done:
110
+ image_arrs = []
111
+ for window_img in obs_window:
112
+ image_arrs.append(window_img)
113
+ image_arrs.append(None)
114
+ image_arrs.append(None)
115
+ images = [Image.fromarray(arr) if arr is not None else None
116
+ for arr in image_arrs]
117
+ actions = policy.step(proprio, images, text_embed).squeeze(0).cpu().numpy()
118
+ # Take 8 steps since RDT is trained to predict interpolated 64 steps(actual 14 steps)
119
+ actions = actions[::4, :]
120
+ for idx in range(actions.shape[0]):
121
+ action = actions[idx]
122
+ obs, reward, terminated, truncated, info = env.step(action)
123
+ img = env.render().squeeze(0).detach().cpu().numpy()
124
+ obs_window.append(img)
125
+ proprio = obs['agent']['qpos'][:, :-1]
126
+ video_frames.append(img)
127
+ global_steps += 1
128
+ if terminated or truncated:
129
+ assert "success" in info, sorted(info.keys())
130
+ if info['success']:
131
+ success_count += 1
132
+ done = True
133
+ break
134
+ print(f"Trial {episode+1} finished, success: {info['success']}, steps: {global_steps}")
135
+
136
+ success_rate = success_count / total_episodes * 100
137
+ print(f"Success rate: {success_rate}%")
lang_embed/aloha_dish_drainer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:903018f06a23f7d8b97480b5bf304442f0593a4d854dc9e0e0fd70822c52b82e
3
+ size 99667
lang_embed/aloha_handover_box.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a343452e908910230df6ca0045320e121f2f59969314c6a1a09af88192b5e81
3
+ size 91475
lang_embed/aloha_lift_box.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e2953aa4aad84e687f08f5819f3b3e6c3d4671cbb1d95539dacf4dcad2e9142
3
+ size 83263
lang_embed/aloha_shoes_table.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f606c5cbb856de1c7a362d7ab437d098abc76e78b2b731f77ded048851b87900
3
+ size 132494
lang_embed/anubis_brush_to_pan.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1155d325416a6d1be4a70bb49d7b15272ef0d75dc8a0437a42a9f4660c607a49
3
+ size 66904
lang_embed/anubis_carrot_to_bag.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:394c01aa62fbfd6fdf3d6b53684b10e377820b0a4d0c9425e08b549862beedb6
3
+ size 83293
lang_embed/anubis_towel_kirby.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92e180dc03d6b6841810bafa499aeee1c43d45a9206b6401426b105bad1fd966
3
+ size 83283
scripts/agilex_inference.py ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/home/lin/software/miniconda3/envs/aloha/bin/python
2
+ # -- coding: UTF-8
3
+ """
4
+ #!/usr/bin/python3
5
+ """
6
+
7
+ import argparse
8
+ import sys
9
+ import threading
10
+ import time
11
+ import yaml
12
+ from collections import deque
13
+
14
+ import numpy as np
15
+ import rospy
16
+ import torch
17
+ from cv_bridge import CvBridge
18
+ from geometry_msgs.msg import Twist
19
+ from nav_msgs.msg import Odometry
20
+ from PIL import Image as PImage
21
+ from sensor_msgs.msg import Image, JointState
22
+ from std_msgs.msg import Header
23
+ import cv2
24
+
25
+ from scripts.agilex_model import create_model
26
+
27
+ # sys.path.append("./")
28
+
29
+ CAMERA_NAMES = ['cam_high', 'cam_right_wrist', 'cam_left_wrist']
30
+
31
+ observation_window = None
32
+
33
+ lang_embeddings = None
34
+
35
+ # debug
36
+ preload_images = None
37
+
38
+
39
+ # Initialize the model
40
+ def make_policy(args):
41
+ with open(args.config_path, "r") as fp:
42
+ config = yaml.safe_load(fp)
43
+ args.config = config
44
+
45
+ # pretrained_text_encoder_name_or_path = "google/t5-v1_1-xxl"
46
+ pretrained_vision_encoder_name_or_path = "google/siglip-so400m-patch14-384"
47
+ model = create_model(
48
+ args=args.config,
49
+ dtype=torch.bfloat16,
50
+ pretrained=args.pretrained_model_name_or_path,
51
+ # pretrained_text_encoder_name_or_path=pretrained_text_encoder_name_or_path,
52
+ pretrained_vision_encoder_name_or_path=pretrained_vision_encoder_name_or_path,
53
+ control_frequency=args.ctrl_freq,
54
+ )
55
+
56
+ return model
57
+
58
+
59
+ def set_seed(seed):
60
+ torch.manual_seed(seed)
61
+ np.random.seed(seed)
62
+
63
+
64
+ # Interpolate the actions to make the robot move smoothly
65
+ def interpolate_action(args, prev_action, cur_action):
66
+ steps = np.concatenate((np.array(args.arm_steps_length), np.array(args.arm_steps_length)), axis=0)
67
+ diff = np.abs(cur_action - prev_action)
68
+ step = np.ceil(diff / steps).astype(int)
69
+ step = np.max(step)
70
+ if step <= 1:
71
+ return cur_action[np.newaxis, :]
72
+ new_actions = np.linspace(prev_action, cur_action, step + 1)
73
+ return new_actions[1:]
74
+
75
+
76
+ def get_config(args):
77
+ config = {
78
+ 'episode_len': args.max_publish_step,
79
+ 'state_dim': 14,
80
+ 'chunk_size': args.chunk_size,
81
+ 'camera_names': CAMERA_NAMES,
82
+ }
83
+ return config
84
+
85
+
86
+ # Get the observation from the ROS topic
87
+ def get_ros_observation(args,ros_operator):
88
+ rate = rospy.Rate(args.publish_rate)
89
+ print_flag = True
90
+
91
+ while True and not rospy.is_shutdown():
92
+ result = ros_operator.get_frame()
93
+ if not result:
94
+ if print_flag:
95
+ print("syn fail when get_ros_observation")
96
+ print_flag = False
97
+ rate.sleep()
98
+ continue
99
+ print_flag = True
100
+ (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
101
+ puppet_arm_left, puppet_arm_right, robot_base) = result
102
+ # print(f"sync success when get_ros_observation")
103
+ return (img_front, img_left, img_right,
104
+ puppet_arm_left, puppet_arm_right)
105
+
106
+
107
+ # Update the observation window buffer
108
+ def update_observation_window(args, config, ros_operator):
109
+ # JPEG transformation
110
+ # Align with training
111
+ def jpeg_mapping(img):
112
+ img = cv2.imencode('.jpg', img)[1].tobytes()
113
+ img = cv2.imdecode(np.frombuffer(img, np.uint8), cv2.IMREAD_COLOR)
114
+ return img
115
+
116
+ global observation_window
117
+ if observation_window is None:
118
+ observation_window = deque(maxlen=2)
119
+
120
+ # Append the first dummy image
121
+ observation_window.append(
122
+ {
123
+ 'qpos': None,
124
+ 'images':
125
+ {
126
+ config["camera_names"][0]: None,
127
+ config["camera_names"][1]: None,
128
+ config["camera_names"][2]: None,
129
+ },
130
+ }
131
+ )
132
+
133
+ img_front, img_left, img_right, puppet_arm_left, puppet_arm_right = get_ros_observation(args,ros_operator)
134
+ img_front = jpeg_mapping(img_front)
135
+ img_left = jpeg_mapping(img_left)
136
+ img_right = jpeg_mapping(img_right)
137
+
138
+ qpos = np.concatenate(
139
+ (np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)), axis=0)
140
+ qpos = torch.from_numpy(qpos).float().cuda()
141
+ observation_window.append(
142
+ {
143
+ 'qpos': qpos,
144
+ 'images':
145
+ {
146
+ config["camera_names"][0]: img_front,
147
+ config["camera_names"][1]: img_right,
148
+ config["camera_names"][2]: img_left,
149
+ },
150
+ }
151
+ )
152
+
153
+
154
+ # RDT inference
155
+ def inference_fn(args, config, policy, t):
156
+ global observation_window
157
+ global lang_embeddings
158
+
159
+ # print(f"Start inference_thread_fn: t={t}")
160
+ while True and not rospy.is_shutdown():
161
+ time1 = time.time()
162
+
163
+ # fetch images in sequence [front, right, left]
164
+ image_arrs = [
165
+ observation_window[-2]['images'][config['camera_names'][0]],
166
+ observation_window[-2]['images'][config['camera_names'][1]],
167
+ observation_window[-2]['images'][config['camera_names'][2]],
168
+
169
+ observation_window[-1]['images'][config['camera_names'][0]],
170
+ observation_window[-1]['images'][config['camera_names'][1]],
171
+ observation_window[-1]['images'][config['camera_names'][2]]
172
+ ]
173
+
174
+ # fetch debug images in sequence [front, right, left]
175
+ # image_arrs = [
176
+ # preload_images[config['camera_names'][0]][max(t - 1, 0)],
177
+ # preload_images[config['camera_names'][2]][max(t - 1, 0)],
178
+ # preload_images[config['camera_names'][1]][max(t - 1, 0)],
179
+ # preload_images[config['camera_names'][0]][t],
180
+ # preload_images[config['camera_names'][2]][t],
181
+ # preload_images[config['camera_names'][1]][t]
182
+ # ]
183
+ # # encode the images
184
+ # for i in range(len(image_arrs)):
185
+ # image_arrs[i] = cv2.imdecode(np.frombuffer(image_arrs[i], np.uint8), cv2.IMREAD_COLOR)
186
+ # proprio = torch.from_numpy(preload_images['qpos'][t]).float().cuda()
187
+
188
+ images = [PImage.fromarray(arr) if arr is not None else None
189
+ for arr in image_arrs]
190
+
191
+ # for i, pos in enumerate(['f', 'r', 'l'] * 2):
192
+ # images[i].save(f'{t}-{i}-{pos}.png')
193
+
194
+ # get last qpos in shape [14, ]
195
+ proprio = observation_window[-1]['qpos']
196
+ # unsqueeze to [1, 14]
197
+ proprio = proprio.unsqueeze(0)
198
+
199
+ # actions shaped as [1, 64, 14] in format [left, right]
200
+ actions = policy.step(
201
+ proprio=proprio,
202
+ images=images,
203
+ text_embeds=lang_embeddings
204
+ ).squeeze(0).cpu().numpy()
205
+ # print(f"inference_actions: {actions.squeeze()}")
206
+
207
+ print(f"Model inference time: {time.time() - time1} s")
208
+
209
+ # print(f"Finish inference_thread_fn: t={t}")
210
+ return actions
211
+
212
+
213
+ # Main loop for the manipulation task
214
+ def model_inference(args, config, ros_operator):
215
+ global lang_embeddings
216
+
217
+ # Load rdt model
218
+ policy = make_policy(args)
219
+
220
+ lang_dict = torch.load(args.lang_embeddings_path)
221
+ print(f"Running with instruction: \"{lang_dict['instruction']}\" from \"{lang_dict['name']}\"")
222
+ lang_embeddings = lang_dict["embeddings"]
223
+
224
+ max_publish_step = config['episode_len']
225
+ chunk_size = config['chunk_size']
226
+
227
+ # Initialize position of the puppet arm
228
+ left0 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, 3.557830810546875]
229
+ right0 = [-0.00133514404296875, 0.00438690185546875, 0.034523963928222656, -0.053597450256347656, -0.00476837158203125, -0.00209808349609375, 3.557830810546875]
230
+ left1 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3393220901489258]
231
+ right1 = [-0.00133514404296875, 0.00247955322265625, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3397035598754883]
232
+ ros_operator.puppet_arm_publish_continuous(left0, right0)
233
+ input("Press enter to continue")
234
+ ros_operator.puppet_arm_publish_continuous(left1, right1)
235
+ # Initialize the previous action to be the initial robot state
236
+ pre_action = np.zeros(config['state_dim'])
237
+ pre_action[:14] = np.array(
238
+ [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3393220901489258] +
239
+ [-0.00133514404296875, 0.00247955322265625, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3397035598754883]
240
+ )
241
+ action = None
242
+ # Inference loop
243
+ with torch.inference_mode():
244
+ while True and not rospy.is_shutdown():
245
+ # The current time step
246
+ t = 0
247
+ rate = rospy.Rate(args.publish_rate)
248
+
249
+ action_buffer = np.zeros([chunk_size, config['state_dim']])
250
+
251
+ while t < max_publish_step and not rospy.is_shutdown():
252
+ # Update observation window
253
+ update_observation_window(args, config, ros_operator)
254
+
255
+ # When coming to the end of the action chunk
256
+ if t % chunk_size == 0:
257
+ # Start inference
258
+ action_buffer = inference_fn(args, config, policy, t).copy()
259
+
260
+ raw_action = action_buffer[t % chunk_size]
261
+ action = raw_action
262
+ # Interpolate the original action sequence
263
+ if args.use_actions_interpolation:
264
+ # print(f"Time {t}, pre {pre_action}, act {action}")
265
+ interp_actions = interpolate_action(args, pre_action, action)
266
+ else:
267
+ interp_actions = action[np.newaxis, :]
268
+ # Execute the interpolated actions one by one
269
+ for act in interp_actions:
270
+ left_action = act[:7]
271
+ right_action = act[7:14]
272
+
273
+ if not args.disable_puppet_arm:
274
+ ros_operator.puppet_arm_publish(left_action, right_action) # puppet_arm_publish_continuous_thread
275
+
276
+ if args.use_robot_base:
277
+ vel_action = act[14:16]
278
+ ros_operator.robot_base_publish(vel_action)
279
+ rate.sleep()
280
+ # print(f"doing action: {act}")
281
+ t += 1
282
+
283
+ print("Published Step", t)
284
+ pre_action = action.copy()
285
+
286
+
287
+ # ROS operator class
288
+ class RosOperator:
289
+ def __init__(self, args):
290
+ self.robot_base_deque = None
291
+ self.puppet_arm_right_deque = None
292
+ self.puppet_arm_left_deque = None
293
+ self.img_front_deque = None
294
+ self.img_right_deque = None
295
+ self.img_left_deque = None
296
+ self.img_front_depth_deque = None
297
+ self.img_right_depth_deque = None
298
+ self.img_left_depth_deque = None
299
+ self.bridge = None
300
+ self.puppet_arm_left_publisher = None
301
+ self.puppet_arm_right_publisher = None
302
+ self.robot_base_publisher = None
303
+ self.puppet_arm_publish_thread = None
304
+ self.puppet_arm_publish_lock = None
305
+ self.args = args
306
+ self.init()
307
+ self.init_ros()
308
+
309
+ def init(self):
310
+ self.bridge = CvBridge()
311
+ self.img_left_deque = deque()
312
+ self.img_right_deque = deque()
313
+ self.img_front_deque = deque()
314
+ self.img_left_depth_deque = deque()
315
+ self.img_right_depth_deque = deque()
316
+ self.img_front_depth_deque = deque()
317
+ self.puppet_arm_left_deque = deque()
318
+ self.puppet_arm_right_deque = deque()
319
+ self.robot_base_deque = deque()
320
+ self.puppet_arm_publish_lock = threading.Lock()
321
+ self.puppet_arm_publish_lock.acquire()
322
+
323
+ def puppet_arm_publish(self, left, right):
324
+ joint_state_msg = JointState()
325
+ joint_state_msg.header = Header()
326
+ joint_state_msg.header.stamp = rospy.Time.now() # Set timestep
327
+ joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
328
+ joint_state_msg.position = left
329
+ self.puppet_arm_left_publisher.publish(joint_state_msg)
330
+ joint_state_msg.position = right
331
+ self.puppet_arm_right_publisher.publish(joint_state_msg)
332
+
333
+ def robot_base_publish(self, vel):
334
+ vel_msg = Twist()
335
+ vel_msg.linear.x = vel[0]
336
+ vel_msg.linear.y = 0
337
+ vel_msg.linear.z = 0
338
+ vel_msg.angular.x = 0
339
+ vel_msg.angular.y = 0
340
+ vel_msg.angular.z = vel[1]
341
+ self.robot_base_publisher.publish(vel_msg)
342
+
343
+ def puppet_arm_publish_continuous(self, left, right):
344
+ rate = rospy.Rate(self.args.publish_rate)
345
+ left_arm = None
346
+ right_arm = None
347
+ while True and not rospy.is_shutdown():
348
+ if len(self.puppet_arm_left_deque) != 0:
349
+ left_arm = list(self.puppet_arm_left_deque[-1].position)
350
+ if len(self.puppet_arm_right_deque) != 0:
351
+ right_arm = list(self.puppet_arm_right_deque[-1].position)
352
+ if left_arm is None or right_arm is None:
353
+ rate.sleep()
354
+ continue
355
+ else:
356
+ break
357
+ left_symbol = [1 if left[i] - left_arm[i] > 0 else -1 for i in range(len(left))]
358
+ right_symbol = [1 if right[i] - right_arm[i] > 0 else -1 for i in range(len(right))]
359
+ flag = True
360
+ step = 0
361
+ while flag and not rospy.is_shutdown():
362
+ if self.puppet_arm_publish_lock.acquire(False):
363
+ return
364
+ left_diff = [abs(left[i] - left_arm[i]) for i in range(len(left))]
365
+ right_diff = [abs(right[i] - right_arm[i]) for i in range(len(right))]
366
+ flag = False
367
+ for i in range(len(left)):
368
+ if left_diff[i] < self.args.arm_steps_length[i]:
369
+ left_arm[i] = left[i]
370
+ else:
371
+ left_arm[i] += left_symbol[i] * self.args.arm_steps_length[i]
372
+ flag = True
373
+ for i in range(len(right)):
374
+ if right_diff[i] < self.args.arm_steps_length[i]:
375
+ right_arm[i] = right[i]
376
+ else:
377
+ right_arm[i] += right_symbol[i] * self.args.arm_steps_length[i]
378
+ flag = True
379
+ joint_state_msg = JointState()
380
+ joint_state_msg.header = Header()
381
+ joint_state_msg.header.stamp = rospy.Time.now() # Set the timestep
382
+ joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
383
+ joint_state_msg.position = left_arm
384
+ self.puppet_arm_left_publisher.publish(joint_state_msg)
385
+ joint_state_msg.position = right_arm
386
+ self.puppet_arm_right_publisher.publish(joint_state_msg)
387
+ step += 1
388
+ print("puppet_arm_publish_continuous:", step)
389
+ rate.sleep()
390
+
391
+ def puppet_arm_publish_linear(self, left, right):
392
+ num_step = 100
393
+ rate = rospy.Rate(200)
394
+
395
+ left_arm = None
396
+ right_arm = None
397
+
398
+ while True and not rospy.is_shutdown():
399
+ if len(self.puppet_arm_left_deque) != 0:
400
+ left_arm = list(self.puppet_arm_left_deque[-1].position)
401
+ if len(self.puppet_arm_right_deque) != 0:
402
+ right_arm = list(self.puppet_arm_right_deque[-1].position)
403
+ if left_arm is None or right_arm is None:
404
+ rate.sleep()
405
+ continue
406
+ else:
407
+ break
408
+
409
+ traj_left_list = np.linspace(left_arm, left, num_step)
410
+ traj_right_list = np.linspace(right_arm, right, num_step)
411
+
412
+ for i in range(len(traj_left_list)):
413
+ traj_left = traj_left_list[i]
414
+ traj_right = traj_right_list[i]
415
+ traj_left[-1] = left[-1]
416
+ traj_right[-1] = right[-1]
417
+ joint_state_msg = JointState()
418
+ joint_state_msg.header = Header()
419
+ joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
420
+ joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
421
+ joint_state_msg.position = traj_left
422
+ self.puppet_arm_left_publisher.publish(joint_state_msg)
423
+ joint_state_msg.position = traj_right
424
+ self.puppet_arm_right_publisher.publish(joint_state_msg)
425
+ rate.sleep()
426
+
427
+ def puppet_arm_publish_continuous_thread(self, left, right):
428
+ if self.puppet_arm_publish_thread is not None:
429
+ self.puppet_arm_publish_lock.release()
430
+ self.puppet_arm_publish_thread.join()
431
+ self.puppet_arm_publish_lock.acquire(False)
432
+ self.puppet_arm_publish_thread = None
433
+ self.puppet_arm_publish_thread = threading.Thread(target=self.puppet_arm_publish_continuous, args=(left, right))
434
+ self.puppet_arm_publish_thread.start()
435
+
436
+ def get_frame(self):
437
+ if len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or \
438
+ (self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0 or len(self.img_front_depth_deque) == 0)):
439
+ return False
440
+ if self.args.use_depth_image:
441
+ frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec(),
442
+ self.img_left_depth_deque[-1].header.stamp.to_sec(), self.img_right_depth_deque[-1].header.stamp.to_sec(), self.img_front_depth_deque[-1].header.stamp.to_sec()])
443
+ else:
444
+ frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec()])
445
+
446
+ if len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time:
447
+ return False
448
+ if len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time:
449
+ return False
450
+ if len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time:
451
+ return False
452
+ if len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time:
453
+ return False
454
+ if len(self.puppet_arm_right_deque) == 0 or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time:
455
+ return False
456
+ if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time):
457
+ return False
458
+ if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0 or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time):
459
+ return False
460
+ if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0 or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time):
461
+ return False
462
+ if self.args.use_robot_base and (len(self.robot_base_deque) == 0 or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time):
463
+ return False
464
+
465
+ while self.img_left_deque[0].header.stamp.to_sec() < frame_time:
466
+ self.img_left_deque.popleft()
467
+ img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), 'passthrough')
468
+
469
+ while self.img_right_deque[0].header.stamp.to_sec() < frame_time:
470
+ self.img_right_deque.popleft()
471
+ img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), 'passthrough')
472
+
473
+ while self.img_front_deque[0].header.stamp.to_sec() < frame_time:
474
+ self.img_front_deque.popleft()
475
+ img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), 'passthrough')
476
+
477
+ while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time:
478
+ self.puppet_arm_left_deque.popleft()
479
+ puppet_arm_left = self.puppet_arm_left_deque.popleft()
480
+
481
+ while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time:
482
+ self.puppet_arm_right_deque.popleft()
483
+ puppet_arm_right = self.puppet_arm_right_deque.popleft()
484
+
485
+ img_left_depth = None
486
+ if self.args.use_depth_image:
487
+ while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time:
488
+ self.img_left_depth_deque.popleft()
489
+ img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), 'passthrough')
490
+
491
+ img_right_depth = None
492
+ if self.args.use_depth_image:
493
+ while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time:
494
+ self.img_right_depth_deque.popleft()
495
+ img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), 'passthrough')
496
+
497
+ img_front_depth = None
498
+ if self.args.use_depth_image:
499
+ while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time:
500
+ self.img_front_depth_deque.popleft()
501
+ img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), 'passthrough')
502
+
503
+ robot_base = None
504
+ if self.args.use_robot_base:
505
+ while self.robot_base_deque[0].header.stamp.to_sec() < frame_time:
506
+ self.robot_base_deque.popleft()
507
+ robot_base = self.robot_base_deque.popleft()
508
+
509
+ return (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
510
+ puppet_arm_left, puppet_arm_right, robot_base)
511
+
512
+ def img_left_callback(self, msg):
513
+ if len(self.img_left_deque) >= 2000:
514
+ self.img_left_deque.popleft()
515
+ self.img_left_deque.append(msg)
516
+
517
+ def img_right_callback(self, msg):
518
+ if len(self.img_right_deque) >= 2000:
519
+ self.img_right_deque.popleft()
520
+ self.img_right_deque.append(msg)
521
+
522
+ def img_front_callback(self, msg):
523
+ if len(self.img_front_deque) >= 2000:
524
+ self.img_front_deque.popleft()
525
+ self.img_front_deque.append(msg)
526
+
527
+ def img_left_depth_callback(self, msg):
528
+ if len(self.img_left_depth_deque) >= 2000:
529
+ self.img_left_depth_deque.popleft()
530
+ self.img_left_depth_deque.append(msg)
531
+
532
+ def img_right_depth_callback(self, msg):
533
+ if len(self.img_right_depth_deque) >= 2000:
534
+ self.img_right_depth_deque.popleft()
535
+ self.img_right_depth_deque.append(msg)
536
+
537
+ def img_front_depth_callback(self, msg):
538
+ if len(self.img_front_depth_deque) >= 2000:
539
+ self.img_front_depth_deque.popleft()
540
+ self.img_front_depth_deque.append(msg)
541
+
542
+ def puppet_arm_left_callback(self, msg):
543
+ if len(self.puppet_arm_left_deque) >= 2000:
544
+ self.puppet_arm_left_deque.popleft()
545
+ self.puppet_arm_left_deque.append(msg)
546
+
547
+ def puppet_arm_right_callback(self, msg):
548
+ if len(self.puppet_arm_right_deque) >= 2000:
549
+ self.puppet_arm_right_deque.popleft()
550
+ self.puppet_arm_right_deque.append(msg)
551
+
552
+ def robot_base_callback(self, msg):
553
+ if len(self.robot_base_deque) >= 2000:
554
+ self.robot_base_deque.popleft()
555
+ self.robot_base_deque.append(msg)
556
+
557
+ def init_ros(self):
558
+ rospy.init_node('joint_state_publisher', anonymous=True)
559
+ rospy.Subscriber(self.args.img_left_topic, Image, self.img_left_callback, queue_size=1000, tcp_nodelay=True)
560
+ rospy.Subscriber(self.args.img_right_topic, Image, self.img_right_callback, queue_size=1000, tcp_nodelay=True)
561
+ rospy.Subscriber(self.args.img_front_topic, Image, self.img_front_callback, queue_size=1000, tcp_nodelay=True)
562
+ if self.args.use_depth_image:
563
+ rospy.Subscriber(self.args.img_left_depth_topic, Image, self.img_left_depth_callback, queue_size=1000, tcp_nodelay=True)
564
+ rospy.Subscriber(self.args.img_right_depth_topic, Image, self.img_right_depth_callback, queue_size=1000, tcp_nodelay=True)
565
+ rospy.Subscriber(self.args.img_front_depth_topic, Image, self.img_front_depth_callback, queue_size=1000, tcp_nodelay=True)
566
+ rospy.Subscriber(self.args.puppet_arm_left_topic, JointState, self.puppet_arm_left_callback, queue_size=1000, tcp_nodelay=True)
567
+ rospy.Subscriber(self.args.puppet_arm_right_topic, JointState, self.puppet_arm_right_callback, queue_size=1000, tcp_nodelay=True)
568
+ rospy.Subscriber(self.args.robot_base_topic, Odometry, self.robot_base_callback, queue_size=1000, tcp_nodelay=True)
569
+ self.puppet_arm_left_publisher = rospy.Publisher(self.args.puppet_arm_left_cmd_topic, JointState, queue_size=10)
570
+ self.puppet_arm_right_publisher = rospy.Publisher(self.args.puppet_arm_right_cmd_topic, JointState, queue_size=10)
571
+ self.robot_base_publisher = rospy.Publisher(self.args.robot_base_cmd_topic, Twist, queue_size=10)
572
+
573
+
574
+ def get_arguments():
575
+ parser = argparse.ArgumentParser()
576
+ parser.add_argument('--max_publish_step', action='store', type=int,
577
+ help='Maximum number of action publishing steps', default=10000, required=False)
578
+ parser.add_argument('--seed', action='store', type=int,
579
+ help='Random seed', default=None, required=False)
580
+
581
+ parser.add_argument('--img_front_topic', action='store', type=str, help='img_front_topic',
582
+ default='/camera_f/color/image_raw', required=False)
583
+ parser.add_argument('--img_left_topic', action='store', type=str, help='img_left_topic',
584
+ default='/camera_l/color/image_raw', required=False)
585
+ parser.add_argument('--img_right_topic', action='store', type=str, help='img_right_topic',
586
+ default='/camera_r/color/image_raw', required=False)
587
+
588
+ parser.add_argument('--img_front_depth_topic', action='store', type=str, help='img_front_depth_topic',
589
+ default='/camera_f/depth/image_raw', required=False)
590
+ parser.add_argument('--img_left_depth_topic', action='store', type=str, help='img_left_depth_topic',
591
+ default='/camera_l/depth/image_raw', required=False)
592
+ parser.add_argument('--img_right_depth_topic', action='store', type=str, help='img_right_depth_topic',
593
+ default='/camera_r/depth/image_raw', required=False)
594
+
595
+ parser.add_argument('--puppet_arm_left_cmd_topic', action='store', type=str, help='puppet_arm_left_cmd_topic',
596
+ default='/master/joint_left', required=False)
597
+ parser.add_argument('--puppet_arm_right_cmd_topic', action='store', type=str, help='puppet_arm_right_cmd_topic',
598
+ default='/master/joint_right', required=False)
599
+ parser.add_argument('--puppet_arm_left_topic', action='store', type=str, help='puppet_arm_left_topic',
600
+ default='/puppet/joint_left', required=False)
601
+ parser.add_argument('--puppet_arm_right_topic', action='store', type=str, help='puppet_arm_right_topic',
602
+ default='/puppet/joint_right', required=False)
603
+
604
+ parser.add_argument('--robot_base_topic', action='store', type=str, help='robot_base_topic',
605
+ default='/odom_raw', required=False)
606
+ parser.add_argument('--robot_base_cmd_topic', action='store', type=str, help='robot_base_topic',
607
+ default='/cmd_vel', required=False)
608
+ parser.add_argument('--use_robot_base', action='store_true',
609
+ help='Whether to use the robot base to move around',
610
+ default=False, required=False)
611
+ parser.add_argument('--publish_rate', action='store', type=int,
612
+ help='The rate at which to publish the actions',
613
+ default=30, required=False)
614
+ parser.add_argument('--ctrl_freq', action='store', type=int,
615
+ help='The control frequency of the robot',
616
+ default=25, required=False)
617
+
618
+ parser.add_argument('--chunk_size', action='store', type=int,
619
+ help='Action chunk size',
620
+ default=64, required=False)
621
+ parser.add_argument('--arm_steps_length', action='store', type=float,
622
+ help='The maximum change allowed for each joint per timestep',
623
+ default=[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.2], required=False)
624
+
625
+ parser.add_argument('--use_actions_interpolation', action='store_true',
626
+ help='Whether to interpolate the actions if the difference is too large',
627
+ default=False, required=False)
628
+ parser.add_argument('--use_depth_image', action='store_true',
629
+ help='Whether to use depth images',
630
+ default=False, required=False)
631
+
632
+ parser.add_argument('--disable_puppet_arm', action='store_true',
633
+ help='Whether to disable the puppet arm. This is useful for safely debugging',default=False)
634
+
635
+ parser.add_argument('--config_path', type=str, default="configs/base.yaml",
636
+ help='Path to the config file')
637
+ # parser.add_argument('--cfg_scale', type=float, default=2.0,
638
+ # help='the scaling factor used to modify the magnitude of the control features during denoising')
639
+ parser.add_argument('--pretrained_model_name_or_path', type=str, required=True, help='Name or path to the pretrained model')
640
+
641
+ parser.add_argument('--lang_embeddings_path', type=str, required=True,
642
+ help='Path to the pre-encoded language instruction embeddings')
643
+
644
+ args = parser.parse_args()
645
+ return args
646
+
647
+
648
+ def main():
649
+ args = get_arguments()
650
+ ros_operator = RosOperator(args)
651
+ if args.seed is not None:
652
+ set_seed(args.seed)
653
+ config = get_config(args)
654
+ model_inference(args, config, ros_operator)
655
+
656
+
657
+ if __name__ == '__main__':
658
+ main()
scripts/agilex_model.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+
8
+ from configs.state_vec import STATE_VEC_IDX_MAPPING
9
+ from models.multimodal_encoder.siglip_encoder import SiglipVisionTower
10
+ from models.multimodal_encoder.t5_encoder import T5Embedder
11
+ from models.rdt_runner import RDTRunner
12
+
13
+
14
+ # The indices that the raw vector should be mapped to in the unified action vector
15
+ AGILEX_STATE_INDICES = [
16
+ STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"] for i in range(6)
17
+ ] + [
18
+ STATE_VEC_IDX_MAPPING["left_gripper_open"]
19
+ ] + [
20
+ STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(6)
21
+ ] + [
22
+ STATE_VEC_IDX_MAPPING[f"right_gripper_open"]
23
+ ]
24
+ TABLETOP_6D_INDICES_NAMES = [
25
+ 'left_eef_pos_x','left_eef_pos_y','left_eef_pos_z','left_eef_angle_0','left_eef_angle_1','left_eef_angle_2','left_eef_angle_3','left_eef_angle_4','left_eef_angle_5','left_gripper_open','right_eef_pos_x','right_eef_pos_y','right_eef_pos_z','right_eef_angle_0','right_eef_angle_1','right_eef_angle_2','right_eef_angle_3','right_eef_angle_4','right_eef_angle_5','right_gripper_open']
26
+ TABLETOP_6D_INDICES = [STATE_VEC_IDX_MAPPING[n] for n in TABLETOP_6D_INDICES_NAMES]
27
+
28
+ # Create the RDT model
29
+ def create_model(args, **kwargs):
30
+ model = RoboticDiffusionTransformerModel(args, **kwargs)
31
+ pretrained = kwargs.get("pretrained", None)
32
+ if (
33
+ pretrained is not None
34
+ and os.path.isfile(pretrained)
35
+ ):
36
+ model.load_pretrained_weights(pretrained)
37
+ return model
38
+
39
+
40
+ class RoboticDiffusionTransformerModel(object):
41
+ """A wrapper for the RDT model, which handles
42
+ 1. Model initialization
43
+ 2. Encodings of instructions
44
+ 3. Model inference
45
+ """
46
+ def __init__(
47
+ self, args,
48
+ device='cuda',
49
+ dtype=torch.bfloat16,
50
+ image_size=None,
51
+ control_frequency=25,
52
+ pretrained=None,
53
+ pretrained_vision_encoder_name_or_path=None,
54
+ pretrained_text_encoder_name_or_path=None
55
+ ):
56
+ self.args = args
57
+ self.dtype = dtype
58
+ self.image_size = image_size
59
+ self.device = device
60
+ self.control_frequency = control_frequency
61
+ # We do not use the text encoder due to limited GPU memory
62
+ self.text_tokenizer, self.text_model = self.get_text_encoder(pretrained_text_encoder_name_or_path)
63
+ self.image_processor, self.vision_model = self.get_vision_encoder(pretrained_vision_encoder_name_or_path)
64
+ self.policy = self.get_policy(pretrained)
65
+
66
+ self.reset()
67
+
68
+ def get_policy(self, pretrained):
69
+ """Initialize the model."""
70
+ # Initialize model with arguments
71
+ if (
72
+ pretrained is None
73
+ or os.path.isfile(pretrained)
74
+ ):
75
+ img_cond_len = (self.args["common"]["img_history_size"]
76
+ * self.args["common"]["num_cameras"]
77
+ * self.vision_model.num_patches)
78
+
79
+ _model = RDTRunner(
80
+ action_dim=self.args["common"]["state_dim"],
81
+ pred_horizon=self.args["common"]["action_chunk_size"],
82
+ config=self.args["model"],
83
+ lang_token_dim=self.args["model"]["lang_token_dim"],
84
+ img_token_dim=self.args["model"]["img_token_dim"],
85
+ state_token_dim=self.args["model"]["state_token_dim"],
86
+ max_lang_cond_len=self.args["dataset"]["tokenizer_max_length"],
87
+ img_cond_len=img_cond_len,
88
+ img_pos_embed_config=[
89
+ # No initial pos embed in the last grid size
90
+ # since we've already done in ViT
91
+ ("image", (self.args["common"]["img_history_size"],
92
+ self.args["common"]["num_cameras"],
93
+ -self.vision_model.num_patches)),
94
+ ],
95
+ lang_pos_embed_config=[
96
+ # Similarly, no initial pos embed for language
97
+ ("lang", -self.args["dataset"]["tokenizer_max_length"]),
98
+ ],
99
+ dtype=self.dtype,
100
+ )
101
+ else:
102
+ _model = RDTRunner.from_pretrained(pretrained)
103
+
104
+ return _model
105
+
106
+ def get_text_encoder(self, pretrained_text_encoder_name_or_path):
107
+ text_embedder = T5Embedder(from_pretrained=pretrained_text_encoder_name_or_path,
108
+ model_max_length=self.args["dataset"]["tokenizer_max_length"],
109
+ device=self.device)
110
+ tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
111
+ return tokenizer, text_encoder
112
+
113
+ def get_vision_encoder(self, pretrained_vision_encoder_name_or_path):
114
+ vision_encoder = SiglipVisionTower(vision_tower=pretrained_vision_encoder_name_or_path, args=None)
115
+ image_processor = vision_encoder.image_processor
116
+ return image_processor, vision_encoder
117
+
118
+ def reset(self):
119
+ """Set model to evaluation mode.
120
+ """
121
+ device = self.device
122
+ weight_dtype = self.dtype
123
+ self.policy.eval()
124
+ # self.text_model.eval()
125
+ self.vision_model.eval()
126
+
127
+ self.policy = self.policy.to(device, dtype=weight_dtype)
128
+ # self.text_model = self.text_model.to(device, dtype=weight_dtype)
129
+ self.vision_model = self.vision_model.to(device, dtype=weight_dtype)
130
+
131
+ def load_pretrained_weights(self, pretrained=None):
132
+ if pretrained is None:
133
+ return
134
+ print(f'Loading weights from {pretrained}')
135
+ filename = os.path.basename(pretrained)
136
+ if filename.endswith('.pt'):
137
+ checkpoint = torch.load(pretrained)
138
+ self.policy.load_state_dict(checkpoint["module"])
139
+ elif filename.endswith('.safetensors'):
140
+ from safetensors.torch import load_model
141
+ load_model(self.policy, pretrained)
142
+ else:
143
+ raise NotImplementedError(f"Unknown checkpoint format: {pretrained}")
144
+
145
+ def encode_instruction(self, instruction, device="cuda"):
146
+ """Encode string instruction to latent embeddings.
147
+
148
+ Args:
149
+ instruction: a string of instruction
150
+ device: a string of device
151
+
152
+ Returns:
153
+ pred: a tensor of latent embeddings of shape (text_max_length, 512)
154
+ """
155
+ tokens = self.text_tokenizer(
156
+ instruction, return_tensors="pt",
157
+ padding="longest",
158
+ truncation=True
159
+ )["input_ids"].to(device)
160
+
161
+ tokens = tokens.view(1, -1)
162
+ with torch.no_grad():
163
+ pred = self.text_model(tokens).last_hidden_state.detach()
164
+
165
+ return pred
166
+
167
+ def _format_joint_to_state(self, joints):
168
+ """
169
+ Format the joint proprioception into the unified action vector.
170
+
171
+ Args:
172
+ joints (torch.Tensor): The 6D EEF proprioception to be formatted.
173
+ qpos ([B, N, 20]).
174
+
175
+ Returns:
176
+ state (torch.Tensor): The formatted vector for RDT ([B, N, 128]).
177
+ """
178
+ # Rescale the gripper to the range of [0, 1]
179
+ joints = joints / torch.tensor(
180
+ [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]],
181
+ device=joints.device, dtype=joints.dtype
182
+ )
183
+
184
+ B, N, _ = joints.shape
185
+ state = torch.zeros(
186
+ (B, N, self.args["model"]["state_token_dim"]),
187
+ device=joints.device, dtype=joints.dtype
188
+ )
189
+ # Fill into the unified state vector
190
+ state[:, :, TABLETOP_6D_INDICES] = joints
191
+ # Assemble the mask indicating each dimension's availability
192
+ state_elem_mask = torch.zeros(
193
+ (B, self.args["model"]["state_token_dim"]),
194
+ device=joints.device, dtype=joints.dtype
195
+ )
196
+ state_elem_mask[:,TABLETOP_6D_INDICES] = 1
197
+ return state, state_elem_mask
198
+
199
+ def _unformat_action_to_joint(self, action):
200
+ """
201
+ Unformat the unified action vector into the joint action to be executed.
202
+
203
+ Args:
204
+ action (torch.Tensor): The unified action vector to be unformatted.
205
+ ([B, N, 128])
206
+
207
+ Returns:
208
+ joints (torch.Tensor): The unformatted robot joint action.
209
+ qpos ([B, N, 14]).
210
+ """
211
+ action_indices = TABLETOP_6D_INDICES
212
+ joints = action[:, :, action_indices]
213
+
214
+ # Rescale the gripper back to the action range
215
+ # Note that the action range and proprioception range are different
216
+ # for Mobile ALOHA robot
217
+ joints = joints * torch.tensor(
218
+ [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]],
219
+ device=joints.device, dtype=joints.dtype
220
+ )
221
+
222
+ return joints
223
+
224
+ @torch.no_grad()
225
+ def step(self, proprio, images, instruction):
226
+ """
227
+ Predict the next action chunk given the
228
+ proprioceptive states, images, and instruction embeddings.
229
+
230
+ Args:
231
+ proprio: proprioceptive states
232
+ images: RGB images, the order should be
233
+ [ext_{t-1}, right_wrist_{t-1}, left_wrist_{t-1},
234
+ ext_{t}, right_wrist_{t}, left_wrist_{t}]
235
+ text_embeds: instruction embeddings
236
+
237
+ Returns:
238
+ action: predicted action
239
+ """
240
+ device = self.device
241
+ dtype = self.dtype
242
+
243
+ # The background image used for padding
244
+ background_color = np.array([
245
+ int(x*255) for x in self.image_processor.image_mean
246
+ ], dtype=np.uint8).reshape(1, 1, 3)
247
+ background_image = np.ones((
248
+ self.image_processor.size["height"],
249
+ self.image_processor.size["width"], 3), dtype=np.uint8
250
+ ) * background_color
251
+
252
+ # Preprocess the images by order and encode them
253
+ image_tensor_list = []
254
+ for image in images:
255
+ if image is None:
256
+ # Replace it with the background image
257
+ image = Image.fromarray(background_image)
258
+
259
+ if self.image_size is not None:
260
+ image = transforms.Resize(self.data_args.image_size)(image)
261
+
262
+ if self.args["dataset"].get("auto_adjust_image_brightness", False):
263
+ pixel_values = list(image.getdata())
264
+ average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
265
+ if average_brightness <= 0.15:
266
+ image = transforms.ColorJitter(brightness=(1.75,1.75))(image)
267
+
268
+ if self.args["dataset"].get("image_aspect_ratio", "pad") == 'pad':
269
+ def expand2square(pil_img, background_color):
270
+ width, height = pil_img.size
271
+ if width == height:
272
+ return pil_img
273
+ elif width > height:
274
+ result = Image.new(pil_img.mode, (width, width), background_color)
275
+ result.paste(pil_img, (0, (width - height) // 2))
276
+ return result
277
+ else:
278
+ result = Image.new(pil_img.mode, (height, height), background_color)
279
+ result.paste(pil_img, ((height - width) // 2, 0))
280
+ return result
281
+ image = expand2square(image, tuple(int(x*255) for x in self.image_processor.image_mean))
282
+ image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
283
+ image_tensor_list.append(image)
284
+
285
+ image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype)
286
+
287
+ image_embeds = self.vision_model(image_tensor).detach()
288
+ image_embeds = image_embeds.reshape(-1, self.vision_model.hidden_size).unsqueeze(0)
289
+
290
+ # Prepare the proprioception states and the control frequency
291
+ joints = proprio.to(device).unsqueeze(0) # (1, 1, 14)
292
+ states, state_elem_mask = self._format_joint_to_state(joints) # (1, 1, 128), (1, 128)
293
+ states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype)
294
+ states = states[:, -1:, :] # (1, 1, 128)
295
+ ctrl_freqs = torch.tensor([self.control_frequency]).to(device)
296
+
297
+ # text_embeds = text_embeds.to(device, dtype=dtype)
298
+ text_embeds = self.encode_instruction(instruction=instruction)
299
+
300
+ # Predict the next action chunk given the inputs
301
+ trajectory = self.policy.predict_action(
302
+ lang_tokens=text_embeds,
303
+ lang_attn_mask=torch.ones(
304
+ text_embeds.shape[:2], dtype=torch.bool,
305
+ device=text_embeds.device),
306
+ img_tokens=image_embeds,
307
+ state_tokens=states,
308
+ action_mask=state_elem_mask.unsqueeze(1),
309
+ ctrl_freqs=ctrl_freqs
310
+ )
311
+ trajectory = self._unformat_action_to_joint(trajectory).to(torch.float32)
312
+
313
+ return trajectory
scripts/encode_lang_batch.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ import torch
5
+ import yaml
6
+ from tqdm import tqdm
7
+
8
+ from models.multimodal_encoder.t5_encoder import T5Embedder
9
+
10
+
11
+ GPU = 0
12
+ MODEL_PATH = "google/t5-v1_1-xxl"
13
+ CONFIG_PATH = "configs/base.yaml"
14
+ # Modify the TARGET_DIR to your dataset path
15
+ TARGET_DIR = "data/datasets/openx_embod/singlevla_benchmark_ee"
16
+
17
+ # Note: if your GPU VRAM is less than 24GB,
18
+ # it is recommended to enable offloading by specifying an offload directory.
19
+ OFFLOAD_DIR = None # Specify your offload directory here, ensuring the directory exists.
20
+
21
+ def main():
22
+ with open(CONFIG_PATH, "r") as fp:
23
+ config = yaml.safe_load(fp)
24
+
25
+ device = torch.device(f"cuda:{GPU}")
26
+ text_embedder = T5Embedder(
27
+ from_pretrained=MODEL_PATH,
28
+ model_max_length=config["dataset"]["tokenizer_max_length"],
29
+ device=device,
30
+ use_offload_folder=OFFLOAD_DIR
31
+ )
32
+ tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
33
+
34
+ # Get all the task paths
35
+ task_paths = []
36
+ for sub_dir in os.listdir(TARGET_DIR):
37
+ middle_dir = os.path.join(TARGET_DIR, sub_dir)
38
+ if os.path.isdir(middle_dir):
39
+ for task_dir in os.listdir(middle_dir):
40
+ task_path = os.path.join(middle_dir, task_dir)
41
+ if os.path.isdir(task_path):
42
+ task_paths.append(task_path)
43
+
44
+ # For each task, encode the instructions
45
+ for task_path in tqdm(task_paths):
46
+ # Load the instructions corresponding to the task from the directory
47
+ with open(os.path.join(task_path, 'expanded_instruction_gpt-4-turbo.json'), 'r') as f_instr:
48
+ instruction_dict = json.load(f_instr)
49
+ instructions = [instruction_dict['instruction']] + instruction_dict['simplified_instruction'] + \
50
+ instruction_dict['expanded_instruction']
51
+
52
+ # Encode the instructions
53
+ tokenized_res = tokenizer(
54
+ instructions, return_tensors="pt",
55
+ padding="longest",
56
+ truncation=True
57
+ )
58
+ tokens = tokenized_res["input_ids"].to(device)
59
+ attn_mask = tokenized_res["attention_mask"].to(device)
60
+
61
+ with torch.no_grad():
62
+ text_embeds = text_encoder(
63
+ input_ids=tokens,
64
+ attention_mask=attn_mask
65
+ )["last_hidden_state"].detach().cpu()
66
+
67
+ attn_mask = attn_mask.cpu().bool()
68
+
69
+ # Save the embeddings for training use
70
+ for i in range(len(instructions)):
71
+ text_embed = text_embeds[i][attn_mask[i]]
72
+ save_path = os.path.join(task_path, f"lang_embed_{i}.pt")
73
+ torch.save(text_embed, save_path)
74
+
75
+ if __name__ == "__main__":
76
+ main()
scripts/maniskill_model.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+
8
+ from configs.state_vec import STATE_VEC_IDX_MAPPING
9
+ from models.multimodal_encoder.siglip_encoder import SiglipVisionTower
10
+ from models.multimodal_encoder.t5_encoder import T5Embedder
11
+ from models.rdt_runner import RDTRunner
12
+
13
+
14
+ MANISKILL_INDICES = [
15
+ STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(7)
16
+ ] + [
17
+ STATE_VEC_IDX_MAPPING[f"right_gripper_open"]
18
+ ]
19
+
20
+
21
+ def create_model(args, pretrained, **kwargs):
22
+ model = RoboticDiffusionTransformerModel(args, **kwargs)
23
+ if pretrained is not None:
24
+ model.load_pretrained_weights(pretrained)
25
+ return model
26
+
27
+
28
+ DATA_STAT = {'state_min': [-0.7463043928146362, -0.0801204964518547, -0.4976441562175751, -2.657780647277832, -0.5742632150650024, 1.8309762477874756, -2.2423808574676514, 0.0], 'state_max': [0.7645499110221863, 1.4967026710510254, 0.4650936424732208, -0.3866899907588959, 0.5505855679512024, 3.2900545597076416, 2.5737812519073486, 0.03999999910593033], 'action_min': [-0.7472005486488342, -0.08631071448326111, -0.4995281398296356, -2.658363103866577, -0.5751323103904724, 1.8290787935256958, -2.245187997817993, -1.0], 'action_max': [0.7654682397842407, 1.4984270334243774, 0.46786263585090637, -0.38181185722351074, 0.5517147779464722, 3.291581630706787, 2.575840711593628, 1.0]}
29
+
30
+ class RoboticDiffusionTransformerModel(object):
31
+ """A wrapper for the RDT model, which handles
32
+ 1. Model initialization
33
+ 2. Encodings of instructions
34
+ 3. Model inference
35
+ """
36
+ def __init__(
37
+ self, args,
38
+ device='cuda',
39
+ dtype=torch.bfloat16,
40
+ image_size=None,
41
+ control_frequency=25,
42
+ pretrained_text_encoder_name_or_path=None,
43
+ pretrained_vision_encoder_name_or_path=None,
44
+ ):
45
+ self.args = args
46
+ self.dtype = dtype
47
+ self.image_size = image_size
48
+ self.device = device
49
+ self.control_frequency = control_frequency
50
+ self.text_tokenizer, self.text_model = self.get_text_encoder(pretrained_text_encoder_name_or_path)
51
+ self.image_processor, self.vision_model = self.get_vision_encoder(pretrained_vision_encoder_name_or_path)
52
+ self.policy = self.get_policy()
53
+
54
+ self.state_min = torch.tensor(DATA_STAT['state_min']).to(device)
55
+ self.state_max = torch.tensor(DATA_STAT['state_max']).to(device)
56
+ self.action_min = torch.tensor(DATA_STAT['action_min']).to(device)
57
+ self.action_max = torch.tensor(DATA_STAT['action_max']).to(device)
58
+
59
+ self.reset()
60
+
61
+ def get_policy(self):
62
+ """Initialize the model."""
63
+ # Initialize model with arguments
64
+ img_cond_len = (self.args["common"]["img_history_size"]
65
+ * self.args["common"]["num_cameras"]
66
+ * self.vision_model.num_patches)
67
+
68
+ _model = RDTRunner(
69
+ action_dim=self.args["common"]["state_dim"],
70
+ pred_horizon=self.args["common"]["action_chunk_size"],
71
+ config=self.args["model"],
72
+ lang_token_dim=self.args["model"]["lang_token_dim"],
73
+ img_token_dim=self.args["model"]["img_token_dim"],
74
+ state_token_dim=self.args["model"]["state_token_dim"],
75
+ max_lang_cond_len=self.args["dataset"]["tokenizer_max_length"],
76
+ img_cond_len=img_cond_len,
77
+ img_pos_embed_config=[
78
+ # No initial pos embed in the last grid size
79
+ # since we've already done in ViT
80
+ ("image", (self.args["common"]["img_history_size"],
81
+ self.args["common"]["num_cameras"],
82
+ -self.vision_model.num_patches)),
83
+ ],
84
+ lang_pos_embed_config=[
85
+ # Similarly, no initial pos embed for language
86
+ ("lang", -self.args["dataset"]["tokenizer_max_length"]),
87
+ ],
88
+ dtype=self.dtype,
89
+ )
90
+
91
+ return _model
92
+
93
+ def get_text_encoder(self, pretrained_text_encoder_name_or_path):
94
+ text_embedder = T5Embedder(from_pretrained=pretrained_text_encoder_name_or_path,
95
+ model_max_length=self.args["dataset"]["tokenizer_max_length"],
96
+ device=self.device)
97
+ tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
98
+ return tokenizer, text_encoder
99
+
100
+ def get_vision_encoder(self, pretrained_vision_encoder_name_or_path):
101
+ vision_encoder = SiglipVisionTower(vision_tower=pretrained_vision_encoder_name_or_path, args=None)
102
+ image_processor = vision_encoder.image_processor
103
+ return image_processor, vision_encoder
104
+
105
+ def reset(self):
106
+ """Set model to evaluation mode.
107
+ """
108
+ device = self.device
109
+ weight_dtype = self.dtype
110
+ self.policy.eval()
111
+ self.text_model.eval()
112
+ self.vision_model.eval()
113
+
114
+ self.policy = self.policy.to(device, dtype=weight_dtype)
115
+ self.text_model = self.text_model.to(device, dtype=weight_dtype)
116
+ self.vision_model = self.vision_model.to(device, dtype=weight_dtype)
117
+
118
+ def load_pretrained_weights(self, pretrained=None):
119
+ if pretrained is None:
120
+ return
121
+ print(f'Loading weights from {pretrained}')
122
+ filename = os.path.basename(pretrained)
123
+ if filename.endswith('.pt'):
124
+ checkpoint = torch.load(pretrained)
125
+ self.policy.load_state_dict(checkpoint["module"])
126
+ elif filename.endswith('.safetensors'):
127
+ from safetensors.torch import load_model
128
+ load_model(self.policy, pretrained)
129
+ else:
130
+ raise NotImplementedError(f"Unknown checkpoint format: {pretrained}")
131
+
132
+ def encode_instruction(self, instruction, device="cuda"):
133
+ """Encode string instruction to latent embeddings.
134
+
135
+ Args:
136
+ instruction: a string of instruction
137
+ device: a string of device
138
+
139
+ Returns:
140
+ pred: a tensor of latent embeddings of shape (text_max_length, 512)
141
+ """
142
+ tokens = self.text_tokenizer(
143
+ instruction, return_tensors="pt",
144
+ padding="longest",
145
+ truncation=True
146
+ )["input_ids"].to(device)
147
+
148
+ tokens = tokens.view(1, -1)
149
+ with torch.no_grad():
150
+ pred = self.text_model(tokens).last_hidden_state.detach()
151
+
152
+ return pred
153
+
154
+ def _format_joint_to_state(self, joints):
155
+ """
156
+ Format the robot joint state into the unified state vector.
157
+
158
+ Args:
159
+ joints (torch.Tensor): The joint state to be formatted.
160
+ qpos ([B, N, 14]).
161
+
162
+ Returns:
163
+ state (torch.Tensor): The formatted state for RDT ([B, N, 128]).
164
+ """
165
+ # Rescale the gripper
166
+ # joints = joints / torch.tensor(
167
+ # [[[1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 4.7888]]],
168
+ # device=joints.device, dtype=joints.dtype
169
+ # )
170
+
171
+ # normalize to -1,1
172
+ joints = (joints - self.state_min) / (self.state_max - self.state_min) * 2 - 1
173
+ B, N, _ = joints.shape
174
+ state = torch.zeros(
175
+ (B, N, self.args["model"]["state_token_dim"]),
176
+ device=joints.device, dtype=joints.dtype
177
+ )
178
+ # assemble the unifed state vector
179
+ state[:, :, MANISKILL_INDICES] = joints
180
+ state_elem_mask = torch.zeros(
181
+ (B, self.args["model"]["state_token_dim"]),
182
+ device=joints.device, dtype=joints.dtype
183
+ )
184
+ state_elem_mask[:, MANISKILL_INDICES] = 1
185
+ return state, state_elem_mask
186
+
187
+ def _unformat_action_to_joint(self, action):
188
+ action_indices = MANISKILL_INDICES
189
+ joints = action[:, :, action_indices]
190
+
191
+ # denormalize to action space
192
+
193
+ joints = (joints + 1) / 2 * (self.action_max - self.action_min) + self.action_min
194
+
195
+ return joints
196
+
197
+ @torch.no_grad()
198
+ def step(self, proprio, images, text_embeds):
199
+ """
200
+ Args:
201
+ proprio: proprioceptive states
202
+ images: RGB images
203
+ text_embeds: instruction embeddings
204
+
205
+ Returns:
206
+ action: predicted action
207
+ """
208
+ device = self.device
209
+ dtype = self.dtype
210
+
211
+ background_color = np.array([
212
+ int(x*255) for x in self.image_processor.image_mean
213
+ ], dtype=np.uint8).reshape(1, 1, 3)
214
+ background_image = np.ones((
215
+ self.image_processor.size["height"],
216
+ self.image_processor.size["width"], 3), dtype=np.uint8
217
+ ) * background_color
218
+
219
+ image_tensor_list = []
220
+ for image in images:
221
+ if image is None:
222
+ # Replace it with the background image
223
+ image = Image.fromarray(background_image)
224
+
225
+ if self.image_size is not None:
226
+ image = transforms.Resize(self.data_args.image_size)(image)
227
+
228
+ if self.args["dataset"].get("auto_adjust_image_brightness", False):
229
+ pixel_values = list(image.getdata())
230
+ average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
231
+ if average_brightness <= 0.15:
232
+ image = transforms.ColorJitter(brightness=(1.75,1.75))(image)
233
+
234
+ if self.args["dataset"].get("image_aspect_ratio", "pad") == 'pad':
235
+ def expand2square(pil_img, background_color):
236
+ width, height = pil_img.size
237
+ if width == height:
238
+ return pil_img
239
+ elif width > height:
240
+ result = Image.new(pil_img.mode, (width, width), background_color)
241
+ result.paste(pil_img, (0, (width - height) // 2))
242
+ return result
243
+ else:
244
+ result = Image.new(pil_img.mode, (height, height), background_color)
245
+ result.paste(pil_img, ((height - width) // 2, 0))
246
+ return result
247
+ image = expand2square(image, tuple(int(x*255) for x in self.image_processor.image_mean))
248
+ image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
249
+ image_tensor_list.append(image)
250
+
251
+ image_tensor = torch.stack(image_tensor_list, dim=0).to(device, dtype=dtype)
252
+
253
+ image_embeds = self.vision_model(image_tensor).detach()
254
+ image_embeds = image_embeds.reshape(-1, self.vision_model.hidden_size).unsqueeze(0)
255
+
256
+ # history of actions
257
+ joints = proprio.to(device).unsqueeze(0) # (1, 1, 14)
258
+ states, state_elem_mask = self._format_joint_to_state(joints) # (1, 1, 128), (1, 128)
259
+ states, state_elem_mask = states.to(device, dtype=dtype), state_elem_mask.to(device, dtype=dtype)
260
+ states = states[:, -1:, :] # (1, 1, 128)
261
+ ctrl_freqs = torch.tensor([self.control_frequency]).to(device)
262
+
263
+ text_embeds = text_embeds.to(device, dtype=dtype)
264
+
265
+ trajectory = self.policy.predict_action(
266
+ lang_tokens=text_embeds,
267
+ lang_attn_mask=torch.ones(
268
+ text_embeds.shape[:2], dtype=torch.bool,
269
+ device=text_embeds.device),
270
+ img_tokens=image_embeds,
271
+ state_tokens=states,
272
+ action_mask=state_elem_mask.unsqueeze(1),
273
+ ctrl_freqs=ctrl_freqs
274
+ )
275
+ trajectory = self._unformat_action_to_joint(trajectory).to(torch.float32)
276
+
277
+ return trajectory
train/dataset.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ import time
3
+ import os
4
+ import json
5
+ import math
6
+ import random
7
+ from typing import Dict, Sequence
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torch.utils.data import Dataset
12
+ from torchvision import transforms
13
+ from PIL import Image
14
+ import transformers
15
+
16
+ from data.filelock import FileLock
17
+ from data.hdf5_vla_dataset import TabletopHDF5VLADataset, AnubisHDF5VLADataset
18
+ from train.image_corrupt import image_corrupt
19
+
20
+
21
+ def get_clean_item(chunk_dir):
22
+ """
23
+ Get indexes of clean items in a chunk.
24
+ """
25
+ dirty_bit = read_dirty_bit(chunk_dir)
26
+ return np.where(1 - dirty_bit)[0].tolist()
27
+
28
+
29
+ def save_dirty_bit(chunk_dir, dirty_bit):
30
+ """
31
+ Save the dirty bit to the chunk directory.
32
+ """
33
+ time_stmp = time.time()
34
+ while time.time() - time_stmp < 10.0:
35
+ try:
36
+ file_path = os.path.join(chunk_dir, "dirty_bit")
37
+ lock = FileLock(file_path)
38
+ lock.acquire_write_lock()
39
+ with open(file_path, 'wb') as file:
40
+ file.write(dirty_bit.tobytes())
41
+ lock.release_lock()
42
+ return
43
+ except KeyboardInterrupt:
44
+ lock.release_lock()
45
+ raise KeyboardInterrupt
46
+ except BaseException:
47
+ lock.release_lock()
48
+ continue
49
+ raise RuntimeError("Failed to save dirty bit.")
50
+
51
+
52
+ def read_dirty_bit(chunk_dir):
53
+ """
54
+ Read the dirty bit from the chunk directory.
55
+ """
56
+ # If error occurs, retry
57
+ time_stmp = time.time()
58
+ while time.time() - time_stmp < 10.0:
59
+ try:
60
+ file_path = os.path.join(chunk_dir, "dirty_bit")
61
+ lock = FileLock(file_path)
62
+ lock.acquire_read_lock()
63
+ with open(file_path, 'rb') as file:
64
+ dirty_bit = np.frombuffer(file.read(), dtype=np.uint8).copy()
65
+ lock.release_lock()
66
+ assert len(dirty_bit) > 0
67
+ return dirty_bit
68
+ except KeyboardInterrupt:
69
+ lock.release_lock()
70
+ raise KeyboardInterrupt
71
+ except BaseException:
72
+ lock.release_lock()
73
+ continue
74
+ raise RuntimeError("Failed to read dirty bit.")
75
+
76
+
77
+ class VLAConsumerDataset(Dataset):
78
+ """A vision-languange-action Dataset for supervised training.
79
+ This dataset will load data from the buffer directory.
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ config,
85
+ tokenizer,
86
+ image_processor,
87
+ num_cameras,
88
+ img_history_size,
89
+ image_size=None,
90
+ auto_adjust_image_brightness=False,
91
+ image_aug=False,
92
+ dataset_type='pretrain',
93
+ cond_mask_prob=0.1,
94
+ cam_ext_mask_prob=-1.0,
95
+ state_noise_snr=None,
96
+ use_hdf5=False,
97
+ use_precomp_lang_embed=False,
98
+ task_name=None
99
+ ):
100
+ super(VLAConsumerDataset, self).__init__()
101
+
102
+ # Load the control frequency for each dataset
103
+ with open("configs/dataset_control_freq.json", 'r') as fp:
104
+ self.control_freq = json.load(fp)
105
+ # Load the dataset names
106
+ dataset_names_cfg = 'configs/pretrain_datasets.json' \
107
+ if dataset_type == 'pretrain' else 'configs/finetune_datasets.json'
108
+ with open(dataset_names_cfg, 'r') as file:
109
+ DATASET_NAMES = json.load(file)
110
+ # Create the mapping between dataset name and id
111
+ # self.dataset_name2id = {name: i for i, name in enumerate(DATASET_NAMES)}
112
+ # self.dataset_id2name = {i: name for i, name in enumerate(DATASET_NAMES)}
113
+ self.dataset_name2id = {task_name: 0}
114
+ self.dataset_id2name = {0: task_name}
115
+
116
+ self.image_processor = image_processor
117
+
118
+ self.buffer_dir = config["buf_path"]
119
+ self.num_chunks = config["buf_num_chunks"]
120
+ self.chunk_size = config["buf_chunk_size"]
121
+ self.tokenizer_max_length = config["tokenizer_max_length"]
122
+ self.image_aspect_ratio = config["image_aspect_ratio"]
123
+ self.state_noise_snr = state_noise_snr
124
+ self.num_cameras = num_cameras
125
+ self.img_history_size = img_history_size
126
+ self.cond_mask_prob = cond_mask_prob
127
+ self.cam_ext_mask_prob = cam_ext_mask_prob
128
+ self.use_hdf5 = use_hdf5
129
+ self.hdf5_dataset = None
130
+ if use_hdf5:
131
+ self.hdf5_dataset = AnubisHDF5VLADataset(task_name)
132
+ self.use_precomp_lang_embed = use_precomp_lang_embed
133
+ if use_precomp_lang_embed:
134
+ self.empty_lang_embed = torch.load("data/empty_lang_embed.pt")
135
+
136
+ # Load dataset stat
137
+ with open("configs/dataset_stat.json", 'r') as f:
138
+ dataset_stat = json.load(f)
139
+ self.dataset_stat = dataset_stat
140
+
141
+ self.tokenizer = tokenizer
142
+ self.image_size = image_size
143
+ self.auto_adjust_image_brightness = auto_adjust_image_brightness
144
+ self.image_aug = image_aug
145
+
146
+ self.last_content = None
147
+ self.last_meta = None
148
+
149
+ def get_dataset_name2id(self):
150
+ return self.dataset_name2id
151
+
152
+ def get_dataset_id2name(self):
153
+ return self.dataset_id2name
154
+
155
+ @staticmethod
156
+ def pairwise(iterable):
157
+ a = iter(iterable)
158
+ return zip(a, a)
159
+
160
+ @staticmethod
161
+ def _load_data_from_chunk(chunk_dir, chunk_item_idx):
162
+ # If error occurs, retry
163
+ time_stmp = time.time()
164
+ while time.time() - time_stmp < 10.0:
165
+ try:
166
+ locks = []
167
+ file_path = os.path.join(chunk_dir, f"json_content_{chunk_item_idx}.json")
168
+ lock = FileLock(file_path)
169
+ locks.append(lock)
170
+ lock.acquire_read_lock()
171
+ with open(file_path, 'r') as file:
172
+ json_content = json.load(file)
173
+ lock.release_lock()
174
+ file_path = os.path.join(chunk_dir, f"sample_{chunk_item_idx}.npz")
175
+ lock = FileLock(file_path)
176
+ locks.append(lock)
177
+ lock.acquire_read_lock()
178
+ with open(file_path, 'rb') as file:
179
+ sample_dict = np.load(file)
180
+ meta = tuple(sample_dict.values())
181
+ lock.release_lock()
182
+ return json_content, meta
183
+ except KeyboardInterrupt:
184
+ for lock in locks:
185
+ lock.release_lock()
186
+ raise KeyboardInterrupt
187
+ except BaseException:
188
+ for lock in locks:
189
+ lock.release_lock()
190
+ continue
191
+ raise RuntimeError("Failed to load sample.")
192
+
193
+ def __len__(self) -> int:
194
+ if self.use_hdf5:
195
+ return len(self.hdf5_dataset)
196
+ else:
197
+ return self.num_chunks * self.chunk_size
198
+
199
+ def _safe_load(self, index):
200
+ read_chunk_item_indices = []
201
+ # Start searching from a random chunk
202
+ read_chunk_idx = index // self.chunk_size
203
+ while len(read_chunk_item_indices) == 0:
204
+ read_chunk_dir = os.path.join(self.buffer_dir, f"chunk_{read_chunk_idx}")
205
+ try:
206
+ read_chunk_item_indices = get_clean_item(read_chunk_dir)
207
+ except BaseException as e:
208
+ # Print the error info
209
+ print("Error catched when searching a clean chunk:", e)
210
+ traceback.print_exc()
211
+ read_chunk_item_indices = []
212
+ read_chunk_idx = (read_chunk_idx + 1) % self.num_chunks
213
+
214
+ # read_chunk_item_index = random.choice(read_chunk_item_indices)
215
+ # read_chunk_item_index = read_chunk_item_indices.pop()
216
+ random_item_index = index % len(read_chunk_item_indices)
217
+ read_chunk_item_index = read_chunk_item_indices[random_item_index]
218
+
219
+ # Modify the dirty bit
220
+ try:
221
+ dirty_bit = read_dirty_bit(read_chunk_dir)
222
+ dirty_bit[read_chunk_item_index] = 1
223
+ save_dirty_bit(read_chunk_dir, dirty_bit)
224
+ except BaseException as e:
225
+ # Print the error info
226
+ print("Error catched when modifying the dirty bit:", e)
227
+ traceback.print_exc()
228
+
229
+ # load the sample
230
+ try:
231
+ content, meta = self._load_data_from_chunk(read_chunk_dir, read_chunk_item_index)
232
+ self.last_content, self.last_meta = content, meta
233
+ except BaseException as e:
234
+ # Print the error info
235
+ print("Error catched when loading sample:", e)
236
+ traceback.print_exc()
237
+
238
+ # If failed to load the data, return the last loaded data for robustness
239
+ content, meta = self.last_content, self.last_meta
240
+
241
+ return (content, *meta)
242
+
243
+ def __getitem__(self, index):
244
+ # For robustness, we will try to load the data until we succeed
245
+ while True:
246
+ data_dict = None
247
+ try:
248
+ if self.use_hdf5:
249
+ res = self.hdf5_dataset.get_item()
250
+ content = res['meta']
251
+ states = res['state']
252
+ actions = res['actions']
253
+ state_elem_mask = res['state_indicator']
254
+ image_metas = [
255
+ res['cam_high'], res['cam_high_mask'],
256
+ res['cam_right_wrist'], res['cam_right_wrist_mask'],
257
+ res['cam_left_wrist'], res['cam_left_wrist_mask'],
258
+ ]
259
+ state_std = res['state_std']
260
+ state_mean = res['state_mean']
261
+ state_norm = res['state_norm']
262
+ else:
263
+ (content, _, states, _, actions, _,
264
+ state_elem_mask, *image_metas,
265
+ state_std, state_mean, state_norm) = self._safe_load(index)
266
+
267
+ data_dict = {}
268
+ data_dict['dataset_name'] = content['dataset_name']
269
+ data_dict['data_idx'] = self.dataset_name2id[data_dict['dataset_name']]
270
+ data_dict['ctrl_freq'] = self.control_freq[data_dict['dataset_name']] \
271
+ if random.random() > self.cond_mask_prob else 0
272
+
273
+ if self.state_noise_snr is not None:
274
+ states += np.random.normal(
275
+ 0.0, state_std / np.sqrt(10 ** (self.state_noise_snr / 10)),
276
+ states.shape)
277
+ ds_state_mean = np.array(self.dataset_stat[data_dict['dataset_name']]['state_mean'])
278
+ ds_state_mean = np.tile(ds_state_mean[None], (states.shape[0], 1))
279
+ # Randomly mask the states by the mean state
280
+ data_dict["states"] = states \
281
+ if random.random() > self.cond_mask_prob else ds_state_mean
282
+ data_dict["actions"] = actions
283
+ data_dict["state_elem_mask"] = state_elem_mask \
284
+ if random.random() > self.cond_mask_prob else np.zeros_like(state_elem_mask)
285
+
286
+ # Stat for the episode that the step belongs to
287
+ data_dict["state_norm"] = state_norm
288
+
289
+ # We replace the invalid images with the background image
290
+ # and also randomly mask images by the background image
291
+ background_color = np.array([
292
+ int(x*255) for x in self.image_processor.image_mean
293
+ ], dtype=np.uint8).reshape(1, 1, 3)
294
+ background_image = np.ones((
295
+ self.image_processor.size["height"],
296
+ self.image_processor.size["width"], 3), dtype=np.uint8
297
+ ) * background_color
298
+
299
+ image_metas = list(self.pairwise(image_metas))
300
+ mask_probs = [self.cond_mask_prob] * self.num_cameras
301
+ if self.cam_ext_mask_prob >= 0.0:
302
+ mask_probs[0] = self.cam_ext_mask_prob
303
+ rearranged_images = []
304
+ for i in range(self.img_history_size):
305
+ for j in range(self.num_cameras):
306
+ images, image_mask = image_metas[j]
307
+ image, valid = images[i], image_mask[i]
308
+ if valid and (math.prod(image.shape) > 0) and \
309
+ (random.random() > mask_probs[j]):
310
+ rearranged_images.append((image, True))
311
+ else:
312
+ rearranged_images.append((background_image.copy(), False))
313
+
314
+ preprocessed_images = []
315
+ processor = self.image_processor
316
+ for image, valid in rearranged_images:
317
+ image = Image.fromarray(image)
318
+ if self.image_size is not None:
319
+ image = transforms.Resize(self.image_size)(image) # (1008, 336)
320
+ # assert image.height == 336, "We haven't prepare for training with images of different resolutions."
321
+
322
+ if valid and self.auto_adjust_image_brightness:
323
+ pixel_values = list(image.getdata())
324
+ average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
325
+ if average_brightness <= 0.15:
326
+ image = transforms.ColorJitter(brightness=(1.75,1.75))(image)
327
+
328
+ # Only apply image augmentation to 50% of the images
329
+ if valid and self.image_aug and (random.random() > 0.5):
330
+ aug_type = random.choice([
331
+ "corrput_only", "color_only", "both"])
332
+ if aug_type != "corrput_only":
333
+ image = transforms.ColorJitter(
334
+ brightness=0.3, contrast=0.4, saturation=0.5, hue=0.03)(image)
335
+ if aug_type != "color_only":
336
+ image = image_corrupt(image)
337
+
338
+ if self.image_aspect_ratio == 'pad':
339
+ def expand2square(pil_img, background_color):
340
+ width, height = pil_img.size
341
+ if width == height:
342
+ return pil_img
343
+ elif width > height:
344
+ result = Image.new(pil_img.mode, (width, width), background_color)
345
+ result.paste(pil_img, (0, (width - height) // 2))
346
+ return result
347
+ else:
348
+ result = Image.new(pil_img.mode, (height, height), background_color)
349
+ result.paste(pil_img, ((height - width) // 2, 0))
350
+ return result
351
+ image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
352
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
353
+ preprocessed_images.append(image)
354
+ data_dict["images"] = preprocessed_images
355
+
356
+ if self.use_precomp_lang_embed:
357
+ if content["instruction"][-1] == ".":
358
+ content["instruction"] = content["instruction"][:-1]
359
+ data_dict["lang_embed"] = torch.load(content["instruction"])['embeddings'][0] \
360
+ if random.random() > self.cond_mask_prob else self.empty_lang_embed ##FIXED
361
+ else:
362
+ instruction = content["instruction"] \
363
+ if random.random() > self.cond_mask_prob else ""
364
+ data_dict["input_ids"] = self.tokenizer(
365
+ instruction,
366
+ return_tensors="pt",
367
+ padding="longest",
368
+ truncation=False,
369
+ ).input_ids[0]
370
+
371
+ assert len(data_dict["input_ids"]) <= self.tokenizer_max_length, \
372
+ f"Instruction length {len(data_dict['input_ids'])} exceeds the maximum length {self.tokenizer_max_length}."
373
+
374
+ for k, v in data_dict.items():
375
+ if isinstance(v, np.ndarray):
376
+ data_dict[k] = torch.from_numpy(v)
377
+
378
+ for k, v in data_dict.items():
379
+ assert not isinstance(v, np.ndarray), f"key: {k}, value: {v}"
380
+ # data_dict[k] = torch.from_numpy(v)
381
+
382
+ return data_dict
383
+ except BaseException as e:
384
+ # Print the error info
385
+ if data_dict is not None:
386
+ print(f"Error catched when processing sample from {data_dict.get('dataset_name')}:", e)
387
+ else:
388
+ print(f"Error catched when processing sample:", e)
389
+ traceback.print_exc()
390
+ # Try incresing the index
391
+ index = (index + 1) % len(self)
392
+
393
+
394
+ class DataCollatorForVLAConsumerDataset(object):
395
+ """Collate examples for supervised training."""
396
+
397
+ def __init__(self, tokenizer: transformers.PreTrainedTokenizer) -> None:
398
+ self.tokenizer = tokenizer
399
+
400
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
401
+ batch = {
402
+ "states": [],
403
+ "actions": [],
404
+ "state_elem_mask": [],
405
+ "state_norm": [],
406
+ "images": [],
407
+ "data_indices": [],
408
+ "ctrl_freqs": []
409
+ }
410
+ input_ids = []
411
+ lang_embeds = []
412
+ lang_embed_lens = []
413
+
414
+ for instance in instances:
415
+ # Convert all the numpy arrays to tensor
416
+ keys_to_check = [
417
+ 'states', 'actions',
418
+ 'state_elem_mask', 'state_norm',
419
+ ]
420
+ for key in keys_to_check:
421
+ if isinstance(instance[key], torch.Tensor):
422
+ item = instance[key]
423
+ else:
424
+ item = torch.from_numpy(instance[key])
425
+ batch[key].append(item)
426
+
427
+ if "input_ids" in instance:
428
+ input_ids.append(instance["input_ids"])
429
+ else:
430
+ lang_embeds.append(instance["lang_embed"])
431
+ lang_embed_lens.append(instance["lang_embed"].shape[0])
432
+
433
+ batch["images"].append(torch.stack(instance["images"], dim=0))
434
+ batch["data_indices"].append(instance["data_idx"])
435
+ batch["ctrl_freqs"].append(instance["ctrl_freq"])
436
+
437
+ keys_to_stack = [
438
+ 'states', 'actions',
439
+ 'state_elem_mask', 'state_norm',
440
+ "images"
441
+ ]
442
+ for key in keys_to_stack:
443
+ batch[key] = torch.stack(batch[key], dim=0)
444
+
445
+ batch["ctrl_freqs"] = torch.tensor(batch["ctrl_freqs"])
446
+
447
+ if len(input_ids) > 0:
448
+ input_ids = torch.nn.utils.rnn.pad_sequence(
449
+ input_ids,
450
+ batch_first=True,
451
+ padding_value=self.tokenizer.pad_token_id)
452
+ batch["input_ids"] = input_ids
453
+ batch["lang_attn_mask"] = input_ids.ne(self.tokenizer.pad_token_id)
454
+ else:
455
+ lang_embeds = torch.nn.utils.rnn.pad_sequence(
456
+ lang_embeds,
457
+ batch_first=True,
458
+ padding_value=0)
459
+ input_lang_attn_mask = torch.zeros(
460
+ lang_embeds.shape[0], lang_embeds.shape[1], dtype=torch.bool)
461
+ for i, l in enumerate(lang_embed_lens):
462
+ input_lang_attn_mask[i, :l] = True
463
+ batch["lang_embeds"] = lang_embeds
464
+ batch["lang_attn_mask"] = input_lang_attn_mask
465
+
466
+
467
+ return batch
train/image_corrupt.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.simplefilter(action='ignore', category=FutureWarning)
3
+
4
+ import numpy as np
5
+ np.bool = np.bool_
6
+ import imgaug.augmenters as iaa
7
+ from PIL import Image
8
+
9
+
10
+ # Define our sequence of augmentation steps that will be applied to every image.
11
+ seq = iaa.Sequential(
12
+ [
13
+ # Execute one of the following noise augmentations
14
+ iaa.OneOf([
15
+ iaa.AdditiveGaussianNoise(
16
+ loc=0, scale=(0.0, 0.05*255), per_channel=0.5
17
+ ),
18
+ iaa.AdditiveLaplaceNoise(scale=(0.0, 0.05*255), per_channel=0.5),
19
+ iaa.AdditivePoissonNoise(lam=(0.0, 0.05*255), per_channel=0.5)
20
+ ]),
21
+
22
+ # Execute one or none of the following blur augmentations
23
+ iaa.SomeOf((0, 1), [
24
+ iaa.OneOf([
25
+ iaa.GaussianBlur((0, 3.0)),
26
+ iaa.AverageBlur(k=(2, 7)),
27
+ iaa.MedianBlur(k=(3, 11)),
28
+ ]),
29
+ iaa.MotionBlur(k=(3, 36)),
30
+ ]),
31
+ ],
32
+ # do all of the above augmentations in random order
33
+ random_order=True
34
+ )
35
+
36
+
37
+ def image_corrupt(image: Image):
38
+ image_arr = np.array(image)
39
+ image_arr = image_arr[None, ...]
40
+
41
+ image_arr = seq(images=image_arr)
42
+
43
+ image = Image.fromarray(image_arr[0])
44
+ return image
train/sample.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ @torch.no_grad()
8
+ def log_sample_res(
9
+ text_encoder, vision_encoder, rdt, args,
10
+ accelerator, weight_dtype, dataset_id2name, dataloader, logger
11
+ ):
12
+ logger.info(
13
+ f"Running sampling for {args.num_sample_batches} batches..."
14
+ )
15
+
16
+ rdt.eval()
17
+
18
+ loss_for_log = defaultdict(float)
19
+ loss_counter = defaultdict(int)
20
+ for step, batch in enumerate(dataloader):
21
+ if step >= args.num_sample_batches:
22
+ break
23
+
24
+ data_indices = batch["data_indices"]
25
+ ctrl_freqs = batch["ctrl_freqs"]
26
+ state_norm = batch["state_norm"].to(dtype=weight_dtype)
27
+ images = batch["images"].to(dtype=weight_dtype)
28
+ states = batch["states"].to(dtype=weight_dtype)
29
+ # We only use the last state as input
30
+ states = states[:, -1:, :]
31
+ actions = batch["actions"].to(dtype=weight_dtype)
32
+ state_elem_mask = batch["state_elem_mask"].to(dtype=weight_dtype)
33
+
34
+ batch_size, _, C, H, W = images.shape
35
+ image_embeds = vision_encoder(images.reshape(-1, C, H, W)).detach()
36
+ image_embeds = image_embeds.reshape((batch_size, -1, vision_encoder.hidden_size))
37
+
38
+ lang_attn_mask = batch["lang_attn_mask"]
39
+ text_embeds = batch["lang_embeds"].to(dtype=weight_dtype) \
40
+ if args.precomp_lang_embed \
41
+ else text_encoder(
42
+ input_ids=batch["input_ids"],
43
+ attention_mask=lang_attn_mask
44
+ )["last_hidden_state"].detach()
45
+
46
+ with torch.autocast(device_type='cuda',dtype=torch.bfloat16):
47
+ pred_actions = rdt.predict_action(
48
+ lang_tokens=text_embeds,
49
+ lang_attn_mask=lang_attn_mask,
50
+ img_tokens=image_embeds,
51
+ state_tokens=states,
52
+ action_mask=state_elem_mask.unsqueeze(1),
53
+ ctrl_freqs=ctrl_freqs
54
+ )
55
+
56
+ num_steps = pred_actions.shape[1]
57
+ expanded_state_elem_mask = state_elem_mask.unsqueeze(1).tile((1, num_steps, 1)).float()
58
+ expanded_state_norm = state_norm.unsqueeze(1).tile((1, num_steps, 1)).float()
59
+
60
+ loss = F.mse_loss(pred_actions, actions, reduction='none').float()
61
+
62
+ mse_loss_per_entry = ((loss * expanded_state_elem_mask).reshape((batch_size, -1)).sum(1)
63
+ / expanded_state_elem_mask.reshape((batch_size, -1)).sum(1))
64
+ l2_loss_per_entry = loss.sqrt() / (expanded_state_norm + 1e-3)
65
+ l2_loss_per_entry = ((l2_loss_per_entry * expanded_state_elem_mask).reshape((batch_size, -1)).sum(1)
66
+ / expanded_state_elem_mask.reshape((batch_size, -1)).sum(1))
67
+
68
+ dataset_indices, mse_losses, l2_losses = accelerator.gather_for_metrics(
69
+ (torch.LongTensor(data_indices).to(device=pred_actions.device),
70
+ mse_loss_per_entry, l2_loss_per_entry),
71
+ )
72
+ dataset_indices = dataset_indices.tolist()
73
+ if accelerator.is_main_process:
74
+ for loss_suffix, losses in zip(["_sample_mse", "_sample_l2err"], [mse_losses, l2_losses]):
75
+ for dataset_idx, loss_tensor in zip(dataset_indices, losses):
76
+ loss_name = dataset_id2name[dataset_idx] + loss_suffix
77
+ loss_for_log[loss_name] += loss_tensor.item()
78
+ loss_counter[loss_name] += 1
79
+
80
+ mse_loss = (loss * expanded_state_elem_mask).sum() / expanded_state_elem_mask.sum()
81
+ mse_loss_scaler = accelerator.gather(mse_loss).mean().item()
82
+ loss_for_log["overall_avg_sample_mse"] += mse_loss_scaler
83
+
84
+ l2_loss = loss.sqrt() / (expanded_state_norm + 1e-3)
85
+ l2_loss = (l2_loss * expanded_state_elem_mask).sum() / expanded_state_elem_mask.sum()
86
+ l2_loss_scaler = accelerator.gather(l2_loss).mean().item()
87
+ loss_for_log["overall_avg_sample_l2err"] += l2_loss_scaler
88
+
89
+ for name in loss_for_log:
90
+ if name in ["overall_avg_sample_mse", "overall_avg_sample_l2err"]:
91
+ loss_scaler = loss_for_log[name]
92
+ loss_for_log[name] = round(loss_scaler / (args.num_sample_batches), 4)
93
+ else:
94
+ loss_for_log[name] = round(loss_for_log[name] / loss_counter[name], 4)
95
+
96
+ rdt.train()
97
+ torch.cuda.empty_cache()
98
+
99
+ return dict(loss_for_log)
train/train.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import copy
17
+ import logging
18
+ import math
19
+ import os
20
+ from pathlib import Path
21
+
22
+ import diffusers
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ import transformers
26
+ import yaml
27
+ from accelerate import Accelerator
28
+ from accelerate.utils import DeepSpeedPlugin, ProjectConfiguration, set_seed
29
+ from diffusers.optimization import get_scheduler
30
+ from diffusers.utils import is_wandb_available
31
+ from huggingface_hub import create_repo, upload_folder
32
+ from tqdm.auto import tqdm
33
+ from safetensors.torch import load_model
34
+
35
+ from models.ema_model import EMAModel
36
+ from models.multimodal_encoder.siglip_encoder import SiglipVisionTower
37
+ from models.multimodal_encoder.t5_encoder import T5Embedder
38
+ from models.rdt_runner import RDTRunner
39
+ from train.dataset import DataCollatorForVLAConsumerDataset, VLAConsumerDataset
40
+ from train.sample import log_sample_res
41
+
42
+
43
+ if is_wandb_available():
44
+ import wandb
45
+
46
+
47
+ def save_model_card(repo_id: str, base_model=str, repo_folder=None):
48
+ yaml = f"""
49
+ ---
50
+ license: mit
51
+ base_model: {base_model}
52
+ language:
53
+ - en
54
+ pipeline_tag: robotics
55
+ library_name: transformers
56
+ tags:
57
+ - robotics
58
+ - pytorch
59
+ - multimodal
60
+ - pretraining
61
+ - vla
62
+ - diffusion
63
+ - rdt
64
+ ---
65
+ """
66
+ model_card = f"""
67
+ # RDT - {repo_id}
68
+
69
+ This is a RDT model derived from {base_model}. The weights were trained using [RDT](https://rdt-robotics.github.io/rdt-robotics/).
70
+ """
71
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
72
+ f.write(yaml + model_card)
73
+
74
+
75
+ def train(args, logger):
76
+ # Read the config
77
+ with open(args.config_path, "r") as fp:
78
+ config = yaml.safe_load(fp)
79
+
80
+ logging_dir = Path(args.output_dir, args.logging_dir)
81
+
82
+ accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
83
+ accelerator = Accelerator(
84
+ deepspeed_plugin=DeepSpeedPlugin(
85
+ hf_ds_config=args.deepspeed
86
+ ) if args.deepspeed is not None else None,
87
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
88
+ mixed_precision=args.mixed_precision,
89
+ log_with=args.report_to,
90
+ project_dir=logging_dir,
91
+ project_config=accelerator_project_config,
92
+ )
93
+
94
+ if args.report_to == "wandb":
95
+ if not is_wandb_available():
96
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
97
+
98
+ # Make one log on every process with the configuration for debugging.
99
+ logging.basicConfig(
100
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
101
+ datefmt="%m/%d/%Y %H:%M:%S",
102
+ level=logging.INFO,
103
+ )
104
+ logger.info(accelerator.state, main_process_only=False)
105
+ if accelerator.is_local_main_process:
106
+ transformers.utils.logging.set_verbosity_warning()
107
+ diffusers.utils.logging.set_verbosity_info()
108
+ else:
109
+ transformers.utils.logging.set_verbosity_error()
110
+ diffusers.utils.logging.set_verbosity_error()
111
+
112
+ # If passed along, set the training seed now.
113
+ if args.seed is not None:
114
+ set_seed(args.seed)
115
+
116
+ # Handle the repository creation
117
+ if accelerator.is_main_process:
118
+ if args.output_dir is not None:
119
+ os.makedirs(args.output_dir, exist_ok=True)
120
+
121
+ if args.push_to_hub:
122
+ repo_id = create_repo(
123
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
124
+ ).repo_id
125
+
126
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
127
+ # as these models are only used for inference, keeping weights in full precision is not required.
128
+ weight_dtype = torch.float32
129
+ if accelerator.mixed_precision == "fp16":
130
+ weight_dtype = torch.float16
131
+ elif accelerator.mixed_precision == "bf16":
132
+ weight_dtype = torch.bfloat16
133
+
134
+ if args.precomp_lang_embed:
135
+ tokenizer, text_encoder = None, None
136
+ else:
137
+ text_embedder = T5Embedder(from_pretrained=args.pretrained_text_encoder_name_or_path,
138
+ model_max_length=config["dataset"]["tokenizer_max_length"], device=accelerator.device)
139
+ tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
140
+
141
+ vision_encoder = SiglipVisionTower(vision_tower=args.pretrained_vision_encoder_name_or_path, args=None)
142
+ image_processor = vision_encoder.image_processor
143
+
144
+ # Load from a pretrained checkpoint
145
+ if (
146
+ args.pretrained_model_name_or_path is not None
147
+ and not os.path.isfile(args.pretrained_model_name_or_path)
148
+ ):
149
+ logger.info("Constructing model from pretrained checkpoint.")
150
+ rdt = RDTRunner.from_pretrained(args.pretrained_model_name_or_path)
151
+ else:
152
+ logger.info("Constructing model from provided config.")
153
+ # Calculate the image condition length
154
+ img_cond_len = (config["common"]["img_history_size"]
155
+ * config["common"]["num_cameras"]
156
+ * vision_encoder.num_patches)
157
+ rdt = RDTRunner(
158
+ action_dim=config["common"]["state_dim"],
159
+ pred_horizon=config["common"]["action_chunk_size"],
160
+ config=config["model"],
161
+ lang_token_dim=config["model"]["lang_token_dim"],
162
+ img_token_dim=config["model"]["img_token_dim"],
163
+ state_token_dim=config["model"]["state_token_dim"],
164
+ max_lang_cond_len=config["dataset"]["tokenizer_max_length"],
165
+ img_cond_len=img_cond_len,
166
+ img_pos_embed_config=[
167
+ # No initial pos embed in the last grid size
168
+ # since we've already done in ViT
169
+ ("image", (config["common"]["img_history_size"],
170
+ config["common"]["num_cameras"],
171
+ -vision_encoder.num_patches)),
172
+ ],
173
+ lang_pos_embed_config=[
174
+ # Similarly, no initial pos embed for language
175
+ ("lang", -config["dataset"]["tokenizer_max_length"]),
176
+ ],
177
+ dtype=weight_dtype,
178
+ )
179
+
180
+
181
+ ema_rdt = copy.deepcopy(rdt)
182
+ ema_model = EMAModel(
183
+ ema_rdt,
184
+ update_after_step=config["model"]["ema"]["update_after_step"],
185
+ inv_gamma=config["model"]["ema"]["inv_gamma"],
186
+ power=config["model"]["ema"]["power"],
187
+ min_value=config["model"]["ema"]["min_value"],
188
+ max_value=config["model"]["ema"]["max_value"]
189
+ )
190
+
191
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
192
+ # which ensure saving model in huggingface format (config.json + pytorch_model.bin)
193
+ def save_model_hook(models, weights, output_dir):
194
+ if accelerator.is_main_process:
195
+ for model in models:
196
+ model_to_save = model.module if hasattr(model, "module") else model # type: ignore
197
+ if isinstance(model_to_save, type(accelerator.unwrap_model(rdt))):
198
+ model_to_save.save_pretrained(output_dir)
199
+
200
+ accelerator.register_save_state_pre_hook(save_model_hook)
201
+
202
+ if args.gradient_checkpointing:
203
+ # TODO:
204
+ raise NotImplementedError("Gradient checkpointing is not yet implemented.")
205
+
206
+ # Enable TF32 for faster training on Ampere GPUs,
207
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
208
+ if args.allow_tf32:
209
+ torch.backends.cuda.matmul.allow_tf32 = True
210
+
211
+ if args.scale_lr:
212
+ args.learning_rate = (
213
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
214
+ )
215
+
216
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
217
+ if args.use_8bit_adam:
218
+ try:
219
+ import bitsandbytes as bnb
220
+ except ImportError:
221
+ raise ImportError(
222
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
223
+ )
224
+
225
+ optimizer_class = bnb.optim.AdamW8bit
226
+ else:
227
+ optimizer_class = torch.optim.AdamW
228
+
229
+ # Optimizer creation
230
+ params_to_optimize = rdt.parameters()
231
+ optimizer = optimizer_class(
232
+ params_to_optimize,
233
+ lr=args.learning_rate,
234
+ betas=(args.adam_beta1, args.adam_beta2),
235
+ weight_decay=args.adam_weight_decay,
236
+ eps=args.adam_epsilon,
237
+ )
238
+
239
+ # Dataset and DataLoaders creation:
240
+ train_dataset = VLAConsumerDataset(
241
+ config=config["dataset"],
242
+ tokenizer=tokenizer,
243
+ image_processor=image_processor,
244
+ num_cameras=config["common"]["num_cameras"],
245
+ img_history_size=config["common"]["img_history_size"],
246
+ dataset_type=args.dataset_type,
247
+ image_aug=args.image_aug,
248
+ cond_mask_prob=args.cond_mask_prob,
249
+ cam_ext_mask_prob=args.cam_ext_mask_prob,
250
+ state_noise_snr=args.state_noise_snr,
251
+ use_hdf5=args.load_from_hdf5,
252
+ use_precomp_lang_embed=args.precomp_lang_embed,
253
+ task_name=args.dataset_name,
254
+ )
255
+ sample_dataset = VLAConsumerDataset(
256
+ config=config["dataset"],
257
+ tokenizer=tokenizer,
258
+ image_processor=image_processor,
259
+ num_cameras=config["common"]["num_cameras"],
260
+ img_history_size=config["common"]["img_history_size"],
261
+ dataset_type=args.dataset_type,
262
+ image_aug=False,
263
+ cond_mask_prob=0,
264
+ cam_ext_mask_prob=-1,
265
+ state_noise_snr=None,
266
+ use_hdf5=args.load_from_hdf5,
267
+ use_precomp_lang_embed=args.precomp_lang_embed,
268
+ task_name=args.dataset_name,
269
+ )
270
+
271
+ data_collator = DataCollatorForVLAConsumerDataset(tokenizer)
272
+
273
+ train_dataloader = torch.utils.data.DataLoader(
274
+ train_dataset,
275
+ batch_size=args.train_batch_size,
276
+ shuffle=True,
277
+ collate_fn=data_collator,
278
+ num_workers=args.dataloader_num_workers,
279
+ pin_memory=True,
280
+ persistent_workers=True
281
+ )
282
+ sample_dataloader = torch.utils.data.DataLoader(
283
+ sample_dataset,
284
+ batch_size=args.sample_batch_size,
285
+ shuffle=True,
286
+ collate_fn=data_collator,
287
+ num_workers=args.dataloader_num_workers,
288
+ pin_memory=True,
289
+ persistent_workers=True
290
+ )
291
+
292
+ # Scheduler and math around the number of training steps.
293
+ overrode_max_train_steps = False
294
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
295
+ if args.max_train_steps is None:
296
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
297
+ overrode_max_train_steps = True
298
+
299
+ lr_scheduler = get_scheduler(
300
+ args.lr_scheduler,
301
+ optimizer=optimizer,
302
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
303
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
304
+ num_cycles=args.lr_num_cycles,
305
+ power=args.lr_power,
306
+ )
307
+
308
+ # Prepare everything with our `accelerator`.
309
+ rdt, optimizer, train_dataloader, sample_dataloader, lr_scheduler = accelerator.prepare(
310
+ rdt, optimizer, train_dataloader, sample_dataloader, lr_scheduler
311
+ )
312
+
313
+ ema_rdt.to(accelerator.device, dtype=weight_dtype)
314
+
315
+ if text_encoder is not None:
316
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
317
+
318
+ if vision_encoder is not None:
319
+ vision_encoder.vision_tower.to(accelerator.device, dtype=weight_dtype)
320
+
321
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
322
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
323
+ if overrode_max_train_steps:
324
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
325
+ # Afterwards we recalculate our number of training epochs
326
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
327
+
328
+ # We need to initialize the trackers we use, and also store our configuration.
329
+ # The trackers initializes automatically on the main process.
330
+ if accelerator.is_main_process:
331
+ accelerator.init_trackers("roboticDiffusionTransformer", config=vars(args))
332
+
333
+ # Train!
334
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
335
+
336
+ logger.info("***** Running training *****")
337
+ logger.info(f" Num examples = {len(train_dataset)}")
338
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
339
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
340
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
341
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
342
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
343
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
344
+ global_step = 0
345
+ first_epoch = 0
346
+
347
+ # Load from a pretrained checkpoint
348
+ if (
349
+ args.resume_from_checkpoint is None
350
+ and args.pretrained_model_name_or_path is not None
351
+ and os.path.isfile(args.pretrained_model_name_or_path)
352
+ ):
353
+ # Since EMA is deprecated, we do not load EMA from the pretrained checkpoint
354
+ logger.info("Loading from a pretrained checkpoint.")
355
+ checkpoint = torch.load(args.pretrained_model_name_or_path)
356
+ rdt.module.load_state_dict(checkpoint["module"])
357
+
358
+ # Potentially load in the weights and states from a previous save
359
+ if args.resume_from_checkpoint:
360
+ if args.resume_from_checkpoint != "latest":
361
+ path = os.path.basename(args.resume_from_checkpoint)
362
+ else:
363
+ # Get the mos recent checkpoint
364
+ dirs = os.listdir(args.output_dir)
365
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
366
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
367
+ path = dirs[-1] if len(dirs) > 0 else None
368
+
369
+ if path is None:
370
+ accelerator.print(
371
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
372
+ )
373
+ args.resume_from_checkpoint = None
374
+ else:
375
+ accelerator.print(f"Resuming from checkpoint {path}")
376
+ try:
377
+ accelerator.load_state(os.path.join(args.output_dir, path)) # load_module_strict=False
378
+ except:
379
+ # load deepspeed's state_dict
380
+ logger.info("Resuming training state failed. Attempting to only load from model checkpoint.")
381
+ checkpoint = torch.load(os.path.join(args.output_dir, path, "pytorch_model", "mp_rank_00_model_states.pt"))
382
+ rdt.module.load_state_dict(checkpoint["module"])
383
+
384
+ load_model(ema_rdt, os.path.join(args.output_dir, path, "ema", "model.safetensors"))
385
+ global_step = int(path.split("-")[1])
386
+
387
+ resume_global_step = global_step * args.gradient_accumulation_steps
388
+ first_epoch = global_step // num_update_steps_per_epoch
389
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
390
+
391
+ # Only show the progress bar once on each machine.
392
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
393
+ progress_bar.set_description("Steps")
394
+
395
+ loss_for_log = {}
396
+ for epoch in range(first_epoch, args.num_train_epochs):
397
+
398
+ rdt.train()
399
+
400
+ # Set the progress_bar to correct position
401
+ if args.resume_from_checkpoint and epoch == first_epoch:
402
+ progress_bar.update(resume_step // args.gradient_accumulation_steps)
403
+
404
+ # Forward and backward...
405
+ for batch in train_dataloader:
406
+ with accelerator.accumulate(rdt):
407
+ images = batch["images"].to(dtype=weight_dtype)
408
+ states = batch["states"].to(dtype=weight_dtype) # (B, T, D_a)
409
+ # We only use the last state as input
410
+ states = states[:, -1:, :]
411
+ actions = batch["actions"].to(dtype=weight_dtype)
412
+ state_elem_mask = batch["state_elem_mask"].to(dtype=weight_dtype)
413
+ ctrl_freqs = batch["ctrl_freqs"]
414
+
415
+ with torch.no_grad():
416
+ batch_size, _, C, H, W = images.shape
417
+ image_embeds = vision_encoder(images.reshape(-1, C, H, W)).detach()
418
+ image_embeds = image_embeds.reshape((batch_size, -1, vision_encoder.hidden_size))
419
+
420
+ lang_attn_mask = batch["lang_attn_mask"]
421
+ text_embeds = batch["lang_embeds"].to(dtype=weight_dtype) \
422
+ if args.precomp_lang_embed \
423
+ else text_encoder(
424
+ input_ids=batch["input_ids"],
425
+ attention_mask=lang_attn_mask
426
+ )["last_hidden_state"].detach()
427
+
428
+ state_elem_mask = state_elem_mask.unsqueeze(1)
429
+ loss = rdt(
430
+ lang_tokens=text_embeds,
431
+ lang_attn_mask=lang_attn_mask,
432
+ img_tokens=image_embeds,
433
+ state_tokens=states,
434
+ action_gt=actions,
435
+ action_mask=state_elem_mask,
436
+ ctrl_freqs=ctrl_freqs
437
+ )
438
+
439
+ accelerator.backward(loss)
440
+ if accelerator.sync_gradients:
441
+ params_to_clip = rdt.parameters()
442
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
443
+ optimizer.step()
444
+ lr_scheduler.step()
445
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
446
+
447
+ ema_model.step(accelerator.unwrap_model(rdt))
448
+
449
+ # Checks if the accelerator has performed an optimization step behind the scenes
450
+ if accelerator.sync_gradients:
451
+ progress_bar.update(1)
452
+ global_step += 1
453
+
454
+ if global_step % args.checkpointing_period == 0:
455
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
456
+ accelerator.save_state(save_path)
457
+ ema_save_path = os.path.join(save_path, f"ema")
458
+ accelerator.save_model(ema_rdt, ema_save_path)
459
+ logger.info(f"Saved state to {save_path}")
460
+
461
+ if args.sample_period > 0 and global_step % args.sample_period == 0:
462
+ sample_loss_for_log = log_sample_res(
463
+ text_encoder,
464
+ vision_encoder,
465
+ rdt, # We do not use EMA currently
466
+ args,
467
+ accelerator,
468
+ weight_dtype,
469
+ sample_dataset.get_dataset_id2name(),
470
+ sample_dataloader,
471
+ logger,
472
+ )
473
+ logger.info(sample_loss_for_log)
474
+ accelerator.log(sample_loss_for_log, step=global_step)
475
+
476
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
477
+ progress_bar.set_postfix(**logs)
478
+ logs.update(loss_for_log)
479
+ # logger.info(logs)
480
+ accelerator.log(logs, step=global_step)
481
+
482
+ if global_step >= args.max_train_steps:
483
+ break
484
+
485
+ # Create the pipeline using using the trained modules and save it.
486
+ accelerator.wait_for_everyone()
487
+ if accelerator.is_main_process:
488
+ accelerator.unwrap_model(rdt).save_pretrained(args.output_dir)
489
+ ema_save_path = os.path.join(args.output_dir, f"ema")
490
+ accelerator.save_model(ema_rdt, ema_save_path)
491
+
492
+ logger.info(f"Saved Model to {args.output_dir}")
493
+
494
+ if args.push_to_hub:
495
+ save_model_card(
496
+ repo_id,
497
+ base_model=args.pretrained_model_name_or_path,
498
+ repo_folder=args.output_dir,
499
+ )
500
+ upload_folder(
501
+ repo_id=repo_id,
502
+ folder_path=args.output_dir,
503
+ commit_message="End of training",
504
+ token=args.hub_token,
505
+ allow_patterns=["pytorch_model.bin", "*.json", "*.md"],
506
+ # ignore_patterns=["step_*", "epoch_*"],
507
+ )
508
+
509
+ accelerator.end_training()