Hunyuan-GameCraft / hymm_sp /sample_inference.py
jbilcke-hf's picture
jbilcke-hf HF Staff
Initial commit with LFS-tracked binary files
01c0e76
import math
import time
import torch
import random
from loguru import logger
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from hymm_sp.diffusion import load_diffusion_pipeline
from hymm_sp.helpers import get_nd_rotary_pos_embed_new
from hymm_sp.inference import Inference
from hymm_sp.diffusion.schedulers import FlowMatchDiscreteScheduler
from packaging import version as pver
ACTION_DICT = {"w": "forward", "a": "left", "d": "right", "s": "backward"}
def custom_meshgrid(*args):
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
if pver.parse(torch.__version__) < pver.parse('1.10'):
return torch.meshgrid(*args)
else:
return torch.meshgrid(*args, indexing='ij')
def get_relative_pose(cam_params):
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
source_cam_c2w = abs_c2ws[0]
cam_to_origin = 0
target_cam_c2w = np.array([
[1, 0, 0, 0],
[0, 1, 0, -cam_to_origin],
[0, 0, 1, 0],
[0, 0, 0, 1]
])
abs2rel = target_cam_c2w @ abs_w2cs[0]
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
for pose in ret_poses:
pose[:3, -1:] *= 10
ret_poses = np.array(ret_poses, dtype=np.float32)
return ret_poses
def ray_condition(K, c2w, H, W, device, flip_flag=None):
# c2w: B, V, 4, 4
# K: B, V, 4
B, V = K.shape[:2]
j, i = custom_meshgrid(
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
)
i = i.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW]
j = j.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW]
n_flip = torch.sum(flip_flag).item() if flip_flag is not None else 0
if n_flip > 0:
j_flip, i_flip = custom_meshgrid(
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
torch.linspace(W - 1, 0, W, device=device, dtype=c2w.dtype)
)
i_flip = i_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
j_flip = j_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
i[:, flip_flag, ...] = i_flip
j[:, flip_flag, ...] = j_flip
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
zs = torch.ones_like(i) # [B, V, HxW]
xs = (i - cx) / fx * zs
ys = (j - cy) / fy * zs
zs = zs.expand_as(ys)
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, HW, 3
rays_o = c2w[..., :3, 3] # B, V, 3
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, HW, 3
# c2w @ dirctions
rays_dxo = torch.cross(rays_o, rays_d) # B, V, HW, 3
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
# plucker = plucker.permute(0, 1, 4, 2, 3)
return plucker
def get_c2w(w2cs, transform_matrix, relative_c2w):
if relative_c2w:
target_cam_c2w = np.array([
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]
])
abs2rel = target_cam_c2w @ w2cs[0]
ret_poses = [target_cam_c2w, ] + [abs2rel @ np.linalg.inv(w2c) for w2c in w2cs[1:]]
for pose in ret_poses:
pose[:3, -1:] *= 2
# ret_poses = [poses[:, :3]*2 for poses in ret_poses]
# ret_poses[:, :, :3] *= 2
else:
ret_poses = [np.linalg.inv(w2c) for w2c in w2cs]
ret_poses = [transform_matrix @ x for x in ret_poses]
return np.array(ret_poses, dtype=np.float32)
def generate_motion_segment(current_pose,
motion_type: str,
value: float,
duration: int = 30):
"""
Parameters:
motion_type: ('forward', 'backward', 'left', 'right',
'rotate_left', 'rotate_right', 'rotate_up', 'rotate_down')
value: Translation(m) or Rotation(degree)
duration: frames
Return:
positions: [np.array(x,y,z), ...]
rotations: [np.array(pitch,yaw,roll), ...]
"""
positions = []
rotations = []
if motion_type in ['forward', 'backward']:
yaw_rad = np.radians(current_pose['rotation'][1])
pitch_rad = np.radians(current_pose['rotation'][0])
forward_vec = np.array([
-math.sin(yaw_rad) * math.cos(pitch_rad),
math.sin(pitch_rad),
-math.cos(yaw_rad) * math.cos(pitch_rad)
])
direction = 1 if motion_type == 'forward' else -1
total_move = forward_vec * value * direction
step = total_move / duration
for i in range(1, duration+1):
new_pos = current_pose['position'] + step * i
positions.append(new_pos.copy())
rotations.append(current_pose['rotation'].copy())
current_pose['position'] = positions[-1]
elif motion_type in ['left', 'right']:
yaw_rad = np.radians(current_pose['rotation'][1])
right_vec = np.array([math.cos(yaw_rad), 0, -math.sin(yaw_rad)])
direction = -1 if motion_type == 'right' else 1
total_move = right_vec * value * direction
step = total_move / duration
for i in range(1, duration+1):
new_pos = current_pose['position'] + step * i
positions.append(new_pos.copy())
rotations.append(current_pose['rotation'].copy())
current_pose['position'] = positions[-1]
elif motion_type.endswith('rot'):
axis = motion_type.split('_')[0]
total_rotation = np.zeros(3)
if axis == 'left':
total_rotation[0] = value
elif axis == 'right':
total_rotation[0] = -value
elif axis == 'up':
total_rotation[2] = -value
elif axis == 'down':
total_rotation[2] = value
step = total_rotation / duration
for i in range(1, duration+1):
positions.append(current_pose['position'].copy())
new_rot = current_pose['rotation'] + step * i
rotations.append(new_rot.copy())
current_pose['rotation'] = rotations[-1]
return positions, rotations, current_pose
def euler_to_quaternion(angles):
pitch, yaw, roll = np.radians(angles)
cy = math.cos(yaw * 0.5)
sy = math.sin(yaw * 0.5)
cp = math.cos(pitch * 0.5)
sp = math.sin(pitch * 0.5)
cr = math.cos(roll * 0.5)
sr = math.sin(roll * 0.5)
qw = cy * cp * cr + sy * sp * sr
qx = cy * cp * sr - sy * sp * cr
qy = sy * cp * sr + cy * sp * cr
qz = sy * cp * cr - cy * sp * sr
return [qw, qx, qy, qz]
def quaternion_to_rotation_matrix(q):
qw, qx, qy, qz = q
return np.array([
[1 - 2*(qy**2 + qz**2), 2*(qx*qy - qw*qz), 2*(qx*qz + qw*qy)],
[2*(qx*qy + qw*qz), 1 - 2*(qx**2 + qz**2), 2*(qy*qz - qw*qx)],
[2*(qx*qz - qw*qy), 2*(qy*qz + qw*qx), 1 - 2*(qx**2 + qy**2)]
])
def ActionToPoseFromID(action_id, value=0.2, duration=33):
all_positions = []
all_rotations = []
current_pose = {
'position': np.array([0.0, 0.0, 0.0]), # XYZ
'rotation': np.array([0.0, 0.0, 0.0]) # (pitch, yaw, roll)
}
intrinsic = [0.50505, 0.8979, 0.5, 0.5]
motion_type = ACTION_DICT[action_id]
positions, rotations, current_pose = generate_motion_segment(current_pose, motion_type, value, duration)
all_positions.extend(positions)
all_rotations.extend(rotations)
pose_list = []
row = [0] + intrinsic + [0, 0] + [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]
first_row = " ".join(map(str, row))
pose_list.append(first_row)
for i, (pos, rot) in enumerate(zip(all_positions, all_rotations)):
quat = euler_to_quaternion(rot)
R = quaternion_to_rotation_matrix(quat)
extrinsic = np.hstack([R, pos.reshape(3, 1)])
row = [i] + intrinsic + [0, 0] + extrinsic.flatten().tolist()
pose_list.append(" ".join(map(str, row)))
return pose_list
class Camera(object):
def __init__(self, entry):
fx, fy, cx, cy = entry[1:5]
self.fx = fx
self.fy = fy
self.cx = cx
self.cy = cy
w2c_mat = np.array(entry[7:]).reshape(3, 4)
w2c_mat_4x4 = np.eye(4)
w2c_mat_4x4[:3, :] = w2c_mat
self.w2c_mat = w2c_mat_4x4
self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
class CameraPoseVisualizer:
def __init__(self, xlim, ylim, zlim):
self.fig = plt.figure(figsize=(7, 7))
self.ax = self.fig.add_subplot(projection='3d')
# self.ax.view_init(elev=25, azim=-90)
self.plotly_data = None # plotly data traces
self.ax.set_aspect("auto")
self.ax.set_xlim(xlim)
self.ax.set_ylim(ylim)
self.ax.set_zlim(zlim)
self.ax.set_xlabel('x')
self.ax.set_ylabel('y')
self.ax.set_zlabel('z')
print('initialize camera pose visualizer')
def extrinsic2pyramid(self, extrinsic, color_map='red', hw_ratio=9/16, base_xval=1, zval=3):
vertex_std = np.array([[0, 0, 0, 1],
[base_xval, -base_xval * hw_ratio, zval, 1],
[base_xval, base_xval * hw_ratio, zval, 1],
[-base_xval, base_xval * hw_ratio, zval, 1],
[-base_xval, -base_xval * hw_ratio, zval, 1]])
vertex_transformed = vertex_std @ extrinsic.T
meshes = [[vertex_transformed[0, :-1], vertex_transformed[1][:-1], vertex_transformed[2, :-1]],
[vertex_transformed[0, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1]],
[vertex_transformed[0, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]],
[vertex_transformed[0, :-1], vertex_transformed[4, :-1], vertex_transformed[1, :-1]],
[vertex_transformed[1, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1],
vertex_transformed[4, :-1]]]
color = color_map if isinstance(color_map, str) else plt.cm.rainbow(color_map)
self.ax.add_collection3d(
Poly3DCollection(meshes, facecolors=color, linewidths=0.3, edgecolors=color, alpha=0.35))
def customize_legend(self, list_label):
list_handle = []
for idx, label in enumerate(list_label):
color = plt.cm.rainbow(idx / len(list_label))
patch = Patch(color=color, label=label)
list_handle.append(patch)
plt.legend(loc='right', bbox_to_anchor=(1.8, 0.5), handles=list_handle)
def colorbar(self, max_frame_length):
cmap = mpl.cm.rainbow
norm = mpl.colors.Normalize(vmin=0, vmax=max_frame_length)
self.fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
ax=self.ax, orientation='vertical', label='Frame Number')
def show(self, file_name):
plt.title('Extrinsic Parameters')
# plt.show()
plt.savefig(file_name)
def align_to(value, alignment):
return int(math.ceil(value / alignment) * alignment)
def GetPoseEmbedsFromPoses(poses, h, w, target_length, flip=False, start_index=None):
poses = [pose.split(' ') for pose in poses]
start_idx = start_index
sample_id = [start_idx + i for i in range(target_length)]
poses = [poses[i] for i in sample_id]
frame = len(poses)
w2cs = [np.asarray([float(p) for p in pose[7:]]).reshape(3, 4) for pose in poses]
transform_matrix = np.asarray([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]).reshape(4, 4)
last_row = np.zeros((1, 4))
last_row[0, -1] = 1.0
w2cs = [np.concatenate((w2c, last_row), axis=0) for w2c in w2cs]
c2ws = get_c2w(w2cs, transform_matrix, relative_c2w=True)
cam_params = [[float(x) for x in pose] for pose in poses]
assert len(cam_params) == target_length
cam_params = [Camera(cam_param) for cam_param in cam_params]
monst3r_w = cam_params[0].cx * 2
monst3r_h = cam_params[0].cy * 2
ratio_w, ratio_h = w/monst3r_w, h/monst3r_h
intrinsics = np.asarray([[cam_param.fx * ratio_w,
cam_param.fy * ratio_h,
cam_param.cx * ratio_w,
cam_param.cy * ratio_h]
for cam_param in cam_params], dtype=np.float32)
intrinsics = torch.as_tensor(intrinsics)[None] # [1, n_frame, 4]
relative_pose = True
if relative_pose:
c2w_poses = get_relative_pose(cam_params)
else:
c2w_poses = np.array([cam_param.c2w_mat for cam_param in cam_params], dtype=np.float32)
c2w = torch.as_tensor(c2w_poses)[None] # [1, n_frame, 4, 4]
uncond_c2w = torch.zeros_like(c2w)
uncond_c2w[:, :] = torch.eye(4, device=c2w.device)
flip_flag = torch.zeros(target_length, dtype=torch.bool, device=c2w.device)
plucker_embedding = ray_condition(intrinsics, c2w, h, w, device='cpu',
flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous()
uncond_plucker_embedding = ray_condition(intrinsics, uncond_c2w, h, w, device='cpu',
flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous()
return plucker_embedding, uncond_plucker_embedding, poses
def GetPoseEmbedsFromTxt(pose_dir, h, w, target_length, flip=False, start_index=None, step=1):
# get camera pose
with open(pose_dir, 'r') as f:
poses = f.readlines()
poses = [pose.strip().split(' ') for pose in poses[1:]]
start_idx = start_index
sample_id = [start_idx + i*step for i in range(target_length)]
poses = [poses[i] for i in sample_id]
cam_params = [[float(x) for x in pose] for pose in poses]
assert len(cam_params) == target_length
cam_params = [Camera(cam_param) for cam_param in cam_params]
monst3r_w = cam_params[0].cx * 2
monst3r_h = cam_params[0].cy * 2
ratio_w, ratio_h = w/monst3r_w, h/monst3r_h
intrinsics = np.asarray([[cam_param.fx * ratio_w,
cam_param.fy * ratio_h,
cam_param.cx * ratio_w,
cam_param.cy * ratio_h]
for cam_param in cam_params], dtype=np.float32)
intrinsics = torch.as_tensor(intrinsics)[None] # [1, n_frame, 4]
relative_pose = True
if relative_pose:
c2w_poses = get_relative_pose(cam_params)
else:
c2w_poses = np.array([cam_param.c2w_mat for cam_param in cam_params], dtype=np.float32)
c2w = torch.as_tensor(c2w_poses)[None] # [1, n_frame, 4, 4]
uncond_c2w = torch.zeros_like(c2w)
uncond_c2w[:, :] = torch.eye(4, device=c2w.device)
if flip:
flip_flag = torch.ones(target_length, dtype=torch.bool, device=c2w.device)
else:
flip_flag = torch.zeros(target_length, dtype=torch.bool, device=c2w.device)
plucker_embedding = ray_condition(intrinsics, c2w, h, w, device='cpu',
flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous()
uncond_plucker_embedding = ray_condition(intrinsics, uncond_c2w, h, w, device='cpu',
flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous()
return plucker_embedding, uncond_plucker_embedding, poses
class HunyuanVideoSampler(Inference):
def __init__(self, args, vae, vae_kwargs, text_encoder, model, text_encoder_2=None, pipeline=None,
device=0, logger=None):
super().__init__(args, vae, vae_kwargs, text_encoder, model, text_encoder_2=text_encoder_2,
pipeline=pipeline, device=device, logger=logger)
self.args = args
self.pipeline = load_diffusion_pipeline(
args, 0, self.vae, self.text_encoder, self.text_encoder_2, self.model,
device=self.device)
print('load hunyuan model successful... ')
def get_rotary_pos_embed(self, video_length, height, width, concat_dict={}):
target_ndim = 3
ndim = 5 - 2
if '884' in self.args.vae:
latents_size = [(video_length-1)//4+1 , height//8, width//8]
else:
latents_size = [video_length , height//8, width//8]
if isinstance(self.model.patch_size, int):
assert all(s % self.model.patch_size == 0 for s in latents_size), \
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \
f"but got {latents_size}."
rope_sizes = [s // self.model.patch_size for s in latents_size]
elif isinstance(self.model.patch_size, list):
assert all(s % self.model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \
f"but got {latents_size}."
rope_sizes = [s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)]
if len(rope_sizes) != target_ndim:
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
head_dim = self.model.hidden_size // self.model.num_heads
rope_dim_list = self.model.rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(rope_dim_list,
rope_sizes,
theta=self.args.rope_theta,
use_real=True,
theta_rescale_factor=1,
concat_dict=concat_dict)
return freqs_cos, freqs_sin
@torch.no_grad()
def predict(self,
prompt,
is_image=True,
size=(720, 1280),
video_length=129,
seed=None,
negative_prompt=None,
infer_steps=50,
guidance_scale=6.0,
flow_shift=5.0,
batch_size=1,
num_videos_per_prompt=1,
verbose=1,
output_type="pil",
**kwargs):
"""
Predict the image from the given text.
Args:
prompt (str or List[str]): The input text.
kwargs:
size (int): The (height, width) of the output image/video. Default is (256, 256).
video_length (int): The frame number of the output video. Default is 1.
seed (int or List[str]): The random seed for the generation. Default is a random integer.
negative_prompt (str or List[str]): The negative text prompt. Default is an empty string.
infer_steps (int): The number of inference steps. Default is 100.
guidance_scale (float): The guidance scale for the generation. Default is 6.0.
num_videos_per_prompt (int): The number of videos per prompt. Default is 1.
verbose (int): 0 for no log, 1 for all log, 2 for fewer log. Default is 1.
output_type (str): The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
Default is 'pil'.
"""
out_dict = dict()
# ---------------------------------
# Prompt
# ---------------------------------
prompt_embeds = kwargs.get("prompt_embeds", None)
attention_mask = kwargs.get("attention_mask", None)
negative_prompt_embeds = kwargs.get("negative_prompt_embeds", None)
negative_attention_mask = kwargs.get("negative_attention_mask", None)
ref_latents = kwargs.get("ref_latents", None)
uncond_ref_latents = kwargs.get("uncond_ref_latents", None)
return_latents = kwargs.get("return_latents", False)
negative_prompt = kwargs.get("negative_prompt", None)
action_id = kwargs.get("action_id", None)
action_speed = kwargs.get("action_speed", None)
start_index = kwargs.get("start_index", None)
last_latents = kwargs.get("last_latents", None)
ref_latents = kwargs.get("ref_latents", None)
input_pose = kwargs.get("input_pose", None)
step = kwargs.get("step", 1)
use_sage = kwargs.get("use_sage", False)
size = self.parse_size(size)
target_height = align_to(size[0], 16)
target_width = align_to(size[1], 16)
# target_video_length = video_length
if input_pose is not None:
pose_embeds, uncond_pose_embeds, poses = GetPoseEmbedsFromTxt(input_pose,
target_height,
target_width,
33,
kwargs.get("flip", False),
start_index,
step)
else:
pose = ActionToPoseFromID(action_id, value=action_speed)
pose_embeds, uncond_pose_embeds, poses = GetPoseEmbedsFromPoses(pose,
target_height,
target_width,
33,
kwargs.get("flip", False),
0)
if is_image:
target_length = 34
else:
target_length = 66
out_dict['frame'] = target_length
# print("pose embeds: ", pose_embeds.shape, uncond_pose_embeds.shape)
pose_embeds = pose_embeds.unsqueeze(0).to(torch.bfloat16).to('cuda')
uncond_pose_embeds = uncond_pose_embeds.unsqueeze(0).to(torch.bfloat16).to('cuda')
cpu_offload = kwargs.get("cpu_offload", 0)
use_deepcache = kwargs.get("use_deepcache", 1)
denoise_strength = kwargs.get("denoise_strength", 1.0)
init_latents = kwargs.get("init_latents", None)
mask = kwargs.get("mask", None)
if prompt is None:
# prompt_embeds, attention_mask, negative_prompt_embeds and negative_attention_mask should not be None
# pipeline will help to check this
prompt = None
negative_prompt = None
batch_size = prompt_embeds.shape[0]
assert prompt_embeds is not None
else:
# prompt_embeds, attention_mask, negative_prompt_embeds and negative_attention_mask should be None
# pipeline will help to check this
if isinstance(prompt, str):
batch_size = 1
prompt = [prompt]
elif isinstance(prompt, (list, tuple)):
batch_size = len(prompt)
else:
raise ValueError(f"Prompt must be a string or a list of strings, got {prompt}.")
if negative_prompt is None:
negative_prompt = [""] * batch_size
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * batch_size
# ---------------------------------
# Other arguments
# ---------------------------------
scheduler = FlowMatchDiscreteScheduler(shift=flow_shift,
reverse=self.args.flow_reverse,
solver=self.args.flow_solver,
)
self.pipeline.scheduler = scheduler
# ---------------------------------
# Random seed
# ---------------------------------
if isinstance(seed, torch.Tensor):
seed = seed.tolist()
if seed is None:
seeds = [random.randint(0, 1_000_000) for _ in range(batch_size * num_videos_per_prompt)]
elif isinstance(seed, int):
seeds = [seed + i for _ in range(batch_size) for i in range(num_videos_per_prompt)]
elif isinstance(seed, (list, tuple)):
if len(seed) == batch_size:
seeds = [int(seed[i]) + j for i in range(batch_size) for j in range(num_videos_per_prompt)]
elif len(seed) == batch_size * num_videos_per_prompt:
seeds = [int(s) for s in seed]
else:
raise ValueError(
f"Length of seed must be equal to number of prompt(batch_size) or "
f"batch_size * num_videos_per_prompt ({batch_size} * {num_videos_per_prompt}), got {seed}."
)
else:
raise ValueError(f"Seed must be an integer, a list of integers, or None, got {seed}.")
generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds]
# ---------------------------------
# Image/Video size and frame
# ---------------------------------
out_dict['size'] = (target_height, target_width)
out_dict['video_length'] = target_length
out_dict['seeds'] = seeds
out_dict['negative_prompt'] = negative_prompt
# ---------------------------------
# Build RoPE
# ---------------------------------
concat_dict = {'mode': 'timecat', 'bias': -1}
if is_image:
freqs_cos, freqs_sin = self.get_rotary_pos_embed(37, target_height, target_width)
else:
freqs_cos, freqs_sin = self.get_rotary_pos_embed(69, target_height, target_width)
n_tokens = freqs_cos.shape[0]
# ---------------------------------
# Inference
# ---------------------------------
output_dir = kwargs.get("output_dir", None)
if verbose == 1:
debug_str = f"""
size: {out_dict['size']}
video_length: {target_length}
prompt: {prompt}
neg_prompt: {negative_prompt}
seed: {seed}
infer_steps: {infer_steps}
denoise_strength: {denoise_strength}
use_deepcache: {use_deepcache}
use_sage: {use_sage}
cpu_offload: {cpu_offload}
num_images_per_prompt: {num_videos_per_prompt}
guidance_scale: {guidance_scale}
n_tokens: {n_tokens}
flow_shift: {flow_shift}
output: {output_dir}"""
self.logger.info(debug_str)
start_time = time.time()
samples = self.pipeline(prompt=prompt,
last_latents=last_latents,
cam_latents=pose_embeds,
uncond_cam_latents=uncond_pose_embeds,
height=target_height,
width=target_width,
video_length=target_length,
gt_latents = ref_latents,
num_inference_steps=infer_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
generator=generator,
prompt_embeds=prompt_embeds,
ref_latents=ref_latents,
latents=init_latents,
denoise_strength=denoise_strength,
mask=mask,
uncond_ref_latents=uncond_ref_latents,
ip_cfg_scale=self.args.ip_cfg_scale,
use_deepcache=use_deepcache,
attention_mask=attention_mask,
negative_prompt_embeds=negative_prompt_embeds,
negative_attention_mask=negative_attention_mask,
output_type=output_type,
freqs_cis=(freqs_cos, freqs_sin),
n_tokens=n_tokens,
data_type='video' if target_length > 1 else 'image',
is_progress_bar=True,
vae_ver=self.args.vae,
enable_tiling=self.args.vae_tiling,
cpu_offload=cpu_offload,
return_latents=return_latents,
use_sage=use_sage,
)
if samples is None:
return None
out_dict['samples'] = []
out_dict["prompts"] = prompt
out_dict['pose'] = poses
if return_latents:
print("return_latents | TRUE")
latents, timesteps, last_latents, ref_latents = samples[1], samples[2], samples[3], samples[4]
# samples = samples[0][0]
if samples[0] is not None and len(samples[0]) > 0:
samples = samples[0][0]
else:
samples = None
out_dict["denoised_lantents"] = latents
out_dict["timesteps"] = timesteps
out_dict["ref_latents"] = ref_latents
out_dict["last_latents"] = last_latents
else:
samples = samples[0]
if samples is not None:
for i, sample in enumerate(samples):
sample = samples[i].unsqueeze(0)
sub_samples = []
sub_samples.append(sample)
sample_num = len(sub_samples)
sub_samples = torch.concat(sub_samples)
# only save in tp rank 0
out_dict['samples'].append(sub_samples)
# visualize pose
gen_time = time.time() - start_time
logger.info(f"Success, time: {gen_time}")
return out_dict