Spaces:
Runtime error
Runtime error
import os | |
import json | |
import pickle as pkl | |
import random | |
import argparse | |
import torch | |
from TADA import smplx | |
import imageio | |
import numpy as np | |
from tqdm import tqdm | |
from PIL import Image | |
from TADA.lib.common.remesh import subdivide_inorder | |
from TADA.lib.common.utils import SMPLXSeg | |
from TADA.lib.common.lbs import warp_points | |
from TADA.lib.common.obj import compute_normal | |
import trimesh | |
import pyrender | |
from shapely import geometry | |
import moviepy.editor as mpy | |
os.environ['PYOPENGL_PLATFORM'] = "egl" | |
def build_new_mesh(v, f, vt, ft): | |
# build a correspondences dictionary from the original mesh indices to the (possibly multiple) texture map indices | |
f_flat = f.flatten() | |
ft_flat = ft.flatten() | |
correspondences = {} | |
# traverse and find the corresponding indices in f and ft | |
for i in range(len(f_flat)): | |
if f_flat[i] not in correspondences: | |
correspondences[f_flat[i]] = [ft_flat[i]] | |
else: | |
if ft_flat[i] not in correspondences[f_flat[i]]: | |
correspondences[f_flat[i]].append(ft_flat[i]) | |
# build a mesh using the texture map vertices | |
new_v = np.zeros((v.shape[0], vt.shape[0], 3)) | |
for old_index, new_indices in correspondences.items(): | |
for new_index in new_indices: | |
new_v[:, new_index] = v[:, old_index] | |
# define new faces using the texture map faces | |
f_new = ft | |
return new_v, f_new | |
class Animation: | |
def __init__(self, ckpt_path, workspace_dir, device="cuda"): | |
self.device = device | |
self.SMPLXSeg = SMPLXSeg(workspace_dir) | |
# load data | |
init_data = np.load(os.path.join(workspace_dir, "init_body/data.npz")) | |
self.dense_faces = torch.as_tensor(init_data['dense_faces'], device=self.device) | |
self.dense_lbs_weights = torch.as_tensor(init_data['dense_lbs_weights'], device=self.device) | |
self.unique = init_data['unique'] | |
self.vt = init_data['vt'] | |
self.ft = init_data['ft'] | |
model_params = dict( | |
model_path=os.path.join(workspace_dir, "smplx/SMPLX_NEUTRAL_2020.npz"), | |
model_type='smplx', | |
create_global_orient=True, | |
create_body_pose=True, | |
create_betas=True, | |
create_left_hand_pose=True, | |
create_right_hand_pose=True, | |
create_jaw_pose=True, | |
create_leye_pose=True, | |
create_reye_pose=True, | |
create_expression=True, | |
create_transl=False, | |
use_pca=False, | |
flat_hand_mean=False, | |
num_betas=300, | |
num_expression_coeffs=100, | |
num_pca_comps=12, | |
dtype=torch.float32, | |
batch_size=1, | |
) | |
self.body_model = smplx.create(**model_params).to(device=self.device) | |
self.smplx_face = self.body_model.faces.astype(np.int32) | |
ckpt_file = os.path.join(workspace_dir, "MESH", ckpt_path, "params.pt") | |
albedo_path = os.path.join(workspace_dir, "MESH", ckpt_path, "mesh_albedo.png") | |
self.load_ckpt_data(ckpt_file, albedo_path) | |
def load_ckpt_data(self, ckpt_file, albedo_path): | |
model_data = torch.load(ckpt_file, map_location=self.device) | |
self.expression = model_data["expression"] if "expression" in model_data else None | |
self.jaw_pose = model_data["jaw_pose"] if "jaw_pose" in model_data else None | |
self.betas = model_data['betas'] | |
self.v_offsets = model_data['v_offsets'] | |
self.v_offsets[self.SMPLXSeg.eyeball_ids] = 0. | |
self.v_offsets[self.SMPLXSeg.hands_ids] = 0. | |
# tex to trimesh texture | |
vt = self.vt.copy() | |
vt[:, 1] = 1 - vt[:, 1] | |
albedo = Image.open(albedo_path) | |
self.raw_albedo = torch.from_numpy(np.array(albedo)) | |
self.raw_albedo = self.raw_albedo / 255.0 | |
self.raw_albedo = self.raw_albedo.permute(2, 0, 1) | |
self.trimesh_visual = trimesh.visual.TextureVisuals( | |
uv=vt, | |
image=albedo, | |
material=trimesh.visual.texture.SimpleMaterial( | |
image=albedo, | |
diffuse=[255, 255, 255, 255], | |
ambient=[255, 255, 255, 255], | |
specular=[0, 0, 0, 255], | |
glossiness=0) | |
) | |
def forward_mdm(self, motion): | |
try: | |
mdm_body_pose = motion["poses"] | |
translate = torch.from_numpy(motion["trans"]) | |
except: | |
translate = torch.from_numpy(motion[:, -3:]) | |
mdm_body_pose = motion[:, :-3] | |
mdm_body_pose = mdm_body_pose.reshape(mdm_body_pose.shape[0], -1, 3) | |
translate = translate.to(self.device) | |
scan_v_posed = [] | |
for i, (pose, t) in tqdm(enumerate(zip(mdm_body_pose, translate))): | |
body_pose = torch.as_tensor(pose[None, 1:22, :], device=self.device) | |
global_orient = torch.as_tensor(pose[None, :1, :], device=self.device) | |
output = self.body_model( | |
betas=self.betas, | |
global_orient=global_orient, | |
jaw_pose=self.jaw_pose, | |
body_pose=body_pose, | |
expression=self.expression, | |
return_verts=True | |
) | |
v_cano = output.v_posed[0] | |
# re-mesh | |
v_cano_dense = subdivide_inorder(v_cano, self.smplx_face[self.SMPLXSeg.remesh_mask], self.unique).squeeze(0) | |
# add offsets | |
vn = compute_normal(v_cano_dense, self.dense_faces)[0] | |
v_cano_dense += self.v_offsets * vn | |
# do LBS | |
v_posed_dense = warp_points(v_cano_dense, self.dense_lbs_weights, output.joints_transform[:, :55]) | |
# translate | |
v_posed_dense += t - translate[0] | |
scan_v_posed.append(v_posed_dense) | |
scan_v_posed = torch.cat(scan_v_posed).detach().cpu().numpy() | |
new_scan_v_posed, new_face = build_new_mesh(scan_v_posed, self.dense_faces, self.vt, self.ft) | |
new_scan_v_posed = new_scan_v_posed.astype(np.float32) | |
return new_scan_v_posed, new_face | |