ThunderVVV's picture
add thirdparty
b7eedf7
import torch
import numpy as np
from collections import OrderedDict
import lietorch
from data_readers.rgbd_utils import compute_distance_matrix_flow, compute_distance_matrix_flow2
def graph_to_edge_list(graph):
ii, jj, kk = [], [], []
for s, u in enumerate(graph):
for v in graph[u]:
ii.append(u)
jj.append(v)
kk.append(s)
ii = torch.as_tensor(ii)
jj = torch.as_tensor(jj)
kk = torch.as_tensor(kk)
return ii, jj, kk
def keyframe_indicies(graph):
return torch.as_tensor([u for u in graph])
def meshgrid(m, n, device='cuda'):
ii, jj = torch.meshgrid(torch.arange(m), torch.arange(n), indexing='ij')
return ii.reshape(-1).to(device), jj.reshape(-1).to(device)
def neighbourhood_graph(n, r):
ii, jj = meshgrid(n, n)
d = (ii - jj).abs()
keep = (d >= 1) & (d <= r)
return ii[keep], jj[keep]
def build_frame_graph(poses, disps, intrinsics, num=16, thresh=24.0, r=2):
""" construct a frame graph between co-visible frames """
N = poses.shape[1]
poses = poses[0].cpu().numpy()
disps = disps[0][:,3::8,3::8].cpu().numpy()
intrinsics = intrinsics[0].cpu().numpy() / 8.0
d = compute_distance_matrix_flow(poses, disps, intrinsics)
count = 0
graph = OrderedDict()
for i in range(N):
graph[i] = []
d[i,i] = np.inf
for j in range(i-r, i+r+1):
if 0 <= j < N and i != j:
graph[i].append(j)
d[i,j] = np.inf
count += 1
while count < num:
ix = np.argmin(d)
i, j = ix // N, ix % N
if d[i,j] < thresh:
graph[i].append(j)
d[i,j] = np.inf
count += 1
else:
break
return graph
def build_frame_graph_v2(poses, disps, intrinsics, num=16, thresh=24.0, r=2):
""" construct a frame graph between co-visible frames """
N = poses.shape[1]
# poses = poses[0].cpu().numpy()
# disps = disps[0].cpu().numpy()
# intrinsics = intrinsics[0].cpu().numpy()
d = compute_distance_matrix_flow2(poses, disps, intrinsics)
# import matplotlib.pyplot as plt
# plt.imshow(d)
# plt.show()
count = 0
graph = OrderedDict()
for i in range(N):
graph[i] = []
d[i,i] = np.inf
for j in range(i-r, i+r+1):
if 0 <= j < N and i != j:
graph[i].append(j)
d[i,j] = np.inf
count += 1
while 1:
ix = np.argmin(d)
i, j = ix // N, ix % N
if d[i,j] < thresh:
graph[i].append(j)
for i1 in range(i-1, i+2):
for j1 in range(j-1, j+2):
if 0 <= i1 < N and 0 <= j1 < N:
d[i1, j1] = np.inf
count += 1
else:
break
return graph