|
from __future__ import annotations |
|
|
|
import inspect |
|
import math |
|
import time |
|
import warnings |
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
|
from dataclasses import dataclass |
|
|
|
from einops import rearrange, repeat |
|
import PIL.Image |
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
|
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer |
|
from diffusers.pipelines.controlnet.pipeline_controlnet import ( |
|
StableDiffusionSafetyChecker, |
|
EXAMPLE_DOC_STRING, |
|
) |
|
from diffusers.pipelines.controlnet.pipeline_controlnet_img2img import ( |
|
StableDiffusionControlNetImg2ImgPipeline as DiffusersStableDiffusionControlNetImg2ImgPipeline, |
|
) |
|
from diffusers.configuration_utils import FrozenDict |
|
from diffusers.models import AutoencoderKL, ControlNetModel |
|
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel |
|
from diffusers.pipelines.stable_diffusion.safety_checker import ( |
|
StableDiffusionSafetyChecker, |
|
) |
|
|
|
from diffusers.schedulers import KarrasDiffusionSchedulers |
|
from diffusers.utils import ( |
|
deprecate, |
|
logging, |
|
BaseOutput, |
|
replace_example_docstring, |
|
) |
|
from diffusers.utils.torch_utils import is_compiled_module |
|
from diffusers.loaders import TextualInversionLoaderMixin |
|
from diffusers.models.attention import ( |
|
BasicTransformerBlock as DiffusersBasicTransformerBlock, |
|
) |
|
from mmcm.vision.process.correct_color import ( |
|
hist_match_color_video_batch, |
|
hist_match_video_bcthw, |
|
) |
|
|
|
from ..models.attention import BasicTransformerBlock |
|
from ..models.unet_3d_condition import UNet3DConditionModel |
|
from ..utils.noise_util import random_noise, video_fusion_noise |
|
from ..data.data_util import ( |
|
adaptive_instance_normalization, |
|
align_repeat_tensor_single_dim, |
|
batch_adain_conditioned_tensor, |
|
batch_concat_two_tensor_with_index, |
|
batch_index_select, |
|
fuse_part_tensor, |
|
) |
|
from ..utils.text_emb_util import encode_weighted_prompt |
|
from ..utils.tensor_util import his_match |
|
from ..utils.timesteps_util import generate_parameters_with_timesteps |
|
from .context import get_context_scheduler, prepare_global_context |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
@dataclass |
|
class VideoPipelineOutput(BaseOutput): |
|
videos: Union[torch.Tensor, np.ndarray] |
|
latents: Union[torch.Tensor, np.ndarray] |
|
videos_mid: Union[torch.Tensor, np.ndarray] |
|
down_block_res_samples: Tuple[torch.FloatTensor] = None |
|
mid_block_res_samples: torch.FloatTensor = None |
|
up_block_res_samples: torch.FloatTensor = None |
|
mid_video_latents: List[torch.FloatTensor] = None |
|
mid_video_noises: List[torch.FloatTensor] = None |
|
|
|
|
|
def torch_dfs(model: torch.nn.Module): |
|
result = [model] |
|
for child in model.children(): |
|
result += torch_dfs(child) |
|
return result |
|
|
|
|
|
def prepare_image( |
|
image, |
|
batch_size, |
|
device, |
|
dtype, |
|
image_processor: Callable, |
|
num_images_per_prompt: int = 1, |
|
width=None, |
|
height=None, |
|
): |
|
if isinstance(image, List) and isinstance(image[0], str): |
|
raise NotImplementedError |
|
if isinstance(image, List) and isinstance(image[0], np.ndarray): |
|
image = np.concatenate(image, axis=0) |
|
if isinstance(image, np.ndarray): |
|
image = torch.from_numpy(image) |
|
if image.ndim == 5: |
|
image = rearrange(image, "b c t h w-> (b t) c h w") |
|
if height is None: |
|
height = image.shape[-2] |
|
if width is None: |
|
width = image.shape[-1] |
|
width, height = (x - x % image_processor.vae_scale_factor for x in (width, height)) |
|
if height != image.shape[-2] or width != image.shape[-1]: |
|
image = torch.nn.functional.interpolate( |
|
image, size=(height, width), mode="bilinear" |
|
) |
|
image = image.to(dtype=torch.float32) / 255.0 |
|
do_normalize = image_processor.config.do_normalize |
|
if image.min() < 0: |
|
warnings.warn( |
|
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " |
|
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", |
|
FutureWarning, |
|
) |
|
do_normalize = False |
|
|
|
if do_normalize: |
|
image = image_processor.normalize(image) |
|
|
|
image_batch_size = image.shape[0] |
|
|
|
if image_batch_size == 1: |
|
repeat_by = batch_size |
|
else: |
|
|
|
repeat_by = num_images_per_prompt |
|
|
|
image = image.repeat_interleave(repeat_by, dim=0) |
|
|
|
image = image.to(device=device, dtype=dtype) |
|
return image |
|
|
|
|
|
class MusevControlNetPipeline( |
|
DiffusersStableDiffusionControlNetImg2ImgPipeline, TextualInversionLoaderMixin |
|
): |
|
""" |
|
a union diffusers pipeline, support |
|
1. text2image model only, or text2video model, by setting skip_temporal_layer |
|
2. text2video, image2video, video2video; |
|
3. multi controlnet |
|
4. IPAdapter |
|
5. referencenet |
|
6. IPAdapterFaceID |
|
""" |
|
|
|
_optional_components = [ |
|
"safety_checker", |
|
"feature_extractor", |
|
] |
|
print_idx = 0 |
|
|
|
def __init__( |
|
self, |
|
vae: AutoencoderKL, |
|
unet: UNet3DConditionModel, |
|
scheduler: KarrasDiffusionSchedulers, |
|
controlnet: ControlNetModel |
|
| List[ControlNetModel] |
|
| Tuple[ControlNetModel] |
|
| MultiControlNetModel, |
|
text_encoder: CLIPTextModel, |
|
tokenizer: CLIPTokenizer, |
|
safety_checker: StableDiffusionSafetyChecker, |
|
feature_extractor: CLIPImageProcessor, |
|
|
|
|
|
|
|
|
|
|
|
requires_safety_checker: bool = False, |
|
referencenet: nn.Module = None, |
|
vision_clip_extractor: nn.Module = None, |
|
ip_adapter_image_proj: nn.Module = None, |
|
face_emb_extractor: nn.Module = None, |
|
facein_image_proj: nn.Module = None, |
|
ip_adapter_face_emb_extractor: nn.Module = None, |
|
ip_adapter_face_image_proj: nn.Module = None, |
|
pose_guider: nn.Module = None, |
|
): |
|
super().__init__( |
|
vae, |
|
text_encoder, |
|
tokenizer, |
|
unet, |
|
controlnet, |
|
scheduler, |
|
safety_checker, |
|
feature_extractor, |
|
requires_safety_checker, |
|
) |
|
self.referencenet = referencenet |
|
|
|
|
|
if isinstance(vision_clip_extractor, nn.Module): |
|
vision_clip_extractor.to(dtype=self.unet.dtype, device=self.unet.device) |
|
self.vision_clip_extractor = vision_clip_extractor |
|
if isinstance(ip_adapter_image_proj, nn.Module): |
|
ip_adapter_image_proj.to(dtype=self.unet.dtype, device=self.unet.device) |
|
self.ip_adapter_image_proj = ip_adapter_image_proj |
|
|
|
|
|
if isinstance(face_emb_extractor, nn.Module): |
|
face_emb_extractor.to(dtype=self.unet.dtype, device=self.unet.device) |
|
self.face_emb_extractor = face_emb_extractor |
|
if isinstance(facein_image_proj, nn.Module): |
|
facein_image_proj.to(dtype=self.unet.dtype, device=self.unet.device) |
|
self.facein_image_proj = facein_image_proj |
|
|
|
|
|
if isinstance(ip_adapter_face_emb_extractor, nn.Module): |
|
ip_adapter_face_emb_extractor.to( |
|
dtype=self.unet.dtype, device=self.unet.device |
|
) |
|
self.ip_adapter_face_emb_extractor = ip_adapter_face_emb_extractor |
|
if isinstance(ip_adapter_face_image_proj, nn.Module): |
|
ip_adapter_face_image_proj.to( |
|
dtype=self.unet.dtype, device=self.unet.device |
|
) |
|
self.ip_adapter_face_image_proj = ip_adapter_face_image_proj |
|
|
|
if isinstance(pose_guider, nn.Module): |
|
pose_guider.to(dtype=self.unet.dtype, device=self.unet.device) |
|
self.pose_guider = pose_guider |
|
|
|
def decode_latents(self, latents): |
|
batch_size = latents.shape[0] |
|
latents = rearrange(latents, "b c f h w -> (b f) c h w") |
|
video = super().decode_latents(latents=latents) |
|
video = rearrange(video, "(b f) h w c -> b c f h w", b=batch_size) |
|
return video |
|
|
|
def prepare_latents( |
|
self, |
|
batch_size: int, |
|
num_channels_latents: int, |
|
video_length: int, |
|
height: int, |
|
width: int, |
|
dtype: torch.dtype, |
|
device: torch.device, |
|
generator: torch.Generator, |
|
latents: torch.Tensor = None, |
|
w_ind_noise: float = 0.5, |
|
image: torch.Tensor = None, |
|
timestep: int = None, |
|
initial_common_latent: torch.Tensor = None, |
|
noise_type: str = "random", |
|
add_latents_noise: bool = False, |
|
need_img_based_video_noise: bool = False, |
|
condition_latents: torch.Tensor = None, |
|
img_weight=1e-3, |
|
) -> torch.Tensor: |
|
""" |
|
支持多种情况下的latens: |
|
img_based_latents: 当Image t=1,latents=None时,使用image赋值到shape,然后加噪;适用于text2video、middle2video。 |
|
video_based_latents:image =shape或Latents!=None时,加噪,适用于video2video; |
|
noise_latents:当image 和latents都为None时,生成随机噪声,适用于text2video |
|
|
|
support multi latents condition: |
|
img_based_latents: when Image t=1, latents=None, use image to assign to shape, then add noise; suitable for text2video, middle2video. |
|
video_based_latents: image =shape or Latents!=None, add noise, suitable for video2video; |
|
noise_laten: when image and latents are both None, generate random noise, suitable for text2video |
|
|
|
Args: |
|
batch_size (int): _description_ |
|
num_channels_latents (int): _description_ |
|
video_length (int): _description_ |
|
height (int): _description_ |
|
width (int): _description_ |
|
dtype (torch.dtype): _description_ |
|
device (torch.device): _description_ |
|
generator (torch.Generator): _description_ |
|
latents (torch.Tensor, optional): _description_. Defaults to None. |
|
w_ind_noise (float, optional): _description_. Defaults to 0.5. |
|
image (torch.Tensor, optional): _description_. Defaults to None. |
|
timestep (int, optional): _description_. Defaults to None. |
|
initial_common_latent (torch.Tensor, optional): _description_. Defaults to None. |
|
noise_type (str, optional): _description_. Defaults to "random". |
|
add_latents_noise (bool, optional): _description_. Defaults to False. |
|
need_img_based_video_noise (bool, optional): _description_. Defaults to False. |
|
condition_latents (torch.Tensor, optional): _description_. Defaults to None. |
|
img_weight (_type_, optional): _description_. Defaults to 1e-3. |
|
|
|
Raises: |
|
ValueError: _description_ |
|
ValueError: _description_ |
|
ValueError: _description_ |
|
|
|
Returns: |
|
torch.Tensor: latents |
|
""" |
|
|
|
|
|
|
|
shape = ( |
|
batch_size, |
|
num_channels_latents, |
|
video_length, |
|
height // self.vae_scale_factor, |
|
width // self.vae_scale_factor, |
|
) |
|
if latents is None or (latents is not None and add_latents_noise): |
|
if noise_type == "random": |
|
noise = random_noise( |
|
shape=shape, dtype=dtype, device=device, generator=generator |
|
) |
|
elif noise_type == "video_fusion": |
|
noise = video_fusion_noise( |
|
shape=shape, |
|
dtype=dtype, |
|
device=device, |
|
generator=generator, |
|
w_ind_noise=w_ind_noise, |
|
initial_common_noise=initial_common_latent, |
|
) |
|
if ( |
|
need_img_based_video_noise |
|
and condition_latents is not None |
|
and image is None |
|
and latents is None |
|
): |
|
if self.print_idx == 0: |
|
logger.debug( |
|
( |
|
f"need_img_based_video_noise, condition_latents={condition_latents.shape}," |
|
f"batch_size={batch_size}, noise={noise.shape}, video_length={video_length}" |
|
) |
|
) |
|
condition_latents = condition_latents.mean(dim=2, keepdim=True) |
|
condition_latents = repeat( |
|
condition_latents, "b c t h w->b c (t x) h w", x=video_length |
|
) |
|
noise = ( |
|
img_weight**0.5 * condition_latents |
|
+ (1 - img_weight) ** 0.5 * noise |
|
) |
|
if self.print_idx == 0: |
|
logger.debug(f"noise={noise.shape}") |
|
|
|
if image is not None: |
|
if image.ndim == 5: |
|
image = rearrange(image, "b c t h w->(b t) c h w") |
|
image = image.to(device=device, dtype=dtype) |
|
if isinstance(generator, list) and len(generator) != batch_size: |
|
raise ValueError( |
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
|
) |
|
|
|
if isinstance(generator, list): |
|
init_latents = [ |
|
|
|
self.vae.encode(image[i : i + 1]).latent_dist.mean |
|
for i in range(batch_size) |
|
] |
|
init_latents = torch.cat(init_latents, dim=0) |
|
else: |
|
|
|
init_latents = self.vae.encode(image).latent_dist.mean |
|
init_latents = self.vae.config.scaling_factor * init_latents |
|
|
|
if ( |
|
batch_size > init_latents.shape[0] |
|
and batch_size % init_latents.shape[0] == 0 |
|
): |
|
|
|
deprecation_message = ( |
|
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" |
|
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note" |
|
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" |
|
" your script to pass as many initial images as text prompts to suppress this warning." |
|
) |
|
deprecate( |
|
"len(prompt) != len(image)", |
|
"1.0.0", |
|
deprecation_message, |
|
standard_warn=False, |
|
) |
|
additional_image_per_prompt = batch_size // init_latents.shape[0] |
|
init_latents = torch.cat( |
|
[init_latents] * additional_image_per_prompt, dim=0 |
|
) |
|
elif ( |
|
batch_size > init_latents.shape[0] |
|
and batch_size % init_latents.shape[0] != 0 |
|
): |
|
raise ValueError( |
|
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." |
|
) |
|
else: |
|
init_latents = torch.cat([init_latents], dim=0) |
|
if init_latents.shape[2] != shape[3] and init_latents.shape[3] != shape[4]: |
|
init_latents = torch.nn.functional.interpolate( |
|
init_latents, |
|
size=(shape[3], shape[4]), |
|
mode="bilinear", |
|
) |
|
init_latents = rearrange( |
|
init_latents, "(b t) c h w-> b c t h w", t=video_length |
|
) |
|
if self.print_idx == 0: |
|
logger.debug(f"init_latensts={init_latents.shape}") |
|
if latents is None: |
|
if image is None: |
|
latents = noise * self.scheduler.init_noise_sigma |
|
else: |
|
if self.print_idx == 0: |
|
logger.debug(f"prepare latents, image is not None") |
|
latents = self.scheduler.add_noise(init_latents, noise, timestep) |
|
else: |
|
if isinstance(latents, np.ndarray): |
|
latents = torch.from_numpy(latents) |
|
latents = latents.to(device=device, dtype=dtype) |
|
if add_latents_noise: |
|
latents = self.scheduler.add_noise(latents, noise, timestep) |
|
else: |
|
latents = latents * self.scheduler.init_noise_sigma |
|
if latents.shape != shape: |
|
raise ValueError( |
|
f"Unexpected latents shape, got {latents.shape}, expected {shape}" |
|
) |
|
latents = latents.to(device, dtype=dtype) |
|
return latents |
|
|
|
def prepare_image( |
|
self, |
|
image, |
|
batch_size, |
|
num_images_per_prompt, |
|
device, |
|
dtype, |
|
width=None, |
|
height=None, |
|
): |
|
return prepare_image( |
|
image=image, |
|
batch_size=batch_size, |
|
num_images_per_prompt=num_images_per_prompt, |
|
device=device, |
|
dtype=dtype, |
|
width=width, |
|
height=height, |
|
image_processor=self.image_processor, |
|
) |
|
|
|
def prepare_control_image( |
|
self, |
|
image, |
|
width, |
|
height, |
|
batch_size, |
|
num_images_per_prompt, |
|
device, |
|
dtype, |
|
do_classifier_free_guidance=False, |
|
guess_mode=False, |
|
): |
|
image = prepare_image( |
|
image=image, |
|
batch_size=batch_size, |
|
num_images_per_prompt=num_images_per_prompt, |
|
device=device, |
|
dtype=dtype, |
|
width=width, |
|
height=height, |
|
image_processor=self.control_image_processor, |
|
) |
|
if do_classifier_free_guidance and not guess_mode: |
|
image = torch.cat([image] * 2) |
|
return image |
|
|
|
def check_inputs( |
|
self, |
|
prompt, |
|
image, |
|
callback_steps, |
|
negative_prompt=None, |
|
prompt_embeds=None, |
|
negative_prompt_embeds=None, |
|
controlnet_conditioning_scale=1, |
|
control_guidance_start=0, |
|
control_guidance_end=1, |
|
): |
|
|
|
if image is not None: |
|
return super().check_inputs( |
|
prompt, |
|
image, |
|
callback_steps, |
|
negative_prompt, |
|
prompt_embeds, |
|
negative_prompt_embeds, |
|
controlnet_conditioning_scale, |
|
control_guidance_start, |
|
control_guidance_end, |
|
) |
|
|
|
def hist_match_with_vis_cond( |
|
self, video: np.ndarray, target: np.ndarray |
|
) -> np.ndarray: |
|
""" |
|
video: b c t1 h w |
|
target: b c t2(=1) h w |
|
""" |
|
video = hist_match_video_bcthw(video, target, value=255.0) |
|
return video |
|
|
|
def get_facein_image_emb( |
|
self, refer_face_image, device, dtype, batch_size, do_classifier_free_guidance |
|
): |
|
|
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"face_emb_extractor={type(self.face_emb_extractor)}, facein_image_proj={type(self.facein_image_proj)}, refer_face_image={type(refer_face_image)}, " |
|
) |
|
if ( |
|
self.face_emb_extractor is not None |
|
and self.facein_image_proj is not None |
|
and refer_face_image is not None |
|
): |
|
if self.print_idx == 0: |
|
logger.debug(f"refer_face_image={refer_face_image.shape}") |
|
if isinstance(refer_face_image, np.ndarray): |
|
refer_face_image = torch.from_numpy(refer_face_image) |
|
refer_face_image_facein = refer_face_image |
|
n_refer_face_image = refer_face_image_facein.shape[2] |
|
refer_face_image_facein = rearrange( |
|
refer_face_image, "b c t h w-> (b t) h w c" |
|
) |
|
|
|
( |
|
refer_face_image_emb, |
|
refer_align_face_image, |
|
) = self.face_emb_extractor.extract_images( |
|
refer_face_image_facein, return_type="torch" |
|
) |
|
refer_face_image_emb = refer_face_image_emb.to(device=device, dtype=dtype) |
|
if self.print_idx == 0: |
|
logger.debug(f"refer_face_image_emb={refer_face_image_emb.shape}") |
|
if refer_face_image_emb.shape == 2: |
|
refer_face_image_emb = rearrange(refer_face_image_emb, "bt d-> bt 1 d") |
|
elif refer_face_image_emb.shape == 4: |
|
refer_face_image_emb = rearrange( |
|
refer_face_image_emb, "bt h w d-> bt (h w) d" |
|
) |
|
refer_face_image_emb_bk = refer_face_image_emb |
|
refer_face_image_emb = self.facein_image_proj(refer_face_image_emb) |
|
|
|
refer_face_image_emb = rearrange( |
|
refer_face_image_emb, |
|
"(b t) n q-> b (t n) q", |
|
t=n_refer_face_image, |
|
) |
|
refer_face_image_emb = align_repeat_tensor_single_dim( |
|
refer_face_image_emb, target_length=batch_size, dim=0 |
|
) |
|
if do_classifier_free_guidance: |
|
|
|
|
|
uncond_refer_face_image_emb = self.facein_image_proj( |
|
torch.zeros_like(refer_face_image_emb_bk).to( |
|
device=device, dtype=dtype |
|
) |
|
) |
|
|
|
|
|
uncond_refer_face_image_emb = rearrange( |
|
uncond_refer_face_image_emb, |
|
"(b t) n q-> b (t n) q", |
|
t=n_refer_face_image, |
|
) |
|
uncond_refer_face_image_emb = align_repeat_tensor_single_dim( |
|
uncond_refer_face_image_emb, target_length=batch_size, dim=0 |
|
) |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"uncond_refer_face_image_emb, {uncond_refer_face_image_emb.shape}" |
|
) |
|
logger.debug(f"refer_face_image_emb, {refer_face_image_emb.shape}") |
|
refer_face_image_emb = torch.concat( |
|
[ |
|
uncond_refer_face_image_emb, |
|
refer_face_image_emb, |
|
], |
|
) |
|
else: |
|
refer_face_image_emb = None |
|
if self.print_idx == 0: |
|
logger.debug(f"refer_face_image_emb={type(refer_face_image_emb)}") |
|
|
|
return refer_face_image_emb |
|
|
|
def get_ip_adapter_face_emb( |
|
self, refer_face_image, device, dtype, batch_size, do_classifier_free_guidance |
|
): |
|
|
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"face_emb_extractor={type(self.face_emb_extractor)}, ip_adapter__image_proj={type(self.facein_image_proj)}, refer_face_image={type(refer_face_image)}, " |
|
) |
|
if ( |
|
self.ip_adapter_face_emb_extractor is not None |
|
and self.ip_adapter_face_image_proj is not None |
|
and refer_face_image is not None |
|
): |
|
if self.print_idx == 0: |
|
logger.debug(f"refer_face_image={refer_face_image.shape}") |
|
if isinstance(refer_face_image, np.ndarray): |
|
refer_face_image = torch.from_numpy(refer_face_image) |
|
refer_ip_adapter_face_image = refer_face_image |
|
n_refer_face_image = refer_ip_adapter_face_image.shape[2] |
|
refer_ip_adapter_face_image = rearrange( |
|
refer_ip_adapter_face_image, "b c t h w-> (b t) h w c" |
|
) |
|
|
|
( |
|
refer_face_image_emb, |
|
refer_align_face_image, |
|
) = self.ip_adapter_face_emb_extractor.extract_images( |
|
refer_ip_adapter_face_image, return_type="torch" |
|
) |
|
refer_face_image_emb = refer_face_image_emb.to(device=device, dtype=dtype) |
|
if self.print_idx == 0: |
|
logger.debug(f"refer_face_image_emb={refer_face_image_emb.shape}") |
|
if refer_face_image_emb.shape == 2: |
|
refer_face_image_emb = rearrange(refer_face_image_emb, "bt d-> bt 1 d") |
|
elif refer_face_image_emb.shape == 4: |
|
refer_face_image_emb = rearrange( |
|
refer_face_image_emb, "bt h w d-> bt (h w) d" |
|
) |
|
refer_face_image_emb_bk = refer_face_image_emb |
|
refer_face_image_emb = self.ip_adapter_face_image_proj(refer_face_image_emb) |
|
|
|
refer_face_image_emb = rearrange( |
|
refer_face_image_emb, |
|
"(b t) n q-> b (t n) q", |
|
t=n_refer_face_image, |
|
) |
|
refer_face_image_emb = align_repeat_tensor_single_dim( |
|
refer_face_image_emb, target_length=batch_size, dim=0 |
|
) |
|
if do_classifier_free_guidance: |
|
|
|
|
|
uncond_refer_face_image_emb = self.ip_adapter_face_image_proj( |
|
torch.zeros_like(refer_face_image_emb_bk).to( |
|
device=device, dtype=dtype |
|
) |
|
) |
|
|
|
|
|
uncond_refer_face_image_emb = rearrange( |
|
uncond_refer_face_image_emb, |
|
"(b t) n q-> b (t n) q", |
|
t=n_refer_face_image, |
|
) |
|
uncond_refer_face_image_emb = align_repeat_tensor_single_dim( |
|
uncond_refer_face_image_emb, target_length=batch_size, dim=0 |
|
) |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"uncond_refer_face_image_emb, {uncond_refer_face_image_emb.shape}" |
|
) |
|
logger.debug(f"refer_face_image_emb, {refer_face_image_emb.shape}") |
|
refer_face_image_emb = torch.concat( |
|
[ |
|
uncond_refer_face_image_emb, |
|
refer_face_image_emb, |
|
], |
|
) |
|
else: |
|
refer_face_image_emb = None |
|
if self.print_idx == 0: |
|
logger.debug(f"ip_adapter_face_emb={type(refer_face_image_emb)}") |
|
|
|
return refer_face_image_emb |
|
|
|
def get_ip_adapter_image_emb( |
|
self, |
|
ip_adapter_image, |
|
device, |
|
dtype, |
|
batch_size, |
|
do_classifier_free_guidance, |
|
height, |
|
width, |
|
): |
|
|
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"vision_clip_extractor={type(self.vision_clip_extractor)}," |
|
f"ip_adapter_image_proj={type(self.ip_adapter_image_proj)}," |
|
f"ip_adapter_image={type(ip_adapter_image)}," |
|
) |
|
if self.vision_clip_extractor is not None and ip_adapter_image is not None: |
|
if self.print_idx == 0: |
|
logger.debug(f"ip_adapter_image={ip_adapter_image.shape}") |
|
if isinstance(ip_adapter_image, np.ndarray): |
|
ip_adapter_image = torch.from_numpy(ip_adapter_image) |
|
|
|
n_ip_adapter_image = ip_adapter_image.shape[2] |
|
ip_adapter_image = rearrange(ip_adapter_image, "b c t h w-> (b t) h w c") |
|
ip_adapter_image_emb = self.vision_clip_extractor.extract_images( |
|
ip_adapter_image, |
|
target_height=height, |
|
target_width=width, |
|
return_type="torch", |
|
) |
|
if ip_adapter_image_emb.ndim == 2: |
|
ip_adapter_image_emb = rearrange(ip_adapter_image_emb, "b q-> b 1 q") |
|
|
|
ip_adapter_image_emb_bk = ip_adapter_image_emb |
|
|
|
|
|
if self.ip_adapter_image_proj is not None: |
|
logger.debug(f"ip_adapter_image_proj is None, ") |
|
ip_adapter_image_emb = self.ip_adapter_image_proj(ip_adapter_image_emb) |
|
|
|
|
|
ip_adapter_image_emb = rearrange( |
|
ip_adapter_image_emb, |
|
"(b t) n q-> b (t n) q", |
|
t=n_ip_adapter_image, |
|
) |
|
ip_adapter_image_emb = align_repeat_tensor_single_dim( |
|
ip_adapter_image_emb, target_length=batch_size, dim=0 |
|
) |
|
if do_classifier_free_guidance: |
|
|
|
|
|
if self.ip_adapter_image_proj is not None: |
|
uncond_ip_adapter_image_emb = self.ip_adapter_image_proj( |
|
torch.zeros_like(ip_adapter_image_emb_bk).to( |
|
device=device, dtype=dtype |
|
) |
|
) |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"uncond_ip_adapter_image_emb use ip_adapter_image_proj(zero_like)" |
|
) |
|
else: |
|
uncond_ip_adapter_image_emb = torch.zeros_like(ip_adapter_image_emb) |
|
if self.print_idx == 0: |
|
logger.debug(f"uncond_ip_adapter_image_emb use zero_like") |
|
|
|
|
|
uncond_ip_adapter_image_emb = rearrange( |
|
uncond_ip_adapter_image_emb, |
|
"(b t) n q-> b (t n) q", |
|
t=n_ip_adapter_image, |
|
) |
|
uncond_ip_adapter_image_emb = align_repeat_tensor_single_dim( |
|
uncond_ip_adapter_image_emb, target_length=batch_size, dim=0 |
|
) |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"uncond_ip_adapter_image_emb, {uncond_ip_adapter_image_emb.shape}" |
|
) |
|
logger.debug(f"ip_adapter_image_emb, {ip_adapter_image_emb.shape}") |
|
|
|
ip_adapter_image_emb = torch.concat( |
|
[ |
|
uncond_ip_adapter_image_emb, |
|
ip_adapter_image_emb, |
|
], |
|
) |
|
|
|
else: |
|
ip_adapter_image_emb = None |
|
if self.print_idx == 0: |
|
logger.debug(f"ip_adapter_image_emb={type(ip_adapter_image_emb)}") |
|
return ip_adapter_image_emb |
|
|
|
def get_referencenet_image_vae_emb( |
|
self, |
|
refer_image, |
|
batch_size, |
|
num_videos_per_prompt, |
|
device, |
|
dtype, |
|
do_classifier_free_guidance, |
|
width: int = None, |
|
height: int = None, |
|
): |
|
|
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"referencenet={type(self.referencenet)}, refer_image={type(refer_image)}" |
|
) |
|
if self.referencenet is not None and refer_image is not None: |
|
n_refer_image = refer_image.shape[2] |
|
refer_image_vae = self.prepare_image( |
|
refer_image, |
|
batch_size=batch_size * num_videos_per_prompt, |
|
num_images_per_prompt=num_videos_per_prompt, |
|
device=device, |
|
dtype=dtype, |
|
width=width, |
|
height=height, |
|
) |
|
|
|
refer_image_vae_emb = self.vae.encode(refer_image_vae).latent_dist.mean |
|
refer_image_vae_emb = self.vae.config.scaling_factor * refer_image_vae_emb |
|
|
|
logger.debug(f"refer_image_vae_emb={refer_image_vae_emb.shape}") |
|
|
|
if do_classifier_free_guidance: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uncond_refer_image_vae_emb = refer_image_vae_emb |
|
|
|
uncond_refer_image_vae_emb = rearrange( |
|
uncond_refer_image_vae_emb, |
|
"(b t) c h w-> b c t h w", |
|
t=n_refer_image, |
|
) |
|
|
|
refer_image_vae_emb = rearrange( |
|
refer_image_vae_emb, "(b t) c h w-> b c t h w", t=n_refer_image |
|
) |
|
refer_image_vae_emb = torch.concat( |
|
[uncond_refer_image_vae_emb, refer_image_vae_emb], dim=0 |
|
) |
|
refer_image_vae_emb = rearrange( |
|
refer_image_vae_emb, "b c t h w-> (b t) c h w" |
|
) |
|
logger.debug(f"refer_image_vae_emb={refer_image_vae_emb.shape}") |
|
else: |
|
refer_image_vae_emb = None |
|
return refer_image_vae_emb |
|
|
|
def get_referencenet_emb( |
|
self, |
|
refer_image_vae_emb, |
|
refer_image, |
|
batch_size, |
|
num_videos_per_prompt, |
|
device, |
|
dtype, |
|
ip_adapter_image_emb, |
|
do_classifier_free_guidance, |
|
prompt_embeds, |
|
ref_timestep_int: int = 0, |
|
): |
|
|
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"referencenet={type(self.referencenet)}, refer_image={type(refer_image)}" |
|
) |
|
if ( |
|
self.referencenet is not None |
|
and refer_image_vae_emb is not None |
|
and refer_image is not None |
|
): |
|
n_refer_image = refer_image.shape[2] |
|
|
|
|
|
|
|
|
|
ref_timestep = torch.zeros_like(ref_timestep_int) |
|
|
|
if ip_adapter_image_emb is not None: |
|
refer_prompt_embeds = ip_adapter_image_emb |
|
else: |
|
refer_prompt_embeds = prompt_embeds |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"use referencenet: n_refer_image={n_refer_image}, refer_image_vae_emb={refer_image_vae_emb.shape}, ref_timestep={ref_timestep.shape}" |
|
) |
|
if prompt_embeds is not None: |
|
logger.debug(f"prompt_embeds={prompt_embeds.shape},") |
|
|
|
|
|
|
|
|
|
|
|
|
|
referencenet_params = { |
|
"sample": refer_image_vae_emb, |
|
"encoder_hidden_states": refer_prompt_embeds, |
|
"timestep": ref_timestep, |
|
"num_frames": n_refer_image, |
|
"return_ndim": 5, |
|
} |
|
( |
|
down_block_refer_embs, |
|
mid_block_refer_emb, |
|
refer_self_attn_emb, |
|
) = self.referencenet(**referencenet_params) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
down_block_refer_embs = None |
|
mid_block_refer_emb = None |
|
refer_self_attn_emb = None |
|
if self.print_idx == 0: |
|
logger.debug(f"down_block_refer_embs={type(down_block_refer_embs)}") |
|
logger.debug(f"mid_block_refer_emb={type(mid_block_refer_emb)}") |
|
logger.debug(f"refer_self_attn_emb={type(refer_self_attn_emb)}") |
|
return down_block_refer_embs, mid_block_refer_emb, refer_self_attn_emb |
|
|
|
def prepare_condition_latents_and_index( |
|
self, |
|
condition_images, |
|
condition_latents, |
|
video_length, |
|
batch_size, |
|
dtype, |
|
device, |
|
latent_index, |
|
vision_condition_latent_index, |
|
): |
|
|
|
if condition_images is not None and condition_latents is None: |
|
|
|
condition_latents = self.vae.encode(condition_images).latent_dist.mean |
|
condition_latents = self.vae.config.scaling_factor * condition_latents |
|
condition_latents = rearrange( |
|
condition_latents, "(b t) c h w-> b c t h w", b=batch_size |
|
) |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"condition_latents from condition_images, shape is condition_latents={condition_latents.shape}", |
|
) |
|
if condition_latents is not None: |
|
total_frames = condition_latents.shape[2] + video_length |
|
if isinstance(condition_latents, np.ndarray): |
|
condition_latents = torch.from_numpy(condition_latents) |
|
condition_latents = condition_latents.to(dtype=dtype, device=device) |
|
|
|
if vision_condition_latent_index is not None: |
|
|
|
|
|
vision_condition_latent_index_lst = [ |
|
i_v if i_v != -1 else total_frames - 1 |
|
for i_v in vision_condition_latent_index |
|
] |
|
vision_condition_latent_index = torch.LongTensor( |
|
vision_condition_latent_index_lst, |
|
).to(device=device) |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"vision_condition_latent_index {type(vision_condition_latent_index)}, {vision_condition_latent_index}" |
|
) |
|
else: |
|
|
|
vision_condition_latent_index = torch.arange( |
|
condition_latents.shape[2], dtype=torch.long, device=device |
|
) |
|
vision_condition_latent_index_lst = ( |
|
vision_condition_latent_index.tolist() |
|
) |
|
if latent_index is None: |
|
|
|
latent_index_lst = sorted( |
|
list( |
|
set(range(total_frames)) |
|
- set(vision_condition_latent_index_lst) |
|
) |
|
) |
|
latent_index = torch.LongTensor( |
|
latent_index_lst, |
|
).to(device=device) |
|
|
|
if vision_condition_latent_index is not None: |
|
vision_condition_latent_index = vision_condition_latent_index.to( |
|
device=device |
|
) |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"pipeline vision_condition_latent_index ={vision_condition_latent_index.shape}, {vision_condition_latent_index}" |
|
) |
|
if latent_index is not None: |
|
latent_index = latent_index.to(device=device) |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"pipeline latent_index ={latent_index.shape}, {latent_index}" |
|
) |
|
logger.debug(f"condition_latents={type(condition_latents)}") |
|
logger.debug(f"latent_index={type(latent_index)}") |
|
logger.debug( |
|
f"vision_condition_latent_index={type(vision_condition_latent_index)}" |
|
) |
|
return condition_latents, latent_index, vision_condition_latent_index |
|
|
|
def prepare_controlnet_and_guidance_parameter( |
|
self, control_guidance_start, control_guidance_end |
|
): |
|
controlnet = ( |
|
self.controlnet._orig_mod |
|
if is_compiled_module(self.controlnet) |
|
else self.controlnet |
|
) |
|
|
|
|
|
if not isinstance(control_guidance_start, list) and isinstance( |
|
control_guidance_end, list |
|
): |
|
control_guidance_start = len(control_guidance_end) * [ |
|
control_guidance_start |
|
] |
|
elif not isinstance(control_guidance_end, list) and isinstance( |
|
control_guidance_start, list |
|
): |
|
control_guidance_end = len(control_guidance_start) * [control_guidance_end] |
|
elif not isinstance(control_guidance_start, list) and not isinstance( |
|
control_guidance_end, list |
|
): |
|
mult = ( |
|
len(controlnet.nets) |
|
if isinstance(controlnet, MultiControlNetModel) |
|
else 1 |
|
) |
|
control_guidance_start, control_guidance_end = mult * [ |
|
control_guidance_start |
|
], mult * [control_guidance_end] |
|
return controlnet, control_guidance_start, control_guidance_end |
|
|
|
def prepare_controlnet_guess_mode(self, controlnet, guess_mode): |
|
global_pool_conditions = ( |
|
controlnet.config.global_pool_conditions |
|
if isinstance(controlnet, ControlNetModel) |
|
else controlnet.nets[0].config.global_pool_conditions |
|
) |
|
guess_mode = guess_mode or global_pool_conditions |
|
return guess_mode |
|
|
|
def prepare_controlnet_image_and_latents( |
|
self, |
|
controlnet, |
|
width, |
|
height, |
|
batch_size, |
|
num_videos_per_prompt, |
|
device, |
|
dtype, |
|
controlnet_latents=None, |
|
controlnet_condition_latents=None, |
|
control_image=None, |
|
controlnet_condition_images=None, |
|
guess_mode=False, |
|
do_classifier_free_guidance=False, |
|
): |
|
if isinstance(controlnet, ControlNetModel): |
|
if controlnet_latents is not None: |
|
if isinstance(controlnet_latents, np.ndarray): |
|
controlnet_latents = torch.from_numpy(controlnet_latents) |
|
if controlnet_condition_latents is not None: |
|
if isinstance(controlnet_condition_latents, np.ndarray): |
|
controlnet_condition_latents = torch.from_numpy( |
|
controlnet_condition_latents |
|
) |
|
|
|
controlnet_latents = torch.concat( |
|
[controlnet_condition_latents, controlnet_latents], dim=2 |
|
) |
|
if not guess_mode and do_classifier_free_guidance: |
|
controlnet_latents = torch.concat([controlnet_latents] * 2, dim=0) |
|
controlnet_latents = rearrange( |
|
controlnet_latents, "b c t h w->(b t) c h w" |
|
) |
|
controlnet_latents = controlnet_latents.to(device=device, dtype=dtype) |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"call, controlnet_latents.shape, f{controlnet_latents.shape}" |
|
) |
|
else: |
|
|
|
if isinstance(control_image, np.ndarray): |
|
control_image = torch.from_numpy(control_image) |
|
if controlnet_condition_images is not None: |
|
if isinstance(controlnet_condition_images, np.ndarray): |
|
controlnet_condition_images = torch.from_numpy( |
|
controlnet_condition_images |
|
) |
|
control_image = torch.concatenate( |
|
[controlnet_condition_images, control_image], dim=2 |
|
) |
|
control_image = self.prepare_control_image( |
|
image=control_image, |
|
width=width, |
|
height=height, |
|
batch_size=batch_size * num_videos_per_prompt, |
|
num_images_per_prompt=num_videos_per_prompt, |
|
device=device, |
|
dtype=controlnet.dtype, |
|
do_classifier_free_guidance=do_classifier_free_guidance, |
|
guess_mode=guess_mode, |
|
) |
|
height, width = control_image.shape[-2:] |
|
if self.print_idx == 0: |
|
logger.debug(f"call, control_image.shape , {control_image.shape}") |
|
|
|
elif isinstance(controlnet, MultiControlNetModel): |
|
control_images = [] |
|
|
|
if ( |
|
controlnet_latents is not None |
|
and controlnet_condition_latents is not None |
|
): |
|
raise NotImplementedError |
|
for i, control_image_ in enumerate(control_image): |
|
if controlnet_condition_images is not None and isinstance( |
|
controlnet_condition_images, list |
|
): |
|
if isinstance(controlnet_condition_images[i], np.ndarray): |
|
control_image_ = np.concatenate( |
|
[controlnet_condition_images[i], control_image_], axis=2 |
|
) |
|
control_image_ = self.prepare_control_image( |
|
image=control_image_, |
|
width=width, |
|
height=height, |
|
batch_size=batch_size * num_videos_per_prompt, |
|
num_images_per_prompt=num_videos_per_prompt, |
|
device=device, |
|
dtype=controlnet.dtype, |
|
do_classifier_free_guidance=do_classifier_free_guidance, |
|
guess_mode=guess_mode, |
|
) |
|
|
|
control_images.append(control_image_) |
|
|
|
control_image = control_images |
|
height, width = control_image[0].shape[-2:] |
|
else: |
|
assert False |
|
if control_image is not None: |
|
if not isinstance(control_image, list): |
|
if self.print_idx == 0: |
|
logger.debug(f"control_image shape is {control_image.shape}") |
|
else: |
|
if self.print_idx == 0: |
|
logger.debug(f"control_image shape is {control_image[0].shape}") |
|
|
|
return control_image, controlnet_latents |
|
|
|
def get_controlnet_emb( |
|
self, |
|
run_controlnet, |
|
guess_mode, |
|
do_classifier_free_guidance, |
|
latents, |
|
prompt_embeds, |
|
latent_model_input, |
|
controlnet_keep, |
|
controlnet_conditioning_scale, |
|
control_image, |
|
controlnet_latents, |
|
i, |
|
t, |
|
): |
|
if run_controlnet and self.pose_guider is None: |
|
|
|
if guess_mode and do_classifier_free_guidance: |
|
|
|
control_model_input = latents |
|
control_model_input = self.scheduler.scale_model_input( |
|
control_model_input, t |
|
) |
|
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] |
|
else: |
|
control_model_input = latent_model_input |
|
controlnet_prompt_embeds = prompt_embeds |
|
if isinstance(controlnet_keep[i], list): |
|
cond_scale = [ |
|
c * s |
|
for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i]) |
|
] |
|
else: |
|
cond_scale = controlnet_conditioning_scale * controlnet_keep[i] |
|
control_model_input_reshape = rearrange( |
|
control_model_input, "b c t h w -> (b t) c h w" |
|
) |
|
logger.debug( |
|
f"control_model_input_reshape={control_model_input_reshape.shape}, controlnet_prompt_embeds={controlnet_prompt_embeds.shape}" |
|
) |
|
encoder_hidden_states_repeat = align_repeat_tensor_single_dim( |
|
controlnet_prompt_embeds, |
|
target_length=control_model_input_reshape.shape[0], |
|
dim=0, |
|
) |
|
|
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"control_model_input_reshape={control_model_input_reshape.shape}, " |
|
f"encoder_hidden_states_repeat={encoder_hidden_states_repeat.shape}, " |
|
) |
|
down_block_res_samples, mid_block_res_sample = self.controlnet( |
|
control_model_input_reshape, |
|
t, |
|
encoder_hidden_states_repeat, |
|
controlnet_cond=control_image, |
|
controlnet_cond_latents=controlnet_latents, |
|
conditioning_scale=cond_scale, |
|
guess_mode=guess_mode, |
|
return_dict=False, |
|
) |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"controlnet, len(down_block_res_samples, {len(down_block_res_samples)}", |
|
) |
|
for i_tmp, tmp in enumerate(down_block_res_samples): |
|
logger.debug( |
|
f"controlnet down_block_res_samples i={i_tmp}, down_block_res_sample={tmp.shape}" |
|
) |
|
logger.debug( |
|
f"controlnet mid_block_res_sample, {mid_block_res_sample.shape}" |
|
) |
|
if guess_mode and do_classifier_free_guidance: |
|
|
|
|
|
|
|
down_block_res_samples = [ |
|
torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples |
|
] |
|
mid_block_res_sample = torch.cat( |
|
[ |
|
torch.zeros_like(mid_block_res_sample), |
|
mid_block_res_sample, |
|
] |
|
) |
|
else: |
|
down_block_res_samples = None |
|
mid_block_res_sample = None |
|
|
|
return down_block_res_samples, mid_block_res_sample |
|
|
|
@torch.no_grad() |
|
@replace_example_docstring(EXAMPLE_DOC_STRING) |
|
def __call__( |
|
self, |
|
video_length: Optional[int], |
|
prompt: Union[str, List[str]] = None, |
|
|
|
image: Union[ |
|
torch.FloatTensor, |
|
PIL.Image.Image, |
|
np.ndarray, |
|
List[torch.FloatTensor], |
|
List[PIL.Image.Image], |
|
List[np.ndarray], |
|
] = None, |
|
control_image: Union[ |
|
torch.FloatTensor, |
|
PIL.Image.Image, |
|
np.ndarray, |
|
List[torch.FloatTensor], |
|
List[PIL.Image.Image], |
|
List[np.ndarray], |
|
] = None, |
|
|
|
condition_images: Optional[torch.FloatTensor] = None, |
|
condition_latents: Optional[torch.FloatTensor] = None, |
|
latents: Optional[torch.FloatTensor] = None, |
|
add_latents_noise: bool = False, |
|
height: Optional[int] = None, |
|
width: Optional[int] = None, |
|
strength: float = 0.8, |
|
num_inference_steps: int = 50, |
|
guidance_scale: float = 7.5, |
|
guidance_scale_end: float = None, |
|
guidance_scale_method: str = "linear", |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
num_videos_per_prompt: Optional[int] = 1, |
|
eta: float = 0.0, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
|
|
controlnet_condition_images: Optional[torch.FloatTensor] = None, |
|
|
|
controlnet_condition_latents: Optional[torch.FloatTensor] = None, |
|
controlnet_latents: Union[torch.FloatTensor, np.ndarray] = None, |
|
prompt_embeds: Optional[torch.FloatTensor] = None, |
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
output_type: Optional[str] = "tensor", |
|
return_dict: bool = True, |
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
|
callback_steps: int = 1, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
controlnet_conditioning_scale: Union[float, List[float]] = 1.0, |
|
guess_mode: bool = False, |
|
control_guidance_start: Union[float, List[float]] = 0.0, |
|
control_guidance_end: Union[float, List[float]] = 1.0, |
|
need_middle_latents: bool = False, |
|
w_ind_noise: float = 0.5, |
|
initial_common_latent: Optional[torch.FloatTensor] = None, |
|
latent_index: torch.LongTensor = None, |
|
vision_condition_latent_index: torch.LongTensor = None, |
|
|
|
noise_type: str = "random", |
|
need_img_based_video_noise: bool = False, |
|
skip_temporal_layer: bool = False, |
|
img_weight: float = 1e-3, |
|
need_hist_match: bool = False, |
|
motion_speed: float = 8.0, |
|
refer_image: Optional[Tuple[torch.Tensor, np.array]] = None, |
|
ip_adapter_image: Optional[Tuple[torch.Tensor, np.array]] = None, |
|
refer_face_image: Optional[Tuple[torch.Tensor, np.array]] = None, |
|
ip_adapter_scale: float = 1.0, |
|
facein_scale: float = 1.0, |
|
ip_adapter_face_scale: float = 1.0, |
|
ip_adapter_face_image: Optional[Tuple[torch.Tensor, np.array]] = None, |
|
prompt_only_use_image_prompt: bool = False, |
|
|
|
record_mid_video_noises: bool = False, |
|
last_mid_video_noises: List[torch.Tensor] = None, |
|
record_mid_video_latents: bool = False, |
|
last_mid_video_latents: List[torch.TensorType] = None, |
|
video_overlap: int = 1, |
|
|
|
|
|
|
|
context_schedule="uniform", |
|
context_frames=12, |
|
context_stride=1, |
|
context_overlap=4, |
|
context_batch_size=1, |
|
interpolation_factor=1, |
|
|
|
decoder_t_segment: int = 200, |
|
): |
|
r""" |
|
旨在兼容text2video、text2image、img2img、video2video、是否有controlnet等的通用pipeline。目前仅不支持img2img、video2video。 |
|
支持多片段同时denoise,交叉部分加权平均 |
|
|
|
当 skip_temporal_layer 为 False 时, unet 起 video 生成作用;skip_temporal_layer为True时,unet起原image作用。 |
|
当controlnet的所有入参为None,等价于走的是text2video pipeline; |
|
当 condition_latents、controlnet_condition_images、controlnet_condition_latents为None时,表示不走首帧条件生成的时序condition pipeline |
|
现在没有考虑对 `num_videos_per_prompt` 的兼容性,不是1可能报错; |
|
|
|
if skip_temporal_layer is False, unet motion layer works, else unet only run text2image layers. |
|
if parameters about controlnet are None, means text2video pipeline; |
|
if ondition_latents、controlnet_condition_images、controlnet_condition_latents are None, means only run text2video without vision condition images. |
|
By now, code works well with `num_videos_per_prpmpt=1`, !=1 may be wrong. |
|
|
|
Args: |
|
prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. |
|
instead. |
|
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: |
|
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): |
|
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If |
|
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can |
|
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If |
|
height and/or width are passed, `image` is resized according to them. If multiple ControlNets are |
|
specified in init, images must be passed as a list such that each element of the list can be correctly |
|
batched for input to a single controlnet. |
|
condition_latents: |
|
与latents相对应,是Latents的时序condition,一般为首帧,b c t(1) ho wo |
|
be corresponding to latents, vision condtion latents, usually first frame, should be b c t(1) ho wo. |
|
controlnet_latents: |
|
与image二选一,image会被转化成controlnet_latents |
|
Choose either image or controlnet_latents. If image is chosen, it will be converted to controlnet_latents. |
|
controlnet_condition_images: |
|
Optional[torch.FloatTensor]# b c t(1) ho wo,与image相对应,会和image在t通道concat一起,然后转化成 controlnet_latents |
|
b c t(1) ho wo, corresponding to image, will be concatenated along the t channel with image and then converted to controlnet_latents. |
|
controlnet_condition_latents: Optional[torch.FloatTensor]:# |
|
b c t(1) ho wo,会和 controlnet_latents 在t 通道concat一起,转化成 controlnet_latents |
|
b c t(1) ho wo will be concatenated along the t channel with controlnet_latents and converted to controlnet_latents. |
|
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): |
|
The height in pixels of the generated image. |
|
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): |
|
The width in pixels of the generated image. |
|
num_inference_steps (`int`, *optional*, defaults to 50): |
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
|
expense of slower inference. |
|
guidance_scale (`float`, *optional*, defaults to 7.5): |
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). |
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen |
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > |
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, |
|
usually at the expense of lower image quality. |
|
negative_prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass |
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is |
|
less than `1`). |
|
strength (`float`, *optional*, defaults to 0.8): |
|
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a |
|
starting point and more noise is added the higher the `strength`. The number of denoising steps depends |
|
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising |
|
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 |
|
essentially ignores `image`. |
|
num_images_per_prompt (`int`, *optional*, defaults to 1): |
|
The number of images to generate per prompt. |
|
eta (`float`, *optional*, defaults to 0.0): |
|
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to |
|
[`schedulers.DDIMScheduler`], will be ignored for others. |
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): |
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) |
|
to make generation deterministic. |
|
latents (`torch.FloatTensor`, *optional*): |
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image |
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents |
|
tensor will ge generated by sampling using the supplied random `generator`. |
|
prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
|
provided, text embeddings will be generated from `prompt` input argument. |
|
negative_prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input |
|
argument. |
|
output_type (`str`, *optional*, defaults to `"pil"`): |
|
The output format of the generate image. Choose between |
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a |
|
plain tuple. |
|
callback (`Callable`, *optional*): |
|
A function that will be called every `callback_steps` steps during inference. The function will be |
|
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. |
|
callback_steps (`int`, *optional*, defaults to 1): |
|
The frequency at which the `callback` function will be called. If not specified, the callback will be |
|
called at every step. |
|
cross_attention_kwargs (`dict`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
|
`self.processor` in |
|
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). |
|
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): |
|
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added |
|
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the |
|
corresponding scale as a list. |
|
guess_mode (`bool`, *optional*, defaults to `False`): |
|
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if |
|
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. |
|
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): |
|
The percentage of total steps at which the controlnet starts applying. |
|
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): |
|
The percentage of total steps at which the controlnet stops applying. |
|
skip_temporal_layer (`bool`: default to False) 为False时,unet起video生成作用,会运行时序生成的block;skip_temporal_layer为True时,unet起原image作用,跳过时序生成的block。 |
|
need_img_based_video_noise: bool = False, 当只有首帧latents时,是否需要扩展为video noise; |
|
num_videos_per_prompt: now only support 1. |
|
|
|
Examples: |
|
|
|
Returns: |
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: |
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. |
|
When returning a tuple, the first element is a list with the generated images, and the second element is a |
|
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" |
|
(nsfw) content, according to the `safety_checker`. |
|
""" |
|
run_controlnet = control_image is not None or controlnet_latents is not None |
|
|
|
if run_controlnet: |
|
( |
|
controlnet, |
|
control_guidance_start, |
|
control_guidance_end, |
|
) = self.prepare_controlnet_and_guidance_parameter( |
|
control_guidance_start=control_guidance_start, |
|
control_guidance_end=control_guidance_end, |
|
) |
|
|
|
|
|
self.check_inputs( |
|
prompt, |
|
control_image, |
|
callback_steps, |
|
negative_prompt, |
|
prompt_embeds, |
|
negative_prompt_embeds, |
|
controlnet_conditioning_scale, |
|
control_guidance_start, |
|
control_guidance_end, |
|
) |
|
|
|
|
|
if prompt is not None and isinstance(prompt, str): |
|
batch_size = 1 |
|
elif prompt is not None and isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
device = self._execution_device |
|
dtype = self.unet.dtype |
|
|
|
|
|
|
|
|
|
do_classifier_free_guidance = guidance_scale > 1.0 |
|
|
|
if run_controlnet: |
|
if isinstance(controlnet, MultiControlNetModel) and isinstance( |
|
controlnet_conditioning_scale, float |
|
): |
|
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len( |
|
controlnet.nets |
|
) |
|
guess_mode = self.prepare_controlnet_guess_mode( |
|
controlnet=controlnet, |
|
guess_mode=guess_mode, |
|
) |
|
|
|
|
|
text_encoder_lora_scale = ( |
|
cross_attention_kwargs.get("scale", None) |
|
if cross_attention_kwargs is not None |
|
else None |
|
) |
|
if self.text_encoder is not None: |
|
prompt_embeds = encode_weighted_prompt( |
|
self, |
|
prompt, |
|
device, |
|
num_videos_per_prompt, |
|
do_classifier_free_guidance, |
|
negative_prompt, |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
|
|
) |
|
logger.debug(f"use text_encoder prepare prompt_emb={prompt_embeds.shape}") |
|
else: |
|
prompt_embeds = None |
|
if image is not None: |
|
image = self.prepare_image( |
|
image, |
|
width=width, |
|
height=height, |
|
batch_size=batch_size * num_videos_per_prompt, |
|
num_images_per_prompt=num_videos_per_prompt, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
if self.print_idx == 0: |
|
logger.debug(f"image={image.shape}") |
|
if condition_images is not None: |
|
condition_images = self.prepare_image( |
|
condition_images, |
|
width=width, |
|
height=height, |
|
batch_size=batch_size * num_videos_per_prompt, |
|
num_images_per_prompt=num_videos_per_prompt, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
if self.print_idx == 0: |
|
logger.debug(f"condition_images={condition_images.shape}") |
|
|
|
if run_controlnet: |
|
( |
|
control_image, |
|
controlnet_latents, |
|
) = self.prepare_controlnet_image_and_latents( |
|
controlnet=controlnet, |
|
width=width, |
|
height=height, |
|
batch_size=batch_size, |
|
num_videos_per_prompt=num_videos_per_prompt, |
|
device=device, |
|
dtype=dtype, |
|
controlnet_condition_latents=controlnet_condition_latents, |
|
control_image=control_image, |
|
controlnet_condition_images=controlnet_condition_images, |
|
guess_mode=guess_mode, |
|
do_classifier_free_guidance=do_classifier_free_guidance, |
|
controlnet_latents=controlnet_latents, |
|
) |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device) |
|
if strength and (image is not None and latents is not None): |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"prepare timesteps, with get_timesteps strength={strength}, num_inference_steps={num_inference_steps}" |
|
) |
|
timesteps, num_inference_steps = self.get_timesteps( |
|
num_inference_steps, strength, device |
|
) |
|
else: |
|
if self.print_idx == 0: |
|
logger.debug(f"prepare timesteps, without get_timesteps") |
|
timesteps = self.scheduler.timesteps |
|
latent_timestep = timesteps[:1].repeat( |
|
batch_size * num_videos_per_prompt |
|
) |
|
|
|
( |
|
condition_latents, |
|
latent_index, |
|
vision_condition_latent_index, |
|
) = self.prepare_condition_latents_and_index( |
|
condition_images=condition_images, |
|
condition_latents=condition_latents, |
|
video_length=video_length, |
|
batch_size=batch_size, |
|
dtype=dtype, |
|
device=device, |
|
latent_index=latent_index, |
|
vision_condition_latent_index=vision_condition_latent_index, |
|
) |
|
if vision_condition_latent_index is None: |
|
n_vision_cond = 0 |
|
else: |
|
n_vision_cond = vision_condition_latent_index.shape[0] |
|
|
|
num_channels_latents = self.unet.config.in_channels |
|
if self.print_idx == 0: |
|
logger.debug(f"pipeline controlnet, start prepare latents") |
|
|
|
latents = self.prepare_latents( |
|
batch_size=batch_size * num_videos_per_prompt, |
|
num_channels_latents=num_channels_latents, |
|
video_length=video_length, |
|
height=height, |
|
width=width, |
|
dtype=dtype, |
|
device=device, |
|
generator=generator, |
|
latents=latents, |
|
image=image, |
|
timestep=latent_timestep, |
|
w_ind_noise=w_ind_noise, |
|
initial_common_latent=initial_common_latent, |
|
noise_type=noise_type, |
|
add_latents_noise=add_latents_noise, |
|
need_img_based_video_noise=need_img_based_video_noise, |
|
condition_latents=condition_latents, |
|
img_weight=img_weight, |
|
) |
|
if self.print_idx == 0: |
|
logger.debug(f"pipeline controlnet, finish prepare latents={latents.shape}") |
|
|
|
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
|
|
|
if noise_type == "video_fusion" and "noise_type" in set( |
|
inspect.signature(self.scheduler.step).parameters.keys() |
|
): |
|
extra_step_kwargs["w_ind_noise"] = w_ind_noise |
|
extra_step_kwargs["noise_type"] = noise_type |
|
|
|
|
|
|
|
if run_controlnet: |
|
controlnet_keep = [] |
|
for i in range(len(timesteps)): |
|
keeps = [ |
|
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) |
|
for s, e in zip(control_guidance_start, control_guidance_end) |
|
] |
|
controlnet_keep.append( |
|
keeps[0] if isinstance(controlnet, ControlNetModel) else keeps |
|
) |
|
else: |
|
controlnet_keep = None |
|
|
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
|
if skip_temporal_layer: |
|
self.unet.set_skip_temporal_layers(True) |
|
|
|
n_timesteps = len(timesteps) |
|
guidance_scale_lst = generate_parameters_with_timesteps( |
|
start=guidance_scale, |
|
stop=guidance_scale_end, |
|
num=n_timesteps, |
|
method=guidance_scale_method, |
|
) |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"guidance_scale_lst, {guidance_scale_method}, {guidance_scale}, {guidance_scale_end}, {guidance_scale_lst}" |
|
) |
|
|
|
ip_adapter_image_emb = self.get_ip_adapter_image_emb( |
|
ip_adapter_image=ip_adapter_image, |
|
batch_size=batch_size, |
|
device=device, |
|
dtype=dtype, |
|
do_classifier_free_guidance=do_classifier_free_guidance, |
|
height=height, |
|
width=width, |
|
) |
|
|
|
|
|
|
|
if ( |
|
ip_adapter_image_emb is not None |
|
and prompt_only_use_image_prompt |
|
and not self.unet.ip_adapter_cross_attn |
|
): |
|
prompt_embeds = ip_adapter_image_emb |
|
logger.debug(f"use ip_adapter_image_emb replace prompt_embeds") |
|
refer_face_image_emb = self.get_facein_image_emb( |
|
refer_face_image=refer_face_image, |
|
batch_size=batch_size, |
|
device=device, |
|
dtype=dtype, |
|
do_classifier_free_guidance=do_classifier_free_guidance, |
|
) |
|
|
|
ip_adapter_face_emb = self.get_ip_adapter_face_emb( |
|
refer_face_image=ip_adapter_face_image, |
|
batch_size=batch_size, |
|
device=device, |
|
dtype=dtype, |
|
do_classifier_free_guidance=do_classifier_free_guidance, |
|
) |
|
refer_image_vae_emb = self.get_referencenet_image_vae_emb( |
|
refer_image=refer_image, |
|
device=device, |
|
dtype=dtype, |
|
do_classifier_free_guidance=do_classifier_free_guidance, |
|
num_videos_per_prompt=num_videos_per_prompt, |
|
batch_size=batch_size, |
|
width=width, |
|
height=height, |
|
) |
|
|
|
if self.pose_guider is not None and control_image is not None: |
|
if self.print_idx == 0: |
|
logger.debug(f"pose_guider, controlnet_image={control_image.shape}") |
|
control_image = rearrange( |
|
control_image, " (b t) c h w->b c t h w", t=video_length |
|
) |
|
pose_guider_emb = self.pose_guider(control_image) |
|
pose_guider_emb = rearrange(pose_guider_emb, "b c t h w-> (b t) c h w") |
|
else: |
|
pose_guider_emb = None |
|
logger.debug(f"prompt_embeds={prompt_embeds.shape}") |
|
|
|
if control_image is not None: |
|
if isinstance(control_image, list): |
|
logger.debug(f"control_imageis list, num={len(control_image)}") |
|
control_image = [ |
|
rearrange( |
|
control_image_tmp, |
|
" (b t) c h w->b c t h w", |
|
b=(int(do_classifier_free_guidance) * 1 + 1) * batch_size, |
|
) |
|
for control_image_tmp in control_image |
|
] |
|
else: |
|
logger.debug(f"control_image={control_image.shape}, before") |
|
control_image = rearrange( |
|
control_image, |
|
" (b t) c h w->b c t h w", |
|
b=(int(do_classifier_free_guidance) * 1 + 1) * batch_size, |
|
) |
|
logger.debug(f"control_image={control_image.shape}, after") |
|
|
|
if controlnet_latents is not None: |
|
if isinstance(controlnet_latents, list): |
|
logger.debug( |
|
f"controlnet_latents is list, num={len(controlnet_latents)}" |
|
) |
|
controlnet_latents = [ |
|
rearrange( |
|
controlnet_latents_tmp, |
|
" (b t) c h w->b c t h w", |
|
b=(int(do_classifier_free_guidance) * 1 + 1) * batch_size, |
|
) |
|
for controlnet_latents_tmp in controlnet_latents |
|
] |
|
else: |
|
logger.debug(f"controlnet_latents={controlnet_latents.shape}, before") |
|
controlnet_latents = rearrange( |
|
controlnet_latents, |
|
" (b t) c h w->b c t h w", |
|
b=(int(do_classifier_free_guidance) * 1 + 1) * batch_size, |
|
) |
|
logger.debug(f"controlnet_latents={controlnet_latents.shape}, after") |
|
|
|
videos_mid = [] |
|
mid_video_noises = [] if record_mid_video_noises else None |
|
mid_video_latents = [] if record_mid_video_latents else None |
|
|
|
global_context = prepare_global_context( |
|
context_schedule=context_schedule, |
|
num_inference_steps=num_inference_steps, |
|
time_size=latents.shape[2], |
|
context_frames=context_frames, |
|
context_stride=context_stride, |
|
context_overlap=context_overlap, |
|
context_batch_size=context_batch_size, |
|
) |
|
logger.debug( |
|
f"context_schedule={context_schedule}, time_size={latents.shape[2]}, context_frames={context_frames}, context_stride={context_stride}, context_overlap={context_overlap}, context_batch_size={context_batch_size}" |
|
) |
|
logger.debug(f"global_context={global_context}") |
|
|
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
for i, t in enumerate(timesteps): |
|
|
|
|
|
if i == 0: |
|
if record_mid_video_latents: |
|
mid_video_latents.append(latents[:, :, -video_overlap:]) |
|
if record_mid_video_noises: |
|
mid_video_noises.append(None) |
|
if ( |
|
last_mid_video_latents is not None |
|
and len(last_mid_video_latents) > 0 |
|
): |
|
if self.print_idx == 1: |
|
logger.debug( |
|
f"{i}, last_mid_video_latents={last_mid_video_latents[i].shape}" |
|
) |
|
latents = fuse_part_tensor( |
|
last_mid_video_latents[0], |
|
latents, |
|
video_overlap, |
|
weight=0.1, |
|
skip_step=0, |
|
) |
|
noise_pred = torch.zeros( |
|
( |
|
latents.shape[0] * (2 if do_classifier_free_guidance else 1), |
|
*latents.shape[1:], |
|
), |
|
device=latents.device, |
|
dtype=latents.dtype, |
|
) |
|
counter = torch.zeros( |
|
(1, 1, latents.shape[2], 1, 1), |
|
device=latents.device, |
|
dtype=latents.dtype, |
|
) |
|
if i == 0: |
|
( |
|
down_block_refer_embs, |
|
mid_block_refer_emb, |
|
refer_self_attn_emb, |
|
) = self.get_referencenet_emb( |
|
refer_image_vae_emb=refer_image_vae_emb, |
|
refer_image=refer_image, |
|
device=device, |
|
dtype=dtype, |
|
do_classifier_free_guidance=do_classifier_free_guidance, |
|
num_videos_per_prompt=num_videos_per_prompt, |
|
prompt_embeds=prompt_embeds, |
|
ip_adapter_image_emb=ip_adapter_image_emb, |
|
batch_size=batch_size, |
|
ref_timestep_int=t, |
|
) |
|
for context in global_context: |
|
|
|
latents_c = torch.cat([latents[:, :, c] for c in context]) |
|
latent_index_c = ( |
|
torch.cat([latent_index[c] for c in context]) |
|
if latent_index is not None |
|
else None |
|
) |
|
latent_model_input = latents_c.to(device).repeat( |
|
2 if do_classifier_free_guidance else 1, 1, 1, 1, 1 |
|
) |
|
latent_model_input = self.scheduler.scale_model_input( |
|
latent_model_input, t |
|
) |
|
sub_latent_index_c = ( |
|
torch.LongTensor( |
|
torch.arange(latent_index_c.shape[-1]) + n_vision_cond |
|
).to(device=latents_c.device) |
|
if latent_index is not None |
|
else None |
|
) |
|
if condition_latents is not None: |
|
latent_model_condition = ( |
|
torch.cat([condition_latents] * 2) |
|
if do_classifier_free_guidance |
|
else latents |
|
) |
|
|
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"vision_condition_latent_index, {vision_condition_latent_index.shape}, vision_condition_latent_index" |
|
) |
|
logger.debug( |
|
f"latent_model_condition, {latent_model_condition.shape}" |
|
) |
|
logger.debug(f"latent_index, {latent_index_c.shape}") |
|
logger.debug( |
|
f"latent_model_input, {latent_model_input.shape}" |
|
) |
|
logger.debug(f"sub_latent_index_c, {sub_latent_index_c}") |
|
latent_model_input = batch_concat_two_tensor_with_index( |
|
data1=latent_model_condition, |
|
data1_index=vision_condition_latent_index, |
|
data2=latent_model_input, |
|
data2_index=sub_latent_index_c, |
|
dim=2, |
|
) |
|
if control_image is not None: |
|
if vision_condition_latent_index is not None: |
|
|
|
|
|
controlnet_condtion_latent_index = ( |
|
vision_condition_latent_index.clone().cpu().tolist() |
|
) |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"context={context}, controlnet_condtion_latent_index={controlnet_condtion_latent_index}" |
|
) |
|
controlnet_context = [ |
|
controlnet_condtion_latent_index |
|
+ [c_i + n_vision_cond for c_i in c] |
|
for c in context |
|
] |
|
else: |
|
controlnet_context = context |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"controlnet_context={controlnet_context}, latent_model_input={latent_model_input.shape}" |
|
) |
|
if isinstance(control_image, list): |
|
control_image_c = [ |
|
torch.cat( |
|
[ |
|
control_image_tmp[:, :, c] |
|
for c in controlnet_context |
|
] |
|
) |
|
for control_image_tmp in control_image |
|
] |
|
control_image_c = [ |
|
rearrange(control_image_tmp, " b c t h w-> (b t) c h w") |
|
for control_image_tmp in control_image_c |
|
] |
|
else: |
|
control_image_c = torch.cat( |
|
[control_image[:, :, c] for c in controlnet_context] |
|
) |
|
control_image_c = rearrange( |
|
control_image_c, " b c t h w-> (b t) c h w" |
|
) |
|
else: |
|
control_image_c = None |
|
if controlnet_latents is not None: |
|
if vision_condition_latent_index is not None: |
|
|
|
|
|
controlnet_condtion_latent_index = ( |
|
vision_condition_latent_index.clone().cpu().tolist() |
|
) |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"context={context}, controlnet_condtion_latent_index={controlnet_condtion_latent_index}" |
|
) |
|
controlnet_context = [ |
|
controlnet_condtion_latent_index |
|
+ [c_i + n_vision_cond for c_i in c] |
|
for c in context |
|
] |
|
else: |
|
controlnet_context = context |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"controlnet_context={controlnet_context}, controlnet_latents={controlnet_latents.shape}, latent_model_input={latent_model_input.shape}," |
|
) |
|
controlnet_latents_c = torch.cat( |
|
[controlnet_latents[:, :, c] for c in controlnet_context] |
|
) |
|
controlnet_latents_c = rearrange( |
|
controlnet_latents_c, " b c t h w-> (b t) c h w" |
|
) |
|
else: |
|
controlnet_latents_c = None |
|
( |
|
down_block_res_samples, |
|
mid_block_res_sample, |
|
) = self.get_controlnet_emb( |
|
run_controlnet=run_controlnet, |
|
guess_mode=guess_mode, |
|
do_classifier_free_guidance=do_classifier_free_guidance, |
|
latents=latents_c, |
|
prompt_embeds=prompt_embeds, |
|
latent_model_input=latent_model_input, |
|
control_image=control_image_c, |
|
controlnet_latents=controlnet_latents_c, |
|
controlnet_keep=controlnet_keep, |
|
t=t, |
|
i=i, |
|
controlnet_conditioning_scale=controlnet_conditioning_scale, |
|
) |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"{i}, latent_model_input={latent_model_input.shape}, sub_latent_index_c={sub_latent_index_c}" |
|
f"{vision_condition_latent_index}" |
|
) |
|
|
|
noise_pred_c = self.unet( |
|
latent_model_input, |
|
t, |
|
encoder_hidden_states=prompt_embeds, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
down_block_additional_residuals=down_block_res_samples, |
|
mid_block_additional_residual=mid_block_res_sample, |
|
return_dict=False, |
|
sample_index=sub_latent_index_c, |
|
vision_conditon_frames_sample_index=vision_condition_latent_index, |
|
sample_frame_rate=motion_speed, |
|
down_block_refer_embs=down_block_refer_embs, |
|
mid_block_refer_emb=mid_block_refer_emb, |
|
refer_self_attn_emb=refer_self_attn_emb, |
|
vision_clip_emb=ip_adapter_image_emb, |
|
face_emb=refer_face_image_emb, |
|
ip_adapter_scale=ip_adapter_scale, |
|
facein_scale=facein_scale, |
|
ip_adapter_face_emb=ip_adapter_face_emb, |
|
ip_adapter_face_scale=ip_adapter_face_scale, |
|
do_classifier_free_guidance=do_classifier_free_guidance, |
|
pose_guider_emb=pose_guider_emb, |
|
)[0] |
|
if condition_latents is not None: |
|
noise_pred_c = batch_index_select( |
|
noise_pred_c, dim=2, index=sub_latent_index_c |
|
).contiguous() |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"{i}, latent_model_input={latent_model_input.shape}, noise_pred_c={noise_pred_c.shape}, {len(context)}, {len(context[0])}" |
|
) |
|
for j, c in enumerate(context): |
|
noise_pred[:, :, c] = noise_pred[:, :, c] + noise_pred_c |
|
counter[:, :, c] = counter[:, :, c] + 1 |
|
noise_pred = noise_pred / counter |
|
|
|
if ( |
|
last_mid_video_noises is not None |
|
and len(last_mid_video_noises) > 0 |
|
and i <= num_inference_steps // 2 |
|
): |
|
if self.print_idx == 1: |
|
logger.debug( |
|
f"{i}, last_mid_video_noises={last_mid_video_noises[i].shape}" |
|
) |
|
noise_pred = fuse_part_tensor( |
|
last_mid_video_noises[i + 1], |
|
noise_pred, |
|
video_overlap, |
|
weight=0.01, |
|
skip_step=1, |
|
) |
|
if record_mid_video_noises: |
|
mid_video_noises.append(noise_pred[:, :, -video_overlap:]) |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale_lst[i] * ( |
|
noise_pred_text - noise_pred_uncond |
|
) |
|
|
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"before step, noise_pred={noise_pred.shape}, {noise_pred.device}, latents={latents.shape}, {latents.device}, t={t}" |
|
) |
|
|
|
latents = self.scheduler.step( |
|
noise_pred, |
|
t, |
|
latents, |
|
**extra_step_kwargs, |
|
).prev_sample |
|
|
|
if ( |
|
last_mid_video_latents is not None |
|
and len(last_mid_video_latents) > 0 |
|
and i <= 1 |
|
): |
|
if self.print_idx == 1: |
|
logger.debug( |
|
f"{i}, last_mid_video_latents={last_mid_video_latents[i].shape}" |
|
) |
|
latents = fuse_part_tensor( |
|
last_mid_video_latents[i + 1], |
|
latents, |
|
video_overlap, |
|
weight=0.1, |
|
skip_step=0, |
|
) |
|
if record_mid_video_latents: |
|
mid_video_latents.append(latents[:, :, -video_overlap:]) |
|
|
|
if need_middle_latents is True: |
|
videos_mid.append(self.decode_latents(latents)) |
|
|
|
if i == len(timesteps) - 1 or ( |
|
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 |
|
): |
|
progress_bar.update() |
|
if callback is not None and i % callback_steps == 0: |
|
callback(i, t, latents) |
|
self.print_idx += 1 |
|
|
|
if condition_latents is not None: |
|
latents = batch_concat_two_tensor_with_index( |
|
data1=condition_latents, |
|
data1_index=vision_condition_latent_index, |
|
data2=latents, |
|
data2_index=latent_index, |
|
dim=2, |
|
) |
|
b, c, t, h, w = latents.shape |
|
num_segments = (t + decoder_t_segment - 1) // decoder_t_segment |
|
|
|
video_segments = [] |
|
|
|
|
|
for i in range(num_segments): |
|
logger.debug(f"Decoding {i} th segment") |
|
start_t = i * decoder_t_segment |
|
end_t = min((i + 1) * decoder_t_segment, t) |
|
latents_segment = latents[:, :, start_t:end_t, :, :] |
|
video_segment = self.decode_latents(latents_segment) |
|
video_segments.append(video_segment) |
|
video_segments_np = np.concatenate(video_segments, axis=2) |
|
video = torch.from_numpy(video_segments_np) |
|
|
|
if skip_temporal_layer: |
|
self.unet.set_skip_temporal_layers(False) |
|
if need_hist_match: |
|
video[:, :, latent_index, :, :] = self.hist_match_with_vis_cond( |
|
batch_index_select(video, index=latent_index, dim=2), |
|
batch_index_select(video, index=vision_condition_latent_index, dim=2), |
|
) |
|
|
|
if output_type == "tensor": |
|
videos_mid = [torch.from_numpy(x) for x in videos_mid] |
|
video = torch.from_numpy(video) |
|
else: |
|
latents = latents.cpu().numpy() |
|
|
|
if not return_dict: |
|
return ( |
|
video, |
|
latents, |
|
videos_mid, |
|
mid_video_latents, |
|
mid_video_noises, |
|
) |
|
|
|
return VideoPipelineOutput( |
|
videos=video, |
|
latents=latents, |
|
videos_mid=videos_mid, |
|
mid_video_latents=mid_video_latents, |
|
mid_video_noises=mid_video_noises, |
|
) |
|
|