|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
|
|
|
|
class PWNPlanesLoss(nn.Module): |
|
""" |
|
Virtual Normal Loss Function. |
|
""" |
|
def __init__(self, delta_cos=0.867, delta_diff_x=0.007, |
|
delta_diff_y=0.007, sample_groups=5000, loss_weight=1.0, data_type=['lidar', 'denselidar'], **kwargs): |
|
""" |
|
Virtual normal planes loss, which constrain points to be on the same 3D plane. |
|
:para focal_x: folcal length fx |
|
:para focal_y: folcal length fy |
|
:para input_size: input image size |
|
:para delta_cos: a threshold for the angle among three point, three points should not be on the same plane |
|
:para delta_diff_x: a threshold for the distance among three points along the x axis |
|
:para delta_diff_y: a threshold for the distance among three points along the y axis |
|
:para sample_groups: sample groups number, each group with 3 points can construct a plane |
|
""" |
|
super(PWNPlanesLoss, self).__init__() |
|
self.delta_cos = delta_cos |
|
self.delta_diff_x = delta_diff_x |
|
self.delta_diff_y = delta_diff_y |
|
self.sample_groups = sample_groups |
|
self.loss_weight = loss_weight |
|
self.data_type = data_type |
|
|
|
def init_image_coor(self, B, H, W): |
|
u = torch.arange(0, H, dtype=torch.float32, device="cuda").contiguous().view(1, H, 1).expand(1, H, W) |
|
v = torch.arange(0, W, dtype=torch.float32, device="cuda").contiguous().view(1, 1, W).expand(1, H, W) |
|
ones = torch.ones((1, H, W), dtype=torch.float32, device="cuda") |
|
pixel_coords = torch.stack((u, v, ones), dim=1).expand(B, 3, H, W) |
|
|
|
self.uv = pixel_coords |
|
|
|
def upproj_pcd(self, depth, intrinsics_inv): |
|
"""Transform coordinates in the pixel frame to the camera frame. |
|
Args: |
|
depth: depth maps -- [B, 1, H, W] |
|
intrinsics_inv: intrinsics_inv matrix for each element of batch -- [B, 3, 3] |
|
Returns: |
|
array of (u,v,1) cam coordinates -- [B, 3, H, W] |
|
""" |
|
b, _, h, w = depth.size() |
|
assert self.uv.shape[0] == b |
|
current_pixel_coords = self.uv.reshape(b, 3, -1) |
|
cam_coords = (intrinsics_inv @ current_pixel_coords) |
|
cam_coords = cam_coords.reshape(b, 3, h, w) |
|
out = depth * cam_coords |
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def select_index(self, mask_kp): |
|
x, _, h, w = mask_kp.shape |
|
|
|
select_size = int(3 * self.sample_groups) |
|
p1_x = [] |
|
p1_y = [] |
|
p2_x = [] |
|
p2_y = [] |
|
p3_x = [] |
|
p3_y = [] |
|
valid_batch = torch.ones((x, 1), dtype=torch.bool, device="cuda") |
|
for i in range(x): |
|
mask_kp_i = mask_kp[i, 0, :, :] |
|
valid_points = torch.nonzero(mask_kp_i) |
|
|
|
if valid_points.shape[0] < select_size * 0.6: |
|
valid_points = torch.nonzero(~mask_kp_i.to(torch.uint8)) |
|
valid_batch[i, :] = False |
|
elif valid_points.shape[0] < select_size: |
|
repeat_idx = torch.randperm(valid_points.shape[0], device="cuda")[:select_size - valid_points.shape[0]] |
|
valid_repeat = valid_points[repeat_idx] |
|
valid_points = torch.cat((valid_points, valid_repeat), 0) |
|
else: |
|
valid_points = valid_points |
|
""" |
|
|
|
if valid_points.shape[0] <= select_size: |
|
valid_points = torch.nonzero(~mask_kp_i.to(torch.uint8)) |
|
valid_batch[i, :] = False |
|
""" |
|
select_indx = torch.randperm(valid_points.size(0), device="cuda") |
|
|
|
p1 = valid_points[select_indx[0:select_size:3]] |
|
p2 = valid_points[select_indx[1:select_size:3]] |
|
p3 = valid_points[select_indx[2:select_size:3]] |
|
|
|
p1_x.append(p1[:, 1]) |
|
p1_y.append(p1[:, 0]) |
|
|
|
p2_x.append(p2[:, 1]) |
|
p2_y.append(p2[:, 0]) |
|
|
|
p3_x.append(p3[:, 1]) |
|
p3_y.append(p3[:, 0]) |
|
p123 = {'p1_x': torch.stack(p1_x), 'p1_y': torch.stack(p1_y), |
|
'p2_x': torch.stack(p2_x), 'p2_y': torch.stack(p2_y), |
|
'p3_x': torch.stack(p3_x), 'p3_y': torch.stack(p3_y), |
|
'valid_batch': valid_batch} |
|
return p123 |
|
|
|
def form_pw_groups(self, p123, pw): |
|
""" |
|
Form 3D points groups, with 3 points in each grouup. |
|
:param p123: points index |
|
:param pw: 3D points, # [1, h, w, c] |
|
:return: |
|
""" |
|
p1_x = p123['p1_x'] |
|
p1_y = p123['p1_y'] |
|
p2_x = p123['p2_x'] |
|
p2_y = p123['p2_y'] |
|
p3_x = p123['p3_x'] |
|
p3_y = p123['p3_y'] |
|
batch_list = torch.arange(0, p1_x.shape[0], device="cuda")[:, None] |
|
pw = pw.repeat((p1_x.shape[0], 1, 1, 1)) |
|
pw1 = pw[batch_list, p1_y, p1_x, :] |
|
pw2 = pw[batch_list, p2_y, p2_x, :] |
|
pw3 = pw[batch_list, p3_y, p3_x, :] |
|
|
|
|
|
pw_groups = torch.cat([pw1[:, :, :, None], pw2[:, :, :, None], pw3[:, :, :, None]], 3) |
|
return pw_groups |
|
|
|
def filter_mask(self, pw_pred): |
|
""" |
|
:param pw_pred: constructed 3d vector (x, y, disp), [B, N, 3(x,y,z), 3(p1,p2,p3)] |
|
""" |
|
xy12 = pw_pred[:, :, 0:2, 1] - pw_pred[:, :, 0:2, 0] |
|
xy13 = pw_pred[:, :, 0:2, 2] - pw_pred[:, :, 0:2, 0] |
|
xy23 = pw_pred[:, :, 0:2, 2] - pw_pred[:, :, 0:2, 1] |
|
|
|
xy_diff = torch.cat([xy12[:, :, :, np.newaxis], xy13[:, :, :, np.newaxis], xy23[:, :, :, np.newaxis]], |
|
3) |
|
m_batchsize, groups, coords, index = xy_diff.shape |
|
proj_query = xy_diff.contiguous().view(m_batchsize * groups, -1, index).permute(0, 2, 1).contiguous() |
|
proj_key = xy_diff.contiguous().view(m_batchsize * groups, -1, index) |
|
q_norm = proj_query.norm(2, dim=2) |
|
nm = torch.bmm(q_norm.contiguous().view(m_batchsize * groups, index, 1), q_norm.contiguous().view(m_batchsize * groups, 1, index)) |
|
energy = torch.bmm(proj_query, proj_key) |
|
norm_energy = energy / (nm + 1e-8) |
|
norm_energy = norm_energy.contiguous().view(m_batchsize * groups, -1) |
|
mask_cos = torch.sum((norm_energy > self.delta_cos) + (norm_energy < -self.delta_cos), 1) > 3 |
|
mask_cos = mask_cos.contiguous().view(m_batchsize, groups) |
|
|
|
|
|
mask_x = torch.sum(torch.abs(xy_diff[:, :, 0, :]) < self.delta_diff_x, 2) > 0 |
|
mask_y = torch.sum(torch.abs(xy_diff[:, :, 1, :]) < self.delta_diff_y, 2) > 0 |
|
mask_near = mask_x & mask_y |
|
mask_valid_pts = ~(mask_cos | mask_near) |
|
return mask_valid_pts |
|
|
|
def select_points_groups(self, pcd_bi, mask_kp): |
|
p123 = self.select_index(mask_kp) |
|
pcd_bi = pcd_bi.permute((0, 2, 3, 1)).contiguous() |
|
groups_pred = self.form_pw_groups(p123, pcd_bi) |
|
|
|
|
|
mask_valid_pts = (self.filter_mask(groups_pred)).to(torch.bool) |
|
mask_valid_batch = p123['valid_batch'].repeat(1, mask_valid_pts.shape[1]) |
|
mask_valid = mask_valid_pts & mask_valid_batch |
|
return groups_pred, mask_valid |
|
|
|
def constrain_a_plane_loss(self, pw_groups_pre_i, mask_valid_i): |
|
""" |
|
pw_groups_pre: selected points groups for the i-th plane, [N, 3(x,y,z), 3(p1,p2,p3)] |
|
""" |
|
if torch.sum(mask_valid_i) < 2: |
|
return 0.0 * torch.sum(pw_groups_pre_i), 0 |
|
pw_groups_pred_i = pw_groups_pre_i[mask_valid_i] |
|
p12 = pw_groups_pred_i[:, :, 1] - pw_groups_pred_i[:, :, 0] |
|
p13 = pw_groups_pred_i[:, :, 2] - pw_groups_pred_i[:, :, 0] |
|
virtual_normal = torch.cross(p12, p13, dim=1) |
|
norm = torch.norm(virtual_normal, 2, dim=1, keepdim=True) |
|
virtual_normal = virtual_normal / (norm + 1e-8) |
|
|
|
|
|
orient_mask = torch.sum(torch.squeeze(virtual_normal) * torch.squeeze(pw_groups_pred_i[:, :, 0]), dim=1) > 0 |
|
virtual_normal[orient_mask] *= -1 |
|
|
|
|
|
|
|
aver_normal = torch.sum(virtual_normal, dim=0) |
|
aver_norm = torch.norm(aver_normal, 2, dim=0, keepdim=True) |
|
aver_normal = aver_normal / (aver_norm + 1e-5) |
|
|
|
cos_diff = 1.0 - torch.sum(virtual_normal * aver_normal, dim=1) |
|
loss_sum = torch.sum(cos_diff, dim=0) |
|
valid_num = cos_diff.numel() |
|
return loss_sum, valid_num |
|
|
|
def get_loss(self, pred_depth, gt_depth, ins_planes_mask, intrinsic=None): |
|
""" |
|
Co-plane loss. Enforce points residing on the same instance plane to be co-plane. |
|
:param pred_depth: predicted depth map, [B,C,H,W] |
|
:param mask: mask for planes, each plane is noted with a value, [B, C, H, W] |
|
:param focal_length: focal length |
|
""" |
|
if pred_depth.ndim==3: |
|
pred_depth = pred_depth[None, ...] |
|
if gt_depth.ndim == 3: |
|
gt_depth = gt_depth[None, ...] |
|
if ins_planes_mask.ndim == 3: |
|
ins_planes_mask = ins_planes_mask[None, ...] |
|
|
|
B, _, H, W = pred_depth.shape |
|
loss_sum = torch.tensor(0.0, device="cuda") |
|
valid_planes_num = 0 |
|
|
|
|
|
self.init_image_coor(B, H, W) |
|
pcd = self.upproj_pcd(pred_depth, intrinsic.inverse()) |
|
|
|
for i in range(B): |
|
mask_i = ins_planes_mask[i, :][None, :, :] |
|
unique_planes = torch.unique(mask_i) |
|
planes = [mask_i == m for m in unique_planes if m != 0] |
|
if len(planes) == 0: |
|
continue |
|
mask_planes = torch.cat(planes, dim=0) |
|
pcd_grops_pred, mask_valid = self.select_points_groups(pcd[i, ...][None, :, :, :], mask_planes) |
|
|
|
for j in range(unique_planes.numel()-1): |
|
mask_valid_j = mask_valid[j, :] |
|
pcd_grops_pred_j = pcd_grops_pred[j, :] |
|
loss_tmp, valid_angles = self.constrain_a_plane_loss(pcd_grops_pred_j, mask_valid_j) |
|
valid_planes_num += valid_angles |
|
loss_sum += loss_tmp |
|
|
|
loss = loss_sum / (valid_planes_num + 1e-6) * self.loss_weight |
|
if torch.isnan(loss).item() | torch.isinf(loss).item(): |
|
loss = torch.sum(pred_depth) * 0 |
|
print(f'PWNPlane NAN error, {loss}') |
|
return loss |
|
|
|
def forward(self, prediction, target, mask, intrinsic, **kwargs): |
|
""" |
|
Virtual normal loss. |
|
:param prediction: predicted depth map, [B,W,H,C] |
|
:param data: target label, ground truth depth, [B, W, H, C], padding region [padding_up, padding_down] |
|
:return: |
|
""" |
|
dataset = kwargs['dataset'] |
|
batch_mask = np.array(dataset) == 'Taskonomy' |
|
if np.sum(batch_mask) == 0: |
|
return torch.sum(prediction) * 0.0 |
|
ins_planes_mask = kwargs['sem_mask'] |
|
assert ins_planes_mask.ndim == 4 |
|
loss = self.get_loss( |
|
prediction[batch_mask], |
|
target[batch_mask], |
|
ins_planes_mask[batch_mask], |
|
intrinsic[batch_mask], |
|
) |
|
return loss |
|
|
|
|
|
if __name__ == '__main__': |
|
import cv2 |
|
vnl_loss = PWNPlanesLoss() |
|
pred_depth = torch.rand([2, 1, 385, 513]).cuda() |
|
gt_depth = torch.rand([2, 1, 385, 513]).cuda() |
|
gt_depth[:, :, 3:20, 40:60] = 0 |
|
mask_kp1 = pred_depth > 0.9 |
|
mask_kp2 = pred_depth < 0.5 |
|
mask = torch.zeros_like(gt_depth, dtype=torch.uint8) |
|
mask = 1*mask_kp1 + 2* mask_kp2 |
|
mask[1,...] = 0 |
|
|
|
|
|
intrinsic = torch.tensor([[100, 0, 50], [0, 100, 50,], [0,0,1]]).cuda().float() |
|
intrins = torch.stack([intrinsic, intrinsic], dim=0) |
|
loss = vnl_loss(gt_depth, gt_depth, mask, intrins, dataset=np.array(['Taskonomy', 'Taskonomy'])) |
|
print(loss) |
|
|