import rclpy from sensor_msgs.msg import JointState import numpy as np import time import h5py from real_robot_env import AnubisRobotEnv import cv2 from tqdm import tqdm import os from scripts.agilex_model import create_model import random from PIL import Image, ImageDraw, ImageFont import yaml import torch class RDTInferenceRobotEnv(AnubisRobotEnv): def __init__(self, hz=20, max_timestep=500, task_name='', num_rollout=1, model_name='1st'): self.model_name = model_name self.checkpoint = f'/home/rllab/workspace/jellyho/RoboticsDiffusionTransformer/checkpoints/{task_name}' super().__init__(hz=hz, max_timestep=max_timestep, task_name=task_name, num_rollout=num_rollout) self.last_front_img = None self.right_wrist_img = None self.left_wrist_img = None def bringup_model(self): with open('configs/base.yaml', "r") as fp: config = yaml.safe_load(fp) self.model = create_model( args=config, dtype=torch.bfloat16, pretrained=self.checkpoint, pretrained_text_encoder_name_or_path="google/t5-v1_1-xxl", pretrained_vision_encoder_name_or_path="google/siglip-so400m-patch14-384", control_frequency=20 ) print('model loaded') def inference(self): proprio = self.obs['eef_pose'] if self.action_counter % 64 == 0: front_img = self.obs['agentview_image'] right_wrist_img = self.obs['rightview_image'] left_wrist_img = self.obs['leftview_image'] image_arrs = [ self.last_front_img, self.last_right_wrist_img, self.last_left_wrist_img, front_img, right_wrist_img, left_wrist_img ] images = [Image.fromarray(arr) if arr is not None else None for arr in image_arrs] proprio = torch.tensor(proprio).unsqueeze(0) with torch.inference_mode(): self.actions = self.model.step( proprio=proprio, images=images, instruction=self.instruction ).squeeze(0).cpu().numpy() idx = self.action_counter % 64 act = self.actions[idx] self.action_counter += 1 self.last_front_img = self.obs['agentview_image'] self.last_right_wrist_img = self.obs['rightview_image'] self.last_left_wrist_img = self.obs['leftview_image'] self.send_action(act) if __name__ == '__main__': task_name = 'anubis_towel_kirby' rollout_num = 20 hz = 20 model_name = 'rdt' # model_name = 'oxe' node = RDTInferenceRobotEnv( hz=hz, max_timestep=800, task_name=task_name, num_rollout=rollout_num, model_name=model_name ) while node.rollout_counter < rollout_num: try: img = cv2.cvtColor(node.obs['agentview_image'], cv2.COLOR_RGB2BGR) if node.start: node.window.show(img, text=node.instruction) else: node.window.show(img, overlay_img=node.overlay_img, text=node.instruction) node.last_front_img = node.obs['agentview_image'] node.last_right_wrist_img = node.obs['rightview_image'] node.last_left_wrist_img = node.obs['leftview_image'] except KeyboardInterrupt: node.ros_close() except Exception as e: print(f"An error occurred: {e}") node.ros_close()