|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import gc |
|
|
|
import logging |
|
from lib.common.config import cfg |
|
from lib.dataset.mesh_util import ( |
|
load_checkpoint, |
|
update_mesh_shape_prior_losses, |
|
blend_rgb_norm, |
|
unwrap, |
|
remesh, |
|
tensor2variable, |
|
rot6d_to_rotmat |
|
) |
|
|
|
from lib.dataset.TestDataset import TestDataset |
|
from lib.common.render import query_color |
|
from lib.net.local_affine import LocalAffine |
|
from pytorch3d.structures import Meshes |
|
from apps.ICON import ICON |
|
|
|
from termcolor import colored |
|
import numpy as np |
|
from PIL import Image |
|
import trimesh |
|
import numpy as np |
|
from tqdm import tqdm |
|
|
|
import torch |
|
torch.backends.cudnn.benchmark = True |
|
|
|
logging.getLogger("trimesh").setLevel(logging.ERROR) |
|
|
|
|
|
def generate_model(in_path, model_type): |
|
|
|
torch.cuda.empty_cache() |
|
|
|
if model_type == 'ICON': |
|
model_type = 'icon-filter' |
|
else: |
|
model_type = model_type.lower() |
|
|
|
config_dict = {'loop_smpl': 100, |
|
'loop_cloth': 200, |
|
'patience': 5, |
|
'out_dir': './results', |
|
'hps_type': 'pymaf', |
|
'config': f"./configs/{model_type}.yaml"} |
|
|
|
|
|
cfg.merge_from_file(config_dict['config']) |
|
cfg.merge_from_file("./lib/pymaf/configs/pymaf_config.yaml") |
|
|
|
os.makedirs(config_dict['out_dir'], exist_ok=True) |
|
|
|
cfg_show_list = [ |
|
"test_gpus", |
|
[0], |
|
"mcube_res", |
|
256, |
|
"clean_mesh", |
|
True, |
|
] |
|
|
|
cfg.merge_from_list(cfg_show_list) |
|
cfg.freeze() |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|
device = torch.device(f"cuda:0") |
|
|
|
|
|
model = ICON(cfg) |
|
model = load_checkpoint(model, cfg) |
|
|
|
dataset_param = { |
|
'image_path': in_path, |
|
'seg_dir': None, |
|
'has_det': True, |
|
'hps_type': 'pymaf' |
|
} |
|
|
|
if config_dict['hps_type'] == "pixie" and "pamir" in config_dict['config']: |
|
print(colored("PIXIE isn't compatible with PaMIR, thus switch to PyMAF", "red")) |
|
dataset_param["hps_type"] = "pymaf" |
|
|
|
dataset = TestDataset(dataset_param, device) |
|
|
|
print(colored(f"Dataset Size: {len(dataset)}", "green")) |
|
|
|
pbar = tqdm(dataset) |
|
|
|
for data in pbar: |
|
|
|
pbar.set_description(f"{data['name']}") |
|
|
|
in_tensor = {"smpl_faces": data["smpl_faces"], "image": data["image"]} |
|
|
|
|
|
optimed_pose = torch.tensor( |
|
data["body_pose"], device=device, requires_grad=True |
|
) |
|
optimed_trans = torch.tensor( |
|
data["trans"], device=device, requires_grad=True |
|
) |
|
optimed_betas = torch.tensor( |
|
data["betas"], device=device, requires_grad=True |
|
) |
|
optimed_orient = torch.tensor( |
|
data["global_orient"], device=device, requires_grad=True |
|
) |
|
|
|
optimizer_smpl = torch.optim.Adam( |
|
[optimed_pose, optimed_trans, optimed_betas, optimed_orient], |
|
lr=1e-3, |
|
amsgrad=True, |
|
) |
|
scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau( |
|
optimizer_smpl, |
|
mode="min", |
|
factor=0.5, |
|
verbose=0, |
|
min_lr=1e-5, |
|
patience=config_dict['patience'], |
|
) |
|
|
|
losses = { |
|
|
|
"cloth": {"weight": 1e1, "value": 0.0}, |
|
|
|
"stiffness": {"weight": 1e5, "value": 0.0}, |
|
|
|
"rigid": {"weight": 1e5, "value": 0.0}, |
|
|
|
"edge": {"weight": 0, "value": 0.0}, |
|
|
|
"nc": {"weight": 0, "value": 0.0}, |
|
|
|
"laplacian": {"weight": 1e2, "value": 0.0}, |
|
|
|
"normal": {"weight": 1e0, "value": 0.0}, |
|
|
|
"silhouette": {"weight": 1e0, "value": 0.0}, |
|
} |
|
|
|
|
|
|
|
loop_smpl = tqdm(range(config_dict['loop_smpl'])) |
|
|
|
for _ in loop_smpl: |
|
|
|
optimizer_smpl.zero_grad() |
|
|
|
|
|
optimed_orient_mat = rot6d_to_rotmat(optimed_orient.view(-1,6)).unsqueeze(0) |
|
optimed_pose_mat = rot6d_to_rotmat(optimed_pose.view(-1,6)).unsqueeze(0) |
|
|
|
if dataset_param["hps_type"] != "pixie": |
|
smpl_out = dataset.smpl_model( |
|
betas=optimed_betas, |
|
body_pose=optimed_pose_mat, |
|
global_orient=optimed_orient_mat, |
|
pose2rot=False, |
|
) |
|
|
|
smpl_verts = ((smpl_out.vertices) + |
|
optimed_trans) * data["scale"] |
|
else: |
|
smpl_verts, _, _ = dataset.smpl_model( |
|
shape_params=optimed_betas, |
|
expression_params=tensor2variable(data["exp"], device), |
|
body_pose=optimed_pose_mat, |
|
global_pose=optimed_orient_mat, |
|
jaw_pose=tensor2variable(data["jaw_pose"], device), |
|
left_hand_pose=tensor2variable( |
|
data["left_hand_pose"], device), |
|
right_hand_pose=tensor2variable( |
|
data["right_hand_pose"], device), |
|
) |
|
|
|
smpl_verts = (smpl_verts + optimed_trans) * data["scale"] |
|
|
|
|
|
in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal( |
|
smpl_verts * |
|
torch.tensor([1.0, -1.0, -1.0] |
|
).to(device), in_tensor["smpl_faces"] |
|
) |
|
T_mask_F, T_mask_B = dataset.render.get_silhouette_image() |
|
|
|
with torch.no_grad(): |
|
in_tensor["normal_F"], in_tensor["normal_B"] = model.netG.normal_filter( |
|
in_tensor |
|
) |
|
|
|
diff_F_smpl = torch.abs( |
|
in_tensor["T_normal_F"] - in_tensor["normal_F"]) |
|
diff_B_smpl = torch.abs( |
|
in_tensor["T_normal_B"] - in_tensor["normal_B"]) |
|
|
|
losses["normal"]["value"] = (diff_F_smpl + diff_B_smpl).mean() |
|
|
|
|
|
smpl_arr = torch.cat([T_mask_F, T_mask_B], dim=-1)[0] |
|
gt_arr = torch.cat( |
|
[in_tensor["normal_F"][0], in_tensor["normal_B"][0]], dim=2 |
|
).permute(1, 2, 0) |
|
gt_arr = ((gt_arr + 1.0) * 0.5).to(device) |
|
bg_color = ( |
|
torch.Tensor([0.5, 0.5, 0.5]).unsqueeze( |
|
0).unsqueeze(0).to(device) |
|
) |
|
gt_arr = ((gt_arr - bg_color).sum(dim=-1) != 0.0).float() |
|
diff_S = torch.abs(smpl_arr - gt_arr) |
|
losses["silhouette"]["value"] = diff_S.mean() |
|
|
|
|
|
smpl_loss = 0.0 |
|
pbar_desc = "Body Fitting --- " |
|
for k in ["normal", "silhouette"]: |
|
pbar_desc += f"{k}: {losses[k]['value'] * losses[k]['weight']:.3f} | " |
|
smpl_loss += losses[k]["value"] * losses[k]["weight"] |
|
pbar_desc += f"Total: {smpl_loss:.3f}" |
|
loop_smpl.set_description(pbar_desc) |
|
|
|
smpl_loss.backward() |
|
optimizer_smpl.step() |
|
scheduler_smpl.step(smpl_loss) |
|
in_tensor["smpl_verts"] = smpl_verts * \ |
|
torch.tensor([1.0, 1.0, -1.0]).to(device) |
|
|
|
|
|
|
|
|
|
|
|
os.makedirs(os.path.join(config_dict['out_dir'], cfg.name, |
|
"refinement"), exist_ok=True) |
|
|
|
|
|
os.makedirs(os.path.join(config_dict['out_dir'], |
|
cfg.name, "vid"), exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
os.makedirs(os.path.join(config_dict['out_dir'], |
|
cfg.name, "png"), exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.makedirs(os.path.join(config_dict['out_dir'], |
|
cfg.name, "obj"), exist_ok=True) |
|
|
|
norm_pred_F = ( |
|
((in_tensor["normal_F"][0].permute(1, 2, 0) + 1.0) * 255.0 / 2.0) |
|
.detach() |
|
.cpu() |
|
.numpy() |
|
.astype(np.uint8) |
|
) |
|
|
|
norm_pred_B = ( |
|
((in_tensor["normal_B"][0].permute(1, 2, 0) + 1.0) * 255.0 / 2.0) |
|
.detach() |
|
.cpu() |
|
.numpy() |
|
.astype(np.uint8) |
|
) |
|
|
|
norm_orig_F = unwrap(norm_pred_F, data) |
|
norm_orig_B = unwrap(norm_pred_B, data) |
|
|
|
mask_orig = unwrap( |
|
np.repeat( |
|
data["mask"].permute(1, 2, 0).detach().cpu().numpy(), 3, axis=2 |
|
).astype(np.uint8), |
|
data, |
|
) |
|
rgb_norm_F = blend_rgb_norm(data["ori_image"], norm_orig_F, mask_orig) |
|
rgb_norm_B = blend_rgb_norm(data["ori_image"], norm_orig_B, mask_orig) |
|
|
|
Image.fromarray( |
|
np.concatenate( |
|
[data["ori_image"].astype(np.uint8), rgb_norm_F, rgb_norm_B], axis=1) |
|
).save(os.path.join(config_dict['out_dir'], cfg.name, f"png/{data['name']}_overlap.png")) |
|
|
|
smpl_obj = trimesh.Trimesh( |
|
in_tensor["smpl_verts"].detach().cpu()[0] * |
|
torch.tensor([1.0, -1.0, 1.0]), |
|
in_tensor['smpl_faces'].detach().cpu()[0], |
|
process=False, |
|
maintains_order=True |
|
) |
|
smpl_obj.visual.vertex_colors = (smpl_obj.vertex_normals+1.0)*255.0*0.5 |
|
smpl_obj.export( |
|
f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.obj") |
|
smpl_obj.export( |
|
f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.glb") |
|
|
|
smpl_info = {'betas': optimed_betas, |
|
'pose': optimed_pose_mat, |
|
'orient': optimed_orient_mat, |
|
'trans': optimed_trans} |
|
|
|
np.save( |
|
f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.npy", smpl_info, allow_pickle=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
in_tensor.update( |
|
dataset.compute_vis_cmap( |
|
in_tensor["smpl_verts"][0], in_tensor["smpl_faces"][0] |
|
) |
|
) |
|
|
|
if cfg.net.prior_type == "pamir": |
|
in_tensor.update( |
|
dataset.compute_voxel_verts( |
|
optimed_pose, |
|
optimed_orient, |
|
optimed_betas, |
|
optimed_trans, |
|
data["scale"], |
|
) |
|
) |
|
|
|
with torch.no_grad(): |
|
verts_pr, faces_pr, _ = model.test_single(in_tensor) |
|
|
|
recon_obj = trimesh.Trimesh( |
|
verts_pr, faces_pr, process=False, maintains_order=True |
|
) |
|
recon_obj.visual.vertex_colors = ( |
|
recon_obj.vertex_normals+1.0)*255.0*0.5 |
|
recon_obj.export( |
|
os.path.join(config_dict['out_dir'], cfg.name, |
|
f"obj/{data['name']}_recon.obj") |
|
) |
|
|
|
|
|
verts_refine, faces_refine = remesh(os.path.join(config_dict['out_dir'], cfg.name, |
|
f"obj/{data['name']}_recon.obj"), 0.5, device) |
|
|
|
|
|
mesh_pr = Meshes(verts_refine, faces_refine).to(device) |
|
local_affine_model = LocalAffine( |
|
mesh_pr.verts_padded().shape[1], mesh_pr.verts_padded().shape[0], mesh_pr.edges_packed()).to(device) |
|
optimizer_cloth = torch.optim.Adam( |
|
[{'params': local_affine_model.parameters()}], lr=1e-4, amsgrad=True) |
|
|
|
scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau( |
|
optimizer_cloth, |
|
mode="min", |
|
factor=0.1, |
|
verbose=0, |
|
min_lr=1e-5, |
|
patience=config_dict['patience'], |
|
) |
|
|
|
final = None |
|
|
|
if config_dict['loop_cloth'] > 0: |
|
|
|
loop_cloth = tqdm(range(config_dict['loop_cloth'])) |
|
|
|
for _ in loop_cloth: |
|
|
|
optimizer_cloth.zero_grad() |
|
|
|
deformed_verts, stiffness, rigid = local_affine_model( |
|
verts_refine.to(device), return_stiff=True) |
|
mesh_pr = mesh_pr.update_padded(deformed_verts) |
|
|
|
|
|
update_mesh_shape_prior_losses(mesh_pr, losses) |
|
|
|
in_tensor["P_normal_F"], in_tensor["P_normal_B"] = dataset.render_normal( |
|
mesh_pr.verts_padded(), mesh_pr.faces_padded()) |
|
|
|
diff_F_cloth = torch.abs( |
|
in_tensor["P_normal_F"] - in_tensor["normal_F"]) |
|
diff_B_cloth = torch.abs( |
|
in_tensor["P_normal_B"] - in_tensor["normal_B"]) |
|
|
|
losses["cloth"]["value"] = (diff_F_cloth + diff_B_cloth).mean() |
|
losses["stiffness"]["value"] = torch.mean(stiffness) |
|
losses["rigid"]["value"] = torch.mean(rigid) |
|
|
|
|
|
cloth_loss = torch.tensor(0.0, requires_grad=True).to(device) |
|
pbar_desc = "Cloth Refinement --- " |
|
|
|
for k in losses.keys(): |
|
if k not in ["normal", "silhouette"] and losses[k]["weight"] > 0.0: |
|
cloth_loss = cloth_loss + \ |
|
losses[k]["value"] * losses[k]["weight"] |
|
pbar_desc += f"{k}:{losses[k]['value']* losses[k]['weight']:.5f} | " |
|
|
|
pbar_desc += f"Total: {cloth_loss:.5f}" |
|
loop_cloth.set_description(pbar_desc) |
|
|
|
|
|
cloth_loss.backward() |
|
optimizer_cloth.step() |
|
scheduler_cloth.step(cloth_loss) |
|
|
|
final = trimesh.Trimesh( |
|
mesh_pr.verts_packed().detach().squeeze(0).cpu(), |
|
mesh_pr.faces_packed().detach().squeeze(0).cpu(), |
|
process=False, maintains_order=True |
|
) |
|
|
|
|
|
tex_colors = query_color( |
|
mesh_pr.verts_packed().detach().squeeze(0).cpu(), |
|
mesh_pr.faces_packed().detach().squeeze(0).cpu(), |
|
in_tensor["image"], |
|
device=device, |
|
) |
|
|
|
|
|
norm_colors = (mesh_pr.verts_normals_padded().squeeze( |
|
0).detach().cpu() + 1.0) * 0.5 * 255.0 |
|
|
|
final.visual.vertex_colors = tex_colors |
|
final.export( |
|
f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.obj") |
|
|
|
final.visual.vertex_colors = norm_colors |
|
final.export( |
|
f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.glb") |
|
|
|
|
|
verts_lst = [smpl_obj.vertices, final.vertices] |
|
faces_lst = [smpl_obj.faces, final.faces] |
|
|
|
|
|
dataset.render.load_meshes( |
|
verts_lst, faces_lst) |
|
dataset.render.get_rendered_video( |
|
[data["ori_image"], rgb_norm_F, rgb_norm_B], |
|
os.path.join(config_dict['out_dir'], cfg.name, |
|
f"vid/{data['name']}_cloth.mp4"), |
|
) |
|
|
|
smpl_obj_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.obj" |
|
smpl_glb_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.glb" |
|
smpl_npy_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.npy" |
|
refine_obj_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.obj" |
|
refine_glb_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.glb" |
|
|
|
video_path = os.path.join( |
|
config_dict['out_dir'], cfg.name, f"vid/{data['name']}_cloth.mp4") |
|
overlap_path = os.path.join( |
|
config_dict['out_dir'], cfg.name, f"png/{data['name']}_overlap.png") |
|
|
|
|
|
for element in dir(): |
|
if 'path' not in element: |
|
del locals()[element] |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
return [smpl_glb_path, smpl_obj_path,smpl_npy_path, |
|
refine_glb_path, refine_obj_path, |
|
video_path, video_path, overlap_path] |
|
|