File size: 3,576 Bytes
9de9fbf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import rclpy
from sensor_msgs.msg import JointState
import numpy as np
import time
import h5py
from real_robot_env import AnubisRobotEnv
import cv2
from tqdm import tqdm
import os
from scripts.agilex_model import create_model
import random
from PIL import Image, ImageDraw, ImageFont
import yaml
import torch
class RDTInferenceRobotEnv(AnubisRobotEnv):
def __init__(self, hz=20, max_timestep=500, task_name='', num_rollout=1, model_name='1st'):
self.model_name = model_name
self.checkpoint = f'/home/rllab/workspace/jellyho/RoboticsDiffusionTransformer/checkpoints/{task_name}'
super().__init__(hz=hz, max_timestep=max_timestep, task_name=task_name, num_rollout=num_rollout)
self.last_front_img = None
self.right_wrist_img = None
self.left_wrist_img = None
def bringup_model(self):
with open('configs/base.yaml', "r") as fp:
config = yaml.safe_load(fp)
self.model = create_model(
args=config,
dtype=torch.bfloat16,
pretrained=self.checkpoint,
pretrained_text_encoder_name_or_path="google/t5-v1_1-xxl",
pretrained_vision_encoder_name_or_path="google/siglip-so400m-patch14-384",
control_frequency=20
)
print('model loaded')
def inference(self):
proprio = self.obs['eef_pose']
if self.action_counter % 64 == 0:
front_img = self.obs['agentview_image']
right_wrist_img = self.obs['rightview_image']
left_wrist_img = self.obs['leftview_image']
image_arrs = [
self.last_front_img,
self.last_right_wrist_img,
self.last_left_wrist_img,
front_img,
right_wrist_img,
left_wrist_img
]
images = [Image.fromarray(arr) if arr is not None else None for arr in image_arrs]
proprio = torch.tensor(proprio).unsqueeze(0)
with torch.inference_mode():
self.actions = self.model.step(
proprio=proprio,
images=images,
instruction=self.instruction
).squeeze(0).cpu().numpy()
idx = self.action_counter % 64
act = self.actions[idx]
self.action_counter += 1
self.last_front_img = self.obs['agentview_image']
self.last_right_wrist_img = self.obs['rightview_image']
self.last_left_wrist_img = self.obs['leftview_image']
self.send_action(act)
if __name__ == '__main__':
task_name = 'anubis_towel_kirby'
rollout_num = 20
hz = 20
model_name = 'rdt'
# model_name = 'oxe'
node = RDTInferenceRobotEnv(
hz=hz,
max_timestep=800,
task_name=task_name,
num_rollout=rollout_num,
model_name=model_name
)
while node.rollout_counter < rollout_num:
try:
img = cv2.cvtColor(node.obs['agentview_image'], cv2.COLOR_RGB2BGR)
if node.start:
node.window.show(img, text=node.instruction)
else:
node.window.show(img, overlay_img=node.overlay_img, text=node.instruction)
node.last_front_img = node.obs['agentview_image']
node.last_right_wrist_img = node.obs['rightview_image']
node.last_left_wrist_img = node.obs['leftview_image']
except KeyboardInterrupt:
node.ros_close()
except Exception as e:
print(f"An error occurred: {e}")
node.ros_close() |