Spaces:
Running
on
Zero
Running
on
Zero
| # -*- coding: utf-8 -*- | |
| # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
| # holder of all proprietary rights on this computer program. | |
| # You can only use this computer program if you have closed | |
| # a license agreement with MPG or you get the right to use the computer | |
| # program from someone who is authorized to grant you that right. | |
| # Any use of the computer program without a valid license is prohibited and | |
| # liable to prosecution. | |
| # | |
| # Copyright©2019 Max-Planck-Gesellschaft zur Förderung | |
| # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
| # for Intelligent Systems. All rights reserved. | |
| # | |
| # Contact: [email protected] | |
| from lib.dataset.mesh_util import projection | |
| from lib.common.render import Render | |
| import numpy as np | |
| import torch | |
| import os.path as osp | |
| from torchvision.utils import make_grid | |
| from pytorch3d.io import IO | |
| from pytorch3d.ops import sample_points_from_meshes | |
| from pytorch3d.loss.point_mesh_distance import _PointFaceDistance | |
| from pytorch3d.structures import Pointclouds | |
| from PIL import Image | |
| def point_mesh_distance(meshes, pcls): | |
| if len(meshes) != len(pcls): | |
| raise ValueError("meshes and pointclouds must be equal sized batches") | |
| N = len(meshes) | |
| # packed representation for pointclouds | |
| points = pcls.points_packed() # (P, 3) | |
| points_first_idx = pcls.cloud_to_packed_first_idx() | |
| max_points = pcls.num_points_per_cloud().max().item() | |
| # packed representation for faces | |
| verts_packed = meshes.verts_packed() | |
| faces_packed = meshes.faces_packed() | |
| tris = verts_packed[faces_packed] # (T, 3, 3) | |
| tris_first_idx = meshes.mesh_to_faces_packed_first_idx() | |
| # point to face distance: shape (P,) | |
| point_to_face = _PointFaceDistance.apply(points, points_first_idx, tris, | |
| tris_first_idx, max_points, 5e-3) | |
| # weight each example by the inverse of number of points in the example | |
| point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i),) | |
| num_points_per_cloud = pcls.num_points_per_cloud() # (N,) | |
| weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx) | |
| weights_p = 1.0 / weights_p.float() | |
| point_to_face = torch.sqrt(point_to_face) * weights_p | |
| point_dist = point_to_face.sum() / N | |
| return point_dist | |
| class Evaluator: | |
| def __init__(self, device): | |
| self.render = Render(size=512, device=device) | |
| self.device = device | |
| def set_mesh(self, result_dict): | |
| for k, v in result_dict.items(): | |
| setattr(self, k, v) | |
| self.verts_pr -= self.recon_size / 2.0 | |
| self.verts_pr /= self.recon_size / 2.0 | |
| self.verts_gt = projection(self.verts_gt, self.calib) | |
| self.verts_gt[:, 1] *= -1 | |
| self.src_mesh = self.render.VF2Mesh(self.verts_pr, self.faces_pr) | |
| self.tgt_mesh = self.render.VF2Mesh(self.verts_gt, self.faces_gt) | |
| def calculate_normal_consist(self, normal_path): | |
| self.render.meshes = self.src_mesh | |
| src_normal_imgs = self.render.get_rgb_image(cam_ids=[ 0,1,2, 3], | |
| bg='black') | |
| self.render.meshes = self.tgt_mesh | |
| tgt_normal_imgs = self.render.get_rgb_image(cam_ids=[0,1,2, 3], | |
| bg='black') | |
| src_normal_arr = make_grid(torch.cat(src_normal_imgs, dim=0), nrow=4,padding=0) # [0,1] | |
| tgt_normal_arr = make_grid(torch.cat(tgt_normal_imgs, dim=0), nrow=4,padding=0) # [0,1] | |
| src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True) | |
| tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True) | |
| src_norm[src_norm == 0.0] = 1.0 | |
| tgt_norm[tgt_norm == 0.0] = 1.0 | |
| src_normal_arr /= src_norm | |
| tgt_normal_arr /= tgt_norm | |
| src_normal_arr = (src_normal_arr + 1.0) * 0.5 | |
| tgt_normal_arr = (tgt_normal_arr + 1.0) * 0.5 | |
| error = (( | |
| (src_normal_arr - tgt_normal_arr)**2).sum(dim=0).mean()) * 4 | |
| #print('normal error:', error) | |
| normal_img = Image.fromarray( | |
| (torch.cat([src_normal_arr, tgt_normal_arr], dim=1).permute( | |
| 1, 2, 0).detach().cpu().numpy() * 255.0).astype(np.uint8)) | |
| normal_img.save(normal_path) | |
| error_list = [] | |
| if len(src_normal_imgs) > 4: | |
| for i in range(len(src_normal_imgs)): | |
| src_normal_arr = src_normal_imgs[i] # Get each source normal image | |
| tgt_normal_arr = tgt_normal_imgs[i] # Get corresponding target normal image | |
| src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True) | |
| tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True) | |
| src_norm[src_norm == 0.0] = 1.0 | |
| tgt_norm[tgt_norm == 0.0] = 1.0 | |
| src_normal_arr /= src_norm | |
| tgt_normal_arr /= tgt_norm | |
| src_normal_arr = (src_normal_arr + 1.0) * 0.5 | |
| tgt_normal_arr = (tgt_normal_arr + 1.0) * 0.5 | |
| error = ((src_normal_arr - tgt_normal_arr) ** 2).sum(dim=0).mean() * 4.0 | |
| error_list.append(error) | |
| return error_list | |
| else: | |
| src_normal_arr = make_grid(torch.cat(src_normal_imgs, dim=0), nrow=4,padding=0) # [0,1] | |
| tgt_normal_arr = make_grid(torch.cat(tgt_normal_imgs, dim=0), nrow=4,padding=0) # [0,1] | |
| src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True) | |
| tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True) | |
| src_norm[src_norm == 0.0] = 1.0 | |
| tgt_norm[tgt_norm == 0.0] = 1.0 | |
| src_normal_arr /= src_norm | |
| tgt_normal_arr /= tgt_norm | |
| # sim_mask = self.get_laplacian_2d(tgt_normal_arr).to(self.device) | |
| src_normal_arr = (src_normal_arr + 1.0) * 0.5 | |
| tgt_normal_arr = (tgt_normal_arr + 1.0) * 0.5 | |
| error = (( | |
| (src_normal_arr - tgt_normal_arr)**2).sum(dim=0).mean()) * 4 | |
| #print('normal error:', error) | |
| return error | |
| def export_mesh(self, dir, name): | |
| IO().save_mesh(self.src_mesh, osp.join(dir, f"{name}_src.obj")) | |
| IO().save_mesh(self.tgt_mesh, osp.join(dir, f"{name}_tgt.obj")) | |
| def calculate_chamfer_p2s(self, num_samples=1000): | |
| tgt_points = Pointclouds( | |
| sample_points_from_meshes(self.tgt_mesh, num_samples)) | |
| src_points = Pointclouds( | |
| sample_points_from_meshes(self.src_mesh, num_samples)) | |
| p2s_dist = point_mesh_distance(self.src_mesh, tgt_points) * 100.0 | |
| chamfer_dist = (point_mesh_distance(self.tgt_mesh, src_points) * 100.0 | |
| + p2s_dist) * 0.5 | |
| return chamfer_dist, p2s_dist | |
| def calc_acc(self, output, target, thres=0.5, use_sdf=False): | |
| # # remove the surface points with thres | |
| # non_surf_ids = (target != thres) | |
| # output = output[non_surf_ids] | |
| # target = target[non_surf_ids] | |
| with torch.no_grad(): | |
| output = output.masked_fill(output < thres, 0.0) | |
| output = output.masked_fill(output > thres, 1.0) | |
| if use_sdf: | |
| target = target.masked_fill(target < thres, 0.0) | |
| target = target.masked_fill(target > thres, 1.0) | |
| acc = output.eq(target).float().mean() | |
| # iou, precison, recall | |
| output = output > thres | |
| target = target > thres | |
| union = output | target | |
| inter = output & target | |
| _max = torch.tensor(1.0).to(output.device) | |
| union = max(union.sum().float(), _max) | |
| true_pos = max(inter.sum().float(), _max) | |
| vol_pred = max(output.sum().float(), _max) | |
| vol_gt = max(target.sum().float(), _max) | |
| return acc, true_pos / union, true_pos / vol_pred, true_pos / vol_gt | |