Spaces:
Runtime error
Runtime error
File size: 517 Bytes
6b59850 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import math
import torch
class SinusoidalPosEmb(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
x = x.squeeze() * 1000
assert len(x.shape) == 1
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim) * -emb)
emb = emb.type_as(x)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
|