|
import dm_env |
|
from absl import logging |
|
|
|
import rclpy |
|
from sensor_msgs.msg import Image, JointState |
|
from std_msgs.msg import Bool |
|
from std_msgs.msg import Int32 |
|
import numpy as np |
|
import threading |
|
import time |
|
|
|
import random |
|
from scipy.spatial.transform import Rotation |
|
from glob import glob |
|
import os |
|
import h5py |
|
import cv2 |
|
|
|
class AnubisRobotEnv: |
|
def __init__(self, hz=20, max_timestep=1000, task_name='', num_rollout=1): |
|
rclpy.init() |
|
self._node = rclpy.create_node('anubis_robot_env_node') |
|
self._subscriber_bringup() |
|
print('ROS2 node created') |
|
|
|
self.window = None |
|
self.start = False |
|
self.thread_done = False |
|
self.hz = hz |
|
self.action_counter = 0 |
|
self.num_rollout = num_rollout |
|
self.rollout_counter = 0 |
|
|
|
self.lang_dict = { |
|
'anubis_brush_to_pan' : 'insert the brush to the dustpan', |
|
'anubis_carrot_to_bag' : 'pick up the carrot and put into the bag', |
|
'anubis_towel_kirby' : 'take the towel off the kirby doll' |
|
} |
|
self.task_name = task_name |
|
self.instruction = self.lang_dict[self.task_name] |
|
self.data_list = glob(f'/home/rllab/workspace/jellyho/demo_collection/{self.task_name}/*.hdf5') |
|
|
|
self.overlay_img = None |
|
self.max_timestep = max_timestep |
|
|
|
self.init_action = JointState() |
|
self.init_action.position = [ |
|
0.20620185010895048, |
|
0.16183641523267392, |
|
0.2277105000367078, |
|
-0.42093861525667453, |
|
0.6546518510233503, |
|
-0.5770953981378887, |
|
0.24739146627474096, |
|
-1.6, |
|
0.21136149716403216, |
|
-0.16027684481842075, |
|
0.21879985782478842, |
|
0.6606782591766969, |
|
-0.428768621033297, |
|
0.2340722378552696, |
|
-0.569975345900049, |
|
-1.6 |
|
] |
|
|
|
print('Initializing Anubis Robot Environment') |
|
|
|
self.thread = PeriodicThread(1/self.hz, self.timer_callback) |
|
self.thread.start() |
|
|
|
self.video_thread = PeriodicThread(1/30, self.video_timer_callback) |
|
self.video_thread.start() |
|
|
|
self.timer_thread = threading.Thread(target=rclpy.spin, args=(self._node,), daemon=True) |
|
self.timer_thread.start() |
|
print('Threads started') |
|
|
|
self.bringup_model() |
|
self.initialize() |
|
logging.set_verbosity(logging.INFO) |
|
logging.info('AnubisRobotEnv successfully initialized.') |
|
|
|
def init_robot_pose(self, demo): |
|
print('Initializing robot pose', demo % len(self.data_list)) |
|
root = h5py.File(self.data_list[demo % len(self.data_list)], 'r') |
|
first_action = root['action']['eef_pose'][0] |
|
self.publish_action(first_action) |
|
|
|
def initialize(self): |
|
self.curr_timestep = 0 |
|
if self.window is None: |
|
from visualize_utils import window |
|
self.window = window('ENV Observation', video_path=f'{self.model_name}-{self.task_name}', video_fps=30, video_size=(640, 480), show=False) |
|
else: |
|
self.window.init_video() |
|
self.send_demo(self.rollout_counter) |
|
self.init_robot_pose(self.rollout_counter) |
|
|
|
def reset(self): |
|
while not self.thread_done: |
|
time.sleep(0.01) |
|
continue |
|
self.thread_done = False |
|
return dm_env.restart(observation=self._observation()) |
|
|
|
def bringup_model(self): |
|
raise NotImplementedError |
|
|
|
def inference(self): |
|
raise NotImplementedError |
|
|
|
def ros_close(self): |
|
self.thread.stop() |
|
self.timer_thread.stop() |
|
self._node.destroy_node() |
|
rclpy.shutdown() |
|
|
|
def _subscriber_bringup(self): |
|
''' |
|
Note: This function creates all the subscribers \ |
|
for reading joint and gripper states. |
|
''' |
|
|
|
self.obs = {} |
|
self.action = {} |
|
|
|
|
|
|
|
self._node.create_subscription(Image, '/camera_center/camera/color/image_raw', self.agentview_image_callback, 10) |
|
self.obs['agentview_image'] = np.zeros(shape=(480, 640, 3), dtype=np.uint8) |
|
|
|
self._node.create_subscription(Image, '/camera_right/camera/color/image_raw', self.rightview_image_callback, 10) |
|
self.obs['rightview_image'] = np.zeros(shape=(480, 640, 3), dtype=np.uint8) |
|
|
|
self._node.create_subscription(Image, '/camera_left/camera/color/image_raw', self.leftview_image_callback, 10) |
|
self.obs['leftview_image'] = np.zeros(shape=(480, 640, 3), dtype=np.uint8) |
|
|
|
|
|
self._node.create_subscription(JointState, '/eef_pose', self.eef_pose_callback, 10) |
|
self.obs['eef_pose'] = np.zeros(shape=(20,), dtype=np.float64) |
|
|
|
|
|
self.obs['language_instruction'] = '' |
|
|
|
|
|
self._node.create_subscription(Bool, '/done', self.done_callback, 10) |
|
|
|
self.demo_pub = self._node.create_publisher(Int32, '/demo', 10) |
|
self.action_pub = self._node.create_publisher(JointState, '/teleop/eef_pose', 10) |
|
|
|
def send_demo(self, num): |
|
demo_msg = Int32() |
|
demo_msg.data = num |
|
self.demo_pub.publish(demo_msg) |
|
|
|
|
|
def agentview_image_callback(self, msg): |
|
self.obs['agentview_image'] = np.reshape(msg.data, (480, 640, 3)) |
|
|
|
def rightview_image_callback(self, msg): |
|
rightview = np.reshape(msg.data, (480, 640, 3)) |
|
self.obs['rightview_image'] = np.rot90(rightview, 2) |
|
|
|
def leftview_image_callback(self, msg): |
|
self.obs['leftview_image'] = np.reshape(msg.data, (480, 640, 3)) |
|
|
|
def eef_pose_callback(self, msg): |
|
recevied_data = np.array(msg.position) |
|
eef_pose_data = np.zeros(shape=(20,), dtype=np.float64) |
|
eef_pose_data[:3] = recevied_data[:3] |
|
eef_pose_data[3:9] = self.quat_to_6d(recevied_data[3:7], scalar_first=False) |
|
eef_pose_data[9] = recevied_data[7] |
|
eef_pose_data[10:13] = recevied_data[8:11] |
|
eef_pose_data[13:19] = self.quat_to_6d(recevied_data[11:15], scalar_first=False) |
|
eef_pose_data[19] = recevied_data[15] |
|
self.obs['eef_pose'] = eef_pose_data |
|
|
|
def send_action(self, act): |
|
if self.start: |
|
action_msg = JointState() |
|
|
|
|
|
|
|
|
|
|
|
action_msg_data = np.zeros(16) |
|
action_msg_data[0:3] = act[0:3] |
|
action_msg_data[3:7] = self.sixd_to_quat(act[3:9]) |
|
action_msg_data[7] = act[9] |
|
action_msg_data[8:11] = act[10:13] |
|
action_msg_data[11:15] = self.sixd_to_quat(act[13:19]) |
|
action_msg_data[15] = act[19] |
|
action_msg.position = action_msg_data.astype(float).tolist() |
|
self.action_pub.publish(action_msg) |
|
|
|
def publish_action(self, action): |
|
action_msg = JointState() |
|
|
|
|
|
|
|
action = action.squeeze() |
|
action_msg_data = np.zeros(16) |
|
action_msg_data[0:3] = action[0:3] |
|
action_msg_data[3:7] = self.sixd_to_quat(action[3:9]) |
|
action_msg_data[7] = action[9] |
|
action_msg_data[8:11] = action[10:13] |
|
action_msg_data[11:15] = self.sixd_to_quat(action[13:19]) |
|
action_msg_data[15] = action[19] |
|
action_msg.position = action_msg_data.astype(float).tolist() |
|
self.action_pub.publish(action_msg) |
|
|
|
def done_callback(self, msg): |
|
if not self.start: |
|
print('Inference & Video Recording Start') |
|
self.start = True |
|
self.window.video_start() |
|
else: |
|
self.start = False |
|
self.action_counter = 0 |
|
self.rollout_counter += 1 |
|
if self.window.video_recording: |
|
self.window.video_stop() |
|
self.initialize() |
|
print('Next Inference Ready') |
|
|
|
def timer_callback(self): |
|
if self.start: |
|
self.inference() |
|
self.curr_timestep += 1 |
|
if self.curr_timestep >= self.max_timestep: |
|
print("Max timestep reached, resetting environment.") |
|
self.start = False |
|
if self.window.video_recording: |
|
self.window.video_stop() |
|
self.rollout_counter += 1 |
|
self.action_counter = 0 |
|
self.curr_timestep = 0 |
|
self.initialize() |
|
self.thread_done = True |
|
|
|
def video_timer_callback(self): |
|
if self.start and self.window.video_recording: |
|
self.window.video_write() |
|
|
|
def quat_to_6d(self, quat, scalar_first=False): |
|
r = Rotation.from_quat(quat, scalar_first=scalar_first) |
|
mat = r.as_matrix() |
|
return mat[:, :2].flatten() |
|
|
|
def sixd_to_quat(self, sixd, scalar_first=False): |
|
mat = np.zeros((3, 3)) |
|
mat[:, :2] = sixd.reshape(3, 2) |
|
mat[:, 2] = np.cross(mat[:, 0], mat[:, 1]) |
|
r = Rotation.from_matrix(mat) |
|
return r.as_quat(scalar_first=scalar_first) |
|
|
|
def ros_close(self): |
|
if self.window.video_recording: |
|
self.window.video_stop() |
|
self.thread.stop() |
|
self.video_thread.stop() |
|
self.timer_thread.stop() |
|
self._node.destroy_node() |
|
rclpy.shutdown() |
|
|
|
class PeriodicThread(threading.Thread): |
|
def __init__(self, interval, function, *args, **kwargs): |
|
super().__init__() |
|
self.interval = interval |
|
self.function = function |
|
self.args = args |
|
self.kwargs = kwargs |
|
self.stop_event = threading.Event() |
|
self._lock = threading.Lock() |
|
|
|
def run(self): |
|
while not self.stop_event.is_set(): |
|
start_time = time.time() |
|
self.function(*self.args, **self.kwargs) |
|
elapsed_time = time.time() - start_time |
|
sleep_time = max(0, self.interval - elapsed_time) |
|
time.sleep(sleep_time) |
|
|
|
def stop(self): |
|
self.stop_event.set() |
|
|
|
def change_period(self, new_interval): |
|
with self._lock: |
|
self.interval = new_interval |