Spaces:
Running
on
A100
Running
on
A100
import math | |
import time | |
import torch | |
import random | |
from loguru import logger | |
import numpy as np | |
import matplotlib as mpl | |
import matplotlib.pyplot as plt | |
from matplotlib.patches import Patch | |
from mpl_toolkits.mplot3d.art3d import Poly3DCollection | |
from hymm_sp.diffusion import load_diffusion_pipeline | |
from hymm_sp.helpers import get_nd_rotary_pos_embed_new | |
from hymm_sp.inference import Inference | |
from hymm_sp.diffusion.schedulers import FlowMatchDiscreteScheduler | |
from packaging import version as pver | |
ACTION_DICT = {"w": "forward", "a": "left", "d": "right", "s": "backward"} | |
def custom_meshgrid(*args): | |
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid | |
if pver.parse(torch.__version__) < pver.parse('1.10'): | |
return torch.meshgrid(*args) | |
else: | |
return torch.meshgrid(*args, indexing='ij') | |
def get_relative_pose(cam_params): | |
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] | |
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] | |
source_cam_c2w = abs_c2ws[0] | |
cam_to_origin = 0 | |
target_cam_c2w = np.array([ | |
[1, 0, 0, 0], | |
[0, 1, 0, -cam_to_origin], | |
[0, 0, 1, 0], | |
[0, 0, 0, 1] | |
]) | |
abs2rel = target_cam_c2w @ abs_w2cs[0] | |
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] | |
for pose in ret_poses: | |
pose[:3, -1:] *= 10 | |
ret_poses = np.array(ret_poses, dtype=np.float32) | |
return ret_poses | |
def ray_condition(K, c2w, H, W, device, flip_flag=None): | |
# c2w: B, V, 4, 4 | |
# K: B, V, 4 | |
B, V = K.shape[:2] | |
j, i = custom_meshgrid( | |
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), | |
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), | |
) | |
i = i.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW] | |
j = j.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW] | |
n_flip = torch.sum(flip_flag).item() if flip_flag is not None else 0 | |
if n_flip > 0: | |
j_flip, i_flip = custom_meshgrid( | |
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), | |
torch.linspace(W - 1, 0, W, device=device, dtype=c2w.dtype) | |
) | |
i_flip = i_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5 | |
j_flip = j_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5 | |
i[:, flip_flag, ...] = i_flip | |
j[:, flip_flag, ...] = j_flip | |
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 | |
zs = torch.ones_like(i) # [B, V, HxW] | |
xs = (i - cx) / fx * zs | |
ys = (j - cy) / fy * zs | |
zs = zs.expand_as(ys) | |
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 | |
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 | |
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, HW, 3 | |
rays_o = c2w[..., :3, 3] # B, V, 3 | |
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, HW, 3 | |
# c2w @ dirctions | |
rays_dxo = torch.cross(rays_o, rays_d) # B, V, HW, 3 | |
plucker = torch.cat([rays_dxo, rays_d], dim=-1) | |
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 | |
# plucker = plucker.permute(0, 1, 4, 2, 3) | |
return plucker | |
def get_c2w(w2cs, transform_matrix, relative_c2w): | |
if relative_c2w: | |
target_cam_c2w = np.array([ | |
[1, 0, 0, 0], | |
[0, 1, 0, 0], | |
[0, 0, 1, 0], | |
[0, 0, 0, 1] | |
]) | |
abs2rel = target_cam_c2w @ w2cs[0] | |
ret_poses = [target_cam_c2w, ] + [abs2rel @ np.linalg.inv(w2c) for w2c in w2cs[1:]] | |
for pose in ret_poses: | |
pose[:3, -1:] *= 2 | |
# ret_poses = [poses[:, :3]*2 for poses in ret_poses] | |
# ret_poses[:, :, :3] *= 2 | |
else: | |
ret_poses = [np.linalg.inv(w2c) for w2c in w2cs] | |
ret_poses = [transform_matrix @ x for x in ret_poses] | |
return np.array(ret_poses, dtype=np.float32) | |
def generate_motion_segment(current_pose, | |
motion_type: str, | |
value: float, | |
duration: int = 30): | |
""" | |
Parameters: | |
motion_type: ('forward', 'backward', 'left', 'right', | |
'rotate_left', 'rotate_right', 'rotate_up', 'rotate_down') | |
value: Translation(m) or Rotation(degree) | |
duration: frames | |
Return: | |
positions: [np.array(x,y,z), ...] | |
rotations: [np.array(pitch,yaw,roll), ...] | |
""" | |
positions = [] | |
rotations = [] | |
if motion_type in ['forward', 'backward']: | |
yaw_rad = np.radians(current_pose['rotation'][1]) | |
pitch_rad = np.radians(current_pose['rotation'][0]) | |
forward_vec = np.array([ | |
-math.sin(yaw_rad) * math.cos(pitch_rad), | |
math.sin(pitch_rad), | |
-math.cos(yaw_rad) * math.cos(pitch_rad) | |
]) | |
direction = 1 if motion_type == 'forward' else -1 | |
total_move = forward_vec * value * direction | |
step = total_move / duration | |
for i in range(1, duration+1): | |
new_pos = current_pose['position'] + step * i | |
positions.append(new_pos.copy()) | |
rotations.append(current_pose['rotation'].copy()) | |
current_pose['position'] = positions[-1] | |
elif motion_type in ['left', 'right']: | |
yaw_rad = np.radians(current_pose['rotation'][1]) | |
right_vec = np.array([math.cos(yaw_rad), 0, -math.sin(yaw_rad)]) | |
direction = -1 if motion_type == 'right' else 1 | |
total_move = right_vec * value * direction | |
step = total_move / duration | |
for i in range(1, duration+1): | |
new_pos = current_pose['position'] + step * i | |
positions.append(new_pos.copy()) | |
rotations.append(current_pose['rotation'].copy()) | |
current_pose['position'] = positions[-1] | |
elif motion_type.endswith('rot'): | |
axis = motion_type.split('_')[0] | |
total_rotation = np.zeros(3) | |
if axis == 'left': | |
total_rotation[0] = value | |
elif axis == 'right': | |
total_rotation[0] = -value | |
elif axis == 'up': | |
total_rotation[2] = -value | |
elif axis == 'down': | |
total_rotation[2] = value | |
step = total_rotation / duration | |
for i in range(1, duration+1): | |
positions.append(current_pose['position'].copy()) | |
new_rot = current_pose['rotation'] + step * i | |
rotations.append(new_rot.copy()) | |
current_pose['rotation'] = rotations[-1] | |
return positions, rotations, current_pose | |
def euler_to_quaternion(angles): | |
pitch, yaw, roll = np.radians(angles) | |
cy = math.cos(yaw * 0.5) | |
sy = math.sin(yaw * 0.5) | |
cp = math.cos(pitch * 0.5) | |
sp = math.sin(pitch * 0.5) | |
cr = math.cos(roll * 0.5) | |
sr = math.sin(roll * 0.5) | |
qw = cy * cp * cr + sy * sp * sr | |
qx = cy * cp * sr - sy * sp * cr | |
qy = sy * cp * sr + cy * sp * cr | |
qz = sy * cp * cr - cy * sp * sr | |
return [qw, qx, qy, qz] | |
def quaternion_to_rotation_matrix(q): | |
qw, qx, qy, qz = q | |
return np.array([ | |
[1 - 2*(qy**2 + qz**2), 2*(qx*qy - qw*qz), 2*(qx*qz + qw*qy)], | |
[2*(qx*qy + qw*qz), 1 - 2*(qx**2 + qz**2), 2*(qy*qz - qw*qx)], | |
[2*(qx*qz - qw*qy), 2*(qy*qz + qw*qx), 1 - 2*(qx**2 + qy**2)] | |
]) | |
def ActionToPoseFromID(action_id, value=0.2, duration=33): | |
all_positions = [] | |
all_rotations = [] | |
current_pose = { | |
'position': np.array([0.0, 0.0, 0.0]), # XYZ | |
'rotation': np.array([0.0, 0.0, 0.0]) # (pitch, yaw, roll) | |
} | |
intrinsic = [0.50505, 0.8979, 0.5, 0.5] | |
motion_type = ACTION_DICT[action_id] | |
positions, rotations, current_pose = generate_motion_segment(current_pose, motion_type, value, duration) | |
all_positions.extend(positions) | |
all_rotations.extend(rotations) | |
pose_list = [] | |
row = [0] + intrinsic + [0, 0] + [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0] | |
first_row = " ".join(map(str, row)) | |
pose_list.append(first_row) | |
for i, (pos, rot) in enumerate(zip(all_positions, all_rotations)): | |
quat = euler_to_quaternion(rot) | |
R = quaternion_to_rotation_matrix(quat) | |
extrinsic = np.hstack([R, pos.reshape(3, 1)]) | |
row = [i] + intrinsic + [0, 0] + extrinsic.flatten().tolist() | |
pose_list.append(" ".join(map(str, row))) | |
return pose_list | |
class Camera(object): | |
def __init__(self, entry): | |
fx, fy, cx, cy = entry[1:5] | |
self.fx = fx | |
self.fy = fy | |
self.cx = cx | |
self.cy = cy | |
w2c_mat = np.array(entry[7:]).reshape(3, 4) | |
w2c_mat_4x4 = np.eye(4) | |
w2c_mat_4x4[:3, :] = w2c_mat | |
self.w2c_mat = w2c_mat_4x4 | |
self.c2w_mat = np.linalg.inv(w2c_mat_4x4) | |
class CameraPoseVisualizer: | |
def __init__(self, xlim, ylim, zlim): | |
self.fig = plt.figure(figsize=(7, 7)) | |
self.ax = self.fig.add_subplot(projection='3d') | |
# self.ax.view_init(elev=25, azim=-90) | |
self.plotly_data = None # plotly data traces | |
self.ax.set_aspect("auto") | |
self.ax.set_xlim(xlim) | |
self.ax.set_ylim(ylim) | |
self.ax.set_zlim(zlim) | |
self.ax.set_xlabel('x') | |
self.ax.set_ylabel('y') | |
self.ax.set_zlabel('z') | |
print('initialize camera pose visualizer') | |
def extrinsic2pyramid(self, extrinsic, color_map='red', hw_ratio=9/16, base_xval=1, zval=3): | |
vertex_std = np.array([[0, 0, 0, 1], | |
[base_xval, -base_xval * hw_ratio, zval, 1], | |
[base_xval, base_xval * hw_ratio, zval, 1], | |
[-base_xval, base_xval * hw_ratio, zval, 1], | |
[-base_xval, -base_xval * hw_ratio, zval, 1]]) | |
vertex_transformed = vertex_std @ extrinsic.T | |
meshes = [[vertex_transformed[0, :-1], vertex_transformed[1][:-1], vertex_transformed[2, :-1]], | |
[vertex_transformed[0, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1]], | |
[vertex_transformed[0, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]], | |
[vertex_transformed[0, :-1], vertex_transformed[4, :-1], vertex_transformed[1, :-1]], | |
[vertex_transformed[1, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1], | |
vertex_transformed[4, :-1]]] | |
color = color_map if isinstance(color_map, str) else plt.cm.rainbow(color_map) | |
self.ax.add_collection3d( | |
Poly3DCollection(meshes, facecolors=color, linewidths=0.3, edgecolors=color, alpha=0.35)) | |
def customize_legend(self, list_label): | |
list_handle = [] | |
for idx, label in enumerate(list_label): | |
color = plt.cm.rainbow(idx / len(list_label)) | |
patch = Patch(color=color, label=label) | |
list_handle.append(patch) | |
plt.legend(loc='right', bbox_to_anchor=(1.8, 0.5), handles=list_handle) | |
def colorbar(self, max_frame_length): | |
cmap = mpl.cm.rainbow | |
norm = mpl.colors.Normalize(vmin=0, vmax=max_frame_length) | |
self.fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), | |
ax=self.ax, orientation='vertical', label='Frame Number') | |
def show(self, file_name): | |
plt.title('Extrinsic Parameters') | |
# plt.show() | |
plt.savefig(file_name) | |
def align_to(value, alignment): | |
return int(math.ceil(value / alignment) * alignment) | |
def GetPoseEmbedsFromPoses(poses, h, w, target_length, flip=False, start_index=None): | |
poses = [pose.split(' ') for pose in poses] | |
start_idx = start_index | |
sample_id = [start_idx + i for i in range(target_length)] | |
poses = [poses[i] for i in sample_id] | |
frame = len(poses) | |
w2cs = [np.asarray([float(p) for p in pose[7:]]).reshape(3, 4) for pose in poses] | |
transform_matrix = np.asarray([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]).reshape(4, 4) | |
last_row = np.zeros((1, 4)) | |
last_row[0, -1] = 1.0 | |
w2cs = [np.concatenate((w2c, last_row), axis=0) for w2c in w2cs] | |
c2ws = get_c2w(w2cs, transform_matrix, relative_c2w=True) | |
cam_params = [[float(x) for x in pose] for pose in poses] | |
assert len(cam_params) == target_length | |
cam_params = [Camera(cam_param) for cam_param in cam_params] | |
monst3r_w = cam_params[0].cx * 2 | |
monst3r_h = cam_params[0].cy * 2 | |
ratio_w, ratio_h = w/monst3r_w, h/monst3r_h | |
intrinsics = np.asarray([[cam_param.fx * ratio_w, | |
cam_param.fy * ratio_h, | |
cam_param.cx * ratio_w, | |
cam_param.cy * ratio_h] | |
for cam_param in cam_params], dtype=np.float32) | |
intrinsics = torch.as_tensor(intrinsics)[None] # [1, n_frame, 4] | |
relative_pose = True | |
if relative_pose: | |
c2w_poses = get_relative_pose(cam_params) | |
else: | |
c2w_poses = np.array([cam_param.c2w_mat for cam_param in cam_params], dtype=np.float32) | |
c2w = torch.as_tensor(c2w_poses)[None] # [1, n_frame, 4, 4] | |
uncond_c2w = torch.zeros_like(c2w) | |
uncond_c2w[:, :] = torch.eye(4, device=c2w.device) | |
flip_flag = torch.zeros(target_length, dtype=torch.bool, device=c2w.device) | |
plucker_embedding = ray_condition(intrinsics, c2w, h, w, device='cpu', | |
flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous() | |
uncond_plucker_embedding = ray_condition(intrinsics, uncond_c2w, h, w, device='cpu', | |
flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous() | |
return plucker_embedding, uncond_plucker_embedding, poses | |
def GetPoseEmbedsFromTxt(pose_dir, h, w, target_length, flip=False, start_index=None, step=1): | |
# get camera pose | |
with open(pose_dir, 'r') as f: | |
poses = f.readlines() | |
poses = [pose.strip().split(' ') for pose in poses[1:]] | |
start_idx = start_index | |
sample_id = [start_idx + i*step for i in range(target_length)] | |
poses = [poses[i] for i in sample_id] | |
cam_params = [[float(x) for x in pose] for pose in poses] | |
assert len(cam_params) == target_length | |
cam_params = [Camera(cam_param) for cam_param in cam_params] | |
monst3r_w = cam_params[0].cx * 2 | |
monst3r_h = cam_params[0].cy * 2 | |
ratio_w, ratio_h = w/monst3r_w, h/monst3r_h | |
intrinsics = np.asarray([[cam_param.fx * ratio_w, | |
cam_param.fy * ratio_h, | |
cam_param.cx * ratio_w, | |
cam_param.cy * ratio_h] | |
for cam_param in cam_params], dtype=np.float32) | |
intrinsics = torch.as_tensor(intrinsics)[None] # [1, n_frame, 4] | |
relative_pose = True | |
if relative_pose: | |
c2w_poses = get_relative_pose(cam_params) | |
else: | |
c2w_poses = np.array([cam_param.c2w_mat for cam_param in cam_params], dtype=np.float32) | |
c2w = torch.as_tensor(c2w_poses)[None] # [1, n_frame, 4, 4] | |
uncond_c2w = torch.zeros_like(c2w) | |
uncond_c2w[:, :] = torch.eye(4, device=c2w.device) | |
if flip: | |
flip_flag = torch.ones(target_length, dtype=torch.bool, device=c2w.device) | |
else: | |
flip_flag = torch.zeros(target_length, dtype=torch.bool, device=c2w.device) | |
plucker_embedding = ray_condition(intrinsics, c2w, h, w, device='cpu', | |
flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous() | |
uncond_plucker_embedding = ray_condition(intrinsics, uncond_c2w, h, w, device='cpu', | |
flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous() | |
return plucker_embedding, uncond_plucker_embedding, poses | |
class HunyuanVideoSampler(Inference): | |
def __init__(self, args, vae, vae_kwargs, text_encoder, model, text_encoder_2=None, pipeline=None, | |
device=0, logger=None): | |
super().__init__(args, vae, vae_kwargs, text_encoder, model, text_encoder_2=text_encoder_2, | |
pipeline=pipeline, device=device, logger=logger) | |
self.args = args | |
self.pipeline = load_diffusion_pipeline( | |
args, 0, self.vae, self.text_encoder, self.text_encoder_2, self.model, | |
device=self.device) | |
print('load hunyuan model successful... ') | |
def get_rotary_pos_embed(self, video_length, height, width, concat_dict={}): | |
target_ndim = 3 | |
ndim = 5 - 2 | |
if '884' in self.args.vae: | |
latents_size = [(video_length-1)//4+1 , height//8, width//8] | |
else: | |
latents_size = [video_length , height//8, width//8] | |
if isinstance(self.model.patch_size, int): | |
assert all(s % self.model.patch_size == 0 for s in latents_size), \ | |
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \ | |
f"but got {latents_size}." | |
rope_sizes = [s // self.model.patch_size for s in latents_size] | |
elif isinstance(self.model.patch_size, list): | |
assert all(s % self.model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \ | |
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \ | |
f"but got {latents_size}." | |
rope_sizes = [s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)] | |
if len(rope_sizes) != target_ndim: | |
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis | |
head_dim = self.model.hidden_size // self.model.num_heads | |
rope_dim_list = self.model.rope_dim_list | |
if rope_dim_list is None: | |
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] | |
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" | |
freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(rope_dim_list, | |
rope_sizes, | |
theta=self.args.rope_theta, | |
use_real=True, | |
theta_rescale_factor=1, | |
concat_dict=concat_dict) | |
return freqs_cos, freqs_sin | |
def predict(self, | |
prompt, | |
is_image=True, | |
size=(720, 1280), | |
video_length=129, | |
seed=None, | |
negative_prompt=None, | |
infer_steps=50, | |
guidance_scale=6.0, | |
flow_shift=5.0, | |
batch_size=1, | |
num_videos_per_prompt=1, | |
verbose=1, | |
output_type="pil", | |
**kwargs): | |
""" | |
Predict the image from the given text. | |
Args: | |
prompt (str or List[str]): The input text. | |
kwargs: | |
size (int): The (height, width) of the output image/video. Default is (256, 256). | |
video_length (int): The frame number of the output video. Default is 1. | |
seed (int or List[str]): The random seed for the generation. Default is a random integer. | |
negative_prompt (str or List[str]): The negative text prompt. Default is an empty string. | |
infer_steps (int): The number of inference steps. Default is 100. | |
guidance_scale (float): The guidance scale for the generation. Default is 6.0. | |
num_videos_per_prompt (int): The number of videos per prompt. Default is 1. | |
verbose (int): 0 for no log, 1 for all log, 2 for fewer log. Default is 1. | |
output_type (str): The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. | |
Default is 'pil'. | |
""" | |
out_dict = dict() | |
# --------------------------------- | |
# Prompt | |
# --------------------------------- | |
prompt_embeds = kwargs.get("prompt_embeds", None) | |
attention_mask = kwargs.get("attention_mask", None) | |
negative_prompt_embeds = kwargs.get("negative_prompt_embeds", None) | |
negative_attention_mask = kwargs.get("negative_attention_mask", None) | |
ref_latents = kwargs.get("ref_latents", None) | |
uncond_ref_latents = kwargs.get("uncond_ref_latents", None) | |
return_latents = kwargs.get("return_latents", False) | |
negative_prompt = kwargs.get("negative_prompt", None) | |
action_id = kwargs.get("action_id", None) | |
action_speed = kwargs.get("action_speed", None) | |
start_index = kwargs.get("start_index", None) | |
last_latents = kwargs.get("last_latents", None) | |
ref_latents = kwargs.get("ref_latents", None) | |
input_pose = kwargs.get("input_pose", None) | |
step = kwargs.get("step", 1) | |
use_sage = kwargs.get("use_sage", False) | |
size = self.parse_size(size) | |
target_height = align_to(size[0], 16) | |
target_width = align_to(size[1], 16) | |
# target_video_length = video_length | |
if input_pose is not None: | |
pose_embeds, uncond_pose_embeds, poses = GetPoseEmbedsFromTxt(input_pose, | |
target_height, | |
target_width, | |
33, | |
kwargs.get("flip", False), | |
start_index, | |
step) | |
else: | |
pose = ActionToPoseFromID(action_id, value=action_speed) | |
pose_embeds, uncond_pose_embeds, poses = GetPoseEmbedsFromPoses(pose, | |
target_height, | |
target_width, | |
33, | |
kwargs.get("flip", False), | |
0) | |
if is_image: | |
target_length = 34 | |
else: | |
target_length = 66 | |
out_dict['frame'] = target_length | |
# print("pose embeds: ", pose_embeds.shape, uncond_pose_embeds.shape) | |
pose_embeds = pose_embeds.unsqueeze(0).to(torch.bfloat16).to('cuda') | |
uncond_pose_embeds = uncond_pose_embeds.unsqueeze(0).to(torch.bfloat16).to('cuda') | |
cpu_offload = kwargs.get("cpu_offload", 0) | |
use_deepcache = kwargs.get("use_deepcache", 1) | |
denoise_strength = kwargs.get("denoise_strength", 1.0) | |
init_latents = kwargs.get("init_latents", None) | |
mask = kwargs.get("mask", None) | |
if prompt is None: | |
# prompt_embeds, attention_mask, negative_prompt_embeds and negative_attention_mask should not be None | |
# pipeline will help to check this | |
prompt = None | |
negative_prompt = None | |
batch_size = prompt_embeds.shape[0] | |
assert prompt_embeds is not None | |
else: | |
# prompt_embeds, attention_mask, negative_prompt_embeds and negative_attention_mask should be None | |
# pipeline will help to check this | |
if isinstance(prompt, str): | |
batch_size = 1 | |
prompt = [prompt] | |
elif isinstance(prompt, (list, tuple)): | |
batch_size = len(prompt) | |
else: | |
raise ValueError(f"Prompt must be a string or a list of strings, got {prompt}.") | |
if negative_prompt is None: | |
negative_prompt = [""] * batch_size | |
if isinstance(negative_prompt, str): | |
negative_prompt = [negative_prompt] * batch_size | |
# --------------------------------- | |
# Other arguments | |
# --------------------------------- | |
scheduler = FlowMatchDiscreteScheduler(shift=flow_shift, | |
reverse=self.args.flow_reverse, | |
solver=self.args.flow_solver, | |
) | |
self.pipeline.scheduler = scheduler | |
# --------------------------------- | |
# Random seed | |
# --------------------------------- | |
if isinstance(seed, torch.Tensor): | |
seed = seed.tolist() | |
if seed is None: | |
seeds = [random.randint(0, 1_000_000) for _ in range(batch_size * num_videos_per_prompt)] | |
elif isinstance(seed, int): | |
seeds = [seed + i for _ in range(batch_size) for i in range(num_videos_per_prompt)] | |
elif isinstance(seed, (list, tuple)): | |
if len(seed) == batch_size: | |
seeds = [int(seed[i]) + j for i in range(batch_size) for j in range(num_videos_per_prompt)] | |
elif len(seed) == batch_size * num_videos_per_prompt: | |
seeds = [int(s) for s in seed] | |
else: | |
raise ValueError( | |
f"Length of seed must be equal to number of prompt(batch_size) or " | |
f"batch_size * num_videos_per_prompt ({batch_size} * {num_videos_per_prompt}), got {seed}." | |
) | |
else: | |
raise ValueError(f"Seed must be an integer, a list of integers, or None, got {seed}.") | |
generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds] | |
# --------------------------------- | |
# Image/Video size and frame | |
# --------------------------------- | |
out_dict['size'] = (target_height, target_width) | |
out_dict['video_length'] = target_length | |
out_dict['seeds'] = seeds | |
out_dict['negative_prompt'] = negative_prompt | |
# --------------------------------- | |
# Build RoPE | |
# --------------------------------- | |
concat_dict = {'mode': 'timecat', 'bias': -1} | |
if is_image: | |
freqs_cos, freqs_sin = self.get_rotary_pos_embed(37, target_height, target_width) | |
else: | |
freqs_cos, freqs_sin = self.get_rotary_pos_embed(69, target_height, target_width) | |
n_tokens = freqs_cos.shape[0] | |
# --------------------------------- | |
# Inference | |
# --------------------------------- | |
output_dir = kwargs.get("output_dir", None) | |
if verbose == 1: | |
debug_str = f""" | |
size: {out_dict['size']} | |
video_length: {target_length} | |
prompt: {prompt} | |
neg_prompt: {negative_prompt} | |
seed: {seed} | |
infer_steps: {infer_steps} | |
denoise_strength: {denoise_strength} | |
use_deepcache: {use_deepcache} | |
use_sage: {use_sage} | |
cpu_offload: {cpu_offload} | |
num_images_per_prompt: {num_videos_per_prompt} | |
guidance_scale: {guidance_scale} | |
n_tokens: {n_tokens} | |
flow_shift: {flow_shift} | |
output: {output_dir}""" | |
self.logger.info(debug_str) | |
start_time = time.time() | |
samples = self.pipeline(prompt=prompt, | |
last_latents=last_latents, | |
cam_latents=pose_embeds, | |
uncond_cam_latents=uncond_pose_embeds, | |
height=target_height, | |
width=target_width, | |
video_length=target_length, | |
gt_latents = ref_latents, | |
num_inference_steps=infer_steps, | |
guidance_scale=guidance_scale, | |
negative_prompt=negative_prompt, | |
num_videos_per_prompt=num_videos_per_prompt, | |
generator=generator, | |
prompt_embeds=prompt_embeds, | |
ref_latents=ref_latents, | |
latents=init_latents, | |
denoise_strength=denoise_strength, | |
mask=mask, | |
uncond_ref_latents=uncond_ref_latents, | |
ip_cfg_scale=self.args.ip_cfg_scale, | |
use_deepcache=use_deepcache, | |
attention_mask=attention_mask, | |
negative_prompt_embeds=negative_prompt_embeds, | |
negative_attention_mask=negative_attention_mask, | |
output_type=output_type, | |
freqs_cis=(freqs_cos, freqs_sin), | |
n_tokens=n_tokens, | |
data_type='video' if target_length > 1 else 'image', | |
is_progress_bar=True, | |
vae_ver=self.args.vae, | |
enable_tiling=self.args.vae_tiling, | |
cpu_offload=cpu_offload, | |
return_latents=return_latents, | |
use_sage=use_sage, | |
) | |
if samples is None: | |
return None | |
out_dict['samples'] = [] | |
out_dict["prompts"] = prompt | |
out_dict['pose'] = poses | |
if return_latents: | |
print("return_latents | TRUE") | |
latents, timesteps, last_latents, ref_latents = samples[1], samples[2], samples[3], samples[4] | |
# samples = samples[0][0] | |
if samples[0] is not None and len(samples[0]) > 0: | |
samples = samples[0][0] | |
else: | |
samples = None | |
out_dict["denoised_lantents"] = latents | |
out_dict["timesteps"] = timesteps | |
out_dict["ref_latents"] = ref_latents | |
out_dict["last_latents"] = last_latents | |
else: | |
samples = samples[0] | |
if samples is not None: | |
for i, sample in enumerate(samples): | |
sample = samples[i].unsqueeze(0) | |
sub_samples = [] | |
sub_samples.append(sample) | |
sample_num = len(sub_samples) | |
sub_samples = torch.concat(sub_samples) | |
# only save in tp rank 0 | |
out_dict['samples'].append(sub_samples) | |
# visualize pose | |
gen_time = time.time() - start_time | |
logger.info(f"Success, time: {gen_time}") | |
return out_dict | |