Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python | |
# coding=utf-8 | |
import numpy as np | |
import open3d as o3d | |
import os | |
import argparse | |
import torch | |
import trimesh | |
import pyrender | |
import copy | |
from copy import deepcopy | |
import torch.nn.functional as F | |
from help_func import auto_orient_and_center_poses | |
import cv2 | |
def extract_depth_from_mesh(mesh, | |
c2w_list, | |
H, W, fx, fy, cx, cy, | |
far=20.0,): | |
"""Adapted from Go-Surf: https://github.com/JingwenWang95/go-surf""" | |
os.environ['PYOPENGL_PLATFORM'] = 'egl' # allows for GPU-accelerated rendering | |
scene = pyrender.Scene() | |
#mesh = trimesh.load("/home/yuzh/mnt/A100_data/sdfstudio/meshes_tnt/bakedangelo/Courthouse_fullres_1024.ply") | |
#mesh = trimesh.load("/home/yuzh/mnt/A100_data/sdfstudio/meshes_tnt/bakedangelo/Caterpillar_fullres_1024.ply") | |
#mesh = trimesh.load("/home/yuzh/mnt/A100_data/sdfstudio/meshes_tnt/bakedangelo/Truck_fullres_1024.ply") | |
#mesh = trimesh.load("/home/yuzh/mnt/A3_data/sdfstudio/meshes_tnt/bakedangelo/Meetingroom_fullres_1024_scaleback.ply") | |
# mesh = trimesh.load("/home/yuzh/mnt/A3_data/sdfstudio/meshes_tnt/bakedangelo/Barn_fullres_1024.ply") | |
mesh = pyrender.Mesh.from_trimesh(mesh) | |
scene.add(mesh) | |
""" | |
import glob | |
for f in glob.glob("/home/yuzh/mnt/A100/Projects/sdfstudio/tmp_meshes/*.ply"): | |
mesh = trimesh.load(f) | |
mesh = pyrender.Mesh.from_trimesh(mesh) | |
scene.add(mesh) | |
print(f) | |
""" | |
camera = pyrender.IntrinsicsCamera(fx=fx, fy=fy, cx=cx, cy=cy, znear=0.01, zfar=far) | |
camera_node = pyrender.Node(camera=camera, matrix=np.eye(4)) | |
scene.add_node(camera_node) | |
renderer = pyrender.OffscreenRenderer(W, H) | |
flags = pyrender.RenderFlags.OFFSCREEN | pyrender.RenderFlags.DEPTH_ONLY | pyrender.RenderFlags.SKIP_CULL_FACES | |
depths = [] | |
for c2w in c2w_list: | |
c2w = c2w.detach().numpy() | |
# Convert camera coordinate system from OpenCV to OpenGL | |
# Details refer to: https://pyrender.readthedocs.io/en/latest/examples/cameras.html | |
c2w_gl = deepcopy(c2w) | |
# nerfstudio's .json file is already OpenGL coordinate | |
#c2w_gl[:3, 1] *= -1 | |
#c2w_gl[:3, 2] *= -1 | |
scene.set_pose(camera_node, c2w_gl) | |
depth = renderer.render(scene, flags) | |
#print(depth, depth.min(), depth.max(), depth.shape) | |
#exit(-1) | |
#cv2.imshow("s", depth) | |
#cv2.waitKey(0) | |
depth = torch.from_numpy(depth) | |
depths.append(depth) | |
renderer.delete() | |
return depths | |
class Mesher(object): | |
def __init__(self, H, W, fx, fy, cx, cy, far, points_batch_size=5e5): | |
""" | |
Mesher class, given a scene representation, the mesher extracts the mesh from it. | |
Args: | |
cfg: (dict), parsed config dict | |
args: (class 'argparse.Namespace'), argparse arguments | |
slam: (class NICE-SLAM), NICE-SLAM main class | |
points_batch_size: (int), maximum points size for query in one batch | |
Used to alleviate GPU memory usage. Defaults to 5e5 | |
ray_batch_size: (int), maximum ray size for query in one batch | |
Used to alleviate GPU memory usage. Defaults to 1e5 | |
""" | |
self.points_batch_size = int(points_batch_size) | |
self.scale = 1.0 | |
self.device = 'cuda:0' | |
self.forecast_radius = 0 | |
self.H, self.W, self.fx, self.fy, self.cx, self.cy = H, W, fx, fy, cx, cy | |
self.resolution = 256 | |
self.level_set = 0.0 | |
self.remove_small_geometry_threshold = 0.2 | |
self.get_largest_components = True | |
self.verbose = True | |
def point_masks(self, | |
input_points, | |
depth_list, | |
estimate_c2w_list): | |
""" | |
Split the input points into seen, unseen, and forecast, | |
according to the estimated camera pose and depth image. | |
Args: | |
input_points: (Tensor), input points | |
keyframe_dict: (list), list of keyframe info dictionary | |
estimate_c2w_list: (list), estimated camera pose. | |
idx: (int), current frame index | |
device: (str), device name to compute on. | |
get_mask_use_all_frames: | |
Returns: | |
seen_mask: (Tensor), the mask for seen area. | |
forecast_mask: (Tensor), the mask for forecast area. | |
unseen_mask: (Tensor), the mask for unseen area. | |
""" | |
H, W, fx, fy, cx, cy = self.H, self.W, self.fx, self.fy, self.cx, self.cy | |
device =self.device | |
if not isinstance(input_points, torch.Tensor): | |
input_points = torch.from_numpy(input_points) | |
input_points = input_points.clone().detach().float() | |
mask = [] | |
forecast_mask = [] | |
# this eps should be tuned for the scene | |
eps = 0.005 | |
for _, pnts in enumerate(torch.split(input_points, self.points_batch_size, dim=0)): | |
n_pts, _ = pnts.shape | |
valid = torch.zeros(n_pts).to(device).bool() | |
valid_num = torch.zeros(n_pts).to(device).int() | |
valid_forecast = torch.zeros(n_pts).to(device).bool() | |
r = self.forecast_radius | |
for i in range(len(estimate_c2w_list)): | |
points = pnts.to(device).float() | |
c2w = estimate_c2w_list[i].to(device).float() | |
# transform to opencv coordinate as nerfstudio dataparser's .json file is in opengl coordinate | |
# c2w[:3, 1:3] *= -1 | |
depth = depth_list[i].to(device) | |
w2c = torch.inverse(c2w).to(device).float() | |
ones = torch.ones_like(points[:, 0]).reshape(-1, 1).to(device) | |
homo_points = torch.cat([points, ones], dim=1).reshape(-1, 4, 1).to(device).float() | |
cam_cord_homo = w2c @ homo_points | |
cam_cord = cam_cord_homo[:, :3, :] # [N, 3, 1] | |
K = np.eye(3) | |
K[0, 0], K[0, 2], K[1, 1], K[1, 2] = fx, cx, fy, cy | |
K = torch.from_numpy(K).to(device) | |
uv = K.float() @ cam_cord.float() | |
z = uv[:, -1:] + 1e-8 | |
uv = uv[:, :2] / z # [N, 2, 1] | |
u, v = uv[:, 0, 0].float(), uv[:, 1, 0].float() | |
z = z[:, 0, 0].float() | |
in_frustum = (u >= 0) & (u <= W-1) & (v >= 0) & (v <= H-1) & (z > 0) | |
forecast_frustum = (u >= -r) & (u <= W-1+r) & (v >= -r) & (v <= H-1+r) & (z > 0) | |
depth = depth.reshape(1, 1, H, W) | |
vgrid = uv.reshape(1, 1, -1, 2) | |
# normalized to [-1, 1] | |
vgrid[..., 0] = (vgrid[..., 0] / (W - 1) * 2.0 - 1.0) | |
vgrid[..., 1] = (vgrid[..., 1] / (H - 1) * 2.0 - 1.0) | |
depth_sample = F.grid_sample(depth, vgrid, padding_mode='border', align_corners=True) | |
depth_sample = depth_sample.reshape(-1) | |
is_front_face = torch.where((depth_sample > 0.0), (z < (depth_sample + eps)), torch.ones_like(z).bool()) | |
is_forecast_face = torch.where((depth_sample > 0.0), (z < (depth_sample + eps)), torch.ones_like(z).bool()) | |
in_frustum = in_frustum & is_front_face | |
valid = valid | in_frustum.bool() | |
valid_num = valid_num + in_frustum.int() | |
forecast_frustum = forecast_frustum & is_forecast_face | |
forecast_frustum = in_frustum | forecast_frustum | |
valid_forecast = valid_forecast | forecast_frustum.bool() | |
valid = valid_num >= 20 | |
# valid = valid_num >= 80 | |
mask.append(valid.cpu().numpy()) | |
forecast_mask.append(valid_forecast.cpu().numpy()) | |
mask = np.concatenate(mask, axis=0) | |
forecast_mask = np.concatenate(forecast_mask, axis=0) | |
return mask, forecast_mask | |
def get_connected_mesh(self, mesh, get_largest_components=False): | |
print("split") | |
components = mesh.split(only_watertight=False) | |
print("split completed") | |
if get_largest_components: | |
areas = np.array([c.area for c in components], dtype=np.float) | |
mesh = components[areas.argmax()] | |
else: | |
new_components = [] | |
global_area = mesh.area | |
for comp in components: | |
if comp.area > self.remove_small_geometry_threshold * global_area: | |
new_components.append(comp) | |
mesh = trimesh.util.concatenate(new_components) | |
return mesh | |
def cull_mesh(self, | |
mesh, | |
estimate_c2w_list): | |
""" | |
Extract mesh from scene representation and save mesh to file. | |
Args: | |
mesh_out_file: (str), output mesh filename | |
estimate_c2w_list: (Tensor), estimated camera pose, camera coordinate system is same with OpenCV | |
[N, 4, 4] | |
""" | |
step = 1 | |
print('Start Mesh Culling', end='') | |
# cull with 3d projection | |
print(f' --->> {step}(Projection)', end='') | |
forward_depths = extract_depth_from_mesh( | |
mesh, estimate_c2w_list, H=self.H, W=self.W, fx=self.fx, fy=self.fy, cx=self.cx, cy=self.cy, far=20.0 | |
) | |
print("after forward depth") | |
""" | |
backward_mesh = deepcopy(mesh) | |
backward_mesh.faces[:, [1, 2]] = backward_mesh.faces[:, [2, 1]] # make the mesh faces from, e.g., facing inside to outside | |
backward_depths = extract_depth_from_mesh( | |
backward_mesh, estimate_c2w_list, H=self.H, W=self.W, fx=self.fx, fy=self.fy, cx=self.cx, cy=self.cy, far=20.0 | |
) | |
depth_list = [] | |
for i in range(len(forward_depths)): | |
depth = torch.where(forward_depths[i] > 0, forward_depths[i], backward_depths[i]) | |
depth = torch.where((backward_depths[i] > 0) & (backward_depths[i] < depth), backward_depths[i], depth) | |
depth_list.append(depth) | |
""" | |
depth_list = forward_depths | |
print("in point masks") | |
vertices = mesh.vertices[:, :3] | |
mask, forecast_mask = self.point_masks( | |
vertices, depth_list, estimate_c2w_list | |
) | |
print(mask.shape, forecast_mask.shape, mask.mean()) | |
face_mask = mask[mesh.faces].all(axis=1) | |
mesh_with_hole = deepcopy(mesh) | |
mesh_with_hole.update_faces(face_mask) | |
mesh_with_hole.remove_unreferenced_vertices() | |
#mesh_with_hole.process(validate=True) | |
step += 1 | |
print("compute componet") | |
# cull by computing connected components | |
print(f' --->> {step}(Component)', end='') | |
#cull_mesh = self.get_connected_mesh(mesh_with_hole, self.get_largest_components) | |
cull_mesh = mesh_with_hole | |
print("after compute componet") | |
step += 1 | |
if abs(self.forecast_radius) > 0: | |
# for forecasting | |
print(f' --->> {step}(Forecast:{self.forecast_radius})', end='') | |
forecast_face_mask = forecast_mask[mesh.faces].all(axis=1) | |
forecast_mesh = deepcopy(mesh) | |
forecast_mesh.update_faces(forecast_face_mask) | |
forecast_mesh.remove_unreferenced_vertices() | |
cull_pc = o3d.geometry.PointCloud( | |
o3d.utility.Vector3dVector(np.array(cull_mesh.vertices)) | |
) | |
aabb = cull_pc.get_oriented_bounding_box() | |
indices = aabb.get_point_indices_within_bounding_box( | |
o3d.utility.Vector3dVector(np.array(forecast_mesh.vertices)) | |
) | |
bound_mask = np.zeros(len(forecast_mesh.vertices)) | |
bound_mask[indices] = 1.0 | |
bound_mask = bound_mask.astype(np.bool_) | |
forecast_face_mask = bound_mask[forecast_mesh.faces].all(axis=1) | |
forecast_mesh.update_faces(forecast_face_mask) | |
forecast_mesh.remove_unreferenced_vertices() | |
forecast_mesh = self.get_connected_mesh(forecast_mesh, self.get_largest_components) | |
step += 1 | |
else: | |
forecast_mesh = deepcopy(cull_mesh) | |
print(' --->> Done!') | |
return cull_mesh, forecast_mesh | |
def __call__(self, mesh_path, estimate_c2w_list): | |
print(f'Loading mesh from {mesh_path}...') | |
mesh = trimesh.load(mesh_path, process=True) | |
mesh.merge_vertices() | |
""" | |
print(f'Mesh loaded from {mesh_path}!') | |
mask = np.linalg.norm(mesh.vertices, axis=-1) < 1.0 | |
print(mask.shape, mask.mean()) | |
face_mask = mask[mesh.faces].all(axis=1) | |
mesh_with_hole = deepcopy(mesh) | |
mesh_with_hole.update_faces(face_mask) | |
mesh_with_hole.remove_unreferenced_vertices() | |
mesh = mesh_with_hole | |
print(f'Mesh clear from {mesh_path}!') | |
""" | |
mesh_out_file = mesh_path.replace('.ply', '_cull.ply') | |
cull_mesh, forecast_mesh = self.cull_mesh( | |
mesh=mesh, | |
estimate_c2w_list=estimate_c2w_list, | |
) | |
cull_mesh.export(mesh_out_file) | |
if self.verbose: | |
print("\nINFO: Save mesh at {}!\n".format(mesh_out_file)) | |
torch.cuda.empty_cache() | |
def read_trajectory(filename): | |
traj = [] | |
with open(filename, "r") as f: | |
metastr = f.readline() | |
while metastr: | |
metadata = map(int, metastr.split()) | |
mat = np.zeros(shape=(4, 4)) | |
for i in range(4): | |
matstr = f.readline() | |
mat[i, :] = np.fromstring(matstr, dtype=float, sep=" \t") | |
traj.append(mat) | |
metastr = f.readline() | |
return traj | |
def get_traj(traj_path): | |
print(f'Load trajectory from {traj_path}.') | |
traj_to_register = [] | |
if traj_path.endswith('.npy'): | |
ld = np.load(traj_path) | |
for i in range(len(ld)): | |
# traj_to_register.append(CameraPose(meta=None, mat=ld[i])) | |
traj_to_register.append(ld[i]) | |
elif traj_path.endswith('.json'): # instant-npg or sdfstudio format | |
import json | |
with open(traj_path, encoding='UTF-8') as f: | |
meta = json.load(f) | |
poses_dict = {} | |
for i, frame in enumerate(meta['frames']): | |
filepath = frame['file_path'] | |
new_i = int(filepath[13:18]) - 1 | |
poses_dict[new_i] = np.array(frame['transform_matrix']) | |
poses = [] | |
for i in range(len(poses_dict)): | |
poses.append(poses_dict[i]) | |
poses = torch.from_numpy(np.array(poses).astype(np.float32)) | |
poses, _ = auto_orient_and_center_poses(poses, method='up', center_poses=True) | |
scale_factor = 1.0 / float(torch.max(torch.abs(poses[:, :3, 3]))) | |
poses[:, :3, 3] *= scale_factor | |
poses = poses.numpy() | |
for i in range(len(poses)): | |
traj_to_register.append(poses[i]) | |
else: | |
traj_to_register = read_trajectory(traj_path) | |
# with open("test.xyz","w") as file_object: | |
# for m in traj_to_register: | |
# # p = - m[:3,:3].T @ m[:3,3:] | |
# # p = p[:,0] | |
# p = m[:3,-1] | |
# print("%f %f %f"%(p[0],p[1],p[2]),file=file_object) | |
for i in range(len(traj_to_register)): | |
c2w = torch.from_numpy(traj_to_register[i]).float() | |
if c2w.shape == (3, 4): | |
c2w = torch.cat([ | |
c2w, | |
torch.tensor([[0, 0, 0, 1]]).float() | |
], dim=0) | |
traj_to_register[i] = c2w # [4, 4] | |
print(f'Trajectory loaded from {traj_path}, including {len(traj_to_register)} camera views.') | |
return traj_to_register | |
if __name__ == "__main__": | |
print('Start culling...') | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--traj-path", | |
type=str, | |
required=True, | |
help= | |
"path to trajectory file. See `convert_to_logfile.py` to create this file.", | |
) | |
parser.add_argument( | |
"--ply-path", | |
type=str, | |
required=True, | |
help="path to reconstruction ply file", | |
) | |
args = parser.parse_args() | |
estimate_c2w_list = get_traj(args.traj_path) | |
# for TanksandTemples dataset | |
H, W = 1080, 1920 | |
fx = 1163.8678928442187 | |
fy = 1172.793101201448 | |
cx = 962.3120628412543 | |
cy = 542.0667209577691 | |
far = 20.0 | |
mesher = Mesher(H, W, fx, fy, cx, cy, far, points_batch_size=5e5) | |
# mesher = Mesher(H*2, W*2, fx*2, fy*2, cx*2, cy*2, far, points_batch_size=5e5) | |
mesher(args.ply_path, estimate_c2w_list) | |
print('Done!') | |