Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import cv2 | |
| import os | |
| import torch | |
| import tqdm | |
| import time | |
| import imageio | |
| from sim.simulator import GenieSimulator, RobomimicSimulator | |
| from diffusion_policy.util.pytorch_util import dict_apply | |
| from sim.policy import DiffusionPolicy | |
| DP_RES = 84 | |
| MAX_STEPS = 100 | |
| NUM_EVAL_TRIALS = 50 | |
| if __name__ == '__main__': | |
| robomimic_simulator = RobomimicSimulator(env_name='lift') | |
| genie_simulator = GenieSimulator( | |
| image_encoder_type='temporalvae', | |
| image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid', | |
| quantize=False, | |
| backbone_type="stmar", | |
| backbone_ckpt="data/mar_ckpt/robomimic_mixed", | |
| prompt_horizon=11, | |
| action_stride=1, | |
| domain='robomimic', | |
| physics_simulator=robomimic_simulator, | |
| physics_simulator_teacher_force=None, | |
| ) | |
| assert genie_simulator.action_stride == 1, "currently only support action stride of 1" | |
| # load the policy | |
| success_rates = [1.00, 0.70, 0.52, 0.38] | |
| eval_time_taken = [0.0] * len(success_rates) | |
| for index, sr in enumerate(success_rates): | |
| diffusion_policy = DiffusionPolicy(f'data/dp_ckpt/dp_lift_sr{sr:.2f}.ckpt') | |
| n_obs_steps = diffusion_policy.n_obs_steps | |
| for trial in range(NUM_EVAL_TRIALS): | |
| # reset | |
| genie_image = genie_simulator.reset() | |
| # obs dict construction | |
| latest_obs_dict = { | |
| 'agentview_image': cv2.resize( | |
| genie_image[:, :genie_image.shape[1]//2], | |
| (DP_RES, DP_RES) | |
| ).transpose(2, 0, 1), | |
| } | |
| obs_dict_buf = dict_apply(latest_obs_dict, lambda x : x[np.newaxis].repeat(n_obs_steps, axis=0)) | |
| done = False | |
| pbar = tqdm.tqdm(total=MAX_STEPS) | |
| simulation_frames = [ genie_image ] | |
| start_time = time.time() | |
| while not done and pbar.n < MAX_STEPS: | |
| # get the latest observation | |
| latest_obs_dict = { | |
| 'agentview_image': cv2.resize( | |
| genie_image, | |
| (DP_RES, DP_RES) | |
| ).transpose(2, 0, 1), | |
| } | |
| # roll the obs dict buffer | |
| obs_dict_buf = dict_apply( | |
| obs_dict_buf, | |
| lambda x : np.roll(x, shift=-1, axis=0) | |
| ) | |
| # update the obs dict buffer with the latest observation | |
| for k, v in latest_obs_dict.items(): | |
| obs_dict_buf[k][-1] = v | |
| # rollout | |
| traj = diffusion_policy.generate_action(dict_apply( | |
| obs_dict_buf, | |
| lambda x : torch.from_numpy(x).to( | |
| device=diffusion_policy.device, dtype=diffusion_policy.dtype | |
| ).unsqueeze(0) | |
| ))['action'].squeeze(0).detach().cpu().numpy() | |
| # step the simulator | |
| for action in traj: | |
| result = genie_simulator.step(action[np.newaxis]) | |
| done = result['done'] | |
| genie_image = result['pred_next_frame'] | |
| phys_image = result['gt_next_frame'] | |
| simulation_frames.append(np.concatenate([genie_image, phys_image], axis=1)) | |
| pbar.update(1) | |
| pbar.close() | |
| end_time = time.time() | |
| eval_time_taken[index] += end_time - start_time | |
| # save the simulation frames | |
| os.makedirs(f'data/policy_eval_videos/policy_{sr:.2f}', exist_ok=True) | |
| print(f"Saving {len(simulation_frames)} frames to data/policy_eval_videos/policy_{sr:.2f}/{trial:02d}.mp4") | |
| imageio.mimsave(f'data/policy_eval_videos/policy_{sr:.2f}/{trial:02d}.mp4', simulation_frames, fps=10) | |
| print(f"This checkpoint took {eval_time_taken[index]} seconds to evaluate") | |
| print("======= Evaluation Time Taken =======") | |
| for sr, t in zip(success_rates, eval_time_taken): | |
| print(f"SR={sr:.2f}: {t:.2f} seconds") | |
| print(f"Average time taken per eval: {np.mean(eval_time_taken) / NUM_EVAL_TRIALS:.2f} seconds") | |
| print("======= Simulation Done =======") | |
| genie_simulator.close() |