import numpy as np |
import torch |
import torch.nn as nn |
class Backprojection(nn.Module): |
"""Layer to backproject a depth image given the camera intrinsics |
Attributes |
xy (Nx3x(HxW)): homogeneous pixel coordinates on regular grid |
""" |
def __init__(self, height, width): |
""" |
Args: |
height (int): image height |
width (int): image width |
""" |
super(Backprojection, self).__init__() |
self.height = height |
self.width = width |
meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy') |
id_coords = np.stack(meshgrid, axis=0).astype(np.float32) |
id_coords = torch.tensor(id_coords, device="cuda") |
ones = torch.ones(1, 1, self.height * self.width, device="cuda") |
xy = torch.unsqueeze( |
torch.stack([id_coords[0].view(-1), id_coords[1].view(-1)], 0), |
0 |
) |
xy = torch.cat([xy, ones], 1) |
self.register_buffer('xy', xy, persistent=False) |
self.register_buffer('ones', ones, persistent=False) |
horizontal_angle_range=[195.0, -15.0] |
vertical_angle_range=[150.0, 0.0] |
horizontal_sample_num=641 |
vertical_sample_num=481 |
self.horizontal_angle_range = horizontal_angle_range |
self.vertical_angle_range = vertical_angle_range |
self.horizontal_sample_num = horizontal_sample_num |
self.vertical_sample_num = vertical_sample_num |
self.horizontal_step = (self.horizontal_angle_range[1] - self.horizontal_angle_range[0]) / ( |
self.horizontal_sample_num - 1) |
self.vertical_step = (self.vertical_angle_range[1] - self.vertical_angle_range[0]) / ( |
self.vertical_sample_num - 1) |
self.horizontal_samples = np.arange(self.horizontal_angle_range[0], self.horizontal_angle_range[1], |
self.horizontal_step) |
self.vertical_samples = np.arange(self.vertical_angle_range[0], self.vertical_angle_range[1], |
self.vertical_step) |
horizontal_samples_in_rad = self.horizontal_samples / 180.0 * np.pi |
vertical_samples_in_rad = self.vertical_samples / 180.0 * np.pi |
virt_H = len(self.vertical_samples) |
virt_W = len(self.horizontal_samples) |
self.virt_H, self.virt_W = virt_H, virt_W |
cos_theta = np.tile(np.cos(vertical_samples_in_rad).reshape(-1, 1), (1, virt_W)) |
sin_theta = np.tile(np.sin(vertical_samples_in_rad).reshape(-1, 1), (1, virt_W)) |
cos_phi = np.tile(np.cos(horizontal_samples_in_rad).reshape(1, -1), (virt_H, 1)) |
sin_phi = np.tile(np.sin(horizontal_samples_in_rad).reshape(1, -1), (virt_H, 1)) |
x = (sin_theta * cos_phi).reshape(1, virt_H, virt_W) |
y = cos_theta.reshape(1, virt_H, virt_W) |
z = (sin_theta * sin_phi).reshape(1, virt_H, virt_W) |
self.dir_in_virt_cam = np.concatenate((x, y, z), axis=0) |
self.dir_in_virt_cam = self.dir_in_virt_cam.reshape(3, self.virt_H * self.virt_W) |
def forward(self, depth, inv_K, img_like_out=False): |
""" |
Args: |
depth (Nx1xHxW): depth map |
inv_K (Nx4x4): inverse camera intrinsics |
img_like_out (bool): if True, the output shape is Nx4xHxW; else Nx4x(HxW) |
Returns: |
points (Nx4x(HxW)): 3D points in homogeneous coordinates |
""" |
depth = depth.contiguous() |
xy = self.xy.repeat(depth.shape[0], 1, 1) |
ones = self.ones.repeat(depth.shape[0],1,1) |
points = torch.matmul(inv_K[:, :3, :3], xy) |
points = depth.view(depth.shape[0], 1, -1) * points |
points = torch.cat([points, ones], 1) |
if img_like_out: |
points = points.reshape(depth.shape[0], 4, self.height, self.width) |
return points |
def get_surface_normalv2(xyz, patch_size=5, mask_valid=None): |
""" |
xyz: xyz coordinates, in [b, h, w, c] |
patch: [p1, p2, p3, |
p4, p5, p6, |
p7, p8, p9] |
surface_normal = [(p9-p1) x (p3-p7)] + [(p6-p4) - (p8-p2)] |
return: normal [h, w, 3, b] |
""" |
b, h, w, c = xyz.shape |
half_patch = patch_size // 2 |
if mask_valid == None: |
mask_valid = xyz[:, :, :, 2] > 0 |
mask_pad = torch.zeros((b, h + patch_size - 1, w + patch_size - 1), device=mask_valid.device).bool() |
mask_pad[:, half_patch:-half_patch, half_patch:-half_patch] = mask_valid |
xyz_pad = torch.zeros((b, h + patch_size - 1, w + patch_size - 1, c), dtype=xyz.dtype, device=xyz.device) |
xyz_pad[:, half_patch:-half_patch, half_patch:-half_patch, :] = xyz |
xyz_left = xyz_pad[:, half_patch:half_patch + h, :w, :] |
xyz_right = xyz_pad[:, half_patch:half_patch + h, -w:, :] |
xyz_top = xyz_pad[:, :h, half_patch:half_patch + w, :] |
xyz_bottom = xyz_pad[:, -h:, half_patch:half_patch + w, :] |
xyz_horizon = xyz_left - xyz_right |
xyz_vertical = xyz_top - xyz_bottom |
xyz_left_in = xyz_pad[:, half_patch:half_patch + h, 1:w+1, :] |
xyz_right_in = xyz_pad[:, half_patch:half_patch + h, patch_size-1:patch_size-1+w, :] |
xyz_top_in = xyz_pad[:, 1:h+1, half_patch:half_patch + w, :] |
xyz_bottom_in = xyz_pad[:, patch_size-1:patch_size-1+h, half_patch:half_patch + w, :] |
xyz_horizon_in = xyz_left_in - xyz_right_in |
xyz_vertical_in = xyz_top_in - xyz_bottom_in |
n_img_1 = torch.cross(xyz_horizon_in, xyz_vertical_in, dim=3) |
n_img_2 = torch.cross(xyz_horizon, xyz_vertical, dim=3) |
orient_mask = torch.sum(n_img_1 * xyz, dim=3) > 0 |
n_img_1[orient_mask] *= -1 |
orient_mask = torch.sum(n_img_2 * xyz, dim=3) > 0 |
n_img_2[orient_mask] *= -1 |
n_img1_L2 = torch.sqrt(torch.sum(n_img_1 ** 2, dim=3, keepdim=True) + 1e-4) |
n_img1_norm = n_img_1 / (n_img1_L2 + 1e-8) |
n_img2_L2 = torch.sqrt(torch.sum(n_img_2 ** 2, dim=3, keepdim=True) + 1e-4) |
n_img2_norm = n_img_2 / (n_img2_L2 + 1e-8) |
n_img_aver = n_img1_norm + n_img2_norm |
n_img_aver_L2 = torch.sqrt(torch.sum(n_img_aver ** 2, dim=3, keepdim=True) + 1e-4) |
n_img_aver_norm = n_img_aver / (n_img_aver_L2 + 1e-8) |
orient_mask = torch.sum(n_img_aver_norm * xyz, dim=3) > 0 |
n_img_aver_norm[orient_mask] *= -1 |
mask_p4p6 = mask_pad[:, half_patch:half_patch + h, :w] & mask_pad[:, half_patch:half_patch + h, -w:] |
mask_p2p8 = mask_pad[:, :h, half_patch:half_patch + w] & mask_pad[:, -h:, half_patch:half_patch + w] |
mask_normal = mask_p2p8 & mask_p4p6 |
n_img_aver_norm[~mask_normal] = 0 |
return n_img_aver_norm.permute(0, 3, 1, 2).contiguous(), mask_normal[:, None, :, :] |
class Depth2Normal(nn.Module): |
"""Layer to compute surface normal from depth map |
""" |
def __init__(self,): |
""" |
Args: |
height (int): image height |
width (int): image width |
""" |
super(Depth2Normal, self).__init__() |
def init_img_coor(self, height, width): |
""" |
Args: |
height (int): image height |
width (int): image width |
""" |
y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device="cuda"), |
torch.arange(0, width, dtype=torch.float32, device="cuda")], indexing='ij') |
meshgrid = torch.stack((x, y)) |
ones = torch.ones((1, 1, height * width), device="cuda") |
xy = meshgrid.reshape(2, -1).unsqueeze(0) |
xy = torch.cat([xy, ones], 1) |
self.register_buffer('xy', xy, persistent=False) |
def back_projection(self, depth, inv_K, img_like_out=False, scale=1.0): |
""" |
Args: |
depth (Nx1xHxW): depth map |
inv_K (Nx4x4): inverse camera intrinsics |
img_like_out (bool): if True, the output shape is Nx4xHxW; else Nx4x(HxW) |
Returns: |
points (Nx4x(HxW)): 3D points in homogeneous coordinates |
""" |
B, C, H, W = depth.shape |
depth = depth.contiguous() |
xy = self.xy |
points = torch.matmul(inv_K[:, :3, :3], xy) |
points = depth.view(depth.shape[0], 1, -1) * points |
depth_descale = points[:, 2:3, :] / scale |
points = torch.cat((points[:, 0:2, :], depth_descale), dim=1) |
if img_like_out: |
points = points.reshape(depth.shape[0], 3, H, W) |
return points |
def forward(self, depth, intrinsics, masks, scale): |
""" |
Args: |
depth (Nx1xHxW): depth map |
#inv_K (Nx4x4): inverse camera intrinsics |
intrinsics (Nx4): camera intrinsics |
Returns: |
normal (Nx3xHxW): normalized surface normal |
mask (Nx1xHxW): valid mask for surface normal |
""" |
B, C, H, W = depth.shape |
if 'xy' not in self._buffers or self.xy.shape[-1] != H*W: |
self.init_img_coor(height=H, width=W) |
inv_K = intrinsics.inverse() |
xyz = self.back_projection(depth, inv_K, scale=scale) |
xyz = xyz.view(depth.shape[0], 3, H, W) |
xyz = xyz[:,:3].permute(0, 2, 3, 1).contiguous() |
normals, normal_masks = get_surface_normalv2(xyz, mask_valid=masks.squeeze()) |
normal_masks = normal_masks & masks |
return normals, normal_masks |
if __name__ == '__main__': |
d2n = Depth2Normal() |
depth = np.random.randn(2, 1, 20, 22) |
intrin = np.array([[300, 0, 10], [0, 300, 10], [0,0,1]]) |
intrinsics = np.stack([intrin, intrin], axis=0) |
depth_t = torch.from_numpy(depth).cuda().float() |
intrinsics = torch.from_numpy(intrinsics).cuda().float() |
normal = d2n(depth_t, intrinsics) |
normal2 = d2n(depth_t, intrinsics) |
print(normal) |