import torch |
from torch import nn |
import numpy as np |
import torch.nn.functional as F |
from .depth_to_normal import Depth2Normal |
""" |
Sampling strategies: RS (Random Sampling), EGS (Edge-Guided Sampling), and IGS (Instance-Guided Sampling) |
""" |
def randomSamplingNormal(inputs, targets, masks, sample_num): |
num_effect_pixels = torch.sum(masks) |
shuffle_effect_pixels = torch.randperm(num_effect_pixels, device="cuda") |
valid_inputs = inputs[:, masks] |
valid_targes = targets[:, masks] |
inputs_A = valid_inputs[:, shuffle_effect_pixels[0 : sample_num * 2 : 2]] |
inputs_B = valid_inputs[:, shuffle_effect_pixels[1 : sample_num * 2 : 2]] |
targets_A = valid_targes[:, shuffle_effect_pixels[0 : sample_num * 2 : 2]] |
targets_B = valid_targes[:, shuffle_effect_pixels[1 : sample_num * 2 : 2]] |
if inputs_A.shape[1] != inputs_B.shape[1]: |
num_min = min(targets_A.shape[1], targets_B.shape[1]) |
inputs_A = inputs_A[:, :num_min] |
inputs_B = inputs_B[:, :num_min] |
targets_A = targets_A[:, :num_min] |
targets_B = targets_B[:, :num_min] |
return inputs_A, inputs_B, targets_A, targets_B |
def ind2sub(idx, cols): |
r = torch.div(idx, cols, rounding_mode='floor') |
c = idx - r * cols |
return r, c |
def sub2ind(r, c, cols): |
idx = r * cols + c |
return idx |
def edgeGuidedSampling(inputs, targets, edges_img, thetas_img, masks, h, w): |
edges_max = edges_img.max() |
edges_min = edges_img.min() |
edges_mask = edges_img.ge(edges_max * 0.1) |
edges_loc = edges_mask.nonzero(as_tuple=False) |
thetas_edge = torch.masked_select(thetas_img, edges_mask) |
minlen = thetas_edge.size()[0] |
sample_num = minlen |
index_anchors = torch.randint(0, minlen, (sample_num,), dtype=torch.long, device="cuda") |
theta_anchors = torch.gather(thetas_edge, 0, index_anchors) |
row_anchors, col_anchors = ind2sub(edges_loc[index_anchors].squeeze(1), w) |
distance_matrix = torch.randint(3, 20, (4, sample_num), device="cuda") |
pos_or_neg = torch.ones(4, sample_num, device="cuda") |
pos_or_neg[:2, :] = -pos_or_neg[:2, :] |
distance_matrix = distance_matrix.float() * pos_or_neg |
col = ( |
col_anchors.unsqueeze(0).expand(4, sample_num).long() |
+ torch.round( |
distance_matrix.float() * torch.abs(torch.cos(theta_anchors)).unsqueeze(0) |
).long() |
) |
row = ( |
row_anchors.unsqueeze(0).expand(4, sample_num).long() |
+ torch.round( |
distance_matrix.float() * torch.abs(torch.sin(theta_anchors)).unsqueeze(0) |
).long() |
) |
col[col < 0] = 0 |
col[col > w - 1] = w - 1 |
row[row < 0] = 0 |
row[row > h - 1] = h - 1 |
a = sub2ind(row[0, :], col[0, :], w) |
b = sub2ind(row[1, :], col[1, :], w) |
c = sub2ind(row[2, :], col[2, :], w) |
d = sub2ind(row[3, :], col[3, :], w) |
A = torch.cat((a, b, c), 0) |
B = torch.cat((b, c, d), 0) |
inputs_A = inputs[:, A] |
inputs_B = inputs[:, B] |
targets_A = targets[:, A] |
targets_B = targets[:, B] |
masks_A = torch.gather(masks, 0, A.long()) |
masks_B = torch.gather(masks, 0, B.long()) |
return ( |
inputs_A, |
inputs_B, |
targets_A, |
targets_B, |
masks_A, |
masks_B, |
sample_num, |
row, |
col, |
) |
class EdgeguidedNormalLoss(nn.Module): |
def __init__( |
self, |
point_pairs=10000, |
cos_theta1=0.25, |
cos_theta2=0.98, |
cos_theta3=0.5, |
cos_theta4=0.86, |
mask_value=1e-8, |
loss_weight=1.0, |
data_type=['stereo', 'denselidar', 'denselidar_nometric','denselidar_syn'], |
**kwargs |
): |
super(EdgeguidedNormalLoss, self).__init__() |
self.point_pairs = point_pairs |
self.mask_value = mask_value |
self.cos_theta1 = cos_theta1 |
self.cos_theta2 = cos_theta2 |
self.cos_theta3 = cos_theta3 |
self.cos_theta4 = cos_theta4 |
self.depth2normal = Depth2Normal() |
self.loss_weight = loss_weight |
self.data_type = data_type |
self.eps = 1e-6 |
def getEdge(self, images): |
n, c, h, w = images.size() |
a = ( |
torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32, device="cuda") |
.contiguous() |
.view((1, 1, 3, 3)) |
.repeat(1, 1, 1, 1) |
) |
b = ( |
torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=torch.float32, device="cuda") |
.contiguous() |
.view((1, 1, 3, 3)) |
.repeat(1, 1, 1, 1) |
) |
if c == 3: |
gradient_x = F.conv2d(images[:, 0, :, :].unsqueeze(1), a) |
gradient_y = F.conv2d(images[:, 0, :, :].unsqueeze(1), b) |
else: |
gradient_x = F.conv2d(images, a) |
gradient_y = F.conv2d(images, b) |
edges = torch.sqrt(torch.pow(gradient_x, 2) + torch.pow(gradient_y, 2)) |
edges = F.pad(edges, (1, 1, 1, 1), "constant", 0) |
thetas = torch.atan2(gradient_y, gradient_x) |
thetas = F.pad(thetas, (1, 1, 1, 1), "constant", 0) |
return edges, thetas |
def getNormalEdge(self, normals): |
n, c, h, w = normals.size() |
a = ( |
torch.Tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32, device="cuda") |
.contiguous() |
.view((1, 1, 3, 3)) |
.repeat(3, 1, 1, 1) |
) |
b = ( |
torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=torch.float32, device="cuda") |
.contiguous() |
.view((1, 1, 3, 3)) |
.repeat(3, 1, 1, 1) |
) |
gradient_x = torch.abs(F.conv2d(normals, a, groups=c)) |
gradient_y = torch.abs(F.conv2d(normals, b, groups=c)) |
gradient_x = gradient_x.mean(dim=1, keepdim=True) |
gradient_y = gradient_y.mean(dim=1, keepdim=True) |
edges = torch.sqrt(torch.pow(gradient_x, 2) + torch.pow(gradient_y, 2)) |
edges = F.pad(edges, (1, 1, 1, 1), "constant", 0) |
thetas = torch.atan2(gradient_y, gradient_x) |
thetas = F.pad(thetas, (1, 1, 1, 1), "constant", 0) |
return edges, thetas |
def visual_check(self, rgb, samples): |
import os |
import matplotlib.pyplot as plt |
rgb = rgb.cpu().squeeze().numpy() |
mean = np.array([123.675, 116.28, 103.53])[:, np.newaxis, np.newaxis] |
std= np.array([58.395, 57.12, 57.375])[:, np.newaxis, np.newaxis] |
rgb = ((rgb * std) + mean).astype(np.uint8).transpose((1, 2, 0)) |
mask_A, mask_B, mask_C, mask_D = samples |
rgb[mask_A.astype(np.bool)] = [255, 0, 0] |
rgb[mask_B.astype(np.bool)] = [0, 255, 0] |
rgb[mask_C.astype(np.bool)] = [0, 0, 255] |
rgb[mask_D.astype(np.bool)] = [255, 255, 0] |
filename = str(np.random.randint(10000)) |
save_path = os.path.join('test_ranking', filename + '.png') |
os.makedirs(os.path.dirname(save_path), exist_ok=True) |
plt.imsave(save_path, rgb) |
def forward(self, prediction, target, mask, input, intrinsic, **kwargs): |
loss = self.get_loss(prediction, target, mask, input, intrinsic, **kwargs) |
return loss |
def get_loss(self, prediction, target, mask, input, intrinsic, **kwargs): |
""" |
input and target: surface normal input |
input: rgb images |
""" |
gt_depths = target |
if 'predictions_normals' not in kwargs: |
predictions_normals, _ = self.depth2normal(prediction, intrinsic, mask) |
targets_normals, targets_normals_masks = self.depth2normal(target, intrinsic, mask) |
else: |
predictions_normals = kwargs['predictions_normals'] |
targets_normals = kwargs['targets_normals'] |
targets_normals_masks = kwargs['targets_normals_masks'] |
masks_normals = mask & targets_normals_masks |
edges_img, thetas_img = self.getEdge(input) |
edges_depth, thetas_depth = self.getEdge(gt_depths) |
n, c, h, w = targets_normals.size() |
predictions_normals = predictions_normals.contiguous().view(n, c, -1) |
targets_normals = targets_normals.contiguous().view(n, c, -1) |
masks_normals = masks_normals.contiguous().view(n, -1) |
edges_img = edges_img.contiguous().view(n, -1) |
thetas_img = thetas_img.contiguous().view(n, -1) |
edges_depth = edges_depth.contiguous().view(n, -1) |
thetas_depth = thetas_depth.contiguous().view(n, -1) |
losses = 0.0 |
valid_samples = 0.0 |
for i in range(n): |
( |
inputs_A, |
inputs_B, |
targets_A, |
targets_B, |
masks_A, |
masks_B, |
sample_num, |
row_img, |
col_img, |
) = edgeGuidedSampling( |
predictions_normals[i, :], |
targets_normals[i, :], |
edges_img[i], |
thetas_img[i], |
masks_normals[i, :], |
h, |
w, |
) |
consistency_mask = masks_A & masks_B |
target_cos = torch.sum(targets_A * targets_B, dim=0) |
input_cos = torch.sum(inputs_A * inputs_B, dim=0) |
losses += torch.sum(torch.abs(torch.ones_like(target_cos)-input_cos) * consistency_mask.float()) |
valid_samples += torch.sum(consistency_mask.float()) |
loss = (losses / (valid_samples + self.eps)) * self.loss_weight |
if torch.isnan(loss).item() | torch.isinf(loss).item(): |
loss = 0 * torch.sum(prediction) |
print(f'Pair-wise Normal Regression Loss NAN error, {loss}, valid pix: {valid_samples}') |
return loss |
def tmp_check_normal(normals, masks, depth): |
import matplotlib.pyplot as plt |
import os |
import cv2 |
from mono.utils.visualization import vis_surface_normal |
vis_normal1 = vis_surface_normal(normals[0, ...].permute(1, 2, 0).detach(), masks[0,...].detach().squeeze()) |
vis_normal2 = vis_surface_normal(normals[1, ...].permute(1, 2, 0).detach(), masks[1,...].detach().squeeze()) |
vis_depth1 = depth[0, ...].detach().cpu().squeeze().numpy() |
vis_depth2 = depth[1, ...].detach().cpu().squeeze().numpy() |
name = np.random.randint(100000) |
os.makedirs('test_normal', exist_ok=True) |
cv2.imwrite(f'test_normal/{name}.png', vis_normal1) |
cv2.imwrite(f'test_normal/{name + 1}.png', vis_normal2) |
plt.imsave(f'test_normal/{name}_d.png', vis_depth1) |
plt.imsave(f'test_normal/{name + 1}_d.png', vis_depth2) |
if __name__ == '__main__': |
ENL = EdgeguidedNormalLoss() |
depth = np.random.randn(2, 1, 20, 22) |
intrin = np.array([[300, 0, 10], [0, 300, 10], [0,0,1]]) |
prediction = np.random.randn(2, 1, 20, 22) |
imgs = np.random.randn(2, 3, 20, 22) |
intrinsics = np.stack([intrin, intrin], axis=0) |
depth_t = torch.from_numpy(depth).cuda().float() |
prediction = torch.from_numpy(prediction).cuda().float() |
intrinsics = torch.from_numpy(intrinsics).cuda().float() |
imgs = torch.from_numpy(imgs).cuda().float() |
depth_t = -1 * torch.abs(depth_t) |
loss = ENL(prediction, depth_t, masks=depth_t>0, images=imgs, intrinsic=intrinsics) |
print(loss) |