|
import math |
|
import time |
|
import torch |
|
import random |
|
from loguru import logger |
|
import numpy as np |
|
import matplotlib as mpl |
|
import matplotlib.pyplot as plt |
|
from matplotlib.patches import Patch |
|
from mpl_toolkits.mplot3d.art3d import Poly3DCollection |
|
|
|
from hymm_sp.diffusion import load_diffusion_pipeline |
|
from hymm_sp.helpers import get_nd_rotary_pos_embed_new |
|
from hymm_sp.inference import Inference |
|
from hymm_sp.diffusion.schedulers import FlowMatchDiscreteScheduler |
|
from packaging import version as pver |
|
|
|
ACTION_DICT = {"w": "forward", "a": "left", "d": "right", "s": "backward"} |
|
|
|
def custom_meshgrid(*args): |
|
|
|
if pver.parse(torch.__version__) < pver.parse('1.10'): |
|
return torch.meshgrid(*args) |
|
else: |
|
return torch.meshgrid(*args, indexing='ij') |
|
|
|
def get_relative_pose(cam_params): |
|
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] |
|
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] |
|
source_cam_c2w = abs_c2ws[0] |
|
cam_to_origin = 0 |
|
target_cam_c2w = np.array([ |
|
[1, 0, 0, 0], |
|
[0, 1, 0, -cam_to_origin], |
|
[0, 0, 1, 0], |
|
[0, 0, 0, 1] |
|
]) |
|
abs2rel = target_cam_c2w @ abs_w2cs[0] |
|
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] |
|
for pose in ret_poses: |
|
pose[:3, -1:] *= 10 |
|
ret_poses = np.array(ret_poses, dtype=np.float32) |
|
return ret_poses |
|
|
|
def ray_condition(K, c2w, H, W, device, flip_flag=None): |
|
|
|
|
|
|
|
B, V = K.shape[:2] |
|
|
|
j, i = custom_meshgrid( |
|
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), |
|
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), |
|
) |
|
i = i.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 |
|
j = j.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 |
|
|
|
n_flip = torch.sum(flip_flag).item() if flip_flag is not None else 0 |
|
if n_flip > 0: |
|
j_flip, i_flip = custom_meshgrid( |
|
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), |
|
torch.linspace(W - 1, 0, W, device=device, dtype=c2w.dtype) |
|
) |
|
i_flip = i_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5 |
|
j_flip = j_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5 |
|
i[:, flip_flag, ...] = i_flip |
|
j[:, flip_flag, ...] = j_flip |
|
|
|
fx, fy, cx, cy = K.chunk(4, dim=-1) |
|
|
|
zs = torch.ones_like(i) |
|
xs = (i - cx) / fx * zs |
|
ys = (j - cy) / fy * zs |
|
zs = zs.expand_as(ys) |
|
|
|
directions = torch.stack((xs, ys, zs), dim=-1) |
|
directions = directions / directions.norm(dim=-1, keepdim=True) |
|
|
|
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) |
|
rays_o = c2w[..., :3, 3] |
|
rays_o = rays_o[:, :, None].expand_as(rays_d) |
|
|
|
rays_dxo = torch.cross(rays_o, rays_d) |
|
plucker = torch.cat([rays_dxo, rays_d], dim=-1) |
|
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) |
|
|
|
return plucker |
|
|
|
def get_c2w(w2cs, transform_matrix, relative_c2w): |
|
if relative_c2w: |
|
target_cam_c2w = np.array([ |
|
[1, 0, 0, 0], |
|
[0, 1, 0, 0], |
|
[0, 0, 1, 0], |
|
[0, 0, 0, 1] |
|
]) |
|
abs2rel = target_cam_c2w @ w2cs[0] |
|
ret_poses = [target_cam_c2w, ] + [abs2rel @ np.linalg.inv(w2c) for w2c in w2cs[1:]] |
|
for pose in ret_poses: |
|
pose[:3, -1:] *= 2 |
|
|
|
|
|
else: |
|
ret_poses = [np.linalg.inv(w2c) for w2c in w2cs] |
|
ret_poses = [transform_matrix @ x for x in ret_poses] |
|
return np.array(ret_poses, dtype=np.float32) |
|
|
|
def generate_motion_segment(current_pose, |
|
motion_type: str, |
|
value: float, |
|
duration: int = 30): |
|
""" |
|
Parameters: |
|
motion_type: ('forward', 'backward', 'left', 'right', |
|
'rotate_left', 'rotate_right', 'rotate_up', 'rotate_down') |
|
value: Translation(m) or Rotation(degree) |
|
duration: frames |
|
|
|
Return: |
|
positions: [np.array(x,y,z), ...] |
|
rotations: [np.array(pitch,yaw,roll), ...] |
|
""" |
|
positions = [] |
|
rotations = [] |
|
|
|
if motion_type in ['forward', 'backward']: |
|
yaw_rad = np.radians(current_pose['rotation'][1]) |
|
pitch_rad = np.radians(current_pose['rotation'][0]) |
|
|
|
forward_vec = np.array([ |
|
-math.sin(yaw_rad) * math.cos(pitch_rad), |
|
math.sin(pitch_rad), |
|
-math.cos(yaw_rad) * math.cos(pitch_rad) |
|
]) |
|
|
|
direction = 1 if motion_type == 'forward' else -1 |
|
total_move = forward_vec * value * direction |
|
step = total_move / duration |
|
|
|
for i in range(1, duration+1): |
|
new_pos = current_pose['position'] + step * i |
|
positions.append(new_pos.copy()) |
|
rotations.append(current_pose['rotation'].copy()) |
|
|
|
current_pose['position'] = positions[-1] |
|
|
|
elif motion_type in ['left', 'right']: |
|
yaw_rad = np.radians(current_pose['rotation'][1]) |
|
right_vec = np.array([math.cos(yaw_rad), 0, -math.sin(yaw_rad)]) |
|
|
|
direction = -1 if motion_type == 'right' else 1 |
|
total_move = right_vec * value * direction |
|
step = total_move / duration |
|
|
|
for i in range(1, duration+1): |
|
new_pos = current_pose['position'] + step * i |
|
positions.append(new_pos.copy()) |
|
rotations.append(current_pose['rotation'].copy()) |
|
|
|
current_pose['position'] = positions[-1] |
|
|
|
elif motion_type.endswith('rot'): |
|
axis = motion_type.split('_')[0] |
|
total_rotation = np.zeros(3) |
|
|
|
if axis == 'left': |
|
total_rotation[0] = value |
|
elif axis == 'right': |
|
total_rotation[0] = -value |
|
elif axis == 'up': |
|
total_rotation[2] = -value |
|
elif axis == 'down': |
|
total_rotation[2] = value |
|
|
|
step = total_rotation / duration |
|
|
|
for i in range(1, duration+1): |
|
positions.append(current_pose['position'].copy()) |
|
new_rot = current_pose['rotation'] + step * i |
|
rotations.append(new_rot.copy()) |
|
|
|
current_pose['rotation'] = rotations[-1] |
|
|
|
return positions, rotations, current_pose |
|
|
|
def euler_to_quaternion(angles): |
|
pitch, yaw, roll = np.radians(angles) |
|
|
|
cy = math.cos(yaw * 0.5) |
|
sy = math.sin(yaw * 0.5) |
|
cp = math.cos(pitch * 0.5) |
|
sp = math.sin(pitch * 0.5) |
|
cr = math.cos(roll * 0.5) |
|
sr = math.sin(roll * 0.5) |
|
|
|
qw = cy * cp * cr + sy * sp * sr |
|
qx = cy * cp * sr - sy * sp * cr |
|
qy = sy * cp * sr + cy * sp * cr |
|
qz = sy * cp * cr - cy * sp * sr |
|
|
|
return [qw, qx, qy, qz] |
|
|
|
def quaternion_to_rotation_matrix(q): |
|
qw, qx, qy, qz = q |
|
return np.array([ |
|
[1 - 2*(qy**2 + qz**2), 2*(qx*qy - qw*qz), 2*(qx*qz + qw*qy)], |
|
[2*(qx*qy + qw*qz), 1 - 2*(qx**2 + qz**2), 2*(qy*qz - qw*qx)], |
|
[2*(qx*qz - qw*qy), 2*(qy*qz + qw*qx), 1 - 2*(qx**2 + qy**2)] |
|
]) |
|
|
|
def ActionToPoseFromID(action_id, value=0.2, duration=33): |
|
|
|
all_positions = [] |
|
all_rotations = [] |
|
current_pose = { |
|
'position': np.array([0.0, 0.0, 0.0]), |
|
'rotation': np.array([0.0, 0.0, 0.0]) |
|
} |
|
intrinsic = [0.50505, 0.8979, 0.5, 0.5] |
|
motion_type = ACTION_DICT[action_id] |
|
positions, rotations, current_pose = generate_motion_segment(current_pose, motion_type, value, duration) |
|
all_positions.extend(positions) |
|
all_rotations.extend(rotations) |
|
|
|
pose_list = [] |
|
|
|
row = [0] + intrinsic + [0, 0] + [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0] |
|
first_row = " ".join(map(str, row)) |
|
pose_list.append(first_row) |
|
for i, (pos, rot) in enumerate(zip(all_positions, all_rotations)): |
|
quat = euler_to_quaternion(rot) |
|
R = quaternion_to_rotation_matrix(quat) |
|
extrinsic = np.hstack([R, pos.reshape(3, 1)]) |
|
|
|
row = [i] + intrinsic + [0, 0] + extrinsic.flatten().tolist() |
|
pose_list.append(" ".join(map(str, row))) |
|
|
|
return pose_list |
|
|
|
class Camera(object): |
|
def __init__(self, entry): |
|
fx, fy, cx, cy = entry[1:5] |
|
self.fx = fx |
|
self.fy = fy |
|
self.cx = cx |
|
self.cy = cy |
|
w2c_mat = np.array(entry[7:]).reshape(3, 4) |
|
w2c_mat_4x4 = np.eye(4) |
|
w2c_mat_4x4[:3, :] = w2c_mat |
|
self.w2c_mat = w2c_mat_4x4 |
|
self.c2w_mat = np.linalg.inv(w2c_mat_4x4) |
|
|
|
class CameraPoseVisualizer: |
|
def __init__(self, xlim, ylim, zlim): |
|
self.fig = plt.figure(figsize=(7, 7)) |
|
self.ax = self.fig.add_subplot(projection='3d') |
|
|
|
self.plotly_data = None |
|
self.ax.set_aspect("auto") |
|
self.ax.set_xlim(xlim) |
|
self.ax.set_ylim(ylim) |
|
self.ax.set_zlim(zlim) |
|
self.ax.set_xlabel('x') |
|
self.ax.set_ylabel('y') |
|
self.ax.set_zlabel('z') |
|
print('initialize camera pose visualizer') |
|
|
|
def extrinsic2pyramid(self, extrinsic, color_map='red', hw_ratio=9/16, base_xval=1, zval=3): |
|
vertex_std = np.array([[0, 0, 0, 1], |
|
[base_xval, -base_xval * hw_ratio, zval, 1], |
|
[base_xval, base_xval * hw_ratio, zval, 1], |
|
[-base_xval, base_xval * hw_ratio, zval, 1], |
|
[-base_xval, -base_xval * hw_ratio, zval, 1]]) |
|
vertex_transformed = vertex_std @ extrinsic.T |
|
meshes = [[vertex_transformed[0, :-1], vertex_transformed[1][:-1], vertex_transformed[2, :-1]], |
|
[vertex_transformed[0, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1]], |
|
[vertex_transformed[0, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]], |
|
[vertex_transformed[0, :-1], vertex_transformed[4, :-1], vertex_transformed[1, :-1]], |
|
[vertex_transformed[1, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1], |
|
vertex_transformed[4, :-1]]] |
|
|
|
color = color_map if isinstance(color_map, str) else plt.cm.rainbow(color_map) |
|
|
|
self.ax.add_collection3d( |
|
Poly3DCollection(meshes, facecolors=color, linewidths=0.3, edgecolors=color, alpha=0.35)) |
|
|
|
def customize_legend(self, list_label): |
|
list_handle = [] |
|
for idx, label in enumerate(list_label): |
|
color = plt.cm.rainbow(idx / len(list_label)) |
|
patch = Patch(color=color, label=label) |
|
list_handle.append(patch) |
|
plt.legend(loc='right', bbox_to_anchor=(1.8, 0.5), handles=list_handle) |
|
|
|
def colorbar(self, max_frame_length): |
|
cmap = mpl.cm.rainbow |
|
norm = mpl.colors.Normalize(vmin=0, vmax=max_frame_length) |
|
self.fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), |
|
ax=self.ax, orientation='vertical', label='Frame Number') |
|
|
|
def show(self, file_name): |
|
plt.title('Extrinsic Parameters') |
|
|
|
plt.savefig(file_name) |
|
|
|
|
|
def align_to(value, alignment): |
|
return int(math.ceil(value / alignment) * alignment) |
|
|
|
|
|
def GetPoseEmbedsFromPoses(poses, h, w, target_length, flip=False, start_index=None): |
|
|
|
poses = [pose.split(' ') for pose in poses] |
|
|
|
start_idx = start_index |
|
sample_id = [start_idx + i for i in range(target_length)] |
|
|
|
poses = [poses[i] for i in sample_id] |
|
|
|
frame = len(poses) |
|
w2cs = [np.asarray([float(p) for p in pose[7:]]).reshape(3, 4) for pose in poses] |
|
transform_matrix = np.asarray([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]).reshape(4, 4) |
|
last_row = np.zeros((1, 4)) |
|
last_row[0, -1] = 1.0 |
|
w2cs = [np.concatenate((w2c, last_row), axis=0) for w2c in w2cs] |
|
c2ws = get_c2w(w2cs, transform_matrix, relative_c2w=True) |
|
|
|
cam_params = [[float(x) for x in pose] for pose in poses] |
|
assert len(cam_params) == target_length |
|
cam_params = [Camera(cam_param) for cam_param in cam_params] |
|
|
|
monst3r_w = cam_params[0].cx * 2 |
|
monst3r_h = cam_params[0].cy * 2 |
|
ratio_w, ratio_h = w/monst3r_w, h/monst3r_h |
|
intrinsics = np.asarray([[cam_param.fx * ratio_w, |
|
cam_param.fy * ratio_h, |
|
cam_param.cx * ratio_w, |
|
cam_param.cy * ratio_h] |
|
for cam_param in cam_params], dtype=np.float32) |
|
intrinsics = torch.as_tensor(intrinsics)[None] |
|
relative_pose = True |
|
if relative_pose: |
|
c2w_poses = get_relative_pose(cam_params) |
|
else: |
|
c2w_poses = np.array([cam_param.c2w_mat for cam_param in cam_params], dtype=np.float32) |
|
c2w = torch.as_tensor(c2w_poses)[None] |
|
uncond_c2w = torch.zeros_like(c2w) |
|
uncond_c2w[:, :] = torch.eye(4, device=c2w.device) |
|
flip_flag = torch.zeros(target_length, dtype=torch.bool, device=c2w.device) |
|
plucker_embedding = ray_condition(intrinsics, c2w, h, w, device='cpu', |
|
flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous() |
|
uncond_plucker_embedding = ray_condition(intrinsics, uncond_c2w, h, w, device='cpu', |
|
flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous() |
|
|
|
return plucker_embedding, uncond_plucker_embedding, poses |
|
|
|
def GetPoseEmbedsFromTxt(pose_dir, h, w, target_length, flip=False, start_index=None, step=1): |
|
|
|
with open(pose_dir, 'r') as f: |
|
poses = f.readlines() |
|
poses = [pose.strip().split(' ') for pose in poses[1:]] |
|
start_idx = start_index |
|
sample_id = [start_idx + i*step for i in range(target_length)] |
|
poses = [poses[i] for i in sample_id] |
|
|
|
cam_params = [[float(x) for x in pose] for pose in poses] |
|
assert len(cam_params) == target_length |
|
cam_params = [Camera(cam_param) for cam_param in cam_params] |
|
|
|
monst3r_w = cam_params[0].cx * 2 |
|
monst3r_h = cam_params[0].cy * 2 |
|
ratio_w, ratio_h = w/monst3r_w, h/monst3r_h |
|
intrinsics = np.asarray([[cam_param.fx * ratio_w, |
|
cam_param.fy * ratio_h, |
|
cam_param.cx * ratio_w, |
|
cam_param.cy * ratio_h] |
|
for cam_param in cam_params], dtype=np.float32) |
|
intrinsics = torch.as_tensor(intrinsics)[None] |
|
relative_pose = True |
|
if relative_pose: |
|
c2w_poses = get_relative_pose(cam_params) |
|
else: |
|
c2w_poses = np.array([cam_param.c2w_mat for cam_param in cam_params], dtype=np.float32) |
|
c2w = torch.as_tensor(c2w_poses)[None] |
|
uncond_c2w = torch.zeros_like(c2w) |
|
uncond_c2w[:, :] = torch.eye(4, device=c2w.device) |
|
if flip: |
|
flip_flag = torch.ones(target_length, dtype=torch.bool, device=c2w.device) |
|
else: |
|
flip_flag = torch.zeros(target_length, dtype=torch.bool, device=c2w.device) |
|
plucker_embedding = ray_condition(intrinsics, c2w, h, w, device='cpu', |
|
flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous() |
|
uncond_plucker_embedding = ray_condition(intrinsics, uncond_c2w, h, w, device='cpu', |
|
flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous() |
|
|
|
return plucker_embedding, uncond_plucker_embedding, poses |
|
|
|
|
|
class HunyuanVideoSampler(Inference): |
|
def __init__(self, args, vae, vae_kwargs, text_encoder, model, text_encoder_2=None, pipeline=None, |
|
device=0, logger=None): |
|
super().__init__(args, vae, vae_kwargs, text_encoder, model, text_encoder_2=text_encoder_2, |
|
pipeline=pipeline, device=device, logger=logger) |
|
|
|
self.args = args |
|
self.pipeline = load_diffusion_pipeline( |
|
args, 0, self.vae, self.text_encoder, self.text_encoder_2, self.model, |
|
device=self.device) |
|
print('load hunyuan model successful... ') |
|
|
|
def get_rotary_pos_embed(self, video_length, height, width, concat_dict={}): |
|
target_ndim = 3 |
|
ndim = 5 - 2 |
|
if '884' in self.args.vae: |
|
latents_size = [(video_length-1)//4+1 , height//8, width//8] |
|
else: |
|
latents_size = [video_length , height//8, width//8] |
|
|
|
if isinstance(self.model.patch_size, int): |
|
assert all(s % self.model.patch_size == 0 for s in latents_size), \ |
|
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \ |
|
f"but got {latents_size}." |
|
rope_sizes = [s // self.model.patch_size for s in latents_size] |
|
elif isinstance(self.model.patch_size, list): |
|
assert all(s % self.model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \ |
|
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \ |
|
f"but got {latents_size}." |
|
rope_sizes = [s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)] |
|
|
|
if len(rope_sizes) != target_ndim: |
|
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes |
|
head_dim = self.model.hidden_size // self.model.num_heads |
|
rope_dim_list = self.model.rope_dim_list |
|
if rope_dim_list is None: |
|
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] |
|
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" |
|
freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(rope_dim_list, |
|
rope_sizes, |
|
theta=self.args.rope_theta, |
|
use_real=True, |
|
theta_rescale_factor=1, |
|
concat_dict=concat_dict) |
|
return freqs_cos, freqs_sin |
|
|
|
@torch.no_grad() |
|
def predict(self, |
|
prompt, |
|
is_image=True, |
|
size=(720, 1280), |
|
video_length=129, |
|
seed=None, |
|
negative_prompt=None, |
|
infer_steps=50, |
|
guidance_scale=6.0, |
|
flow_shift=5.0, |
|
batch_size=1, |
|
num_videos_per_prompt=1, |
|
verbose=1, |
|
output_type="pil", |
|
**kwargs): |
|
""" |
|
Predict the image from the given text. |
|
|
|
Args: |
|
prompt (str or List[str]): The input text. |
|
kwargs: |
|
size (int): The (height, width) of the output image/video. Default is (256, 256). |
|
video_length (int): The frame number of the output video. Default is 1. |
|
seed (int or List[str]): The random seed for the generation. Default is a random integer. |
|
negative_prompt (str or List[str]): The negative text prompt. Default is an empty string. |
|
infer_steps (int): The number of inference steps. Default is 100. |
|
guidance_scale (float): The guidance scale for the generation. Default is 6.0. |
|
num_videos_per_prompt (int): The number of videos per prompt. Default is 1. |
|
verbose (int): 0 for no log, 1 for all log, 2 for fewer log. Default is 1. |
|
output_type (str): The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. |
|
Default is 'pil'. |
|
""" |
|
|
|
out_dict = dict() |
|
|
|
|
|
|
|
|
|
prompt_embeds = kwargs.get("prompt_embeds", None) |
|
attention_mask = kwargs.get("attention_mask", None) |
|
negative_prompt_embeds = kwargs.get("negative_prompt_embeds", None) |
|
negative_attention_mask = kwargs.get("negative_attention_mask", None) |
|
ref_latents = kwargs.get("ref_latents", None) |
|
uncond_ref_latents = kwargs.get("uncond_ref_latents", None) |
|
return_latents = kwargs.get("return_latents", False) |
|
negative_prompt = kwargs.get("negative_prompt", None) |
|
|
|
action_id = kwargs.get("action_id", None) |
|
action_speed = kwargs.get("action_speed", None) |
|
start_index = kwargs.get("start_index", None) |
|
last_latents = kwargs.get("last_latents", None) |
|
ref_latents = kwargs.get("ref_latents", None) |
|
input_pose = kwargs.get("input_pose", None) |
|
step = kwargs.get("step", 1) |
|
use_sage = kwargs.get("use_sage", False) |
|
|
|
size = self.parse_size(size) |
|
target_height = align_to(size[0], 16) |
|
target_width = align_to(size[1], 16) |
|
|
|
|
|
if input_pose is not None: |
|
pose_embeds, uncond_pose_embeds, poses = GetPoseEmbedsFromTxt(input_pose, |
|
target_height, |
|
target_width, |
|
33, |
|
kwargs.get("flip", False), |
|
start_index, |
|
step) |
|
else: |
|
pose = ActionToPoseFromID(action_id, value=action_speed) |
|
pose_embeds, uncond_pose_embeds, poses = GetPoseEmbedsFromPoses(pose, |
|
target_height, |
|
target_width, |
|
33, |
|
kwargs.get("flip", False), |
|
0) |
|
|
|
if is_image: |
|
target_length = 34 |
|
else: |
|
target_length = 66 |
|
|
|
out_dict['frame'] = target_length |
|
|
|
|
|
pose_embeds = pose_embeds.unsqueeze(0).to(torch.bfloat16).to('cuda') |
|
uncond_pose_embeds = uncond_pose_embeds.unsqueeze(0).to(torch.bfloat16).to('cuda') |
|
|
|
|
|
|
|
cpu_offload = kwargs.get("cpu_offload", 0) |
|
use_deepcache = kwargs.get("use_deepcache", 1) |
|
denoise_strength = kwargs.get("denoise_strength", 1.0) |
|
init_latents = kwargs.get("init_latents", None) |
|
mask = kwargs.get("mask", None) |
|
if prompt is None: |
|
|
|
|
|
prompt = None |
|
negative_prompt = None |
|
batch_size = prompt_embeds.shape[0] |
|
assert prompt_embeds is not None |
|
else: |
|
|
|
|
|
if isinstance(prompt, str): |
|
batch_size = 1 |
|
prompt = [prompt] |
|
elif isinstance(prompt, (list, tuple)): |
|
batch_size = len(prompt) |
|
else: |
|
raise ValueError(f"Prompt must be a string or a list of strings, got {prompt}.") |
|
|
|
if negative_prompt is None: |
|
negative_prompt = [""] * batch_size |
|
if isinstance(negative_prompt, str): |
|
negative_prompt = [negative_prompt] * batch_size |
|
|
|
|
|
|
|
|
|
scheduler = FlowMatchDiscreteScheduler(shift=flow_shift, |
|
reverse=self.args.flow_reverse, |
|
solver=self.args.flow_solver, |
|
) |
|
self.pipeline.scheduler = scheduler |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(seed, torch.Tensor): |
|
seed = seed.tolist() |
|
if seed is None: |
|
seeds = [random.randint(0, 1_000_000) for _ in range(batch_size * num_videos_per_prompt)] |
|
elif isinstance(seed, int): |
|
seeds = [seed + i for _ in range(batch_size) for i in range(num_videos_per_prompt)] |
|
elif isinstance(seed, (list, tuple)): |
|
if len(seed) == batch_size: |
|
seeds = [int(seed[i]) + j for i in range(batch_size) for j in range(num_videos_per_prompt)] |
|
elif len(seed) == batch_size * num_videos_per_prompt: |
|
seeds = [int(s) for s in seed] |
|
else: |
|
raise ValueError( |
|
f"Length of seed must be equal to number of prompt(batch_size) or " |
|
f"batch_size * num_videos_per_prompt ({batch_size} * {num_videos_per_prompt}), got {seed}." |
|
) |
|
else: |
|
raise ValueError(f"Seed must be an integer, a list of integers, or None, got {seed}.") |
|
generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds] |
|
|
|
|
|
|
|
|
|
|
|
|
|
out_dict['size'] = (target_height, target_width) |
|
out_dict['video_length'] = target_length |
|
out_dict['seeds'] = seeds |
|
out_dict['negative_prompt'] = negative_prompt |
|
|
|
|
|
|
|
|
|
concat_dict = {'mode': 'timecat', 'bias': -1} |
|
if is_image: |
|
freqs_cos, freqs_sin = self.get_rotary_pos_embed(37, target_height, target_width) |
|
else: |
|
freqs_cos, freqs_sin = self.get_rotary_pos_embed(69, target_height, target_width) |
|
|
|
n_tokens = freqs_cos.shape[0] |
|
|
|
|
|
|
|
|
|
output_dir = kwargs.get("output_dir", None) |
|
|
|
if verbose == 1: |
|
debug_str = f""" |
|
size: {out_dict['size']} |
|
video_length: {target_length} |
|
prompt: {prompt} |
|
neg_prompt: {negative_prompt} |
|
seed: {seed} |
|
infer_steps: {infer_steps} |
|
denoise_strength: {denoise_strength} |
|
use_deepcache: {use_deepcache} |
|
use_sage: {use_sage} |
|
cpu_offload: {cpu_offload} |
|
num_images_per_prompt: {num_videos_per_prompt} |
|
guidance_scale: {guidance_scale} |
|
n_tokens: {n_tokens} |
|
flow_shift: {flow_shift} |
|
output: {output_dir}""" |
|
self.logger.info(debug_str) |
|
|
|
start_time = time.time() |
|
samples = self.pipeline(prompt=prompt, |
|
last_latents=last_latents, |
|
cam_latents=pose_embeds, |
|
uncond_cam_latents=uncond_pose_embeds, |
|
height=target_height, |
|
width=target_width, |
|
video_length=target_length, |
|
gt_latents = ref_latents, |
|
num_inference_steps=infer_steps, |
|
guidance_scale=guidance_scale, |
|
negative_prompt=negative_prompt, |
|
num_videos_per_prompt=num_videos_per_prompt, |
|
generator=generator, |
|
prompt_embeds=prompt_embeds, |
|
ref_latents=ref_latents, |
|
latents=init_latents, |
|
denoise_strength=denoise_strength, |
|
mask=mask, |
|
uncond_ref_latents=uncond_ref_latents, |
|
ip_cfg_scale=self.args.ip_cfg_scale, |
|
use_deepcache=use_deepcache, |
|
attention_mask=attention_mask, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
negative_attention_mask=negative_attention_mask, |
|
output_type=output_type, |
|
freqs_cis=(freqs_cos, freqs_sin), |
|
n_tokens=n_tokens, |
|
data_type='video' if target_length > 1 else 'image', |
|
is_progress_bar=True, |
|
vae_ver=self.args.vae, |
|
enable_tiling=self.args.vae_tiling, |
|
cpu_offload=cpu_offload, |
|
return_latents=return_latents, |
|
use_sage=use_sage, |
|
) |
|
if samples is None: |
|
return None |
|
out_dict['samples'] = [] |
|
out_dict["prompts"] = prompt |
|
out_dict['pose'] = poses |
|
|
|
if return_latents: |
|
print("return_latents | TRUE") |
|
latents, timesteps, last_latents, ref_latents = samples[1], samples[2], samples[3], samples[4] |
|
|
|
if samples[0] is not None and len(samples[0]) > 0: |
|
samples = samples[0][0] |
|
else: |
|
samples = None |
|
out_dict["denoised_lantents"] = latents |
|
out_dict["timesteps"] = timesteps |
|
out_dict["ref_latents"] = ref_latents |
|
out_dict["last_latents"] = last_latents |
|
|
|
else: |
|
samples = samples[0] |
|
|
|
if samples is not None: |
|
for i, sample in enumerate(samples): |
|
sample = samples[i].unsqueeze(0) |
|
sub_samples = [] |
|
sub_samples.append(sample) |
|
sample_num = len(sub_samples) |
|
sub_samples = torch.concat(sub_samples) |
|
|
|
out_dict['samples'].append(sub_samples) |
|
|
|
|
|
|
|
gen_time = time.time() - start_time |
|
logger.info(f"Success, time: {gen_time}") |
|
return out_dict |
|
|
|
|