Upload diffusion policy checkpoint at step 100000
Browse files
README.md
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Diffusion Policy Checkpoint - Step 100000
|
2 |
+
|
3 |
+
这是第100000步的diffusion policy训练checkpoint,用于机器人视觉运动控制。
|
4 |
+
|
5 |
+
## 模型信息
|
6 |
+
|
7 |
+
- **模型类型**: Diffusion Policy
|
8 |
+
- **训练步数**: 100,000
|
9 |
+
- **视觉骨干网络**: resnet18
|
10 |
+
- **输入观察步数**: 2
|
11 |
+
- **动作步数**: 8
|
12 |
+
- **时间范围**: 16
|
13 |
+
- **扩散步数**: 100
|
14 |
+
|
15 |
+
## 输入特征
|
16 |
+
|
17 |
+
- **状态观察**: [14] 维
|
18 |
+
- **头部相机图像**: [3, 256, 256] (RGB, 256x256)
|
19 |
+
|
20 |
+
## 输出特征
|
21 |
+
|
22 |
+
- **动作**: [16] 维
|
23 |
+
|
24 |
+
## 训练配置
|
25 |
+
|
26 |
+
- **批次大小**: 16
|
27 |
+
- **学习率**: 0.0001
|
28 |
+
- **优化器**: adam
|
29 |
+
- **数据集**: /home/shuo/research/lerobot/data/lerobot_dataset/rainbow_real
|
30 |
+
|
31 |
+
## 文件结构
|
32 |
+
|
33 |
+
```
|
34 |
+
100000/
|
35 |
+
├── pretrained_model/
|
36 |
+
│ ├── model.safetensors # 模型权重 (~1005MB)
|
37 |
+
│ ├── config.json # 模型配置
|
38 |
+
│ └── train_config.json # 训练配置
|
39 |
+
└── training_state/
|
40 |
+
├── optimizer_state.safetensors # 优化器状态 (~2GB)
|
41 |
+
├── scheduler_state.json # 学习率调度器状态
|
42 |
+
├── optimizer_param_groups.json # 优化器参数组
|
43 |
+
├── rng_state.safetensors # 随机数生成器状态
|
44 |
+
└── training_step.json # 训练步数信息
|
45 |
+
```
|
46 |
+
|
47 |
+
## 使用方法
|
48 |
+
|
49 |
+
### 加载模型
|
50 |
+
|
51 |
+
```python
|
52 |
+
import torch
|
53 |
+
from lerobot.common.policies.diffusion import DiffusionPolicy
|
54 |
+
|
55 |
+
# 加载配置
|
56 |
+
config_path = "pretrained_model/config.json"
|
57 |
+
with open(config_path, 'r') as f:
|
58 |
+
config = json.load(f)
|
59 |
+
|
60 |
+
# 创建模型
|
61 |
+
policy = DiffusionPolicy(config)
|
62 |
+
|
63 |
+
# 加载权重
|
64 |
+
checkpoint = torch.load("pretrained_model/model.safetensors", map_location='cpu')
|
65 |
+
policy.load_state_dict(checkpoint)
|
66 |
+
```
|
67 |
+
|
68 |
+
### 恢复训练
|
69 |
+
|
70 |
+
```python
|
71 |
+
# 加载训练状态
|
72 |
+
training_state = torch.load("training_state/optimizer_state.safetensors")
|
73 |
+
optimizer.load_state_dict(training_state)
|
74 |
+
|
75 |
+
# 继续训练
|
76 |
+
# ... 训练代码
|
77 |
+
```
|
78 |
+
|
79 |
+
## 数据集
|
80 |
+
|
81 |
+
此模型在rainbow_real数据集上训练,包含机器人操作任务。
|
82 |
+
|
83 |
+
## 许可证
|
84 |
+
|
85 |
+
请参考原始项目的许可证。
|