Upload diffusion policy checkpoint at step 100000
Browse files
README.md
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Diffusion Policy Checkpoint - Step 100000
|
| 2 |
+
|
| 3 |
+
这是第100000步的diffusion policy训练checkpoint,用于机器人视觉运动控制。
|
| 4 |
+
|
| 5 |
+
## 模型信息
|
| 6 |
+
|
| 7 |
+
- **模型类型**: Diffusion Policy
|
| 8 |
+
- **训练步数**: 100,000
|
| 9 |
+
- **视觉骨干网络**: resnet18
|
| 10 |
+
- **输入观察步数**: 2
|
| 11 |
+
- **动作步数**: 8
|
| 12 |
+
- **时间范围**: 16
|
| 13 |
+
- **扩散步数**: 100
|
| 14 |
+
|
| 15 |
+
## 输入特征
|
| 16 |
+
|
| 17 |
+
- **状态观察**: [14] 维
|
| 18 |
+
- **头部相机图像**: [3, 256, 256] (RGB, 256x256)
|
| 19 |
+
|
| 20 |
+
## 输出特征
|
| 21 |
+
|
| 22 |
+
- **动作**: [16] 维
|
| 23 |
+
|
| 24 |
+
## 训练配置
|
| 25 |
+
|
| 26 |
+
- **批次大小**: 16
|
| 27 |
+
- **学习率**: 0.0001
|
| 28 |
+
- **优化器**: adam
|
| 29 |
+
- **数据集**: /home/shuo/research/lerobot/data/lerobot_dataset/rainbow_real
|
| 30 |
+
|
| 31 |
+
## 文件结构
|
| 32 |
+
|
| 33 |
+
```
|
| 34 |
+
100000/
|
| 35 |
+
├── pretrained_model/
|
| 36 |
+
│ ├── model.safetensors # 模型权重 (~1005MB)
|
| 37 |
+
│ ├── config.json # 模型配置
|
| 38 |
+
│ └── train_config.json # 训练配置
|
| 39 |
+
└── training_state/
|
| 40 |
+
├── optimizer_state.safetensors # 优化器状态 (~2GB)
|
| 41 |
+
├── scheduler_state.json # 学习率调度器状态
|
| 42 |
+
├── optimizer_param_groups.json # 优化器参数组
|
| 43 |
+
├── rng_state.safetensors # 随机数生成器状态
|
| 44 |
+
└── training_step.json # 训练步数信息
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
## 使用方法
|
| 48 |
+
|
| 49 |
+
### 加载模型
|
| 50 |
+
|
| 51 |
+
```python
|
| 52 |
+
import torch
|
| 53 |
+
from lerobot.common.policies.diffusion import DiffusionPolicy
|
| 54 |
+
|
| 55 |
+
# 加载配置
|
| 56 |
+
config_path = "pretrained_model/config.json"
|
| 57 |
+
with open(config_path, 'r') as f:
|
| 58 |
+
config = json.load(f)
|
| 59 |
+
|
| 60 |
+
# 创建模型
|
| 61 |
+
policy = DiffusionPolicy(config)
|
| 62 |
+
|
| 63 |
+
# 加载权重
|
| 64 |
+
checkpoint = torch.load("pretrained_model/model.safetensors", map_location='cpu')
|
| 65 |
+
policy.load_state_dict(checkpoint)
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
### 恢复训练
|
| 69 |
+
|
| 70 |
+
```python
|
| 71 |
+
# 加载训练状态
|
| 72 |
+
training_state = torch.load("training_state/optimizer_state.safetensors")
|
| 73 |
+
optimizer.load_state_dict(training_state)
|
| 74 |
+
|
| 75 |
+
# 继续训练
|
| 76 |
+
# ... 训练代码
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
## 数据集
|
| 80 |
+
|
| 81 |
+
此模型在rainbow_real数据集上训练,包含机器人操作任务。
|
| 82 |
+
|
| 83 |
+
## 许可证
|
| 84 |
+
|
| 85 |
+
请参考原始项目的许可证。
|