File size: 5,749 Bytes
da2e2ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from mmdetection (https://github.com/open-mmlab/mmdetection)
# Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------
#  Modified by Shihao Wang
# ------------------------------------------------------------------------
import math

import numpy as np
import torch


def pos2posemb3d(pos, num_pos_feats=128, temperature=10000):
    scale = 2 * math.pi
    pos = pos * scale
    dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)
    dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats)
    pos_x = pos[..., 0, None] / dim_t
    pos_y = pos[..., 1, None] / dim_t
    pos_z = pos[..., 2, None] / dim_t
    pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
    pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
    pos_z = torch.stack((pos_z[..., 0::2].sin(), pos_z[..., 1::2].cos()), dim=-1).flatten(-2)
    posemb = torch.cat((pos_y, pos_x, pos_z), dim=-1)
    return posemb


def pos2posemb1d(pos, num_pos_feats=256, temperature=10000):
    scale = 2 * math.pi
    pos = pos * scale
    dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)
    dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats)
    pos_x = pos[..., 0, None] / dim_t

    pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)

    return pos_x


def nerf_positional_encoding(

        tensor, num_encoding_functions=6, include_input=False, log_sampling=True

) -> torch.Tensor:
    r"""Apply positional encoding to the input.

    Args:

        tensor (torch.Tensor): Input tensor to be positionally encoded.

        encoding_size (optional, int): Number of encoding functions used to compute

            a positional encoding (default: 6).

        include_input (optional, bool): Whether or not to include the input in the

            positional encoding (default: True).

    Returns:

    (torch.Tensor): Positional encoding of the input tensor.

    """
    # TESTED
    # Trivially, the input tensor is added to the positional encoding.
    encoding = [tensor] if include_input else []
    if log_sampling:
        frequency_bands = 2.0 ** torch.linspace(
            0.0,
            num_encoding_functions - 1,
            num_encoding_functions,
            dtype=tensor.dtype,
            device=tensor.device,
        )
    else:
        frequency_bands = torch.linspace(
            2.0 ** 0.0,
            2.0 ** (num_encoding_functions - 1),
            num_encoding_functions,
            dtype=tensor.dtype,
            device=tensor.device,
        )

    for freq in frequency_bands:
        for func in [torch.sin, torch.cos]:
            encoding.append(func(tensor * freq))

    # Special case, for no positional encoding
    if len(encoding) == 1:
        return encoding[0]
    else:
        return torch.cat(encoding, dim=-1)


def traj2nerf(traj):
    result = torch.cat(
        [
            nerf_positional_encoding(traj[..., :2]),
            torch.cos(traj[..., -1])[..., None],
            torch.sin(traj[..., -1])[..., None],
        ], dim=-1
    )
    return result


def nerf2traj(nerf, num_encoding_functions=6, include_input=False, log_sampling=True):
    # Calculate the length of the original 2D position tensor
    original_dim = 2

    # Calculate the length of the positional encoding for the 2D position tensor
    if include_input:
        encoding_length = original_dim * (2 * num_encoding_functions + 1)
    else:
        encoding_length = original_dim * 2 * num_encoding_functions

    # Extract the positional encoding for the original 2D position tensor
    positional_encoding = nerf[..., :encoding_length]

    # Reverse positional encoding
    if include_input:
        original_position = positional_encoding[..., :original_dim]
        positional_encoding = positional_encoding[..., original_dim:]
    else:
        original_position = torch.zeros(
            (*nerf.shape[:-1], original_dim), dtype=nerf.dtype, device=nerf.device
        )

    if log_sampling:
        frequency_bands = 2.0 ** torch.linspace(
            0.0,
            num_encoding_functions - 1,
            num_encoding_functions,
            dtype=nerf.dtype,
            device=nerf.device,
        )
    else:
        frequency_bands = torch.linspace(
            2.0 ** 0.0,
            2.0 ** (num_encoding_functions - 1),
            num_encoding_functions,
            dtype=nerf.dtype,
            device=nerf.device,
        )

    for i, freq in enumerate(frequency_bands):
        for j, func in enumerate([torch.sin, torch.cos]):
            original_position += func(positional_encoding[..., (2 * i + j)::2 * num_encoding_functions]) / freq

    # Extract the sine and cosine of the angle
    cos_angle = nerf[..., -2]
    sin_angle = nerf[..., -1]

    # Reconstruct the angle using atan2
    angle = torch.atan2(sin_angle, cos_angle)

    # Combine the original position and the angle to form the trajectory
    traj = torch.cat([original_position, angle[..., None]], dim=-1)
    return traj


if __name__ == '__main__':
    traj = torch.from_numpy(np.load('/mnt/f/e2e/navsim_ours/traj_final/test_4096_kmeans.npy'))
    nerf = traj2nerf(traj)
    traj_2 = nerf2traj(nerf)