|
import os |
|
import sys |
|
from dataclasses import dataclass |
|
from pathlib import Path |
|
from typing import Optional, Union |
|
import draccus |
|
import numpy as np |
|
import tqdm |
|
|
|
from PIL import Image |
|
import torch |
|
|
|
import tabletop |
|
from dm_env import StepType as st |
|
import imageio |
|
import time |
|
import yaml |
|
from scripts.agilex_model import create_model |
|
import random |
|
from PIL import Image, ImageDraw, ImageFont |
|
|
|
def set_seed_everywhere(seed: int): |
|
"""Sets the random seed for Python, NumPy, and PyTorch functions.""" |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
os.environ["PYTHONHASHSEED"] = str(seed) |
|
|
|
def save_rollout_video(rollout_images, idx, success, task_description, log_file=None, folder=None, subtitile=None): |
|
"""Saves an MP4 replay of an episode.""" |
|
if folder is None: |
|
rollout_dir = f"./rollouts/{DATE}" |
|
else: |
|
rollout_dir = f"./rollouts/{DATE}/{folder}/videos" |
|
os.makedirs(rollout_dir, exist_ok=True) |
|
|
|
processed_task_description = task_description.lower().replace(" ", "_").replace("\n", "_").replace(".", "_")[:50] |
|
mp4_path = f"{rollout_dir}/{DATE_TIME}--episode={idx}--success={success}--task={processed_task_description}.mp4" |
|
video_writer = imageio.get_writer(mp4_path, fps=30) |
|
|
|
for img in rollout_images: |
|
if subtitile: |
|
pil_img = Image.fromarray(img) |
|
draw = ImageDraw.Draw(pil_img) |
|
font = ImageFont.load_default() |
|
|
|
font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" |
|
font = ImageFont.truetype(font_path, size=24) |
|
|
|
|
|
bbox = draw.textbbox((0, 0), subtitile, font=font) |
|
text_width = bbox[2] - bbox[0] |
|
text_height = bbox[3] - bbox[1] |
|
text_x = (pil_img.width - text_width) // 2 |
|
text_y = pil_img.height - text_height - 10 |
|
|
|
draw.text((text_x, text_y), subtitile, font=font, fill="white") |
|
img = np.array(pil_img) |
|
|
|
video_writer.append_data(img) |
|
|
|
video_writer.close() |
|
print(f"Saved rollout MP4 at path {mp4_path}") |
|
if log_file is not None: |
|
log_file.write(f"Saved rollout MP4 at path {mp4_path}\n") |
|
return mp4_path |
|
|
|
DATE = time.strftime("%Y_%m_%d") |
|
DATE_TIME = time.strftime("%Y_%m_%d-%H_%M_%S") |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
os.environ["MUJOCO_GL"] = "egl" |
|
|
|
@dataclass |
|
class Config: |
|
checkpoint: Union[str, Path] = "" |
|
twin_or_sing: str = 'TwinVLA' |
|
task_name: str = "aloha_dish_drainer" |
|
action_space: str = "ee_6d_pos" |
|
num_steps_wait: int = 10 |
|
num_trials_per_task: int = 5 |
|
action_len: int = 20 |
|
benchmark: bool = True |
|
|
|
run_id_note: Optional[str] = None |
|
seed: int = 48 |
|
|
|
@draccus.wrap() |
|
def eval_tabletop(cfg: Config) -> None: |
|
set_seed_everywhere(cfg.seed) |
|
unnorm_key = cfg.task_name |
|
with open('configs/base.yaml', "r") as fp: |
|
config = yaml.safe_load(fp) |
|
|
|
model = create_model( |
|
args=config, |
|
dtype=torch.bfloat16, |
|
pretrained=cfg.checkpoint, |
|
pretrained_text_encoder_name_or_path="google/t5-v1_1-xxl", |
|
pretrained_vision_encoder_name_or_path="google/siglip-so400m-patch14-384", |
|
control_frequency=20 |
|
) |
|
|
|
|
|
env = tabletop.env(cfg.task_name, cfg.action_space) |
|
highest_rewards = [] |
|
episode_returns = [] |
|
for rollout_id in tqdm.tqdm(range(cfg.num_trials_per_task)): |
|
np.random.seed(cfg.seed + rollout_id) |
|
ts = env.reset() |
|
if cfg.benchmark: |
|
ts = env.task.benchmark_init(env.physics, rollout_id) |
|
action_counter = 0 |
|
replay_images = [] |
|
rewards = [] |
|
last_front_img = ts.observation['images']['back'] |
|
last_right_wrist_img = ts.observation['images']['wrist_right'] |
|
last_left_wrist_img = ts.observation['images']['wrist_left'] |
|
with torch.inference_mode(): |
|
while True: |
|
obs = ts.observation |
|
replay_images.append(obs['images']['back']) |
|
if action_counter == 0: |
|
front_img = obs['images']['back'] |
|
right_wrist_img = obs['images']['wrist_right'] |
|
left_wrist_img = obs['images']['wrist_left'] |
|
image_arrs = [ |
|
last_front_img, |
|
last_right_wrist_img, |
|
last_left_wrist_img, |
|
front_img, |
|
right_wrist_img, |
|
left_wrist_img |
|
] |
|
images = [Image.fromarray(arr) if arr is not None else None for arr in image_arrs] |
|
proprio = torch.tensor(obs['ee_6d_pos']).unsqueeze(0) |
|
actions = model.step( |
|
proprio=proprio, |
|
images=images, |
|
instruction=obs['language_instruction'] |
|
).squeeze(0).cpu().numpy() |
|
|
|
action = actions[action_counter] |
|
ts = env.step(action) |
|
rewards.append(ts.reward) |
|
action_counter += 1 |
|
if action_counter == cfg.action_len: |
|
action_counter = 0 |
|
if ts.reward == env.task.max_reward or ts.step_type==st.LAST: |
|
break |
|
last_front_img = ts.observation['images']['back'] |
|
last_right_wrist_img = ts.observation['images']['wrist_right'] |
|
last_left_wrist_img = ts.observation['images']['wrist_left'] |
|
|
|
rewards = np.array(rewards) |
|
episode_return = np.sum(rewards[rewards!=None]) |
|
episode_returns.append(episode_return) |
|
episode_highest_reward = np.max(rewards) |
|
highest_rewards.append(episode_highest_reward) |
|
env_max_reward = env.task.max_reward |
|
|
|
save_rollout_video( |
|
replay_images, rollout_id, success=episode_highest_reward==env_max_reward, task_description=cfg.task_name, folder=f"{cfg.checkpoint.split('/')[-1]}-{cfg.task_name}-{cfg.seed}", subtitile=obs['language_instruction'] |
|
) |
|
replay_images.clear() |
|
|
|
success_rate = np.mean(np.array(highest_rewards) == env_max_reward) |
|
avg_return = np.mean(episode_returns) |
|
summary_str = f'\nSuccess rate: {success_rate}\nAverage return: {avg_return}\n\n' |
|
for r in range(env_max_reward+1): |
|
more_or_equal_r = (np.array(highest_rewards) >= r).sum() |
|
more_or_equal_r_rate = more_or_equal_r / cfg.num_trials_per_task |
|
summary_str += f'Reward >= {r}: {more_or_equal_r}/{cfg.num_trials_per_task} = {more_or_equal_r_rate*100}%\n' |
|
|
|
log_dir = Path('rollouts') / DATE / f"{cfg.checkpoint.split('/')[-1]}-{cfg.task_name}-{cfg.seed}" |
|
log_dir.mkdir(parents=True, exist_ok=True) |
|
summary_file = log_dir / "summary.txt" |
|
with summary_file.open("w") as f: |
|
f.write(summary_str) |
|
|
|
if __name__ == "__main__": |
|
eval_tabletop() |