Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import time | |
import random | |
import functools | |
from typing import List, Optional, Tuple, Union | |
from pathlib import Path | |
from loguru import logger | |
import torch | |
import torch.distributed as dist | |
from hyvideo.constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE | |
from hyvideo.vae import load_vae | |
from hyvideo.modules import load_model | |
from hyvideo.text_encoder import TextEncoder | |
from hyvideo.utils.data_utils import align_to | |
from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed | |
from hyvideo.modules.fp8_optimization import convert_fp8_linear | |
from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler | |
from hyvideo.diffusion.pipelines import HunyuanVideoPipeline | |
try: | |
import xfuser | |
from xfuser.core.distributed import ( | |
get_sequence_parallel_world_size, | |
get_sequence_parallel_rank, | |
get_sp_group, | |
initialize_model_parallel, | |
init_distributed_environment | |
) | |
except: | |
xfuser = None | |
get_sequence_parallel_world_size = None | |
get_sequence_parallel_rank = None | |
get_sp_group = None | |
initialize_model_parallel = None | |
init_distributed_environment = None | |
def parallelize_transformer(pipe): | |
transformer = pipe.transformer | |
original_forward = transformer.forward | |
def new_forward( | |
self, | |
x: torch.Tensor, | |
t: torch.Tensor, # Should be in range(0, 1000). | |
text_states: torch.Tensor = None, | |
text_mask: torch.Tensor = None, # Now we don't use it. | |
text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation. | |
freqs_cos: Optional[torch.Tensor] = None, | |
freqs_sin: Optional[torch.Tensor] = None, | |
guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000. | |
return_dict: bool = True, | |
): | |
if x.shape[-2] // 2 % get_sequence_parallel_world_size() == 0: | |
# try to split x by height | |
split_dim = -2 | |
elif x.shape[-1] // 2 % get_sequence_parallel_world_size() == 0: | |
# try to split x by width | |
split_dim = -1 | |
else: | |
raise ValueError(f"Cannot split video sequence into ulysses_degree x ring_degree ({get_sequence_parallel_world_size()}) parts evenly") | |
# patch sizes for the temporal, height, and width dimensions are 1, 2, and 2. | |
temporal_size, h, w = x.shape[2], x.shape[3] // 2, x.shape[4] // 2 | |
x = torch.chunk(x, get_sequence_parallel_world_size(),dim=split_dim)[get_sequence_parallel_rank()] | |
dim_thw = freqs_cos.shape[-1] | |
freqs_cos = freqs_cos.reshape(temporal_size, h, w, dim_thw) | |
freqs_cos = torch.chunk(freqs_cos, get_sequence_parallel_world_size(),dim=split_dim - 1)[get_sequence_parallel_rank()] | |
freqs_cos = freqs_cos.reshape(-1, dim_thw) | |
dim_thw = freqs_sin.shape[-1] | |
freqs_sin = freqs_sin.reshape(temporal_size, h, w, dim_thw) | |
freqs_sin = torch.chunk(freqs_sin, get_sequence_parallel_world_size(),dim=split_dim - 1)[get_sequence_parallel_rank()] | |
freqs_sin = freqs_sin.reshape(-1, dim_thw) | |
from xfuser.core.long_ctx_attention import xFuserLongContextAttention | |
for block in transformer.double_blocks + transformer.single_blocks: | |
block.hybrid_seq_parallel_attn = xFuserLongContextAttention() | |
output = original_forward( | |
x, | |
t, | |
text_states, | |
text_mask, | |
text_states_2, | |
freqs_cos, | |
freqs_sin, | |
guidance, | |
return_dict, | |
) | |
return_dict = not isinstance(output, tuple) | |
sample = output["x"] | |
sample = get_sp_group().all_gather(sample, dim=split_dim) | |
output["x"] = sample | |
return output | |
new_forward = new_forward.__get__(transformer) | |
transformer.forward = new_forward | |
class Inference(object): | |
def __init__( | |
self, | |
args, | |
vae, | |
vae_kwargs, | |
text_encoder, | |
model, | |
text_encoder_2=None, | |
pipeline=None, | |
use_cpu_offload=False, | |
device=None, | |
logger=None, | |
parallel_args=None, | |
): | |
self.vae = vae | |
self.vae_kwargs = vae_kwargs | |
self.text_encoder = text_encoder | |
self.text_encoder_2 = text_encoder_2 | |
self.model = model | |
self.pipeline = pipeline | |
self.use_cpu_offload = use_cpu_offload | |
self.args = args | |
self.device = ( | |
device | |
if device is not None | |
else "cuda" | |
if torch.cuda.is_available() | |
else "cpu" | |
) | |
self.logger = logger | |
self.parallel_args = parallel_args | |
def from_pretrained(cls, pretrained_model_path, args, device=None, **kwargs): | |
""" | |
Initialize the Inference pipeline. | |
Args: | |
pretrained_model_path (str or pathlib.Path): The model path, including t2v, text encoder and vae checkpoints. | |
args (argparse.Namespace): The arguments for the pipeline. | |
device (int): The device for inference. Default is 0. | |
""" | |
# ======================================================================== | |
logger.info(f"Got text-to-video model root path: {pretrained_model_path}") | |
# ==================== Initialize Distributed Environment ================ | |
if args.ulysses_degree > 1 or args.ring_degree > 1: | |
assert xfuser is not None, \ | |
"Ulysses Attention and Ring Attention requires xfuser package." | |
assert args.use_cpu_offload is False, \ | |
"Cannot enable use_cpu_offload in the distributed environment." | |
dist.init_process_group("nccl") | |
assert dist.get_world_size() == args.ring_degree * args.ulysses_degree, \ | |
"number of GPUs should be equal to ring_degree * ulysses_degree." | |
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) | |
initialize_model_parallel( | |
sequence_parallel_degree=dist.get_world_size(), | |
ring_degree=args.ring_degree, | |
ulysses_degree=args.ulysses_degree, | |
) | |
device = torch.device(f"cuda:{os.environ['LOCAL_RANK']}") | |
else: | |
if device is None: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
parallel_args = {"ulysses_degree": args.ulysses_degree, "ring_degree": args.ring_degree} | |
# ======================== Get the args path ============================= | |
# Disable gradient | |
torch.set_grad_enabled(False) | |
# =========================== Build main model =========================== | |
logger.info("Building model...") | |
factor_kwargs = {"device": device, "dtype": PRECISION_TO_TYPE[args.precision]} | |
in_channels = args.latent_channels | |
out_channels = args.latent_channels | |
model = load_model( | |
args, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
factor_kwargs=factor_kwargs, | |
) | |
if args.use_fp8: | |
convert_fp8_linear(model, args.dit_weight, original_dtype=PRECISION_TO_TYPE[args.precision]) | |
model = model.to(device) | |
model = Inference.load_state_dict(args, model, pretrained_model_path) | |
model.eval() | |
# ============================= Build extra models ======================== | |
# VAE | |
vae, _, s_ratio, t_ratio = load_vae( | |
args.vae, | |
args.vae_precision, | |
logger=logger, | |
device=device if not args.use_cpu_offload else "cpu", | |
) | |
vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio} | |
# Text encoder | |
if args.prompt_template_video is not None: | |
crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get( | |
"crop_start", 0 | |
) | |
elif args.prompt_template is not None: | |
crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0) | |
else: | |
crop_start = 0 | |
max_length = args.text_len + crop_start | |
# prompt_template | |
prompt_template = ( | |
PROMPT_TEMPLATE[args.prompt_template] | |
if args.prompt_template is not None | |
else None | |
) | |
# prompt_template_video | |
prompt_template_video = ( | |
PROMPT_TEMPLATE[args.prompt_template_video] | |
if args.prompt_template_video is not None | |
else None | |
) | |
text_encoder = TextEncoder( | |
text_encoder_type=args.text_encoder, | |
max_length=max_length, | |
text_encoder_precision=args.text_encoder_precision, | |
tokenizer_type=args.tokenizer, | |
prompt_template=prompt_template, | |
prompt_template_video=prompt_template_video, | |
hidden_state_skip_layer=args.hidden_state_skip_layer, | |
apply_final_norm=args.apply_final_norm, | |
reproduce=args.reproduce, | |
logger=logger, | |
device=device if not args.use_cpu_offload else "cpu", | |
) | |
text_encoder_2 = None | |
if args.text_encoder_2 is not None: | |
text_encoder_2 = TextEncoder( | |
text_encoder_type=args.text_encoder_2, | |
max_length=args.text_len_2, | |
text_encoder_precision=args.text_encoder_precision_2, | |
tokenizer_type=args.tokenizer_2, | |
reproduce=args.reproduce, | |
logger=logger, | |
device=device if not args.use_cpu_offload else "cpu", | |
) | |
return cls( | |
args=args, | |
vae=vae, | |
vae_kwargs=vae_kwargs, | |
text_encoder=text_encoder, | |
text_encoder_2=text_encoder_2, | |
model=model, | |
use_cpu_offload=args.use_cpu_offload, | |
device=device, | |
logger=logger, | |
parallel_args=parallel_args | |
) | |
def load_state_dict(args, model, pretrained_model_path): | |
load_key = args.load_key | |
dit_weight = Path(args.dit_weight) | |
if dit_weight is None: | |
model_dir = pretrained_model_path / f"t2v_{args.model_resolution}" | |
files = list(model_dir.glob("*.pt")) | |
if len(files) == 0: | |
raise ValueError(f"No model weights found in {model_dir}") | |
if str(files[0]).startswith("pytorch_model_"): | |
model_path = dit_weight / f"pytorch_model_{load_key}.pt" | |
bare_model = True | |
elif any(str(f).endswith("_model_states.pt") for f in files): | |
files = [f for f in files if str(f).endswith("_model_states.pt")] | |
model_path = files[0] | |
if len(files) > 1: | |
logger.warning( | |
f"Multiple model weights found in {dit_weight}, using {model_path}" | |
) | |
bare_model = False | |
else: | |
raise ValueError( | |
f"Invalid model path: {dit_weight} with unrecognized weight format: " | |
f"{list(map(str, files))}. When given a directory as --dit-weight, only " | |
f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and " | |
f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a " | |
f"specific weight file, please provide the full path to the file." | |
) | |
else: | |
if dit_weight.is_dir(): | |
files = list(dit_weight.glob("*.pt")) | |
if len(files) == 0: | |
raise ValueError(f"No model weights found in {dit_weight}") | |
if str(files[0]).startswith("pytorch_model_"): | |
model_path = dit_weight / f"pytorch_model_{load_key}.pt" | |
bare_model = True | |
elif any(str(f).endswith("_model_states.pt") for f in files): | |
files = [f for f in files if str(f).endswith("_model_states.pt")] | |
model_path = files[0] | |
if len(files) > 1: | |
logger.warning( | |
f"Multiple model weights found in {dit_weight}, using {model_path}" | |
) | |
bare_model = False | |
else: | |
raise ValueError( | |
f"Invalid model path: {dit_weight} with unrecognized weight format: " | |
f"{list(map(str, files))}. When given a directory as --dit-weight, only " | |
f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and " | |
f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a " | |
f"specific weight file, please provide the full path to the file." | |
) | |
elif dit_weight.is_file(): | |
model_path = dit_weight | |
bare_model = "unknown" | |
else: | |
raise ValueError(f"Invalid model path: {dit_weight}") | |
if not model_path.exists(): | |
raise ValueError(f"model_path not exists: {model_path}") | |
logger.info(f"Loading torch model {model_path}...") | |
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage) | |
if bare_model == "unknown" and ("ema" in state_dict or "module" in state_dict): | |
bare_model = False | |
if bare_model is False: | |
if load_key in state_dict: | |
state_dict = state_dict[load_key] | |
else: | |
raise KeyError( | |
f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint " | |
f"are: {list(state_dict.keys())}." | |
) | |
model.load_state_dict(state_dict, strict=True) | |
return model | |
def parse_size(size): | |
if isinstance(size, int): | |
size = [size] | |
if not isinstance(size, (list, tuple)): | |
raise ValueError(f"Size must be an integer or (height, width), got {size}.") | |
if len(size) == 1: | |
size = [size[0], size[0]] | |
if len(size) != 2: | |
raise ValueError(f"Size must be an integer or (height, width), got {size}.") | |
return size | |
class HunyuanVideoSampler(Inference): | |
def __init__( | |
self, | |
args, | |
vae, | |
vae_kwargs, | |
text_encoder, | |
model, | |
text_encoder_2=None, | |
pipeline=None, | |
use_cpu_offload=False, | |
device=0, | |
logger=None, | |
parallel_args=None | |
): | |
super().__init__( | |
args, | |
vae, | |
vae_kwargs, | |
text_encoder, | |
model, | |
text_encoder_2=text_encoder_2, | |
pipeline=pipeline, | |
use_cpu_offload=use_cpu_offload, | |
device=device, | |
logger=logger, | |
parallel_args=parallel_args | |
) | |
self.pipeline = self.load_diffusion_pipeline( | |
args=args, | |
vae=self.vae, | |
text_encoder=self.text_encoder, | |
text_encoder_2=self.text_encoder_2, | |
model=self.model, | |
device=self.device, | |
) | |
self.default_negative_prompt = NEGATIVE_PROMPT | |
if self.parallel_args['ulysses_degree'] > 1 or self.parallel_args['ring_degree'] > 1: | |
parallelize_transformer(self.pipeline) | |
def load_diffusion_pipeline( | |
self, | |
args, | |
vae, | |
text_encoder, | |
text_encoder_2, | |
model, | |
scheduler=None, | |
device=None, | |
progress_bar_config=None, | |
data_type="video", | |
): | |
"""Load the denoising scheduler for inference.""" | |
if scheduler is None: | |
if args.denoise_type == "flow": | |
scheduler = FlowMatchDiscreteScheduler( | |
shift=args.flow_shift, | |
reverse=args.flow_reverse, | |
solver=args.flow_solver, | |
) | |
else: | |
raise ValueError(f"Invalid denoise type {args.denoise_type}") | |
pipeline = HunyuanVideoPipeline( | |
vae=vae, | |
text_encoder=text_encoder, | |
text_encoder_2=text_encoder_2, | |
transformer=model, | |
scheduler=scheduler, | |
progress_bar_config=progress_bar_config, | |
args=args, | |
) | |
if self.use_cpu_offload: | |
pipeline.enable_sequential_cpu_offload() | |
else: | |
pipeline = pipeline.to(device) | |
return pipeline | |
def get_rotary_pos_embed(self, video_length, height, width): | |
target_ndim = 3 | |
ndim = 5 - 2 | |
# 884 | |
if "884" in self.args.vae: | |
latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8] | |
elif "888" in self.args.vae: | |
latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8] | |
else: | |
latents_size = [video_length, height // 8, width // 8] | |
if isinstance(self.model.patch_size, int): | |
assert all(s % self.model.patch_size == 0 for s in latents_size), ( | |
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " | |
f"but got {latents_size}." | |
) | |
rope_sizes = [s // self.model.patch_size for s in latents_size] | |
elif isinstance(self.model.patch_size, list): | |
assert all( | |
s % self.model.patch_size[idx] == 0 | |
for idx, s in enumerate(latents_size) | |
), ( | |
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " | |
f"but got {latents_size}." | |
) | |
rope_sizes = [ | |
s // self.model.patch_size[idx] for idx, s in enumerate(latents_size) | |
] | |
if len(rope_sizes) != target_ndim: | |
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis | |
head_dim = self.model.hidden_size // self.model.heads_num | |
rope_dim_list = self.model.rope_dim_list | |
if rope_dim_list is None: | |
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] | |
assert ( | |
sum(rope_dim_list) == head_dim | |
), "sum(rope_dim_list) should equal to head_dim of attention layer" | |
freqs_cos, freqs_sin = get_nd_rotary_pos_embed( | |
rope_dim_list, | |
rope_sizes, | |
theta=self.args.rope_theta, | |
use_real=True, | |
theta_rescale_factor=1, | |
) | |
return freqs_cos, freqs_sin | |
def predict( | |
self, | |
prompt, | |
height=192, | |
width=336, | |
video_length=129, | |
seed=None, | |
negative_prompt=None, | |
infer_steps=50, | |
guidance_scale=6, | |
flow_shift=5.0, | |
embedded_guidance_scale=None, | |
batch_size=1, | |
num_videos_per_prompt=1, | |
**kwargs, | |
): | |
""" | |
Predict the image/video from the given text. | |
Args: | |
prompt (str or List[str]): The input text. | |
kwargs: | |
height (int): The height of the output video. Default is 192. | |
width (int): The width of the output video. Default is 336. | |
video_length (int): The frame number of the output video. Default is 129. | |
seed (int or List[str]): The random seed for the generation. Default is a random integer. | |
negative_prompt (str or List[str]): The negative text prompt. Default is an empty string. | |
guidance_scale (float): The guidance scale for the generation. Default is 6.0. | |
num_images_per_prompt (int): The number of images per prompt. Default is 1. | |
infer_steps (int): The number of inference steps. Default is 100. | |
""" | |
out_dict = dict() | |
# ======================================================================== | |
# Arguments: seed | |
# ======================================================================== | |
if isinstance(seed, torch.Tensor): | |
seed = seed.tolist() | |
if seed is None: | |
seeds = [ | |
random.randint(0, 1_000_000) | |
for _ in range(batch_size * num_videos_per_prompt) | |
] | |
elif isinstance(seed, int): | |
seeds = [ | |
seed + i | |
for _ in range(batch_size) | |
for i in range(num_videos_per_prompt) | |
] | |
elif isinstance(seed, (list, tuple)): | |
if len(seed) == batch_size: | |
seeds = [ | |
int(seed[i]) + j | |
for i in range(batch_size) | |
for j in range(num_videos_per_prompt) | |
] | |
elif len(seed) == batch_size * num_videos_per_prompt: | |
seeds = [int(s) for s in seed] | |
else: | |
raise ValueError( | |
f"Length of seed must be equal to number of prompt(batch_size) or " | |
f"batch_size * num_videos_per_prompt ({batch_size} * {num_videos_per_prompt}), got {seed}." | |
) | |
else: | |
raise ValueError( | |
f"Seed must be an integer, a list of integers, or None, got {seed}." | |
) | |
generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds] | |
out_dict["seeds"] = seeds | |
# ======================================================================== | |
# Arguments: target_width, target_height, target_video_length | |
# ======================================================================== | |
if width <= 0 or height <= 0 or video_length <= 0: | |
raise ValueError( | |
f"`height` and `width` and `video_length` must be positive integers, got height={height}, width={width}, video_length={video_length}" | |
) | |
if (video_length - 1) % 4 != 0: | |
raise ValueError( | |
f"`video_length-1` must be a multiple of 4, got {video_length}" | |
) | |
logger.info( | |
f"Input (height, width, video_length) = ({height}, {width}, {video_length})" | |
) | |
target_height = align_to(height, 16) | |
target_width = align_to(width, 16) | |
target_video_length = video_length | |
out_dict["size"] = (target_height, target_width, target_video_length) | |
# ======================================================================== | |
# Arguments: prompt, new_prompt, negative_prompt | |
# ======================================================================== | |
if not isinstance(prompt, str): | |
raise TypeError(f"`prompt` must be a string, but got {type(prompt)}") | |
prompt = [prompt.strip()] | |
# negative prompt | |
if negative_prompt is None or negative_prompt == "": | |
negative_prompt = self.default_negative_prompt | |
if not isinstance(negative_prompt, str): | |
raise TypeError( | |
f"`negative_prompt` must be a string, but got {type(negative_prompt)}" | |
) | |
negative_prompt = [negative_prompt.strip()] | |
# ======================================================================== | |
# Scheduler | |
# ======================================================================== | |
scheduler = FlowMatchDiscreteScheduler( | |
shift=flow_shift, | |
reverse=self.args.flow_reverse, | |
solver=self.args.flow_solver | |
) | |
self.pipeline.scheduler = scheduler | |
# ======================================================================== | |
# Build Rope freqs | |
# ======================================================================== | |
freqs_cos, freqs_sin = self.get_rotary_pos_embed( | |
target_video_length, target_height, target_width | |
) | |
n_tokens = freqs_cos.shape[0] | |
# ======================================================================== | |
# Print infer args | |
# ======================================================================== | |
debug_str = f""" | |
height: {target_height} | |
width: {target_width} | |
video_length: {target_video_length} | |
prompt: {prompt} | |
neg_prompt: {negative_prompt} | |
seed: {seed} | |
infer_steps: {infer_steps} | |
num_videos_per_prompt: {num_videos_per_prompt} | |
guidance_scale: {guidance_scale} | |
n_tokens: {n_tokens} | |
flow_shift: {flow_shift} | |
embedded_guidance_scale: {embedded_guidance_scale}""" | |
logger.debug(debug_str) | |
# ======================================================================== | |
# Pipeline inference | |
# ======================================================================== | |
start_time = time.time() | |
samples = self.pipeline( | |
prompt=prompt, | |
height=target_height, | |
width=target_width, | |
video_length=target_video_length, | |
num_inference_steps=infer_steps, | |
guidance_scale=guidance_scale, | |
negative_prompt=negative_prompt, | |
num_videos_per_prompt=num_videos_per_prompt, | |
generator=generator, | |
output_type="pil", | |
freqs_cis=(freqs_cos, freqs_sin), | |
n_tokens=n_tokens, | |
embedded_guidance_scale=embedded_guidance_scale, | |
data_type="video" if target_video_length > 1 else "image", | |
is_progress_bar=True, | |
vae_ver=self.args.vae, | |
enable_tiling=self.args.vae_tiling, | |
)[0] | |
out_dict["samples"] = samples | |
out_dict["prompts"] = prompt | |
gen_time = time.time() - start_time | |
logger.info(f"Success, time: {gen_time}") | |
return out_dict | |