Spaces:
Paused
Paused
import torch | |
from .wan_video_dit import DiTBlock | |
from .utils import hash_state_dict_keys | |
class VaceWanAttentionBlock(DiTBlock): | |
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0): | |
super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps) | |
self.block_id = block_id | |
if block_id == 0: | |
self.before_proj = torch.nn.Linear(self.dim, self.dim) | |
self.after_proj = torch.nn.Linear(self.dim, self.dim) | |
def forward(self, c, x, context, t_mod, freqs): | |
if self.block_id == 0: | |
c = self.before_proj(c) + x | |
all_c = [] | |
else: | |
all_c = list(torch.unbind(c)) | |
c = all_c.pop(-1) | |
c, _ = super().forward(c, context, t_mod, freqs) | |
c_skip = self.after_proj(c) | |
all_c += [c_skip, c] | |
c = torch.stack(all_c) | |
return c | |
class VaceWanModel(torch.nn.Module): | |
def __init__( | |
self, | |
vace_layers=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28), | |
vace_in_dim=96, | |
patch_size=(1, 2, 2), | |
has_image_input=False, | |
dim=1536, | |
num_heads=12, | |
ffn_dim=8960, | |
eps=1e-6, | |
): | |
super().__init__() | |
self.vace_layers = vace_layers | |
self.vace_in_dim = vace_in_dim | |
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)} | |
# vace blocks | |
self.vace_blocks = torch.nn.ModuleList( | |
[ | |
VaceWanAttentionBlock( | |
has_image_input, dim, num_heads, ffn_dim, eps, block_id=i | |
) | |
for i in self.vace_layers | |
] | |
) | |
# vace patch embeddings | |
self.vace_patch_embedding = torch.nn.Conv3d( | |
vace_in_dim, dim, kernel_size=patch_size, stride=patch_size | |
) | |
def forward( | |
self, | |
x, | |
vace_context, | |
context, | |
t_mod, | |
freqs, | |
use_gradient_checkpointing: bool = False, | |
use_gradient_checkpointing_offload: bool = False, | |
): | |
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] | |
c = [u.flatten(2).transpose(1, 2) for u in c] | |
c = torch.cat( | |
[ | |
torch.cat([u, u.new_zeros(1, x.shape[1] - u.size(1), u.size(2))], dim=1) | |
for u in c | |
] | |
) | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
return module(*inputs) | |
return custom_forward | |
for block in self.vace_blocks: | |
if use_gradient_checkpointing_offload: | |
with torch.autograd.graph.save_on_cpu(): | |
c = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(block), | |
c, | |
x, | |
context, | |
t_mod, | |
freqs, | |
use_reentrant=False, | |
) | |
elif use_gradient_checkpointing: | |
c = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(block), | |
c, | |
x, | |
context, | |
t_mod, | |
freqs, | |
use_reentrant=False, | |
) | |
else: | |
c = block(c, x, context, t_mod, freqs) | |
hints = torch.unbind(c)[:-1] | |
return hints | |
def state_dict_converter(): | |
return VaceWanModelDictConverter() | |
class VaceWanModelDictConverter: | |
def __init__(self): | |
pass | |
def from_civitai(self, state_dict): | |
state_dict_ = { | |
name: param for name, param in state_dict.items() if name.startswith("vace") | |
} | |
if ( | |
hash_state_dict_keys(state_dict_) == "3b2726384e4f64837bdf216eea3f310d" | |
): # vace 14B | |
config = { | |
"vace_layers": (0, 5, 10, 15, 20, 25, 30, 35), | |
"vace_in_dim": 96, | |
"patch_size": (1, 2, 2), | |
"has_image_input": False, | |
"dim": 5120, | |
"num_heads": 40, | |
"ffn_dim": 13824, | |
"eps": 1e-06, | |
} | |
else: | |
config = {} | |
return state_dict_, config | |