zach
initial commit based on github repo
3ef1661
import torch
import torch.nn as nn
import numpy as np
class PWNPlanesLoss(nn.Module):
"""
Virtual Normal Loss Function.
"""
def __init__(self, delta_cos=0.867, delta_diff_x=0.007,
delta_diff_y=0.007, sample_groups=5000, loss_weight=1.0, data_type=['lidar', 'denselidar'], **kwargs):
"""
Virtual normal planes loss, which constrain points to be on the same 3D plane.
:para focal_x: folcal length fx
:para focal_y: folcal length fy
:para input_size: input image size
:para delta_cos: a threshold for the angle among three point, three points should not be on the same plane
:para delta_diff_x: a threshold for the distance among three points along the x axis
:para delta_diff_y: a threshold for the distance among three points along the y axis
:para sample_groups: sample groups number, each group with 3 points can construct a plane
"""
super(PWNPlanesLoss, self).__init__()
self.delta_cos = delta_cos
self.delta_diff_x = delta_diff_x
self.delta_diff_y = delta_diff_y
self.sample_groups = sample_groups
self.loss_weight = loss_weight
self.data_type = data_type
def init_image_coor(self, B, H, W):
u = torch.arange(0, H, dtype=torch.float32, device="cuda").contiguous().view(1, H, 1).expand(1, H, W) # [1, H, W]
v = torch.arange(0, W, dtype=torch.float32, device="cuda").contiguous().view(1, 1, W).expand(1, H, W) # [1, H, W]
ones = torch.ones((1, H, W), dtype=torch.float32, device="cuda")
pixel_coords = torch.stack((u, v, ones), dim=1).expand(B, 3, H, W) # [B, 3, H, W]
# self.register_buffer('uv', pixel_coords, persistent=False)
self.uv = pixel_coords
def upproj_pcd(self, depth, intrinsics_inv):
"""Transform coordinates in the pixel frame to the camera frame.
Args:
depth: depth maps -- [B, 1, H, W]
intrinsics_inv: intrinsics_inv matrix for each element of batch -- [B, 3, 3]
Returns:
array of (u,v,1) cam coordinates -- [B, 3, H, W]
"""
b, _, h, w = depth.size()
assert self.uv.shape[0] == b
current_pixel_coords = self.uv.reshape(b, 3, -1) # [B, 3, H*W]
cam_coords = (intrinsics_inv @ current_pixel_coords)
cam_coords = cam_coords.reshape(b, 3, h, w)
out = depth * cam_coords
return out
# def transfer_xyz(self, depth):
# x = self.u_u0 * torch.abs(depth) / self.focal_length
# y = self.v_v0 * torch.abs(depth) / self.focal_length
# z = depth
# pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1).contiguous() # [b, h, w, c]
# return pw
# def transfer_uvz(self, depth):
# max_uv = self.u_u0.max()
# u = self.u_u0.repeat((depth.shape[0], 1, 1, 1)) / max_uv
# v = self.v_v0.repeat((depth.shape[0], 1, 1, 1)) / max_uv
# z = depth
# pw = torch.cat([u, v, z], 1).permute(0, 2, 3, 1).contiguous() # [b, h, w, c]
# return pw
def select_index(self, mask_kp):
x, _, h, w = mask_kp.shape
select_size = int(3 * self.sample_groups)
p1_x = []
p1_y = []
p2_x = []
p2_y = []
p3_x = []
p3_y = []
valid_batch = torch.ones((x, 1), dtype=torch.bool, device="cuda")
for i in range(x):
mask_kp_i = mask_kp[i, 0, :, :]
valid_points = torch.nonzero(mask_kp_i)
if valid_points.shape[0] < select_size * 0.6:
valid_points = torch.nonzero(~mask_kp_i.to(torch.uint8))
valid_batch[i, :] = False
elif valid_points.shape[0] < select_size:
repeat_idx = torch.randperm(valid_points.shape[0], device="cuda")[:select_size - valid_points.shape[0]]
valid_repeat = valid_points[repeat_idx]
valid_points = torch.cat((valid_points, valid_repeat), 0)
else:
valid_points = valid_points
"""
if valid_points.shape[0] <= select_size:
valid_points = torch.nonzero(~mask_kp_i.to(torch.uint8))
valid_batch[i, :] = False
"""
select_indx = torch.randperm(valid_points.size(0), device="cuda")
p1 = valid_points[select_indx[0:select_size:3]]
p2 = valid_points[select_indx[1:select_size:3]]
p3 = valid_points[select_indx[2:select_size:3]]
p1_x.append(p1[:, 1])
p1_y.append(p1[:, 0])
p2_x.append(p2[:, 1])
p2_y.append(p2[:, 0])
p3_x.append(p3[:, 1])
p3_y.append(p3[:, 0])
p123 = {'p1_x': torch.stack(p1_x), 'p1_y': torch.stack(p1_y),
'p2_x': torch.stack(p2_x), 'p2_y': torch.stack(p2_y),
'p3_x': torch.stack(p3_x), 'p3_y': torch.stack(p3_y),
'valid_batch': valid_batch}
return p123
def form_pw_groups(self, p123, pw):
"""
Form 3D points groups, with 3 points in each grouup.
:param p123: points index
:param pw: 3D points, # [1, h, w, c]
:return:
"""
p1_x = p123['p1_x']
p1_y = p123['p1_y']
p2_x = p123['p2_x']
p2_y = p123['p2_y']
p3_x = p123['p3_x']
p3_y = p123['p3_y']
batch_list = torch.arange(0, p1_x.shape[0], device="cuda")[:, None]
pw = pw.repeat((p1_x.shape[0], 1, 1, 1))
pw1 = pw[batch_list, p1_y, p1_x, :]
pw2 = pw[batch_list, p2_y, p2_x, :]
pw3 = pw[batch_list, p3_y, p3_x, :]
# [B, N, 3(x,y,z), 3(p1,p2,p3)]
pw_groups = torch.cat([pw1[:, :, :, None], pw2[:, :, :, None], pw3[:, :, :, None]], 3)
return pw_groups
def filter_mask(self, pw_pred):
"""
:param pw_pred: constructed 3d vector (x, y, disp), [B, N, 3(x,y,z), 3(p1,p2,p3)]
"""
xy12 = pw_pred[:, :, 0:2, 1] - pw_pred[:, :, 0:2, 0]
xy13 = pw_pred[:, :, 0:2, 2] - pw_pred[:, :, 0:2, 0]
xy23 = pw_pred[:, :, 0:2, 2] - pw_pred[:, :, 0:2, 1]
# Ignore linear
xy_diff = torch.cat([xy12[:, :, :, np.newaxis], xy13[:, :, :, np.newaxis], xy23[:, :, :, np.newaxis]],
3) # [b, n, 2(xy), 3]
m_batchsize, groups, coords, index = xy_diff.shape
proj_query = xy_diff.contiguous().view(m_batchsize * groups, -1, index).permute(0, 2, 1).contiguous() # [bn, 3(p123), 2(xy)]
proj_key = xy_diff.contiguous().view(m_batchsize * groups, -1, index) # [bn, 2(xy), 3(p123)]
q_norm = proj_query.norm(2, dim=2) # [bn, 3(p123)]
nm = torch.bmm(q_norm.contiguous().view(m_batchsize * groups, index, 1), q_norm.contiguous().view(m_batchsize * groups, 1, index)) # []
energy = torch.bmm(proj_query, proj_key) # transpose check [bn, 3(p123), 3(p123)]
norm_energy = energy / (nm + 1e-8)
norm_energy = norm_energy.contiguous().view(m_batchsize * groups, -1) # [bn, 9(p123)]
mask_cos = torch.sum((norm_energy > self.delta_cos) + (norm_energy < -self.delta_cos), 1) > 3 # igonre
mask_cos = mask_cos.contiguous().view(m_batchsize, groups) # [b, n] # igonre
#ignore near
mask_x = torch.sum(torch.abs(xy_diff[:, :, 0, :]) < self.delta_diff_x, 2) > 0
mask_y = torch.sum(torch.abs(xy_diff[:, :, 1, :]) < self.delta_diff_y, 2) > 0
mask_near = mask_x & mask_y
mask_valid_pts = ~(mask_cos | mask_near)
return mask_valid_pts
def select_points_groups(self, pcd_bi, mask_kp):
p123 = self.select_index(mask_kp) # p1_x: [x, n]
pcd_bi = pcd_bi.permute((0, 2, 3, 1)).contiguous() #[1, h, w, 3(xyz)]
groups_pred = self.form_pw_groups(p123, pcd_bi) # [x, N, 3(x,y,z), 3(p1,p2,p3)]
# mask:[x, n]
mask_valid_pts = (self.filter_mask(groups_pred)).to(torch.bool) # [x, n]
mask_valid_batch = p123['valid_batch'].repeat(1, mask_valid_pts.shape[1]) # [x, n]
mask_valid = mask_valid_pts & mask_valid_batch # [x, n]
return groups_pred, mask_valid
def constrain_a_plane_loss(self, pw_groups_pre_i, mask_valid_i):
"""
pw_groups_pre: selected points groups for the i-th plane, [N, 3(x,y,z), 3(p1,p2,p3)]
"""
if torch.sum(mask_valid_i) < 2:
return 0.0 * torch.sum(pw_groups_pre_i), 0
pw_groups_pred_i = pw_groups_pre_i[mask_valid_i] # [n, 3, 3]
p12 = pw_groups_pred_i[:, :, 1] - pw_groups_pred_i[:, :, 0]
p13 = pw_groups_pred_i[:, :, 2] - pw_groups_pred_i[:, :, 0]
virtual_normal = torch.cross(p12, p13, dim=1) # [n, 3]
norm = torch.norm(virtual_normal, 2, dim=1, keepdim=True)
virtual_normal = virtual_normal / (norm + 1e-8)
# re-orient normals consistently
orient_mask = torch.sum(torch.squeeze(virtual_normal) * torch.squeeze(pw_groups_pred_i[:, :, 0]), dim=1) > 0
virtual_normal[orient_mask] *= -1
#direct = virtual_normal[:, 2] / torch.abs(virtual_normal[:, 2])
#virtual_normal = virtual_normal / direct[:, None] # [n, 3]
aver_normal = torch.sum(virtual_normal, dim=0)
aver_norm = torch.norm(aver_normal, 2, dim=0, keepdim=True)
aver_normal = aver_normal / (aver_norm + 1e-5) # [3]
cos_diff = 1.0 - torch.sum(virtual_normal * aver_normal, dim=1)
loss_sum = torch.sum(cos_diff, dim=0)
valid_num = cos_diff.numel()
return loss_sum, valid_num
def get_loss(self, pred_depth, gt_depth, ins_planes_mask, intrinsic=None):
"""
Co-plane loss. Enforce points residing on the same instance plane to be co-plane.
:param pred_depth: predicted depth map, [B,C,H,W]
:param mask: mask for planes, each plane is noted with a value, [B, C, H, W]
:param focal_length: focal length
"""
if pred_depth.ndim==3:
pred_depth = pred_depth[None, ...]
if gt_depth.ndim == 3:
gt_depth = gt_depth[None, ...]
if ins_planes_mask.ndim == 3:
ins_planes_mask = ins_planes_mask[None, ...]
B, _, H, W = pred_depth.shape
loss_sum = torch.tensor(0.0, device="cuda")
valid_planes_num = 0
#if 'uv' not in self._buffers or ('uv' in self._buffers and self.uv.shape[0] != B):
self.init_image_coor(B, H, W)
pcd = self.upproj_pcd(pred_depth, intrinsic.inverse())
for i in range(B):
mask_i = ins_planes_mask[i, :][None, :, :]
unique_planes = torch.unique(mask_i)
planes = [mask_i == m for m in unique_planes if m != 0] #[x, 1, h, w] x is the planes number
if len(planes) == 0:
continue
mask_planes = torch.cat(planes, dim=0) #torch.stack(planes, dim=0) #
pcd_grops_pred, mask_valid = self.select_points_groups(pcd[i, ...][None, :, :, :], mask_planes) # [x, N, 3(x,y,z), 3(p1,p2,p3)]
for j in range(unique_planes.numel()-1):
mask_valid_j = mask_valid[j, :]
pcd_grops_pred_j = pcd_grops_pred[j, :]
loss_tmp, valid_angles = self.constrain_a_plane_loss(pcd_grops_pred_j, mask_valid_j)
valid_planes_num += valid_angles
loss_sum += loss_tmp
loss = loss_sum / (valid_planes_num + 1e-6) * self.loss_weight
if torch.isnan(loss).item() | torch.isinf(loss).item():
loss = torch.sum(pred_depth) * 0
print(f'PWNPlane NAN error, {loss}')
return loss
def forward(self, prediction, target, mask, intrinsic, **kwargs): #gt_depth, pred_depth, select=True):
"""
Virtual normal loss.
:param prediction: predicted depth map, [B,W,H,C]
:param data: target label, ground truth depth, [B, W, H, C], padding region [padding_up, padding_down]
:return:
"""
dataset = kwargs['dataset']
batch_mask = np.array(dataset) == 'Taskonomy'
if np.sum(batch_mask) == 0:
return torch.sum(prediction) * 0.0
ins_planes_mask = kwargs['sem_mask'] #
assert ins_planes_mask.ndim == 4
loss = self.get_loss(
prediction[batch_mask],
target[batch_mask],
ins_planes_mask[batch_mask],
intrinsic[batch_mask],
)
return loss
if __name__ == '__main__':
import cv2
vnl_loss = PWNPlanesLoss()
pred_depth = torch.rand([2, 1, 385, 513]).cuda()
gt_depth = torch.rand([2, 1, 385, 513]).cuda()
gt_depth[:, :, 3:20, 40:60] = 0
mask_kp1 = pred_depth > 0.9
mask_kp2 = pred_depth < 0.5
mask = torch.zeros_like(gt_depth, dtype=torch.uint8)
mask = 1*mask_kp1 + 2* mask_kp2
mask[1,...] = 0
intrinsic = torch.tensor([[100, 0, 50], [0, 100, 50,], [0,0,1]]).cuda().float()
intrins = torch.stack([intrinsic, intrinsic], dim=0)
loss = vnl_loss(gt_depth, gt_depth, mask, intrins, dataset=np.array(['Taskonomy', 'Taskonomy']))
print(loss)