YuWang0103's picture
Upload 41 files
6b59850 verified
raw
history blame
517 Bytes
import math
import torch
class SinusoidalPosEmb(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
x = x.squeeze() * 1000
assert len(x.shape) == 1
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim) * -emb)
emb = emb.type_as(x)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb