Spaces:
Runtime error
Runtime error
fix
Browse files- sim/simulator.py +32 -32
sim/simulator.py
CHANGED
|
@@ -27,13 +27,13 @@ class Simulator:
|
|
| 27 |
@torch.inference_mode()
|
| 28 |
def step(self, action):
|
| 29 |
raise NotImplementedError
|
| 30 |
-
|
| 31 |
def reset(self):
|
| 32 |
raise NotImplementedError
|
| 33 |
-
|
| 34 |
def close(self):
|
| 35 |
raise NotImplementedError
|
| 36 |
-
|
| 37 |
@property
|
| 38 |
def dt(self):
|
| 39 |
raise NotImplementedError
|
|
@@ -46,16 +46,16 @@ class PhysicsSimulator(Simulator):
|
|
| 46 |
# physics engine should be able to update dt
|
| 47 |
def set_dt(self, dt):
|
| 48 |
raise NotImplementedError
|
| 49 |
-
|
| 50 |
# physics engine should be able to get scene state
|
| 51 |
# e.g., robot joint positions, object positions, etc.
|
| 52 |
def get_raw_state(self, port: Optional[str] = None):
|
| 53 |
raise NotImplementedError
|
| 54 |
-
|
| 55 |
@property
|
| 56 |
def action_dimension(self):
|
| 57 |
raise NotImplementedError
|
| 58 |
-
|
| 59 |
|
| 60 |
class LearnedSimulator(Simulator):
|
| 61 |
def __init__(self):
|
|
@@ -65,9 +65,9 @@ class LearnedSimulator(Simulator):
|
|
| 65 |
# data replayed respect physics, so we inherit from PhysicsSimulator
|
| 66 |
# it can be considered as a special case of PhysicsSimulator
|
| 67 |
class ReplaySimulator(PhysicsSimulator):
|
| 68 |
-
def __init__(self,
|
| 69 |
-
frames,
|
| 70 |
-
prompt_horizon: int = 0,
|
| 71 |
dt: Optional[float] = None
|
| 72 |
):
|
| 73 |
super().__init__()
|
|
@@ -76,10 +76,10 @@ class ReplaySimulator(PhysicsSimulator):
|
|
| 76 |
assert self.frame_idx < len(self.frames)
|
| 77 |
self._dt = dt
|
| 78 |
self.prompt_horizon = prompt_horizon
|
| 79 |
-
|
| 80 |
def __len__(self):
|
| 81 |
return len(self.frames) - self.prompt_horizon
|
| 82 |
-
|
| 83 |
def step(self, action):
|
| 84 |
frame = self.frames[self.frame_idx]
|
| 85 |
assert self.frame_idx < len(self.frames)
|
|
@@ -87,20 +87,20 @@ class ReplaySimulator(PhysicsSimulator):
|
|
| 87 |
return {
|
| 88 |
'pred_next_frame': frame
|
| 89 |
}
|
| 90 |
-
|
| 91 |
def reset(self): # return current frame = last frame of prompt
|
| 92 |
self.frame_idx = self.prompt_horizon
|
| 93 |
return self.prompt()[-1]
|
| 94 |
-
|
| 95 |
def prompt(self):
|
| 96 |
return self.frames[:self.prompt_horizon]
|
| 97 |
-
|
| 98 |
@property
|
| 99 |
def dt(self):
|
| 100 |
return self._dt
|
| 101 |
-
|
| 102 |
|
| 103 |
-
|
|
|
|
| 104 |
|
| 105 |
class GenieSimulator(LearnedSimulator):
|
| 106 |
|
|
@@ -164,7 +164,7 @@ class GenieSimulator(LearnedSimulator):
|
|
| 164 |
elif backbone_type == "stmar":
|
| 165 |
inference_iterations = 2
|
| 166 |
|
| 167 |
-
# misc
|
| 168 |
self.device = torch.device(device)
|
| 169 |
self.measure_step_time = measure_step_time
|
| 170 |
self.compute_psnr = compute_psnr
|
|
@@ -200,11 +200,11 @@ class GenieSimulator(LearnedSimulator):
|
|
| 200 |
else:
|
| 201 |
self.backbone = STMAR.from_pretrained(backbone_ckpt)
|
| 202 |
self.backbone = self.backbone.to(device=self.device).eval()
|
| 203 |
-
|
| 204 |
self.post_processor = post_processor
|
| 205 |
-
|
| 206 |
# load physics simulator if available
|
| 207 |
-
# the phys sim to get ground truth image,
|
| 208 |
# assume the phys sim has aligned prompt frames
|
| 209 |
self.gt_phys_sim = physics_simulator
|
| 210 |
self.gt_teacher_force = physics_simulator_teacher_force
|
|
@@ -237,7 +237,7 @@ class GenieSimulator(LearnedSimulator):
|
|
| 237 |
# return: (H, W, 3)
|
| 238 |
assert self.cached_latent_frames is not None and self.cached_actions is not None, \
|
| 239 |
"Model is not prompted yet. Please call `set_initial_state` first."
|
| 240 |
-
|
| 241 |
if action.ndim == 1:
|
| 242 |
action = np.tile(action, (self.action_stride, 1))
|
| 243 |
|
|
@@ -273,7 +273,7 @@ class GenieSimulator(LearnedSimulator):
|
|
| 273 |
start_time = time.time()
|
| 274 |
pred_next_latent_state = self.backbone.maskgit_generate(
|
| 275 |
input_latent_states,
|
| 276 |
-
out_t=
|
| 277 |
maskgit_steps=self.inference_iterations,
|
| 278 |
temperature=self.sampling_temperature,
|
| 279 |
action_ids=input_actions,
|
|
@@ -310,7 +310,7 @@ class GenieSimulator(LearnedSimulator):
|
|
| 310 |
# compute PSNR against ground truth
|
| 311 |
if self.compute_psnr:
|
| 312 |
psnr = skimage.metrics.peak_signal_noise_ratio(
|
| 313 |
-
image_true=gt_next_frame / 255.,
|
| 314 |
image_test=pred_next_frame / 255.,
|
| 315 |
data_range=1.0
|
| 316 |
)
|
|
@@ -348,7 +348,7 @@ class GenieSimulator(LearnedSimulator):
|
|
| 348 |
|
| 349 |
if self.gt_teacher_force is not None and self.step_count % self.gt_teacher_force == 0:
|
| 350 |
pred_next_latent_state = self._encode_image(gt_next_frame)
|
| 351 |
-
|
| 352 |
# update history buffer
|
| 353 |
self.cached_latent_frames = torch.cat([
|
| 354 |
self.cached_latent_frames[1:], pred_next_latent_state.unsqueeze(0)
|
|
@@ -356,7 +356,7 @@ class GenieSimulator(LearnedSimulator):
|
|
| 356 |
self.cached_actions = torch.cat([
|
| 357 |
self.cached_actions[1:], action.unsqueeze(0)
|
| 358 |
])
|
| 359 |
-
|
| 360 |
# post processing
|
| 361 |
if self.post_processor is not None:
|
| 362 |
pred_next_frame = self.post_processor(pred_next_frame, action)
|
|
@@ -364,7 +364,7 @@ class GenieSimulator(LearnedSimulator):
|
|
| 364 |
self.step_count += 1
|
| 365 |
|
| 366 |
return step_result
|
| 367 |
-
|
| 368 |
|
| 369 |
@torch.inference_mode()
|
| 370 |
def _encode_image(self, image: np.ndarray) -> torch.Tensor:
|
|
@@ -422,11 +422,11 @@ class GenieSimulator(LearnedSimulator):
|
|
| 422 |
decoded_image = decoded_image.squeeze(0).to(torch.float32).detach().cpu().numpy()
|
| 423 |
decoded_image = self._unnormalize_image(decoded_image).transpose(1, 2, 0)
|
| 424 |
return decoded_image
|
| 425 |
-
|
| 426 |
|
| 427 |
def _normalize_image(self, image: np.ndarray) -> np.ndarray:
|
| 428 |
# (H, W, 3) normalized to [-1, 1]
|
| 429 |
-
# if `resize`, resize the shorter side to `resized_res`
|
| 430 |
# and then do a center crop
|
| 431 |
|
| 432 |
image = np.asarray(image, dtype=np.float32)
|
|
@@ -435,7 +435,7 @@ class GenieSimulator(LearnedSimulator):
|
|
| 435 |
|
| 436 |
# resize if asked
|
| 437 |
if self.resize_image:
|
| 438 |
-
resized_res = self.resize_image_resolution
|
| 439 |
if H < W:
|
| 440 |
Hnew, Wnew = resized_res, int(resized_res * W / H)
|
| 441 |
else:
|
|
@@ -469,7 +469,7 @@ class GenieSimulator(LearnedSimulator):
|
|
| 469 |
|
| 470 |
|
| 471 |
def reset(self) -> np.ndarray:
|
| 472 |
-
# if ground truth physics simulator is provided,
|
| 473 |
# return the the side-by-side concatenated image
|
| 474 |
|
| 475 |
# get the initial prompt from the physics simulator if not yet set
|
|
@@ -480,7 +480,7 @@ class GenieSimulator(LearnedSimulator):
|
|
| 480 |
action_prompt = np.zeros(
|
| 481 |
(self.prompt_horizon, self.action_stride, self.gt_phys_sim.action_dimension)
|
| 482 |
).astype(np.float32)
|
| 483 |
-
else:
|
| 484 |
assert self.init_prompt is not None, "Initial state is not set."
|
| 485 |
image_prompt, action_prompt = self.init_prompt
|
| 486 |
|
|
@@ -498,7 +498,7 @@ class GenieSimulator(LearnedSimulator):
|
|
| 498 |
], axis=0)
|
| 499 |
|
| 500 |
if self.resize_image:
|
| 501 |
-
current_image = cv2.resize(current_image,
|
| 502 |
(self.resize_image_resolution, self.resize_image_resolution))
|
| 503 |
|
| 504 |
if self.gt_phys_sim is not None:
|
|
|
|
| 27 |
@torch.inference_mode()
|
| 28 |
def step(self, action):
|
| 29 |
raise NotImplementedError
|
| 30 |
+
|
| 31 |
def reset(self):
|
| 32 |
raise NotImplementedError
|
| 33 |
+
|
| 34 |
def close(self):
|
| 35 |
raise NotImplementedError
|
| 36 |
+
|
| 37 |
@property
|
| 38 |
def dt(self):
|
| 39 |
raise NotImplementedError
|
|
|
|
| 46 |
# physics engine should be able to update dt
|
| 47 |
def set_dt(self, dt):
|
| 48 |
raise NotImplementedError
|
| 49 |
+
|
| 50 |
# physics engine should be able to get scene state
|
| 51 |
# e.g., robot joint positions, object positions, etc.
|
| 52 |
def get_raw_state(self, port: Optional[str] = None):
|
| 53 |
raise NotImplementedError
|
| 54 |
+
|
| 55 |
@property
|
| 56 |
def action_dimension(self):
|
| 57 |
raise NotImplementedError
|
| 58 |
+
|
| 59 |
|
| 60 |
class LearnedSimulator(Simulator):
|
| 61 |
def __init__(self):
|
|
|
|
| 65 |
# data replayed respect physics, so we inherit from PhysicsSimulator
|
| 66 |
# it can be considered as a special case of PhysicsSimulator
|
| 67 |
class ReplaySimulator(PhysicsSimulator):
|
| 68 |
+
def __init__(self,
|
| 69 |
+
frames,
|
| 70 |
+
prompt_horizon: int = 0,
|
| 71 |
dt: Optional[float] = None
|
| 72 |
):
|
| 73 |
super().__init__()
|
|
|
|
| 76 |
assert self.frame_idx < len(self.frames)
|
| 77 |
self._dt = dt
|
| 78 |
self.prompt_horizon = prompt_horizon
|
| 79 |
+
|
| 80 |
def __len__(self):
|
| 81 |
return len(self.frames) - self.prompt_horizon
|
| 82 |
+
|
| 83 |
def step(self, action):
|
| 84 |
frame = self.frames[self.frame_idx]
|
| 85 |
assert self.frame_idx < len(self.frames)
|
|
|
|
| 87 |
return {
|
| 88 |
'pred_next_frame': frame
|
| 89 |
}
|
| 90 |
+
|
| 91 |
def reset(self): # return current frame = last frame of prompt
|
| 92 |
self.frame_idx = self.prompt_horizon
|
| 93 |
return self.prompt()[-1]
|
| 94 |
+
|
| 95 |
def prompt(self):
|
| 96 |
return self.frames[:self.prompt_horizon]
|
| 97 |
+
|
| 98 |
@property
|
| 99 |
def dt(self):
|
| 100 |
return self._dt
|
|
|
|
| 101 |
|
| 102 |
+
|
| 103 |
+
|
| 104 |
|
| 105 |
class GenieSimulator(LearnedSimulator):
|
| 106 |
|
|
|
|
| 164 |
elif backbone_type == "stmar":
|
| 165 |
inference_iterations = 2
|
| 166 |
|
| 167 |
+
# misc
|
| 168 |
self.device = torch.device(device)
|
| 169 |
self.measure_step_time = measure_step_time
|
| 170 |
self.compute_psnr = compute_psnr
|
|
|
|
| 200 |
else:
|
| 201 |
self.backbone = STMAR.from_pretrained(backbone_ckpt)
|
| 202 |
self.backbone = self.backbone.to(device=self.device).eval()
|
| 203 |
+
|
| 204 |
self.post_processor = post_processor
|
| 205 |
+
|
| 206 |
# load physics simulator if available
|
| 207 |
+
# the phys sim to get ground truth image,
|
| 208 |
# assume the phys sim has aligned prompt frames
|
| 209 |
self.gt_phys_sim = physics_simulator
|
| 210 |
self.gt_teacher_force = physics_simulator_teacher_force
|
|
|
|
| 237 |
# return: (H, W, 3)
|
| 238 |
assert self.cached_latent_frames is not None and self.cached_actions is not None, \
|
| 239 |
"Model is not prompted yet. Please call `set_initial_state` first."
|
| 240 |
+
|
| 241 |
if action.ndim == 1:
|
| 242 |
action = np.tile(action, (self.action_stride, 1))
|
| 243 |
|
|
|
|
| 273 |
start_time = time.time()
|
| 274 |
pred_next_latent_state = self.backbone.maskgit_generate(
|
| 275 |
input_latent_states,
|
| 276 |
+
out_t=input_latent_states.shape[1] - 1,,
|
| 277 |
maskgit_steps=self.inference_iterations,
|
| 278 |
temperature=self.sampling_temperature,
|
| 279 |
action_ids=input_actions,
|
|
|
|
| 310 |
# compute PSNR against ground truth
|
| 311 |
if self.compute_psnr:
|
| 312 |
psnr = skimage.metrics.peak_signal_noise_ratio(
|
| 313 |
+
image_true=gt_next_frame / 255.,
|
| 314 |
image_test=pred_next_frame / 255.,
|
| 315 |
data_range=1.0
|
| 316 |
)
|
|
|
|
| 348 |
|
| 349 |
if self.gt_teacher_force is not None and self.step_count % self.gt_teacher_force == 0:
|
| 350 |
pred_next_latent_state = self._encode_image(gt_next_frame)
|
| 351 |
+
|
| 352 |
# update history buffer
|
| 353 |
self.cached_latent_frames = torch.cat([
|
| 354 |
self.cached_latent_frames[1:], pred_next_latent_state.unsqueeze(0)
|
|
|
|
| 356 |
self.cached_actions = torch.cat([
|
| 357 |
self.cached_actions[1:], action.unsqueeze(0)
|
| 358 |
])
|
| 359 |
+
|
| 360 |
# post processing
|
| 361 |
if self.post_processor is not None:
|
| 362 |
pred_next_frame = self.post_processor(pred_next_frame, action)
|
|
|
|
| 364 |
self.step_count += 1
|
| 365 |
|
| 366 |
return step_result
|
| 367 |
+
|
| 368 |
|
| 369 |
@torch.inference_mode()
|
| 370 |
def _encode_image(self, image: np.ndarray) -> torch.Tensor:
|
|
|
|
| 422 |
decoded_image = decoded_image.squeeze(0).to(torch.float32).detach().cpu().numpy()
|
| 423 |
decoded_image = self._unnormalize_image(decoded_image).transpose(1, 2, 0)
|
| 424 |
return decoded_image
|
| 425 |
+
|
| 426 |
|
| 427 |
def _normalize_image(self, image: np.ndarray) -> np.ndarray:
|
| 428 |
# (H, W, 3) normalized to [-1, 1]
|
| 429 |
+
# if `resize`, resize the shorter side to `resized_res`
|
| 430 |
# and then do a center crop
|
| 431 |
|
| 432 |
image = np.asarray(image, dtype=np.float32)
|
|
|
|
| 435 |
|
| 436 |
# resize if asked
|
| 437 |
if self.resize_image:
|
| 438 |
+
resized_res = self.resize_image_resolution
|
| 439 |
if H < W:
|
| 440 |
Hnew, Wnew = resized_res, int(resized_res * W / H)
|
| 441 |
else:
|
|
|
|
| 469 |
|
| 470 |
|
| 471 |
def reset(self) -> np.ndarray:
|
| 472 |
+
# if ground truth physics simulator is provided,
|
| 473 |
# return the the side-by-side concatenated image
|
| 474 |
|
| 475 |
# get the initial prompt from the physics simulator if not yet set
|
|
|
|
| 480 |
action_prompt = np.zeros(
|
| 481 |
(self.prompt_horizon, self.action_stride, self.gt_phys_sim.action_dimension)
|
| 482 |
).astype(np.float32)
|
| 483 |
+
else:
|
| 484 |
assert self.init_prompt is not None, "Initial state is not set."
|
| 485 |
image_prompt, action_prompt = self.init_prompt
|
| 486 |
|
|
|
|
| 498 |
], axis=0)
|
| 499 |
|
| 500 |
if self.resize_image:
|
| 501 |
+
current_image = cv2.resize(current_image,
|
| 502 |
(self.resize_image_resolution, self.resize_image_resolution))
|
| 503 |
|
| 504 |
if self.gt_phys_sim is not None:
|