xizaoqu
commited on
Commit
·
f07d258
1
Parent(s):
cef86dc
update
Browse files
app.py
CHANGED
|
@@ -132,26 +132,6 @@ def load_image_as_tensor(image_path: str) -> torch.Tensor:
|
|
| 132 |
])
|
| 133 |
return transform(image)
|
| 134 |
|
| 135 |
-
def run_local(cfg: DictConfig):
|
| 136 |
-
# delay some imports in case they are not needed in non-local envs for submission
|
| 137 |
-
from experiments import build_experiment
|
| 138 |
-
|
| 139 |
-
# Get yaml names
|
| 140 |
-
hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
|
| 141 |
-
cfg_choice = OmegaConf.to_container(hydra_cfg.runtime.choices)
|
| 142 |
-
|
| 143 |
-
with open_dict(cfg):
|
| 144 |
-
if cfg_choice["experiment"] is not None:
|
| 145 |
-
cfg.experiment._name = cfg_choice["experiment"]
|
| 146 |
-
if cfg_choice["dataset"] is not None:
|
| 147 |
-
cfg.dataset._name = cfg_choice["dataset"]
|
| 148 |
-
if cfg_choice["algorithm"] is not None:
|
| 149 |
-
cfg.algorithm._name = cfg_choice["algorithm"]
|
| 150 |
-
|
| 151 |
-
# launch experiment
|
| 152 |
-
experiment = build_experiment(cfg, None, None)
|
| 153 |
-
return experiment.exec_interactive(cfg.experiment.tasks[0])
|
| 154 |
-
|
| 155 |
def enable_amp(model, precision="16-mixed"):
|
| 156 |
original_forward = model.forward
|
| 157 |
|
|
@@ -193,7 +173,7 @@ load_custom_checkpoint(algo=worldmem.diffusion_model, checkpoint_path=cfg.diffus
|
|
| 193 |
load_custom_checkpoint(algo=worldmem.vae, checkpoint_path=cfg.vae_path)
|
| 194 |
load_custom_checkpoint(algo=worldmem.pose_prediction_model, checkpoint_path=cfg.pose_predictor_path)
|
| 195 |
worldmem.to("cuda").eval()
|
| 196 |
-
worldmem = enable_amp(worldmem, precision="16-mixed")
|
| 197 |
|
| 198 |
actions = np.zeros((1, 25), dtype=np.float32)
|
| 199 |
poses = np.zeros((1, 5), dtype=np.float32)
|
|
@@ -555,7 +535,6 @@ with gr.Blocks(css=css) as demo:
|
|
| 555 |
)
|
| 556 |
|
| 557 |
|
| 558 |
-
# input_box.submit(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output])
|
| 559 |
submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
| 560 |
reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
| 561 |
image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
|
|
|
| 132 |
])
|
| 133 |
return transform(image)
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
def enable_amp(model, precision="16-mixed"):
|
| 136 |
original_forward = model.forward
|
| 137 |
|
|
|
|
| 173 |
load_custom_checkpoint(algo=worldmem.vae, checkpoint_path=cfg.vae_path)
|
| 174 |
load_custom_checkpoint(algo=worldmem.pose_prediction_model, checkpoint_path=cfg.pose_predictor_path)
|
| 175 |
worldmem.to("cuda").eval()
|
| 176 |
+
# worldmem = enable_amp(worldmem, precision="16-mixed")
|
| 177 |
|
| 178 |
actions = np.zeros((1, 25), dtype=np.float32)
|
| 179 |
poses = np.zeros((1, 5), dtype=np.float32)
|
|
|
|
| 535 |
)
|
| 536 |
|
| 537 |
|
|
|
|
| 538 |
submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
| 539 |
reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
| 540 |
image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|