Spaces:
Running
Running
import torch | |
import numpy as np | |
from collections import OrderedDict | |
import lietorch | |
from data_readers.rgbd_utils import compute_distance_matrix_flow, compute_distance_matrix_flow2 | |
def graph_to_edge_list(graph): | |
ii, jj, kk = [], [], [] | |
for s, u in enumerate(graph): | |
for v in graph[u]: | |
ii.append(u) | |
jj.append(v) | |
kk.append(s) | |
ii = torch.as_tensor(ii) | |
jj = torch.as_tensor(jj) | |
kk = torch.as_tensor(kk) | |
return ii, jj, kk | |
def keyframe_indicies(graph): | |
return torch.as_tensor([u for u in graph]) | |
def meshgrid(m, n, device='cuda'): | |
ii, jj = torch.meshgrid(torch.arange(m), torch.arange(n), indexing='ij') | |
return ii.reshape(-1).to(device), jj.reshape(-1).to(device) | |
def neighbourhood_graph(n, r): | |
ii, jj = meshgrid(n, n) | |
d = (ii - jj).abs() | |
keep = (d >= 1) & (d <= r) | |
return ii[keep], jj[keep] | |
def build_frame_graph(poses, disps, intrinsics, num=16, thresh=24.0, r=2): | |
""" construct a frame graph between co-visible frames """ | |
N = poses.shape[1] | |
poses = poses[0].cpu().numpy() | |
disps = disps[0][:,3::8,3::8].cpu().numpy() | |
intrinsics = intrinsics[0].cpu().numpy() / 8.0 | |
d = compute_distance_matrix_flow(poses, disps, intrinsics) | |
count = 0 | |
graph = OrderedDict() | |
for i in range(N): | |
graph[i] = [] | |
d[i,i] = np.inf | |
for j in range(i-r, i+r+1): | |
if 0 <= j < N and i != j: | |
graph[i].append(j) | |
d[i,j] = np.inf | |
count += 1 | |
while count < num: | |
ix = np.argmin(d) | |
i, j = ix // N, ix % N | |
if d[i,j] < thresh: | |
graph[i].append(j) | |
d[i,j] = np.inf | |
count += 1 | |
else: | |
break | |
return graph | |
def build_frame_graph_v2(poses, disps, intrinsics, num=16, thresh=24.0, r=2): | |
""" construct a frame graph between co-visible frames """ | |
N = poses.shape[1] | |
# poses = poses[0].cpu().numpy() | |
# disps = disps[0].cpu().numpy() | |
# intrinsics = intrinsics[0].cpu().numpy() | |
d = compute_distance_matrix_flow2(poses, disps, intrinsics) | |
# import matplotlib.pyplot as plt | |
# plt.imshow(d) | |
# plt.show() | |
count = 0 | |
graph = OrderedDict() | |
for i in range(N): | |
graph[i] = [] | |
d[i,i] = np.inf | |
for j in range(i-r, i+r+1): | |
if 0 <= j < N and i != j: | |
graph[i].append(j) | |
d[i,j] = np.inf | |
count += 1 | |
while 1: | |
ix = np.argmin(d) | |
i, j = ix // N, ix % N | |
if d[i,j] < thresh: | |
graph[i].append(j) | |
for i1 in range(i-1, i+2): | |
for j1 in range(j-1, j+2): | |
if 0 <= i1 < N and 0 <= j1 < N: | |
d[i1, j1] = np.inf | |
count += 1 | |
else: | |
break | |
return graph | |