from collections import OrderedDict import numpy as np import torch from lietorch import SO3, SE3, Sim3 from .graph_utils import graph_to_edge_list from .projective_ops import projective_transform def pose_metrics(dE): """ Translation/Rotation/Scaling metrics from Sim3 """ t, q, s = dE.data.split([3, 4, 1], -1) ang = SO3(q).log().norm(dim=-1) # convert radians to degrees r_err = (180 / np.pi) * ang t_err = t.norm(dim=-1) s_err = (s - 1.0).abs() return r_err, t_err, s_err def fit_scale(Ps, Gs): b = Ps.shape[0] t1 = Ps.data[...,:3].detach().reshape(b, -1) t2 = Gs.data[...,:3].detach().reshape(b, -1) s = (t1*t2).sum(-1) / ((t2*t2).sum(-1) + 1e-8) return s def geodesic_loss(Ps, Gs, graph, gamma=0.9, do_scale=True): """ Loss function for training network """ # relative pose ii, jj, kk = graph_to_edge_list(graph) dP = Ps[:,jj] * Ps[:,ii].inv() n = len(Gs) geodesic_loss = 0.0 for i in range(n): w = gamma ** (n - i - 1) dG = Gs[i][:,jj] * Gs[i][:,ii].inv() if do_scale: s = fit_scale(dP, dG) dG = dG.scale(s[:,None]) # pose error d = (dG * dP.inv()).log() if isinstance(dG, SE3): tau, phi = d.split([3,3], dim=-1) geodesic_loss += w * ( tau.norm(dim=-1).mean() + phi.norm(dim=-1).mean()) elif isinstance(dG, Sim3): tau, phi, sig = d.split([3,3,1], dim=-1) geodesic_loss += w * ( tau.norm(dim=-1).mean() + phi.norm(dim=-1).mean() + 0.05 * sig.norm(dim=-1).mean()) dE = Sim3(dG * dP.inv()).detach() r_err, t_err, s_err = pose_metrics(dE) metrics = { 'rot_error': r_err.mean().item(), 'tr_error': t_err.mean().item(), 'bad_rot': (r_err < .1).float().mean().item(), 'bad_tr': (t_err < .01).float().mean().item(), } return geodesic_loss, metrics def residual_loss(residuals, gamma=0.9): """ loss on system residuals """ residual_loss = 0.0 n = len(residuals) for i in range(n): w = gamma ** (n - i - 1) residual_loss += w * residuals[i].abs().mean() return residual_loss, {'residual': residual_loss.item()} def flow_loss(Ps, disps, poses_est, disps_est, intrinsics, graph, gamma=0.9): """ optical flow loss """ N = Ps.shape[1] graph = OrderedDict() for i in range(N): graph[i] = [j for j in range(N) if abs(i-j)==1] ii, jj, kk = graph_to_edge_list(graph) coords0, val0 = projective_transform(Ps, disps, intrinsics, ii, jj) val0 = val0 * (disps[:,ii] > 0).float().unsqueeze(dim=-1) n = len(poses_est) flow_loss = 0.0 for i in range(n): w = gamma ** (n - i - 1) coords1, val1 = projective_transform(poses_est[i], disps_est[i], intrinsics, ii, jj) v = (val0 * val1).squeeze(dim=-1) epe = v * (coords1 - coords0).norm(dim=-1) flow_loss += w * epe.mean() epe = epe.reshape(-1)[v.reshape(-1) > 0.5] metrics = { 'f_error': epe.mean().item(), '1px': (epe<1.0).float().mean().item(), } return flow_loss, metrics