xizaoqu
commited on
Commit
·
c7542a3
1
Parent(s):
a7ea928
update
Browse files- algorithms/worldmem/df_video.py +2 -5
- app.py +142 -59
algorithms/worldmem/df_video.py
CHANGED
|
@@ -615,8 +615,6 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
| 615 |
for _ in range(condition_similar_length):
|
| 616 |
overlap_ratio = ((in_fov1.bool() & in_fov_list).sum(1)) / in_fov1.sum()
|
| 617 |
|
| 618 |
-
# if curr_frame == 54:
|
| 619 |
-
# import pdb;pdb.set_trace()
|
| 620 |
confidence = overlap_ratio + (curr_frame - frame_idx[:curr_frame]) / curr_frame * (-0.2)
|
| 621 |
|
| 622 |
if len(random_idx) > 0:
|
|
@@ -624,10 +622,11 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
| 624 |
_, r_idx = torch.topk(confidence, k=1, dim=0)
|
| 625 |
random_idx.append(r_idx[0])
|
| 626 |
|
|
|
|
| 627 |
occupied_mask = in_fov_list[r_idx[0, range(in_fov1.shape[-1])], :, range(in_fov1.shape[-1])].permute(1,0)
|
| 628 |
-
|
| 629 |
in_fov1 = in_fov1 & ~occupied_mask
|
| 630 |
|
|
|
|
| 631 |
# cos_sim = F.cosine_similarity(xs_pred.to(r_idx.device)[r_idx[:, range(in_fov1.shape[1])],
|
| 632 |
# range(in_fov1.shape[1])], xs_pred.to(r_idx.device)[:curr_frame], dim=2)
|
| 633 |
# cos_sim = cos_sim.mean((-2,-1))
|
|
@@ -637,8 +636,6 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
| 637 |
|
| 638 |
random_idx = torch.stack(random_idx).cpu()
|
| 639 |
|
| 640 |
-
print(random_idx)
|
| 641 |
-
|
| 642 |
return random_idx
|
| 643 |
|
| 644 |
def _prepare_conditions(self,
|
|
|
|
| 615 |
for _ in range(condition_similar_length):
|
| 616 |
overlap_ratio = ((in_fov1.bool() & in_fov_list).sum(1)) / in_fov1.sum()
|
| 617 |
|
|
|
|
|
|
|
| 618 |
confidence = overlap_ratio + (curr_frame - frame_idx[:curr_frame]) / curr_frame * (-0.2)
|
| 619 |
|
| 620 |
if len(random_idx) > 0:
|
|
|
|
| 622 |
_, r_idx = torch.topk(confidence, k=1, dim=0)
|
| 623 |
random_idx.append(r_idx[0])
|
| 624 |
|
| 625 |
+
# choice 1: directly remove overlapping region
|
| 626 |
occupied_mask = in_fov_list[r_idx[0, range(in_fov1.shape[-1])], :, range(in_fov1.shape[-1])].permute(1,0)
|
|
|
|
| 627 |
in_fov1 = in_fov1 & ~occupied_mask
|
| 628 |
|
| 629 |
+
# choice 2: apply similarity filter
|
| 630 |
# cos_sim = F.cosine_similarity(xs_pred.to(r_idx.device)[r_idx[:, range(in_fov1.shape[1])],
|
| 631 |
# range(in_fov1.shape[1])], xs_pred.to(r_idx.device)[:curr_frame], dim=2)
|
| 632 |
# cos_sim = cos_sim.mean((-2,-1))
|
|
|
|
| 636 |
|
| 637 |
random_idx = torch.stack(random_idx).cpu()
|
| 638 |
|
|
|
|
|
|
|
| 639 |
return random_idx
|
| 640 |
|
| 641 |
def _prepare_conditions(self,
|
app.py
CHANGED
|
@@ -70,6 +70,13 @@ KEY_TO_ACTION = {
|
|
| 70 |
"1": ("hotbar.1", 1),
|
| 71 |
}
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
def load_custom_checkpoint(algo, checkpoint_path):
|
| 74 |
hf_ckpt = str(checkpoint_path).split('/')
|
| 75 |
repo_id = '/'.join(hf_ckpt[:2])
|
|
@@ -156,7 +163,6 @@ def enable_amp(model, precision="16-mixed"):
|
|
| 156 |
return model
|
| 157 |
|
| 158 |
memory_frames = []
|
| 159 |
-
memory_curr_frame = 0
|
| 160 |
input_history = ""
|
| 161 |
ICE_PLAINS_IMAGE = "assets/ice_plains.png"
|
| 162 |
DESERT_IMAGE = "assets/desert.png"
|
|
@@ -166,7 +172,6 @@ PLACE_IMAGE = "assets/place.png"
|
|
| 166 |
SUNFLOWERS_IMAGE = "assets/sunflower_plains.png"
|
| 167 |
SUNFLOWERS_RAIN_IMAGE = "assets/rain_sunflower_plains.png"
|
| 168 |
|
| 169 |
-
DEFAULT_IMAGE = ICE_PLAINS_IMAGE
|
| 170 |
device = torch.device('cuda')
|
| 171 |
|
| 172 |
def save_video(frames, path="output.mp4", fps=10):
|
|
@@ -193,13 +198,6 @@ worldmem = enable_amp(worldmem, precision="16-mixed")
|
|
| 193 |
actions = np.zeros((1, 25), dtype=np.float32)
|
| 194 |
poses = np.zeros((1, 5), dtype=np.float32)
|
| 195 |
|
| 196 |
-
memory_frames = load_image_as_tensor(DEFAULT_IMAGE)[None].numpy()
|
| 197 |
-
|
| 198 |
-
self_frames = None
|
| 199 |
-
self_actions = None
|
| 200 |
-
self_poses = None
|
| 201 |
-
self_memory_c2w = None
|
| 202 |
-
self_frame_idx = None
|
| 203 |
|
| 204 |
|
| 205 |
def get_duration_single_image_to_long_video(first_frame, action, first_pose, device, self_frames, self_actions,
|
|
@@ -240,17 +238,8 @@ def set_memory_length(memory_length, sampling_memory_length_state):
|
|
| 240 |
print("set memory length to", worldmem.condition_similar_length)
|
| 241 |
return sampling_memory_length_state
|
| 242 |
|
| 243 |
-
def generate(keys):
|
| 244 |
-
# print("algo frame:", len(worldmem.frames))
|
| 245 |
input_actions = parse_input_to_tensor(keys)
|
| 246 |
-
global input_history
|
| 247 |
-
global memory_frames
|
| 248 |
-
global memory_curr_frame
|
| 249 |
-
global self_frames
|
| 250 |
-
global self_actions
|
| 251 |
-
global self_poses
|
| 252 |
-
global self_memory_c2w
|
| 253 |
-
global self_frame_idx
|
| 254 |
|
| 255 |
if self_frames is None:
|
| 256 |
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
|
|
@@ -282,25 +271,34 @@ def generate(keys):
|
|
| 282 |
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
| 283 |
save_video(out_video, temporal_video_path)
|
| 284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
input_history += keys
|
| 286 |
-
return out_video[-1], temporal_video_path, input_history
|
| 287 |
-
|
| 288 |
-
def reset():
|
| 289 |
-
global memory_curr_frame
|
| 290 |
-
global input_history
|
| 291 |
-
global memory_frames
|
| 292 |
-
global self_frames
|
| 293 |
-
global self_actions
|
| 294 |
-
global self_poses
|
| 295 |
-
global self_memory_c2w
|
| 296 |
-
global self_frame_idx
|
| 297 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
self_frames = None
|
| 299 |
self_poses = None
|
| 300 |
self_actions = None
|
| 301 |
self_memory_c2w = None
|
| 302 |
self_frame_idx = None
|
| 303 |
-
memory_frames = load_image_as_tensor(
|
| 304 |
input_history = ""
|
| 305 |
|
| 306 |
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
|
|
@@ -313,14 +311,58 @@ def reset():
|
|
| 313 |
self_memory_c2w=self_memory_c2w,
|
| 314 |
self_frame_idx=self_frame_idx)
|
| 315 |
|
| 316 |
-
return input_history,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
reset()
|
| 322 |
-
return SELECTED_IMAGE
|
| 323 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
|
| 325 |
css = """
|
| 326 |
h1 {
|
|
@@ -329,6 +371,10 @@ h1 {
|
|
| 329 |
}
|
| 330 |
"""
|
| 331 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
with gr.Blocks(css=css) as demo:
|
| 333 |
gr.Markdown(
|
| 334 |
"""
|
|
@@ -358,13 +404,18 @@ with gr.Blocks(css=css) as demo:
|
|
| 358 |
# </a>
|
| 359 |
# </div>
|
| 360 |
|
| 361 |
-
example_actions =
|
| 362 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
|
|
|
|
| 364 |
|
| 365 |
with gr.Row(variant="panel"):
|
| 366 |
video_display = gr.Video(autoplay=True, loop=True)
|
| 367 |
-
image_display = gr.Image(value=
|
| 368 |
|
| 369 |
|
| 370 |
with gr.Row(variant="panel"):
|
|
@@ -374,17 +425,17 @@ with gr.Blocks(css=css) as demo:
|
|
| 374 |
gr.Markdown("### Action sequence examples.")
|
| 375 |
with gr.Row():
|
| 376 |
buttons = []
|
| 377 |
-
for
|
| 378 |
-
with gr.Column(scale=len(
|
| 379 |
-
buttons.append(gr.Button(
|
| 380 |
with gr.Row():
|
| 381 |
-
for
|
| 382 |
-
with gr.Column(scale=len(
|
| 383 |
-
buttons.append(gr.Button(
|
| 384 |
with gr.Row():
|
| 385 |
-
for
|
| 386 |
-
with gr.Column(scale=len(
|
| 387 |
-
buttons.append(gr.Button(
|
| 388 |
|
| 389 |
with gr.Column(scale=1):
|
| 390 |
slider_denoising_step = gr.Slider(minimum=10, maximum=50, value=worldmem.sampling_timesteps, step=1, label="Denoising Steps")
|
|
@@ -397,6 +448,12 @@ with gr.Blocks(css=css) as demo:
|
|
| 397 |
sampling_context_length_state = gr.State(worldmem.n_tokens)
|
| 398 |
sampling_memory_length_state = gr.State(worldmem.condition_similar_length)
|
| 399 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
|
| 401 |
def set_action(action):
|
| 402 |
return action
|
|
@@ -404,8 +461,8 @@ with gr.Blocks(css=css) as demo:
|
|
| 404 |
# gr.Markdown("### Action sequence examples.")
|
| 405 |
|
| 406 |
|
| 407 |
-
for button,
|
| 408 |
-
button.click(set_action, inputs=[gr.State(value=
|
| 409 |
|
| 410 |
|
| 411 |
gr.Markdown("### Click on the images below to reset the sequence and generate from the new image.")
|
|
@@ -418,6 +475,32 @@ with gr.Blocks(css=css) as demo:
|
|
| 418 |
image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
|
| 419 |
image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
|
| 420 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
gr.Markdown(
|
| 422 |
"""
|
| 423 |
## Instructions & Notes:
|
|
@@ -441,14 +524,14 @@ with gr.Blocks(css=css) as demo:
|
|
| 441 |
"""
|
| 442 |
)
|
| 443 |
# input_box.submit(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output])
|
| 444 |
-
submit_button.click(generate, inputs=[input_box], outputs=[image_display, video_display, log_output])
|
| 445 |
-
reset_btn.click(reset, outputs=[log_output,
|
| 446 |
-
image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=image_display)
|
| 447 |
-
image_display_2.select(lambda: on_image_click(DESERT_IMAGE), outputs=image_display)
|
| 448 |
-
image_display_3.select(lambda: on_image_click(SAVANNA_IMAGE), outputs=image_display)
|
| 449 |
-
image_display_4.select(lambda: on_image_click(ICE_PLAINS_IMAGE), outputs=image_display)
|
| 450 |
-
image_display_5.select(lambda: on_image_click(SUNFLOWERS_RAIN_IMAGE), outputs=image_display)
|
| 451 |
-
image_display_6.select(lambda: on_image_click(PLACE_IMAGE), outputs=image_display)
|
| 452 |
|
| 453 |
slider_denoising_step.change(fn=set_denoising_steps, inputs=[slider_denoising_step, sampling_timesteps_state], outputs=sampling_timesteps_state)
|
| 454 |
slider_context_length.change(fn=set_context_length, inputs=[slider_context_length, sampling_context_length_state], outputs=sampling_context_length_state)
|
|
|
|
| 70 |
"1": ("hotbar.1", 1),
|
| 71 |
}
|
| 72 |
|
| 73 |
+
example_images = [
|
| 74 |
+
["1", "assets/ice_plains.png", "turn right+go backward+look up+turn left+look down+turn right+go forward+turn left", 20, 3, 8],
|
| 75 |
+
["2", "assets/place.png", "put item+go backward+put item+go backward+go around", 20, 3, 8],
|
| 76 |
+
["3", "assets/rain_sunflower_plains.png", "turn right+look up+turn right+look down+turn left+go backward+turn left", 20, 3, 8],
|
| 77 |
+
["4", "assets/desert.png", "turn 360 degree+turn right+go forward+turn left", 20, 3, 8],
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
def load_custom_checkpoint(algo, checkpoint_path):
|
| 81 |
hf_ckpt = str(checkpoint_path).split('/')
|
| 82 |
repo_id = '/'.join(hf_ckpt[:2])
|
|
|
|
| 163 |
return model
|
| 164 |
|
| 165 |
memory_frames = []
|
|
|
|
| 166 |
input_history = ""
|
| 167 |
ICE_PLAINS_IMAGE = "assets/ice_plains.png"
|
| 168 |
DESERT_IMAGE = "assets/desert.png"
|
|
|
|
| 172 |
SUNFLOWERS_IMAGE = "assets/sunflower_plains.png"
|
| 173 |
SUNFLOWERS_RAIN_IMAGE = "assets/rain_sunflower_plains.png"
|
| 174 |
|
|
|
|
| 175 |
device = torch.device('cuda')
|
| 176 |
|
| 177 |
def save_video(frames, path="output.mp4", fps=10):
|
|
|
|
| 198 |
actions = np.zeros((1, 25), dtype=np.float32)
|
| 199 |
poses = np.zeros((1, 5), dtype=np.float32)
|
| 200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
|
| 203 |
def get_duration_single_image_to_long_video(first_frame, action, first_pose, device, self_frames, self_actions,
|
|
|
|
| 238 |
print("set memory length to", worldmem.condition_similar_length)
|
| 239 |
return sampling_memory_length_state
|
| 240 |
|
| 241 |
+
def generate(keys, input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx):
|
|
|
|
| 242 |
input_actions = parse_input_to_tensor(keys)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
if self_frames is None:
|
| 245 |
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
|
|
|
|
| 271 |
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
| 272 |
save_video(out_video, temporal_video_path)
|
| 273 |
|
| 274 |
+
|
| 275 |
+
now = datetime.now()
|
| 276 |
+
folder_name = now.strftime("%Y-%m-%d_%H-%M-%S")
|
| 277 |
+
folder_path = os.path.join("/mnt/xiaozeqi/worldmem/output_material", folder_name)
|
| 278 |
+
os.makedirs(folder_path, exist_ok=True)
|
| 279 |
input_history += keys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
|
| 281 |
+
data_dict = {
|
| 282 |
+
"input_history": input_history,
|
| 283 |
+
"memory_frames": memory_frames,
|
| 284 |
+
"self_frames": self_frames,
|
| 285 |
+
"self_actions": self_actions,
|
| 286 |
+
"self_poses": self_poses,
|
| 287 |
+
"self_memory_c2w": self_memory_c2w,
|
| 288 |
+
"self_frame_idx": self_frame_idx,
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
np.savez(os.path.join(folder_path, "data_bundle.npz"), **data_dict)
|
| 292 |
+
|
| 293 |
+
return out_video[-1], temporal_video_path, input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
|
| 294 |
+
|
| 295 |
+
def reset(selected_image):
|
| 296 |
self_frames = None
|
| 297 |
self_poses = None
|
| 298 |
self_actions = None
|
| 299 |
self_memory_c2w = None
|
| 300 |
self_frame_idx = None
|
| 301 |
+
memory_frames = load_image_as_tensor(selected_image).numpy()[None]
|
| 302 |
input_history = ""
|
| 303 |
|
| 304 |
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
|
|
|
|
| 311 |
self_memory_c2w=self_memory_c2w,
|
| 312 |
self_frame_idx=self_frame_idx)
|
| 313 |
|
| 314 |
+
return input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
|
| 315 |
+
|
| 316 |
+
def on_image_click(selected_image):
|
| 317 |
+
input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = reset(selected_image)
|
| 318 |
+
return input_history, selected_image, selected_image, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
|
| 319 |
+
|
| 320 |
+
def set_memory(examples_case, image_display, log_output, slider_denoising_step, slider_context_length, slider_memory_length):
|
| 321 |
+
if examples_case == '1':
|
| 322 |
+
data_bundle = np.load("/mnt/xiaozeqi/worldmem/output_material/2025-04-11_16-01-49/data_bundle.npz")
|
| 323 |
+
input_history = data_bundle['input_history'].item()
|
| 324 |
+
memory_frames = data_bundle['memory_frames']
|
| 325 |
+
self_frames = data_bundle['self_frames']
|
| 326 |
+
self_actions = data_bundle['self_actions']
|
| 327 |
+
self_poses = data_bundle['self_poses']
|
| 328 |
+
self_memory_c2w = data_bundle['self_memory_c2w']
|
| 329 |
+
self_frame_idx = data_bundle['self_frame_idx']
|
| 330 |
+
elif examples_case == '2':
|
| 331 |
+
data_bundle = np.load("/mnt/xiaozeqi/worldmem/output_material/2025-04-12_10-42-04/data_bundle.npz")
|
| 332 |
+
input_history = data_bundle['input_history'].item()
|
| 333 |
+
memory_frames = data_bundle['memory_frames']
|
| 334 |
+
self_frames = data_bundle['self_frames']
|
| 335 |
+
self_actions = data_bundle['self_actions']
|
| 336 |
+
self_poses = data_bundle['self_poses']
|
| 337 |
+
self_memory_c2w = data_bundle['self_memory_c2w']
|
| 338 |
+
self_frame_idx = data_bundle['self_frame_idx']
|
| 339 |
+
elif examples_case == '3':
|
| 340 |
+
data_bundle = np.load("/mnt/xiaozeqi/worldmem/output_material/2025-04-12_10-56-57/data_bundle.npz")
|
| 341 |
+
input_history = data_bundle['input_history'].item()
|
| 342 |
+
memory_frames = data_bundle['memory_frames']
|
| 343 |
+
self_frames = data_bundle['self_frames']
|
| 344 |
+
self_actions = data_bundle['self_actions']
|
| 345 |
+
self_poses = data_bundle['self_poses']
|
| 346 |
+
self_memory_c2w = data_bundle['self_memory_c2w']
|
| 347 |
+
self_frame_idx = data_bundle['self_frame_idx']
|
| 348 |
+
elif examples_case == '4':
|
| 349 |
+
data_bundle = np.load("/mnt/xiaozeqi/worldmem/output_material/2025-04-11_16-07-19/data_bundle.npz")
|
| 350 |
+
input_history = data_bundle['input_history'].item()
|
| 351 |
+
memory_frames = data_bundle['memory_frames']
|
| 352 |
+
self_frames = data_bundle['self_frames']
|
| 353 |
+
self_actions = data_bundle['self_actions']
|
| 354 |
+
self_poses = data_bundle['self_poses']
|
| 355 |
+
self_memory_c2w = data_bundle['self_memory_c2w']
|
| 356 |
+
self_frame_idx = data_bundle['self_frame_idx']
|
| 357 |
|
| 358 |
+
out_video = memory_frames.transpose(0,2,3,1)
|
| 359 |
+
out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
|
| 360 |
+
out_video = (out_video * 255).astype(np.uint8)
|
|
|
|
|
|
|
| 361 |
|
| 362 |
+
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
| 363 |
+
save_video(out_video, temporal_video_path)
|
| 364 |
+
|
| 365 |
+
return input_history, out_video[-1], temporal_video_path, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
|
| 366 |
|
| 367 |
css = """
|
| 368 |
h1 {
|
|
|
|
| 371 |
}
|
| 372 |
"""
|
| 373 |
|
| 374 |
+
def on_select(evt: gr.SelectData):
|
| 375 |
+
selected_index = evt.index
|
| 376 |
+
return examples[selected_index]
|
| 377 |
+
|
| 378 |
with gr.Blocks(css=css) as demo:
|
| 379 |
gr.Markdown(
|
| 380 |
"""
|
|
|
|
| 404 |
# </a>
|
| 405 |
# </div>
|
| 406 |
|
| 407 |
+
example_actions = {"turn left + turn right": "AAAAAAAAAAAADDDDDDDDDDDD",
|
| 408 |
+
"turn 360 degree": "AAAAAAAAAAAAAAAAAAAAAAAA",
|
| 409 |
+
"turn right+go backward+look up+turn left+look down": "DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW",
|
| 410 |
+
"turn right+go forward+turn left": "DDDDDDDDDDDDQQQQQQQQQQQQQQQDDDDDDDDDDDD",
|
| 411 |
+
"turn right+look up+turn right+look down": "DDDDWWWDDDDDDDDDDDDDDDDDDDDSSS",
|
| 412 |
+
"put item+go backward+put item+go backward":"SSUNNWWEEEEEEEEEAAASSUNNWWEEEEEEEEE"}
|
| 413 |
|
| 414 |
+
selected_image = gr.State(ICE_PLAINS_IMAGE)
|
| 415 |
|
| 416 |
with gr.Row(variant="panel"):
|
| 417 |
video_display = gr.Video(autoplay=True, loop=True)
|
| 418 |
+
image_display = gr.Image(value=selected_image.value, interactive=False, label="Current Frame")
|
| 419 |
|
| 420 |
|
| 421 |
with gr.Row(variant="panel"):
|
|
|
|
| 425 |
gr.Markdown("### Action sequence examples.")
|
| 426 |
with gr.Row():
|
| 427 |
buttons = []
|
| 428 |
+
for action_key in list(example_actions.keys())[:2]:
|
| 429 |
+
with gr.Column(scale=len(action_key)):
|
| 430 |
+
buttons.append(gr.Button(action_key))
|
| 431 |
with gr.Row():
|
| 432 |
+
for action_key in list(example_actions.keys())[2:4]:
|
| 433 |
+
with gr.Column(scale=len(action_key)):
|
| 434 |
+
buttons.append(gr.Button(action_key))
|
| 435 |
with gr.Row():
|
| 436 |
+
for action_key in list(example_actions.keys())[4:6]:
|
| 437 |
+
with gr.Column(scale=len(action_key)):
|
| 438 |
+
buttons.append(gr.Button(action_key))
|
| 439 |
|
| 440 |
with gr.Column(scale=1):
|
| 441 |
slider_denoising_step = gr.Slider(minimum=10, maximum=50, value=worldmem.sampling_timesteps, step=1, label="Denoising Steps")
|
|
|
|
| 448 |
sampling_context_length_state = gr.State(worldmem.n_tokens)
|
| 449 |
sampling_memory_length_state = gr.State(worldmem.condition_similar_length)
|
| 450 |
|
| 451 |
+
memory_frames = gr.State(load_image_as_tensor(selected_image.value)[None].numpy())
|
| 452 |
+
self_frames = gr.State()
|
| 453 |
+
self_actions = gr.State()
|
| 454 |
+
self_poses = gr.State()
|
| 455 |
+
self_memory_c2w = gr.State()
|
| 456 |
+
self_frame_idx = gr.State()
|
| 457 |
|
| 458 |
def set_action(action):
|
| 459 |
return action
|
|
|
|
| 461 |
# gr.Markdown("### Action sequence examples.")
|
| 462 |
|
| 463 |
|
| 464 |
+
for button, action_key in zip(buttons, list(example_actions.keys())):
|
| 465 |
+
button.click(set_action, inputs=[gr.State(value=example_actions[action_key])], outputs=input_box)
|
| 466 |
|
| 467 |
|
| 468 |
gr.Markdown("### Click on the images below to reset the sequence and generate from the new image.")
|
|
|
|
| 475 |
image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
|
| 476 |
image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
|
| 477 |
|
| 478 |
+
gr.Markdown("### Click the examples below for a quick review, and continue generating based on them.")
|
| 479 |
+
|
| 480 |
+
example_case = gr.Textbox(label="Case", visible=False)
|
| 481 |
+
image_output = gr.Image(visible=False)
|
| 482 |
+
|
| 483 |
+
# gr.Examples(examples=example_images,
|
| 484 |
+
# inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
|
| 485 |
+
# fn=set_memory,
|
| 486 |
+
# outputs=[log_output, image_display, video_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx],
|
| 487 |
+
# cache_examples=True
|
| 488 |
+
# )
|
| 489 |
+
|
| 490 |
+
examples = gr.Examples(
|
| 491 |
+
examples=example_images,
|
| 492 |
+
inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
|
| 493 |
+
cache_examples=False
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
example_case.change(
|
| 497 |
+
fn=set_memory,
|
| 498 |
+
inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length],
|
| 499 |
+
outputs=[log_output, image_display, video_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx]
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
|
| 504 |
gr.Markdown(
|
| 505 |
"""
|
| 506 |
## Instructions & Notes:
|
|
|
|
| 524 |
"""
|
| 525 |
)
|
| 526 |
# input_box.submit(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output])
|
| 527 |
+
submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
| 528 |
+
reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
| 529 |
+
image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
| 530 |
+
image_display_2.select(lambda: on_image_click(DESERT_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
| 531 |
+
image_display_3.select(lambda: on_image_click(SAVANNA_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
| 532 |
+
image_display_4.select(lambda: on_image_click(ICE_PLAINS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
| 533 |
+
image_display_5.select(lambda: on_image_click(SUNFLOWERS_RAIN_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
| 534 |
+
image_display_6.select(lambda: on_image_click(PLACE_IMAGE), outputs=[log_output, selected_image,image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx])
|
| 535 |
|
| 536 |
slider_denoising_step.change(fn=set_denoising_steps, inputs=[slider_denoising_step, sampling_timesteps_state], outputs=sampling_timesteps_state)
|
| 537 |
slider_context_length.change(fn=set_context_length, inputs=[slider_context_length, sampling_context_length_state], outputs=sampling_context_length_state)
|