ThunderVVV's picture
add thirdparty
b7eedf7
from collections import OrderedDict
import numpy as np
import torch
from lietorch import SO3, SE3, Sim3
from .graph_utils import graph_to_edge_list
from .projective_ops import projective_transform
def pose_metrics(dE):
""" Translation/Rotation/Scaling metrics from Sim3 """
t, q, s = dE.data.split([3, 4, 1], -1)
ang = SO3(q).log().norm(dim=-1)
# convert radians to degrees
r_err = (180 / np.pi) * ang
t_err = t.norm(dim=-1)
s_err = (s - 1.0).abs()
return r_err, t_err, s_err
def fit_scale(Ps, Gs):
b = Ps.shape[0]
t1 = Ps.data[...,:3].detach().reshape(b, -1)
t2 = Gs.data[...,:3].detach().reshape(b, -1)
s = (t1*t2).sum(-1) / ((t2*t2).sum(-1) + 1e-8)
return s
def geodesic_loss(Ps, Gs, graph, gamma=0.9, do_scale=True):
""" Loss function for training network """
# relative pose
ii, jj, kk = graph_to_edge_list(graph)
dP = Ps[:,jj] * Ps[:,ii].inv()
n = len(Gs)
geodesic_loss = 0.0
for i in range(n):
w = gamma ** (n - i - 1)
dG = Gs[i][:,jj] * Gs[i][:,ii].inv()
if do_scale:
s = fit_scale(dP, dG)
dG = dG.scale(s[:,None])
# pose error
d = (dG * dP.inv()).log()
if isinstance(dG, SE3):
tau, phi = d.split([3,3], dim=-1)
geodesic_loss += w * (
tau.norm(dim=-1).mean() +
phi.norm(dim=-1).mean())
elif isinstance(dG, Sim3):
tau, phi, sig = d.split([3,3,1], dim=-1)
geodesic_loss += w * (
tau.norm(dim=-1).mean() +
phi.norm(dim=-1).mean() +
0.05 * sig.norm(dim=-1).mean())
dE = Sim3(dG * dP.inv()).detach()
r_err, t_err, s_err = pose_metrics(dE)
metrics = {
'rot_error': r_err.mean().item(),
'tr_error': t_err.mean().item(),
'bad_rot': (r_err < .1).float().mean().item(),
'bad_tr': (t_err < .01).float().mean().item(),
}
return geodesic_loss, metrics
def residual_loss(residuals, gamma=0.9):
""" loss on system residuals """
residual_loss = 0.0
n = len(residuals)
for i in range(n):
w = gamma ** (n - i - 1)
residual_loss += w * residuals[i].abs().mean()
return residual_loss, {'residual': residual_loss.item()}
def flow_loss(Ps, disps, poses_est, disps_est, intrinsics, graph, gamma=0.9):
""" optical flow loss """
N = Ps.shape[1]
graph = OrderedDict()
for i in range(N):
graph[i] = [j for j in range(N) if abs(i-j)==1]
ii, jj, kk = graph_to_edge_list(graph)
coords0, val0 = projective_transform(Ps, disps, intrinsics, ii, jj)
val0 = val0 * (disps[:,ii] > 0).float().unsqueeze(dim=-1)
n = len(poses_est)
flow_loss = 0.0
for i in range(n):
w = gamma ** (n - i - 1)
coords1, val1 = projective_transform(poses_est[i], disps_est[i], intrinsics, ii, jj)
v = (val0 * val1).squeeze(dim=-1)
epe = v * (coords1 - coords0).norm(dim=-1)
flow_loss += w * epe.mean()
epe = epe.reshape(-1)[v.reshape(-1) > 0.5]
metrics = {
'f_error': epe.mean().item(),
'1px': (epe<1.0).float().mean().item(),
}
return flow_loss, metrics