Spaces:
Sleeping
Sleeping
File size: 502 Bytes
04f8e39 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import torch
import torch.nn as nn
import math
class LearnableToken(nn.Module):
def __init__(self, token_num, token_dim):
super(LearnableToken, self).__init__()
self.token_num = token_num
self.token_dim = token_dim
self.token = nn.Parameter(torch.Tensor(self.token_num, self.token_dim))
# self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.token, a=math.sqrt(5))
def forward(self):
return self.token
|