Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
| # holder of all proprietary rights on this computer program. | |
| # You can only use this computer program if you have closed | |
| # a license agreement with MPG or you get the right to use the computer | |
| # program from someone who is authorized to grant you that right. | |
| # Any use of the computer program without a valid license is prohibited and | |
| # liable to prosecution. | |
| # | |
| # Copyright©2019 Max-Planck-Gesellschaft zur Förderung | |
| # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
| # for Intelligent Systems. All rights reserved. | |
| # | |
| # Contact: [email protected] | |
| import os | |
| import logging | |
| from lib.common.render import query_color | |
| from lib.common.config import cfg | |
| from lib.dataset.mesh_util import ( | |
| load_checkpoint, | |
| update_mesh_shape_prior_losses, | |
| blend_rgb_norm, | |
| unwrap, | |
| remesh, | |
| tensor2variable, | |
| ) | |
| from lib.dataset.TestDataset import TestDataset | |
| from lib.net.local_affine import LocalAffine | |
| from pytorch3d.structures import Meshes | |
| from apps.ICON import ICON | |
| from termcolor import colored | |
| import numpy as np | |
| from PIL import Image | |
| import trimesh | |
| import numpy as np | |
| from tqdm import tqdm | |
| import torch | |
| torch.backends.cudnn.benchmark = True | |
| logging.getLogger("trimesh").setLevel(logging.ERROR) | |
| def generate_model(in_path, model_type): | |
| torch.cuda.empty_cache() | |
| config_dict = {'loop_smpl': 50, | |
| 'loop_cloth': 100, | |
| 'patience': 5, | |
| 'vis_freq': 10, | |
| 'out_dir': './results', | |
| 'hps_type': 'pymaf', | |
| 'config': f"./configs/{model_type}.yaml"} | |
| # cfg read and merge | |
| cfg.merge_from_file(config_dict['config']) | |
| cfg.merge_from_file("./lib/pymaf/configs/pymaf_config.yaml") | |
| os.makedirs(config_dict['out_dir'], exist_ok=True) | |
| cfg_show_list = [ | |
| "test_gpus", | |
| [0], | |
| "mcube_res", | |
| 256, | |
| "clean_mesh", | |
| True, | |
| ] | |
| cfg.merge_from_list(cfg_show_list) | |
| cfg.freeze() | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
| device = torch.device(f"cuda:0") | |
| # load model and dataloader | |
| model = ICON(cfg) | |
| model = load_checkpoint(model, cfg) | |
| dataset_param = { | |
| 'image_path': in_path, | |
| 'seg_dir': None, | |
| 'has_det': True, # w/ or w/o detection | |
| 'hps_type': 'pymaf' # pymaf/pare/pixie | |
| } | |
| if config_dict['hps_type'] == "pixie" and "pamir" in config_dict['config']: | |
| print(colored("PIXIE isn't compatible with PaMIR, thus switch to PyMAF", "red")) | |
| dataset_param["hps_type"] = "pymaf" | |
| dataset = TestDataset(dataset_param, device) | |
| print(colored(f"Dataset Size: {len(dataset)}", "green")) | |
| pbar = tqdm(dataset) | |
| for data in pbar: | |
| pbar.set_description(f"{data['name']}") | |
| in_tensor = {"smpl_faces": data["smpl_faces"], "image": data["image"]} | |
| # The optimizer and variables | |
| optimed_pose = torch.tensor( | |
| data["body_pose"], device=device, requires_grad=True | |
| ) # [1,23,3,3] | |
| optimed_trans = torch.tensor( | |
| data["trans"], device=device, requires_grad=True | |
| ) # [3] | |
| optimed_betas = torch.tensor( | |
| data["betas"], device=device, requires_grad=True | |
| ) # [1,10] | |
| optimed_orient = torch.tensor( | |
| data["global_orient"], device=device, requires_grad=True | |
| ) # [1,1,3,3] | |
| optimizer_smpl = torch.optim.SGD( | |
| [optimed_pose, optimed_trans, optimed_betas, optimed_orient], | |
| lr=1e-3, | |
| momentum=0.9, | |
| ) | |
| scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer_smpl, | |
| mode="min", | |
| factor=0.5, | |
| verbose=0, | |
| min_lr=1e-5, | |
| patience=config_dict['patience'], | |
| ) | |
| losses = { | |
| # Cloth: Normal_recon - Normal_pred | |
| "cloth": {"weight": 1e1, "value": 0.0}, | |
| # Cloth: [RT]_v1 - [RT]_v2 (v1-edge-v2) | |
| "stiffness": {"weight": 1e5, "value": 0.0}, | |
| # Cloth: det(R) = 1 | |
| "rigid": {"weight": 1e5, "value": 0.0}, | |
| # Cloth: edge length | |
| "edge": {"weight": 0, "value": 0.0}, | |
| # Cloth: normal consistency | |
| "nc": {"weight": 0, "value": 0.0}, | |
| # Cloth: laplacian smoonth | |
| "laplacian": {"weight": 1e2, "value": 0.0}, | |
| # Body: Normal_pred - Normal_smpl | |
| "normal": {"weight": 1e0, "value": 0.0}, | |
| # Body: Silhouette_pred - Silhouette_smpl | |
| "silhouette": {"weight": 1e0, "value": 0.0}, | |
| } | |
| # smpl optimization | |
| loop_smpl = tqdm( | |
| range(config_dict['loop_smpl'] if cfg.net.prior_type != "pifu" else 1)) | |
| for i in loop_smpl: | |
| optimizer_smpl.zero_grad() | |
| if dataset_param["hps_type"] != "pixie": | |
| smpl_out = dataset.smpl_model( | |
| betas=optimed_betas, | |
| body_pose=optimed_pose, | |
| global_orient=optimed_orient, | |
| pose2rot=False, | |
| ) | |
| smpl_verts = ((smpl_out.vertices) + | |
| optimed_trans) * data["scale"] | |
| else: | |
| smpl_verts, _, _ = dataset.smpl_model( | |
| shape_params=optimed_betas, | |
| expression_params=tensor2variable(data["exp"], device), | |
| body_pose=optimed_pose, | |
| global_pose=optimed_orient, | |
| jaw_pose=tensor2variable(data["jaw_pose"], device), | |
| left_hand_pose=tensor2variable( | |
| data["left_hand_pose"], device), | |
| right_hand_pose=tensor2variable( | |
| data["right_hand_pose"], device), | |
| ) | |
| smpl_verts = (smpl_verts + optimed_trans) * data["scale"] | |
| # render optimized mesh (normal, T_normal, image [-1,1]) | |
| in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal( | |
| smpl_verts * | |
| torch.tensor([1.0, -1.0, -1.0] | |
| ).to(device), in_tensor["smpl_faces"] | |
| ) | |
| T_mask_F, T_mask_B = dataset.render.get_silhouette_image() | |
| with torch.no_grad(): | |
| in_tensor["normal_F"], in_tensor["normal_B"] = model.netG.normal_filter( | |
| in_tensor | |
| ) | |
| diff_F_smpl = torch.abs( | |
| in_tensor["T_normal_F"] - in_tensor["normal_F"]) | |
| diff_B_smpl = torch.abs( | |
| in_tensor["T_normal_B"] - in_tensor["normal_B"]) | |
| losses["normal"]["value"] = (diff_F_smpl + diff_B_smpl).mean() | |
| # silhouette loss | |
| smpl_arr = torch.cat([T_mask_F, T_mask_B], dim=-1)[0] | |
| gt_arr = torch.cat( | |
| [in_tensor["normal_F"][0], in_tensor["normal_B"][0]], dim=2 | |
| ).permute(1, 2, 0) | |
| gt_arr = ((gt_arr + 1.0) * 0.5).to(device) | |
| bg_color = ( | |
| torch.Tensor([0.5, 0.5, 0.5]).unsqueeze( | |
| 0).unsqueeze(0).to(device) | |
| ) | |
| gt_arr = ((gt_arr - bg_color).sum(dim=-1) != 0.0).float() | |
| diff_S = torch.abs(smpl_arr - gt_arr) | |
| losses["silhouette"]["value"] = diff_S.mean() | |
| # Weighted sum of the losses | |
| smpl_loss = 0.0 | |
| pbar_desc = "Body Fitting --- " | |
| for k in ["normal", "silhouette"]: | |
| pbar_desc += f"{k}: {losses[k]['value'] * losses[k]['weight']:.3f} | " | |
| smpl_loss += losses[k]["value"] * losses[k]["weight"] | |
| pbar_desc += f"Total: {smpl_loss:.3f}" | |
| loop_smpl.set_description(pbar_desc) | |
| smpl_loss.backward() | |
| optimizer_smpl.step() | |
| scheduler_smpl.step(smpl_loss) | |
| in_tensor["smpl_verts"] = smpl_verts * \ | |
| torch.tensor([1.0, 1.0, -1.0]).to(device) | |
| # visualize the optimization process | |
| # 1. SMPL Fitting | |
| # 2. Clothes Refinement | |
| os.makedirs(os.path.join(config_dict['out_dir'], cfg.name, | |
| "refinement"), exist_ok=True) | |
| # visualize the final results in self-rotation mode | |
| os.makedirs(os.path.join(config_dict['out_dir'], | |
| cfg.name, "vid"), exist_ok=True) | |
| # final results rendered as image | |
| # 1. Render the final fitted SMPL (xxx_smpl.png) | |
| # 2. Render the final reconstructed clothed human (xxx_cloth.png) | |
| # 3. Blend the original image with predicted cloth normal (xxx_overlap.png) | |
| os.makedirs(os.path.join(config_dict['out_dir'], | |
| cfg.name, "png"), exist_ok=True) | |
| # final reconstruction meshes | |
| # 1. SMPL mesh (xxx_smpl.obj) | |
| # 2. SMPL params (xxx_smpl.npy) | |
| # 3. clohted mesh (xxx_recon.obj) | |
| # 4. remeshed clothed mesh (xxx_remesh.obj) | |
| # 5. refined clothed mesh (xxx_refine.obj) | |
| os.makedirs(os.path.join(config_dict['out_dir'], | |
| cfg.name, "obj"), exist_ok=True) | |
| norm_pred = ( | |
| ((in_tensor["normal_F"][0].permute(1, 2, 0) + 1.0) * 255.0 / 2.0) | |
| .detach() | |
| .cpu() | |
| .numpy() | |
| .astype(np.uint8) | |
| ) | |
| norm_orig = unwrap(norm_pred, data) | |
| mask_orig = unwrap( | |
| np.repeat( | |
| data["mask"].permute(1, 2, 0).detach().cpu().numpy(), 3, axis=2 | |
| ).astype(np.uint8), | |
| data, | |
| ) | |
| rgb_norm = blend_rgb_norm(data["ori_image"], norm_orig, mask_orig) | |
| Image.fromarray( | |
| np.concatenate( | |
| [data["ori_image"].astype(np.uint8), rgb_norm], axis=1) | |
| ).save(os.path.join(config_dict['out_dir'], cfg.name, f"png/{data['name']}_overlap.png")) | |
| smpl_obj = trimesh.Trimesh( | |
| in_tensor["smpl_verts"].detach().cpu()[0] * | |
| torch.tensor([1.0, -1.0, 1.0]), | |
| in_tensor['smpl_faces'].detach().cpu()[0], | |
| process=False, | |
| maintains_order=True | |
| ) | |
| smpl_obj.export( | |
| f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.glb") | |
| smpl_info = {'betas': optimed_betas, | |
| 'pose': optimed_pose, | |
| 'orient': optimed_orient, | |
| 'trans': optimed_trans} | |
| np.save( | |
| f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.npy", smpl_info, allow_pickle=True) | |
| # ------------------------------------------------------------------------------------------------------------------ | |
| # cloth optimization | |
| # cloth recon | |
| in_tensor.update( | |
| dataset.compute_vis_cmap( | |
| in_tensor["smpl_verts"][0], in_tensor["smpl_faces"][0] | |
| ) | |
| ) | |
| if cfg.net.prior_type == "pamir": | |
| in_tensor.update( | |
| dataset.compute_voxel_verts( | |
| optimed_pose, | |
| optimed_orient, | |
| optimed_betas, | |
| optimed_trans, | |
| data["scale"], | |
| ) | |
| ) | |
| with torch.no_grad(): | |
| verts_pr, faces_pr, _ = model.test_single(in_tensor) | |
| recon_obj = trimesh.Trimesh( | |
| verts_pr, faces_pr, process=False, maintains_order=True | |
| ) | |
| recon_obj.export( | |
| os.path.join(config_dict['out_dir'], cfg.name, | |
| f"obj/{data['name']}_recon.obj") | |
| ) | |
| recon_obj.export( | |
| os.path.join(config_dict['out_dir'], cfg.name, | |
| f"obj/{data['name']}_recon.glb") | |
| ) | |
| # Isotropic Explicit Remeshing for better geometry topology | |
| verts_refine, faces_refine = remesh(os.path.join(config_dict['out_dir'], cfg.name, | |
| f"obj/{data['name']}_recon.obj"), 0.5, device) | |
| # define local_affine deform verts | |
| mesh_pr = Meshes(verts_refine, faces_refine).to(device) | |
| local_affine_model = LocalAffine( | |
| mesh_pr.verts_padded().shape[1], mesh_pr.verts_padded().shape[0], mesh_pr.edges_packed()).to(device) | |
| optimizer_cloth = torch.optim.Adam( | |
| [{'params': local_affine_model.parameters()}], lr=1e-4, amsgrad=True) | |
| scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer_cloth, | |
| mode="min", | |
| factor=0.1, | |
| verbose=0, | |
| min_lr=1e-5, | |
| patience=config_dict['patience'], | |
| ) | |
| final = None | |
| if config_dict['loop_cloth'] > 0: | |
| loop_cloth = tqdm(range(config_dict['loop_cloth'])) | |
| for i in loop_cloth: | |
| optimizer_cloth.zero_grad() | |
| deformed_verts, stiffness, rigid = local_affine_model( | |
| verts_refine.to(device), return_stiff=True) | |
| mesh_pr = mesh_pr.update_padded(deformed_verts) | |
| # losses for laplacian, edge, normal consistency | |
| update_mesh_shape_prior_losses(mesh_pr, losses) | |
| in_tensor["P_normal_F"], in_tensor["P_normal_B"] = dataset.render_normal( | |
| mesh_pr.verts_padded(), mesh_pr.faces_padded()) | |
| diff_F_cloth = torch.abs( | |
| in_tensor["P_normal_F"] - in_tensor["normal_F"]) | |
| diff_B_cloth = torch.abs( | |
| in_tensor["P_normal_B"] - in_tensor["normal_B"]) | |
| losses["cloth"]["value"] = (diff_F_cloth + diff_B_cloth).mean() | |
| losses["stiffness"]["value"] = torch.mean(stiffness) | |
| losses["rigid"]["value"] = torch.mean(rigid) | |
| # Weighted sum of the losses | |
| cloth_loss = torch.tensor(0.0, requires_grad=True).to(device) | |
| pbar_desc = "Cloth Refinement --- " | |
| for k in losses.keys(): | |
| if k not in ["normal", "silhouette"] and losses[k]["weight"] > 0.0: | |
| cloth_loss = cloth_loss + \ | |
| losses[k]["value"] * losses[k]["weight"] | |
| pbar_desc += f"{k}:{losses[k]['value']* losses[k]['weight']:.5f} | " | |
| pbar_desc += f"Total: {cloth_loss:.5f}" | |
| loop_cloth.set_description(pbar_desc) | |
| # update params | |
| cloth_loss.backward(retain_graph=True) | |
| optimizer_cloth.step() | |
| scheduler_cloth.step(cloth_loss) | |
| final = trimesh.Trimesh( | |
| mesh_pr.verts_packed().detach().squeeze(0).cpu(), | |
| mesh_pr.faces_packed().detach().squeeze(0).cpu(), | |
| process=False, maintains_order=True | |
| ) | |
| final_colors = query_color( | |
| mesh_pr.verts_packed().detach().squeeze(0).cpu(), | |
| mesh_pr.faces_packed().detach().squeeze(0).cpu(), | |
| in_tensor["image"], | |
| device=device, | |
| ) | |
| final.visual.vertex_colors = final_colors | |
| final.export( | |
| f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.glb") | |
| # always export visualized video regardless of the cloth refinment | |
| if final is not None: | |
| verts_lst = [verts_pr, final.vertices] | |
| faces_lst = [faces_pr, final.faces] | |
| else: | |
| verts_lst = [verts_pr] | |
| faces_lst = [faces_pr] | |
| # self-rotated video | |
| dataset.render.load_meshes( | |
| verts_lst, faces_lst) | |
| dataset.render.get_rendered_video( | |
| [data["ori_image"], rgb_norm], | |
| os.path.join(config_dict['out_dir'], cfg.name, | |
| f"vid/{data['name']}_cloth.mp4"), | |
| ) | |
| smpl_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.glb" | |
| smpl_npy_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.npy" | |
| recon_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_recon.glb" | |
| refine_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.glb" | |
| video_path = os.path.join(config_dict['out_dir'], cfg.name, f"vid/{data['name']}_cloth.mp4") | |
| overlap_path = os.path.join(config_dict['out_dir'], cfg.name, f"png/{data['name']}_overlap.png") | |
| torch.cuda.empty_cache() | |
| del model | |
| del dataset | |
| del local_affine_model | |
| del optimizer_smpl | |
| del optimizer_cloth | |
| del scheduler_smpl | |
| del scheduler_cloth | |
| del losses | |
| del in_tensor | |
| return [smpl_path, smpl_path, smpl_npy_path, recon_path, recon_path, refine_path, refine_path, video_path, overlap_path] |