Spaces:
Running
on
A100
Running
on
A100
import os | |
from pathlib import Path | |
from loguru import logger | |
import torch | |
import numpy as np | |
import torch.distributed | |
import random | |
import torchvision.transforms as transforms | |
from PIL import Image | |
import cv2 | |
from torch.utils.data.distributed import DistributedSampler | |
from torch.utils.data import DataLoader | |
from hymm_sp.config import parse_args | |
from hymm_sp.sample_inference import HunyuanVideoSampler | |
from hymm_sp.data_kits.video_dataset import VideoCSVDataset | |
from hymm_sp.data_kits.data_tools import save_videos_grid | |
from hymm_sp.modules.parallel_states import ( | |
initialize_distributed, | |
nccl_info, | |
) | |
class CropResize: | |
""" | |
Custom transform to resize and crop images to a target size while preserving aspect ratio. | |
Resizes the image to ensure it covers the target dimensions, then center-crops to the exact size. | |
Useful for preparing consistent input dimensions for video generation models. | |
""" | |
def __init__(self, size=(704, 1216)): | |
""" | |
Args: | |
size (tuple): Target dimensions (height, width) for the output image | |
""" | |
self.target_h, self.target_w = size | |
def __call__(self, img): | |
""" | |
Apply the transform to an image. | |
Args: | |
img (PIL.Image): Input image to transform | |
Returns: | |
PIL.Image: Resized and cropped image with target dimensions | |
""" | |
# Get original image dimensions | |
w, h = img.size | |
# Calculate scaling factor to ensure image covers target size | |
scale = max( | |
self.target_w / w, # Scale needed to cover target width | |
self.target_h / h # Scale needed to cover target height | |
) | |
# Resize image while preserving aspect ratio | |
new_size = (int(h * scale), int(w * scale)) | |
resize_transform = transforms.Resize( | |
new_size, | |
interpolation=transforms.InterpolationMode.BILINEAR | |
) | |
resized_img = resize_transform(img) | |
# Center-crop to exact target dimensions | |
crop_transform = transforms.CenterCrop((self.target_h, self.target_w)) | |
return crop_transform(resized_img) | |
def main(): | |
""" | |
Main function for video generation using the Hunyuan multimodal model. | |
Handles argument parsing, distributed setup, model loading, data preparation, | |
and video generation with action-controlled transitions. Supports both image-to-video | |
and video-to-video generation tasks. | |
""" | |
# Parse command-line arguments and configuration | |
args = parse_args() | |
models_root_path = Path(args.ckpt) | |
action_list = args.action_list | |
action_speed_list = args.action_speed_list | |
negative_prompt = args.add_neg_prompt | |
# Initialize distributed training/evaluation environment | |
logger.info("*" * 20) | |
initialize_distributed(args.seed) | |
# Validate model checkpoint path exists | |
if not models_root_path.exists(): | |
raise ValueError(f"Model checkpoint path does not exist: {models_root_path}") | |
logger.info("+" * 20) | |
# Set up output directory | |
save_path = args.save_path if args.save_path_suffix == "" else f'{args.save_path}_{args.save_path_suffix}' | |
os.makedirs(save_path, exist_ok=True) | |
logger.info(f"Generated videos will be saved to: {save_path}") | |
# Initialize device configuration for distributed processing | |
rank = 0 | |
device = torch.device("cuda") | |
if nccl_info.sp_size > 1: | |
# Use specific GPU based on process rank in distributed setup | |
device = torch.device(f"cuda:{torch.distributed.get_rank()}") | |
rank = torch.distributed.get_rank() | |
# Load the Hunyuan video sampler model from checkpoint | |
logger.info(f"Loading model from checkpoint: {args.ckpt}") | |
hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained( | |
args.ckpt, | |
args=args, | |
device=device if not args.cpu_offload else torch.device("cpu") | |
) | |
# Update args with model-specific configurations from the checkpoint | |
args = hunyuan_video_sampler.args | |
# Enable CPU offloading if specified to reduce GPU memory usage | |
if args.cpu_offload: | |
from diffusers.hooks import apply_group_offloading | |
onload_device = torch.device("cuda") | |
apply_group_offloading( | |
hunyuan_video_sampler.pipeline.transformer, | |
onload_device=onload_device, | |
offload_type="block_level", | |
num_blocks_per_group=1 | |
) | |
logger.info("Enabled CPU offloading for transformer blocks") | |
# Process each batch in the dataset | |
prompt = args.prompt | |
image_paths = [args.image_path] | |
logger.info(f"Prompt: {prompt}, Image Path {args.image_path}") | |
# Generate random seed for reproducibility | |
seed = args.seed if args.seed else random.randint(0, 1_000_000) | |
# Define image transformation pipeline for input reference images | |
closest_size = (704, 1216) | |
ref_image_transform = transforms.Compose([ | |
CropResize(closest_size), | |
transforms.CenterCrop(closest_size), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]) # Normalize to [-1, 1] range | |
]) | |
# Handle image-based generation (start from a single image) | |
if args.image_start: | |
# Load and preprocess reference images | |
raw_ref_images = [Image.open(image_path).convert('RGB') for image_path in image_paths] | |
# Apply transformations and prepare tensor for model input | |
ref_images_pixel_values = [ref_image_transform(ref_image) for ref_image in raw_ref_images] | |
ref_images_pixel_values = torch.cat(ref_images_pixel_values).unsqueeze(0).unsqueeze(2).to(device) | |
# Encode reference images to latent space using VAE | |
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True): | |
if args.cpu_offload: | |
# Move VAE components to GPU temporarily for encoding | |
hunyuan_video_sampler.vae.quant_conv.to('cuda') | |
hunyuan_video_sampler.vae.encoder.to('cuda') | |
# Enable tiling for VAE to handle large images efficiently | |
hunyuan_video_sampler.pipeline.vae.enable_tiling() | |
# Encode image to latents and scale by VAE's scaling factor | |
raw_last_latents = hunyuan_video_sampler.vae.encode( | |
ref_images_pixel_values | |
).latent_dist.sample().to(dtype=torch.float16) # Shape: (B, C, F, H, W) | |
raw_last_latents.mul_(hunyuan_video_sampler.vae.config.scaling_factor) | |
raw_ref_latents = raw_last_latents.clone() | |
# Clean up | |
hunyuan_video_sampler.pipeline.vae.disable_tiling() | |
if args.cpu_offload: | |
# Move VAE components back to CPU after encoding | |
hunyuan_video_sampler.vae.quant_conv.to('cpu') | |
hunyuan_video_sampler.vae.encoder.to('cpu') | |
# Handle video-based generation (start from an existing video) | |
else: | |
from decord import VideoReader # Lazy import for video handling | |
# Validate video file exists | |
video_path = args.video_path | |
if not os.path.exists(video_path): | |
raise RuntimeError(f"Video file not found: {video_path}") | |
# Load reference images from video metadata | |
raw_ref_images = [Image.open(image_path).convert('RGB') for image_path in image_paths] | |
# Load video and extract frames | |
ref_video = VideoReader(video_path) | |
ref_frames_length = len(ref_video) | |
logger.info(f"Loaded reference video with {ref_frames_length} frames") | |
# Preprocess video frames | |
transformed_images = [] | |
for index in range(ref_frames_length): | |
# Convert video frame to PIL image and apply transformations | |
video_image = ref_video[index].numpy() | |
transformed_image = ref_image_transform(Image.fromarray(video_image)) | |
transformed_images.append(transformed_image) | |
# Prepare tensor for model input | |
transformed_images = torch.stack(transformed_images, dim=1).unsqueeze(0).to( | |
device=hunyuan_video_sampler.device, | |
dtype=torch.float16 | |
) | |
# Encode video frames to latent space using VAE | |
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True): | |
if args.cpu_offload: | |
hunyuan_video_sampler.vae.quant_conv.to('cuda') | |
hunyuan_video_sampler.vae.encoder.to('cuda') | |
hunyuan_video_sampler.pipeline.vae.enable_tiling() | |
# Encode last 33 frames of video (model-specific requirement) | |
raw_last_latents = hunyuan_video_sampler.vae.encode( | |
transformed_images[:, :, -33:, ...] | |
).latent_dist.sample().to(dtype=torch.float16) | |
raw_last_latents.mul_(hunyuan_video_sampler.vae.config.scaling_factor) | |
# Encode a single reference frame from the video | |
raw_ref_latents = hunyuan_video_sampler.vae.encode( | |
transformed_images[:, :, -33:-32, ...] | |
).latent_dist.sample().to(dtype=torch.float16) | |
raw_ref_latents.mul_(hunyuan_video_sampler.vae.config.scaling_factor) | |
# Clean up | |
hunyuan_video_sampler.pipeline.vae.disable_tiling() | |
if args.cpu_offload: | |
hunyuan_video_sampler.vae.quant_conv.to('cpu') | |
hunyuan_video_sampler.vae.encoder.to('cpu') | |
# Store references for generation loop | |
ref_images = raw_ref_images | |
last_latents = raw_last_latents | |
ref_latents = raw_ref_latents | |
# Generate video segments for each action in the action list | |
for idx, action_id in enumerate(action_list): | |
# Determine if this is the first action and using image start | |
is_image = (idx == 0 and args.image_start) | |
logger.info(f"Generating segment {idx+1}/{len(action_list)} with action ID: {action_id}") | |
# Generate video segment with the current action | |
outputs = hunyuan_video_sampler.predict( | |
prompt=prompt, | |
action_id=action_id, | |
action_speed=action_speed_list[idx], | |
is_image=is_image, | |
size=(704, 1216), | |
seed=seed, | |
last_latents=last_latents, # Previous frame latents for continuity | |
ref_latents=ref_latents, # Reference latents for style consistency | |
video_length=args.sample_n_frames, | |
guidance_scale=args.cfg_scale, | |
num_images_per_prompt=args.num_images, | |
negative_prompt=negative_prompt, | |
infer_steps=args.infer_steps, | |
flow_shift=args.flow_shift_eval_video, | |
use_linear_quadratic_schedule=args.use_linear_quadratic_schedule, | |
linear_schedule_end=args.linear_schedule_end, | |
use_deepcache=args.use_deepcache, | |
cpu_offload=args.cpu_offload, | |
ref_images=ref_images, | |
output_dir=save_path, | |
return_latents=True, | |
use_sage=args.use_sage, | |
) | |
# Update latents for next iteration (maintain temporal consistency) | |
ref_latents = outputs["ref_latents"] | |
last_latents = outputs["last_latents"] | |
# Save generated video segments if this is the main process (rank 0) | |
if rank == 0: | |
sub_samples = outputs['samples'][0] | |
# Initialize or concatenate video segments | |
if idx == 0: | |
if args.image_start: | |
out_cat = sub_samples | |
else: | |
# Combine original video frames with generated frames | |
out_cat = torch.cat([(transformed_images.detach().cpu() + 1) / 2.0, sub_samples], dim=2) | |
else: | |
# Append new segment to existing video | |
out_cat = torch.cat([out_cat, sub_samples], dim=2) | |
# Save final combined video | |
save_path_mp4 = f"{save_path}/{os.path.basename(args.image_path).split('.')[0]}.mp4" | |
save_videos_grid(out_cat, save_path_mp4, n_rows=1, fps=24) | |
logger.info(f"Saved generated video to: {save_path_mp4}") | |
if __name__ == "__main__": | |
main() | |