import os import json import pickle as pkl import random import argparse import torch from TADA import smplx import imageio import numpy as np from tqdm import tqdm from PIL import Image from TADA.lib.common.remesh import subdivide_inorder from TADA.lib.common.utils import SMPLXSeg from TADA.lib.common.lbs import warp_points from TADA.lib.common.obj import compute_normal import trimesh import pyrender from shapely import geometry import moviepy.editor as mpy os.environ['PYOPENGL_PLATFORM'] = "egl" def build_new_mesh(v, f, vt, ft): # build a correspondences dictionary from the original mesh indices to the (possibly multiple) texture map indices f_flat = f.flatten() ft_flat = ft.flatten() correspondences = {} # traverse and find the corresponding indices in f and ft for i in range(len(f_flat)): if f_flat[i] not in correspondences: correspondences[f_flat[i]] = [ft_flat[i]] else: if ft_flat[i] not in correspondences[f_flat[i]]: correspondences[f_flat[i]].append(ft_flat[i]) # build a mesh using the texture map vertices new_v = np.zeros((v.shape[0], vt.shape[0], 3)) for old_index, new_indices in correspondences.items(): for new_index in new_indices: new_v[:, new_index] = v[:, old_index] # define new faces using the texture map faces f_new = ft return new_v, f_new class Animation: def __init__(self, ckpt_path, workspace_dir, device="cuda"): self.device = device self.SMPLXSeg = SMPLXSeg(workspace_dir) # load data init_data = np.load(os.path.join(workspace_dir, "init_body/data.npz")) self.dense_faces = torch.as_tensor(init_data['dense_faces'], device=self.device) self.dense_lbs_weights = torch.as_tensor(init_data['dense_lbs_weights'], device=self.device) self.unique = init_data['unique'] self.vt = init_data['vt'] self.ft = init_data['ft'] model_params = dict( model_path=os.path.join(workspace_dir, "smplx/SMPLX_NEUTRAL_2020.npz"), model_type='smplx', create_global_orient=True, create_body_pose=True, create_betas=True, create_left_hand_pose=True, create_right_hand_pose=True, create_jaw_pose=True, create_leye_pose=True, create_reye_pose=True, create_expression=True, create_transl=False, use_pca=False, flat_hand_mean=False, num_betas=300, num_expression_coeffs=100, num_pca_comps=12, dtype=torch.float32, batch_size=1, ) self.body_model = smplx.create(**model_params).to(device=self.device) self.smplx_face = self.body_model.faces.astype(np.int32) ckpt_file = os.path.join(workspace_dir, "MESH", ckpt_path, "params.pt") albedo_path = os.path.join(workspace_dir, "MESH", ckpt_path, "mesh_albedo.png") self.load_ckpt_data(ckpt_file, albedo_path) def load_ckpt_data(self, ckpt_file, albedo_path): model_data = torch.load(ckpt_file, map_location=self.device) self.expression = model_data["expression"] if "expression" in model_data else None self.jaw_pose = model_data["jaw_pose"] if "jaw_pose" in model_data else None self.betas = model_data['betas'] self.v_offsets = model_data['v_offsets'] self.v_offsets[self.SMPLXSeg.eyeball_ids] = 0. self.v_offsets[self.SMPLXSeg.hands_ids] = 0. # tex to trimesh texture vt = self.vt.copy() vt[:, 1] = 1 - vt[:, 1] albedo = Image.open(albedo_path) self.raw_albedo = torch.from_numpy(np.array(albedo)) self.raw_albedo = self.raw_albedo / 255.0 self.raw_albedo = self.raw_albedo.permute(2, 0, 1) self.trimesh_visual = trimesh.visual.TextureVisuals( uv=vt, image=albedo, material=trimesh.visual.texture.SimpleMaterial( image=albedo, diffuse=[255, 255, 255, 255], ambient=[255, 255, 255, 255], specular=[0, 0, 0, 255], glossiness=0) ) def forward_mdm(self, motion): try: mdm_body_pose = motion["poses"] translate = torch.from_numpy(motion["trans"]) except: translate = torch.from_numpy(motion[:, -3:]) mdm_body_pose = motion[:, :-3] mdm_body_pose = mdm_body_pose.reshape(mdm_body_pose.shape[0], -1, 3) translate = translate.to(self.device) scan_v_posed = [] for i, (pose, t) in tqdm(enumerate(zip(mdm_body_pose, translate))): body_pose = torch.as_tensor(pose[None, 1:22, :], device=self.device) global_orient = torch.as_tensor(pose[None, :1, :], device=self.device) output = self.body_model( betas=self.betas, global_orient=global_orient, jaw_pose=self.jaw_pose, body_pose=body_pose, expression=self.expression, return_verts=True ) v_cano = output.v_posed[0] # re-mesh v_cano_dense = subdivide_inorder(v_cano, self.smplx_face[self.SMPLXSeg.remesh_mask], self.unique).squeeze(0) # add offsets vn = compute_normal(v_cano_dense, self.dense_faces)[0] v_cano_dense += self.v_offsets * vn # do LBS v_posed_dense = warp_points(v_cano_dense, self.dense_lbs_weights, output.joints_transform[:, :55]) # translate v_posed_dense += t - translate[0] scan_v_posed.append(v_posed_dense) scan_v_posed = torch.cat(scan_v_posed).detach().cpu().numpy() new_scan_v_posed, new_face = build_new_mesh(scan_v_posed, self.dense_faces, self.vt, self.ft) new_scan_v_posed = new_scan_v_posed.astype(np.float32) return new_scan_v_posed, new_face