lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
raw
history blame
5.75 kB
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from mmdetection (https://github.com/open-mmlab/mmdetection)
# Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------
# Modified by Shihao Wang
# ------------------------------------------------------------------------
import math
import numpy as np
import torch
def pos2posemb3d(pos, num_pos_feats=128, temperature=10000):
scale = 2 * math.pi
pos = pos * scale
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats)
pos_x = pos[..., 0, None] / dim_t
pos_y = pos[..., 1, None] / dim_t
pos_z = pos[..., 2, None] / dim_t
pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
pos_z = torch.stack((pos_z[..., 0::2].sin(), pos_z[..., 1::2].cos()), dim=-1).flatten(-2)
posemb = torch.cat((pos_y, pos_x, pos_z), dim=-1)
return posemb
def pos2posemb1d(pos, num_pos_feats=256, temperature=10000):
scale = 2 * math.pi
pos = pos * scale
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats)
pos_x = pos[..., 0, None] / dim_t
pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
return pos_x
def nerf_positional_encoding(
tensor, num_encoding_functions=6, include_input=False, log_sampling=True
) -> torch.Tensor:
r"""Apply positional encoding to the input.
Args:
tensor (torch.Tensor): Input tensor to be positionally encoded.
encoding_size (optional, int): Number of encoding functions used to compute
a positional encoding (default: 6).
include_input (optional, bool): Whether or not to include the input in the
positional encoding (default: True).
Returns:
(torch.Tensor): Positional encoding of the input tensor.
"""
# TESTED
# Trivially, the input tensor is added to the positional encoding.
encoding = [tensor] if include_input else []
if log_sampling:
frequency_bands = 2.0 ** torch.linspace(
0.0,
num_encoding_functions - 1,
num_encoding_functions,
dtype=tensor.dtype,
device=tensor.device,
)
else:
frequency_bands = torch.linspace(
2.0 ** 0.0,
2.0 ** (num_encoding_functions - 1),
num_encoding_functions,
dtype=tensor.dtype,
device=tensor.device,
)
for freq in frequency_bands:
for func in [torch.sin, torch.cos]:
encoding.append(func(tensor * freq))
# Special case, for no positional encoding
if len(encoding) == 1:
return encoding[0]
else:
return torch.cat(encoding, dim=-1)
def traj2nerf(traj):
result = torch.cat(
[
nerf_positional_encoding(traj[..., :2]),
torch.cos(traj[..., -1])[..., None],
torch.sin(traj[..., -1])[..., None],
], dim=-1
)
return result
def nerf2traj(nerf, num_encoding_functions=6, include_input=False, log_sampling=True):
# Calculate the length of the original 2D position tensor
original_dim = 2
# Calculate the length of the positional encoding for the 2D position tensor
if include_input:
encoding_length = original_dim * (2 * num_encoding_functions + 1)
else:
encoding_length = original_dim * 2 * num_encoding_functions
# Extract the positional encoding for the original 2D position tensor
positional_encoding = nerf[..., :encoding_length]
# Reverse positional encoding
if include_input:
original_position = positional_encoding[..., :original_dim]
positional_encoding = positional_encoding[..., original_dim:]
else:
original_position = torch.zeros(
(*nerf.shape[:-1], original_dim), dtype=nerf.dtype, device=nerf.device
)
if log_sampling:
frequency_bands = 2.0 ** torch.linspace(
0.0,
num_encoding_functions - 1,
num_encoding_functions,
dtype=nerf.dtype,
device=nerf.device,
)
else:
frequency_bands = torch.linspace(
2.0 ** 0.0,
2.0 ** (num_encoding_functions - 1),
num_encoding_functions,
dtype=nerf.dtype,
device=nerf.device,
)
for i, freq in enumerate(frequency_bands):
for j, func in enumerate([torch.sin, torch.cos]):
original_position += func(positional_encoding[..., (2 * i + j)::2 * num_encoding_functions]) / freq
# Extract the sine and cosine of the angle
cos_angle = nerf[..., -2]
sin_angle = nerf[..., -1]
# Reconstruct the angle using atan2
angle = torch.atan2(sin_angle, cos_angle)
# Combine the original position and the angle to form the trajectory
traj = torch.cat([original_position, angle[..., None]], dim=-1)
return traj
if __name__ == '__main__':
traj = torch.from_numpy(np.load('/mnt/f/e2e/navsim_ours/traj_final/test_4096_kmeans.npy'))
nerf = traj2nerf(traj)
traj_2 = nerf2traj(nerf)