Spaces:
Running
Running
from collections import OrderedDict | |
import numpy as np | |
import torch | |
from lietorch import SO3, SE3, Sim3 | |
from .graph_utils import graph_to_edge_list | |
from .projective_ops import projective_transform | |
def pose_metrics(dE): | |
""" Translation/Rotation/Scaling metrics from Sim3 """ | |
t, q, s = dE.data.split([3, 4, 1], -1) | |
ang = SO3(q).log().norm(dim=-1) | |
# convert radians to degrees | |
r_err = (180 / np.pi) * ang | |
t_err = t.norm(dim=-1) | |
s_err = (s - 1.0).abs() | |
return r_err, t_err, s_err | |
def fit_scale(Ps, Gs): | |
b = Ps.shape[0] | |
t1 = Ps.data[...,:3].detach().reshape(b, -1) | |
t2 = Gs.data[...,:3].detach().reshape(b, -1) | |
s = (t1*t2).sum(-1) / ((t2*t2).sum(-1) + 1e-8) | |
return s | |
def geodesic_loss(Ps, Gs, graph, gamma=0.9, do_scale=True): | |
""" Loss function for training network """ | |
# relative pose | |
ii, jj, kk = graph_to_edge_list(graph) | |
dP = Ps[:,jj] * Ps[:,ii].inv() | |
n = len(Gs) | |
geodesic_loss = 0.0 | |
for i in range(n): | |
w = gamma ** (n - i - 1) | |
dG = Gs[i][:,jj] * Gs[i][:,ii].inv() | |
if do_scale: | |
s = fit_scale(dP, dG) | |
dG = dG.scale(s[:,None]) | |
# pose error | |
d = (dG * dP.inv()).log() | |
if isinstance(dG, SE3): | |
tau, phi = d.split([3,3], dim=-1) | |
geodesic_loss += w * ( | |
tau.norm(dim=-1).mean() + | |
phi.norm(dim=-1).mean()) | |
elif isinstance(dG, Sim3): | |
tau, phi, sig = d.split([3,3,1], dim=-1) | |
geodesic_loss += w * ( | |
tau.norm(dim=-1).mean() + | |
phi.norm(dim=-1).mean() + | |
0.05 * sig.norm(dim=-1).mean()) | |
dE = Sim3(dG * dP.inv()).detach() | |
r_err, t_err, s_err = pose_metrics(dE) | |
metrics = { | |
'rot_error': r_err.mean().item(), | |
'tr_error': t_err.mean().item(), | |
'bad_rot': (r_err < .1).float().mean().item(), | |
'bad_tr': (t_err < .01).float().mean().item(), | |
} | |
return geodesic_loss, metrics | |
def residual_loss(residuals, gamma=0.9): | |
""" loss on system residuals """ | |
residual_loss = 0.0 | |
n = len(residuals) | |
for i in range(n): | |
w = gamma ** (n - i - 1) | |
residual_loss += w * residuals[i].abs().mean() | |
return residual_loss, {'residual': residual_loss.item()} | |
def flow_loss(Ps, disps, poses_est, disps_est, intrinsics, graph, gamma=0.9): | |
""" optical flow loss """ | |
N = Ps.shape[1] | |
graph = OrderedDict() | |
for i in range(N): | |
graph[i] = [j for j in range(N) if abs(i-j)==1] | |
ii, jj, kk = graph_to_edge_list(graph) | |
coords0, val0 = projective_transform(Ps, disps, intrinsics, ii, jj) | |
val0 = val0 * (disps[:,ii] > 0).float().unsqueeze(dim=-1) | |
n = len(poses_est) | |
flow_loss = 0.0 | |
for i in range(n): | |
w = gamma ** (n - i - 1) | |
coords1, val1 = projective_transform(poses_est[i], disps_est[i], intrinsics, ii, jj) | |
v = (val0 * val1).squeeze(dim=-1) | |
epe = v * (coords1 - coords0).norm(dim=-1) | |
flow_loss += w * epe.mean() | |
epe = epe.reshape(-1)[v.reshape(-1) > 0.5] | |
metrics = { | |
'f_error': epe.mean().item(), | |
'1px': (epe<1.0).float().mean().item(), | |
} | |
return flow_loss, metrics | |