File size: 11,157 Bytes
d899b9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 |
# Pipeline of Pre-Training RDT
Firstly, you need to install the prerequisites for RDT (see [README](../README.md#installation)). Then, you can install the prerequisites for TensorFlow Dataset (in another Conda environment).
## Installation for TensorFlow Dataset
```bash
# Under the root directory of this repo
conda create -n rdt-data python=3.10
conda activate rdt-data
# Install all the prequisites
pip install -r requirements_data.txt
# Or you can manually install each package (please refer to requirements_data.txt for specific versions)
pip install tfds-nightly gsutil tensorflow Pillow pyyaml opencv-python tensorflow-graphics imageio[ffmpeg]
# If the speed is too slow, you can specify alternative sources (tfds-nightly is not available in Tsinghua mirror)
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple gsutil tensorflow Pillow pyyaml opencv-python tensorflow-graphics imageio[ffmpeg]
```
## Download and Prepare Datasets
We introduce how to download each of our pre-training datasets. If you plan to pre-train on a subset of them, just download the ones you need. You can also fine-tune RDT through this pipeline only if your target dataset is included below or in the Google Cloud Storage.
| Dataset | Sample Percentage (%) |
| ---- | ---- |
| RT-1 Dataset | 9.00 |
| TACO Dataset | 1.99 |
| JACO Play Dataset | 1.10 |
| Cable Routing Dataset | 0.27 |
| NYU Door Opening | 0.33 |
| Viola | 0.40 |
| Berkeley UR5 | 1.06 |
| TOTO | 1.06 |
| Kuka | 1.66 |
| Language Table | 3.32 |
| Columbia Cairlab Pusht Real | 0.40 |
| Stanford Kuka Multimodal Dataset | 1.83 |
| Stanford Hydra Dataset | 0.80 |
| Austin Buds Dataset | 0.23 |
| Maniskill Dataset | 5.78 |
| Furniture Bench Dataset | 2.36 |
| UCSD Kitchen Dataset | 0.40 |
| UCSD Pick And Place Dataset | 1.23 |
| Austin Sailor Dataset | 0.50 |
| Austin Sirius Dataset | 0.80 |
| BC Z | 6.91 |
| UTokyo PR2 Opening Fridge | 0.30 |
| UTokyo PR2 Tabletop Manipulation | 0.50 |
| UTokyo Xarm Pick And Place | 0.33 |
| UTokyo Xarm Bimanual | 0.03 |
| Berkeley MVP | 0.73 |
| Berkeley RPT | 1.00 |
| KAIST Nonprehensile | 0.46 |
| Tokyo U LSMO | 0.23 |
| DLR Sara Grid Clamp | 0.03 |
| Robocook | 1.66 |
| Imperialcollege Sawyer Wrist Cam | 0.43 |
| Iamlab CMU Pickup Insert | 0.83 |
| UTAustin Mutex | 1.29 |
| Fanuc Manipulation | 0.66 |
| Play Fusion | 0.80 |
| Droid | 10.06 |
| FMB| 1.39 |
| Dobb·E | 1.20 |
| QUT Dexterous Manipulation | 0.46 |
| Aloha Dataset | 4.98 |
| Mobile Aloha Dataset | 4.98 |
| Roboset | 4.48 |
| RH20T | 10.99 |
| Calvin Dataset | 3.32 |
| Bridgev2 | 7.44 |
Before everything, let's link the dataset directory on your disk to a subfolder of this repo:
```bash
ln -s /path/to/dataset /path/to/repo/RoboticsDiffusionTransformer/data/datasets
```
### Open X-Embodiment
Specify the correct path to the `gsutil` in your Conda in [this file](../data/openx_embod/download.sh#L72).
Run the following commands to download our selected datasets for the Open X-Embodiment:
```bash
# Under the root directory of this repo
cd data/openx_embod
# Download all datasets
bash download_openx_embod.sh
```
Note: By modifying `download_openx_embod.sh`, you can download any dataset on the Google Cloud (as long as it can be downloaded with `gsutil` and is stored in `TFRecord` format), not just the ones we have listed.
### Mobile ALOHA Dataset
Download the Mobile ALOHA Dataset from the [official website](https://mobile-aloha.github.io) to `data/datasets/aloha`, then run:
```bash
cd data/aloha
# Convert the dataset to TFRecord
python hdf5totfrecords.py
```
### Bridgev2
Run:
```bash
cd data/bridgev2
# Download and preprocess the dataset
sh download.sh
```
### Calvin
Run:
```bash
cd data/calvin
# Download and preprocess the dataset
sh download.sh
# Convert the dataset to TFRecord format
python hdf5totfrecords.py
```
### RH20T
Download the RH20T Dataset from there [official website](https://rh20t.github.io/#download) to `data/datasets/rh20t`, then run
```bash
cd data/rh20t
# Convert the dataset to TFRecord
python hdf5totfrecords.py
```
### RoboSet
Run:
```bash
cd data/roboset
# Download and preprocess the dataset
sh download.sh
```
## If Want to Train on a New Dataset
If you want to train on a new dataset (e.g., `my_pretrain_dataset`) through this pre-training pipeline, you need to modify several files as follows:
##### 1. `configs/dataset_control_freq.json`
Add the control frequency of your dataset.
##### 2. `data/preprocess_scripts/my_pretrain_dataset.py`
If your dataset can be loaded by `tfds.builder_from_directory()`, then you only need to download it into the folder of Open X-Embodiment `data/datasets/openx_embod` and implement the function of `process_step()`. You may need to specify the tfds loading path in L78 (see [this file](../data/vla_dataset.py#L78)). We refer to `data/preprocess_scripts/droid.py` for an example.
If not, you need to first convert it into TFRecords and then implement both `load_dataset()` and `process_step()`. We refer to `data/agilex/hdf5totfrecords.py` and `data/preprocess_scripts/agilex.py` for examples.
Here some descriptions:
##### `load_dataset(seed: int)`
- Returns a dataset that supports iterator and `repeat` method with a random seed.
- Suggested implementation: Use `tf.data.Dataset.from_generator` and `tf.data.TFRecordDataset`.
- The iterator should return a subdataset that supports iterator representing one episode with the following structure:
- `step`: A dataset object that supports iterator containing multiple frames per episode.
- `observation`: A dictionary containing your images.
- `your_first_image_key`: Your observation RGB image keys.
- ...
- `other_attribute`: Any other relevant attributes.
##### `process_step(step: dict) -> dict`
Processes a single frame and returns a dictionary with the following keys:
- `observation`:
- `your_first_view_image: tf.Tensor`: Your first view image.
- `arm_concat: tf.Tensor`: Concatenation of physical states.
- `format: tf.constant(string)`: Format of `arm_concat` (e.g., `arm_joint_pos_0,arm_joint_pos_1,arm_joint_pos_2`).
- `action`: Frame action (leave empty if there's none).
- `arm_concat`: Same as in `observation`.
- `format`: Same as in `observation`.
- `terminate: tf.Tensor`: Boolean Tensor indicates if the episode ends.
**IMPORTANT**: You should only use TensorFlow functions for any branch or loop operations. For example, use `tf.cond` instead of `if`.
##### 3. `configs/dataset_img_keys.json`
Add the image keys of your dataset. For example:
```json
"my_pretrain_dataset": {
"image_keys": [
"exterior-cam",
"right-wrist-cam",
"left-wrist-cam",
"left-wrist-cam"
],
"image_mask": [1, 1, 1, 0]
}
```
- To make TensorFlow happy, you have to specify four images in this order: `exterior-cam, right-wrist-cam, left-wrist-cam, any-cam`. Each key should correspond to your `step` attribute key of observation images.
- If you only have a single wrist, just make it a *right* wrist.
- The `image_mask` indicates whether each image is valid (1) or not (0).
- What if you don’t have four images? Simply repeat the images in the following positions and set their masks to 0 (invalid).
- The key order is *strict*. If you don't have the exterior camera but have both wrists, leave the exterior position blank (or pad) and use the following:
```json
"my_pretrain_dataset": {
"image_keys": [
"right-wrist-cam",
"right-wrist-cam",
"left-wrist-cam",
"left-wrist-cam"
],
"image_mask": [0, 1, 1, 0]
}
```
- During training, only the first *three* cameras will be used.
##### 4. `configs/dataset_stat.json`
Compute the statistics (min, max, mean, and std) for your dataset:
```bash
# Use -h to see the full usage
python -m data.compute_dataset_stat --skip_exist
```
This will update the `dataset_stat.json` file with your dataset's statistics.
##### 5. `data/vla_dataset.py`
- Add your dataset to `DATASET_NAMES_NOOPENX` if it cannot be loaded by `tfds.builder_from_directory()`.
- If your dataset only contains action but no proprioception (i.e., robot state), add your dataset to `DATASET_NAMES_NO_STATE` in [this file](../data/preprocess.py).
- Normally, we consider the future state as the action of current timestep. If you want to use different actions, you should implement more functions. We refer to `flatten_episode_agilex()` in [this file](../data/episode_transform.py) and `_generate_json_state_agilex()` in [this file](../data/preprocess.py) for examples. You may also refer to L318 in [this file](../data/preprocess.py) and L128 in [this file](../data/vla_dataset.py) for how to select your dataset and preprocess it differently.
## Start Pre-Training
We employ a producer-consumer framework with TensorFlow Dataset for fast data loading. Since most of the datasets in the Open X-Embodiment are stored in the form of `TFRecord`, we convert all pre-training datasets into `TFRecord` for storage. In pre-training, we use the producer process to decompress the data from `TFRecord` and store it in a buffer on the hard disk. At the same time, we use the consumer process to read data from the buffer in a disorderly order and feed it to the model training. This not only decouples the `TensorFlow` and `PyTorch` environments but also alleviates the training performance loss caused by the small size of the shuffling buffer in the memory.
[This file](../configs/base.yaml) includes configurations relevant to model architecture (including number of heads, hidden dimension, and so on) and data processing. You may need to modify `buf_path` (L22) to your real buffer path. This buffer is used as disk shuffling buffer for data loading.
Configurations relevant to training are passed through *Command Line Arguments*. Use `python main.py -h ` to see the descriptions. We provide an example pre-training script in [this file](../pretrain.sh) (`pretrain.sh`). You may need to modify some of the parameters in this file, such as `CUTLASS_PATH` and `WANDB_PROJECT`.
You may need to modify the list of pre-training datasets in [this file](../configs/pretrain_datasets.json) and their corresponding sampling weights in [this file](../configs/pretrain_sample_weights.json). If you want to fine-tune RDT through this pipeline, you may need to remove abundant datasets in the list.
Before start pre-training, we first start the data producer process (if you use multiple nodes, you should run this command in each node):
```bash
# Under the root directory of this repo
conda activate rdt-data
# Use -h to see the full usage
python -m data.producer --fill_up
# Please proceed to the next step AFTER finishing the filling up process
```
Then, we run the pre-training script:
```bash
source pretrain.sh
```
Note: You can monitor the training process by observing `loss` (through a long window moving average), `overall_avg_sample_mse`, and the sampling MSE of each dataset in [Wandb](https://wandb.ai/site) or [TensorBoard](https://www.tensorflow.org/tensorboard). We empirically found that the lower the `overall_avg_sample_mse`, the better the model performs.
|