flashsloth / model /learnable_token.py
Tongbo's picture
Upload folder using huggingface_hub
04f8e39 verified
raw
history blame contribute delete
502 Bytes
import torch
import torch.nn as nn
import math
class LearnableToken(nn.Module):
def __init__(self, token_num, token_dim):
super(LearnableToken, self).__init__()
self.token_num = token_num
self.token_dim = token_dim
self.token = nn.Parameter(torch.Tensor(self.token_num, self.token_dim))
# self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.token, a=math.sqrt(5))
def forward(self):
return self.token