File size: 502 Bytes
04f8e39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
import torch.nn as nn
import math

class LearnableToken(nn.Module):
    def __init__(self, token_num, token_dim):
        super(LearnableToken, self).__init__()
        self.token_num = token_num
        self.token_dim = token_dim
        self.token = nn.Parameter(torch.Tensor(self.token_num, self.token_dim))
        # self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.token, a=math.sqrt(5))

    def forward(self):
        return self.token