File size: 2,041 Bytes
8df2653 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
# Diffusion Policy Checkpoint - Step 100000
这是第100000步的diffusion policy训练checkpoint,用于机器人视觉运动控制。
## 模型信息
- **模型类型**: Diffusion Policy
- **训练步数**: 100,000
- **视觉骨干网络**: resnet18
- **输入观察步数**: 2
- **动作步数**: 8
- **时间范围**: 16
- **扩散步数**: 100
## 输入特征
- **状态观察**: [14] 维
- **头部相机图像**: [3, 256, 256] (RGB, 256x256)
## 输出特征
- **动作**: [16] 维
## 训练配置
- **批次大小**: 16
- **学习率**: 0.0001
- **优化器**: adam
- **数据集**: /home/shuo/research/lerobot/data/lerobot_dataset/rainbow_real
## 文件结构
```
100000/
├── pretrained_model/
│ ├── model.safetensors # 模型权重 (~1005MB)
│ ├── config.json # 模型配置
│ └── train_config.json # 训练配置
└── training_state/
├── optimizer_state.safetensors # 优化器状态 (~2GB)
├── scheduler_state.json # 学习率调度器状态
├── optimizer_param_groups.json # 优化器参数组
├── rng_state.safetensors # 随机数生成器状态
└── training_step.json # 训练步数信息
```
## 使用方法
### 加载模型
```python
import torch
from lerobot.common.policies.diffusion import DiffusionPolicy
# 加载配置
config_path = "pretrained_model/config.json"
with open(config_path, 'r') as f:
config = json.load(f)
# 创建模型
policy = DiffusionPolicy(config)
# 加载权重
checkpoint = torch.load("pretrained_model/model.safetensors", map_location='cpu')
policy.load_state_dict(checkpoint)
```
### 恢复训练
```python
# 加载训练状态
training_state = torch.load("training_state/optimizer_state.safetensors")
optimizer.load_state_dict(training_state)
# 继续训练
# ... 训练代码
```
## 数据集
此模型在rainbow_real数据集上训练,包含机器人操作任务。
## 许可证
请参考原始项目的许可证。
|