NON_WORKING_matrix_game_2 / pipeline /causal_inference.py
jbilcke-hf's picture
jbilcke-hf HF Staff
Upload 91 files
eb94d89 verified
from typing import List, Optional
import numpy as np
import torch
import time
import copy
from einops import rearrange
from utils.wan_wrapper import WanDiffusionWrapper, WanVAEWrapper
from utils.visualize import process_video
import torch.nn.functional as F
from demo_utils.constant import ZERO_VAE_CACHE
from tqdm import tqdm
def get_current_action(mode="universal"):
CAM_VALUE = 0.1
if mode == 'universal':
print()
print('-'*30)
print("PRESS [I, K, J, L, U] FOR CAMERA TRANSFORM\n (I: up, K: down, J: left, L: right, U: no move)")
print("PRESS [W, S, A, D, Q] FOR MOVEMENT\n (W: forward, S: back, A: left, D: right, Q: no move)")
print('-'*30)
CAMERA_VALUE_MAP = {
"i": [CAM_VALUE, 0],
"k": [-CAM_VALUE, 0],
"j": [0, -CAM_VALUE],
"l": [0, CAM_VALUE],
"u": [0, 0]
}
KEYBOARD_IDX = {
"w": [1, 0, 0, 0], "s": [0, 1, 0, 0], "a": [0, 0, 1, 0], "d": [0, 0, 0, 1],
"q": [0, 0, 0, 0]
}
flag = 0
while flag != 1:
try:
idx_mouse = input('Please input the mouse action (e.g. `U`):\n').strip().lower()
idx_keyboard = input('Please input the keyboard action (e.g. `W`):\n').strip().lower()
if idx_mouse in CAMERA_VALUE_MAP.keys() and idx_keyboard in KEYBOARD_IDX.keys():
flag = 1
except:
pass
mouse_cond = torch.tensor(CAMERA_VALUE_MAP[idx_mouse]).cuda()
keyboard_cond = torch.tensor(KEYBOARD_IDX[idx_keyboard]).cuda()
elif mode == 'gta_drive':
print()
print('-'*30)
print("PRESS [W, S, A, D, Q] FOR MOVEMENT\n (W: forward, S: back, A: left, D: right, Q: no move)")
print('-'*30)
CAMERA_VALUE_MAP = {
"a": [0, -CAM_VALUE],
"d": [0, CAM_VALUE],
"q": [0, 0]
}
KEYBOARD_IDX = {
"w": [1, 0], "s": [0, 1],
"q": [0, 0]
}
flag = 0
while flag != 1:
try:
indexes = input('Please input the actions (split with ` `):\n(e.g. `W` for forward, `W A` for forward and left)\n').strip().lower().split(' ')
idx_mouse = []
idx_keyboard = []
for i in indexes:
if i in CAMERA_VALUE_MAP.keys():
idx_mouse += [i]
elif i in KEYBOARD_IDX.keys():
idx_keyboard += [i]
if len(idx_mouse) == 0:
idx_mouse += ['q']
if len(idx_keyboard) == 0:
idx_keyboard += ['q']
assert idx_mouse in [['a'], ['d'], ['q']] and idx_keyboard in [['q'], ['w'], ['s']]
flag = 1
except:
pass
mouse_cond = torch.tensor(CAMERA_VALUE_MAP[idx_mouse[0]]).cuda()
keyboard_cond = torch.tensor(KEYBOARD_IDX[idx_keyboard[0]]).cuda()
elif mode == 'templerun':
print()
print('-'*30)
print("PRESS [W, S, A, D, Z, C, Q] FOR ACTIONS\n (W: jump, S: slide, A: left side, D: right side, Z: turn left, C: turn right, Q: no move)")
print('-'*30)
KEYBOARD_IDX = {
"w": [0, 1, 0, 0, 0, 0, 0], "s": [0, 0, 1, 0, 0, 0, 0],
"a": [0, 0, 0, 0, 0, 1, 0], "d": [0, 0, 0, 0, 0, 0, 1],
"z": [0, 0, 0, 1, 0, 0, 0], "c": [0, 0, 0, 0, 1, 0, 0],
"q": [1, 0, 0, 0, 0, 0, 0]
}
flag = 0
while flag != 1:
try:
idx_keyboard = input('Please input the action: \n(e.g. `W` for forward, `Z` for turning left)\n').strip().lower()
if idx_keyboard in KEYBOARD_IDX.keys():
flag = 1
except:
pass
keyboard_cond = torch.tensor(KEYBOARD_IDX[idx_keyboard]).cuda()
if mode != 'templerun':
return {
"mouse": mouse_cond,
"keyboard": keyboard_cond
}
return {
"keyboard": keyboard_cond
}
def cond_current(conditional_dict, current_start_frame, num_frame_per_block, replace=None, mode='universal'):
new_cond = {}
new_cond["cond_concat"] = conditional_dict["cond_concat"][:, :, current_start_frame: current_start_frame + num_frame_per_block]
new_cond["visual_context"] = conditional_dict["visual_context"]
if replace != None:
if current_start_frame == 0:
last_frame_num = 1 + 4 * (num_frame_per_block - 1)
else:
last_frame_num = 4 * num_frame_per_block
final_frame = 1 + 4 * (current_start_frame + num_frame_per_block-1)
if mode != 'templerun':
conditional_dict["mouse_cond"][:, -last_frame_num + final_frame: final_frame] = replace['mouse'][None, None, :].repeat(1, last_frame_num, 1)
conditional_dict["keyboard_cond"][:, -last_frame_num + final_frame: final_frame] = replace['keyboard'][None, None, :].repeat(1, last_frame_num, 1)
if mode != 'templerun':
new_cond["mouse_cond"] = conditional_dict["mouse_cond"][:, : 1 + 4 * (current_start_frame + num_frame_per_block - 1)]
new_cond["keyboard_cond"] = conditional_dict["keyboard_cond"][:, : 1 + 4 * (current_start_frame + num_frame_per_block - 1)]
if replace != None:
return new_cond, conditional_dict
else:
return new_cond
class CausalInferencePipeline(torch.nn.Module):
def __init__(
self,
args,
device="cuda",
generator=None,
vae_decoder=None,
):
super().__init__()
# Step 1: Initialize all models
self.generator = WanDiffusionWrapper(
**getattr(args, "model_kwargs", {}), is_causal=True) if generator is None else generator
self.vae_decoder = vae_decoder
# Step 2: Initialize all causal hyperparmeters
self.scheduler = self.generator.get_scheduler()
self.denoising_step_list = torch.tensor(
args.denoising_step_list, dtype=torch.long)
if args.warp_denoising_step:
timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)))
self.denoising_step_list = timesteps[1000 - self.denoising_step_list]
self.num_transformer_blocks = 30
self.frame_seq_length = 880
self.kv_cache1 = None
self.kv_cache_mouse = None
self.kv_cache_keyboard = None
self.args = args
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
self.local_attn_size = self.generator.model.local_attn_size
assert self.local_attn_size != -1
print(f"KV inference with {self.num_frame_per_block} frames per block")
if self.num_frame_per_block > 1:
self.generator.model.num_frame_per_block = self.num_frame_per_block
def inference(
self,
noise: torch.Tensor,
conditional_dict,
initial_latent = None,
return_latents = False,
mode = 'universal',
profile = False,
) -> torch.Tensor:
"""
Perform inference on the given noise and text prompts.
Inputs:
noise (torch.Tensor): The input noise tensor of shape
(batch_size, num_output_frames, num_channels, height, width).
text_prompts (List[str]): The list of text prompts.
initial_latent (torch.Tensor): The initial latent tensor of shape
(batch_size, num_input_frames, num_channels, height, width).
If num_input_frames is 1, perform image to video.
If num_input_frames is greater than 1, perform video extension.
return_latents (bool): Whether to return the latents.
Outputs:
video (torch.Tensor): The generated video tensor of shape
(batch_size, num_output_frames, num_channels, height, width).
It is normalized to be in the range [0, 1].
"""
assert noise.shape[1] == 16
batch_size, num_channels, num_frames, height, width = noise.shape
assert num_frames % self.num_frame_per_block == 0
num_blocks = num_frames // self.num_frame_per_block
num_input_frames = initial_latent.shape[2] if initial_latent is not None else 0
num_output_frames = num_frames + num_input_frames # add the initial latent frames
output = torch.zeros(
[batch_size, num_channels, num_output_frames, height, width],
device=noise.device,
dtype=noise.dtype
)
videos = []
vae_cache = copy.deepcopy(ZERO_VAE_CACHE)
for j in range(len(vae_cache)):
vae_cache[j] = None
self.kv_cache1 = self.kv_cache_keyboard = self.kv_cache_mouse = self.crossattn_cache=None
# Step 1: Initialize KV cache to all zeros
if self.kv_cache1 is None:
self._initialize_kv_cache(
batch_size=batch_size,
dtype=noise.dtype,
device=noise.device
)
self._initialize_kv_cache_mouse_and_keyboard(
batch_size=batch_size,
dtype=noise.dtype,
device=noise.device
)
self._initialize_crossattn_cache(
batch_size=batch_size,
dtype=noise.dtype,
device=noise.device
)
else:
# reset cross attn cache
for block_index in range(self.num_transformer_blocks):
self.crossattn_cache[block_index]["is_init"] = False
# reset kv cache
for block_index in range(len(self.kv_cache1)):
self.kv_cache1[block_index]["global_end_index"] = torch.tensor(
[0], dtype=torch.long, device=noise.device)
self.kv_cache1[block_index]["local_end_index"] = torch.tensor(
[0], dtype=torch.long, device=noise.device)
self.kv_cache_mouse[block_index]["global_end_index"] = torch.tensor(
[0], dtype=torch.long, device=noise.device)
self.kv_cache_mouse[block_index]["local_end_index"] = torch.tensor(
[0], dtype=torch.long, device=noise.device)
self.kv_cache_keyboard[block_index]["global_end_index"] = torch.tensor(
[0], dtype=torch.long, device=noise.device)
self.kv_cache_keyboard[block_index]["local_end_index"] = torch.tensor(
[0], dtype=torch.long, device=noise.device)
# Step 2: Cache context feature
current_start_frame = 0
if initial_latent is not None:
timestep = torch.ones([batch_size, 1], device=noise.device, dtype=torch.int64) * 0
# Assume num_input_frames is self.num_frame_per_block * num_input_blocks
assert num_input_frames % self.num_frame_per_block == 0
num_input_blocks = num_input_frames // self.num_frame_per_block
for _ in range(num_input_blocks):
current_ref_latents = \
initial_latent[:, :, current_start_frame:current_start_frame + self.num_frame_per_block]
output[:, :, current_start_frame:current_start_frame + self.num_frame_per_block] = current_ref_latents
self.generator(
noisy_image_or_video=current_ref_latents,
conditional_dict=cond_current(conditional_dict, current_start_frame, self.num_frame_per_block, mode=mode),
timestep=timestep * 0,
kv_cache=self.kv_cache1,
kv_cache_mouse=self.kv_cache_mouse,
kv_cache_keyboard=self.kv_cache_keyboard,
crossattn_cache=self.crossattn_cache,
current_start=current_start_frame * self.frame_seq_length,
)
current_start_frame += self.num_frame_per_block
# Step 3: Temporal denoising loop
all_num_frames = [self.num_frame_per_block] * num_blocks
if profile:
diffusion_start = torch.cuda.Event(enable_timing=True)
diffusion_end = torch.cuda.Event(enable_timing=True)
for current_num_frames in tqdm(all_num_frames):
noisy_input = noise[
:, :, current_start_frame - num_input_frames:current_start_frame + current_num_frames - num_input_frames]
# Step 3.1: Spatial denoising loop
if profile:
torch.cuda.synchronize()
diffusion_start.record()
for index, current_timestep in enumerate(self.denoising_step_list):
# set current timestep
timestep = torch.ones(
[batch_size, current_num_frames],
device=noise.device,
dtype=torch.int64) * current_timestep
if index < len(self.denoising_step_list) - 1:
_, denoised_pred = self.generator(
noisy_image_or_video=noisy_input,
conditional_dict=cond_current(conditional_dict, current_start_frame, self.num_frame_per_block, mode=mode),
timestep=timestep,
kv_cache=self.kv_cache1,
kv_cache_mouse=self.kv_cache_mouse,
kv_cache_keyboard=self.kv_cache_keyboard,
crossattn_cache=self.crossattn_cache,
current_start=current_start_frame * self.frame_seq_length
)
next_timestep = self.denoising_step_list[index + 1]
noisy_input = self.scheduler.add_noise(
rearrange(denoised_pred, 'b c f h w -> (b f) c h w'),# .flatten(0, 1),
torch.randn_like(rearrange(denoised_pred, 'b c f h w -> (b f) c h w')),
next_timestep * torch.ones(
[batch_size * current_num_frames], device=noise.device, dtype=torch.long)
)
noisy_input = rearrange(noisy_input, '(b f) c h w -> b c f h w', b=denoised_pred.shape[0])
else:
# for getting real output
_, denoised_pred = self.generator(
noisy_image_or_video=noisy_input,
conditional_dict=cond_current(conditional_dict, current_start_frame, self.num_frame_per_block, mode=mode),
timestep=timestep,
kv_cache=self.kv_cache1,
kv_cache_mouse=self.kv_cache_mouse,
kv_cache_keyboard=self.kv_cache_keyboard,
crossattn_cache=self.crossattn_cache,
current_start=current_start_frame * self.frame_seq_length
)
# Step 3.2: record the model's output
output[:, :, current_start_frame:current_start_frame + current_num_frames] = denoised_pred
# Step 3.3: rerun with timestep zero to update KV cache using clean context
context_timestep = torch.ones_like(timestep) * self.args.context_noise
self.generator(
noisy_image_or_video=denoised_pred,
conditional_dict=cond_current(conditional_dict, current_start_frame, self.num_frame_per_block, mode=mode),
timestep=context_timestep,
kv_cache=self.kv_cache1,
kv_cache_mouse=self.kv_cache_mouse,
kv_cache_keyboard=self.kv_cache_keyboard,
crossattn_cache=self.crossattn_cache,
current_start=current_start_frame * self.frame_seq_length,
)
# Step 3.4: update the start and end frame indices
current_start_frame += current_num_frames
denoised_pred = denoised_pred.transpose(1,2)
video, vae_cache = self.vae_decoder(denoised_pred.half(), *vae_cache)
videos += [video]
if profile:
torch.cuda.synchronize()
diffusion_end.record()
diffusion_time = diffusion_start.elapsed_time(diffusion_end)
print(f"diffusion_time: {diffusion_time}", flush=True)
fps = video.shape[1]*1000/ diffusion_time
print(f" - FPS: {fps:.2f}")
if return_latents:
return output
else:
return videos
def _initialize_kv_cache(self, batch_size, dtype, device):
"""
Initialize a Per-GPU KV cache for the Wan model.
"""
kv_cache1 = []
if self.local_attn_size != -1:
# Use the local attention size to compute the KV cache size
kv_cache_size = self.local_attn_size * self.frame_seq_length
else:
# Use the default KV cache size
kv_cache_size = 15 * 1 * self.frame_seq_length # 32760
for _ in range(self.num_transformer_blocks):
kv_cache1.append({
"k": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
"v": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
})
self.kv_cache1 = kv_cache1 # always store the clean cache
def _initialize_kv_cache_mouse_and_keyboard(self, batch_size, dtype, device):
"""
Initialize a Per-GPU KV cache for the Wan model.
"""
kv_cache_mouse = []
kv_cache_keyboard = []
if self.local_attn_size != -1:
kv_cache_size = self.local_attn_size
else:
kv_cache_size = 15 * 1
for _ in range(self.num_transformer_blocks):
kv_cache_keyboard.append({
"k": torch.zeros([batch_size, kv_cache_size, 16, 64], dtype=dtype, device=device),
"v": torch.zeros([batch_size, kv_cache_size, 16, 64], dtype=dtype, device=device),
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
})
kv_cache_mouse.append({
"k": torch.zeros([batch_size * self.frame_seq_length, kv_cache_size, 16, 64], dtype=dtype, device=device),
"v": torch.zeros([batch_size * self.frame_seq_length, kv_cache_size, 16, 64], dtype=dtype, device=device),
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
})
self.kv_cache_keyboard = kv_cache_keyboard # always store the clean cache
self.kv_cache_mouse = kv_cache_mouse # always store the clean cache
def _initialize_crossattn_cache(self, batch_size, dtype, device):
"""
Initialize a Per-GPU cross-attention cache for the Wan model.
"""
crossattn_cache = []
for _ in range(self.num_transformer_blocks):
crossattn_cache.append({
"k": torch.zeros([batch_size, 257, 12, 128], dtype=dtype, device=device),
"v": torch.zeros([batch_size, 257, 12, 128], dtype=dtype, device=device),
"is_init": False
})
self.crossattn_cache = crossattn_cache
class CausalInferenceStreamingPipeline(torch.nn.Module):
def __init__(
self,
args,
device="cuda",
vae_decoder=None,
generator=None,
):
super().__init__()
# Step 1: Initialize all models
self.generator = WanDiffusionWrapper(
**getattr(args, "model_kwargs", {}), is_causal=True) if generator is None else generator
self.vae_decoder = vae_decoder
# Step 2: Initialize all causal hyperparmeters
self.scheduler = self.generator.get_scheduler()
self.denoising_step_list = torch.tensor(
args.denoising_step_list, dtype=torch.long)
if args.warp_denoising_step:
timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)))
self.denoising_step_list = timesteps[1000 - self.denoising_step_list]
self.num_transformer_blocks = 30
self.frame_seq_length = 880 # 1590 # HW/4
self.kv_cache1 = None
self.kv_cache_mouse = None
self.kv_cache_keyboard = None
self.args = args
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
self.local_attn_size = self.generator.model.local_attn_size
assert self.local_attn_size != -1
print(f"KV inference with {self.num_frame_per_block} frames per block")
if self.num_frame_per_block > 1:
self.generator.model.num_frame_per_block = self.num_frame_per_block
def inference(
self,
noise: torch.Tensor,
conditional_dict,
initial_latent: Optional[torch.Tensor] = None,
return_latents: bool = False,
output_folder = None,
name = None,
mode = 'universal'
) -> torch.Tensor:
"""
Perform inference on the given noise and text prompts.
Inputs:
noise (torch.Tensor): The input noise tensor of shape
(batch_size, num_output_frames, num_channels, height, width).
text_prompts (List[str]): The list of text prompts.
initial_latent (torch.Tensor): The initial latent tensor of shape
(batch_size, num_input_frames, num_channels, height, width).
If num_input_frames is 1, perform image to video.
If num_input_frames is greater than 1, perform video extension.
return_latents (bool): Whether to return the latents.
Outputs:
video (torch.Tensor): The generated video tensor of shape
(batch_size, num_output_frames, num_channels, height, width).
It is normalized to be in the range [0, 1].
"""
assert noise.shape[1] == 16
batch_size, num_channels, num_frames, height, width = noise.shape
assert num_frames % self.num_frame_per_block == 0
num_blocks = num_frames // self.num_frame_per_block
num_input_frames = initial_latent.shape[2] if initial_latent is not None else 0
num_output_frames = num_frames + num_input_frames # add the initial latent frames
output = torch.zeros(
[batch_size, num_channels, num_output_frames, height, width],
device=noise.device,
dtype=noise.dtype
)
videos = []
vae_cache = copy.deepcopy(ZERO_VAE_CACHE)
for j in range(len(vae_cache)):
vae_cache[j] = None
# Set up profiling if requested
self.kv_cache1=self.kv_cache_keyboard=self.kv_cache_mouse=self.crossattn_cache=None
# Step 1: Initialize KV cache to all zeros
if self.kv_cache1 is None:
self._initialize_kv_cache(
batch_size=batch_size,
dtype=noise.dtype,
device=noise.device
)
self._initialize_kv_cache_mouse_and_keyboard(
batch_size=batch_size,
dtype=noise.dtype,
device=noise.device
)
self._initialize_crossattn_cache(
batch_size=batch_size,
dtype=noise.dtype,
device=noise.device
)
else:
# reset cross attn cache
for block_index in range(self.num_transformer_blocks):
self.crossattn_cache[block_index]["is_init"] = False
# reset kv cache
for block_index in range(len(self.kv_cache1)):
self.kv_cache1[block_index]["global_end_index"] = torch.tensor(
[0], dtype=torch.long, device=noise.device)
self.kv_cache1[block_index]["local_end_index"] = torch.tensor(
[0], dtype=torch.long, device=noise.device)
self.kv_cache_mouse[block_index]["global_end_index"] = torch.tensor(
[0], dtype=torch.long, device=noise.device)
self.kv_cache_mouse[block_index]["local_end_index"] = torch.tensor(
[0], dtype=torch.long, device=noise.device)
self.kv_cache_keyboard[block_index]["global_end_index"] = torch.tensor(
[0], dtype=torch.long, device=noise.device)
self.kv_cache_keyboard[block_index]["local_end_index"] = torch.tensor(
[0], dtype=torch.long, device=noise.device)
# Step 2: Cache context feature
current_start_frame = 0
if initial_latent is not None:
timestep = torch.ones([batch_size, 1], device=noise.device, dtype=torch.int64) * 0
# Assume num_input_frames is self.num_frame_per_block * num_input_blocks
assert num_input_frames % self.num_frame_per_block == 0
num_input_blocks = num_input_frames // self.num_frame_per_block
for _ in range(num_input_blocks):
current_ref_latents = \
initial_latent[:, :, current_start_frame:current_start_frame + self.num_frame_per_block]
output[:, :, current_start_frame:current_start_frame + self.num_frame_per_block] = current_ref_latents
self.generator(
noisy_image_or_video=current_ref_latents,
conditional_dict=cond_current(conditional_dict, current_start_frame, self.num_frame_per_block, replace=True),
timestep=timestep * 0,
kv_cache=self.kv_cache1,
kv_cache_mouse=self.kv_cache_mouse,
kv_cache_keyboard=self.kv_cache_keyboard,
crossattn_cache=self.crossattn_cache,
current_start=current_start_frame * self.frame_seq_length,
)
current_start_frame += self.num_frame_per_block
# Step 3: Temporal denoising loop
all_num_frames = [self.num_frame_per_block] * num_blocks
for current_num_frames in all_num_frames:
noisy_input = noise[
:, :, current_start_frame - num_input_frames:current_start_frame + current_num_frames - num_input_frames]
current_actions = get_current_action(mode=mode)
new_act, conditional_dict = cond_current(conditional_dict, current_start_frame, self.num_frame_per_block, replace=current_actions, mode=mode)
# Step 3.1: Spatial denoising loop
for index, current_timestep in enumerate(self.denoising_step_list):
# set current timestep
timestep = torch.ones(
[batch_size, current_num_frames],
device=noise.device,
dtype=torch.int64) * current_timestep
if index < len(self.denoising_step_list) - 1:
_, denoised_pred = self.generator(
noisy_image_or_video=noisy_input,
conditional_dict=new_act,
timestep=timestep,
kv_cache=self.kv_cache1,
kv_cache_mouse=self.kv_cache_mouse,
kv_cache_keyboard=self.kv_cache_keyboard,
crossattn_cache=self.crossattn_cache,
current_start=current_start_frame * self.frame_seq_length
)
next_timestep = self.denoising_step_list[index + 1]
noisy_input = self.scheduler.add_noise(
rearrange(denoised_pred, 'b c f h w -> (b f) c h w'),# .flatten(0, 1),
torch.randn_like(rearrange(denoised_pred, 'b c f h w -> (b f) c h w')),
next_timestep * torch.ones(
[batch_size * current_num_frames], device=noise.device, dtype=torch.long)
)
noisy_input = rearrange(noisy_input, '(b f) c h w -> b c f h w', b=denoised_pred.shape[0])
else:
# for getting real output
_, denoised_pred = self.generator(
noisy_image_or_video=noisy_input,
conditional_dict=new_act,
timestep=timestep,
kv_cache=self.kv_cache1,
kv_cache_mouse=self.kv_cache_mouse,
kv_cache_keyboard=self.kv_cache_keyboard,
crossattn_cache=self.crossattn_cache,
current_start=current_start_frame * self.frame_seq_length
)
# Step 3.2: record the model's output
output[:, :, current_start_frame:current_start_frame + current_num_frames] = denoised_pred
# Step 3.3: rerun with timestep zero to update KV cache using clean context
context_timestep = torch.ones_like(timestep) * self.args.context_noise
self.generator(
noisy_image_or_video=denoised_pred,
conditional_dict=new_act,
timestep=context_timestep,
kv_cache=self.kv_cache1,
kv_cache_mouse=self.kv_cache_mouse,
kv_cache_keyboard=self.kv_cache_keyboard,
crossattn_cache=self.crossattn_cache,
current_start=current_start_frame * self.frame_seq_length,
)
# Step 3.4: update the start and end frame indices
denoised_pred = denoised_pred.transpose(1,2)
video, vae_cache = self.vae_decoder(denoised_pred.half(), *vae_cache)
videos += [video]
video = rearrange(video, "B T C H W -> B T H W C")
video = ((video.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)[0]
video = np.ascontiguousarray(video)
mouse_icon = 'assets/images/mouse.png'
if mode != 'templerun':
config = (
conditional_dict["keyboard_cond"][0, : 1 + 4 * (current_start_frame + self.num_frame_per_block-1)].float().cpu().numpy(),
conditional_dict["mouse_cond"][0, : 1 + 4 * (current_start_frame + self.num_frame_per_block-1)].float().cpu().numpy(),
)
else:
config = (
conditional_dict["keyboard_cond"][0, : 1 + 4 * (current_start_frame + self.num_frame_per_block-1)].float().cpu().numpy()
)
process_video(video.astype(np.uint8), output_folder+f'/{name}_current.mp4', config, mouse_icon, mouse_scale=0.1, process_icon=False, mode=mode)
current_start_frame += current_num_frames
if input("Continue? (Press `n` to break)").strip() == "n":
break
videos_tensor = torch.cat(videos, dim=1)
videos = rearrange(videos_tensor, "B T C H W -> B T H W C")
videos = ((videos.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)[0]
video = np.ascontiguousarray(videos)
mouse_icon = 'assets/images/mouse.png'
if mode != 'templerun':
config = (
conditional_dict["keyboard_cond"][0, : 1 + 4 * (current_start_frame + self.num_frame_per_block-1)].float().cpu().numpy(),
conditional_dict["mouse_cond"][0, : 1 + 4 * (current_start_frame + self.num_frame_per_block-1)].float().cpu().numpy(),
)
else:
config = (
conditional_dict["keyboard_cond"][0, : 1 + 4 * (current_start_frame + self.num_frame_per_block-1)].float().cpu().numpy()
)
process_video(video.astype(np.uint8), output_folder+f'/{name}_icon.mp4', config, mouse_icon, mouse_scale=0.1, mode=mode)
process_video(video.astype(np.uint8), output_folder+f'/{name}.mp4', config, mouse_icon, mouse_scale=0.1, process_icon=False, mode=mode)
if return_latents:
return output
else:
return video
def _initialize_kv_cache(self, batch_size, dtype, device):
"""
Initialize a Per-GPU KV cache for the Wan model.
"""
kv_cache1 = []
if self.local_attn_size != -1:
# Use the local attention size to compute the KV cache size
kv_cache_size = self.local_attn_size * self.frame_seq_length
else:
# Use the default KV cache size
kv_cache_size = 15 * 1 * self.frame_seq_length # 32760
for _ in range(self.num_transformer_blocks):
kv_cache1.append({
"k": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
"v": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
})
self.kv_cache1 = kv_cache1 # always store the clean cache
def _initialize_kv_cache_mouse_and_keyboard(self, batch_size, dtype, device):
"""
Initialize a Per-GPU KV cache for the Wan model.
"""
kv_cache_mouse = []
kv_cache_keyboard = []
if self.local_attn_size != -1:
kv_cache_size = self.local_attn_size
else:
kv_cache_size = 15 * 1
for _ in range(self.num_transformer_blocks):
kv_cache_keyboard.append({
"k": torch.zeros([batch_size, kv_cache_size, 16, 64], dtype=dtype, device=device),
"v": torch.zeros([batch_size, kv_cache_size, 16, 64], dtype=dtype, device=device),
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
})
kv_cache_mouse.append({
"k": torch.zeros([batch_size * self.frame_seq_length, kv_cache_size, 16, 64], dtype=dtype, device=device),
"v": torch.zeros([batch_size * self.frame_seq_length, kv_cache_size, 16, 64], dtype=dtype, device=device),
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
})
self.kv_cache_keyboard = kv_cache_keyboard # always store the clean cache
self.kv_cache_mouse = kv_cache_mouse # always store the clean cache
def _initialize_crossattn_cache(self, batch_size, dtype, device):
"""
Initialize a Per-GPU cross-attention cache for the Wan model.
"""
crossattn_cache = []
for _ in range(self.num_transformer_blocks):
crossattn_cache.append({
"k": torch.zeros([batch_size, 257, 12, 128], dtype=dtype, device=device),
"v": torch.zeros([batch_size, 257, 12, 128], dtype=dtype, device=device),
"is_init": False
})
self.crossattn_cache = crossattn_cache