Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import imageio | |
| from sim.simulator import Simulator | |
| from sim.policy import Policy | |
| from sim.viewer import ImageViewer | |
| from typing import List, Tuple | |
| step_time = [] | |
| psnr = [] | |
| delta_psnr = [] | |
| class InteractiveDigitalWorld: | |
| def __init__(self, | |
| simulator: Simulator, | |
| policy: Policy, | |
| offscreen: bool = True, # if False, show live window | |
| window_size: Tuple[int, int] = (512, 512), | |
| ): | |
| self.simulator = simulator | |
| self.policy = policy | |
| self.offscreen = offscreen | |
| self.video_frames: List[np.ndarray] = [] | |
| self.dt = simulator.dt | |
| self.obs = self.simulator.reset() # input to policy | |
| self.video_frames.append(self.obs) | |
| if not offscreen: | |
| self.viewer = ImageViewer( | |
| window_name=( | |
| f"Simulator: {simulator.__class__.__name__} | " | |
| f"Policy: {policy.__class__.__name__}" | |
| ), | |
| refresh_rate=self.dt, | |
| window_size=window_size | |
| ) | |
| self.viewer.update_image(self.obs) | |
| def step(self) -> None: | |
| action = self.policy.generate_action(self.obs) | |
| result = self.simulator.step(action) | |
| next_frame = result['pred_next_frame'] | |
| if 'gt_next_frame' in result: | |
| gt_next_frame = result['gt_next_frame'] | |
| next_frame = np.concatenate([next_frame, gt_next_frame], axis=1) | |
| if 'psnr' in result: | |
| psnr.append(result['psnr']) | |
| if 'delta_psnr' in result: | |
| delta_psnr.append(result['delta_psnr']) | |
| if 'step_time' in result: | |
| step_time.append(result['step_time']) | |
| self.obs = next_frame | |
| if not self.offscreen: | |
| self.viewer.update_image(next_frame) | |
| self.video_frames.append(next_frame) | |
| def save_video(self, save_path: str, as_gif: bool = False) -> None: | |
| if as_gif: | |
| imageio.mimsave(save_path, self.video_frames, format='GIF', fps=1/self.dt) | |
| else: | |
| imageio.mimsave(save_path, self.video_frames, format='mp4', fps=1/self.dt) | |
| print(f"{'GIF' if as_gif else 'MP4'} saved to {save_path}") | |
| def reset(self) -> None: | |
| self.obs = self.simulator.reset() | |
| self.video_frames = [] | |
| def close(self) -> None: | |
| self.simulator.close() | |
| if not self.offscreen: | |
| self.viewer.stop() | |
| def analyze_scalar_sequence(data: List[float]): | |
| q1 = np.percentile(data, 25, method='nearest') | |
| median = np.median(data) | |
| q3 = np.percentile(data, 75, method='nearest') | |
| mean = np.mean([t for t in data if q1 <= t <= q3]) | |
| return mean, median | |
| # report stats | |
| if len(step_time) > 0: | |
| # take mean over data between q1 and q3 | |
| mean, median = analyze_scalar_sequence(step_time) | |
| print( | |
| f"=========== Timing ===========\n" | |
| f"Mean: {mean}\n" | |
| f"Meadian: {median}\n" | |
| ) | |
| if len(psnr) > 0: | |
| mean, median = analyze_scalar_sequence(psnr) | |
| print( | |
| f"=========== PSNR ===========\n" | |
| f"Mean: {mean}\n" | |
| f"Meadian: {median}\n" | |
| ) | |
| if len(delta_psnr) > 0: | |
| mean, median = analyze_scalar_sequence(delta_psnr) | |
| print( | |
| f"=========== Delta PSNR ===========\n" | |
| f"Mean: {mean}\n" | |
| f"Meadian: {median}\n" | |
| ) |