FlameF0X commited on
Commit
6c6142e
·
verified ·
1 Parent(s): 06b8354

Upload 3 files

Browse files
__init__.py CHANGED
@@ -1,5 +1,2 @@
1
- from .modeling_snowflake_core import SnowflakeCoreG1, SnowflakeCoreG1Config
2
- from transformers import AutoConfig, AutoModelForCausalLM
3
-
4
- AutoConfig.register("snowflake_core", SnowflakeCoreG1Config)
5
- AutoModelForCausalLM.register(SnowflakeCoreG1Config, SnowflakeCoreG1)
 
1
+ from .modeling_snowflake_core import SnowflakeCoreG1
2
+ from .configuration_snowflake_core import SnowflakeCoreConfig
 
 
 
configuration_snowflake_core.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class SnowflakeCoreConfig(PretrainedConfig):
4
+ model_type = "snowflake_core"
5
+
6
+ def __init__(
7
+ self,
8
+ vocab_size=50257,
9
+ embed_dim=1024,
10
+ num_heads=16,
11
+ num_layers=24,
12
+ max_length=2048,
13
+ ffn_dim=4096,
14
+ pad_token_id=50256,
15
+ eos_token_id=50256,
16
+ bos_token_id=None,
17
+ unk_token_id=None,
18
+ dropout=0.1,
19
+ **kwargs
20
+ ):
21
+ super().__init__(
22
+ pad_token_id=pad_token_id,
23
+ eos_token_id=eos_token_id,
24
+ bos_token_id=bos_token_id,
25
+ unk_token_id=unk_token_id,
26
+ **kwargs
27
+ )
28
+ self.vocab_size = vocab_size
29
+ self.embed_dim = embed_dim
30
+ self.num_heads = num_heads
31
+ self.num_layers = num_layers
32
+ self.max_length = max_length
33
+ self.ffn_dim = ffn_dim
34
+ self.dropout = dropout
modeling_snowflake_core.py CHANGED
@@ -1,186 +1,130 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from torch import Tensor
5
- from transformers import PretrainedConfig, PreTrainedModel
6
- from transformers.modeling_outputs import CausalLMOutput
7
- from transformers.utils import logging
8
- from typing import Optional, Tuple, Dict, Any
9
-
10
- logger = logging.get_logger(__name__)
11
-
12
-
13
- # ===== Custom Attention and Transformer Block =====
14
- class FusedSelfAttention(nn.Module):
15
- def __init__(self, embed_dim, num_heads):
16
- super().__init__()
17
- self.num_heads = num_heads
18
- self.head_dim = embed_dim // num_heads
19
- self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim)
20
- self.out_proj = nn.Linear(embed_dim, embed_dim)
21
-
22
- def forward(self, x, attn_mask=None, key_padding_mask=None):
23
- B, T, C = x.size()
24
- qkv = self.qkv_proj(x)
25
- qkv = qkv.view(B, T, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
26
- q, k, v = qkv[0], qkv[1], qkv[2]
27
-
28
- scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
29
-
30
- if attn_mask is not None:
31
- scores += attn_mask.unsqueeze(0).unsqueeze(0).to(dtype=scores.dtype)
32
-
33
- if key_padding_mask is not None:
34
- scores = scores.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"))
35
-
36
- attn_probs = F.softmax(scores, dim=-1)
37
- context = torch.matmul(attn_probs, v)
38
- context = context.transpose(1, 2).contiguous().view(B, T, C)
39
- return self.out_proj(context)
40
-
41
-
42
- class GPTBlock(nn.Module):
43
- def __init__(self, embed_dim, num_heads, dropout=0.1):
44
- super().__init__()
45
- self.ln1 = nn.LayerNorm(embed_dim)
46
- self.attn = FusedSelfAttention(embed_dim, num_heads)
47
- self.dropout1 = nn.Dropout(dropout)
48
- self.ln2 = nn.LayerNorm(embed_dim)
49
- self.mlp = nn.Sequential(
50
- nn.Linear(embed_dim, 4 * embed_dim),
51
- nn.GELU(),
52
- nn.Dropout(dropout),
53
- nn.Linear(4 * embed_dim, embed_dim),
54
- )
55
- self.dropout2 = nn.Dropout(dropout)
56
-
57
- def forward(self, x, attn_mask=None, key_padding_mask=None):
58
- x = x + self.dropout1(self.attn(self.ln1(x), attn_mask, key_padding_mask))
59
- x = x + self.dropout2(self.mlp(self.ln2(x)))
60
- return x
61
-
62
-
63
- # ===== Config =====
64
- class SnowflakeCoreG1Config(PretrainedConfig):
65
- model_type = "snowflake_core"
66
-
67
- def __init__(
68
- self,
69
- vocab_size=50257,
70
- embed_dim=1024,
71
- num_heads=16,
72
- num_layers=24,
73
- max_length=2048,
74
- ffn_dim=4096,
75
- dropout=0.1,
76
- pad_token_id=50256,
77
- eos_token_id=50256,
78
- bos_token_id=None,
79
- unk_token_id=None,
80
- tie_word_embeddings=False,
81
- **kwargs,
82
- ):
83
- self.vocab_size = vocab_size
84
- self.embed_dim = embed_dim
85
- self.num_heads = num_heads
86
- self.num_layers = num_layers
87
- self.max_length = max_length
88
- self.ffn_dim = ffn_dim
89
- self.dropout = dropout
90
-
91
- super().__init__(
92
- pad_token_id=pad_token_id,
93
- eos_token_id=eos_token_id,
94
- bos_token_id=bos_token_id,
95
- unk_token_id=unk_token_id,
96
- tie_word_embeddings=tie_word_embeddings,
97
- **kwargs,
98
- )
99
-
100
-
101
- # ===== Model =====
102
- class SnowflakeCoreG1(PreTrainedModel):
103
- config_class = SnowflakeCoreG1Config
104
- base_model_prefix = "snowflake_core"
105
-
106
- def __init__(self, config: SnowflakeCoreG1Config):
107
- super().__init__(config)
108
-
109
- self.embed = nn.Embedding(config.vocab_size, config.embed_dim)
110
- self.pos_embed = nn.Embedding(config.max_length, config.embed_dim)
111
- self.dropout = nn.Dropout(config.dropout)
112
-
113
- self.blocks = nn.ModuleList([
114
- GPTBlock(config.embed_dim, config.num_heads, config.dropout)
115
- for _ in range(config.num_layers)
116
- ])
117
-
118
- self.ln_f = nn.LayerNorm(config.embed_dim)
119
- self.head = nn.Linear(config.embed_dim, config.vocab_size, bias=False)
120
-
121
- self.post_init()
122
-
123
- def forward(
124
- self,
125
- input_ids: Tensor,
126
- attention_mask: Optional[Tensor] = None,
127
- labels: Optional[Tensor] = None,
128
- **kwargs
129
- ) -> CausalLMOutput:
130
-
131
- B, T = input_ids.shape
132
- if T > self.config.max_length:
133
- logger.warning("Input truncated to max_length.")
134
- input_ids = input_ids[:, -self.config.max_length:]
135
- T = self.config.max_length
136
-
137
- pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0)
138
- x = self.embed(input_ids) + self.pos_embed(pos)
139
- x = self.dropout(x)
140
-
141
- causal_mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
142
- causal_mask = causal_mask.masked_fill(causal_mask, float('-inf'))
143
-
144
- key_padding_mask = None
145
- if attention_mask is not None:
146
- attention_mask = attention_mask[:, :T]
147
- key_padding_mask = attention_mask == 0
148
-
149
- for block in self.blocks:
150
- x = block(x, attn_mask=causal_mask, key_padding_mask=key_padding_mask)
151
-
152
- x = self.ln_f(x)
153
- logits = self.head(x)
154
-
155
- loss = None
156
- if labels is not None:
157
- shift_logits = logits[..., :-1, :].contiguous()
158
- shift_labels = labels[..., 1:].contiguous()
159
- loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
160
-
161
- return CausalLMOutput(
162
- loss=loss,
163
- logits=logits,
164
- past_key_values=None,
165
- hidden_states=None,
166
- attentions=None,
167
- )
168
-
169
- def prepare_inputs_for_generation(
170
- self,
171
- input_ids: Tensor,
172
- past_key_values: Optional[Tuple] = None,
173
- attention_mask: Optional[Tensor] = None,
174
- **kwargs
175
- ) -> Dict[str, Any]:
176
- return {
177
- "input_ids": input_ids[:, -1:] if past_key_values is not None else input_ids,
178
- "attention_mask": attention_mask,
179
- "past_key_values": past_key_values,
180
- }
181
-
182
- def get_output_embeddings(self):
183
- return self.head
184
-
185
- def set_output_embeddings(self, new_embeddings):
186
- self.head = new_embeddings
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import PreTrainedModel, PretrainedConfig
5
+ from typing import Optional, Tuple
6
+
7
+ # Optional: import custom config if present
8
+ try:
9
+ from .configuration_snowflake_core import SnowflakeCoreConfig
10
+ except ImportError:
11
+ SnowflakeCoreConfig = PretrainedConfig
12
+
13
+ class FusedSelfAttention(nn.Module):
14
+ def __init__(self, embed_dim, num_heads):
15
+ super().__init__()
16
+ self.num_heads = num_heads
17
+ self.head_dim = embed_dim // num_heads
18
+ assert (
19
+ self.head_dim * num_heads == embed_dim
20
+ ), "embed_dim must be divisible by num_heads"
21
+ self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim)
22
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
23
+
24
+ def forward(self, x, attn_mask=None, key_padding_mask=None):
25
+ B, T, C = x.size()
26
+ qkv = self.qkv_proj(x) # [B, T, 3 * C]
27
+ qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
28
+ q, k, v = qkv[0], qkv[1], qkv[2] # Each: [B, num_heads, T, head_dim]
29
+
30
+ attn_scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) # [B, num_heads, T, T]
31
+ if attn_mask is not None:
32
+ attn_scores = attn_scores + attn_mask.unsqueeze(0).unsqueeze(0).to(attn_scores.dtype)
33
+ if key_padding_mask is not None:
34
+ attn_scores = attn_scores.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'))
35
+ attn_probs = F.softmax(attn_scores, dim=-1)
36
+ attn_output = attn_probs @ v # [B, num_heads, T, head_dim]
37
+ attn_output = attn_output.transpose(1, 2).reshape(B, T, C)
38
+ return self.out_proj(attn_output)
39
+
40
+ class GPTBlock(nn.Module):
41
+ def __init__(self, embed_dim, num_heads, dropout=0.1):
42
+ super().__init__()
43
+ self.ln1 = nn.LayerNorm(embed_dim)
44
+ self.attn = FusedSelfAttention(embed_dim, num_heads)
45
+ self.dropout1 = nn.Dropout(dropout)
46
+ self.ln2 = nn.LayerNorm(embed_dim)
47
+ self.mlp = nn.Sequential(
48
+ nn.Linear(embed_dim, 4 * embed_dim),
49
+ nn.GELU(),
50
+ nn.Dropout(dropout),
51
+ nn.Linear(4 * embed_dim, embed_dim),
52
+ )
53
+ self.dropout2 = nn.Dropout(dropout)
54
+ def forward(self, x, attn_mask=None, key_padding_mask=None):
55
+ h = self.ln1(x)
56
+ attn_output = self.attn(h, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
57
+ x = x + self.dropout1(attn_output)
58
+ x = x + self.dropout2(self.mlp(self.ln2(x)))
59
+ return x
60
+
61
+ class SnowflakeCoreG1(PreTrainedModel):
62
+ config_class = SnowflakeCoreConfig
63
+ supports_gradient_checkpointing = True
64
+
65
+ def __init__(self, config):
66
+ super().__init__(config)
67
+ self.vocab_size = config.vocab_size
68
+ self.embed_dim = config.embed_dim
69
+ self.num_heads = config.num_heads
70
+ self.num_layers = config.num_layers
71
+ self.max_length = config.max_length
72
+ self.ffn_dim = getattr(config, 'ffn_dim', 4 * config.embed_dim)
73
+ self.dropout = getattr(config, 'dropout', 0.1)
74
+
75
+ self.embed = nn.Embedding(self.vocab_size, self.embed_dim)
76
+ self.pos_embed = nn.Embedding(self.max_length, self.embed_dim)
77
+ self.dropout_layer = nn.Dropout(self.dropout)
78
+ self.blocks = nn.ModuleList([
79
+ GPTBlock(self.embed_dim, self.num_heads, self.dropout) for _ in range(self.num_layers)
80
+ ])
81
+ self.ln_f = nn.LayerNorm(self.embed_dim)
82
+ self.lm_head = nn.Linear(self.embed_dim, self.vocab_size, bias=False)
83
+
84
+ self.post_init()
85
+
86
+ def get_input_embeddings(self):
87
+ return self.embed
88
+
89
+ def set_input_embeddings(self, value):
90
+ self.embed = value
91
+
92
+ def get_output_embeddings(self):
93
+ return self.lm_head
94
+
95
+ def set_output_embeddings(self, new_embeddings):
96
+ self.lm_head = new_embeddings
97
+
98
+ def forward(
99
+ self,
100
+ input_ids: torch.LongTensor = None,
101
+ attention_mask: Optional[torch.Tensor] = None,
102
+ labels: Optional[torch.LongTensor] = None,
103
+ **kwargs
104
+ ) -> Tuple:
105
+ B, T = input_ids.size()
106
+ pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0)
107
+ x = self.embed(input_ids) + self.pos_embed(pos)
108
+ x = self.dropout_layer(x)
109
+ causal_mask = torch.triu(torch.ones(T, T, device=input_ids.device), diagonal=1).bool()
110
+ causal_mask = causal_mask.masked_fill(causal_mask, float('-inf'))
111
+ key_padding_mask = None
112
+ if attention_mask is not None:
113
+ key_padding_mask = attention_mask == 0
114
+ for block in self.blocks:
115
+ x = block(x, attn_mask=causal_mask, key_padding_mask=key_padding_mask)
116
+ x = self.ln_f(x)
117
+ logits = self.lm_head(x)
118
+
119
+ loss = None
120
+ if labels is not None:
121
+ shift_logits = logits[:, :-1, :].contiguous().view(-1, self.vocab_size)
122
+ shift_labels = labels[:, 1:].contiguous().view(-1)
123
+ loss = F.cross_entropy(shift_logits, shift_labels, ignore_index=self.config.pad_token_id)
124
+ if loss is not None:
125
+ return {"loss": loss, "logits": logits}
126
+ return {"logits": logits}
127
+
128
+ @classmethod
129
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, config=None, **kwargs):
130
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)