File size: 5,749 Bytes
da2e2ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from mmdetection (https://github.com/open-mmlab/mmdetection)
# Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------
# Modified by Shihao Wang
# ------------------------------------------------------------------------
import math
import numpy as np
import torch
def pos2posemb3d(pos, num_pos_feats=128, temperature=10000):
scale = 2 * math.pi
pos = pos * scale
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats)
pos_x = pos[..., 0, None] / dim_t
pos_y = pos[..., 1, None] / dim_t
pos_z = pos[..., 2, None] / dim_t
pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
pos_z = torch.stack((pos_z[..., 0::2].sin(), pos_z[..., 1::2].cos()), dim=-1).flatten(-2)
posemb = torch.cat((pos_y, pos_x, pos_z), dim=-1)
return posemb
def pos2posemb1d(pos, num_pos_feats=256, temperature=10000):
scale = 2 * math.pi
pos = pos * scale
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats)
pos_x = pos[..., 0, None] / dim_t
pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
return pos_x
def nerf_positional_encoding(
tensor, num_encoding_functions=6, include_input=False, log_sampling=True
) -> torch.Tensor:
r"""Apply positional encoding to the input.
Args:
tensor (torch.Tensor): Input tensor to be positionally encoded.
encoding_size (optional, int): Number of encoding functions used to compute
a positional encoding (default: 6).
include_input (optional, bool): Whether or not to include the input in the
positional encoding (default: True).
Returns:
(torch.Tensor): Positional encoding of the input tensor.
"""
# TESTED
# Trivially, the input tensor is added to the positional encoding.
encoding = [tensor] if include_input else []
if log_sampling:
frequency_bands = 2.0 ** torch.linspace(
0.0,
num_encoding_functions - 1,
num_encoding_functions,
dtype=tensor.dtype,
device=tensor.device,
)
else:
frequency_bands = torch.linspace(
2.0 ** 0.0,
2.0 ** (num_encoding_functions - 1),
num_encoding_functions,
dtype=tensor.dtype,
device=tensor.device,
)
for freq in frequency_bands:
for func in [torch.sin, torch.cos]:
encoding.append(func(tensor * freq))
# Special case, for no positional encoding
if len(encoding) == 1:
return encoding[0]
else:
return torch.cat(encoding, dim=-1)
def traj2nerf(traj):
result = torch.cat(
[
nerf_positional_encoding(traj[..., :2]),
torch.cos(traj[..., -1])[..., None],
torch.sin(traj[..., -1])[..., None],
], dim=-1
)
return result
def nerf2traj(nerf, num_encoding_functions=6, include_input=False, log_sampling=True):
# Calculate the length of the original 2D position tensor
original_dim = 2
# Calculate the length of the positional encoding for the 2D position tensor
if include_input:
encoding_length = original_dim * (2 * num_encoding_functions + 1)
else:
encoding_length = original_dim * 2 * num_encoding_functions
# Extract the positional encoding for the original 2D position tensor
positional_encoding = nerf[..., :encoding_length]
# Reverse positional encoding
if include_input:
original_position = positional_encoding[..., :original_dim]
positional_encoding = positional_encoding[..., original_dim:]
else:
original_position = torch.zeros(
(*nerf.shape[:-1], original_dim), dtype=nerf.dtype, device=nerf.device
)
if log_sampling:
frequency_bands = 2.0 ** torch.linspace(
0.0,
num_encoding_functions - 1,
num_encoding_functions,
dtype=nerf.dtype,
device=nerf.device,
)
else:
frequency_bands = torch.linspace(
2.0 ** 0.0,
2.0 ** (num_encoding_functions - 1),
num_encoding_functions,
dtype=nerf.dtype,
device=nerf.device,
)
for i, freq in enumerate(frequency_bands):
for j, func in enumerate([torch.sin, torch.cos]):
original_position += func(positional_encoding[..., (2 * i + j)::2 * num_encoding_functions]) / freq
# Extract the sine and cosine of the angle
cos_angle = nerf[..., -2]
sin_angle = nerf[..., -1]
# Reconstruct the angle using atan2
angle = torch.atan2(sin_angle, cos_angle)
# Combine the original position and the angle to form the trajectory
traj = torch.cat([original_position, angle[..., None]], dim=-1)
return traj
if __name__ == '__main__':
traj = torch.from_numpy(np.load('/mnt/f/e2e/navsim_ours/traj_final/test_4096_kmeans.npy'))
nerf = traj2nerf(traj)
traj_2 = nerf2traj(nerf)
|