zach
initial commit based on github repo
3ef1661
import torch
import torch.nn as nn
import numpy as np
class VNLoss(nn.Module):
"""
Virtual Normal Loss.
"""
def __init__(self,
delta_cos=0.867, delta_diff_x=0.01,
delta_diff_y=0.01, delta_diff_z=0.01,
delta_z=1e-5, sample_ratio=0.15,
loss_weight=1.0, data_type=['sfm', 'stereo', 'lidar', 'denselidar', 'denselidar_nometric', 'denselidar_syn'], **kwargs):
super(VNLoss, self).__init__()
self.delta_cos = delta_cos
self.delta_diff_x = delta_diff_x
self.delta_diff_y = delta_diff_y
self.delta_diff_z = delta_diff_z
self.delta_z = delta_z
self.sample_ratio = sample_ratio
self.loss_weight = loss_weight
self.data_type = data_type
self.eps = 1e-6
def init_image_coor(self, intrinsic, height, width):
# x_row = torch.arange(0, W, device="cuda")
# x = torch.tile(x_row, (H, 1))
# x = x.to(torch.float32)
# u_m_u0 = x[None, None, :, :] - u0
# self.register_buffer('u_m_u0', u_m_u0, persistent=False)
# y_col = torch.arange(0, H, device="cuda") # y_col = np.arange(0, height)
# y = torch.transpose(torch.tile(y_col, (W, 1)), 1, 0)
# y = y.to(torch.float32)
# v_m_v0 = y[None, None, :, :] - v0
# self.register_buffer('v_m_v0', v_m_v0, persistent=False)
# pix_idx_mat = torch.arange(H*W, device="cuda").reshape((H, W))
# self.register_buffer('pix_idx_mat', pix_idx_mat, persistent=False)
#self.pix_idx_mat = torch.arange(height*width, device="cuda").reshape((height, width))
u0 = intrinsic[:, 0, 2][:, None, None, None]
v0 = intrinsic[:, 1, 2][:, None, None, None]
y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device="cuda"),
torch.arange(0, width, dtype=torch.float32, device="cuda")], indexing='ij')
u_m_u0 = x[None, None, :, :] - u0
v_m_v0 = y[None, None, :, :] - v0
# return u_m_u0, v_m_v0
self.register_buffer('v_m_v0', v_m_v0, persistent=False)
self.register_buffer('u_m_u0', u_m_u0, persistent=False)
def transfer_xyz(self, depth, focal_length, u_m_u0, v_m_v0):
x = u_m_u0 * depth / focal_length
y = v_m_v0 * depth / focal_length
z = depth
pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1).contiguous() # [b, h, w, c]
return pw
def select_index(self, B, H, W, mask):
"""
"""
p1 = []
p2 = []
p3 = []
pix_idx_mat = torch.arange(H*W, device="cuda").reshape((H, W))
for i in range(B):
inputs_index = torch.masked_select(pix_idx_mat, mask[i, ...].gt(self.eps))
num_effect_pixels = len(inputs_index)
intend_sample_num = int(H * W * self.sample_ratio)
sample_num = intend_sample_num if num_effect_pixels >= intend_sample_num else num_effect_pixels
shuffle_effect_pixels = torch.randperm(num_effect_pixels, device="cuda")
p1i = inputs_index[shuffle_effect_pixels[:sample_num]]
shuffle_effect_pixels = torch.randperm(num_effect_pixels, device="cuda")
p2i = inputs_index[shuffle_effect_pixels[:sample_num]]
shuffle_effect_pixels = torch.randperm(num_effect_pixels, device="cuda")
p3i = inputs_index[shuffle_effect_pixels[:sample_num]]
cat_null = torch.tensor(([0,] * (intend_sample_num - sample_num)), dtype=torch.long, device="cuda")
p1i = torch.cat([p1i, cat_null])
p2i = torch.cat([p2i, cat_null])
p3i = torch.cat([p3i, cat_null])
p1.append(p1i)
p2.append(p2i)
p3.append(p3i)
p1 = torch.stack(p1, dim=0)
p2 = torch.stack(p2, dim=0)
p3 = torch.stack(p3, dim=0)
p1_x = p1 % W
p1_y = torch.div(p1, W, rounding_mode='trunc').long() # p1 // W
p2_x = p2 % W
p2_y = torch.div(p2, W, rounding_mode='trunc').long() # p2 // W
p3_x = p3 % W
p3_y = torch.div(p3, W, rounding_mode='trunc').long() # p3 // W
p123 = {'p1_x': p1_x, 'p1_y': p1_y, 'p2_x': p2_x, 'p2_y': p2_y, 'p3_x': p3_x, 'p3_y': p3_y}
return p123
def form_pw_groups(self, p123, pw):
"""
Form 3D points groups, with 3 points in each grouup.
:param p123: points index
:param pw: 3D points
:return:
"""
B, _, _, _ = pw.shape
p1_x = p123['p1_x']
p1_y = p123['p1_y']
p2_x = p123['p2_x']
p2_y = p123['p2_y']
p3_x = p123['p3_x']
p3_y = p123['p3_y']
pw_groups = []
for i in range(B):
pw1 = pw[i, p1_y[i], p1_x[i], :]
pw2 = pw[i, p2_y[i], p2_x[i], :]
pw3 = pw[i, p3_y[i], p3_x[i], :]
pw_bi = torch.stack([pw1, pw2, pw3], dim=2)
pw_groups.append(pw_bi)
# [B, N, 3(x,y,z), 3(p1,p2,p3)]
pw_groups = torch.stack(pw_groups, dim=0)
return pw_groups
def filter_mask(self, p123, gt_xyz, delta_cos=0.867,
delta_diff_x=0.005,
delta_diff_y=0.005,
delta_diff_z=0.005):
pw = self.form_pw_groups(p123, gt_xyz)
pw12 = pw[:, :, :, 1] - pw[:, :, :, 0]
pw13 = pw[:, :, :, 2] - pw[:, :, :, 0]
pw23 = pw[:, :, :, 2] - pw[:, :, :, 1]
###ignore linear
pw_diff = torch.cat([pw12[:, :, :, np.newaxis], pw13[:, :, :, np.newaxis], pw23[:, :, :, np.newaxis]],
3) # [b, n, 3, 3]
m_batchsize, groups, coords, index = pw_diff.shape
proj_query = pw_diff.view(m_batchsize * groups, -1, index).permute(0, 2, 1).contiguous() # (B* X CX(3)) [bn, 3(p123), 3(xyz)]
proj_key = pw_diff.contiguous().view(m_batchsize * groups, -1, index) # B X (3)*C [bn, 3(xyz), 3(p123)]
q_norm = proj_query.norm(2, dim=2)
nm = torch.bmm(q_norm.contiguous().view(m_batchsize * groups, index, 1), q_norm.view(m_batchsize * groups, 1, index)) #[]
energy = torch.bmm(proj_query, proj_key) # transpose check [bn, 3(p123), 3(p123)]
norm_energy = energy / (nm + self.eps)
norm_energy = norm_energy.contiguous().view(m_batchsize * groups, -1)
mask_cos = torch.sum((norm_energy > delta_cos) + (norm_energy < -delta_cos), 1) > 3 # igonre
mask_cos = mask_cos.contiguous().view(m_batchsize, groups)
##ignore padding and invilid depth
mask_pad = torch.sum(pw[:, :, 2, :] > self.delta_z, 2) == 3
###ignore near
mask_x = torch.sum(torch.abs(pw_diff[:, :, 0, :]) < delta_diff_x, 2) > 0
mask_y = torch.sum(torch.abs(pw_diff[:, :, 1, :]) < delta_diff_y, 2) > 0
mask_z = torch.sum(torch.abs(pw_diff[:, :, 2, :]) < delta_diff_z, 2) > 0
mask_ignore = (mask_x & mask_y & mask_z) | mask_cos
mask_near = ~mask_ignore
mask = mask_pad & mask_near
return mask, pw
def select_points_groups(self, gt_depth, pred_depth, intrinsic, mask):
B, C, H, W = gt_depth.shape
focal_length = intrinsic[:, 0, 0][:, None, None, None]
u_m_u0, v_m_v0 = self.u_m_u0, self.v_m_v0 # self.init_image_coor(intrinsic, height=H, width=W)
pw_gt = self.transfer_xyz(gt_depth, focal_length, u_m_u0, v_m_v0)
pw_pred = self.transfer_xyz(pred_depth, focal_length, u_m_u0, v_m_v0)
p123 = self.select_index(B, H, W, mask)
# mask:[b, n], pw_groups_gt: [b, n, 3(x,y,z), 3(p1,p2,p3)]
mask, pw_groups_gt = self.filter_mask(p123, pw_gt,
delta_cos=0.867,
delta_diff_x=0.005,
delta_diff_y=0.005,
delta_diff_z=0.005)
# [b, n, 3, 3]
pw_groups_pred = self.form_pw_groups(p123, pw_pred)
pw_groups_pred[pw_groups_pred[:, :, 2, :] == 0] = 0.0001
mask_broadcast = mask.repeat(1, 9).reshape(B, 3, 3, -1).permute(0, 3, 1, 2).contiguous()
pw_groups_pred_not_ignore = pw_groups_pred[mask_broadcast].reshape(1, -1, 3, 3)
pw_groups_gt_not_ignore = pw_groups_gt[mask_broadcast].reshape(1, -1, 3, 3)
return pw_groups_gt_not_ignore, pw_groups_pred_not_ignore
def forward(self, prediction, target, mask, intrinsic, select=True, **kwargs): #gt_depth, pred_depth, select=True):
"""
Virtual normal loss.
:param prediction: predicted depth map, [B,W,H,C]
:param data: target label, ground truth depth, [B, W, H, C], padding region [padding_up, padding_down]
:return:
"""
loss = self.get_loss(prediction, target, mask, intrinsic, select, **kwargs)
return loss
def get_loss(self, prediction, target, mask, intrinsic, select=True, **kwargs):
# configs for the cameras
# focal_length = intrinsic[:, 0, 0][:, None, None, None]
# u0 = intrinsic[:, 0, 2][:, None, None, None]
# v0 = intrinsic[:, 1, 2][:, None, None, None]
B, _, H, W = target.shape
if 'u_m_u0' not in self._buffers or 'v_m_v0' not in self._buffers \
or self.u_m_u0.shape != torch.Size([B,1,H,W]) or self.v_m_v0.shape != torch.Size([B,1,H,W]):
self.init_image_coor(intrinsic, H, W)
gt_points, pred_points = self.select_points_groups(target, prediction, intrinsic, mask)
gt_p12 = gt_points[:, :, :, 1] - gt_points[:, :, :, 0]
gt_p13 = gt_points[:, :, :, 2] - gt_points[:, :, :, 0]
pred_p12 = pred_points[:, :, :, 1] - pred_points[:, :, :, 0]
pred_p13 = pred_points[:, :, :, 2] - pred_points[:, :, :, 0]
gt_normal = torch.cross(gt_p12, gt_p13, dim=2)
pred_normal = torch.cross(pred_p12, pred_p13, dim=2)
pred_norm = torch.norm(pred_normal, 2, dim=2, keepdim=True)
gt_norm = torch.norm(gt_normal, 2, dim=2, keepdim=True)
pred_mask = pred_norm == 0.0
gt_mask = gt_norm == 0.0
pred_mask = pred_mask.to(torch.float32)
gt_mask = gt_mask.to(torch.float32)
pred_mask *= self.eps
gt_mask *= self.eps
gt_norm = gt_norm + gt_mask
pred_norm = pred_norm + pred_mask
gt_normal = gt_normal / gt_norm
pred_normal = pred_normal / pred_norm
loss = torch.abs(gt_normal - pred_normal)
loss = torch.sum(torch.sum(loss, dim=2), dim=0)
if select:
loss, indices = torch.sort(loss, dim=0, descending=False)
loss = loss[int(loss.size(0) * 0.25):]
loss = torch.sum(loss) / (loss.numel() + self.eps)
if torch.isnan(loss).item() | torch.isinf(loss).item():
loss = 0 * torch.sum(prediction)
print(f'VNL NAN error, {loss}')
return loss * self.loss_weight
if __name__ == '__main__':
import cv2
vnl_loss = VNLoss()
pred_depth = np.random.random([2, 1, 480, 640])
gt_depth = np.zeros_like(pred_depth) #np.random.random([2, 1, 480, 640])
intrinsic = [[[100, 0, 200], [0, 100, 200], [0, 0, 1]], [[100, 0, 200], [0, 100, 200], [0, 0, 1]],]
gt_depth = torch.tensor(np.array(gt_depth, np.float32)).cuda()
pred_depth = torch.tensor(np.array(pred_depth, np.float32)).cuda()
intrinsic = torch.tensor(np.array(intrinsic, np.float32)).cuda()
mask = gt_depth > 0
loss1 = vnl_loss(pred_depth, gt_depth, mask, intrinsic)
loss2 = vnl_loss(pred_depth, gt_depth, mask, intrinsic)
print(loss1, loss2)