ThomasTheMaker commited on
Commit
73ac787
·
verified ·
1 Parent(s): d3d58af

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/config.json +22 -0
  2. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/fabric_state/checkpoint.pt +3 -0
  3. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/generation_config.json +4 -0
  4. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_activations.pt +3 -0
  5. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_data/data-00000-of-00001.arrow +3 -0
  6. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_data/dataset_info.json +19 -0
  7. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_data/state.json +13 -0
  8. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_gradients.pt +3 -0
  9. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_weights.pt +3 -0
  10. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/model.safetensors +3 -0
  11. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/pico_decoder.py +911 -0
  12. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/special_tokens_map.json +16 -0
  13. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/tokenizer.json +0 -0
  14. pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/tokenizer_config.json +239 -0
  15. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/config.json +22 -0
  16. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/fabric_state/checkpoint.pt +3 -0
  17. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/generation_config.json +4 -0
  18. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_activations.pt +3 -0
  19. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_data/data-00000-of-00001.arrow +3 -0
  20. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_data/dataset_info.json +19 -0
  21. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_data/state.json +13 -0
  22. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_gradients.pt +3 -0
  23. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_weights.pt +3 -0
  24. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/model.safetensors +3 -0
  25. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/pico_decoder.py +911 -0
  26. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/special_tokens_map.json +16 -0
  27. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/tokenizer.json +0 -0
  28. pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/tokenizer_config.json +239 -0
  29. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/config.json +22 -0
  30. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/fabric_state/checkpoint.pt +3 -0
  31. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/generation_config.json +4 -0
  32. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_activations.pt +3 -0
  33. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_data/data-00000-of-00001.arrow +3 -0
  34. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_data/dataset_info.json +19 -0
  35. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_data/state.json +13 -0
  36. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_gradients.pt +3 -0
  37. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_weights.pt +3 -0
  38. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/model.safetensors +3 -0
  39. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/pico_decoder.py +911 -0
  40. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/special_tokens_map.json +16 -0
  41. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/tokenizer.json +0 -0
  42. pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/tokenizer_config.json +239 -0
  43. pico-decoder-tiny-dolma20M-v1/checkpoints/step_100000/config.json +22 -0
  44. pico-decoder-tiny-dolma20M-v1/checkpoints/step_100000/fabric_state/checkpoint.pt +3 -0
  45. pico-decoder-tiny-dolma20M-v1/checkpoints/step_100000/generation_config.json +4 -0
  46. pico-decoder-tiny-dolma20M-v1/checkpoints/step_100000/model.safetensors +3 -0
  47. pico-decoder-tiny-dolma20M-v1/checkpoints/step_100000/pico_decoder.py +911 -0
  48. pico-decoder-tiny-dolma20M-v1/checkpoints/step_100000/special_tokens_map.json +16 -0
  49. pico-decoder-tiny-dolma20M-v1/checkpoints/step_100000/tokenizer.json +0 -0
  50. pico-decoder-tiny-dolma20M-v1/checkpoints/step_100000/tokenizer_config.json +239 -0
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_hidden_dim": 384,
3
+ "architectures": [
4
+ "PicoDecoderHF"
5
+ ],
6
+ "attention_n_heads": 12,
7
+ "attention_n_kv_heads": 4,
8
+ "auto_map": {
9
+ "AutoConfig": "pico_decoder.PicoDecoderHFConfig",
10
+ "AutoModelForCausalLM": "pico_decoder.PicoDecoderHF"
11
+ },
12
+ "batch_size": 1024,
13
+ "d_model": 96,
14
+ "max_seq_len": 2048,
15
+ "model_type": "pico_decoder",
16
+ "n_layers": 12,
17
+ "norm_eps": 1e-06,
18
+ "position_emb_theta": 10000.0,
19
+ "torch_dtype": "float32",
20
+ "transformers_version": "4.48.3",
21
+ "vocab_size": 50304
22
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/fabric_state/checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d01a6a79f53f412afc600ef5825ba1ce606eacf5d8808aa0c83de62b2b42ef28
3
+ size 45187997
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "transformers_version": "4.48.3",
3
+ "vocab_size": 50304
4
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_activations.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77cf61aa7d133a73efa174773297c7654113483bd3505224f5b0725336a2b479
3
+ size 98331
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_data/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71dea5221a5b809d03b575f9a437c3772951dc4d8c202e5af09d005a23791b3a
3
+ size 271568
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_data/dataset_info.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "input_ids": {
6
+ "feature": {
7
+ "dtype": "int32",
8
+ "_type": "Value"
9
+ },
10
+ "_type": "Sequence"
11
+ },
12
+ "text": {
13
+ "dtype": "string",
14
+ "_type": "Value"
15
+ }
16
+ },
17
+ "homepage": "",
18
+ "license": ""
19
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_data/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "6cc12b19e292c1f8",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_gradients.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9cd1f0ba01d7b6d8f8c470ba7065f7ba7251409f02235127fb5952480aec233a
3
+ size 2371527
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/learning_dynamics/train_weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c029ef92a6494ae121c847e432e52e6a8ff3bf7d9fef3e61bef871c1e9a9aa02
3
+ size 2371443
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1852515eb5c8556533445f22edf523884b9f8cc44812379a6a951668a4ffa3a3
3
+ size 45143592
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/pico_decoder.py ADDED
@@ -0,0 +1,911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pico Decoder: A Lightweight Causal Transformer Language Model
3
+
4
+ Pico Decoder uses a simple LLAMA-style transformer architecture, written for clarity and educational purposes.
5
+
6
+ Everything is written with a modular design for easy modification and experimentation.
7
+
8
+ Key features:
9
+ - RMSNorm for layer normalization
10
+ - Rotary Positional Embeddings (RoPE)
11
+ - Multi-head attention with KV-cache support
12
+ - SwiGLU activation function
13
+ - Residual connections throughout
14
+
15
+ - KV-cache for faster autoregressive generation
16
+
17
+ References:
18
+ - RoPE: https://arxiv.org/abs/2104.09864
19
+ - SwiGLU: https://arxiv.org/abs/2002.05202
20
+ - LLAMA: https://arxiv.org/abs/2302.13971
21
+
22
+ Adapted from:
23
+ - OLMO: https://github.com/allenai/OLMo
24
+ - LLAMA: https://github.com/meta/llama
25
+ """
26
+
27
+ from dataclasses import asdict
28
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+
34
+ # Handle PyTorch version compatibility for attention backend
35
+ try:
36
+ from torch.nn.attention import SDPBackend, sdpa_kernel
37
+
38
+ HAS_TORCH_ATTENTION = True
39
+ except ImportError:
40
+ # Fallback for older PyTorch versions
41
+ HAS_TORCH_ATTENTION = False
42
+ SDPBackend = None
43
+ sdpa_kernel = None
44
+
45
+ from transformers import GenerationMixin, PretrainedConfig, PreTrainedModel
46
+ from transformers.generation import GenerationConfig
47
+ from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
48
+
49
+ try:
50
+ if TYPE_CHECKING:
51
+ # We need to do this to avoid importing these when creating the HF-compatible models
52
+ from src.config import ModelConfig
53
+ except ImportError:
54
+ pass
55
+
56
+ ########################################################
57
+ #
58
+ # Layer Normalization
59
+ #
60
+ ########################################################
61
+
62
+
63
+ class RMSNorm(torch.nn.Module):
64
+ """Root Mean Square Layer Normalization.
65
+
66
+ A variant of Layer Normalization that uses RMS statistics instead of mean/variance,
67
+ resulting in improved stability and performance.
68
+
69
+ Args:
70
+ config (Union[ModelConfig, PicoHFConfig]): Configuration object containing normalization parameters
71
+ - config.norm_eps: Small constant for numerical stability
72
+ - config.d_model: Model dimension for the weight parameter
73
+
74
+ References:
75
+ https://arxiv.org/abs/1910.07467
76
+ """
77
+
78
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
79
+ super().__init__()
80
+ self.eps = config.norm_eps
81
+ self.weight = nn.Parameter(torch.ones(config.d_model))
82
+
83
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
84
+ """
85
+ Normalizes the input tensor by its RMS value.
86
+ """
87
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ """
91
+ Applies RMS normalization to the input tensor and scales it by the weight parameter.
92
+ """
93
+ output = self._norm(x.float()).type_as(x)
94
+ return output * self.weight
95
+
96
+
97
+ ########################################################
98
+ #
99
+ # Positional Embedding
100
+ #
101
+ ########################################################
102
+
103
+
104
+ class RoPE(nn.Module):
105
+ """Rotary Positional Embeddings (RoPE).
106
+
107
+ Implements position-dependent rotation of keys and queries in attention mechanism,
108
+ allowing better modeling of relative positions in sequences. Uses complex number
109
+ operations for efficient rotation.
110
+
111
+ Args:
112
+ config (Union[ModelConfig, PicoHFConfig]): Model configuration containing:
113
+ - config.position_emb_theta: Base for frequency computation
114
+ - config.d_model: Model dimension
115
+ - config.attention_n_heads: Number of attention heads
116
+ - config.max_seq_len: Maximum sequence length
117
+
118
+ References:
119
+ https://arxiv.org/abs/2104.09864
120
+ """
121
+
122
+ _freqs_cis_tensor: torch.Tensor | None = None
123
+
124
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
125
+ super().__init__()
126
+
127
+ self.theta = config.position_emb_theta
128
+ self.dim = config.d_model // config.attention_n_heads
129
+
130
+ max_seq_len = config.max_seq_len
131
+
132
+ # only gets set once, and then reused for all RoPE instances
133
+ if RoPE._freqs_cis_tensor is None:
134
+ RoPE._freqs_cis_tensor = self._setup_freqs_cis(
135
+ max_seq_len, self.theta, self.dim
136
+ )
137
+
138
+ # register _freqs_cis buffer
139
+ # can be easily recomputed so persistent=False
140
+ self.register_buffer("_freqs_cis", self._freqs_cis_tensor, persistent=False)
141
+
142
+ @classmethod
143
+ def _setup_freqs_cis(cls, seq_len: int, theta: float, dim: int) -> torch.Tensor:
144
+ """Setup Frequency Tensor for RoPE Embeddings
145
+
146
+ Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
147
+
148
+ Note other implementations will use cos and sin directly, but using the complex
149
+ number representation is (probably) more efficient:
150
+
151
+ e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
152
+ """
153
+ _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
154
+ positions = torch.arange(seq_len)
155
+ freqs = torch.outer(positions, _freqs)
156
+ return torch.polar(torch.ones_like(freqs), freqs) # complex64
157
+
158
+ def get_freqs_cis(
159
+ self, input_shape: torch.Size, start_pos: int, end_pos: int
160
+ ) -> torch.Tensor:
161
+ """Reshape Frequency Tensor for RoPE Embeddings
162
+
163
+ Makes the frequency tensor broadcastable with the input tensor.
164
+ """
165
+ _freqs_cis = self._freqs_cis[start_pos:end_pos]
166
+ ndim = len(input_shape)
167
+ assert 0 <= 1 < ndim
168
+ assert _freqs_cis.shape == (input_shape[1], input_shape[-1])
169
+
170
+ # TODO: Check whether this is correct (might be able to remove this)
171
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)]
172
+ return _freqs_cis.view(*shape)
173
+
174
+ def forward(
175
+ self,
176
+ queries: torch.Tensor,
177
+ keys: torch.Tensor,
178
+ start_pos: int = 0,
179
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
180
+ """Apply RoPE Embeddings to Queries and Keys
181
+
182
+ Applies the rotary positional embeddings to the input tensors via complex num multiplication
183
+
184
+ NOTE: The start_pos is used if we want to use the kv_cache in the attention mechanism.
185
+ """
186
+ queries_ = torch.view_as_complex(
187
+ queries.float().reshape(*queries.shape[:-1], -1, 2)
188
+ )
189
+ keys_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
190
+
191
+ input_shape = (
192
+ queries_.shape
193
+ ) # same as keys: (batch_size, seq_len, n_heads, head_dim/2)
194
+ freqs_start_pos = start_pos
195
+ freqs_end_pos = freqs_start_pos + queries_.shape[1]
196
+
197
+ freqs_cis = self.get_freqs_cis(input_shape, freqs_start_pos, freqs_end_pos)
198
+
199
+ queries_rotated = torch.view_as_real(queries_ * freqs_cis).flatten(3)
200
+ keys_rotated = torch.view_as_real(keys_ * freqs_cis).flatten(3)
201
+ return queries_rotated.type_as(queries), keys_rotated.type_as(keys)
202
+
203
+
204
+ ########################################################
205
+ #
206
+ # Attention
207
+ #
208
+ ########################################################
209
+
210
+
211
+ class Attention(nn.Module):
212
+ """Multi-head Attention with Group Query Attention support.
213
+
214
+ Implements scaled dot-product attention and supports:
215
+ - Grouped Query Attention (GQA)
216
+ - Key-Value caching for efficient inference
217
+ - RoPE integration
218
+
219
+ Args:
220
+ config (Union[ModelConfig, PretrainedConfig]): Configuration containing:
221
+ - config.attention_n_heads: Number of attention heads
222
+ - config.attention_n_kv_heads: Number of key/value heads
223
+ - config.d_model: Model dimension
224
+ - config.batch_size: Maximum batch size
225
+ - config.max_seq_len: Maximum sequence length
226
+
227
+ Shape:
228
+ - Input: (batch_size, seq_len, d_model)
229
+ - Output: (batch_size, seq_len, d_model)
230
+ """
231
+
232
+ def __init__(
233
+ self,
234
+ config: Union["ModelConfig", "PicoDecoderHFConfig"],
235
+ ):
236
+ super().__init__()
237
+
238
+ self.n_heads = config.attention_n_heads
239
+ self.n_kv_heads = config.attention_n_kv_heads
240
+
241
+ self.batch_size = config.batch_size
242
+ self.max_seq_len = config.max_seq_len
243
+
244
+ d_model = config.d_model
245
+ self.head_dim = d_model // self.n_heads
246
+
247
+ self.n_rep = self.n_heads // self.n_kv_heads
248
+
249
+ self.q_proj = nn.Linear(d_model, self.n_heads * self.head_dim, bias=False)
250
+ self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
251
+ self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
252
+ self.o_proj = nn.Linear(self.n_heads * self.head_dim, d_model, bias=False)
253
+
254
+ self.rope = RoPE(config)
255
+
256
+ def forward(
257
+ self,
258
+ input: torch.Tensor,
259
+ mask: Optional[torch.Tensor] = None,
260
+ past_key_values: Optional[Tuple[torch.Tensor, ...]] = None,
261
+ use_cache: bool = False,
262
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
263
+ """Forward pass for the attention mechanism.
264
+
265
+ Computes queries, keys, and values for the attention mechanism. Applies rotary positional
266
+ embeddings to the queries and keys, and then computes attention scores and outputs.
267
+
268
+ For an introduction to the attention mechanism, see:
269
+ https://arxiv.org/abs/1706.03762
270
+
271
+ A few things to note:
272
+ - The past_key_values is used to implement the KV cache, which is used to speed up
273
+ generation by caching the KV pairs from previous forward passes. This is useful when doing
274
+ tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
275
+ modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
276
+ its own KV cache - this KV cache is implemented as a tuple.
277
+ """
278
+ bsz, seq_len, _ = input.shape
279
+ _queries, _keys, _values = (
280
+ self.q_proj(input),
281
+ self.k_proj(input),
282
+ self.v_proj(input),
283
+ )
284
+
285
+ # Reshaping for multi-head attention
286
+ queries = _queries.view(bsz, seq_len, self.n_heads, self.head_dim)
287
+ keys = _keys.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
288
+ values = _values.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
289
+
290
+ # The start position is used to apply the RoPE embeddings to only the new tokens
291
+ # when using the kv_cache in the attention mechanism.
292
+ # We want to start from the last position in the cache.
293
+ start_pos = 0
294
+ if past_key_values is not None and past_key_values[0] is not None:
295
+ start_pos = past_key_values[0].shape[1]
296
+
297
+ # apply rotary positional embeddings
298
+ queries, keys = self.rope(queries, keys, start_pos)
299
+
300
+ if (
301
+ past_key_values is not None
302
+ and past_key_values[0] is not None
303
+ and past_key_values[1] is not None
304
+ ):
305
+ keys = torch.cat([past_key_values[0], keys], dim=1)
306
+ values = torch.cat([past_key_values[1], values], dim=1)
307
+
308
+ if use_cache:
309
+ cached_keys = keys
310
+ cached_values = values
311
+ else:
312
+ cached_keys = None
313
+ cached_values = None
314
+
315
+ queries = queries.transpose(1, 2)
316
+ keys = keys.transpose(1, 2)
317
+ values = values.transpose(1, 2)
318
+
319
+ apply_gqa = self.n_rep > 1
320
+ if apply_gqa and queries.device.type == "mps":
321
+ # NOTE: MPS does not support GQA in the SDPA kernel, but we can repeat the keys and values
322
+ # outside of the kernel to get the same effect.
323
+ # See: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
324
+ keys = keys.repeat_interleave(self.n_rep, dim=-3)
325
+ values = values.repeat_interleave(self.n_rep, dim=-3)
326
+ apply_gqa = False
327
+
328
+ if HAS_TORCH_ATTENTION:
329
+ backends = [SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]
330
+ with sdpa_kernel(backends=backends):
331
+ attn_output = F.scaled_dot_product_attention(
332
+ queries.contiguous(),
333
+ keys.contiguous(),
334
+ values.contiguous(),
335
+ attn_mask=mask.to(queries.dtype) if mask is not None else None,
336
+ enable_gqa=apply_gqa,
337
+ )
338
+ else:
339
+ # Fallback for older PyTorch versions - use default backend
340
+ attn_output = F.scaled_dot_product_attention(
341
+ queries.contiguous(),
342
+ keys.contiguous(),
343
+ values.contiguous(),
344
+ attn_mask=mask.to(queries.dtype) if mask is not None else None,
345
+ enable_gqa=apply_gqa,
346
+ )
347
+
348
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
349
+ output = self.o_proj(attn_output)
350
+
351
+ return output, (cached_keys, cached_values)
352
+
353
+
354
+ ########################################################
355
+ #
356
+ # SwiGLU (Combines MLP and Activation)
357
+ #
358
+ ########################################################
359
+
360
+
361
+ class SwiGLU(nn.Module):
362
+ """SwiGLU Activation Function with Linear Projections.
363
+
364
+ Implements the SwiGLU activation function combined with linear transformations,
365
+ serving as the feed-forward network in transformer blocks.
366
+
367
+ Args:
368
+ config (Union[ModelConfig, PicoDecoderHFConfig]): Configuration containing:
369
+ - config.d_model: Model dimension
370
+ - config.activation_hidden_dim: Hidden dimension (typically 4 * d_model)
371
+
372
+ References:
373
+ https://arxiv.org/abs/2002.05202
374
+ """
375
+
376
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
377
+ super().__init__()
378
+
379
+ model_dim = config.d_model
380
+ act_hidden_dim = config.activation_hidden_dim # usually 4 * d_model
381
+
382
+ self.w_0 = nn.Linear(model_dim, act_hidden_dim, bias=False)
383
+ self.w_1 = nn.Linear(model_dim, act_hidden_dim, bias=False)
384
+ self.w_2 = nn.Linear(act_hidden_dim, model_dim, bias=False)
385
+
386
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
387
+ return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
388
+
389
+
390
+ ########################################################
391
+ #
392
+ # PicoDecoderBlock
393
+ #
394
+ ########################################################
395
+
396
+
397
+ class PicoDecoderBlock(nn.Module):
398
+ """Single Transformer Block with Attention and Feed-forward layers.
399
+
400
+ Implements a standard transformer block with:
401
+ - Multi-head attention with normalization and residual connection
402
+ - SwiGLU feed-forward network with normalization and residual connection
403
+
404
+ Args:
405
+ config (Union[ModelConfig, PicoDecoderHFConfig]): Model configuration; either a dataclass or
406
+ a HuggingFace PicoDecoderHFConfig
407
+ """
408
+
409
+ def __init__(
410
+ self,
411
+ config: Union["ModelConfig", "PicoDecoderHFConfig"],
412
+ ):
413
+ super().__init__()
414
+
415
+ self.attention = Attention(config)
416
+ self.swiglu = SwiGLU(config)
417
+ self.attention_norm = RMSNorm(config)
418
+ self.swiglu_norm = RMSNorm(config)
419
+
420
+ def forward(
421
+ self,
422
+ input: torch.Tensor,
423
+ mask: Optional[torch.Tensor] = None,
424
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
425
+ use_cache: bool = False,
426
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
427
+ attention_output, cached_key_values = self.attention(
428
+ self.attention_norm(input),
429
+ mask=mask,
430
+ past_key_values=past_key_values,
431
+ use_cache=use_cache,
432
+ )
433
+ # NOTE: cached_key_values is None if use_cache is False
434
+
435
+ h = input + attention_output
436
+ out = h + self.swiglu(self.swiglu_norm(h))
437
+ return out, cached_key_values
438
+
439
+
440
+ ########################################################
441
+ #
442
+ # Pico Decoder (Causal Transformer Model)
443
+ #
444
+ ########################################################
445
+
446
+
447
+ class PicoDecoder(nn.Module):
448
+ """
449
+ Pico Decoder: combines the embedding, causal decoder blocks, and output projection into a
450
+ single autoregressive model.
451
+
452
+ For more information on the model, see the classes for the modules that make up the model.
453
+ """
454
+
455
+ def __init__(
456
+ self,
457
+ model_config: Union["ModelConfig", "PicoDecoderHFConfig"],
458
+ ):
459
+ super().__init__()
460
+ self.config = model_config
461
+
462
+ self.embedding_proj = nn.Embedding(self.config.vocab_size, self.config.d_model)
463
+ self.layers = nn.ModuleList(
464
+ [PicoDecoderBlock(self.config) for _ in range(self.config.n_layers)]
465
+ )
466
+ self.output_norm = RMSNorm(self.config)
467
+ self.de_embedding_proj = nn.Linear(
468
+ self.config.d_model, self.config.vocab_size, bias=False
469
+ )
470
+
471
+ def convert_to_hf_model(self) -> "PicoDecoderHF":
472
+ """Convert the Lightning model to a HuggingFace model."""
473
+ # Create HF config without fabric-specific settings
474
+ hf_config = PicoDecoderHFConfig.from_dataclass(self.config)
475
+
476
+ # Create new HF model
477
+ hf_model = PicoDecoderHF(hf_config)
478
+
479
+ # Copy state dict, excluding fabric-specific keys
480
+ hf_model.load_state_dict(self.state_dict(prefix="pico_decoder."))
481
+
482
+ return hf_model
483
+
484
+ def forward(
485
+ self,
486
+ input_ids: torch.Tensor,
487
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
488
+ use_cache: bool = False,
489
+ ) -> Tuple[torch.Tensor, Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]]]:
490
+ """
491
+ This is the forward pass for the entire Pico model. It boils down to:
492
+ - Embedding the input ids
493
+ - Creating a causal mask
494
+ - Processing through the pico layers
495
+ - Projecting the output to logits
496
+
497
+ NOTE: One feature that might be confusing is the KV cache. The KV cache is used to speed up
498
+ generation by caching the KV pairs from previous forward passes. This is useful when doing
499
+ tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
500
+ modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
501
+ its own KV cache which is stored as a tuple. The whole model then stores a tuple of these
502
+ KV caches (so a tuple of tuples).
503
+ """
504
+
505
+ seq_len = input_ids.shape[-1]
506
+ h = self.embedding_proj(input_ids)
507
+
508
+ # Calculate start position from past cached KV pairs. Remember that each layer has its
509
+ # own KV Cache. So when we index past_key_values, we need to index into the KV pairs for the
510
+ # correct layer and then for either the keys or values.
511
+ start_pos = 0
512
+ if (
513
+ past_key_values is not None
514
+ and past_key_values[0] is not None
515
+ and past_key_values[0][0] is not None
516
+ ):
517
+ start_pos = past_key_values[0][0].shape[1]
518
+
519
+ # Create causal mask for current sequence
520
+ mask = None
521
+ if seq_len > 1:
522
+ mask = torch.full((seq_len, seq_len), float("-inf"))
523
+ mask = torch.triu(mask, diagonal=1)
524
+
525
+ # If using KV cache, extend mask to cover cached sequence length
526
+ if past_key_values is not None:
527
+ # Add zeros for cached tokens (we can attend to all of them)
528
+ mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
529
+
530
+ mask = mask.to(h.device)
531
+
532
+ # NOTE: If we are using the cache, we need to store the cached KV pairs for each layer
533
+ # in a tuple. Each layer will have its own cached KV pair which we aggregate in a tuple.
534
+ cached_key_values = () if use_cache else None
535
+
536
+ # Process through transformer blocks
537
+ for idx, layer in enumerate(self.layers):
538
+ layer_past_key_values = None
539
+ if past_key_values is not None:
540
+ try:
541
+ # Handle both tuple-based cache and HuggingFace cache objects
542
+ if hasattr(past_key_values, "__getitem__") and idx < len(
543
+ past_key_values
544
+ ):
545
+ layer_past_key_values = past_key_values[idx]
546
+ except (KeyError, IndexError, TypeError):
547
+ # If we can't access the cache properly, just skip it
548
+ layer_past_key_values = None
549
+
550
+ h, layer_cached_key_values = layer(
551
+ h, mask=mask, past_key_values=layer_past_key_values, use_cache=use_cache
552
+ )
553
+
554
+ if use_cache:
555
+ cached_key_values += (layer_cached_key_values,)
556
+
557
+ # Final norm and projection
558
+ h = self.output_norm(h)
559
+ logits = self.de_embedding_proj(h).float()
560
+
561
+ return logits, cached_key_values
562
+
563
+
564
+ ########################################################
565
+ #
566
+ # HuggingFace Wrapper for the Pico Decoder model.
567
+ #
568
+ ########################################################
569
+
570
+
571
+ class PicoDecoderHFConfig(PretrainedConfig):
572
+ """Config class for the Pico Decoder HuggingFace wrapper."""
573
+
574
+ model_type = "pico_decoder"
575
+
576
+ @classmethod
577
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
578
+ """
579
+ Initialize config from a dictionary. Note that no kwargs are passed to the constructor --
580
+ this is because with some kwargs special handling is required and can make this class
581
+ brittle.
582
+ """
583
+ pico_config = cls(**config_dict)
584
+
585
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
586
+ unused_kwargs = {
587
+ key: value for key, value in kwargs.items() if not hasattr(pico_config, key)
588
+ }
589
+
590
+ if return_unused_kwargs:
591
+ return pico_config, unused_kwargs
592
+ return pico_config
593
+
594
+ @classmethod
595
+ def from_dataclass(cls, model_config: "ModelConfig"):
596
+ """Initialise from our custom config dataclass."""
597
+ return cls.from_dict(asdict(model_config))
598
+
599
+
600
+ class PicoDecoderHF(PreTrainedModel, GenerationMixin):
601
+ """
602
+ HuggingFace wrapper for the Pico model with generation support.
603
+
604
+ Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
605
+ wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
606
+ Pico model as well as the model wrapped in this HuggingFace class.
607
+
608
+ This also lets you do cool things like:
609
+
610
+ `model = AutoModelForCausalLM.from_pretrained("path/to/checkpoint")`
611
+ """
612
+
613
+ config_class = PicoDecoderHFConfig
614
+ _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
615
+ main_input_name = "input_ids"
616
+
617
+ def __init__(self, config: PicoDecoderHFConfig):
618
+ super().__init__(config)
619
+ self.pico_decoder = PicoDecoder(config)
620
+ # Initialize generation config with defaults
621
+ self.generation_config = GenerationConfig()
622
+ # Set some reasonable defaults for the model
623
+ if hasattr(config, "max_position_embeddings"):
624
+ self.generation_config.max_length = config.max_position_embeddings
625
+ if hasattr(config, "vocab_size"):
626
+ self.generation_config.vocab_size = config.vocab_size
627
+
628
+ def forward(
629
+ self,
630
+ input_ids: torch.Tensor,
631
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
632
+ use_cache: bool = False,
633
+ **kwargs,
634
+ ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
635
+ """HuggingFace forward pass wrapper.
636
+
637
+ Forwards pass for the HuggingFace version of the Pico Model. Basic wrapper around the
638
+ Pico model's forward pass, and returns the output as a HuggingFace CausalLMOutput.
639
+ """
640
+ logits, past_key_values = self.pico_decoder(
641
+ input_ids, past_key_values, use_cache
642
+ )
643
+ if use_cache:
644
+ return CausalLMOutputWithPast(
645
+ logits=logits,
646
+ past_key_values=past_key_values,
647
+ )
648
+ else:
649
+ return CausalLMOutput(
650
+ logits=logits,
651
+ )
652
+
653
+ def prepare_inputs_for_generation(
654
+ self,
655
+ input_ids: torch.LongTensor,
656
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
657
+ attention_mask: Optional[torch.LongTensor] = None,
658
+ **kwargs,
659
+ ) -> Dict[str, Any]:
660
+ """
661
+ Prepare inputs for generation.
662
+
663
+ Args:
664
+ input_ids: Input token IDs
665
+ past_key_values: Cached key-value pairs from previous forward passes
666
+ attention_mask: Attention mask for the input
667
+ **kwargs: Additional arguments
668
+
669
+ Returns:
670
+ Dictionary containing prepared inputs
671
+ """
672
+ # If we have past_key_values, we only need the last token
673
+ if past_key_values is not None:
674
+ input_ids = input_ids[:, -1:]
675
+
676
+ return {
677
+ "input_ids": input_ids,
678
+ "past_key_values": past_key_values,
679
+ "use_cache": True,
680
+ }
681
+
682
+ def get_input_embeddings(self):
683
+ """Get the input embeddings layer."""
684
+ return self.pico_decoder.embedding_proj
685
+
686
+ def set_input_embeddings(self, value):
687
+ """Set the input embeddings layer."""
688
+ self.pico_decoder.embedding_proj = value
689
+
690
+ def get_output_embeddings(self):
691
+ """Get the output embeddings layer."""
692
+ return self.pico_decoder.de_embedding_proj
693
+
694
+ def set_output_embeddings(self, value):
695
+ """Set the output embeddings layer."""
696
+ self.pico_decoder.de_embedding_proj = value
697
+
698
+ def get_lm_head(self):
699
+ """Get the language model head."""
700
+ return self.pico_decoder.de_embedding_proj
701
+
702
+ def can_generate(self) -> bool:
703
+ """Check if the model can generate text."""
704
+ return True
705
+
706
+ @property
707
+ def is_encoder_decoder(self) -> bool:
708
+ """Check if the model is an encoder-decoder model."""
709
+ return False
710
+
711
+ @property
712
+ def can_use_cache(self) -> bool:
713
+ """Check if the model can use KV cache."""
714
+ return True
715
+
716
+ def resize_token_embeddings(
717
+ self, new_num_tokens: Optional[int] = None
718
+ ) -> torch.nn.Embedding:
719
+ """Resize token embeddings."""
720
+ old_embeddings = self.get_input_embeddings()
721
+ if new_num_tokens is None:
722
+ new_num_tokens = old_embeddings.num_embeddings
723
+
724
+ new_embeddings = torch.nn.Embedding(
725
+ new_num_tokens, old_embeddings.embedding_dim
726
+ )
727
+ new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
728
+ old_embeddings.weight.data
729
+ )
730
+
731
+ self.pico_decoder.embedding_proj = new_embeddings
732
+ self.pico_decoder.de_embedding_proj = torch.nn.Linear(
733
+ old_embeddings.embedding_dim, new_num_tokens, bias=False
734
+ )
735
+
736
+ return new_embeddings
737
+
738
+
739
+ # Register for auto classes
740
+ PicoDecoderHFConfig.register_for_auto_class()
741
+ PicoDecoderHF.register_for_auto_class("AutoModel")
742
+ PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
743
+
744
+
745
+ ########################################################
746
+ #
747
+ # New PicoDecoderForCausalLM class for generation support
748
+ #
749
+ ########################################################
750
+
751
+
752
+ class PicoDecoderForCausalLM(PreTrainedModel, GenerationMixin):
753
+ """
754
+ PicoDecoderForCausalLM: A HuggingFace-compatible model that properly supports generation.
755
+
756
+ This class is designed to work with existing checkpoints and provides full generation support.
757
+ It inherits from the right base classes that HuggingFace expects for text generation.
758
+ """
759
+
760
+ config_class = PicoDecoderHFConfig
761
+ _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
762
+ main_input_name = "input_ids"
763
+
764
+ def __init__(self, config: PicoDecoderHFConfig):
765
+ super().__init__(config)
766
+ self.pico_decoder = PicoDecoder(config)
767
+ # Initialize generation config with defaults
768
+ self.generation_config = GenerationConfig()
769
+ # Set some reasonable defaults for the model
770
+ if hasattr(config, "max_position_embeddings"):
771
+ self.generation_config.max_length = config.max_position_embeddings
772
+ if hasattr(config, "vocab_size"):
773
+ self.generation_config.vocab_size = config.vocab_size
774
+
775
+ def forward(
776
+ self,
777
+ input_ids: torch.Tensor,
778
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
779
+ use_cache: bool = False,
780
+ **kwargs,
781
+ ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
782
+ """Forward pass for text generation."""
783
+ logits, past_key_values = self.pico_decoder(
784
+ input_ids, past_key_values, use_cache
785
+ )
786
+ if use_cache:
787
+ return CausalLMOutputWithPast(
788
+ logits=logits,
789
+ past_key_values=past_key_values,
790
+ )
791
+ else:
792
+ return CausalLMOutput(
793
+ logits=logits,
794
+ )
795
+
796
+ def prepare_inputs_for_generation(
797
+ self,
798
+ input_ids: torch.LongTensor,
799
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
800
+ attention_mask: Optional[torch.LongTensor] = None,
801
+ **kwargs,
802
+ ) -> Dict[str, Any]:
803
+ """Prepare inputs for generation."""
804
+ # If we have past_key_values, we only need the last token
805
+ if past_key_values is not None:
806
+ input_ids = input_ids[:, -1:]
807
+
808
+ return {
809
+ "input_ids": input_ids,
810
+ "past_key_values": past_key_values,
811
+ "use_cache": True,
812
+ }
813
+
814
+ def get_input_embeddings(self):
815
+ """Get the input embeddings layer."""
816
+ return self.pico_decoder.embedding_proj
817
+
818
+ def set_input_embeddings(self, value):
819
+ """Set the input embeddings layer."""
820
+ self.pico_decoder.embedding_proj = value
821
+
822
+ def get_output_embeddings(self):
823
+ """Get the output embeddings layer."""
824
+ return self.pico_decoder.de_embedding_proj
825
+
826
+ def set_output_embeddings(self, value):
827
+ """Set the output embeddings layer."""
828
+ self.pico_decoder.de_embedding_proj = value
829
+
830
+ def get_lm_head(self):
831
+ """Get the language model head."""
832
+ return self.pico_decoder.de_embedding_proj
833
+
834
+ def can_generate(self) -> bool:
835
+ """Check if the model can generate text."""
836
+ return True
837
+
838
+ @property
839
+ def is_encoder_decoder(self) -> bool:
840
+ """Check if the model is an encoder-decoder model."""
841
+ return False
842
+
843
+ @property
844
+ def can_use_cache(self) -> bool:
845
+ """Check if the model can use KV cache."""
846
+ return True
847
+
848
+ def resize_token_embeddings(
849
+ self, new_num_tokens: Optional[int] = None
850
+ ) -> torch.nn.Embedding:
851
+ """Resize token embeddings."""
852
+ old_embeddings = self.get_input_embeddings()
853
+ if new_num_tokens is None:
854
+ new_num_tokens = old_embeddings.num_embeddings
855
+
856
+ new_embeddings = torch.nn.Embedding(
857
+ new_num_tokens, old_embeddings.embedding_dim
858
+ )
859
+ new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
860
+ old_embeddings.weight.data
861
+ )
862
+
863
+ self.pico_decoder.embedding_proj = new_embeddings
864
+ self.pico_decoder.de_embedding_proj = torch.nn.Linear(
865
+ old_embeddings.embedding_dim, new_num_tokens, bias=False
866
+ )
867
+
868
+ return new_embeddings
869
+
870
+ @classmethod
871
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
872
+ """
873
+ Load a pretrained model from a checkpoint.
874
+
875
+ This method handles loading from both the old PicoDecoderHF format and the new format.
876
+ """
877
+ # First try to load with the new class
878
+ try:
879
+ return super().from_pretrained(
880
+ pretrained_model_name_or_path, *model_args, **kwargs
881
+ )
882
+ except Exception as e:
883
+ print(f"Failed to load with new class: {e}")
884
+ print("Attempting to load with legacy class and convert...")
885
+
886
+ # Try to load with the old class and convert
887
+ try:
888
+ from transformers import AutoModel
889
+
890
+ old_model = AutoModel.from_pretrained(
891
+ pretrained_model_name_or_path,
892
+ trust_remote_code=True,
893
+ *model_args,
894
+ **kwargs,
895
+ )
896
+
897
+ # Create new model instance
898
+ new_model = cls(old_model.config)
899
+
900
+ # Copy state dict
901
+ new_model.load_state_dict(old_model.state_dict(), strict=False)
902
+
903
+ return new_model
904
+
905
+ except Exception as e2:
906
+ print(f"Failed to convert from legacy format: {e2}")
907
+ raise e
908
+
909
+
910
+ # Register the new class
911
+ PicoDecoderForCausalLM.register_for_auto_class("AutoModelForCausalLM")
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/special_tokens_map.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "pad_token": {
10
+ "content": "<|padding|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ }
16
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
pico-decoder-tiny-dolma20M-v1/checkpoints/step_0/tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "|||IP_ADDRESS|||",
8
+ "lstrip": false,
9
+ "normalized": true,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": false
13
+ },
14
+ "1": {
15
+ "content": "<|padding|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "50254": {
23
+ "content": " ",
24
+ "lstrip": false,
25
+ "normalized": true,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": false
29
+ },
30
+ "50255": {
31
+ "content": " ",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": false
37
+ },
38
+ "50256": {
39
+ "content": " ",
40
+ "lstrip": false,
41
+ "normalized": true,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": false
45
+ },
46
+ "50257": {
47
+ "content": " ",
48
+ "lstrip": false,
49
+ "normalized": true,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": false
53
+ },
54
+ "50258": {
55
+ "content": " ",
56
+ "lstrip": false,
57
+ "normalized": true,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": false
61
+ },
62
+ "50259": {
63
+ "content": " ",
64
+ "lstrip": false,
65
+ "normalized": true,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": false
69
+ },
70
+ "50260": {
71
+ "content": " ",
72
+ "lstrip": false,
73
+ "normalized": true,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": false
77
+ },
78
+ "50261": {
79
+ "content": " ",
80
+ "lstrip": false,
81
+ "normalized": true,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": false
85
+ },
86
+ "50262": {
87
+ "content": " ",
88
+ "lstrip": false,
89
+ "normalized": true,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": false
93
+ },
94
+ "50263": {
95
+ "content": " ",
96
+ "lstrip": false,
97
+ "normalized": true,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": false
101
+ },
102
+ "50264": {
103
+ "content": " ",
104
+ "lstrip": false,
105
+ "normalized": true,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": false
109
+ },
110
+ "50265": {
111
+ "content": " ",
112
+ "lstrip": false,
113
+ "normalized": true,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": false
117
+ },
118
+ "50266": {
119
+ "content": " ",
120
+ "lstrip": false,
121
+ "normalized": true,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": false
125
+ },
126
+ "50267": {
127
+ "content": " ",
128
+ "lstrip": false,
129
+ "normalized": true,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": false
133
+ },
134
+ "50268": {
135
+ "content": " ",
136
+ "lstrip": false,
137
+ "normalized": true,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": false
141
+ },
142
+ "50269": {
143
+ "content": " ",
144
+ "lstrip": false,
145
+ "normalized": true,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": false
149
+ },
150
+ "50270": {
151
+ "content": " ",
152
+ "lstrip": false,
153
+ "normalized": true,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "50271": {
159
+ "content": " ",
160
+ "lstrip": false,
161
+ "normalized": true,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "50272": {
167
+ "content": " ",
168
+ "lstrip": false,
169
+ "normalized": true,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "50273": {
175
+ "content": " ",
176
+ "lstrip": false,
177
+ "normalized": true,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "50274": {
183
+ "content": " ",
184
+ "lstrip": false,
185
+ "normalized": true,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": false
189
+ },
190
+ "50275": {
191
+ "content": " ",
192
+ "lstrip": false,
193
+ "normalized": true,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": false
197
+ },
198
+ "50276": {
199
+ "content": " ",
200
+ "lstrip": false,
201
+ "normalized": true,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": false
205
+ },
206
+ "50277": {
207
+ "content": "|||EMAIL_ADDRESS|||",
208
+ "lstrip": false,
209
+ "normalized": true,
210
+ "rstrip": false,
211
+ "single_word": false,
212
+ "special": false
213
+ },
214
+ "50278": {
215
+ "content": "|||PHONE_NUMBER|||",
216
+ "lstrip": false,
217
+ "normalized": true,
218
+ "rstrip": false,
219
+ "single_word": false,
220
+ "special": false
221
+ },
222
+ "50279": {
223
+ "content": "<|endoftext|>",
224
+ "lstrip": false,
225
+ "normalized": false,
226
+ "rstrip": false,
227
+ "single_word": false,
228
+ "special": true
229
+ }
230
+ },
231
+ "bos_token": null,
232
+ "clean_up_tokenization_spaces": true,
233
+ "eos_token": "<|endoftext|>",
234
+ "extra_special_tokens": {},
235
+ "model_max_length": 1000000000000000019884624838656,
236
+ "pad_token": "<|padding|>",
237
+ "tokenizer_class": "GPTNeoXTokenizer",
238
+ "unk_token": null
239
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_hidden_dim": 384,
3
+ "architectures": [
4
+ "PicoDecoderHF"
5
+ ],
6
+ "attention_n_heads": 12,
7
+ "attention_n_kv_heads": 4,
8
+ "auto_map": {
9
+ "AutoConfig": "pico_decoder.PicoDecoderHFConfig",
10
+ "AutoModelForCausalLM": "pico_decoder.PicoDecoderHF"
11
+ },
12
+ "batch_size": 1024,
13
+ "d_model": 96,
14
+ "max_seq_len": 2048,
15
+ "model_type": "pico_decoder",
16
+ "n_layers": 12,
17
+ "norm_eps": 1e-06,
18
+ "position_emb_theta": 10000.0,
19
+ "torch_dtype": "float32",
20
+ "transformers_version": "4.48.3",
21
+ "vocab_size": 50304
22
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/fabric_state/checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92fa51a0afd806b08b0d199e7d2ff4555923904d7ef132046182de3335c38e8e
3
+ size 135543171
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "transformers_version": "4.48.3",
3
+ "vocab_size": 50304
4
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_activations.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d847298841bc229f66a0842b7938bf594ff1509b77e2d668d17238105391ed1f
3
+ size 98331
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_data/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7dc168d01315589299bb5c8857c28a085e3fef703290a2b99d551bc33a6fdf0
3
+ size 277160
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_data/dataset_info.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "input_ids": {
6
+ "feature": {
7
+ "dtype": "int32",
8
+ "_type": "Value"
9
+ },
10
+ "_type": "Sequence"
11
+ },
12
+ "text": {
13
+ "dtype": "string",
14
+ "_type": "Value"
15
+ }
16
+ },
17
+ "homepage": "",
18
+ "license": ""
19
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_data/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "1d11f8d9010f1e26",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_gradients.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:067570bee041e27414f53e1579b8269c0122124602a4b61263453baed7b22cb9
3
+ size 2371527
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/learning_dynamics/train_weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca9e9ed20e2ee9b41c6999b5300990a19c499db05b6dcf0de03c17627480f2b5
3
+ size 2371443
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:594904d9e9d616c61f90057c2e32dfd5323b1994996434891eacb57abf9193f1
3
+ size 45143592
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/pico_decoder.py ADDED
@@ -0,0 +1,911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pico Decoder: A Lightweight Causal Transformer Language Model
3
+
4
+ Pico Decoder uses a simple LLAMA-style transformer architecture, written for clarity and educational purposes.
5
+
6
+ Everything is written with a modular design for easy modification and experimentation.
7
+
8
+ Key features:
9
+ - RMSNorm for layer normalization
10
+ - Rotary Positional Embeddings (RoPE)
11
+ - Multi-head attention with KV-cache support
12
+ - SwiGLU activation function
13
+ - Residual connections throughout
14
+
15
+ - KV-cache for faster autoregressive generation
16
+
17
+ References:
18
+ - RoPE: https://arxiv.org/abs/2104.09864
19
+ - SwiGLU: https://arxiv.org/abs/2002.05202
20
+ - LLAMA: https://arxiv.org/abs/2302.13971
21
+
22
+ Adapted from:
23
+ - OLMO: https://github.com/allenai/OLMo
24
+ - LLAMA: https://github.com/meta/llama
25
+ """
26
+
27
+ from dataclasses import asdict
28
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+
34
+ # Handle PyTorch version compatibility for attention backend
35
+ try:
36
+ from torch.nn.attention import SDPBackend, sdpa_kernel
37
+
38
+ HAS_TORCH_ATTENTION = True
39
+ except ImportError:
40
+ # Fallback for older PyTorch versions
41
+ HAS_TORCH_ATTENTION = False
42
+ SDPBackend = None
43
+ sdpa_kernel = None
44
+
45
+ from transformers import GenerationMixin, PretrainedConfig, PreTrainedModel
46
+ from transformers.generation import GenerationConfig
47
+ from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
48
+
49
+ try:
50
+ if TYPE_CHECKING:
51
+ # We need to do this to avoid importing these when creating the HF-compatible models
52
+ from src.config import ModelConfig
53
+ except ImportError:
54
+ pass
55
+
56
+ ########################################################
57
+ #
58
+ # Layer Normalization
59
+ #
60
+ ########################################################
61
+
62
+
63
+ class RMSNorm(torch.nn.Module):
64
+ """Root Mean Square Layer Normalization.
65
+
66
+ A variant of Layer Normalization that uses RMS statistics instead of mean/variance,
67
+ resulting in improved stability and performance.
68
+
69
+ Args:
70
+ config (Union[ModelConfig, PicoHFConfig]): Configuration object containing normalization parameters
71
+ - config.norm_eps: Small constant for numerical stability
72
+ - config.d_model: Model dimension for the weight parameter
73
+
74
+ References:
75
+ https://arxiv.org/abs/1910.07467
76
+ """
77
+
78
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
79
+ super().__init__()
80
+ self.eps = config.norm_eps
81
+ self.weight = nn.Parameter(torch.ones(config.d_model))
82
+
83
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
84
+ """
85
+ Normalizes the input tensor by its RMS value.
86
+ """
87
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ """
91
+ Applies RMS normalization to the input tensor and scales it by the weight parameter.
92
+ """
93
+ output = self._norm(x.float()).type_as(x)
94
+ return output * self.weight
95
+
96
+
97
+ ########################################################
98
+ #
99
+ # Positional Embedding
100
+ #
101
+ ########################################################
102
+
103
+
104
+ class RoPE(nn.Module):
105
+ """Rotary Positional Embeddings (RoPE).
106
+
107
+ Implements position-dependent rotation of keys and queries in attention mechanism,
108
+ allowing better modeling of relative positions in sequences. Uses complex number
109
+ operations for efficient rotation.
110
+
111
+ Args:
112
+ config (Union[ModelConfig, PicoHFConfig]): Model configuration containing:
113
+ - config.position_emb_theta: Base for frequency computation
114
+ - config.d_model: Model dimension
115
+ - config.attention_n_heads: Number of attention heads
116
+ - config.max_seq_len: Maximum sequence length
117
+
118
+ References:
119
+ https://arxiv.org/abs/2104.09864
120
+ """
121
+
122
+ _freqs_cis_tensor: torch.Tensor | None = None
123
+
124
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
125
+ super().__init__()
126
+
127
+ self.theta = config.position_emb_theta
128
+ self.dim = config.d_model // config.attention_n_heads
129
+
130
+ max_seq_len = config.max_seq_len
131
+
132
+ # only gets set once, and then reused for all RoPE instances
133
+ if RoPE._freqs_cis_tensor is None:
134
+ RoPE._freqs_cis_tensor = self._setup_freqs_cis(
135
+ max_seq_len, self.theta, self.dim
136
+ )
137
+
138
+ # register _freqs_cis buffer
139
+ # can be easily recomputed so persistent=False
140
+ self.register_buffer("_freqs_cis", self._freqs_cis_tensor, persistent=False)
141
+
142
+ @classmethod
143
+ def _setup_freqs_cis(cls, seq_len: int, theta: float, dim: int) -> torch.Tensor:
144
+ """Setup Frequency Tensor for RoPE Embeddings
145
+
146
+ Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
147
+
148
+ Note other implementations will use cos and sin directly, but using the complex
149
+ number representation is (probably) more efficient:
150
+
151
+ e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
152
+ """
153
+ _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
154
+ positions = torch.arange(seq_len)
155
+ freqs = torch.outer(positions, _freqs)
156
+ return torch.polar(torch.ones_like(freqs), freqs) # complex64
157
+
158
+ def get_freqs_cis(
159
+ self, input_shape: torch.Size, start_pos: int, end_pos: int
160
+ ) -> torch.Tensor:
161
+ """Reshape Frequency Tensor for RoPE Embeddings
162
+
163
+ Makes the frequency tensor broadcastable with the input tensor.
164
+ """
165
+ _freqs_cis = self._freqs_cis[start_pos:end_pos]
166
+ ndim = len(input_shape)
167
+ assert 0 <= 1 < ndim
168
+ assert _freqs_cis.shape == (input_shape[1], input_shape[-1])
169
+
170
+ # TODO: Check whether this is correct (might be able to remove this)
171
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)]
172
+ return _freqs_cis.view(*shape)
173
+
174
+ def forward(
175
+ self,
176
+ queries: torch.Tensor,
177
+ keys: torch.Tensor,
178
+ start_pos: int = 0,
179
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
180
+ """Apply RoPE Embeddings to Queries and Keys
181
+
182
+ Applies the rotary positional embeddings to the input tensors via complex num multiplication
183
+
184
+ NOTE: The start_pos is used if we want to use the kv_cache in the attention mechanism.
185
+ """
186
+ queries_ = torch.view_as_complex(
187
+ queries.float().reshape(*queries.shape[:-1], -1, 2)
188
+ )
189
+ keys_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
190
+
191
+ input_shape = (
192
+ queries_.shape
193
+ ) # same as keys: (batch_size, seq_len, n_heads, head_dim/2)
194
+ freqs_start_pos = start_pos
195
+ freqs_end_pos = freqs_start_pos + queries_.shape[1]
196
+
197
+ freqs_cis = self.get_freqs_cis(input_shape, freqs_start_pos, freqs_end_pos)
198
+
199
+ queries_rotated = torch.view_as_real(queries_ * freqs_cis).flatten(3)
200
+ keys_rotated = torch.view_as_real(keys_ * freqs_cis).flatten(3)
201
+ return queries_rotated.type_as(queries), keys_rotated.type_as(keys)
202
+
203
+
204
+ ########################################################
205
+ #
206
+ # Attention
207
+ #
208
+ ########################################################
209
+
210
+
211
+ class Attention(nn.Module):
212
+ """Multi-head Attention with Group Query Attention support.
213
+
214
+ Implements scaled dot-product attention and supports:
215
+ - Grouped Query Attention (GQA)
216
+ - Key-Value caching for efficient inference
217
+ - RoPE integration
218
+
219
+ Args:
220
+ config (Union[ModelConfig, PretrainedConfig]): Configuration containing:
221
+ - config.attention_n_heads: Number of attention heads
222
+ - config.attention_n_kv_heads: Number of key/value heads
223
+ - config.d_model: Model dimension
224
+ - config.batch_size: Maximum batch size
225
+ - config.max_seq_len: Maximum sequence length
226
+
227
+ Shape:
228
+ - Input: (batch_size, seq_len, d_model)
229
+ - Output: (batch_size, seq_len, d_model)
230
+ """
231
+
232
+ def __init__(
233
+ self,
234
+ config: Union["ModelConfig", "PicoDecoderHFConfig"],
235
+ ):
236
+ super().__init__()
237
+
238
+ self.n_heads = config.attention_n_heads
239
+ self.n_kv_heads = config.attention_n_kv_heads
240
+
241
+ self.batch_size = config.batch_size
242
+ self.max_seq_len = config.max_seq_len
243
+
244
+ d_model = config.d_model
245
+ self.head_dim = d_model // self.n_heads
246
+
247
+ self.n_rep = self.n_heads // self.n_kv_heads
248
+
249
+ self.q_proj = nn.Linear(d_model, self.n_heads * self.head_dim, bias=False)
250
+ self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
251
+ self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
252
+ self.o_proj = nn.Linear(self.n_heads * self.head_dim, d_model, bias=False)
253
+
254
+ self.rope = RoPE(config)
255
+
256
+ def forward(
257
+ self,
258
+ input: torch.Tensor,
259
+ mask: Optional[torch.Tensor] = None,
260
+ past_key_values: Optional[Tuple[torch.Tensor, ...]] = None,
261
+ use_cache: bool = False,
262
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
263
+ """Forward pass for the attention mechanism.
264
+
265
+ Computes queries, keys, and values for the attention mechanism. Applies rotary positional
266
+ embeddings to the queries and keys, and then computes attention scores and outputs.
267
+
268
+ For an introduction to the attention mechanism, see:
269
+ https://arxiv.org/abs/1706.03762
270
+
271
+ A few things to note:
272
+ - The past_key_values is used to implement the KV cache, which is used to speed up
273
+ generation by caching the KV pairs from previous forward passes. This is useful when doing
274
+ tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
275
+ modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
276
+ its own KV cache - this KV cache is implemented as a tuple.
277
+ """
278
+ bsz, seq_len, _ = input.shape
279
+ _queries, _keys, _values = (
280
+ self.q_proj(input),
281
+ self.k_proj(input),
282
+ self.v_proj(input),
283
+ )
284
+
285
+ # Reshaping for multi-head attention
286
+ queries = _queries.view(bsz, seq_len, self.n_heads, self.head_dim)
287
+ keys = _keys.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
288
+ values = _values.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
289
+
290
+ # The start position is used to apply the RoPE embeddings to only the new tokens
291
+ # when using the kv_cache in the attention mechanism.
292
+ # We want to start from the last position in the cache.
293
+ start_pos = 0
294
+ if past_key_values is not None and past_key_values[0] is not None:
295
+ start_pos = past_key_values[0].shape[1]
296
+
297
+ # apply rotary positional embeddings
298
+ queries, keys = self.rope(queries, keys, start_pos)
299
+
300
+ if (
301
+ past_key_values is not None
302
+ and past_key_values[0] is not None
303
+ and past_key_values[1] is not None
304
+ ):
305
+ keys = torch.cat([past_key_values[0], keys], dim=1)
306
+ values = torch.cat([past_key_values[1], values], dim=1)
307
+
308
+ if use_cache:
309
+ cached_keys = keys
310
+ cached_values = values
311
+ else:
312
+ cached_keys = None
313
+ cached_values = None
314
+
315
+ queries = queries.transpose(1, 2)
316
+ keys = keys.transpose(1, 2)
317
+ values = values.transpose(1, 2)
318
+
319
+ apply_gqa = self.n_rep > 1
320
+ if apply_gqa and queries.device.type == "mps":
321
+ # NOTE: MPS does not support GQA in the SDPA kernel, but we can repeat the keys and values
322
+ # outside of the kernel to get the same effect.
323
+ # See: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
324
+ keys = keys.repeat_interleave(self.n_rep, dim=-3)
325
+ values = values.repeat_interleave(self.n_rep, dim=-3)
326
+ apply_gqa = False
327
+
328
+ if HAS_TORCH_ATTENTION:
329
+ backends = [SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]
330
+ with sdpa_kernel(backends=backends):
331
+ attn_output = F.scaled_dot_product_attention(
332
+ queries.contiguous(),
333
+ keys.contiguous(),
334
+ values.contiguous(),
335
+ attn_mask=mask.to(queries.dtype) if mask is not None else None,
336
+ enable_gqa=apply_gqa,
337
+ )
338
+ else:
339
+ # Fallback for older PyTorch versions - use default backend
340
+ attn_output = F.scaled_dot_product_attention(
341
+ queries.contiguous(),
342
+ keys.contiguous(),
343
+ values.contiguous(),
344
+ attn_mask=mask.to(queries.dtype) if mask is not None else None,
345
+ enable_gqa=apply_gqa,
346
+ )
347
+
348
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
349
+ output = self.o_proj(attn_output)
350
+
351
+ return output, (cached_keys, cached_values)
352
+
353
+
354
+ ########################################################
355
+ #
356
+ # SwiGLU (Combines MLP and Activation)
357
+ #
358
+ ########################################################
359
+
360
+
361
+ class SwiGLU(nn.Module):
362
+ """SwiGLU Activation Function with Linear Projections.
363
+
364
+ Implements the SwiGLU activation function combined with linear transformations,
365
+ serving as the feed-forward network in transformer blocks.
366
+
367
+ Args:
368
+ config (Union[ModelConfig, PicoDecoderHFConfig]): Configuration containing:
369
+ - config.d_model: Model dimension
370
+ - config.activation_hidden_dim: Hidden dimension (typically 4 * d_model)
371
+
372
+ References:
373
+ https://arxiv.org/abs/2002.05202
374
+ """
375
+
376
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
377
+ super().__init__()
378
+
379
+ model_dim = config.d_model
380
+ act_hidden_dim = config.activation_hidden_dim # usually 4 * d_model
381
+
382
+ self.w_0 = nn.Linear(model_dim, act_hidden_dim, bias=False)
383
+ self.w_1 = nn.Linear(model_dim, act_hidden_dim, bias=False)
384
+ self.w_2 = nn.Linear(act_hidden_dim, model_dim, bias=False)
385
+
386
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
387
+ return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
388
+
389
+
390
+ ########################################################
391
+ #
392
+ # PicoDecoderBlock
393
+ #
394
+ ########################################################
395
+
396
+
397
+ class PicoDecoderBlock(nn.Module):
398
+ """Single Transformer Block with Attention and Feed-forward layers.
399
+
400
+ Implements a standard transformer block with:
401
+ - Multi-head attention with normalization and residual connection
402
+ - SwiGLU feed-forward network with normalization and residual connection
403
+
404
+ Args:
405
+ config (Union[ModelConfig, PicoDecoderHFConfig]): Model configuration; either a dataclass or
406
+ a HuggingFace PicoDecoderHFConfig
407
+ """
408
+
409
+ def __init__(
410
+ self,
411
+ config: Union["ModelConfig", "PicoDecoderHFConfig"],
412
+ ):
413
+ super().__init__()
414
+
415
+ self.attention = Attention(config)
416
+ self.swiglu = SwiGLU(config)
417
+ self.attention_norm = RMSNorm(config)
418
+ self.swiglu_norm = RMSNorm(config)
419
+
420
+ def forward(
421
+ self,
422
+ input: torch.Tensor,
423
+ mask: Optional[torch.Tensor] = None,
424
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
425
+ use_cache: bool = False,
426
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
427
+ attention_output, cached_key_values = self.attention(
428
+ self.attention_norm(input),
429
+ mask=mask,
430
+ past_key_values=past_key_values,
431
+ use_cache=use_cache,
432
+ )
433
+ # NOTE: cached_key_values is None if use_cache is False
434
+
435
+ h = input + attention_output
436
+ out = h + self.swiglu(self.swiglu_norm(h))
437
+ return out, cached_key_values
438
+
439
+
440
+ ########################################################
441
+ #
442
+ # Pico Decoder (Causal Transformer Model)
443
+ #
444
+ ########################################################
445
+
446
+
447
+ class PicoDecoder(nn.Module):
448
+ """
449
+ Pico Decoder: combines the embedding, causal decoder blocks, and output projection into a
450
+ single autoregressive model.
451
+
452
+ For more information on the model, see the classes for the modules that make up the model.
453
+ """
454
+
455
+ def __init__(
456
+ self,
457
+ model_config: Union["ModelConfig", "PicoDecoderHFConfig"],
458
+ ):
459
+ super().__init__()
460
+ self.config = model_config
461
+
462
+ self.embedding_proj = nn.Embedding(self.config.vocab_size, self.config.d_model)
463
+ self.layers = nn.ModuleList(
464
+ [PicoDecoderBlock(self.config) for _ in range(self.config.n_layers)]
465
+ )
466
+ self.output_norm = RMSNorm(self.config)
467
+ self.de_embedding_proj = nn.Linear(
468
+ self.config.d_model, self.config.vocab_size, bias=False
469
+ )
470
+
471
+ def convert_to_hf_model(self) -> "PicoDecoderHF":
472
+ """Convert the Lightning model to a HuggingFace model."""
473
+ # Create HF config without fabric-specific settings
474
+ hf_config = PicoDecoderHFConfig.from_dataclass(self.config)
475
+
476
+ # Create new HF model
477
+ hf_model = PicoDecoderHF(hf_config)
478
+
479
+ # Copy state dict, excluding fabric-specific keys
480
+ hf_model.load_state_dict(self.state_dict(prefix="pico_decoder."))
481
+
482
+ return hf_model
483
+
484
+ def forward(
485
+ self,
486
+ input_ids: torch.Tensor,
487
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
488
+ use_cache: bool = False,
489
+ ) -> Tuple[torch.Tensor, Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]]]:
490
+ """
491
+ This is the forward pass for the entire Pico model. It boils down to:
492
+ - Embedding the input ids
493
+ - Creating a causal mask
494
+ - Processing through the pico layers
495
+ - Projecting the output to logits
496
+
497
+ NOTE: One feature that might be confusing is the KV cache. The KV cache is used to speed up
498
+ generation by caching the KV pairs from previous forward passes. This is useful when doing
499
+ tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
500
+ modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
501
+ its own KV cache which is stored as a tuple. The whole model then stores a tuple of these
502
+ KV caches (so a tuple of tuples).
503
+ """
504
+
505
+ seq_len = input_ids.shape[-1]
506
+ h = self.embedding_proj(input_ids)
507
+
508
+ # Calculate start position from past cached KV pairs. Remember that each layer has its
509
+ # own KV Cache. So when we index past_key_values, we need to index into the KV pairs for the
510
+ # correct layer and then for either the keys or values.
511
+ start_pos = 0
512
+ if (
513
+ past_key_values is not None
514
+ and past_key_values[0] is not None
515
+ and past_key_values[0][0] is not None
516
+ ):
517
+ start_pos = past_key_values[0][0].shape[1]
518
+
519
+ # Create causal mask for current sequence
520
+ mask = None
521
+ if seq_len > 1:
522
+ mask = torch.full((seq_len, seq_len), float("-inf"))
523
+ mask = torch.triu(mask, diagonal=1)
524
+
525
+ # If using KV cache, extend mask to cover cached sequence length
526
+ if past_key_values is not None:
527
+ # Add zeros for cached tokens (we can attend to all of them)
528
+ mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
529
+
530
+ mask = mask.to(h.device)
531
+
532
+ # NOTE: If we are using the cache, we need to store the cached KV pairs for each layer
533
+ # in a tuple. Each layer will have its own cached KV pair which we aggregate in a tuple.
534
+ cached_key_values = () if use_cache else None
535
+
536
+ # Process through transformer blocks
537
+ for idx, layer in enumerate(self.layers):
538
+ layer_past_key_values = None
539
+ if past_key_values is not None:
540
+ try:
541
+ # Handle both tuple-based cache and HuggingFace cache objects
542
+ if hasattr(past_key_values, "__getitem__") and idx < len(
543
+ past_key_values
544
+ ):
545
+ layer_past_key_values = past_key_values[idx]
546
+ except (KeyError, IndexError, TypeError):
547
+ # If we can't access the cache properly, just skip it
548
+ layer_past_key_values = None
549
+
550
+ h, layer_cached_key_values = layer(
551
+ h, mask=mask, past_key_values=layer_past_key_values, use_cache=use_cache
552
+ )
553
+
554
+ if use_cache:
555
+ cached_key_values += (layer_cached_key_values,)
556
+
557
+ # Final norm and projection
558
+ h = self.output_norm(h)
559
+ logits = self.de_embedding_proj(h).float()
560
+
561
+ return logits, cached_key_values
562
+
563
+
564
+ ########################################################
565
+ #
566
+ # HuggingFace Wrapper for the Pico Decoder model.
567
+ #
568
+ ########################################################
569
+
570
+
571
+ class PicoDecoderHFConfig(PretrainedConfig):
572
+ """Config class for the Pico Decoder HuggingFace wrapper."""
573
+
574
+ model_type = "pico_decoder"
575
+
576
+ @classmethod
577
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
578
+ """
579
+ Initialize config from a dictionary. Note that no kwargs are passed to the constructor --
580
+ this is because with some kwargs special handling is required and can make this class
581
+ brittle.
582
+ """
583
+ pico_config = cls(**config_dict)
584
+
585
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
586
+ unused_kwargs = {
587
+ key: value for key, value in kwargs.items() if not hasattr(pico_config, key)
588
+ }
589
+
590
+ if return_unused_kwargs:
591
+ return pico_config, unused_kwargs
592
+ return pico_config
593
+
594
+ @classmethod
595
+ def from_dataclass(cls, model_config: "ModelConfig"):
596
+ """Initialise from our custom config dataclass."""
597
+ return cls.from_dict(asdict(model_config))
598
+
599
+
600
+ class PicoDecoderHF(PreTrainedModel, GenerationMixin):
601
+ """
602
+ HuggingFace wrapper for the Pico model with generation support.
603
+
604
+ Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
605
+ wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
606
+ Pico model as well as the model wrapped in this HuggingFace class.
607
+
608
+ This also lets you do cool things like:
609
+
610
+ `model = AutoModelForCausalLM.from_pretrained("path/to/checkpoint")`
611
+ """
612
+
613
+ config_class = PicoDecoderHFConfig
614
+ _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
615
+ main_input_name = "input_ids"
616
+
617
+ def __init__(self, config: PicoDecoderHFConfig):
618
+ super().__init__(config)
619
+ self.pico_decoder = PicoDecoder(config)
620
+ # Initialize generation config with defaults
621
+ self.generation_config = GenerationConfig()
622
+ # Set some reasonable defaults for the model
623
+ if hasattr(config, "max_position_embeddings"):
624
+ self.generation_config.max_length = config.max_position_embeddings
625
+ if hasattr(config, "vocab_size"):
626
+ self.generation_config.vocab_size = config.vocab_size
627
+
628
+ def forward(
629
+ self,
630
+ input_ids: torch.Tensor,
631
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
632
+ use_cache: bool = False,
633
+ **kwargs,
634
+ ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
635
+ """HuggingFace forward pass wrapper.
636
+
637
+ Forwards pass for the HuggingFace version of the Pico Model. Basic wrapper around the
638
+ Pico model's forward pass, and returns the output as a HuggingFace CausalLMOutput.
639
+ """
640
+ logits, past_key_values = self.pico_decoder(
641
+ input_ids, past_key_values, use_cache
642
+ )
643
+ if use_cache:
644
+ return CausalLMOutputWithPast(
645
+ logits=logits,
646
+ past_key_values=past_key_values,
647
+ )
648
+ else:
649
+ return CausalLMOutput(
650
+ logits=logits,
651
+ )
652
+
653
+ def prepare_inputs_for_generation(
654
+ self,
655
+ input_ids: torch.LongTensor,
656
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
657
+ attention_mask: Optional[torch.LongTensor] = None,
658
+ **kwargs,
659
+ ) -> Dict[str, Any]:
660
+ """
661
+ Prepare inputs for generation.
662
+
663
+ Args:
664
+ input_ids: Input token IDs
665
+ past_key_values: Cached key-value pairs from previous forward passes
666
+ attention_mask: Attention mask for the input
667
+ **kwargs: Additional arguments
668
+
669
+ Returns:
670
+ Dictionary containing prepared inputs
671
+ """
672
+ # If we have past_key_values, we only need the last token
673
+ if past_key_values is not None:
674
+ input_ids = input_ids[:, -1:]
675
+
676
+ return {
677
+ "input_ids": input_ids,
678
+ "past_key_values": past_key_values,
679
+ "use_cache": True,
680
+ }
681
+
682
+ def get_input_embeddings(self):
683
+ """Get the input embeddings layer."""
684
+ return self.pico_decoder.embedding_proj
685
+
686
+ def set_input_embeddings(self, value):
687
+ """Set the input embeddings layer."""
688
+ self.pico_decoder.embedding_proj = value
689
+
690
+ def get_output_embeddings(self):
691
+ """Get the output embeddings layer."""
692
+ return self.pico_decoder.de_embedding_proj
693
+
694
+ def set_output_embeddings(self, value):
695
+ """Set the output embeddings layer."""
696
+ self.pico_decoder.de_embedding_proj = value
697
+
698
+ def get_lm_head(self):
699
+ """Get the language model head."""
700
+ return self.pico_decoder.de_embedding_proj
701
+
702
+ def can_generate(self) -> bool:
703
+ """Check if the model can generate text."""
704
+ return True
705
+
706
+ @property
707
+ def is_encoder_decoder(self) -> bool:
708
+ """Check if the model is an encoder-decoder model."""
709
+ return False
710
+
711
+ @property
712
+ def can_use_cache(self) -> bool:
713
+ """Check if the model can use KV cache."""
714
+ return True
715
+
716
+ def resize_token_embeddings(
717
+ self, new_num_tokens: Optional[int] = None
718
+ ) -> torch.nn.Embedding:
719
+ """Resize token embeddings."""
720
+ old_embeddings = self.get_input_embeddings()
721
+ if new_num_tokens is None:
722
+ new_num_tokens = old_embeddings.num_embeddings
723
+
724
+ new_embeddings = torch.nn.Embedding(
725
+ new_num_tokens, old_embeddings.embedding_dim
726
+ )
727
+ new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
728
+ old_embeddings.weight.data
729
+ )
730
+
731
+ self.pico_decoder.embedding_proj = new_embeddings
732
+ self.pico_decoder.de_embedding_proj = torch.nn.Linear(
733
+ old_embeddings.embedding_dim, new_num_tokens, bias=False
734
+ )
735
+
736
+ return new_embeddings
737
+
738
+
739
+ # Register for auto classes
740
+ PicoDecoderHFConfig.register_for_auto_class()
741
+ PicoDecoderHF.register_for_auto_class("AutoModel")
742
+ PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
743
+
744
+
745
+ ########################################################
746
+ #
747
+ # New PicoDecoderForCausalLM class for generation support
748
+ #
749
+ ########################################################
750
+
751
+
752
+ class PicoDecoderForCausalLM(PreTrainedModel, GenerationMixin):
753
+ """
754
+ PicoDecoderForCausalLM: A HuggingFace-compatible model that properly supports generation.
755
+
756
+ This class is designed to work with existing checkpoints and provides full generation support.
757
+ It inherits from the right base classes that HuggingFace expects for text generation.
758
+ """
759
+
760
+ config_class = PicoDecoderHFConfig
761
+ _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
762
+ main_input_name = "input_ids"
763
+
764
+ def __init__(self, config: PicoDecoderHFConfig):
765
+ super().__init__(config)
766
+ self.pico_decoder = PicoDecoder(config)
767
+ # Initialize generation config with defaults
768
+ self.generation_config = GenerationConfig()
769
+ # Set some reasonable defaults for the model
770
+ if hasattr(config, "max_position_embeddings"):
771
+ self.generation_config.max_length = config.max_position_embeddings
772
+ if hasattr(config, "vocab_size"):
773
+ self.generation_config.vocab_size = config.vocab_size
774
+
775
+ def forward(
776
+ self,
777
+ input_ids: torch.Tensor,
778
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
779
+ use_cache: bool = False,
780
+ **kwargs,
781
+ ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
782
+ """Forward pass for text generation."""
783
+ logits, past_key_values = self.pico_decoder(
784
+ input_ids, past_key_values, use_cache
785
+ )
786
+ if use_cache:
787
+ return CausalLMOutputWithPast(
788
+ logits=logits,
789
+ past_key_values=past_key_values,
790
+ )
791
+ else:
792
+ return CausalLMOutput(
793
+ logits=logits,
794
+ )
795
+
796
+ def prepare_inputs_for_generation(
797
+ self,
798
+ input_ids: torch.LongTensor,
799
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
800
+ attention_mask: Optional[torch.LongTensor] = None,
801
+ **kwargs,
802
+ ) -> Dict[str, Any]:
803
+ """Prepare inputs for generation."""
804
+ # If we have past_key_values, we only need the last token
805
+ if past_key_values is not None:
806
+ input_ids = input_ids[:, -1:]
807
+
808
+ return {
809
+ "input_ids": input_ids,
810
+ "past_key_values": past_key_values,
811
+ "use_cache": True,
812
+ }
813
+
814
+ def get_input_embeddings(self):
815
+ """Get the input embeddings layer."""
816
+ return self.pico_decoder.embedding_proj
817
+
818
+ def set_input_embeddings(self, value):
819
+ """Set the input embeddings layer."""
820
+ self.pico_decoder.embedding_proj = value
821
+
822
+ def get_output_embeddings(self):
823
+ """Get the output embeddings layer."""
824
+ return self.pico_decoder.de_embedding_proj
825
+
826
+ def set_output_embeddings(self, value):
827
+ """Set the output embeddings layer."""
828
+ self.pico_decoder.de_embedding_proj = value
829
+
830
+ def get_lm_head(self):
831
+ """Get the language model head."""
832
+ return self.pico_decoder.de_embedding_proj
833
+
834
+ def can_generate(self) -> bool:
835
+ """Check if the model can generate text."""
836
+ return True
837
+
838
+ @property
839
+ def is_encoder_decoder(self) -> bool:
840
+ """Check if the model is an encoder-decoder model."""
841
+ return False
842
+
843
+ @property
844
+ def can_use_cache(self) -> bool:
845
+ """Check if the model can use KV cache."""
846
+ return True
847
+
848
+ def resize_token_embeddings(
849
+ self, new_num_tokens: Optional[int] = None
850
+ ) -> torch.nn.Embedding:
851
+ """Resize token embeddings."""
852
+ old_embeddings = self.get_input_embeddings()
853
+ if new_num_tokens is None:
854
+ new_num_tokens = old_embeddings.num_embeddings
855
+
856
+ new_embeddings = torch.nn.Embedding(
857
+ new_num_tokens, old_embeddings.embedding_dim
858
+ )
859
+ new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
860
+ old_embeddings.weight.data
861
+ )
862
+
863
+ self.pico_decoder.embedding_proj = new_embeddings
864
+ self.pico_decoder.de_embedding_proj = torch.nn.Linear(
865
+ old_embeddings.embedding_dim, new_num_tokens, bias=False
866
+ )
867
+
868
+ return new_embeddings
869
+
870
+ @classmethod
871
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
872
+ """
873
+ Load a pretrained model from a checkpoint.
874
+
875
+ This method handles loading from both the old PicoDecoderHF format and the new format.
876
+ """
877
+ # First try to load with the new class
878
+ try:
879
+ return super().from_pretrained(
880
+ pretrained_model_name_or_path, *model_args, **kwargs
881
+ )
882
+ except Exception as e:
883
+ print(f"Failed to load with new class: {e}")
884
+ print("Attempting to load with legacy class and convert...")
885
+
886
+ # Try to load with the old class and convert
887
+ try:
888
+ from transformers import AutoModel
889
+
890
+ old_model = AutoModel.from_pretrained(
891
+ pretrained_model_name_or_path,
892
+ trust_remote_code=True,
893
+ *model_args,
894
+ **kwargs,
895
+ )
896
+
897
+ # Create new model instance
898
+ new_model = cls(old_model.config)
899
+
900
+ # Copy state dict
901
+ new_model.load_state_dict(old_model.state_dict(), strict=False)
902
+
903
+ return new_model
904
+
905
+ except Exception as e2:
906
+ print(f"Failed to convert from legacy format: {e2}")
907
+ raise e
908
+
909
+
910
+ # Register the new class
911
+ PicoDecoderForCausalLM.register_for_auto_class("AutoModelForCausalLM")
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/special_tokens_map.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "pad_token": {
10
+ "content": "<|padding|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ }
16
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
pico-decoder-tiny-dolma20M-v1/checkpoints/step_1000/tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "|||IP_ADDRESS|||",
8
+ "lstrip": false,
9
+ "normalized": true,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": false
13
+ },
14
+ "1": {
15
+ "content": "<|padding|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "50254": {
23
+ "content": " ",
24
+ "lstrip": false,
25
+ "normalized": true,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": false
29
+ },
30
+ "50255": {
31
+ "content": " ",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": false
37
+ },
38
+ "50256": {
39
+ "content": " ",
40
+ "lstrip": false,
41
+ "normalized": true,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": false
45
+ },
46
+ "50257": {
47
+ "content": " ",
48
+ "lstrip": false,
49
+ "normalized": true,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": false
53
+ },
54
+ "50258": {
55
+ "content": " ",
56
+ "lstrip": false,
57
+ "normalized": true,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": false
61
+ },
62
+ "50259": {
63
+ "content": " ",
64
+ "lstrip": false,
65
+ "normalized": true,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": false
69
+ },
70
+ "50260": {
71
+ "content": " ",
72
+ "lstrip": false,
73
+ "normalized": true,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": false
77
+ },
78
+ "50261": {
79
+ "content": " ",
80
+ "lstrip": false,
81
+ "normalized": true,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": false
85
+ },
86
+ "50262": {
87
+ "content": " ",
88
+ "lstrip": false,
89
+ "normalized": true,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": false
93
+ },
94
+ "50263": {
95
+ "content": " ",
96
+ "lstrip": false,
97
+ "normalized": true,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": false
101
+ },
102
+ "50264": {
103
+ "content": " ",
104
+ "lstrip": false,
105
+ "normalized": true,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": false
109
+ },
110
+ "50265": {
111
+ "content": " ",
112
+ "lstrip": false,
113
+ "normalized": true,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": false
117
+ },
118
+ "50266": {
119
+ "content": " ",
120
+ "lstrip": false,
121
+ "normalized": true,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": false
125
+ },
126
+ "50267": {
127
+ "content": " ",
128
+ "lstrip": false,
129
+ "normalized": true,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": false
133
+ },
134
+ "50268": {
135
+ "content": " ",
136
+ "lstrip": false,
137
+ "normalized": true,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": false
141
+ },
142
+ "50269": {
143
+ "content": " ",
144
+ "lstrip": false,
145
+ "normalized": true,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": false
149
+ },
150
+ "50270": {
151
+ "content": " ",
152
+ "lstrip": false,
153
+ "normalized": true,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "50271": {
159
+ "content": " ",
160
+ "lstrip": false,
161
+ "normalized": true,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "50272": {
167
+ "content": " ",
168
+ "lstrip": false,
169
+ "normalized": true,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "50273": {
175
+ "content": " ",
176
+ "lstrip": false,
177
+ "normalized": true,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "50274": {
183
+ "content": " ",
184
+ "lstrip": false,
185
+ "normalized": true,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": false
189
+ },
190
+ "50275": {
191
+ "content": " ",
192
+ "lstrip": false,
193
+ "normalized": true,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": false
197
+ },
198
+ "50276": {
199
+ "content": " ",
200
+ "lstrip": false,
201
+ "normalized": true,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": false
205
+ },
206
+ "50277": {
207
+ "content": "|||EMAIL_ADDRESS|||",
208
+ "lstrip": false,
209
+ "normalized": true,
210
+ "rstrip": false,
211
+ "single_word": false,
212
+ "special": false
213
+ },
214
+ "50278": {
215
+ "content": "|||PHONE_NUMBER|||",
216
+ "lstrip": false,
217
+ "normalized": true,
218
+ "rstrip": false,
219
+ "single_word": false,
220
+ "special": false
221
+ },
222
+ "50279": {
223
+ "content": "<|endoftext|>",
224
+ "lstrip": false,
225
+ "normalized": false,
226
+ "rstrip": false,
227
+ "single_word": false,
228
+ "special": true
229
+ }
230
+ },
231
+ "bos_token": null,
232
+ "clean_up_tokenization_spaces": true,
233
+ "eos_token": "<|endoftext|>",
234
+ "extra_special_tokens": {},
235
+ "model_max_length": 1000000000000000019884624838656,
236
+ "pad_token": "<|padding|>",
237
+ "tokenizer_class": "GPTNeoXTokenizer",
238
+ "unk_token": null
239
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_hidden_dim": 384,
3
+ "architectures": [
4
+ "PicoDecoderHF"
5
+ ],
6
+ "attention_n_heads": 12,
7
+ "attention_n_kv_heads": 4,
8
+ "auto_map": {
9
+ "AutoConfig": "pico_decoder.PicoDecoderHFConfig",
10
+ "AutoModelForCausalLM": "pico_decoder.PicoDecoderHF"
11
+ },
12
+ "batch_size": 1024,
13
+ "d_model": 96,
14
+ "max_seq_len": 2048,
15
+ "model_type": "pico_decoder",
16
+ "n_layers": 12,
17
+ "norm_eps": 1e-06,
18
+ "position_emb_theta": 10000.0,
19
+ "torch_dtype": "float32",
20
+ "transformers_version": "4.48.3",
21
+ "vocab_size": 50304
22
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/fabric_state/checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f63cfb2d211ae93a79dd41503954f66c62499535e964b2024491642521ef8c55
3
+ size 135543171
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "transformers_version": "4.48.3",
3
+ "vocab_size": 50304
4
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_activations.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:190349fc2f5403189b1fe8012235f07261aedad560cf90d70dc1268760ea9ea5
3
+ size 98331
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_data/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9565d0b7f27a7a6e121ff10214d8530cf1999016e90a567eac83d2a26bdeb3e
3
+ size 274672
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_data/dataset_info.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "input_ids": {
6
+ "feature": {
7
+ "dtype": "int32",
8
+ "_type": "Value"
9
+ },
10
+ "_type": "Sequence"
11
+ },
12
+ "text": {
13
+ "dtype": "string",
14
+ "_type": "Value"
15
+ }
16
+ },
17
+ "homepage": "",
18
+ "license": ""
19
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_data/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "80201725dca773a1",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_gradients.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e58cd384c1152b182faa36e56cee0ca5f28b1ddf786b8e1b68d90bba8539e9f
3
+ size 2371527
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/learning_dynamics/train_weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4de550e5abc314db55e3074d9218d203a86b523c616ed13157d48954a7fac76
3
+ size 2371443
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d849318b797c81059d055fe1e1bf7dd20699b24fa9038c056724c49de447915a
3
+ size 45143592
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/pico_decoder.py ADDED
@@ -0,0 +1,911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pico Decoder: A Lightweight Causal Transformer Language Model
3
+
4
+ Pico Decoder uses a simple LLAMA-style transformer architecture, written for clarity and educational purposes.
5
+
6
+ Everything is written with a modular design for easy modification and experimentation.
7
+
8
+ Key features:
9
+ - RMSNorm for layer normalization
10
+ - Rotary Positional Embeddings (RoPE)
11
+ - Multi-head attention with KV-cache support
12
+ - SwiGLU activation function
13
+ - Residual connections throughout
14
+
15
+ - KV-cache for faster autoregressive generation
16
+
17
+ References:
18
+ - RoPE: https://arxiv.org/abs/2104.09864
19
+ - SwiGLU: https://arxiv.org/abs/2002.05202
20
+ - LLAMA: https://arxiv.org/abs/2302.13971
21
+
22
+ Adapted from:
23
+ - OLMO: https://github.com/allenai/OLMo
24
+ - LLAMA: https://github.com/meta/llama
25
+ """
26
+
27
+ from dataclasses import asdict
28
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+
34
+ # Handle PyTorch version compatibility for attention backend
35
+ try:
36
+ from torch.nn.attention import SDPBackend, sdpa_kernel
37
+
38
+ HAS_TORCH_ATTENTION = True
39
+ except ImportError:
40
+ # Fallback for older PyTorch versions
41
+ HAS_TORCH_ATTENTION = False
42
+ SDPBackend = None
43
+ sdpa_kernel = None
44
+
45
+ from transformers import GenerationMixin, PretrainedConfig, PreTrainedModel
46
+ from transformers.generation import GenerationConfig
47
+ from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
48
+
49
+ try:
50
+ if TYPE_CHECKING:
51
+ # We need to do this to avoid importing these when creating the HF-compatible models
52
+ from src.config import ModelConfig
53
+ except ImportError:
54
+ pass
55
+
56
+ ########################################################
57
+ #
58
+ # Layer Normalization
59
+ #
60
+ ########################################################
61
+
62
+
63
+ class RMSNorm(torch.nn.Module):
64
+ """Root Mean Square Layer Normalization.
65
+
66
+ A variant of Layer Normalization that uses RMS statistics instead of mean/variance,
67
+ resulting in improved stability and performance.
68
+
69
+ Args:
70
+ config (Union[ModelConfig, PicoHFConfig]): Configuration object containing normalization parameters
71
+ - config.norm_eps: Small constant for numerical stability
72
+ - config.d_model: Model dimension for the weight parameter
73
+
74
+ References:
75
+ https://arxiv.org/abs/1910.07467
76
+ """
77
+
78
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
79
+ super().__init__()
80
+ self.eps = config.norm_eps
81
+ self.weight = nn.Parameter(torch.ones(config.d_model))
82
+
83
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
84
+ """
85
+ Normalizes the input tensor by its RMS value.
86
+ """
87
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ """
91
+ Applies RMS normalization to the input tensor and scales it by the weight parameter.
92
+ """
93
+ output = self._norm(x.float()).type_as(x)
94
+ return output * self.weight
95
+
96
+
97
+ ########################################################
98
+ #
99
+ # Positional Embedding
100
+ #
101
+ ########################################################
102
+
103
+
104
+ class RoPE(nn.Module):
105
+ """Rotary Positional Embeddings (RoPE).
106
+
107
+ Implements position-dependent rotation of keys and queries in attention mechanism,
108
+ allowing better modeling of relative positions in sequences. Uses complex number
109
+ operations for efficient rotation.
110
+
111
+ Args:
112
+ config (Union[ModelConfig, PicoHFConfig]): Model configuration containing:
113
+ - config.position_emb_theta: Base for frequency computation
114
+ - config.d_model: Model dimension
115
+ - config.attention_n_heads: Number of attention heads
116
+ - config.max_seq_len: Maximum sequence length
117
+
118
+ References:
119
+ https://arxiv.org/abs/2104.09864
120
+ """
121
+
122
+ _freqs_cis_tensor: torch.Tensor | None = None
123
+
124
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
125
+ super().__init__()
126
+
127
+ self.theta = config.position_emb_theta
128
+ self.dim = config.d_model // config.attention_n_heads
129
+
130
+ max_seq_len = config.max_seq_len
131
+
132
+ # only gets set once, and then reused for all RoPE instances
133
+ if RoPE._freqs_cis_tensor is None:
134
+ RoPE._freqs_cis_tensor = self._setup_freqs_cis(
135
+ max_seq_len, self.theta, self.dim
136
+ )
137
+
138
+ # register _freqs_cis buffer
139
+ # can be easily recomputed so persistent=False
140
+ self.register_buffer("_freqs_cis", self._freqs_cis_tensor, persistent=False)
141
+
142
+ @classmethod
143
+ def _setup_freqs_cis(cls, seq_len: int, theta: float, dim: int) -> torch.Tensor:
144
+ """Setup Frequency Tensor for RoPE Embeddings
145
+
146
+ Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
147
+
148
+ Note other implementations will use cos and sin directly, but using the complex
149
+ number representation is (probably) more efficient:
150
+
151
+ e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
152
+ """
153
+ _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
154
+ positions = torch.arange(seq_len)
155
+ freqs = torch.outer(positions, _freqs)
156
+ return torch.polar(torch.ones_like(freqs), freqs) # complex64
157
+
158
+ def get_freqs_cis(
159
+ self, input_shape: torch.Size, start_pos: int, end_pos: int
160
+ ) -> torch.Tensor:
161
+ """Reshape Frequency Tensor for RoPE Embeddings
162
+
163
+ Makes the frequency tensor broadcastable with the input tensor.
164
+ """
165
+ _freqs_cis = self._freqs_cis[start_pos:end_pos]
166
+ ndim = len(input_shape)
167
+ assert 0 <= 1 < ndim
168
+ assert _freqs_cis.shape == (input_shape[1], input_shape[-1])
169
+
170
+ # TODO: Check whether this is correct (might be able to remove this)
171
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)]
172
+ return _freqs_cis.view(*shape)
173
+
174
+ def forward(
175
+ self,
176
+ queries: torch.Tensor,
177
+ keys: torch.Tensor,
178
+ start_pos: int = 0,
179
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
180
+ """Apply RoPE Embeddings to Queries and Keys
181
+
182
+ Applies the rotary positional embeddings to the input tensors via complex num multiplication
183
+
184
+ NOTE: The start_pos is used if we want to use the kv_cache in the attention mechanism.
185
+ """
186
+ queries_ = torch.view_as_complex(
187
+ queries.float().reshape(*queries.shape[:-1], -1, 2)
188
+ )
189
+ keys_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
190
+
191
+ input_shape = (
192
+ queries_.shape
193
+ ) # same as keys: (batch_size, seq_len, n_heads, head_dim/2)
194
+ freqs_start_pos = start_pos
195
+ freqs_end_pos = freqs_start_pos + queries_.shape[1]
196
+
197
+ freqs_cis = self.get_freqs_cis(input_shape, freqs_start_pos, freqs_end_pos)
198
+
199
+ queries_rotated = torch.view_as_real(queries_ * freqs_cis).flatten(3)
200
+ keys_rotated = torch.view_as_real(keys_ * freqs_cis).flatten(3)
201
+ return queries_rotated.type_as(queries), keys_rotated.type_as(keys)
202
+
203
+
204
+ ########################################################
205
+ #
206
+ # Attention
207
+ #
208
+ ########################################################
209
+
210
+
211
+ class Attention(nn.Module):
212
+ """Multi-head Attention with Group Query Attention support.
213
+
214
+ Implements scaled dot-product attention and supports:
215
+ - Grouped Query Attention (GQA)
216
+ - Key-Value caching for efficient inference
217
+ - RoPE integration
218
+
219
+ Args:
220
+ config (Union[ModelConfig, PretrainedConfig]): Configuration containing:
221
+ - config.attention_n_heads: Number of attention heads
222
+ - config.attention_n_kv_heads: Number of key/value heads
223
+ - config.d_model: Model dimension
224
+ - config.batch_size: Maximum batch size
225
+ - config.max_seq_len: Maximum sequence length
226
+
227
+ Shape:
228
+ - Input: (batch_size, seq_len, d_model)
229
+ - Output: (batch_size, seq_len, d_model)
230
+ """
231
+
232
+ def __init__(
233
+ self,
234
+ config: Union["ModelConfig", "PicoDecoderHFConfig"],
235
+ ):
236
+ super().__init__()
237
+
238
+ self.n_heads = config.attention_n_heads
239
+ self.n_kv_heads = config.attention_n_kv_heads
240
+
241
+ self.batch_size = config.batch_size
242
+ self.max_seq_len = config.max_seq_len
243
+
244
+ d_model = config.d_model
245
+ self.head_dim = d_model // self.n_heads
246
+
247
+ self.n_rep = self.n_heads // self.n_kv_heads
248
+
249
+ self.q_proj = nn.Linear(d_model, self.n_heads * self.head_dim, bias=False)
250
+ self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
251
+ self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
252
+ self.o_proj = nn.Linear(self.n_heads * self.head_dim, d_model, bias=False)
253
+
254
+ self.rope = RoPE(config)
255
+
256
+ def forward(
257
+ self,
258
+ input: torch.Tensor,
259
+ mask: Optional[torch.Tensor] = None,
260
+ past_key_values: Optional[Tuple[torch.Tensor, ...]] = None,
261
+ use_cache: bool = False,
262
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
263
+ """Forward pass for the attention mechanism.
264
+
265
+ Computes queries, keys, and values for the attention mechanism. Applies rotary positional
266
+ embeddings to the queries and keys, and then computes attention scores and outputs.
267
+
268
+ For an introduction to the attention mechanism, see:
269
+ https://arxiv.org/abs/1706.03762
270
+
271
+ A few things to note:
272
+ - The past_key_values is used to implement the KV cache, which is used to speed up
273
+ generation by caching the KV pairs from previous forward passes. This is useful when doing
274
+ tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
275
+ modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
276
+ its own KV cache - this KV cache is implemented as a tuple.
277
+ """
278
+ bsz, seq_len, _ = input.shape
279
+ _queries, _keys, _values = (
280
+ self.q_proj(input),
281
+ self.k_proj(input),
282
+ self.v_proj(input),
283
+ )
284
+
285
+ # Reshaping for multi-head attention
286
+ queries = _queries.view(bsz, seq_len, self.n_heads, self.head_dim)
287
+ keys = _keys.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
288
+ values = _values.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
289
+
290
+ # The start position is used to apply the RoPE embeddings to only the new tokens
291
+ # when using the kv_cache in the attention mechanism.
292
+ # We want to start from the last position in the cache.
293
+ start_pos = 0
294
+ if past_key_values is not None and past_key_values[0] is not None:
295
+ start_pos = past_key_values[0].shape[1]
296
+
297
+ # apply rotary positional embeddings
298
+ queries, keys = self.rope(queries, keys, start_pos)
299
+
300
+ if (
301
+ past_key_values is not None
302
+ and past_key_values[0] is not None
303
+ and past_key_values[1] is not None
304
+ ):
305
+ keys = torch.cat([past_key_values[0], keys], dim=1)
306
+ values = torch.cat([past_key_values[1], values], dim=1)
307
+
308
+ if use_cache:
309
+ cached_keys = keys
310
+ cached_values = values
311
+ else:
312
+ cached_keys = None
313
+ cached_values = None
314
+
315
+ queries = queries.transpose(1, 2)
316
+ keys = keys.transpose(1, 2)
317
+ values = values.transpose(1, 2)
318
+
319
+ apply_gqa = self.n_rep > 1
320
+ if apply_gqa and queries.device.type == "mps":
321
+ # NOTE: MPS does not support GQA in the SDPA kernel, but we can repeat the keys and values
322
+ # outside of the kernel to get the same effect.
323
+ # See: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
324
+ keys = keys.repeat_interleave(self.n_rep, dim=-3)
325
+ values = values.repeat_interleave(self.n_rep, dim=-3)
326
+ apply_gqa = False
327
+
328
+ if HAS_TORCH_ATTENTION:
329
+ backends = [SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]
330
+ with sdpa_kernel(backends=backends):
331
+ attn_output = F.scaled_dot_product_attention(
332
+ queries.contiguous(),
333
+ keys.contiguous(),
334
+ values.contiguous(),
335
+ attn_mask=mask.to(queries.dtype) if mask is not None else None,
336
+ enable_gqa=apply_gqa,
337
+ )
338
+ else:
339
+ # Fallback for older PyTorch versions - use default backend
340
+ attn_output = F.scaled_dot_product_attention(
341
+ queries.contiguous(),
342
+ keys.contiguous(),
343
+ values.contiguous(),
344
+ attn_mask=mask.to(queries.dtype) if mask is not None else None,
345
+ enable_gqa=apply_gqa,
346
+ )
347
+
348
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
349
+ output = self.o_proj(attn_output)
350
+
351
+ return output, (cached_keys, cached_values)
352
+
353
+
354
+ ########################################################
355
+ #
356
+ # SwiGLU (Combines MLP and Activation)
357
+ #
358
+ ########################################################
359
+
360
+
361
+ class SwiGLU(nn.Module):
362
+ """SwiGLU Activation Function with Linear Projections.
363
+
364
+ Implements the SwiGLU activation function combined with linear transformations,
365
+ serving as the feed-forward network in transformer blocks.
366
+
367
+ Args:
368
+ config (Union[ModelConfig, PicoDecoderHFConfig]): Configuration containing:
369
+ - config.d_model: Model dimension
370
+ - config.activation_hidden_dim: Hidden dimension (typically 4 * d_model)
371
+
372
+ References:
373
+ https://arxiv.org/abs/2002.05202
374
+ """
375
+
376
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
377
+ super().__init__()
378
+
379
+ model_dim = config.d_model
380
+ act_hidden_dim = config.activation_hidden_dim # usually 4 * d_model
381
+
382
+ self.w_0 = nn.Linear(model_dim, act_hidden_dim, bias=False)
383
+ self.w_1 = nn.Linear(model_dim, act_hidden_dim, bias=False)
384
+ self.w_2 = nn.Linear(act_hidden_dim, model_dim, bias=False)
385
+
386
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
387
+ return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
388
+
389
+
390
+ ########################################################
391
+ #
392
+ # PicoDecoderBlock
393
+ #
394
+ ########################################################
395
+
396
+
397
+ class PicoDecoderBlock(nn.Module):
398
+ """Single Transformer Block with Attention and Feed-forward layers.
399
+
400
+ Implements a standard transformer block with:
401
+ - Multi-head attention with normalization and residual connection
402
+ - SwiGLU feed-forward network with normalization and residual connection
403
+
404
+ Args:
405
+ config (Union[ModelConfig, PicoDecoderHFConfig]): Model configuration; either a dataclass or
406
+ a HuggingFace PicoDecoderHFConfig
407
+ """
408
+
409
+ def __init__(
410
+ self,
411
+ config: Union["ModelConfig", "PicoDecoderHFConfig"],
412
+ ):
413
+ super().__init__()
414
+
415
+ self.attention = Attention(config)
416
+ self.swiglu = SwiGLU(config)
417
+ self.attention_norm = RMSNorm(config)
418
+ self.swiglu_norm = RMSNorm(config)
419
+
420
+ def forward(
421
+ self,
422
+ input: torch.Tensor,
423
+ mask: Optional[torch.Tensor] = None,
424
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
425
+ use_cache: bool = False,
426
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
427
+ attention_output, cached_key_values = self.attention(
428
+ self.attention_norm(input),
429
+ mask=mask,
430
+ past_key_values=past_key_values,
431
+ use_cache=use_cache,
432
+ )
433
+ # NOTE: cached_key_values is None if use_cache is False
434
+
435
+ h = input + attention_output
436
+ out = h + self.swiglu(self.swiglu_norm(h))
437
+ return out, cached_key_values
438
+
439
+
440
+ ########################################################
441
+ #
442
+ # Pico Decoder (Causal Transformer Model)
443
+ #
444
+ ########################################################
445
+
446
+
447
+ class PicoDecoder(nn.Module):
448
+ """
449
+ Pico Decoder: combines the embedding, causal decoder blocks, and output projection into a
450
+ single autoregressive model.
451
+
452
+ For more information on the model, see the classes for the modules that make up the model.
453
+ """
454
+
455
+ def __init__(
456
+ self,
457
+ model_config: Union["ModelConfig", "PicoDecoderHFConfig"],
458
+ ):
459
+ super().__init__()
460
+ self.config = model_config
461
+
462
+ self.embedding_proj = nn.Embedding(self.config.vocab_size, self.config.d_model)
463
+ self.layers = nn.ModuleList(
464
+ [PicoDecoderBlock(self.config) for _ in range(self.config.n_layers)]
465
+ )
466
+ self.output_norm = RMSNorm(self.config)
467
+ self.de_embedding_proj = nn.Linear(
468
+ self.config.d_model, self.config.vocab_size, bias=False
469
+ )
470
+
471
+ def convert_to_hf_model(self) -> "PicoDecoderHF":
472
+ """Convert the Lightning model to a HuggingFace model."""
473
+ # Create HF config without fabric-specific settings
474
+ hf_config = PicoDecoderHFConfig.from_dataclass(self.config)
475
+
476
+ # Create new HF model
477
+ hf_model = PicoDecoderHF(hf_config)
478
+
479
+ # Copy state dict, excluding fabric-specific keys
480
+ hf_model.load_state_dict(self.state_dict(prefix="pico_decoder."))
481
+
482
+ return hf_model
483
+
484
+ def forward(
485
+ self,
486
+ input_ids: torch.Tensor,
487
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
488
+ use_cache: bool = False,
489
+ ) -> Tuple[torch.Tensor, Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]]]:
490
+ """
491
+ This is the forward pass for the entire Pico model. It boils down to:
492
+ - Embedding the input ids
493
+ - Creating a causal mask
494
+ - Processing through the pico layers
495
+ - Projecting the output to logits
496
+
497
+ NOTE: One feature that might be confusing is the KV cache. The KV cache is used to speed up
498
+ generation by caching the KV pairs from previous forward passes. This is useful when doing
499
+ tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
500
+ modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
501
+ its own KV cache which is stored as a tuple. The whole model then stores a tuple of these
502
+ KV caches (so a tuple of tuples).
503
+ """
504
+
505
+ seq_len = input_ids.shape[-1]
506
+ h = self.embedding_proj(input_ids)
507
+
508
+ # Calculate start position from past cached KV pairs. Remember that each layer has its
509
+ # own KV Cache. So when we index past_key_values, we need to index into the KV pairs for the
510
+ # correct layer and then for either the keys or values.
511
+ start_pos = 0
512
+ if (
513
+ past_key_values is not None
514
+ and past_key_values[0] is not None
515
+ and past_key_values[0][0] is not None
516
+ ):
517
+ start_pos = past_key_values[0][0].shape[1]
518
+
519
+ # Create causal mask for current sequence
520
+ mask = None
521
+ if seq_len > 1:
522
+ mask = torch.full((seq_len, seq_len), float("-inf"))
523
+ mask = torch.triu(mask, diagonal=1)
524
+
525
+ # If using KV cache, extend mask to cover cached sequence length
526
+ if past_key_values is not None:
527
+ # Add zeros for cached tokens (we can attend to all of them)
528
+ mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
529
+
530
+ mask = mask.to(h.device)
531
+
532
+ # NOTE: If we are using the cache, we need to store the cached KV pairs for each layer
533
+ # in a tuple. Each layer will have its own cached KV pair which we aggregate in a tuple.
534
+ cached_key_values = () if use_cache else None
535
+
536
+ # Process through transformer blocks
537
+ for idx, layer in enumerate(self.layers):
538
+ layer_past_key_values = None
539
+ if past_key_values is not None:
540
+ try:
541
+ # Handle both tuple-based cache and HuggingFace cache objects
542
+ if hasattr(past_key_values, "__getitem__") and idx < len(
543
+ past_key_values
544
+ ):
545
+ layer_past_key_values = past_key_values[idx]
546
+ except (KeyError, IndexError, TypeError):
547
+ # If we can't access the cache properly, just skip it
548
+ layer_past_key_values = None
549
+
550
+ h, layer_cached_key_values = layer(
551
+ h, mask=mask, past_key_values=layer_past_key_values, use_cache=use_cache
552
+ )
553
+
554
+ if use_cache:
555
+ cached_key_values += (layer_cached_key_values,)
556
+
557
+ # Final norm and projection
558
+ h = self.output_norm(h)
559
+ logits = self.de_embedding_proj(h).float()
560
+
561
+ return logits, cached_key_values
562
+
563
+
564
+ ########################################################
565
+ #
566
+ # HuggingFace Wrapper for the Pico Decoder model.
567
+ #
568
+ ########################################################
569
+
570
+
571
+ class PicoDecoderHFConfig(PretrainedConfig):
572
+ """Config class for the Pico Decoder HuggingFace wrapper."""
573
+
574
+ model_type = "pico_decoder"
575
+
576
+ @classmethod
577
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
578
+ """
579
+ Initialize config from a dictionary. Note that no kwargs are passed to the constructor --
580
+ this is because with some kwargs special handling is required and can make this class
581
+ brittle.
582
+ """
583
+ pico_config = cls(**config_dict)
584
+
585
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
586
+ unused_kwargs = {
587
+ key: value for key, value in kwargs.items() if not hasattr(pico_config, key)
588
+ }
589
+
590
+ if return_unused_kwargs:
591
+ return pico_config, unused_kwargs
592
+ return pico_config
593
+
594
+ @classmethod
595
+ def from_dataclass(cls, model_config: "ModelConfig"):
596
+ """Initialise from our custom config dataclass."""
597
+ return cls.from_dict(asdict(model_config))
598
+
599
+
600
+ class PicoDecoderHF(PreTrainedModel, GenerationMixin):
601
+ """
602
+ HuggingFace wrapper for the Pico model with generation support.
603
+
604
+ Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
605
+ wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
606
+ Pico model as well as the model wrapped in this HuggingFace class.
607
+
608
+ This also lets you do cool things like:
609
+
610
+ `model = AutoModelForCausalLM.from_pretrained("path/to/checkpoint")`
611
+ """
612
+
613
+ config_class = PicoDecoderHFConfig
614
+ _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
615
+ main_input_name = "input_ids"
616
+
617
+ def __init__(self, config: PicoDecoderHFConfig):
618
+ super().__init__(config)
619
+ self.pico_decoder = PicoDecoder(config)
620
+ # Initialize generation config with defaults
621
+ self.generation_config = GenerationConfig()
622
+ # Set some reasonable defaults for the model
623
+ if hasattr(config, "max_position_embeddings"):
624
+ self.generation_config.max_length = config.max_position_embeddings
625
+ if hasattr(config, "vocab_size"):
626
+ self.generation_config.vocab_size = config.vocab_size
627
+
628
+ def forward(
629
+ self,
630
+ input_ids: torch.Tensor,
631
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
632
+ use_cache: bool = False,
633
+ **kwargs,
634
+ ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
635
+ """HuggingFace forward pass wrapper.
636
+
637
+ Forwards pass for the HuggingFace version of the Pico Model. Basic wrapper around the
638
+ Pico model's forward pass, and returns the output as a HuggingFace CausalLMOutput.
639
+ """
640
+ logits, past_key_values = self.pico_decoder(
641
+ input_ids, past_key_values, use_cache
642
+ )
643
+ if use_cache:
644
+ return CausalLMOutputWithPast(
645
+ logits=logits,
646
+ past_key_values=past_key_values,
647
+ )
648
+ else:
649
+ return CausalLMOutput(
650
+ logits=logits,
651
+ )
652
+
653
+ def prepare_inputs_for_generation(
654
+ self,
655
+ input_ids: torch.LongTensor,
656
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
657
+ attention_mask: Optional[torch.LongTensor] = None,
658
+ **kwargs,
659
+ ) -> Dict[str, Any]:
660
+ """
661
+ Prepare inputs for generation.
662
+
663
+ Args:
664
+ input_ids: Input token IDs
665
+ past_key_values: Cached key-value pairs from previous forward passes
666
+ attention_mask: Attention mask for the input
667
+ **kwargs: Additional arguments
668
+
669
+ Returns:
670
+ Dictionary containing prepared inputs
671
+ """
672
+ # If we have past_key_values, we only need the last token
673
+ if past_key_values is not None:
674
+ input_ids = input_ids[:, -1:]
675
+
676
+ return {
677
+ "input_ids": input_ids,
678
+ "past_key_values": past_key_values,
679
+ "use_cache": True,
680
+ }
681
+
682
+ def get_input_embeddings(self):
683
+ """Get the input embeddings layer."""
684
+ return self.pico_decoder.embedding_proj
685
+
686
+ def set_input_embeddings(self, value):
687
+ """Set the input embeddings layer."""
688
+ self.pico_decoder.embedding_proj = value
689
+
690
+ def get_output_embeddings(self):
691
+ """Get the output embeddings layer."""
692
+ return self.pico_decoder.de_embedding_proj
693
+
694
+ def set_output_embeddings(self, value):
695
+ """Set the output embeddings layer."""
696
+ self.pico_decoder.de_embedding_proj = value
697
+
698
+ def get_lm_head(self):
699
+ """Get the language model head."""
700
+ return self.pico_decoder.de_embedding_proj
701
+
702
+ def can_generate(self) -> bool:
703
+ """Check if the model can generate text."""
704
+ return True
705
+
706
+ @property
707
+ def is_encoder_decoder(self) -> bool:
708
+ """Check if the model is an encoder-decoder model."""
709
+ return False
710
+
711
+ @property
712
+ def can_use_cache(self) -> bool:
713
+ """Check if the model can use KV cache."""
714
+ return True
715
+
716
+ def resize_token_embeddings(
717
+ self, new_num_tokens: Optional[int] = None
718
+ ) -> torch.nn.Embedding:
719
+ """Resize token embeddings."""
720
+ old_embeddings = self.get_input_embeddings()
721
+ if new_num_tokens is None:
722
+ new_num_tokens = old_embeddings.num_embeddings
723
+
724
+ new_embeddings = torch.nn.Embedding(
725
+ new_num_tokens, old_embeddings.embedding_dim
726
+ )
727
+ new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
728
+ old_embeddings.weight.data
729
+ )
730
+
731
+ self.pico_decoder.embedding_proj = new_embeddings
732
+ self.pico_decoder.de_embedding_proj = torch.nn.Linear(
733
+ old_embeddings.embedding_dim, new_num_tokens, bias=False
734
+ )
735
+
736
+ return new_embeddings
737
+
738
+
739
+ # Register for auto classes
740
+ PicoDecoderHFConfig.register_for_auto_class()
741
+ PicoDecoderHF.register_for_auto_class("AutoModel")
742
+ PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
743
+
744
+
745
+ ########################################################
746
+ #
747
+ # New PicoDecoderForCausalLM class for generation support
748
+ #
749
+ ########################################################
750
+
751
+
752
+ class PicoDecoderForCausalLM(PreTrainedModel, GenerationMixin):
753
+ """
754
+ PicoDecoderForCausalLM: A HuggingFace-compatible model that properly supports generation.
755
+
756
+ This class is designed to work with existing checkpoints and provides full generation support.
757
+ It inherits from the right base classes that HuggingFace expects for text generation.
758
+ """
759
+
760
+ config_class = PicoDecoderHFConfig
761
+ _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
762
+ main_input_name = "input_ids"
763
+
764
+ def __init__(self, config: PicoDecoderHFConfig):
765
+ super().__init__(config)
766
+ self.pico_decoder = PicoDecoder(config)
767
+ # Initialize generation config with defaults
768
+ self.generation_config = GenerationConfig()
769
+ # Set some reasonable defaults for the model
770
+ if hasattr(config, "max_position_embeddings"):
771
+ self.generation_config.max_length = config.max_position_embeddings
772
+ if hasattr(config, "vocab_size"):
773
+ self.generation_config.vocab_size = config.vocab_size
774
+
775
+ def forward(
776
+ self,
777
+ input_ids: torch.Tensor,
778
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
779
+ use_cache: bool = False,
780
+ **kwargs,
781
+ ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
782
+ """Forward pass for text generation."""
783
+ logits, past_key_values = self.pico_decoder(
784
+ input_ids, past_key_values, use_cache
785
+ )
786
+ if use_cache:
787
+ return CausalLMOutputWithPast(
788
+ logits=logits,
789
+ past_key_values=past_key_values,
790
+ )
791
+ else:
792
+ return CausalLMOutput(
793
+ logits=logits,
794
+ )
795
+
796
+ def prepare_inputs_for_generation(
797
+ self,
798
+ input_ids: torch.LongTensor,
799
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
800
+ attention_mask: Optional[torch.LongTensor] = None,
801
+ **kwargs,
802
+ ) -> Dict[str, Any]:
803
+ """Prepare inputs for generation."""
804
+ # If we have past_key_values, we only need the last token
805
+ if past_key_values is not None:
806
+ input_ids = input_ids[:, -1:]
807
+
808
+ return {
809
+ "input_ids": input_ids,
810
+ "past_key_values": past_key_values,
811
+ "use_cache": True,
812
+ }
813
+
814
+ def get_input_embeddings(self):
815
+ """Get the input embeddings layer."""
816
+ return self.pico_decoder.embedding_proj
817
+
818
+ def set_input_embeddings(self, value):
819
+ """Set the input embeddings layer."""
820
+ self.pico_decoder.embedding_proj = value
821
+
822
+ def get_output_embeddings(self):
823
+ """Get the output embeddings layer."""
824
+ return self.pico_decoder.de_embedding_proj
825
+
826
+ def set_output_embeddings(self, value):
827
+ """Set the output embeddings layer."""
828
+ self.pico_decoder.de_embedding_proj = value
829
+
830
+ def get_lm_head(self):
831
+ """Get the language model head."""
832
+ return self.pico_decoder.de_embedding_proj
833
+
834
+ def can_generate(self) -> bool:
835
+ """Check if the model can generate text."""
836
+ return True
837
+
838
+ @property
839
+ def is_encoder_decoder(self) -> bool:
840
+ """Check if the model is an encoder-decoder model."""
841
+ return False
842
+
843
+ @property
844
+ def can_use_cache(self) -> bool:
845
+ """Check if the model can use KV cache."""
846
+ return True
847
+
848
+ def resize_token_embeddings(
849
+ self, new_num_tokens: Optional[int] = None
850
+ ) -> torch.nn.Embedding:
851
+ """Resize token embeddings."""
852
+ old_embeddings = self.get_input_embeddings()
853
+ if new_num_tokens is None:
854
+ new_num_tokens = old_embeddings.num_embeddings
855
+
856
+ new_embeddings = torch.nn.Embedding(
857
+ new_num_tokens, old_embeddings.embedding_dim
858
+ )
859
+ new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
860
+ old_embeddings.weight.data
861
+ )
862
+
863
+ self.pico_decoder.embedding_proj = new_embeddings
864
+ self.pico_decoder.de_embedding_proj = torch.nn.Linear(
865
+ old_embeddings.embedding_dim, new_num_tokens, bias=False
866
+ )
867
+
868
+ return new_embeddings
869
+
870
+ @classmethod
871
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
872
+ """
873
+ Load a pretrained model from a checkpoint.
874
+
875
+ This method handles loading from both the old PicoDecoderHF format and the new format.
876
+ """
877
+ # First try to load with the new class
878
+ try:
879
+ return super().from_pretrained(
880
+ pretrained_model_name_or_path, *model_args, **kwargs
881
+ )
882
+ except Exception as e:
883
+ print(f"Failed to load with new class: {e}")
884
+ print("Attempting to load with legacy class and convert...")
885
+
886
+ # Try to load with the old class and convert
887
+ try:
888
+ from transformers import AutoModel
889
+
890
+ old_model = AutoModel.from_pretrained(
891
+ pretrained_model_name_or_path,
892
+ trust_remote_code=True,
893
+ *model_args,
894
+ **kwargs,
895
+ )
896
+
897
+ # Create new model instance
898
+ new_model = cls(old_model.config)
899
+
900
+ # Copy state dict
901
+ new_model.load_state_dict(old_model.state_dict(), strict=False)
902
+
903
+ return new_model
904
+
905
+ except Exception as e2:
906
+ print(f"Failed to convert from legacy format: {e2}")
907
+ raise e
908
+
909
+
910
+ # Register the new class
911
+ PicoDecoderForCausalLM.register_for_auto_class("AutoModelForCausalLM")
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/special_tokens_map.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "pad_token": {
10
+ "content": "<|padding|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ }
16
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
pico-decoder-tiny-dolma20M-v1/checkpoints/step_10000/tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "|||IP_ADDRESS|||",
8
+ "lstrip": false,
9
+ "normalized": true,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": false
13
+ },
14
+ "1": {
15
+ "content": "<|padding|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "50254": {
23
+ "content": " ",
24
+ "lstrip": false,
25
+ "normalized": true,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": false
29
+ },
30
+ "50255": {
31
+ "content": " ",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": false
37
+ },
38
+ "50256": {
39
+ "content": " ",
40
+ "lstrip": false,
41
+ "normalized": true,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": false
45
+ },
46
+ "50257": {
47
+ "content": " ",
48
+ "lstrip": false,
49
+ "normalized": true,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": false
53
+ },
54
+ "50258": {
55
+ "content": " ",
56
+ "lstrip": false,
57
+ "normalized": true,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": false
61
+ },
62
+ "50259": {
63
+ "content": " ",
64
+ "lstrip": false,
65
+ "normalized": true,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": false
69
+ },
70
+ "50260": {
71
+ "content": " ",
72
+ "lstrip": false,
73
+ "normalized": true,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": false
77
+ },
78
+ "50261": {
79
+ "content": " ",
80
+ "lstrip": false,
81
+ "normalized": true,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": false
85
+ },
86
+ "50262": {
87
+ "content": " ",
88
+ "lstrip": false,
89
+ "normalized": true,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": false
93
+ },
94
+ "50263": {
95
+ "content": " ",
96
+ "lstrip": false,
97
+ "normalized": true,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": false
101
+ },
102
+ "50264": {
103
+ "content": " ",
104
+ "lstrip": false,
105
+ "normalized": true,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": false
109
+ },
110
+ "50265": {
111
+ "content": " ",
112
+ "lstrip": false,
113
+ "normalized": true,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": false
117
+ },
118
+ "50266": {
119
+ "content": " ",
120
+ "lstrip": false,
121
+ "normalized": true,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": false
125
+ },
126
+ "50267": {
127
+ "content": " ",
128
+ "lstrip": false,
129
+ "normalized": true,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": false
133
+ },
134
+ "50268": {
135
+ "content": " ",
136
+ "lstrip": false,
137
+ "normalized": true,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": false
141
+ },
142
+ "50269": {
143
+ "content": " ",
144
+ "lstrip": false,
145
+ "normalized": true,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": false
149
+ },
150
+ "50270": {
151
+ "content": " ",
152
+ "lstrip": false,
153
+ "normalized": true,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "50271": {
159
+ "content": " ",
160
+ "lstrip": false,
161
+ "normalized": true,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "50272": {
167
+ "content": " ",
168
+ "lstrip": false,
169
+ "normalized": true,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "50273": {
175
+ "content": " ",
176
+ "lstrip": false,
177
+ "normalized": true,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "50274": {
183
+ "content": " ",
184
+ "lstrip": false,
185
+ "normalized": true,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": false
189
+ },
190
+ "50275": {
191
+ "content": " ",
192
+ "lstrip": false,
193
+ "normalized": true,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": false
197
+ },
198
+ "50276": {
199
+ "content": " ",
200
+ "lstrip": false,
201
+ "normalized": true,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": false
205
+ },
206
+ "50277": {
207
+ "content": "|||EMAIL_ADDRESS|||",
208
+ "lstrip": false,
209
+ "normalized": true,
210
+ "rstrip": false,
211
+ "single_word": false,
212
+ "special": false
213
+ },
214
+ "50278": {
215
+ "content": "|||PHONE_NUMBER|||",
216
+ "lstrip": false,
217
+ "normalized": true,
218
+ "rstrip": false,
219
+ "single_word": false,
220
+ "special": false
221
+ },
222
+ "50279": {
223
+ "content": "<|endoftext|>",
224
+ "lstrip": false,
225
+ "normalized": false,
226
+ "rstrip": false,
227
+ "single_word": false,
228
+ "special": true
229
+ }
230
+ },
231
+ "bos_token": null,
232
+ "clean_up_tokenization_spaces": true,
233
+ "eos_token": "<|endoftext|>",
234
+ "extra_special_tokens": {},
235
+ "model_max_length": 1000000000000000019884624838656,
236
+ "pad_token": "<|padding|>",
237
+ "tokenizer_class": "GPTNeoXTokenizer",
238
+ "unk_token": null
239
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_100000/config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_hidden_dim": 384,
3
+ "architectures": [
4
+ "PicoDecoderHF"
5
+ ],
6
+ "attention_n_heads": 12,
7
+ "attention_n_kv_heads": 4,
8
+ "auto_map": {
9
+ "AutoConfig": "pico_decoder.PicoDecoderHFConfig",
10
+ "AutoModelForCausalLM": "pico_decoder.PicoDecoderHF"
11
+ },
12
+ "batch_size": 1024,
13
+ "d_model": 96,
14
+ "max_seq_len": 2048,
15
+ "model_type": "pico_decoder",
16
+ "n_layers": 12,
17
+ "norm_eps": 1e-06,
18
+ "position_emb_theta": 10000.0,
19
+ "torch_dtype": "float32",
20
+ "transformers_version": "4.48.3",
21
+ "vocab_size": 50304
22
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_100000/fabric_state/checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f1a4640d6d2b47183e24f7da485cdb5c819d35a03f56f6d2a14a0c8953a324f
3
+ size 135543171
pico-decoder-tiny-dolma20M-v1/checkpoints/step_100000/generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "transformers_version": "4.48.3",
3
+ "vocab_size": 50304
4
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_100000/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ddba0ef3f06d7dfc662d560efc8fddec48da59a6bfb0f826238c9050429eb1e
3
+ size 45143592
pico-decoder-tiny-dolma20M-v1/checkpoints/step_100000/pico_decoder.py ADDED
@@ -0,0 +1,911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pico Decoder: A Lightweight Causal Transformer Language Model
3
+
4
+ Pico Decoder uses a simple LLAMA-style transformer architecture, written for clarity and educational purposes.
5
+
6
+ Everything is written with a modular design for easy modification and experimentation.
7
+
8
+ Key features:
9
+ - RMSNorm for layer normalization
10
+ - Rotary Positional Embeddings (RoPE)
11
+ - Multi-head attention with KV-cache support
12
+ - SwiGLU activation function
13
+ - Residual connections throughout
14
+
15
+ - KV-cache for faster autoregressive generation
16
+
17
+ References:
18
+ - RoPE: https://arxiv.org/abs/2104.09864
19
+ - SwiGLU: https://arxiv.org/abs/2002.05202
20
+ - LLAMA: https://arxiv.org/abs/2302.13971
21
+
22
+ Adapted from:
23
+ - OLMO: https://github.com/allenai/OLMo
24
+ - LLAMA: https://github.com/meta/llama
25
+ """
26
+
27
+ from dataclasses import asdict
28
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+
34
+ # Handle PyTorch version compatibility for attention backend
35
+ try:
36
+ from torch.nn.attention import SDPBackend, sdpa_kernel
37
+
38
+ HAS_TORCH_ATTENTION = True
39
+ except ImportError:
40
+ # Fallback for older PyTorch versions
41
+ HAS_TORCH_ATTENTION = False
42
+ SDPBackend = None
43
+ sdpa_kernel = None
44
+
45
+ from transformers import GenerationMixin, PretrainedConfig, PreTrainedModel
46
+ from transformers.generation import GenerationConfig
47
+ from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
48
+
49
+ try:
50
+ if TYPE_CHECKING:
51
+ # We need to do this to avoid importing these when creating the HF-compatible models
52
+ from src.config import ModelConfig
53
+ except ImportError:
54
+ pass
55
+
56
+ ########################################################
57
+ #
58
+ # Layer Normalization
59
+ #
60
+ ########################################################
61
+
62
+
63
+ class RMSNorm(torch.nn.Module):
64
+ """Root Mean Square Layer Normalization.
65
+
66
+ A variant of Layer Normalization that uses RMS statistics instead of mean/variance,
67
+ resulting in improved stability and performance.
68
+
69
+ Args:
70
+ config (Union[ModelConfig, PicoHFConfig]): Configuration object containing normalization parameters
71
+ - config.norm_eps: Small constant for numerical stability
72
+ - config.d_model: Model dimension for the weight parameter
73
+
74
+ References:
75
+ https://arxiv.org/abs/1910.07467
76
+ """
77
+
78
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
79
+ super().__init__()
80
+ self.eps = config.norm_eps
81
+ self.weight = nn.Parameter(torch.ones(config.d_model))
82
+
83
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
84
+ """
85
+ Normalizes the input tensor by its RMS value.
86
+ """
87
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ """
91
+ Applies RMS normalization to the input tensor and scales it by the weight parameter.
92
+ """
93
+ output = self._norm(x.float()).type_as(x)
94
+ return output * self.weight
95
+
96
+
97
+ ########################################################
98
+ #
99
+ # Positional Embedding
100
+ #
101
+ ########################################################
102
+
103
+
104
+ class RoPE(nn.Module):
105
+ """Rotary Positional Embeddings (RoPE).
106
+
107
+ Implements position-dependent rotation of keys and queries in attention mechanism,
108
+ allowing better modeling of relative positions in sequences. Uses complex number
109
+ operations for efficient rotation.
110
+
111
+ Args:
112
+ config (Union[ModelConfig, PicoHFConfig]): Model configuration containing:
113
+ - config.position_emb_theta: Base for frequency computation
114
+ - config.d_model: Model dimension
115
+ - config.attention_n_heads: Number of attention heads
116
+ - config.max_seq_len: Maximum sequence length
117
+
118
+ References:
119
+ https://arxiv.org/abs/2104.09864
120
+ """
121
+
122
+ _freqs_cis_tensor: torch.Tensor | None = None
123
+
124
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
125
+ super().__init__()
126
+
127
+ self.theta = config.position_emb_theta
128
+ self.dim = config.d_model // config.attention_n_heads
129
+
130
+ max_seq_len = config.max_seq_len
131
+
132
+ # only gets set once, and then reused for all RoPE instances
133
+ if RoPE._freqs_cis_tensor is None:
134
+ RoPE._freqs_cis_tensor = self._setup_freqs_cis(
135
+ max_seq_len, self.theta, self.dim
136
+ )
137
+
138
+ # register _freqs_cis buffer
139
+ # can be easily recomputed so persistent=False
140
+ self.register_buffer("_freqs_cis", self._freqs_cis_tensor, persistent=False)
141
+
142
+ @classmethod
143
+ def _setup_freqs_cis(cls, seq_len: int, theta: float, dim: int) -> torch.Tensor:
144
+ """Setup Frequency Tensor for RoPE Embeddings
145
+
146
+ Initializes the complex frequency tensor that is used to compute the RoPE embeddings.
147
+
148
+ Note other implementations will use cos and sin directly, but using the complex
149
+ number representation is (probably) more efficient:
150
+
151
+ e^(theta * i * t) = cos(theta * t) + i * sin(theta * t) [Euler's formula]
152
+ """
153
+ _freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
154
+ positions = torch.arange(seq_len)
155
+ freqs = torch.outer(positions, _freqs)
156
+ return torch.polar(torch.ones_like(freqs), freqs) # complex64
157
+
158
+ def get_freqs_cis(
159
+ self, input_shape: torch.Size, start_pos: int, end_pos: int
160
+ ) -> torch.Tensor:
161
+ """Reshape Frequency Tensor for RoPE Embeddings
162
+
163
+ Makes the frequency tensor broadcastable with the input tensor.
164
+ """
165
+ _freqs_cis = self._freqs_cis[start_pos:end_pos]
166
+ ndim = len(input_shape)
167
+ assert 0 <= 1 < ndim
168
+ assert _freqs_cis.shape == (input_shape[1], input_shape[-1])
169
+
170
+ # TODO: Check whether this is correct (might be able to remove this)
171
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(input_shape)]
172
+ return _freqs_cis.view(*shape)
173
+
174
+ def forward(
175
+ self,
176
+ queries: torch.Tensor,
177
+ keys: torch.Tensor,
178
+ start_pos: int = 0,
179
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
180
+ """Apply RoPE Embeddings to Queries and Keys
181
+
182
+ Applies the rotary positional embeddings to the input tensors via complex num multiplication
183
+
184
+ NOTE: The start_pos is used if we want to use the kv_cache in the attention mechanism.
185
+ """
186
+ queries_ = torch.view_as_complex(
187
+ queries.float().reshape(*queries.shape[:-1], -1, 2)
188
+ )
189
+ keys_ = torch.view_as_complex(keys.float().reshape(*keys.shape[:-1], -1, 2))
190
+
191
+ input_shape = (
192
+ queries_.shape
193
+ ) # same as keys: (batch_size, seq_len, n_heads, head_dim/2)
194
+ freqs_start_pos = start_pos
195
+ freqs_end_pos = freqs_start_pos + queries_.shape[1]
196
+
197
+ freqs_cis = self.get_freqs_cis(input_shape, freqs_start_pos, freqs_end_pos)
198
+
199
+ queries_rotated = torch.view_as_real(queries_ * freqs_cis).flatten(3)
200
+ keys_rotated = torch.view_as_real(keys_ * freqs_cis).flatten(3)
201
+ return queries_rotated.type_as(queries), keys_rotated.type_as(keys)
202
+
203
+
204
+ ########################################################
205
+ #
206
+ # Attention
207
+ #
208
+ ########################################################
209
+
210
+
211
+ class Attention(nn.Module):
212
+ """Multi-head Attention with Group Query Attention support.
213
+
214
+ Implements scaled dot-product attention and supports:
215
+ - Grouped Query Attention (GQA)
216
+ - Key-Value caching for efficient inference
217
+ - RoPE integration
218
+
219
+ Args:
220
+ config (Union[ModelConfig, PretrainedConfig]): Configuration containing:
221
+ - config.attention_n_heads: Number of attention heads
222
+ - config.attention_n_kv_heads: Number of key/value heads
223
+ - config.d_model: Model dimension
224
+ - config.batch_size: Maximum batch size
225
+ - config.max_seq_len: Maximum sequence length
226
+
227
+ Shape:
228
+ - Input: (batch_size, seq_len, d_model)
229
+ - Output: (batch_size, seq_len, d_model)
230
+ """
231
+
232
+ def __init__(
233
+ self,
234
+ config: Union["ModelConfig", "PicoDecoderHFConfig"],
235
+ ):
236
+ super().__init__()
237
+
238
+ self.n_heads = config.attention_n_heads
239
+ self.n_kv_heads = config.attention_n_kv_heads
240
+
241
+ self.batch_size = config.batch_size
242
+ self.max_seq_len = config.max_seq_len
243
+
244
+ d_model = config.d_model
245
+ self.head_dim = d_model // self.n_heads
246
+
247
+ self.n_rep = self.n_heads // self.n_kv_heads
248
+
249
+ self.q_proj = nn.Linear(d_model, self.n_heads * self.head_dim, bias=False)
250
+ self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
251
+ self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
252
+ self.o_proj = nn.Linear(self.n_heads * self.head_dim, d_model, bias=False)
253
+
254
+ self.rope = RoPE(config)
255
+
256
+ def forward(
257
+ self,
258
+ input: torch.Tensor,
259
+ mask: Optional[torch.Tensor] = None,
260
+ past_key_values: Optional[Tuple[torch.Tensor, ...]] = None,
261
+ use_cache: bool = False,
262
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
263
+ """Forward pass for the attention mechanism.
264
+
265
+ Computes queries, keys, and values for the attention mechanism. Applies rotary positional
266
+ embeddings to the queries and keys, and then computes attention scores and outputs.
267
+
268
+ For an introduction to the attention mechanism, see:
269
+ https://arxiv.org/abs/1706.03762
270
+
271
+ A few things to note:
272
+ - The past_key_values is used to implement the KV cache, which is used to speed up
273
+ generation by caching the KV pairs from previous forward passes. This is useful when doing
274
+ tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
275
+ modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
276
+ its own KV cache - this KV cache is implemented as a tuple.
277
+ """
278
+ bsz, seq_len, _ = input.shape
279
+ _queries, _keys, _values = (
280
+ self.q_proj(input),
281
+ self.k_proj(input),
282
+ self.v_proj(input),
283
+ )
284
+
285
+ # Reshaping for multi-head attention
286
+ queries = _queries.view(bsz, seq_len, self.n_heads, self.head_dim)
287
+ keys = _keys.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
288
+ values = _values.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
289
+
290
+ # The start position is used to apply the RoPE embeddings to only the new tokens
291
+ # when using the kv_cache in the attention mechanism.
292
+ # We want to start from the last position in the cache.
293
+ start_pos = 0
294
+ if past_key_values is not None and past_key_values[0] is not None:
295
+ start_pos = past_key_values[0].shape[1]
296
+
297
+ # apply rotary positional embeddings
298
+ queries, keys = self.rope(queries, keys, start_pos)
299
+
300
+ if (
301
+ past_key_values is not None
302
+ and past_key_values[0] is not None
303
+ and past_key_values[1] is not None
304
+ ):
305
+ keys = torch.cat([past_key_values[0], keys], dim=1)
306
+ values = torch.cat([past_key_values[1], values], dim=1)
307
+
308
+ if use_cache:
309
+ cached_keys = keys
310
+ cached_values = values
311
+ else:
312
+ cached_keys = None
313
+ cached_values = None
314
+
315
+ queries = queries.transpose(1, 2)
316
+ keys = keys.transpose(1, 2)
317
+ values = values.transpose(1, 2)
318
+
319
+ apply_gqa = self.n_rep > 1
320
+ if apply_gqa and queries.device.type == "mps":
321
+ # NOTE: MPS does not support GQA in the SDPA kernel, but we can repeat the keys and values
322
+ # outside of the kernel to get the same effect.
323
+ # See: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
324
+ keys = keys.repeat_interleave(self.n_rep, dim=-3)
325
+ values = values.repeat_interleave(self.n_rep, dim=-3)
326
+ apply_gqa = False
327
+
328
+ if HAS_TORCH_ATTENTION:
329
+ backends = [SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]
330
+ with sdpa_kernel(backends=backends):
331
+ attn_output = F.scaled_dot_product_attention(
332
+ queries.contiguous(),
333
+ keys.contiguous(),
334
+ values.contiguous(),
335
+ attn_mask=mask.to(queries.dtype) if mask is not None else None,
336
+ enable_gqa=apply_gqa,
337
+ )
338
+ else:
339
+ # Fallback for older PyTorch versions - use default backend
340
+ attn_output = F.scaled_dot_product_attention(
341
+ queries.contiguous(),
342
+ keys.contiguous(),
343
+ values.contiguous(),
344
+ attn_mask=mask.to(queries.dtype) if mask is not None else None,
345
+ enable_gqa=apply_gqa,
346
+ )
347
+
348
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
349
+ output = self.o_proj(attn_output)
350
+
351
+ return output, (cached_keys, cached_values)
352
+
353
+
354
+ ########################################################
355
+ #
356
+ # SwiGLU (Combines MLP and Activation)
357
+ #
358
+ ########################################################
359
+
360
+
361
+ class SwiGLU(nn.Module):
362
+ """SwiGLU Activation Function with Linear Projections.
363
+
364
+ Implements the SwiGLU activation function combined with linear transformations,
365
+ serving as the feed-forward network in transformer blocks.
366
+
367
+ Args:
368
+ config (Union[ModelConfig, PicoDecoderHFConfig]): Configuration containing:
369
+ - config.d_model: Model dimension
370
+ - config.activation_hidden_dim: Hidden dimension (typically 4 * d_model)
371
+
372
+ References:
373
+ https://arxiv.org/abs/2002.05202
374
+ """
375
+
376
+ def __init__(self, config: Union["ModelConfig", "PicoDecoderHFConfig"]):
377
+ super().__init__()
378
+
379
+ model_dim = config.d_model
380
+ act_hidden_dim = config.activation_hidden_dim # usually 4 * d_model
381
+
382
+ self.w_0 = nn.Linear(model_dim, act_hidden_dim, bias=False)
383
+ self.w_1 = nn.Linear(model_dim, act_hidden_dim, bias=False)
384
+ self.w_2 = nn.Linear(act_hidden_dim, model_dim, bias=False)
385
+
386
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
387
+ return self.w_2(F.silu(self.w_0(x)) * self.w_1(x))
388
+
389
+
390
+ ########################################################
391
+ #
392
+ # PicoDecoderBlock
393
+ #
394
+ ########################################################
395
+
396
+
397
+ class PicoDecoderBlock(nn.Module):
398
+ """Single Transformer Block with Attention and Feed-forward layers.
399
+
400
+ Implements a standard transformer block with:
401
+ - Multi-head attention with normalization and residual connection
402
+ - SwiGLU feed-forward network with normalization and residual connection
403
+
404
+ Args:
405
+ config (Union[ModelConfig, PicoDecoderHFConfig]): Model configuration; either a dataclass or
406
+ a HuggingFace PicoDecoderHFConfig
407
+ """
408
+
409
+ def __init__(
410
+ self,
411
+ config: Union["ModelConfig", "PicoDecoderHFConfig"],
412
+ ):
413
+ super().__init__()
414
+
415
+ self.attention = Attention(config)
416
+ self.swiglu = SwiGLU(config)
417
+ self.attention_norm = RMSNorm(config)
418
+ self.swiglu_norm = RMSNorm(config)
419
+
420
+ def forward(
421
+ self,
422
+ input: torch.Tensor,
423
+ mask: Optional[torch.Tensor] = None,
424
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
425
+ use_cache: bool = False,
426
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
427
+ attention_output, cached_key_values = self.attention(
428
+ self.attention_norm(input),
429
+ mask=mask,
430
+ past_key_values=past_key_values,
431
+ use_cache=use_cache,
432
+ )
433
+ # NOTE: cached_key_values is None if use_cache is False
434
+
435
+ h = input + attention_output
436
+ out = h + self.swiglu(self.swiglu_norm(h))
437
+ return out, cached_key_values
438
+
439
+
440
+ ########################################################
441
+ #
442
+ # Pico Decoder (Causal Transformer Model)
443
+ #
444
+ ########################################################
445
+
446
+
447
+ class PicoDecoder(nn.Module):
448
+ """
449
+ Pico Decoder: combines the embedding, causal decoder blocks, and output projection into a
450
+ single autoregressive model.
451
+
452
+ For more information on the model, see the classes for the modules that make up the model.
453
+ """
454
+
455
+ def __init__(
456
+ self,
457
+ model_config: Union["ModelConfig", "PicoDecoderHFConfig"],
458
+ ):
459
+ super().__init__()
460
+ self.config = model_config
461
+
462
+ self.embedding_proj = nn.Embedding(self.config.vocab_size, self.config.d_model)
463
+ self.layers = nn.ModuleList(
464
+ [PicoDecoderBlock(self.config) for _ in range(self.config.n_layers)]
465
+ )
466
+ self.output_norm = RMSNorm(self.config)
467
+ self.de_embedding_proj = nn.Linear(
468
+ self.config.d_model, self.config.vocab_size, bias=False
469
+ )
470
+
471
+ def convert_to_hf_model(self) -> "PicoDecoderHF":
472
+ """Convert the Lightning model to a HuggingFace model."""
473
+ # Create HF config without fabric-specific settings
474
+ hf_config = PicoDecoderHFConfig.from_dataclass(self.config)
475
+
476
+ # Create new HF model
477
+ hf_model = PicoDecoderHF(hf_config)
478
+
479
+ # Copy state dict, excluding fabric-specific keys
480
+ hf_model.load_state_dict(self.state_dict(prefix="pico_decoder."))
481
+
482
+ return hf_model
483
+
484
+ def forward(
485
+ self,
486
+ input_ids: torch.Tensor,
487
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
488
+ use_cache: bool = False,
489
+ ) -> Tuple[torch.Tensor, Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]]]:
490
+ """
491
+ This is the forward pass for the entire Pico model. It boils down to:
492
+ - Embedding the input ids
493
+ - Creating a causal mask
494
+ - Processing through the pico layers
495
+ - Projecting the output to logits
496
+
497
+ NOTE: One feature that might be confusing is the KV cache. The KV cache is used to speed up
498
+ generation by caching the KV pairs from previous forward passes. This is useful when doing
499
+ tasks that require generating multiple tokens conditioned on previous tokens (e.g. language
500
+ modeling, text generation, etc.). The way the KV cache is implemented is that each layer has
501
+ its own KV cache which is stored as a tuple. The whole model then stores a tuple of these
502
+ KV caches (so a tuple of tuples).
503
+ """
504
+
505
+ seq_len = input_ids.shape[-1]
506
+ h = self.embedding_proj(input_ids)
507
+
508
+ # Calculate start position from past cached KV pairs. Remember that each layer has its
509
+ # own KV Cache. So when we index past_key_values, we need to index into the KV pairs for the
510
+ # correct layer and then for either the keys or values.
511
+ start_pos = 0
512
+ if (
513
+ past_key_values is not None
514
+ and past_key_values[0] is not None
515
+ and past_key_values[0][0] is not None
516
+ ):
517
+ start_pos = past_key_values[0][0].shape[1]
518
+
519
+ # Create causal mask for current sequence
520
+ mask = None
521
+ if seq_len > 1:
522
+ mask = torch.full((seq_len, seq_len), float("-inf"))
523
+ mask = torch.triu(mask, diagonal=1)
524
+
525
+ # If using KV cache, extend mask to cover cached sequence length
526
+ if past_key_values is not None:
527
+ # Add zeros for cached tokens (we can attend to all of them)
528
+ mask = torch.hstack([torch.zeros((seq_len, start_pos)), mask])
529
+
530
+ mask = mask.to(h.device)
531
+
532
+ # NOTE: If we are using the cache, we need to store the cached KV pairs for each layer
533
+ # in a tuple. Each layer will have its own cached KV pair which we aggregate in a tuple.
534
+ cached_key_values = () if use_cache else None
535
+
536
+ # Process through transformer blocks
537
+ for idx, layer in enumerate(self.layers):
538
+ layer_past_key_values = None
539
+ if past_key_values is not None:
540
+ try:
541
+ # Handle both tuple-based cache and HuggingFace cache objects
542
+ if hasattr(past_key_values, "__getitem__") and idx < len(
543
+ past_key_values
544
+ ):
545
+ layer_past_key_values = past_key_values[idx]
546
+ except (KeyError, IndexError, TypeError):
547
+ # If we can't access the cache properly, just skip it
548
+ layer_past_key_values = None
549
+
550
+ h, layer_cached_key_values = layer(
551
+ h, mask=mask, past_key_values=layer_past_key_values, use_cache=use_cache
552
+ )
553
+
554
+ if use_cache:
555
+ cached_key_values += (layer_cached_key_values,)
556
+
557
+ # Final norm and projection
558
+ h = self.output_norm(h)
559
+ logits = self.de_embedding_proj(h).float()
560
+
561
+ return logits, cached_key_values
562
+
563
+
564
+ ########################################################
565
+ #
566
+ # HuggingFace Wrapper for the Pico Decoder model.
567
+ #
568
+ ########################################################
569
+
570
+
571
+ class PicoDecoderHFConfig(PretrainedConfig):
572
+ """Config class for the Pico Decoder HuggingFace wrapper."""
573
+
574
+ model_type = "pico_decoder"
575
+
576
+ @classmethod
577
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PicoDecoderHFConfig":
578
+ """
579
+ Initialize config from a dictionary. Note that no kwargs are passed to the constructor --
580
+ this is because with some kwargs special handling is required and can make this class
581
+ brittle.
582
+ """
583
+ pico_config = cls(**config_dict)
584
+
585
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
586
+ unused_kwargs = {
587
+ key: value for key, value in kwargs.items() if not hasattr(pico_config, key)
588
+ }
589
+
590
+ if return_unused_kwargs:
591
+ return pico_config, unused_kwargs
592
+ return pico_config
593
+
594
+ @classmethod
595
+ def from_dataclass(cls, model_config: "ModelConfig"):
596
+ """Initialise from our custom config dataclass."""
597
+ return cls.from_dict(asdict(model_config))
598
+
599
+
600
+ class PicoDecoderHF(PreTrainedModel, GenerationMixin):
601
+ """
602
+ HuggingFace wrapper for the Pico model with generation support.
603
+
604
+ Many evaluation frameworks require a model be setup as a HuggingFace model, so we provide a simple
605
+ wrapper that does just that. When we save checkpoints of the Pico model, we save both the normal
606
+ Pico model as well as the model wrapped in this HuggingFace class.
607
+
608
+ This also lets you do cool things like:
609
+
610
+ `model = AutoModelForCausalLM.from_pretrained("path/to/checkpoint")`
611
+ """
612
+
613
+ config_class = PicoDecoderHFConfig
614
+ _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
615
+ main_input_name = "input_ids"
616
+
617
+ def __init__(self, config: PicoDecoderHFConfig):
618
+ super().__init__(config)
619
+ self.pico_decoder = PicoDecoder(config)
620
+ # Initialize generation config with defaults
621
+ self.generation_config = GenerationConfig()
622
+ # Set some reasonable defaults for the model
623
+ if hasattr(config, "max_position_embeddings"):
624
+ self.generation_config.max_length = config.max_position_embeddings
625
+ if hasattr(config, "vocab_size"):
626
+ self.generation_config.vocab_size = config.vocab_size
627
+
628
+ def forward(
629
+ self,
630
+ input_ids: torch.Tensor,
631
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
632
+ use_cache: bool = False,
633
+ **kwargs,
634
+ ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
635
+ """HuggingFace forward pass wrapper.
636
+
637
+ Forwards pass for the HuggingFace version of the Pico Model. Basic wrapper around the
638
+ Pico model's forward pass, and returns the output as a HuggingFace CausalLMOutput.
639
+ """
640
+ logits, past_key_values = self.pico_decoder(
641
+ input_ids, past_key_values, use_cache
642
+ )
643
+ if use_cache:
644
+ return CausalLMOutputWithPast(
645
+ logits=logits,
646
+ past_key_values=past_key_values,
647
+ )
648
+ else:
649
+ return CausalLMOutput(
650
+ logits=logits,
651
+ )
652
+
653
+ def prepare_inputs_for_generation(
654
+ self,
655
+ input_ids: torch.LongTensor,
656
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
657
+ attention_mask: Optional[torch.LongTensor] = None,
658
+ **kwargs,
659
+ ) -> Dict[str, Any]:
660
+ """
661
+ Prepare inputs for generation.
662
+
663
+ Args:
664
+ input_ids: Input token IDs
665
+ past_key_values: Cached key-value pairs from previous forward passes
666
+ attention_mask: Attention mask for the input
667
+ **kwargs: Additional arguments
668
+
669
+ Returns:
670
+ Dictionary containing prepared inputs
671
+ """
672
+ # If we have past_key_values, we only need the last token
673
+ if past_key_values is not None:
674
+ input_ids = input_ids[:, -1:]
675
+
676
+ return {
677
+ "input_ids": input_ids,
678
+ "past_key_values": past_key_values,
679
+ "use_cache": True,
680
+ }
681
+
682
+ def get_input_embeddings(self):
683
+ """Get the input embeddings layer."""
684
+ return self.pico_decoder.embedding_proj
685
+
686
+ def set_input_embeddings(self, value):
687
+ """Set the input embeddings layer."""
688
+ self.pico_decoder.embedding_proj = value
689
+
690
+ def get_output_embeddings(self):
691
+ """Get the output embeddings layer."""
692
+ return self.pico_decoder.de_embedding_proj
693
+
694
+ def set_output_embeddings(self, value):
695
+ """Set the output embeddings layer."""
696
+ self.pico_decoder.de_embedding_proj = value
697
+
698
+ def get_lm_head(self):
699
+ """Get the language model head."""
700
+ return self.pico_decoder.de_embedding_proj
701
+
702
+ def can_generate(self) -> bool:
703
+ """Check if the model can generate text."""
704
+ return True
705
+
706
+ @property
707
+ def is_encoder_decoder(self) -> bool:
708
+ """Check if the model is an encoder-decoder model."""
709
+ return False
710
+
711
+ @property
712
+ def can_use_cache(self) -> bool:
713
+ """Check if the model can use KV cache."""
714
+ return True
715
+
716
+ def resize_token_embeddings(
717
+ self, new_num_tokens: Optional[int] = None
718
+ ) -> torch.nn.Embedding:
719
+ """Resize token embeddings."""
720
+ old_embeddings = self.get_input_embeddings()
721
+ if new_num_tokens is None:
722
+ new_num_tokens = old_embeddings.num_embeddings
723
+
724
+ new_embeddings = torch.nn.Embedding(
725
+ new_num_tokens, old_embeddings.embedding_dim
726
+ )
727
+ new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
728
+ old_embeddings.weight.data
729
+ )
730
+
731
+ self.pico_decoder.embedding_proj = new_embeddings
732
+ self.pico_decoder.de_embedding_proj = torch.nn.Linear(
733
+ old_embeddings.embedding_dim, new_num_tokens, bias=False
734
+ )
735
+
736
+ return new_embeddings
737
+
738
+
739
+ # Register for auto classes
740
+ PicoDecoderHFConfig.register_for_auto_class()
741
+ PicoDecoderHF.register_for_auto_class("AutoModel")
742
+ PicoDecoderHF.register_for_auto_class("AutoModelForCausalLM")
743
+
744
+
745
+ ########################################################
746
+ #
747
+ # New PicoDecoderForCausalLM class for generation support
748
+ #
749
+ ########################################################
750
+
751
+
752
+ class PicoDecoderForCausalLM(PreTrainedModel, GenerationMixin):
753
+ """
754
+ PicoDecoderForCausalLM: A HuggingFace-compatible model that properly supports generation.
755
+
756
+ This class is designed to work with existing checkpoints and provides full generation support.
757
+ It inherits from the right base classes that HuggingFace expects for text generation.
758
+ """
759
+
760
+ config_class = PicoDecoderHFConfig
761
+ _no_split_modules = ["PicoBlock", "Attention", "SwiGLU", "RMSNorm"]
762
+ main_input_name = "input_ids"
763
+
764
+ def __init__(self, config: PicoDecoderHFConfig):
765
+ super().__init__(config)
766
+ self.pico_decoder = PicoDecoder(config)
767
+ # Initialize generation config with defaults
768
+ self.generation_config = GenerationConfig()
769
+ # Set some reasonable defaults for the model
770
+ if hasattr(config, "max_position_embeddings"):
771
+ self.generation_config.max_length = config.max_position_embeddings
772
+ if hasattr(config, "vocab_size"):
773
+ self.generation_config.vocab_size = config.vocab_size
774
+
775
+ def forward(
776
+ self,
777
+ input_ids: torch.Tensor,
778
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
779
+ use_cache: bool = False,
780
+ **kwargs,
781
+ ) -> Union[CausalLMOutput, CausalLMOutputWithPast]:
782
+ """Forward pass for text generation."""
783
+ logits, past_key_values = self.pico_decoder(
784
+ input_ids, past_key_values, use_cache
785
+ )
786
+ if use_cache:
787
+ return CausalLMOutputWithPast(
788
+ logits=logits,
789
+ past_key_values=past_key_values,
790
+ )
791
+ else:
792
+ return CausalLMOutput(
793
+ logits=logits,
794
+ )
795
+
796
+ def prepare_inputs_for_generation(
797
+ self,
798
+ input_ids: torch.LongTensor,
799
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
800
+ attention_mask: Optional[torch.LongTensor] = None,
801
+ **kwargs,
802
+ ) -> Dict[str, Any]:
803
+ """Prepare inputs for generation."""
804
+ # If we have past_key_values, we only need the last token
805
+ if past_key_values is not None:
806
+ input_ids = input_ids[:, -1:]
807
+
808
+ return {
809
+ "input_ids": input_ids,
810
+ "past_key_values": past_key_values,
811
+ "use_cache": True,
812
+ }
813
+
814
+ def get_input_embeddings(self):
815
+ """Get the input embeddings layer."""
816
+ return self.pico_decoder.embedding_proj
817
+
818
+ def set_input_embeddings(self, value):
819
+ """Set the input embeddings layer."""
820
+ self.pico_decoder.embedding_proj = value
821
+
822
+ def get_output_embeddings(self):
823
+ """Get the output embeddings layer."""
824
+ return self.pico_decoder.de_embedding_proj
825
+
826
+ def set_output_embeddings(self, value):
827
+ """Set the output embeddings layer."""
828
+ self.pico_decoder.de_embedding_proj = value
829
+
830
+ def get_lm_head(self):
831
+ """Get the language model head."""
832
+ return self.pico_decoder.de_embedding_proj
833
+
834
+ def can_generate(self) -> bool:
835
+ """Check if the model can generate text."""
836
+ return True
837
+
838
+ @property
839
+ def is_encoder_decoder(self) -> bool:
840
+ """Check if the model is an encoder-decoder model."""
841
+ return False
842
+
843
+ @property
844
+ def can_use_cache(self) -> bool:
845
+ """Check if the model can use KV cache."""
846
+ return True
847
+
848
+ def resize_token_embeddings(
849
+ self, new_num_tokens: Optional[int] = None
850
+ ) -> torch.nn.Embedding:
851
+ """Resize token embeddings."""
852
+ old_embeddings = self.get_input_embeddings()
853
+ if new_num_tokens is None:
854
+ new_num_tokens = old_embeddings.num_embeddings
855
+
856
+ new_embeddings = torch.nn.Embedding(
857
+ new_num_tokens, old_embeddings.embedding_dim
858
+ )
859
+ new_embeddings.weight.data[: old_embeddings.num_embeddings] = (
860
+ old_embeddings.weight.data
861
+ )
862
+
863
+ self.pico_decoder.embedding_proj = new_embeddings
864
+ self.pico_decoder.de_embedding_proj = torch.nn.Linear(
865
+ old_embeddings.embedding_dim, new_num_tokens, bias=False
866
+ )
867
+
868
+ return new_embeddings
869
+
870
+ @classmethod
871
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
872
+ """
873
+ Load a pretrained model from a checkpoint.
874
+
875
+ This method handles loading from both the old PicoDecoderHF format and the new format.
876
+ """
877
+ # First try to load with the new class
878
+ try:
879
+ return super().from_pretrained(
880
+ pretrained_model_name_or_path, *model_args, **kwargs
881
+ )
882
+ except Exception as e:
883
+ print(f"Failed to load with new class: {e}")
884
+ print("Attempting to load with legacy class and convert...")
885
+
886
+ # Try to load with the old class and convert
887
+ try:
888
+ from transformers import AutoModel
889
+
890
+ old_model = AutoModel.from_pretrained(
891
+ pretrained_model_name_or_path,
892
+ trust_remote_code=True,
893
+ *model_args,
894
+ **kwargs,
895
+ )
896
+
897
+ # Create new model instance
898
+ new_model = cls(old_model.config)
899
+
900
+ # Copy state dict
901
+ new_model.load_state_dict(old_model.state_dict(), strict=False)
902
+
903
+ return new_model
904
+
905
+ except Exception as e2:
906
+ print(f"Failed to convert from legacy format: {e2}")
907
+ raise e
908
+
909
+
910
+ # Register the new class
911
+ PicoDecoderForCausalLM.register_for_auto_class("AutoModelForCausalLM")
pico-decoder-tiny-dolma20M-v1/checkpoints/step_100000/special_tokens_map.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "pad_token": {
10
+ "content": "<|padding|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ }
16
+ }
pico-decoder-tiny-dolma20M-v1/checkpoints/step_100000/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
pico-decoder-tiny-dolma20M-v1/checkpoints/step_100000/tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "|||IP_ADDRESS|||",
8
+ "lstrip": false,
9
+ "normalized": true,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": false
13
+ },
14
+ "1": {
15
+ "content": "<|padding|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "50254": {
23
+ "content": " ",
24
+ "lstrip": false,
25
+ "normalized": true,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": false
29
+ },
30
+ "50255": {
31
+ "content": " ",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": false
37
+ },
38
+ "50256": {
39
+ "content": " ",
40
+ "lstrip": false,
41
+ "normalized": true,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": false
45
+ },
46
+ "50257": {
47
+ "content": " ",
48
+ "lstrip": false,
49
+ "normalized": true,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": false
53
+ },
54
+ "50258": {
55
+ "content": " ",
56
+ "lstrip": false,
57
+ "normalized": true,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": false
61
+ },
62
+ "50259": {
63
+ "content": " ",
64
+ "lstrip": false,
65
+ "normalized": true,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": false
69
+ },
70
+ "50260": {
71
+ "content": " ",
72
+ "lstrip": false,
73
+ "normalized": true,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": false
77
+ },
78
+ "50261": {
79
+ "content": " ",
80
+ "lstrip": false,
81
+ "normalized": true,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": false
85
+ },
86
+ "50262": {
87
+ "content": " ",
88
+ "lstrip": false,
89
+ "normalized": true,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": false
93
+ },
94
+ "50263": {
95
+ "content": " ",
96
+ "lstrip": false,
97
+ "normalized": true,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": false
101
+ },
102
+ "50264": {
103
+ "content": " ",
104
+ "lstrip": false,
105
+ "normalized": true,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": false
109
+ },
110
+ "50265": {
111
+ "content": " ",
112
+ "lstrip": false,
113
+ "normalized": true,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": false
117
+ },
118
+ "50266": {
119
+ "content": " ",
120
+ "lstrip": false,
121
+ "normalized": true,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": false
125
+ },
126
+ "50267": {
127
+ "content": " ",
128
+ "lstrip": false,
129
+ "normalized": true,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": false
133
+ },
134
+ "50268": {
135
+ "content": " ",
136
+ "lstrip": false,
137
+ "normalized": true,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": false
141
+ },
142
+ "50269": {
143
+ "content": " ",
144
+ "lstrip": false,
145
+ "normalized": true,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": false
149
+ },
150
+ "50270": {
151
+ "content": " ",
152
+ "lstrip": false,
153
+ "normalized": true,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "50271": {
159
+ "content": " ",
160
+ "lstrip": false,
161
+ "normalized": true,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "50272": {
167
+ "content": " ",
168
+ "lstrip": false,
169
+ "normalized": true,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "50273": {
175
+ "content": " ",
176
+ "lstrip": false,
177
+ "normalized": true,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "50274": {
183
+ "content": " ",
184
+ "lstrip": false,
185
+ "normalized": true,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": false
189
+ },
190
+ "50275": {
191
+ "content": " ",
192
+ "lstrip": false,
193
+ "normalized": true,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": false
197
+ },
198
+ "50276": {
199
+ "content": " ",
200
+ "lstrip": false,
201
+ "normalized": true,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": false
205
+ },
206
+ "50277": {
207
+ "content": "|||EMAIL_ADDRESS|||",
208
+ "lstrip": false,
209
+ "normalized": true,
210
+ "rstrip": false,
211
+ "single_word": false,
212
+ "special": false
213
+ },
214
+ "50278": {
215
+ "content": "|||PHONE_NUMBER|||",
216
+ "lstrip": false,
217
+ "normalized": true,
218
+ "rstrip": false,
219
+ "single_word": false,
220
+ "special": false
221
+ },
222
+ "50279": {
223
+ "content": "<|endoftext|>",
224
+ "lstrip": false,
225
+ "normalized": false,
226
+ "rstrip": false,
227
+ "single_word": false,
228
+ "special": true
229
+ }
230
+ },
231
+ "bos_token": null,
232
+ "clean_up_tokenization_spaces": true,
233
+ "eos_token": "<|endoftext|>",
234
+ "extra_special_tokens": {},
235
+ "model_max_length": 1000000000000000019884624838656,
236
+ "pad_token": "<|padding|>",
237
+ "tokenizer_class": "GPTNeoXTokenizer",
238
+ "unk_token": null
239
+ }