import torch class VerletStandardizer(): def __init__(self, max_dist=50): super().__init__() self.max_dist = max_dist # magic def transform_features(self, trajectory, history): return trajectory # trajectory = trajectory.reshape(trajectory.shape[0], -1, 3) # # # Apply Verlet parameterization # full_trajectory = torch.cat([history, trajectory], dim=1) # deltas = torch.diff(full_trajectory, dim=1)[:, :-1] # pred_trajectory = full_trajectory[:, 1:-1] + deltas # actions = full_trajectory[:, 2:] - pred_trajectory # # # Standardize actions # actions = actions * self.max_dist # # actions = actions.reshape(actions.shape[0], -1) # return actions def untransform_features(self, actions, history): return actions # actions = actions.reshape(actions.shape[0], -1, 3) # # # Unstandardize actions # actions = actions / self.max_dist # # # Use Verlet parameterization to calculate trajectory # states = [history[:, 0], history[:, 1]] # for t in range(actions.shape[1]): # states.append((2 * states[-1]) - states[-2] + actions[:, t]) # trajectory = torch.stack(states[2:], dim=1) # # trajectory = trajectory.reshape(trajectory.shape[0], -1) # return trajectory