|
import torch |
|
from ..models import SDUNet, SDMotionModel |
|
from ..models.sd_unet import PushBlock, PopBlock, ResnetBlock, AttentionBlock |
|
from ..models.tiler import TileWorker |
|
from ..controlnets import MultiControlNetManager |
|
|
|
|
|
def lets_dance( |
|
unet: SDUNet, |
|
motion_modules: SDMotionModel = None, |
|
controlnet: MultiControlNetManager = None, |
|
sample = None, |
|
timestep = None, |
|
encoder_hidden_states = None, |
|
controlnet_frames = None, |
|
unet_batch_size = 1, |
|
controlnet_batch_size = 1, |
|
cross_frame_attention = False, |
|
tiled=False, |
|
tile_size=64, |
|
tile_stride=32, |
|
device = "cuda", |
|
vram_limit_level = 0, |
|
): |
|
|
|
|
|
|
|
controlnet_insert_block_id = 30 |
|
if controlnet is not None and controlnet_frames is not None: |
|
res_stacks = [] |
|
|
|
for batch_id in range(0, sample.shape[0], controlnet_batch_size): |
|
batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0]) |
|
res_stack = controlnet( |
|
sample[batch_id: batch_id_], |
|
timestep, |
|
encoder_hidden_states[batch_id: batch_id_], |
|
controlnet_frames[:, batch_id: batch_id_], |
|
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride |
|
) |
|
if vram_limit_level >= 1: |
|
res_stack = [res.cpu() for res in res_stack] |
|
res_stacks.append(res_stack) |
|
|
|
additional_res_stack = [] |
|
for i in range(len(res_stacks[0])): |
|
res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0) |
|
additional_res_stack.append(res) |
|
else: |
|
additional_res_stack = None |
|
|
|
|
|
time_emb = unet.time_proj(timestep[None]).to(sample.dtype) |
|
time_emb = unet.time_embedding(time_emb) |
|
|
|
|
|
height, width = sample.shape[2], sample.shape[3] |
|
hidden_states = unet.conv_in(sample) |
|
text_emb = encoder_hidden_states |
|
res_stack = [hidden_states.cpu() if vram_limit_level>=1 else hidden_states] |
|
|
|
|
|
for block_id, block in enumerate(unet.blocks): |
|
|
|
if isinstance(block, PushBlock): |
|
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) |
|
if vram_limit_level>=1: |
|
res_stack[-1] = res_stack[-1].cpu() |
|
elif isinstance(block, PopBlock): |
|
if vram_limit_level>=1: |
|
res_stack[-1] = res_stack[-1].to(device) |
|
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) |
|
else: |
|
hidden_states_input = hidden_states |
|
hidden_states_output = [] |
|
for batch_id in range(0, sample.shape[0], unet_batch_size): |
|
batch_id_ = min(batch_id + unet_batch_size, sample.shape[0]) |
|
hidden_states, _, _, _ = block( |
|
hidden_states_input[batch_id: batch_id_], |
|
time_emb, |
|
text_emb[batch_id: batch_id_], |
|
res_stack, |
|
cross_frame_attention=cross_frame_attention, |
|
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride |
|
) |
|
hidden_states_output.append(hidden_states) |
|
hidden_states = torch.concat(hidden_states_output, dim=0) |
|
|
|
if motion_modules is not None: |
|
if block_id in motion_modules.call_block_id: |
|
motion_module_id = motion_modules.call_block_id[block_id] |
|
hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id]( |
|
hidden_states, time_emb, text_emb, res_stack, |
|
batch_size=1 |
|
) |
|
|
|
if block_id == controlnet_insert_block_id and additional_res_stack is not None: |
|
hidden_states += additional_res_stack.pop().to(device) |
|
if vram_limit_level>=1: |
|
res_stack = [(res.to(device) + additional_res.to(device)).cpu() for res, additional_res in zip(res_stack, additional_res_stack)] |
|
else: |
|
res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)] |
|
|
|
|
|
hidden_states = unet.conv_norm_out(hidden_states) |
|
hidden_states = unet.conv_act(hidden_states) |
|
hidden_states = unet.conv_out(hidden_states) |
|
|
|
return hidden_states |
|
|