|
from transformers import PretrainedConfig
|
|
|
|
class SnowflakeCoreConfig(PretrainedConfig):
|
|
model_type = "snowflake_core"
|
|
|
|
def __init__(
|
|
self,
|
|
vocab_size=50257,
|
|
embed_dim=1024,
|
|
num_heads=16,
|
|
num_layers=24,
|
|
max_length=2048,
|
|
ffn_dim=4096,
|
|
pad_token_id=50256,
|
|
eos_token_id=50256,
|
|
bos_token_id=None,
|
|
unk_token_id=None,
|
|
dropout=0.1,
|
|
**kwargs
|
|
):
|
|
super().__init__(
|
|
pad_token_id=pad_token_id,
|
|
eos_token_id=eos_token_id,
|
|
bos_token_id=bos_token_id,
|
|
unk_token_id=unk_token_id,
|
|
**kwargs
|
|
)
|
|
self.vocab_size = vocab_size
|
|
self.embed_dim = embed_dim
|
|
self.num_heads = num_heads
|
|
self.num_layers = num_layers
|
|
self.max_length = max_length
|
|
self.ffn_dim = ffn_dim
|
|
self.dropout = dropout |