|
from node import InferenceNode |
|
import json |
|
import torch |
|
from PIL import Image as IMG |
|
import numpy as np |
|
from std_msgs.msg import String, Bool |
|
import argparse |
|
import h5py |
|
import os, pickle |
|
from einops import rearrange |
|
import numpy as np |
|
from PIL import Image |
|
import time |
|
""" |
|
#!/usr/bin/python3 |
|
""" |
|
|
|
import argparse |
|
import sys |
|
import threading |
|
import time |
|
import yaml |
|
from collections import deque |
|
|
|
import numpy as np |
|
import torch |
|
from cv_bridge import CvBridge |
|
from geometry_msgs.msg import Twist |
|
from nav_msgs.msg import Odometry |
|
from std_msgs.msg import Header |
|
import cv2 |
|
|
|
from scripts.agilex_model import create_model |
|
|
|
class RDTNode(InferenceNode): |
|
def __init__(self, action_chunk, instruction, ckpt_dir, unnorm_key, hz=20, max_timestep=1000, dataset_name=None, single_arm=True, lang_embed_name=''): |
|
self.ckpt_dir = ckpt_dir |
|
self.lang_embed_name = f'outs/{lang_embed_name}.pt' |
|
self.run_name = f'rdt_{ckpt_dir.split("/")[-1]}' |
|
self.single_arm = single_arm |
|
super().__init__(hz=hz, max_timestep=max_timestep, dataset_name=dataset_name, single_arm=single_arm) |
|
self.obs['language_instruction'] = f'{instruction}' |
|
self.action_chunk = action_chunk |
|
self.action_counter = 0 |
|
self.unnorm_key = unnorm_key |
|
self.prompt_sub = self._node.create_subscription(String, '/vla/prompt', self.prompt_sub, 1) |
|
self.attn = None |
|
|
|
|
|
def prompt_sub(self, msg): |
|
if self.policy is not None: |
|
img = self.obs['image'] |
|
pil_image = Image.fromarray(img) |
|
print(self.policy.inference_prompt(pil_image, msg.data)) |
|
|
|
def bringup_model(self): |
|
with open('configs/base.yaml', "r") as fp: |
|
config = yaml.safe_load(fp) |
|
self.policy = create_model( |
|
args=config, |
|
dtype=torch.bfloat16, |
|
pretrained=self.ckpt_dir, |
|
|
|
pretrained_vision_encoder_name_or_path="google/siglip-so400m-patch14-384", |
|
control_frequency=20, |
|
single_arm=self.single_arm |
|
) |
|
self.lang_embeddings = torch.load(self.lang_embed_name)["embeddings"] |
|
|
|
def inference_fn(self): |
|
if self.single_arm: |
|
image_arrs = [ |
|
self.frame_buffer[-2], |
|
None, |
|
None, |
|
self.frame_buffer[-1], |
|
None, |
|
None |
|
|
|
] |
|
else: |
|
image_arrs = [ |
|
self.frame_buffer[-2], |
|
self.left_frame_buffer[-2], |
|
None, |
|
self.frame_buffer[-1], |
|
self.left_frame_buffer[-1], |
|
None |
|
] |
|
images = [Image.fromarray(arr) if arr is not None else None |
|
for arr in image_arrs] |
|
if self.single_arm: |
|
proprio = torch.tensor(self.joint_pos_buffer[-1][7:]).unsqueeze(0) |
|
else: |
|
proprio = torch.tensor(self.joint_pos_buffer[-1]).unsqueeze(0) |
|
|
|
actions = self.policy.step( |
|
proprio=proprio, |
|
images=images, |
|
text_embeds=self.lang_embeddings |
|
).squeeze(0).cpu().numpy() |
|
|
|
return actions |
|
|
|
def inference(self): |
|
if self.action_counter == 0: |
|
with torch.inference_mode(): |
|
|
|
start_time = time.time() |
|
self.actions = self.inference_fn() |
|
end_time = time.time() |
|
print(f'{end_time - start_time:.6f} sec') |
|
|
|
action = self.actions[self.action_counter] |
|
|
|
if self.single_arm: |
|
self.joint_action(None, action) |
|
else: |
|
self.joint_action(action[:7], action[7:]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.action_counter += 1 |
|
if self.action_counter == self.action_chunk: |
|
self.action_counter = 0 |
|
|
|
def done_callback(self, msg): |
|
if not self.start: |
|
|
|
if self.data_list is not None: |
|
root = h5py.File(self.data_list[self.num], 'r') |
|
skip = 5 |
|
if self.single_arm: |
|
self.target_joint_right = root['observation']['joint_pos'][skip, :7] |
|
self.joint_action(None, self.target_joint_right) |
|
else: |
|
self.target_joint_left = root['observation']['joint_pos'][skip, :7] |
|
self.target_joint_right = root['observation']['joint_pos'][skip, 7:] |
|
self.joint_action(self.target_joint_left, self.target_joint_right) |
|
time.sleep(2) |
|
|
|
else: |
|
self.target_ee_left = self.obs['left_pose'] |
|
self.target_ee_right = self.obs['right_pose'] |
|
print('Inference & Video Recording Start') |
|
self.start = True |
|
msg = Bool() |
|
msg.data = True |
|
self.sync_pub.publish(msg) |
|
self.window.video_start() |
|
else: |
|
self.start = False |
|
msg = Bool() |
|
msg.data = False |
|
self.sync_pub.publish(msg) |
|
self.init_robot() |
|
self.action_counter = 0 |
|
if self.window.video_recording: |
|
self.window.video_stop() |
|
self.initialize() |
|
print('Next Inference Ready') |
|
|
|
if __name__ == "__main__": |
|
import cv2 |
|
|
|
ckpt_dir = '/home/univ/workspace/rdt-ckpts/checkpoint-38000' |
|
|
|
action_chunk = 64 |
|
hz = 20 |
|
|
|
instruction = 'handover the stuffed doll' |
|
unnorm_key = 'handover_kirby' |
|
single_arm = False |
|
dataset_name = [ |
|
'vla_upright_mug', |
|
'vla_sweep_screws', |
|
'vla_pick_ball_place_bin', |
|
'twinvla_handover_kirby', |
|
'twinvla_put_bottle', |
|
'twinvla_detach_ball', |
|
'twinvla_tear_paper_towel' |
|
] |
|
lang_embed_name = [ |
|
'upright_mug', |
|
'sweep_screws', |
|
'pick_ball_place_bin', |
|
'handover_kirby' |
|
] |
|
num = 3 |
|
|
|
node = RDTNode( |
|
action_chunk=action_chunk, |
|
instruction=instruction, |
|
ckpt_dir=ckpt_dir, |
|
unnorm_key=unnorm_key, |
|
hz=hz, |
|
max_timestep=1000, |
|
dataset_name=dataset_name[num], |
|
lang_embed_name=lang_embed_name[num], |
|
single_arm=single_arm |
|
) |
|
|
|
while True: |
|
try: |
|
if node.single_arm: |
|
img = cv2.cvtColor(node.obs['image'], cv2.COLOR_BGR2RGB) |
|
else: |
|
left_img = cv2.cvtColor(node.obs['leftview_image'], cv2.COLOR_BGR2RGB) |
|
right_img = cv2.cvtColor(node.obs['image'], cv2.COLOR_BGR2RGB) |
|
img = cv2.hconcat([left_img, right_img]) |
|
if node.start: |
|
node.window.show(img, overlay_img=None, text=node.obs['language_instruction']) |
|
else: |
|
|
|
node.boundary_query() |
|
node.window.show(img, overlay_img=node.overlay_img, text=node.obs['language_instruction'], grid=node.grid) |
|
except KeyboardInterrupt: |
|
node.ros_close() |
|
|
|
except Exception as e: |
|
print(f"An error occurred: {e}") |
|
|
|
|