|
import rclpy |
|
from sensor_msgs.msg import JointState |
|
import numpy as np |
|
import time |
|
import h5py |
|
from real_robot_env import AnubisRobotEnv |
|
import cv2 |
|
from tqdm import tqdm |
|
import os |
|
from scripts.agilex_model import create_model |
|
import random |
|
from PIL import Image, ImageDraw, ImageFont |
|
import yaml |
|
import torch |
|
|
|
class RDTInferenceRobotEnv(AnubisRobotEnv): |
|
def __init__(self, hz=20, max_timestep=500, task_name='', num_rollout=1, model_name='1st'): |
|
self.model_name = model_name |
|
self.checkpoint = f'/home/rllab/workspace/jellyho/RoboticsDiffusionTransformer/checkpoints/{task_name}' |
|
super().__init__(hz=hz, max_timestep=max_timestep, task_name=task_name, num_rollout=num_rollout) |
|
self.last_front_img = None |
|
self.right_wrist_img = None |
|
self.left_wrist_img = None |
|
|
|
def bringup_model(self): |
|
with open('configs/base.yaml', "r") as fp: |
|
config = yaml.safe_load(fp) |
|
self.model = create_model( |
|
args=config, |
|
dtype=torch.bfloat16, |
|
pretrained=self.checkpoint, |
|
pretrained_text_encoder_name_or_path="google/t5-v1_1-xxl", |
|
pretrained_vision_encoder_name_or_path="google/siglip-so400m-patch14-384", |
|
control_frequency=20 |
|
) |
|
print('model loaded') |
|
|
|
def inference(self): |
|
proprio = self.obs['eef_pose'] |
|
if self.action_counter % 64 == 0: |
|
front_img = self.obs['agentview_image'] |
|
right_wrist_img = self.obs['rightview_image'] |
|
left_wrist_img = self.obs['leftview_image'] |
|
image_arrs = [ |
|
self.last_front_img, |
|
self.last_right_wrist_img, |
|
self.last_left_wrist_img, |
|
front_img, |
|
right_wrist_img, |
|
left_wrist_img |
|
] |
|
images = [Image.fromarray(arr) if arr is not None else None for arr in image_arrs] |
|
proprio = torch.tensor(proprio).unsqueeze(0) |
|
with torch.inference_mode(): |
|
self.actions = self.model.step( |
|
proprio=proprio, |
|
images=images, |
|
instruction=self.instruction |
|
).squeeze(0).cpu().numpy() |
|
idx = self.action_counter % 64 |
|
act = self.actions[idx] |
|
self.action_counter += 1 |
|
self.last_front_img = self.obs['agentview_image'] |
|
self.last_right_wrist_img = self.obs['rightview_image'] |
|
self.last_left_wrist_img = self.obs['leftview_image'] |
|
self.send_action(act) |
|
|
|
if __name__ == '__main__': |
|
task_name = 'anubis_towel_kirby' |
|
rollout_num = 20 |
|
hz = 20 |
|
model_name = 'rdt' |
|
|
|
node = RDTInferenceRobotEnv( |
|
hz=hz, |
|
max_timestep=800, |
|
task_name=task_name, |
|
num_rollout=rollout_num, |
|
model_name=model_name |
|
) |
|
while node.rollout_counter < rollout_num: |
|
try: |
|
img = cv2.cvtColor(node.obs['agentview_image'], cv2.COLOR_RGB2BGR) |
|
if node.start: |
|
node.window.show(img, text=node.instruction) |
|
else: |
|
node.window.show(img, overlay_img=node.overlay_img, text=node.instruction) |
|
node.last_front_img = node.obs['agentview_image'] |
|
node.last_right_wrist_img = node.obs['rightview_image'] |
|
node.last_left_wrist_img = node.obs['leftview_image'] |
|
except KeyboardInterrupt: |
|
node.ros_close() |
|
|
|
except Exception as e: |
|
print(f"An error occurred: {e}") |
|
node.ros_close() |