import torch | |
class VerletStandardizer(): | |
def __init__(self, max_dist=50): | |
super().__init__() | |
self.max_dist = max_dist # magic | |
def transform_features(self, trajectory, history): | |
return trajectory | |
# trajectory = trajectory.reshape(trajectory.shape[0], -1, 3) | |
# | |
# # Apply Verlet parameterization | |
# full_trajectory = torch.cat([history, trajectory], dim=1) | |
# deltas = torch.diff(full_trajectory, dim=1)[:, :-1] | |
# pred_trajectory = full_trajectory[:, 1:-1] + deltas | |
# actions = full_trajectory[:, 2:] - pred_trajectory | |
# | |
# # Standardize actions | |
# actions = actions * self.max_dist | |
# | |
# actions = actions.reshape(actions.shape[0], -1) | |
# return actions | |
def untransform_features(self, actions, history): | |
return actions | |
# actions = actions.reshape(actions.shape[0], -1, 3) | |
# | |
# # Unstandardize actions | |
# actions = actions / self.max_dist | |
# | |
# # Use Verlet parameterization to calculate trajectory | |
# states = [history[:, 0], history[:, 1]] | |
# for t in range(actions.shape[1]): | |
# states.append((2 * states[-1]) - states[-2] + actions[:, t]) | |
# trajectory = torch.stack(states[2:], dim=1) | |
# | |
# trajectory = trajectory.reshape(trajectory.shape[0], -1) | |
# return trajectory | |