Metric3D / training /mono /model /losses /depth_to_normal.py
zach
initial commit based on github repo
3ef1661
import numpy as np
import torch
import torch.nn as nn
class Backprojection(nn.Module):
"""Layer to backproject a depth image given the camera intrinsics
Attributes
xy (Nx3x(HxW)): homogeneous pixel coordinates on regular grid
"""
def __init__(self, height, width):
"""
Args:
height (int): image height
width (int): image width
"""
super(Backprojection, self).__init__()
self.height = height
self.width = width
# generate regular grid
meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
id_coords = torch.tensor(id_coords, device="cuda")
# generate homogeneous pixel coordinates
# self.ones = nn.Parameter(torch.ones(1, 1, self.height * self.width),
# requires_grad=False)
ones = torch.ones(1, 1, self.height * self.width, device="cuda")
xy = torch.unsqueeze(
torch.stack([id_coords[0].view(-1), id_coords[1].view(-1)], 0),
0
)
xy = torch.cat([xy, ones], 1)
#self.xy = nn.Parameter(self.xy, requires_grad=False)
self.register_buffer('xy', xy, persistent=False)
self.register_buffer('ones', ones, persistent=False)
# for virtual camera only
horizontal_angle_range=[195.0, -15.0]
vertical_angle_range=[150.0, 0.0]
horizontal_sample_num=641
vertical_sample_num=481
self.horizontal_angle_range = horizontal_angle_range
self.vertical_angle_range = vertical_angle_range
self.horizontal_sample_num = horizontal_sample_num
self.vertical_sample_num = vertical_sample_num
self.horizontal_step = (self.horizontal_angle_range[1] - self.horizontal_angle_range[0]) / (
self.horizontal_sample_num - 1)
self.vertical_step = (self.vertical_angle_range[1] - self.vertical_angle_range[0]) / (
self.vertical_sample_num - 1)
self.horizontal_samples = np.arange(self.horizontal_angle_range[0], self.horizontal_angle_range[1],
self.horizontal_step)
self.vertical_samples = np.arange(self.vertical_angle_range[0], self.vertical_angle_range[1],
self.vertical_step)
horizontal_samples_in_rad = self.horizontal_samples / 180.0 * np.pi
vertical_samples_in_rad = self.vertical_samples / 180.0 * np.pi
virt_H = len(self.vertical_samples)
virt_W = len(self.horizontal_samples)
self.virt_H, self.virt_W = virt_H, virt_W
cos_theta = np.tile(np.cos(vertical_samples_in_rad).reshape(-1, 1), (1, virt_W))
sin_theta = np.tile(np.sin(vertical_samples_in_rad).reshape(-1, 1), (1, virt_W))
cos_phi = np.tile(np.cos(horizontal_samples_in_rad).reshape(1, -1), (virt_H, 1))
sin_phi = np.tile(np.sin(horizontal_samples_in_rad).reshape(1, -1), (virt_H, 1))
x = (sin_theta * cos_phi).reshape(1, virt_H, virt_W)
y = cos_theta.reshape(1, virt_H, virt_W)
z = (sin_theta * sin_phi).reshape(1, virt_H, virt_W)
self.dir_in_virt_cam = np.concatenate((x, y, z), axis=0)
self.dir_in_virt_cam = self.dir_in_virt_cam.reshape(3, self.virt_H * self.virt_W)
def forward(self, depth, inv_K, img_like_out=False):
"""
Args:
depth (Nx1xHxW): depth map
inv_K (Nx4x4): inverse camera intrinsics
img_like_out (bool): if True, the output shape is Nx4xHxW; else Nx4x(HxW)
Returns:
points (Nx4x(HxW)): 3D points in homogeneous coordinates
"""
depth = depth.contiguous()
xy = self.xy.repeat(depth.shape[0], 1, 1)
ones = self.ones.repeat(depth.shape[0],1,1)
points = torch.matmul(inv_K[:, :3, :3], xy)
points = depth.view(depth.shape[0], 1, -1) * points
points = torch.cat([points, ones], 1)
if img_like_out:
points = points.reshape(depth.shape[0], 4, self.height, self.width)
return points
def get_surface_normalv2(xyz, patch_size=5, mask_valid=None):
"""
xyz: xyz coordinates, in [b, h, w, c]
patch: [p1, p2, p3,
p4, p5, p6,
p7, p8, p9]
surface_normal = [(p9-p1) x (p3-p7)] + [(p6-p4) - (p8-p2)]
return: normal [h, w, 3, b]
"""
b, h, w, c = xyz.shape
half_patch = patch_size // 2
if mask_valid == None:
mask_valid = xyz[:, :, :, 2] > 0 # [b, h, w]
mask_pad = torch.zeros((b, h + patch_size - 1, w + patch_size - 1), device=mask_valid.device).bool()
mask_pad[:, half_patch:-half_patch, half_patch:-half_patch] = mask_valid
xyz_pad = torch.zeros((b, h + patch_size - 1, w + patch_size - 1, c), dtype=xyz.dtype, device=xyz.device)
xyz_pad[:, half_patch:-half_patch, half_patch:-half_patch, :] = xyz
xyz_left = xyz_pad[:, half_patch:half_patch + h, :w, :] # p4
xyz_right = xyz_pad[:, half_patch:half_patch + h, -w:, :] # p6
xyz_top = xyz_pad[:, :h, half_patch:half_patch + w, :] # p2
xyz_bottom = xyz_pad[:, -h:, half_patch:half_patch + w, :] # p8
xyz_horizon = xyz_left - xyz_right # p4p6
xyz_vertical = xyz_top - xyz_bottom # p2p8
xyz_left_in = xyz_pad[:, half_patch:half_patch + h, 1:w+1, :] # p4
xyz_right_in = xyz_pad[:, half_patch:half_patch + h, patch_size-1:patch_size-1+w, :] # p6
xyz_top_in = xyz_pad[:, 1:h+1, half_patch:half_patch + w, :] # p2
xyz_bottom_in = xyz_pad[:, patch_size-1:patch_size-1+h, half_patch:half_patch + w, :] # p8
xyz_horizon_in = xyz_left_in - xyz_right_in # p4p6
xyz_vertical_in = xyz_top_in - xyz_bottom_in # p2p8
n_img_1 = torch.cross(xyz_horizon_in, xyz_vertical_in, dim=3)
n_img_2 = torch.cross(xyz_horizon, xyz_vertical, dim=3)
# re-orient normals consistently
orient_mask = torch.sum(n_img_1 * xyz, dim=3) > 0
n_img_1[orient_mask] *= -1
orient_mask = torch.sum(n_img_2 * xyz, dim=3) > 0
n_img_2[orient_mask] *= -1
n_img1_L2 = torch.sqrt(torch.sum(n_img_1 ** 2, dim=3, keepdim=True) + 1e-4)
n_img1_norm = n_img_1 / (n_img1_L2 + 1e-8)
n_img2_L2 = torch.sqrt(torch.sum(n_img_2 ** 2, dim=3, keepdim=True) + 1e-4)
n_img2_norm = n_img_2 / (n_img2_L2 + 1e-8)
# average 2 norms
n_img_aver = n_img1_norm + n_img2_norm
n_img_aver_L2 = torch.sqrt(torch.sum(n_img_aver ** 2, dim=3, keepdim=True) + 1e-4)
n_img_aver_norm = n_img_aver / (n_img_aver_L2 + 1e-8)
# re-orient normals consistently
orient_mask = torch.sum(n_img_aver_norm * xyz, dim=3) > 0
n_img_aver_norm[orient_mask] *= -1
#n_img_aver_norm_out = n_img_aver_norm.permute((1, 2, 3, 0)) # [h, w, c, b]
# get mask for normals
mask_p4p6 = mask_pad[:, half_patch:half_patch + h, :w] & mask_pad[:, half_patch:half_patch + h, -w:]
mask_p2p8 = mask_pad[:, :h, half_patch:half_patch + w] & mask_pad[:, -h:, half_patch:half_patch + w]
mask_normal = mask_p2p8 & mask_p4p6
n_img_aver_norm[~mask_normal] = 0
# a = torch.sum(n_img1_norm_out*n_img2_norm_out, dim=2).cpu().numpy().squeeze()
# plt.imshow(np.abs(a), cmap='rainbow')
# plt.show()
return n_img_aver_norm.permute(0, 3, 1, 2).contiguous(), mask_normal[:, None, :, :] # [b, h, w, 3]
class Depth2Normal(nn.Module):
"""Layer to compute surface normal from depth map
"""
def __init__(self,):
"""
Args:
height (int): image height
width (int): image width
"""
super(Depth2Normal, self).__init__()
def init_img_coor(self, height, width):
"""
Args:
height (int): image height
width (int): image width
"""
y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device="cuda"),
torch.arange(0, width, dtype=torch.float32, device="cuda")], indexing='ij')
meshgrid = torch.stack((x, y))
# # generate regular grid
# meshgrid = np.meshgrid(range(width), range(height), indexing='xy')
# id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
# id_coords = torch.tensor(id_coords)
# generate homogeneous pixel coordinates
ones = torch.ones((1, 1, height * width), device="cuda")
# xy = torch.unsqueeze(
# torch.stack([x.reshape(-1), y.reshape(-1)], 0),
# 0
# )
xy = meshgrid.reshape(2, -1).unsqueeze(0)
xy = torch.cat([xy, ones], 1)
self.register_buffer('xy', xy, persistent=False)
def back_projection(self, depth, inv_K, img_like_out=False, scale=1.0):
"""
Args:
depth (Nx1xHxW): depth map
inv_K (Nx4x4): inverse camera intrinsics
img_like_out (bool): if True, the output shape is Nx4xHxW; else Nx4x(HxW)
Returns:
points (Nx4x(HxW)): 3D points in homogeneous coordinates
"""
B, C, H, W = depth.shape
depth = depth.contiguous()
# xy = self.init_img_coor(height=H, width=W)
xy = self.xy # xy.repeat(depth.shape[0], 1, 1)
#ones = self.ones.repeat(depth.shape[0],1,1)
points = torch.matmul(inv_K[:, :3, :3], xy)
points = depth.view(depth.shape[0], 1, -1) * points
depth_descale = points[:, 2:3, :] / scale
points = torch.cat((points[:, 0:2, :], depth_descale), dim=1)
#points = torch.cat([points, ones], 1)
if img_like_out:
points = points.reshape(depth.shape[0], 3, H, W)
return points
# def transfer_xyz(self, u0, v0, H, W, depth, focal_length):
# x_row = np.arange(0, W)
# x = np.tile(x_row, (H, 1))
# x = x.astype(np.float32)
# x = torch.from_numpy(x.copy()).cuda()
# u_m_u0 = x[None, None, :, :] - u0
# self.register_buffer('u_m_u0', u_m_u0, persistent=False)
# y_col = np.arange(0, H) # y_col = np.arange(0, height)
# y = np.tile(y_col, (W, 1)).T
# y = y.astype(np.float32)
# y = torch.from_numpy(y.copy()).cuda()
# v_m_v0 = y[None, None, :, :] - v0
# self.register_buffer('v_m_v0', v_m_v0, persistent=False)
# pix_idx_mat = torch.arange(H*W).reshape((H, W)).cuda()
# self.register_buffer('pix_idx_mat', pix_idx_mat, persistent=False)
# x = self.u_m_u0 * depth / focal_length
# y = self.v_m_v0 * depth / focal_length
# z = depth
# pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1) # [b, h, w, c]
# return pw
def forward(self, depth, intrinsics, masks, scale):
"""
Args:
depth (Nx1xHxW): depth map
#inv_K (Nx4x4): inverse camera intrinsics
intrinsics (Nx4): camera intrinsics
Returns:
normal (Nx3xHxW): normalized surface normal
mask (Nx1xHxW): valid mask for surface normal
"""
B, C, H, W = depth.shape
if 'xy' not in self._buffers or self.xy.shape[-1] != H*W:
self.init_img_coor(height=H, width=W)
# Compute 3D point cloud
inv_K = intrinsics.inverse()
xyz = self.back_projection(depth, inv_K, scale=scale) # [N, 4, HxW]
xyz = xyz.view(depth.shape[0], 3, H, W)
xyz = xyz[:,:3].permute(0, 2, 3, 1).contiguous() # [b, h, w, c]
# focal_length = intrinsics[:, 0, 0][:, None, None, None]
# u0 = intrinsics[:, 0, 2][:, None, None, None]
# v0 = intrinsics[:, 1, 2][:, None, None, None]
# xyz2 = self.transfer_xyz(u0, v0, H, W, depth, focal_length)
normals, normal_masks = get_surface_normalv2(xyz, mask_valid=masks.squeeze())
normal_masks = normal_masks & masks
return normals, normal_masks
if __name__ == '__main__':
d2n = Depth2Normal()
depth = np.random.randn(2, 1, 20, 22)
intrin = np.array([[300, 0, 10], [0, 300, 10], [0,0,1]])
intrinsics = np.stack([intrin, intrin], axis=0)
depth_t = torch.from_numpy(depth).cuda().float()
intrinsics = torch.from_numpy(intrinsics).cuda().float()
normal = d2n(depth_t, intrinsics)
normal2 = d2n(depth_t, intrinsics)
print(normal)