lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
import torch
class VerletStandardizer():
def __init__(self, max_dist=50):
super().__init__()
self.max_dist = max_dist # magic
def transform_features(self, trajectory, history):
return trajectory
# trajectory = trajectory.reshape(trajectory.shape[0], -1, 3)
#
# # Apply Verlet parameterization
# full_trajectory = torch.cat([history, trajectory], dim=1)
# deltas = torch.diff(full_trajectory, dim=1)[:, :-1]
# pred_trajectory = full_trajectory[:, 1:-1] + deltas
# actions = full_trajectory[:, 2:] - pred_trajectory
#
# # Standardize actions
# actions = actions * self.max_dist
#
# actions = actions.reshape(actions.shape[0], -1)
# return actions
def untransform_features(self, actions, history):
return actions
# actions = actions.reshape(actions.shape[0], -1, 3)
#
# # Unstandardize actions
# actions = actions / self.max_dist
#
# # Use Verlet parameterization to calculate trajectory
# states = [history[:, 0], history[:, 1]]
# for t in range(actions.shape[1]):
# states.append((2 * states[-1]) - states[-2] + actions[:, t])
# trajectory = torch.stack(states[2:], dim=1)
#
# trajectory = trajectory.reshape(trajectory.shape[0], -1)
# return trajectory