File size: 3,576 Bytes
9de9fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import rclpy 
from sensor_msgs.msg import JointState
import numpy as np
import time
import h5py
from real_robot_env import AnubisRobotEnv
import cv2
from tqdm import tqdm
import os
from scripts.agilex_model import create_model
import random
from PIL import Image, ImageDraw, ImageFont
import yaml
import torch

class RDTInferenceRobotEnv(AnubisRobotEnv):
    def __init__(self, hz=20, max_timestep=500, task_name='', num_rollout=1, model_name='1st'):
        self.model_name = model_name
        self.checkpoint = f'/home/rllab/workspace/jellyho/RoboticsDiffusionTransformer/checkpoints/{task_name}'
        super().__init__(hz=hz, max_timestep=max_timestep, task_name=task_name, num_rollout=num_rollout)
        self.last_front_img = None
        self.right_wrist_img = None
        self.left_wrist_img = None

    def bringup_model(self):
        with open('configs/base.yaml', "r") as fp:
            config = yaml.safe_load(fp)
        self.model = create_model(
            args=config,
            dtype=torch.bfloat16,
            pretrained=self.checkpoint,
            pretrained_text_encoder_name_or_path="google/t5-v1_1-xxl",
            pretrained_vision_encoder_name_or_path="google/siglip-so400m-patch14-384",
            control_frequency=20
        )
        print('model loaded')

    def inference(self):
        proprio = self.obs['eef_pose']
        if self.action_counter % 64 == 0:
            front_img = self.obs['agentview_image']
            right_wrist_img = self.obs['rightview_image']
            left_wrist_img = self.obs['leftview_image']
            image_arrs = [
                self.last_front_img,
                self.last_right_wrist_img,
                self.last_left_wrist_img,
                front_img,
                right_wrist_img,
                left_wrist_img
            ]
            images = [Image.fromarray(arr) if arr is not None else None for arr in image_arrs]
            proprio = torch.tensor(proprio).unsqueeze(0)
            with torch.inference_mode():
                self.actions = self.model.step(
                    proprio=proprio,
                    images=images,
                    instruction=self.instruction
                ).squeeze(0).cpu().numpy()
        idx = self.action_counter % 64
        act = self.actions[idx]
        self.action_counter += 1
        self.last_front_img = self.obs['agentview_image']
        self.last_right_wrist_img = self.obs['rightview_image']
        self.last_left_wrist_img = self.obs['leftview_image']
        self.send_action(act)

if __name__ == '__main__':
    task_name = 'anubis_towel_kirby'
    rollout_num = 20
    hz = 20
    model_name = 'rdt'
    # model_name = 'oxe'
    node = RDTInferenceRobotEnv(
        hz=hz,
        max_timestep=800,
        task_name=task_name,
        num_rollout=rollout_num,
        model_name=model_name
    )
    while node.rollout_counter < rollout_num:
        try:
            img = cv2.cvtColor(node.obs['agentview_image'], cv2.COLOR_RGB2BGR)
            if node.start:
                node.window.show(img, text=node.instruction)
            else:
                node.window.show(img, overlay_img=node.overlay_img, text=node.instruction)
                node.last_front_img = node.obs['agentview_image']
                node.last_right_wrist_img = node.obs['rightview_image']
                node.last_left_wrist_img = node.obs['leftview_image']
        except KeyboardInterrupt:
            node.ros_close()
        
        except Exception as e:
            print(f"An error occurred: {e}")
    node.ros_close()