davidafrica commited on
Commit
586fe46
·
verified ·
1 Parent(s): 4117800

Incomplete

Browse files
checkpoints/step_7500/config.json DELETED
@@ -1,22 +0,0 @@
1
- {
2
- "activation_hidden_dim": 6144,
3
- "architectures": [
4
- "PicoDecoderHF"
5
- ],
6
- "attention_n_heads": 12,
7
- "attention_n_kv_heads": 4,
8
- "auto_map": {
9
- "AutoConfig": "pico_decoder.PicoDecoderHFConfig",
10
- "AutoModelForCausalLM": "pico_decoder.PicoDecoderHF"
11
- },
12
- "batch_size": 1024,
13
- "d_model": 1536,
14
- "max_seq_len": 512,
15
- "model_type": "pico_decoder",
16
- "n_layers": 12,
17
- "norm_eps": 1e-06,
18
- "position_emb_theta": 10000.0,
19
- "torch_dtype": "float32",
20
- "transformers_version": "4.48.3",
21
- "vocab_size": 50281
22
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
checkpoints/step_7500/fabric_state/checkpoint/mp_rank_00_model_states.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:dd9c5745c862431946d20567566cbf5a57abe4c43d1bc8ad3810f1b8b94d5e7a
3
- size 67138069
 
 
 
 
checkpoints/step_7500/model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:64f6da054ed07dcd0b7ea12c5871e6b8456e6f9c2af3d3fa32f3398cda9f3638
3
- size 2267115520
 
 
 
 
checkpoints/step_7500/pico_decoder.py DELETED
@@ -1,623 +0,0 @@
1
- """
2
- Pico Decoder: A Lightweight Causal Transformer Language Model
3
-
4
- Pico Decoder uses a simple LLAMA-style transformer architecture, written for clarity and educational purposes.
5
-
6
- Everything is written with a modular design for easy modification and experimentation.
7
-
8
- Key features:
9
- - RMSNorm for layer normalization
10
- - Rotary Positional Embeddings (RoPE)
11
- - Multi-head attention with KV-cache support
12
- - SwiGLU activation function
13
- - Residual connections throughout
14
-
15
- - KV-cache for faster autoregressive generation
16
-
17
- References:
18
- - RoPE: https://arxiv.org/abs/2104.09864
19
- - SwiGLU: https://arxiv.org/abs/2002.05202
20
- - LLAMA: https://arxiv.org/abs/2302.13971
21
-
22
- Adapted from:
23
- - OLMO: https://github.com/allenai/OLMo
24
- - LLAMA: https://github.com/meta/llama
25
- """
26
-
27
- from dataclasses import asdict
28
- from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
29
-
30
- import torch
31
- import torch.nn as nn
32
- import torch.nn.functional as F
33
- from torch.nn.attention import SDPBackend, sdpa_kernel
34
- from transformers import PretrainedConfig, PreTrainedModel
35
- from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
36
-
37
- try:
38
- if TYPE_CHECKING:
39
- # We need to do this to avoid importing these when creating the HF-compatible models
40
- from src.config import ModelConfig
41
- except ImportError:
42
- pass
43
-
44
- ########################################################
45
- #
46
- # Layer Normalization
47
- #
48
- ########################################################
49
-
50
-
51
- class RMSNorm(torch.nn.Module):
52
- """Root Mean Square Layer Normalization.
53
-
54
- A variant of Layer Normalization that uses RMS statistics instead of mean/variance,
55
- resulting in improved stability and performance.
56
-
57
- Args:
58
- config (Union[ModelConfig, PicoHFConfig]): Configuration object containing normalization parameters
59
- - config.norm_eps: Small constant for numerical stability
60
- - config.d_model: Model dimension for the weight parameter
61
-
62
- References:
63
- https://arxiv.org/abs/1910.07467
64
- """
65
-
66
- def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
67
- super().__init__()
68
- self.eps = config.norm_eps
69
- self.weight = nn.Parameter(torch.ones(config.d_model))
70
-
71
- def _norm(self, x: torch.Tensor) -> torch.Tensor:
72
- """
73
- Normalizes the input tensor by its RMS value.
74
- """
75
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
76
-
77
- def forward(self, x: torch.Tensor) -> torch.Tensor:
78
- """
79
- Applies RMS normalization to the input tensor and scales it by the weight parameter.
80
- """
81
- output = self._norm(x.float()).type_as(x)
82
- return output * self.weight
83
-
84
-
85
- ########################################################
86
- #
87
- # Positional Embedding
88
- #
89
- ########################################################
90
-
91
-
92
- class RoPE(nn.Module):
93
- """Rotary Positional Embeddings (RoPE).
94
-
95
- Implements position-dependent rotation of keys and queries in attention mechanism,
96
- allowing better modeling of relative positions in sequences. Uses complex number
97
- operations for efficient rotation.
98
-
99
- Args:
100
- config (Union[ModelConfig, PicoHFConfig]): Model configuration containing:
101
- - config.position_emb_theta: Base for frequency computation
102
- - config.d_model: Model dimension
103
- - config.attention_n_heads: Number of attention heads
104
- - config.max_seq_len: Maximum sequence length
105
-
106
- References:
107
- https://arxiv.org/abs/2104.09864
108
- """
109
-
110
- _freqs_cis_tensor: torch.Tensor | None = None
111
-
112
- def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
113
- super().__init__()
114
-
115
- self.theta = config.position_emb_theta
116
- self.dim = config.d_model // config.attention_n_heads
117
-
118
- max_seq_len = config.max_seq_len
119
-
120
- # only gets set once, and then reused for all RoPE instances
121
- if RoPE._freqs_cis_tensor is None:
122
- RoPE._freqs_cis_tensor = self._setup_freqs_cis(
123
- max_seq_len, self.theta, self.dim
124
- )
125
-
126
- # register _freqs_cis buffer
127
- # can be easily recomputed so persistent=False
128
- self.register_buffer("_freqs_cis", self._freqs_cis_tensor, persistent=False)
129
-
130
- @classmethod
131
- def _setup_freqs_cis(cls, seq_len: int, theta: float, dim: int) -> torch.Tensor:
132
- """Setup Frequency Tensor for RoPE Embeddings
133
-
134
- Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
135
-
136
- Note other implementations will use cos and sin directly, but using the complex
137
- number representation is (probably?) more efficient:
138
-
139
- e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
140
- """
141
- _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
142
- positions = torch.arange(seq_len)
143
- freqs = torch.outer(positions, _freqs)
144
- return torch.polar(torch.ones_like(freqs), freqs) # complex64
145
-
146
- def get_freqs_cis(
147
- self, input_shape: torch.Size, start_pos: int, end_pos: int
148
- ) -> torch.Tensor:
149
- """Reshape Frequency Tensor for RoPE Embeddings
150
-
151
- Makes the frequency tensor broadcastable with the input tensor.
152
- """
153
- _freqs_cis = self._freqs_cis[start_pos:end_pos]
154
- ndim = len(input_shape)
155
- assert 0 <= 1 < ndim
156
- assert _freqs_cis.shape == (input_shape[1], input_shape[-1])
157
-
158
- # TODO: Check whether this is correct (might be able to remove this)
159
- shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)]
160
- return _freqs_cis.view(*shape)
161
-
162
- def forward(
163
- self,
164
- queries: torch.Tensor,
165
- keys: torch.Tensor,
166
- start_pos: int = 0,
167
- ) -> Tuple[torch.Tensor, torch.Tensor]:
168
- """Apply RoPE Embeddings to Queries and Keys
169
-
170
- Applies the rotary positional embeddings to the input tensors via complex num multiplication
171
-
172
- NOTE: The start_pos is used if we want to use the kv_cache in the attention mechanism.
173
- """
174
- queries_ = torch.view_as_complex(
175
- queries.float().reshape(*queries.shape[:-1], -1, 2)
176
- )
177
- keys_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
178
-
179
- input_shape = (
180
- queries_.shape
181
- ) # same as keys: (batch_size, seq_len, n_heads, head_dim/2)
182
- freqs_start_pos = start_pos
183
- freqs_end_pos = freqs_start_pos + queries_.shape[1]
184
-
185
- freqs_cis = self.get_freqs_cis(input_shape, freqs_start_pos, freqs_end_pos)
186
-
187
- queries_rotated = torch.view_as_real(queries_ * freqs_cis).flatten(3)
188
- keys_rotated = torch.view_as_real(keys_ * freqs_cis).flatten(3)
189
- return queries_rotated.type_as(queries), keys_rotated.type_as(keys)
190
-
191
-
192
- ########################################################
193
- #
194
- # Attention
195
- #
196
- ########################################################
197
-
198
-
199
- class Attention(nn.Module):
200
- """Multi-head Attention with Group Query Attention support.
201
-
202
- Implements scaled dot-product attention and supports:
203
- - Grouped Query Attention (GQA)
204
- - Key-Value caching for efficient inference
205
- - RoPE integration
206
-
207
- Args:
208
- config (Union[ModelConfig, PretrainedConfig]): Configuration containing:
209
- - config.attention_n_heads: Number of attention heads
210
- - config.attention_n_kv_heads: Number of key/value heads
211
- - config.d_model: Model dimension
212
- - config.batch_size: Maximum batch size
213
- - config.max_seq_len: Maximum sequence length
214
-
215
- Shape:
216
- - Input: (batch_size, seq_len, d_model)
217
- - Output: (batch_size, seq_len, d_model)
218
- """
219
-
220
- def __init__(
221
- self,
222
- config: Union["ModelConfig", "PicoDecoderHFConfig"],
223
- ):
224
- super().__init__()
225
-
226
- self.n_heads = config.attention_n_heads
227
- self.n_kv_heads = config.attention_n_kv_heads
228
-
229
- self.batch_size = config.batch_size
230
- self.max_seq_len = config.max_seq_len
231
-
232
- d_model = config.d_model
233
- self.head_dim = d_model // self.n_heads
234
-
235
- self.n_rep = self.n_heads // self.n_kv_heads
236
-
237
- self.q_proj = nn.Linear(d_model, self.n_heads * self.head_dim, bias=False)
238
- self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
239
- self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
240
- self.o_proj = nn.Linear(self.n_heads * self.head_dim, d_model, bias=False)
241
-
242
- self.rope = RoPE(config)
243
-
244
- def forward(
245
- self,
246
- input: torch.Tensor,
247
- mask: Optional[torch.Tensor] = None,
248
- past_key_values: Optional[Tuple[torch.Tensor, ...]] = None,
249
- use_cache: bool = False,
250
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
251
- """Forward pass for the attention mechanism.
252
-
253
- Computes queries, keys, and values for the attention mechanism. Applies rotary positional
254
- embeddings to the queries and keys, and then computes attention scores and outputs.
255
-
256
- For an introduction to the attention mechanism, see:
257
- https://arxiv.org/abs/1706.03762
258
-
259
- A few things to note:
260
- - The past_key_values is used to implement the KV cache, which is used to speed up
261
- generation by caching the KV pairs from previous forward passes. This is useful when doing
262
- tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
263
- modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
264
- its own KV cache - this KV cache is implemented as a tuple.
265
- """
266
- bsz, seq_len, _ = input.shape
267
- _queries, _keys, _values = (
268
- self.q_proj(input),
269
- self.k_proj(input),
270
- self.v_proj(input),
271
- )
272
-
273
- # Reshaping for multi-head attention
274
- queries = _queries.view(bsz, seq_len, self.n_heads, self.head_dim)
275
- keys = _keys.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
276
- values = _values.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
277
-
278
- # The start position is used to apply the RoPE embeddings to only the new tokens
279
- # when using the kv_cache in the attention mechanism.
280
- # We want to start from the last position in the cache.
281
- start_pos = past_key_values[0].shape[1] if past_key_values is not None else 0
282
-
283
- # apply rotary positional embeddings
284
- queries, keys = self.rope(queries, keys, start_pos)
285
-
286
- if past_key_values is not None:
287
- keys = torch.cat([past_key_values[0], keys], dim=1)
288
- values = torch.cat([past_key_values[1], values], dim=1)
289
-
290
- if use_cache:
291
- cached_keys = keys
292
- cached_values = values
293
- else:
294
- cached_keys = None
295
- cached_values = None
296
-
297
- queries = queries.transpose(1, 2)
298
- keys = keys.transpose(1, 2)
299
- values = values.transpose(1, 2)
300
-
301
- apply_gqa = self.n_rep > 1
302
- if apply_gqa and queries.device.type == "mps":
303
- # NOTE: MPS does not support GQA in the SDPA kernel, but we can repeat the keys and values
304
- # outside of the kernel to get the same effect.
305
- # See: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
306
- keys = keys.repeat_interleave(self.n_rep, dim=-3)
307
- values = values.repeat_interleave(self.n_rep, dim=-3)
308
- apply_gqa = False
309
-
310
- backends = [SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]
311
-
312
- with sdpa_kernel(backends=backends):
313
- attn_output = F.scaled_dot_product_attention(
314
- queries.contiguous(),
315
- keys.contiguous(),
316
- values.contiguous(),
317
- attn_mask=mask.to(queries.dtype),
318
- enable_gqa=apply_gqa,
319
- )
320
-
321
- attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
322
- output = self.o_proj(attn_output)
323
-
324
- return output, (cached_keys, cached_values)
325
-
326
-
327
- ########################################################
328
- #
329
- # SwiGLU (Combines MLP and Activation)
330
- #
331
- ########################################################
332
-
333
-
334
- class SwiGLU(nn.Module):
335
- """SwiGLU Activation Function with Linear Projections.
336
-
337
- Implements the SwiGLU activation function combined with linear transformations,
338
- serving as the feed-forward network in transformer blocks.
339
-
340
- Args:
341
- config (Union[ModelConfig, PicoDecoderHFConfig]): Configuration containing:
342
- - config.d_model: Model dimension
343
- - config.activation_hidden_dim: Hidden dimension (typically 4 * d_model)
344
-
345
- References:
346
- https://arxiv.org/abs/2002.05202
347
- """
348
-
349
- def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
350
- super().__init__()
351
-
352
- model_dim = config.d_model
353
- act_hidden_dim = config.activation_hidden_dim # usually 4 * d_model
354
-
355
- self.w_0 = nn.Linear(model_dim, act_hidden_dim, bias=False)
356
- self.w_1 = nn.Linear(model_dim, act_hidden_dim, bias=False)
357
- self.w_2 = nn.Linear(act_hidden_dim, model_dim, bias=False)
358
-
359
- def forward(self, x: torch.Tensor) -> torch.Tensor:
360
- return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
361
-
362
-
363
- ########################################################
364
- #
365
- # PicoDecoderBlock
366
- #
367
- ########################################################
368
-
369
-
370
- class PicoDecoderBlock(nn.Module):
371
- """Single Transformer Block with Attention and Feed-forward layers.
372
-
373
- Implements a standard transformer block with:
374
- - Multi-head attention with normalization and residual connection
375
- - SwiGLU feed-forward network with normalization and residual connection
376
-
377
- Args:
378
- config (Union[ModelConfig, PicoDecoderHFConfig]): Model configuration; either a dataclass or
379
- a HuggingFace PicoDecoderHFConfig
380
- """
381
-
382
- def __init__(
383
- self,
384
- config: Union["ModelConfig", "PicoDecoderHFConfig"],
385
- ):
386
- super().__init__()
387
-
388
- self.attention = Attention(config)
389
- self.swiglu = SwiGLU(config)
390
- self.attention_norm = RMSNorm(config)
391
- self.swiglu_norm = RMSNorm(config)
392
-
393
- def forward(
394
- self,
395
- input: torch.Tensor,
396
- mask: Optional[torch.Tensor] = None,
397
- past_key_values: Optional[Tuple[torch.Tensor]] = None,
398
- use_cache: bool = False,
399
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
400
- attention_output, cached_key_values = self.attention(
401
- self.attention_norm(input),
402
- mask=mask,
403
- past_key_values=past_key_values,
404
- use_cache=use_cache,
405
- )
406
- # NOTE: cached_key_values is None if use_cache is False
407
-
408
- h = input + attention_output
409
- out = h + self.swiglu(self.swiglu_norm(h))
410
- return out, cached_key_values
411
-
412
-
413
- ########################################################
414
- #
415
- # Pico Decoder (Causal Transformer Model)
416
- #
417
- ########################################################
418
-
419
-
420
- class PicoDecoder(nn.Module):
421
- """
422
- Pico Decoder: combines the embedding, causal decoder blocks, and output projection into a
423
- single autoregressive model.
424
-
425
- For more information on the model, see the classes for the modules that make up the model.
426
- """
427
-
428
- def __init__(
429
- self,
430
- model_config: Union["ModelConfig", "PicoDecoderHFConfig"],
431
- ):
432
- super().__init__()
433
- self.config = model_config
434
-
435
- self.embedding_proj = nn.Embedding(self.config.vocab_size, self.config.d_model)
436
- self.layers = nn.ModuleList(
437
- [PicoDecoderBlock(self.config) for _ in range(self.config.n_layers)]
438
- )
439
- self.output_norm = RMSNorm(self.config)
440
- self.de_embedding_proj = nn.Linear(
441
- self.config.d_model, self.config.vocab_size, bias=False
442
- )
443
-
444
- def convert_to_hf_model(self) -> "PicoDecoderHF":
445
- """Convert the Lightning model to a HuggingFace model."""
446
- # Build HF config
447
- hf_config = PicoDecoderHFConfig.from_dataclass(self.config)
448
-
449
- # Instantiate the HF-wrapped model
450
- hf_model = PicoDecoderHF(hf_config)
451
-
452
- # Grab our full state dict, prefixing module names
453
- raw_state = self.state_dict(prefix="pico_decoder.")
454
-
455
- # Only keep keys that exist in the HF model (drops classifier_head, etc.)
456
- hf_keys = set(hf_model.state_dict().keys())
457
- filtered_state = {k: v for k, v in raw_state.items() if k in hf_keys}
458
-
459
- # Load into HF model, ignore any missing keys
460
- hf_model.load_state_dict(filtered_state, strict=False)
461
-
462
- return hf_model
463
-
464
- def forward(
465
- self,
466
- input_ids: torch.Tensor,
467
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
468
- use_cache: bool = False,
469
- return_hidden: bool = False,
470
- ) -> Tuple[torch.Tensor, Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]]]:
471
- """
472
- This is the forward pass for the entire Pico model. It boils down to:
473
- - Embedding the input ids
474
- - Creating a causal mask
475
- - Processing through the pico layers
476
- - Projecting the output to logits
477
-
478
- NOTE: One feature that might be confusing is the KV cache. The KV cache is used to speed up
479
- generation by caching the KV pairs from previous forward passes. This is useful when doing
480
- tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
481
- modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
482
- its own KV cache which is stored as a tuple. The whole model then stores a tuple of these
483
- KV caches (so a tuple of tuples).
484
- """
485
-
486
- seq_len = input_ids.shape[-1]
487
- h = self.embedding_proj(input_ids)
488
-
489
- # Calculate start position from past cached KV pairs. Remember that each layer has its
490
- # own KV Cache. So when we index past_key_values, we need to index into the KV pairs for the
491
- # correct layer and then for either the keys or values.
492
- start_pos = 0 if past_key_values is None else past_key_values[0][0].shape[1]
493
-
494
- # Create causal mask for current sequence
495
- mask = None
496
- if seq_len > 1:
497
- mask = torch.full((seq_len, seq_len), float("-inf"))
498
- mask = torch.triu(mask, diagonal=1)
499
-
500
- # If using KV cache, extend mask to cover cached sequence length
501
- if past_key_values is not None:
502
- # Add zeros for cached tokens (we can attend to all of them)
503
- mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
504
-
505
- mask = mask.to(h.device)
506
-
507
- # NOTE: If we are using the cache, we need to store the cached KV pairs for each layer
508
- # in a tuple. Each layer will have its own cached KV pair which we aggregate in a tuple.
509
- cached_key_values = () if use_cache else None
510
-
511
- # Process through transformer blocks
512
- for idx, layer in enumerate(self.layers):
513
- layer_past_key_values = (
514
- past_key_values[idx] if past_key_values is not None else None
515
- )
516
-
517
- h, layer_cached_key_values = layer(
518
- h, mask=mask, past_key_values=layer_past_key_values, use_cache=use_cache
519
- )
520
-
521
- if use_cache:
522
- cached_key_values += (layer_cached_key_values,)
523
-
524
- # Final norm and projection
525
- h = self.output_norm(h)
526
-
527
- if return_hidden:
528
- return h, cached_key_values
529
-
530
- logits = self.de_embedding_proj(h).float()
531
-
532
- return logits, cached_key_values
533
-
534
-
535
- ########################################################
536
- #
537
- # HuggingFace Wrapper for the Pico Decoder model.
538
- #
539
- ########################################################
540
-
541
-
542
- class PicoDecoderHFConfig(PretrainedConfig):
543
- """Config class for the Pico Decoder HuggingFace wrapper."""
544
-
545
- model_type = "pico_decoder"
546
-
547
- @classmethod
548
- def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
549
- # NOTE The typical from_dict method doesn't actually set the attributes unless they are
550
- # defined in the constructor.
551
-
552
- pico_config = cls(**kwargs)
553
-
554
- # Because this class is just a wrapper around the ModelConfig dataclass, we need to do
555
- # a little extra work to ensure that the attributes are actually set.
556
- for key, value in config_dict.items():
557
- setattr(pico_config, key, value)
558
-
559
- return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
560
- unused_kwargs = {
561
- key: value for key, value in kwargs.items() if not hasattr(pico_config, key)
562
- }
563
-
564
- if return_unused_kwargs:
565
- return pico_config, unused_kwargs
566
- return pico_config
567
-
568
- @classmethod
569
- def from_dataclass(cls, model_config: "ModelConfig"):
570
- """Initialise from our custom config dataclass."""
571
- return cls.from_dict(asdict(model_config))
572
-
573
-
574
- class PicoDecoderHF(PreTrainedModel):
575
- """
576
- HuggingFace wrapper for the Pico model.
577
-
578
- Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
579
- wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
580
- Pico model as well as the model wrapped in this HuggingFace class.
581
-
582
- This also lets you do cool things like:
583
-
584
- `model = AutoModelForCausalLM.from_pretrained("path/to/checkpoint")`
585
- """
586
-
587
- config_class = PicoDecoderHFConfig
588
- _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
589
-
590
- def __init__(self, config: PicoDecoderHFConfig):
591
- super().__init__(config)
592
- self.pico_decoder = PicoDecoder(config)
593
-
594
- def forward(
595
- self,
596
- input_ids: torch.Tensor,
597
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
598
- use_cache: bool = False,
599
- **kwargs,
600
- ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
601
- """HuggingFace forward pass wrapper.
602
-
603
- Forwards pass for the HuggingFace version of the Pico Model. Basic wrapper around the
604
- Pico model's forward pass, and returns the output as a HuggingFace CausalLMOutput.
605
- """
606
- logits, past_key_values = self.pico_decoder(
607
- input_ids, past_key_values, use_cache
608
- )
609
- if use_cache:
610
- return CausalLMOutputWithPast(
611
- logits=logits,
612
- past_key_values=past_key_values,
613
- )
614
- else:
615
- return CausalLMOutput(
616
- logits=logits,
617
- )
618
-
619
-
620
- # Register for auto classes
621
- PicoDecoderHFConfig.register_for_auto_class()
622
- PicoDecoderHF.register_for_auto_class("AutoModel")
623
- PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
checkpoints/step_7500/special_tokens_map.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "eos_token": {
3
- "content": "<|endoftext|>",
4
- "lstrip": false,
5
- "normalized": false,
6
- "rstrip": false,
7
- "single_word": false
8
- },
9
- "mask_token": {
10
- "content": "[MASK]",
11
- "lstrip": false,
12
- "normalized": false,
13
- "rstrip": false,
14
- "single_word": false
15
- },
16
- "pad_token": {
17
- "content": "<|padding|>",
18
- "lstrip": false,
19
- "normalized": false,
20
- "rstrip": false,
21
- "single_word": false
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
checkpoints/step_7500/tokenizer.json DELETED
The diff for this file is too large to render. See raw diff
 
checkpoints/step_7500/tokenizer_config.json DELETED
@@ -1,248 +0,0 @@
1
- {
2
- "add_bos_token": false,
3
- "add_eos_token": false,
4
- "add_prefix_space": false,
5
- "added_tokens_decoder": {
6
- "0": {
7
- "content": "|||IP_ADDRESS|||",
8
- "lstrip": false,
9
- "normalized": true,
10
- "rstrip": false,
11
- "single_word": false,
12
- "special": false
13
- },
14
- "1": {
15
- "content": "<|padding|>",
16
- "lstrip": false,
17
- "normalized": false,
18
- "rstrip": false,
19
- "single_word": false,
20
- "special": true
21
- },
22
- "50254": {
23
- "content": " ",
24
- "lstrip": false,
25
- "normalized": true,
26
- "rstrip": false,
27
- "single_word": false,
28
- "special": false
29
- },
30
- "50255": {
31
- "content": " ",
32
- "lstrip": false,
33
- "normalized": true,
34
- "rstrip": false,
35
- "single_word": false,
36
- "special": false
37
- },
38
- "50256": {
39
- "content": " ",
40
- "lstrip": false,
41
- "normalized": true,
42
- "rstrip": false,
43
- "single_word": false,
44
- "special": false
45
- },
46
- "50257": {
47
- "content": " ",
48
- "lstrip": false,
49
- "normalized": true,
50
- "rstrip": false,
51
- "single_word": false,
52
- "special": false
53
- },
54
- "50258": {
55
- "content": " ",
56
- "lstrip": false,
57
- "normalized": true,
58
- "rstrip": false,
59
- "single_word": false,
60
- "special": false
61
- },
62
- "50259": {
63
- "content": " ",
64
- "lstrip": false,
65
- "normalized": true,
66
- "rstrip": false,
67
- "single_word": false,
68
- "special": false
69
- },
70
- "50260": {
71
- "content": " ",
72
- "lstrip": false,
73
- "normalized": true,
74
- "rstrip": false,
75
- "single_word": false,
76
- "special": false
77
- },
78
- "50261": {
79
- "content": " ",
80
- "lstrip": false,
81
- "normalized": true,
82
- "rstrip": false,
83
- "single_word": false,
84
- "special": false
85
- },
86
- "50262": {
87
- "content": " ",
88
- "lstrip": false,
89
- "normalized": true,
90
- "rstrip": false,
91
- "single_word": false,
92
- "special": false
93
- },
94
- "50263": {
95
- "content": " ",
96
- "lstrip": false,
97
- "normalized": true,
98
- "rstrip": false,
99
- "single_word": false,
100
- "special": false
101
- },
102
- "50264": {
103
- "content": " ",
104
- "lstrip": false,
105
- "normalized": true,
106
- "rstrip": false,
107
- "single_word": false,
108
- "special": false
109
- },
110
- "50265": {
111
- "content": " ",
112
- "lstrip": false,
113
- "normalized": true,
114
- "rstrip": false,
115
- "single_word": false,
116
- "special": false
117
- },
118
- "50266": {
119
- "content": " ",
120
- "lstrip": false,
121
- "normalized": true,
122
- "rstrip": false,
123
- "single_word": false,
124
- "special": false
125
- },
126
- "50267": {
127
- "content": " ",
128
- "lstrip": false,
129
- "normalized": true,
130
- "rstrip": false,
131
- "single_word": false,
132
- "special": false
133
- },
134
- "50268": {
135
- "content": " ",
136
- "lstrip": false,
137
- "normalized": true,
138
- "rstrip": false,
139
- "single_word": false,
140
- "special": false
141
- },
142
- "50269": {
143
- "content": " ",
144
- "lstrip": false,
145
- "normalized": true,
146
- "rstrip": false,
147
- "single_word": false,
148
- "special": false
149
- },
150
- "50270": {
151
- "content": " ",
152
- "lstrip": false,
153
- "normalized": true,
154
- "rstrip": false,
155
- "single_word": false,
156
- "special": false
157
- },
158
- "50271": {
159
- "content": " ",
160
- "lstrip": false,
161
- "normalized": true,
162
- "rstrip": false,
163
- "single_word": false,
164
- "special": false
165
- },
166
- "50272": {
167
- "content": " ",
168
- "lstrip": false,
169
- "normalized": true,
170
- "rstrip": false,
171
- "single_word": false,
172
- "special": false
173
- },
174
- "50273": {
175
- "content": " ",
176
- "lstrip": false,
177
- "normalized": true,
178
- "rstrip": false,
179
- "single_word": false,
180
- "special": false
181
- },
182
- "50274": {
183
- "content": " ",
184
- "lstrip": false,
185
- "normalized": true,
186
- "rstrip": false,
187
- "single_word": false,
188
- "special": false
189
- },
190
- "50275": {
191
- "content": " ",
192
- "lstrip": false,
193
- "normalized": true,
194
- "rstrip": false,
195
- "single_word": false,
196
- "special": false
197
- },
198
- "50276": {
199
- "content": " ",
200
- "lstrip": false,
201
- "normalized": true,
202
- "rstrip": false,
203
- "single_word": false,
204
- "special": false
205
- },
206
- "50277": {
207
- "content": "|||EMAIL_ADDRESS|||",
208
- "lstrip": false,
209
- "normalized": true,
210
- "rstrip": false,
211
- "single_word": false,
212
- "special": false
213
- },
214
- "50278": {
215
- "content": "|||PHONE_NUMBER|||",
216
- "lstrip": false,
217
- "normalized": true,
218
- "rstrip": false,
219
- "single_word": false,
220
- "special": false
221
- },
222
- "50279": {
223
- "content": "<|endoftext|>",
224
- "lstrip": false,
225
- "normalized": false,
226
- "rstrip": false,
227
- "single_word": false,
228
- "special": true
229
- },
230
- "50280": {
231
- "content": "[MASK]",
232
- "lstrip": false,
233
- "normalized": false,
234
- "rstrip": false,
235
- "single_word": false,
236
- "special": true
237
- }
238
- },
239
- "bos_token": null,
240
- "clean_up_tokenization_spaces": true,
241
- "eos_token": "<|endoftext|>",
242
- "extra_special_tokens": {},
243
- "mask_token": "[MASK]",
244
- "model_max_length": 1000000000000000019884624838656,
245
- "pad_token": "<|padding|>",
246
- "tokenizer_class": "GPTNeoXTokenizer",
247
- "unk_token": null
248
- }