euijinrnd's picture
Add files using upload-large-folder tool
9de9fbf verified
raw
history blame
7.92 kB
from node import InferenceNode
import json
import torch
from PIL import Image as IMG
import numpy as np
from std_msgs.msg import String, Bool
import argparse
import h5py
import os, pickle
from einops import rearrange
import numpy as np
from PIL import Image
import time
"""
#!/usr/bin/python3
"""
import argparse
import sys
import threading
import time
import yaml
from collections import deque
import numpy as np
import torch
from cv_bridge import CvBridge
from geometry_msgs.msg import Twist
from nav_msgs.msg import Odometry
from std_msgs.msg import Header
import cv2
from scripts.agilex_model import create_model
class RDTNode(InferenceNode):
def __init__(self, action_chunk, instruction, ckpt_dir, unnorm_key, hz=20, max_timestep=1000, dataset_name=None, single_arm=True, lang_embed_name=''):
self.ckpt_dir = ckpt_dir
self.lang_embed_name = f'outs/{lang_embed_name}.pt'
self.run_name = f'rdt_{ckpt_dir.split("/")[-1]}' # for video name
self.single_arm = single_arm
super().__init__(hz=hz, max_timestep=max_timestep, dataset_name=dataset_name, single_arm=single_arm)
self.obs['language_instruction'] = f'{instruction}'
self.action_chunk = action_chunk
self.action_counter = 0
self.unnorm_key = unnorm_key
self.prompt_sub = self._node.create_subscription(String, '/vla/prompt', self.prompt_sub, 1)
self.attn = None
def prompt_sub(self, msg):
if self.policy is not None:
img = self.obs['image']
pil_image = Image.fromarray(img)
print(self.policy.inference_prompt(pil_image, msg.data))
def bringup_model(self):
with open('configs/base.yaml', "r") as fp:
config = yaml.safe_load(fp)
self.policy = create_model(
args=config,
dtype=torch.bfloat16,
pretrained=self.ckpt_dir,
# pretrained_text_encoder_name_or_path="google/t5-v1_1-xxl",
pretrained_vision_encoder_name_or_path="google/siglip-so400m-patch14-384",
control_frequency=20,
single_arm=self.single_arm
)
self.lang_embeddings = torch.load(self.lang_embed_name)["embeddings"]
def inference_fn(self):
if self.single_arm:
image_arrs = [
self.frame_buffer[-2],
None,
None,
self.frame_buffer[-1],
None,
None
# self.left_frame_buffer[-1],
]
else:
image_arrs = [
self.frame_buffer[-2],
self.left_frame_buffer[-2],
None,
self.frame_buffer[-1],
self.left_frame_buffer[-1],
None
]
images = [Image.fromarray(arr) if arr is not None else None
for arr in image_arrs]
if self.single_arm:
proprio = torch.tensor(self.joint_pos_buffer[-1][7:]).unsqueeze(0)
else:
proprio = torch.tensor(self.joint_pos_buffer[-1]).unsqueeze(0)
actions = self.policy.step(
proprio=proprio,
images=images,
text_embeds=self.lang_embeddings
).squeeze(0).cpu().numpy()
return actions
def inference(self):
if self.action_counter == 0:
with torch.inference_mode():
# Len , action dim
start_time = time.time()
self.actions = self.inference_fn()
end_time = time.time()
print(f'{end_time - start_time:.6f} sec')
# print(self.actions)
action = self.actions[self.action_counter]
# action[-1] = action[-1] * 4.0
if self.single_arm:
self.joint_action(None, action)
else:
self.joint_action(action[:7], action[7:])
# print(action)
# self.joint_action(None, )
# print(action[6], action[-1])
# self.ee_action(None, action)
# self.target_ee_left += np.array(action[:6])
# self.target_ee_right += np.array(action[7:-1])
# action_target_ee_left = np.concatenate([self.target_ee_left, [action[6]]])
# action_target_ee_right = np.concatenate([self.target_ee_right, [action[-1]]])
# print(action_target_ee_right)
# self.ee_action(None, action_target_ee_right)
# self.ee_action(action_target_ee_left, action_target_ee_right)
self.action_counter += 1
if self.action_counter == self.action_chunk:
self.action_counter = 0
def done_callback(self, msg):
if not self.start:
## For delta ee control
if self.data_list is not None:
root = h5py.File(self.data_list[self.num], 'r')
skip = 5
if self.single_arm:
self.target_joint_right = root['observation']['joint_pos'][skip, :7]
self.joint_action(None, self.target_joint_right)
else:
self.target_joint_left = root['observation']['joint_pos'][skip, :7]
self.target_joint_right = root['observation']['joint_pos'][skip, 7:]
self.joint_action(self.target_joint_left, self.target_joint_right)
time.sleep(2)
else:
self.target_ee_left = self.obs['left_pose']
self.target_ee_right = self.obs['right_pose']
print('Inference & Video Recording Start')
self.start = True
msg = Bool()
msg.data = True
self.sync_pub.publish(msg)
self.window.video_start()
else:
self.start = False
msg = Bool()
msg.data = False
self.sync_pub.publish(msg)
self.init_robot()
self.action_counter = 0
if self.window.video_recording:
self.window.video_stop()
self.initialize()
print('Next Inference Ready')
if __name__ == "__main__":
import cv2
ckpt_dir = '/home/univ/workspace/rdt-ckpts/checkpoint-38000'
action_chunk = 64
hz = 20
instruction = 'handover the stuffed doll'
unnorm_key = 'handover_kirby'
single_arm = False
dataset_name = [
'vla_upright_mug',
'vla_sweep_screws',
'vla_pick_ball_place_bin',
'twinvla_handover_kirby',
'twinvla_put_bottle',
'twinvla_detach_ball',
'twinvla_tear_paper_towel'
]
lang_embed_name = [
'upright_mug',
'sweep_screws',
'pick_ball_place_bin',
'handover_kirby'
]
num = 3
node = RDTNode(
action_chunk=action_chunk,
instruction=instruction,
ckpt_dir=ckpt_dir,
unnorm_key=unnorm_key,
hz=hz,
max_timestep=1000,
dataset_name=dataset_name[num],
lang_embed_name=lang_embed_name[num],
single_arm=single_arm
)
while True:
try:
if node.single_arm:
img = cv2.cvtColor(node.obs['image'], cv2.COLOR_BGR2RGB)
else:
left_img = cv2.cvtColor(node.obs['leftview_image'], cv2.COLOR_BGR2RGB)
right_img = cv2.cvtColor(node.obs['image'], cv2.COLOR_BGR2RGB)
img = cv2.hconcat([left_img, right_img])
if node.start:
node.window.show(img, overlay_img=None, text=node.obs['language_instruction'])
else:
# print(node.attn)
node.boundary_query()
node.window.show(img, overlay_img=node.overlay_img, text=node.obs['language_instruction'], grid=node.grid)
except KeyboardInterrupt:
node.ros_close()
except Exception as e:
print(f"An error occurred: {e}")
# node.ros_close()