RoboticsDiffusionTransformer / real_robot_env.py
euijinrnd's picture
Add files using upload-large-folder tool
9de9fbf verified
raw
history blame
10.5 kB
import dm_env
from absl import logging
import rclpy
from sensor_msgs.msg import Image, JointState
from std_msgs.msg import Bool
from std_msgs.msg import Int32
import numpy as np
import threading
import time
# from visualize_utils import window
import random
from scipy.spatial.transform import Rotation
from glob import glob
import os
import h5py
import cv2
class AnubisRobotEnv:
def __init__(self, hz=20, max_timestep=1000, task_name='', num_rollout=1):
rclpy.init() # initialize ROS2 node
self._node = rclpy.create_node('anubis_robot_env_node')
self._subscriber_bringup()
print('ROS2 node created')
self.window = None
self.start = False
self.thread_done = False
self.hz = hz # control frequency
self.action_counter = 0
self.num_rollout = num_rollout
self.rollout_counter = 0
self.lang_dict = {
'anubis_brush_to_pan' : 'insert the brush to the dustpan',
'anubis_carrot_to_bag' : 'pick up the carrot and put into the bag',
'anubis_towel_kirby' : 'take the towel off the kirby doll'
}
self.task_name = task_name
self.instruction = self.lang_dict[self.task_name]
self.data_list = glob(f'/home/rllab/workspace/jellyho/demo_collection/{self.task_name}/*.hdf5')
self.overlay_img = None
self.max_timestep = max_timestep
self.init_action = JointState()
self.init_action.position = [
0.20620185010895048,
0.16183641523267392,
0.2277105000367078,
-0.42093861525667453,
0.6546518510233503,
-0.5770953981378887,
0.24739146627474096,
-1.6, #
0.21136149716403216,
-0.16027684481842075,
0.21879985782478842,
0.6606782591766969,
-0.428768621033297,
0.2340722378552696,
-0.569975345900049,
-1.6
]
print('Initializing Anubis Robot Environment')
self.thread = PeriodicThread(1/self.hz, self.timer_callback)
self.thread.start()
self.video_thread = PeriodicThread(1/30, self.video_timer_callback)
self.video_thread.start()
self.timer_thread = threading.Thread(target=rclpy.spin, args=(self._node,), daemon=True)
self.timer_thread.start()
print('Threads started')
self.bringup_model()
self.initialize()
logging.set_verbosity(logging.INFO)
logging.info('AnubisRobotEnv successfully initialized.')
def init_robot_pose(self, demo):
print('Initializing robot pose', demo % len(self.data_list))
root = h5py.File(self.data_list[demo % len(self.data_list)], 'r')
first_action = root['action']['eef_pose'][0]
self.publish_action(first_action)
def initialize(self):
self.curr_timestep = 0
if self.window is None:
from visualize_utils import window
self.window = window('ENV Observation', video_path=f'{self.model_name}-{self.task_name}', video_fps=30, video_size=(640, 480), show=False)
else:
self.window.init_video()
self.send_demo(self.rollout_counter)
self.init_robot_pose(self.rollout_counter)
def reset(self):
while not self.thread_done:
time.sleep(0.01)
continue
self.thread_done = False
return dm_env.restart(observation=self._observation())
def bringup_model(self):
raise NotImplementedError
def inference(self):
raise NotImplementedError
def ros_close(self):
self.thread.stop()
self.timer_thread.stop()
self._node.destroy_node()
rclpy.shutdown()
def _subscriber_bringup(self):
'''
Note: This function creates all the subscribers \
for reading joint and gripper states.
'''
###### Initial Setup #####
self.obs = {}
self.action = {}
###### OBSERVATION ######
# image
self._node.create_subscription(Image, '/camera_center/camera/color/image_raw', self.agentview_image_callback, 10)
self.obs['agentview_image'] = np.zeros(shape=(480, 640, 3), dtype=np.uint8)
self._node.create_subscription(Image, '/camera_right/camera/color/image_raw', self.rightview_image_callback, 10)
self.obs['rightview_image'] = np.zeros(shape=(480, 640, 3), dtype=np.uint8)
self._node.create_subscription(Image, '/camera_left/camera/color/image_raw', self.leftview_image_callback, 10)
self.obs['leftview_image'] = np.zeros(shape=(480, 640, 3), dtype=np.uint8)
# # arm pose states
self._node.create_subscription(JointState, '/eef_pose', self.eef_pose_callback, 10)
self.obs['eef_pose'] = np.zeros(shape=(20,), dtype=np.float64)
# # gripper joint states
self.obs['language_instruction'] = ''
##### TRIGGER #####
self._node.create_subscription(Bool, '/done', self.done_callback, 10)
self.demo_pub = self._node.create_publisher(Int32, '/demo', 10)
self.action_pub = self._node.create_publisher(JointState, '/teleop/eef_pose', 10)
def send_demo(self, num):
demo_msg = Int32()
demo_msg.data = num
self.demo_pub.publish(demo_msg)
#### OBS ###########
def agentview_image_callback(self, msg):
self.obs['agentview_image'] = np.reshape(msg.data, (480, 640, 3))
def rightview_image_callback(self, msg):
rightview = np.reshape(msg.data, (480, 640, 3))
self.obs['rightview_image'] = np.rot90(rightview, 2)
def leftview_image_callback(self, msg):
self.obs['leftview_image'] = np.reshape(msg.data, (480, 640, 3))
def eef_pose_callback(self, msg):
recevied_data = np.array(msg.position)
eef_pose_data = np.zeros(shape=(20,), dtype=np.float64)
eef_pose_data[:3] = recevied_data[:3]
eef_pose_data[3:9] = self.quat_to_6d(recevied_data[3:7], scalar_first=False)
eef_pose_data[9] = recevied_data[7]
eef_pose_data[10:13] = recevied_data[8:11]
eef_pose_data[13:19] = self.quat_to_6d(recevied_data[11:15], scalar_first=False)
eef_pose_data[19] = recevied_data[15]
self.obs['eef_pose'] = eef_pose_data
def send_action(self, act):
if self.start:
action_msg = JointState()
# print('action msg', act)
# print(act, act[9] < 0, act[-1] < 0)
# act[9] = -1.6 if act[9] > 0 else 0.1
# act[-1] = -1.6 if act[-1] > 0 else 0.1
# Assign the NumPy array to the data field of the message
action_msg_data = np.zeros(16)
action_msg_data[0:3] = act[0:3]
action_msg_data[3:7] = self.sixd_to_quat(act[3:9])
action_msg_data[7] = act[9]
action_msg_data[8:11] = act[10:13]
action_msg_data[11:15] = self.sixd_to_quat(act[13:19])
action_msg_data[15] = act[19]
action_msg.position = action_msg_data.astype(float).tolist()
self.action_pub.publish(action_msg)
def publish_action(self, action):
action_msg = JointState()
# Assign the NumPy array to the data field of the message
# Squeeze the action to remove any extra dimensions
action = action.squeeze()
action_msg_data = np.zeros(16)
action_msg_data[0:3] = action[0:3]
action_msg_data[3:7] = self.sixd_to_quat(action[3:9])
action_msg_data[7] = action[9]
action_msg_data[8:11] = action[10:13]
action_msg_data[11:15] = self.sixd_to_quat(action[13:19])
action_msg_data[15] = action[19]
action_msg.position = action_msg_data.astype(float).tolist()
self.action_pub.publish(action_msg)
def done_callback(self, msg):
if not self.start:
print('Inference & Video Recording Start')
self.start = True
self.window.video_start()
else:
self.start = False
self.action_counter = 0
self.rollout_counter += 1
if self.window.video_recording:
self.window.video_stop()
self.initialize()
print('Next Inference Ready')
def timer_callback(self):
if self.start:
self.inference()
self.curr_timestep += 1
if self.curr_timestep >= self.max_timestep:
print("Max timestep reached, resetting environment.")
self.start = False
if self.window.video_recording:
self.window.video_stop()
self.rollout_counter += 1
self.action_counter = 0
self.curr_timestep = 0
self.initialize()
self.thread_done = True
def video_timer_callback(self):
if self.start and self.window.video_recording:
self.window.video_write()
def quat_to_6d(self, quat, scalar_first=False):
r = Rotation.from_quat(quat, scalar_first=scalar_first)
mat = r.as_matrix()
return mat[:, :2].flatten()
def sixd_to_quat(self, sixd, scalar_first=False):
mat = np.zeros((3, 3))
mat[:, :2] = sixd.reshape(3, 2)
mat[:, 2] = np.cross(mat[:, 0], mat[:, 1])
r = Rotation.from_matrix(mat)
return r.as_quat(scalar_first=scalar_first)
def ros_close(self):
if self.window.video_recording:
self.window.video_stop()
self.thread.stop()
self.video_thread.stop()
self.timer_thread.stop()
self._node.destroy_node()
rclpy.shutdown()
class PeriodicThread(threading.Thread):
def __init__(self, interval, function, *args, **kwargs):
super().__init__()
self.interval = interval
self.function = function
self.args = args
self.kwargs = kwargs
self.stop_event = threading.Event()
self._lock = threading.Lock()
def run(self):
while not self.stop_event.is_set():
start_time = time.time()
self.function(*self.args, **self.kwargs)
elapsed_time = time.time() - start_time
sleep_time = max(0, self.interval - elapsed_time)
time.sleep(sleep_time)
def stop(self):
self.stop_event.set()
def change_period(self, new_interval):
with self._lock:
self.interval = new_interval