euijinrnd's picture
Add files using upload-large-folder tool
d899b9f verified
# Pipeline of Pre-Training RDT
Firstly, you need to install the prerequisites for RDT (see [README](../README.md#installation)). Then, you can install the prerequisites for TensorFlow Dataset (in another Conda environment).
## Installation for TensorFlow Dataset
```bash
# Under the root directory of this repo
conda create -n rdt-data python=3.10
conda activate rdt-data
# Install all the prequisites
pip install -r requirements_data.txt
# Or you can manually install each package (please refer to requirements_data.txt for specific versions)
pip install tfds-nightly gsutil tensorflow Pillow pyyaml opencv-python tensorflow-graphics imageio[ffmpeg]
# If the speed is too slow, you can specify alternative sources (tfds-nightly is not available in Tsinghua mirror)
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple gsutil tensorflow Pillow pyyaml opencv-python tensorflow-graphics imageio[ffmpeg]
```
## Download and Prepare Datasets
We introduce how to download each of our pre-training datasets. If you plan to pre-train on a subset of them, just download the ones you need. You can also fine-tune RDT through this pipeline only if your target dataset is included below or in the Google Cloud Storage.
| Dataset | Sample Percentage (%) |
| ---- | ---- |
| RT-1 Dataset | 9.00 |
| TACO Dataset | 1.99 |
| JACO Play Dataset | 1.10 |
| Cable Routing Dataset | 0.27 |
| NYU Door Opening | 0.33 |
| Viola | 0.40 |
| Berkeley UR5 | 1.06 |
| TOTO | 1.06 |
| Kuka | 1.66 |
| Language Table | 3.32 |
| Columbia Cairlab Pusht Real | 0.40 |
| Stanford Kuka Multimodal Dataset | 1.83 |
| Stanford Hydra Dataset | 0.80 |
| Austin Buds Dataset | 0.23 |
| Maniskill Dataset | 5.78 |
| Furniture Bench Dataset | 2.36 |
| UCSD Kitchen Dataset | 0.40 |
| UCSD Pick And Place Dataset | 1.23 |
| Austin Sailor Dataset | 0.50 |
| Austin Sirius Dataset | 0.80 |
| BC Z | 6.91 |
| UTokyo PR2 Opening Fridge | 0.30 |
| UTokyo PR2 Tabletop Manipulation | 0.50 |
| UTokyo Xarm Pick And Place | 0.33 |
| UTokyo Xarm Bimanual | 0.03 |
| Berkeley MVP | 0.73 |
| Berkeley RPT | 1.00 |
| KAIST Nonprehensile | 0.46 |
| Tokyo U LSMO | 0.23 |
| DLR Sara Grid Clamp | 0.03 |
| Robocook | 1.66 |
| Imperialcollege Sawyer Wrist Cam | 0.43 |
| Iamlab CMU Pickup Insert | 0.83 |
| UTAustin Mutex | 1.29 |
| Fanuc Manipulation | 0.66 |
| Play Fusion | 0.80 |
| Droid | 10.06 |
| FMB| 1.39 |
| Dobb·E | 1.20 |
| QUT Dexterous Manipulation | 0.46 |
| Aloha Dataset | 4.98 |
| Mobile Aloha Dataset | 4.98 |
| Roboset | 4.48 |
| RH20T | 10.99 |
| Calvin Dataset | 3.32 |
| Bridgev2 | 7.44 |
Before everything, let's link the dataset directory on your disk to a subfolder of this repo:
```bash
ln -s /path/to/dataset /path/to/repo/RoboticsDiffusionTransformer/data/datasets
```
### Open X-Embodiment
Specify the correct path to the `gsutil` in your Conda in [this file](../data/openx_embod/download.sh#L72).
Run the following commands to download our selected datasets for the Open X-Embodiment:
```bash
# Under the root directory of this repo
cd data/openx_embod
# Download all datasets
bash download_openx_embod.sh
```
Note: By modifying `download_openx_embod.sh`, you can download any dataset on the Google Cloud (as long as it can be downloaded with `gsutil` and is stored in `TFRecord` format), not just the ones we have listed.
### Mobile ALOHA Dataset
Download the Mobile ALOHA Dataset from the [official website](https://mobile-aloha.github.io) to `data/datasets/aloha`, then run:
```bash
cd data/aloha
# Convert the dataset to TFRecord
python hdf5totfrecords.py
```
### Bridgev2
Run:
```bash
cd data/bridgev2
# Download and preprocess the dataset
sh download.sh
```
### Calvin
Run:
```bash
cd data/calvin
# Download and preprocess the dataset
sh download.sh
# Convert the dataset to TFRecord format
python hdf5totfrecords.py
```
### RH20T
Download the RH20T Dataset from there [official website](https://rh20t.github.io/#download) to `data/datasets/rh20t`, then run
```bash
cd data/rh20t
# Convert the dataset to TFRecord
python hdf5totfrecords.py
```
### RoboSet
Run:
```bash
cd data/roboset
# Download and preprocess the dataset
sh download.sh
```
## If Want to Train on a New Dataset
If you want to train on a new dataset (e.g., `my_pretrain_dataset`) through this pre-training pipeline, you need to modify several files as follows:
##### 1. `configs/dataset_control_freq.json`
Add the control frequency of your dataset.
##### 2. `data/preprocess_scripts/my_pretrain_dataset.py`
If your dataset can be loaded by `tfds.builder_from_directory()`, then you only need to download it into the folder of Open X-Embodiment `data/datasets/openx_embod` and implement the function of `process_step()`. You may need to specify the tfds loading path in L78 (see [this file](../data/vla_dataset.py#L78)). We refer to `data/preprocess_scripts/droid.py` for an example.
If not, you need to first convert it into TFRecords and then implement both `load_dataset()` and `process_step()`. We refer to `data/agilex/hdf5totfrecords.py` and `data/preprocess_scripts/agilex.py` for examples.
Here some descriptions:
##### `load_dataset(seed: int)`
- Returns a dataset that supports iterator and `repeat` method with a random seed.
- Suggested implementation: Use `tf.data.Dataset.from_generator` and `tf.data.TFRecordDataset`.
- The iterator should return a subdataset that supports iterator representing one episode with the following structure:
- `step`: A dataset object that supports iterator containing multiple frames per episode.
- `observation`: A dictionary containing your images.
- `your_first_image_key`: Your observation RGB image keys.
- ...
- `other_attribute`: Any other relevant attributes.
##### `process_step(step: dict) -> dict`
Processes a single frame and returns a dictionary with the following keys:
- `observation`:
- `your_first_view_image: tf.Tensor`: Your first view image.
- `arm_concat: tf.Tensor`: Concatenation of physical states.
- `format: tf.constant(string)`: Format of `arm_concat` (e.g., `arm_joint_pos_0,arm_joint_pos_1,arm_joint_pos_2`).
- `action`: Frame action (leave empty if there's none).
- `arm_concat`: Same as in `observation`.
- `format`: Same as in `observation`.
- `terminate: tf.Tensor`: Boolean Tensor indicates if the episode ends.
**IMPORTANT**: You should only use TensorFlow functions for any branch or loop operations. For example, use `tf.cond` instead of `if`.
##### 3. `configs/dataset_img_keys.json`
Add the image keys of your dataset. For example:
```json
"my_pretrain_dataset": {
"image_keys": [
"exterior-cam",
"right-wrist-cam",
"left-wrist-cam",
"left-wrist-cam"
],
"image_mask": [1, 1, 1, 0]
}
```
- To make TensorFlow happy, you have to specify four images in this order: `exterior-cam, right-wrist-cam, left-wrist-cam, any-cam`. Each key should correspond to your `step` attribute key of observation images.
- If you only have a single wrist, just make it a *right* wrist.
- The `image_mask` indicates whether each image is valid (1) or not (0).
- What if you don’t have four images? Simply repeat the images in the following positions and set their masks to 0 (invalid).
- The key order is *strict*. If you don't have the exterior camera but have both wrists, leave the exterior position blank (or pad) and use the following:
```json
"my_pretrain_dataset": {
"image_keys": [
"right-wrist-cam",
"right-wrist-cam",
"left-wrist-cam",
"left-wrist-cam"
],
"image_mask": [0, 1, 1, 0]
}
```
- During training, only the first *three* cameras will be used.
##### 4. `configs/dataset_stat.json`
Compute the statistics (min, max, mean, and std) for your dataset:
```bash
# Use -h to see the full usage
python -m data.compute_dataset_stat --skip_exist
```
This will update the `dataset_stat.json` file with your dataset's statistics.
##### 5. `data/vla_dataset.py`
- Add your dataset to `DATASET_NAMES_NOOPENX` if it cannot be loaded by `tfds.builder_from_directory()`.
- If your dataset only contains action but no proprioception (i.e., robot state), add your dataset to `DATASET_NAMES_NO_STATE` in [this file](../data/preprocess.py).
- Normally, we consider the future state as the action of current timestep. If you want to use different actions, you should implement more functions. We refer to `flatten_episode_agilex()` in [this file](../data/episode_transform.py) and `_generate_json_state_agilex()` in [this file](../data/preprocess.py) for examples. You may also refer to L318 in [this file](../data/preprocess.py) and L128 in [this file](../data/vla_dataset.py) for how to select your dataset and preprocess it differently.
## Start Pre-Training
We employ a producer-consumer framework with TensorFlow Dataset for fast data loading. Since most of the datasets in the Open X-Embodiment are stored in the form of `TFRecord`, we convert all pre-training datasets into `TFRecord` for storage. In pre-training, we use the producer process to decompress the data from `TFRecord` and store it in a buffer on the hard disk. At the same time, we use the consumer process to read data from the buffer in a disorderly order and feed it to the model training. This not only decouples the `TensorFlow` and `PyTorch` environments but also alleviates the training performance loss caused by the small size of the shuffling buffer in the memory.
[This file](../configs/base.yaml) includes configurations relevant to model architecture (including number of heads, hidden dimension, and so on) and data processing. You may need to modify `buf_path` (L22) to your real buffer path. This buffer is used as disk shuffling buffer for data loading.
Configurations relevant to training are passed through *Command Line Arguments*. Use `python main.py -h ` to see the descriptions. We provide an example pre-training script in [this file](../pretrain.sh) (`pretrain.sh`). You may need to modify some of the parameters in this file, such as `CUTLASS_PATH` and `WANDB_PROJECT`.
You may need to modify the list of pre-training datasets in [this file](../configs/pretrain_datasets.json) and their corresponding sampling weights in [this file](../configs/pretrain_sample_weights.json). If you want to fine-tune RDT through this pipeline, you may need to remove abundant datasets in the list.
Before start pre-training, we first start the data producer process (if you use multiple nodes, you should run this command in each node):
```bash
# Under the root directory of this repo
conda activate rdt-data
# Use -h to see the full usage
python -m data.producer --fill_up
# Please proceed to the next step AFTER finishing the filling up process
```
Then, we run the pre-training script:
```bash
source pretrain.sh
```
Note: You can monitor the training process by observing `loss` (through a long window moving average), `overall_avg_sample_mse`, and the sampling MSE of each dataset in [Wandb](https://wandb.ai/site) or [TensorBoard](https://www.tensorflow.org/tensorboard). We empirically found that the lower the `overall_avg_sample_mse`, the better the model performs.