RelaxingSnorlax commited on
Commit
51331d3
·
verified ·
1 Parent(s): f62b690

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. config.json +38 -96
  2. eagle.py +538 -0
  3. generation_config.json +4 -0
  4. model.safetensors +2 -2
config.json CHANGED
@@ -1,125 +1,67 @@
1
  {
 
 
 
 
 
 
 
2
  "has_no_defaults_at_init": false,
3
- "speculators_model_type": "eagle",
4
- "speculators_version": "0.1.0.dev18",
5
  "speculators_config": {
6
  "algorithm": "eagle",
 
7
  "proposal_methods": [
8
  {
 
9
  "proposal_type": "greedy",
10
  "speculative_tokens": 5,
11
- "verifier_accept_k": 1,
12
- "accept_tolerance": 0.0
13
  }
14
  ],
15
- "default_proposal_method": "greedy",
16
  "verifier": {
17
- "name_or_path": "meta-llama/Meta-Llama-3.1-8B-Instruct",
18
  "architectures": [
19
  "LlamaForCausalLM"
20
- ]
 
21
  }
22
  },
23
- "architectures": [
24
- "EagleSpeculator",
25
- "LlamaDecoderLayer"
26
- ],
27
  "transformer_layer_architecture": "LlamaDecoderLayer",
28
  "transformer_layer_config": {
29
- "vocab_size": 128256,
30
- "max_position_embeddings": 131072,
 
 
 
 
 
 
 
 
31
  "hidden_size": 4096,
 
32
  "intermediate_size": 14336,
33
- "num_hidden_layers": 1,
 
 
34
  "num_attention_heads": 32,
 
35
  "num_key_value_heads": 8,
36
- "hidden_act": "silu",
37
- "initializer_range": 0.02,
38
- "rms_norm_eps": 1e-05,
39
  "pretraining_tp": 1,
40
- "use_cache": true,
41
- "rope_theta": 500000.0,
42
  "rope_scaling": {
43
  "factor": 8.0,
44
- "low_freq_factor": 1.0,
45
  "high_freq_factor": 4.0,
 
46
  "original_max_position_embeddings": 8192,
47
  "rope_type": "llama3"
48
  },
49
- "attention_bias": false,
50
- "attention_dropout": 0.0,
51
- "mlp_bias": false,
52
- "head_dim": 128,
53
- "tie_word_embeddings": false,
54
- "bos_token_id": 128000,
55
- "eos_token_id": [
56
- 128001,
57
- 128008,
58
- 128009
59
- ],
60
- "transformers_version": "4.52.4",
61
- "model_type": "llama"
62
- },
63
- "layernorms": true,
64
- "fusion_bias": false,
65
- "_name_or_path": "",
66
- "transformers_version": "4.52.4",
67
- "return_dict": true,
68
- "output_hidden_states": false,
69
- "output_attentions": false,
70
- "torchscript": false,
71
- "torch_dtype": null,
72
- "use_bfloat16": false,
73
- "tf_legacy_loss": false,
74
- "pruned_heads": {},
75
- "tie_word_embeddings": true,
76
- "chunk_size_feed_forward": 0,
77
- "is_encoder_decoder": false,
78
- "is_decoder": false,
79
- "cross_attention_hidden_size": null,
80
- "add_cross_attention": false,
81
- "tie_encoder_decoder": false,
82
- "max_length": 20,
83
- "min_length": 0,
84
- "do_sample": false,
85
- "early_stopping": false,
86
- "num_beams": 1,
87
- "num_beam_groups": 1,
88
- "diversity_penalty": 0.0,
89
- "temperature": 1.0,
90
- "top_k": 50,
91
- "top_p": 1.0,
92
- "typical_p": 1.0,
93
- "repetition_penalty": 1.0,
94
- "length_penalty": 1.0,
95
- "no_repeat_ngram_size": 0,
96
- "encoder_no_repeat_ngram_size": 0,
97
- "bad_words_ids": null,
98
- "num_return_sequences": 1,
99
- "output_scores": false,
100
- "return_dict_in_generate": false,
101
- "forced_bos_token_id": null,
102
- "forced_eos_token_id": null,
103
- "remove_invalid_values": false,
104
- "exponential_decay_length_penalty": null,
105
- "suppress_tokens": null,
106
- "begin_suppress_tokens": null,
107
- "finetuning_task": null,
108
- "id2label": {
109
- "0": "LABEL_0",
110
- "1": "LABEL_1"
111
- },
112
- "label2id": {
113
- "LABEL_0": 0,
114
- "LABEL_1": 1
115
  },
116
- "tokenizer_class": null,
117
- "prefix": null,
118
- "bos_token_id": null,
119
- "pad_token_id": null,
120
- "eos_token_id": null,
121
- "sep_token_id": null,
122
- "decoder_start_token_id": null,
123
- "task_specific_params": null,
124
- "problem_type": null
125
- }
 
1
  {
2
+ "architectures": [
3
+ "EagleSpeculator"
4
+ ],
5
+ "auto_map": {
6
+ "": "eagle.EagleSpeculatorConfig"
7
+ },
8
+ "fusion_bias": false,
9
  "has_no_defaults_at_init": false,
10
+ "layernorms": true,
 
11
  "speculators_config": {
12
  "algorithm": "eagle",
13
+ "default_proposal_method": "greedy",
14
  "proposal_methods": [
15
  {
16
+ "accept_tolerance": 0.0,
17
  "proposal_type": "greedy",
18
  "speculative_tokens": 5,
19
+ "verifier_accept_k": 1
 
20
  }
21
  ],
 
22
  "verifier": {
 
23
  "architectures": [
24
  "LlamaForCausalLM"
25
+ ],
26
+ "name_or_path": "meta-llama/Meta-Llama-3.1-8B-Instruct"
27
  }
28
  },
29
+ "speculators_model_type": "eagle",
30
+ "speculators_version": "0.1.0.dev13",
31
+ "torch_dtype": "float32",
 
32
  "transformer_layer_architecture": "LlamaDecoderLayer",
33
  "transformer_layer_config": {
34
+ "attention_bias": false,
35
+ "attention_dropout": 0.0,
36
+ "bos_token_id": 128000,
37
+ "eos_token_id": [
38
+ 128001,
39
+ 128008,
40
+ 128009
41
+ ],
42
+ "head_dim": 128,
43
+ "hidden_act": "silu",
44
  "hidden_size": 4096,
45
+ "initializer_range": 0.02,
46
  "intermediate_size": 14336,
47
+ "max_position_embeddings": 131072,
48
+ "mlp_bias": false,
49
+ "model_type": "llama",
50
  "num_attention_heads": 32,
51
+ "num_hidden_layers": 1,
52
  "num_key_value_heads": 8,
 
 
 
53
  "pretraining_tp": 1,
54
+ "rms_norm_eps": 1e-05,
 
55
  "rope_scaling": {
56
  "factor": 8.0,
 
57
  "high_freq_factor": 4.0,
58
+ "low_freq_factor": 1.0,
59
  "original_max_position_embeddings": 8192,
60
  "rope_type": "llama3"
61
  },
62
+ "rope_theta": 500000.0,
63
+ "use_cache": true,
64
+ "vocab_size": 128256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  },
66
+ "transformers_version": "4.52.4"
67
+ }
 
 
 
 
 
 
 
 
eagle.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speculators implementations providing a unified implementation
3
+ for EAGLE v1, EAGLE v2, and HASS variants for spec decoding:
4
+ - Eagle / Eagle v1: https://arxiv.org/abs/2401.15077
5
+ - Eagle v2: https://arxiv.org/abs/2406.16858
6
+ - HASS: https://arxiv.org/abs/2408.15766
7
+
8
+ Classes:
9
+ EagleSpeculatorConfig: Configuration class for EAGLE/HASS model variants
10
+ EagleSpeculator: Main model implementation for EAGLE/HASS speculators
11
+ """
12
+
13
+ import os
14
+ from typing import Any, ClassVar, Literal, Optional, Union
15
+
16
+ import torch
17
+ from pydantic import Field, field_serializer, field_validator, model_validator
18
+ from torch import nn
19
+ from transformers import PretrainedConfig, PreTrainedModel
20
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
21
+ from transformers.modeling_outputs import CausalLMOutputWithPast
22
+ from transformers.models.llama.configuration_llama import LlamaConfig
23
+ from transformers.models.llama.modeling_llama import (
24
+ LlamaDecoderLayer,
25
+ LlamaRMSNorm,
26
+ )
27
+ from typing_extensions import Self
28
+
29
+ from speculators import SpeculatorModel, SpeculatorModelConfig
30
+
31
+ __all__ = [
32
+ "EagleSpeculator",
33
+ "EagleSpeculatorConfig",
34
+ ]
35
+
36
+
37
+ @SpeculatorModelConfig.register("eagle")
38
+ class EagleSpeculatorConfig(SpeculatorModelConfig):
39
+ """
40
+ A SpeculatorModelConfig implementation to be used with the EagleSpeculator
41
+ for EAGLE and HASS variants for spec decoding:
42
+ - Eagle / Eagle v1: https://arxiv.org/abs/2401.15077
43
+ - Eagle v2: https://arxiv.org/abs/2406.16858
44
+ - HASS: https://arxiv.org/abs/2408.15766
45
+
46
+ Model Configurations:
47
+ - EAGLE1: layernorms=False, fusion_bias=False
48
+ - EAGLE2: layernorms=False, fusion_bias=False
49
+ - HASS: layernorms=False, fusion_bias=True
50
+
51
+ Example:
52
+ ```python
53
+ from speculators import SpeculatorsConfig, VerifierConfig
54
+ from speculators.models import EagleSpeculatorConfig
55
+ from speculators.proposals import GreedyTokenProposalConfig
56
+ from transformers import AutoConfig
57
+
58
+ config = EagleSpeculatorConfig(
59
+ transformer_layer_config=AutoConfig.from_pretrained("meta-llama/Llama-3.1-8B-Instruct"),
60
+ speculators_config=SpeculatorsConfig(
61
+ algorithm="eagle",
62
+ proposal_methods=[
63
+ GreedyTokenProposalConfig(),
64
+ ],
65
+ default_proposal_method="greedy",
66
+ verifier=VerifierConfig(
67
+ name_or_path="meta-llama/Llama-3.1-8B-Instruct",
68
+ architectures=["LlamaForCausalLM"],
69
+ )
70
+ )
71
+ ```
72
+ """
73
+
74
+ speculators_model_type: Literal["eagle"] = "eagle"
75
+ architectures: list[str] = Field(
76
+ default_factory=lambda: ["EagleSpeculator"],
77
+ description=(
78
+ "List of model architectures that can be used with the model "
79
+ "pretrained weights. Automatically includes the transformer layer "
80
+ "architecture to ensure compatibility during model loading and "
81
+ "validation."
82
+ ),
83
+ )
84
+
85
+ transformer_layer_architecture: str = Field(
86
+ default="LlamaDecoderLayer",
87
+ description=(
88
+ "The architecture class name of the transformer layer to use for "
89
+ "the speculator's decoder layer. Must correspond to a valid "
90
+ "transformer decoder layer class (e.g., 'LlamaDecoderLayer')."
91
+ ),
92
+ )
93
+ transformer_layer_config: PretrainedConfig = Field(
94
+ default_factory=LlamaConfig,
95
+ description=(
96
+ "Configuration object for the transformer layer architecture. "
97
+ "Must be a PretrainedConfig instance that matches the requirements "
98
+ "of the transformer_layer_architecture. Contains parameters such as "
99
+ "hidden_size, num_attention_heads, intermediate_size, vocab_size, "
100
+ "and other architecture-specific settings."
101
+ ),
102
+ )
103
+ layernorms: bool = Field(
104
+ default=False,
105
+ description=(
106
+ "Whether to include additional layer normalization layers in the "
107
+ "model architecture. When True, adds RMSNorm layers after the "
108
+ "verifier's hidden state (embedding_layernorm), after the fusion "
109
+ "layer output, and before the language model head (pre_lm_head_layernorm). "
110
+ "When False, these layers are not included and the output layernorm "
111
+ "within the transformer architecture is removed as well. "
112
+ "Standard EAGLE1, EAGLE2, and HASS implementations use False."
113
+ ),
114
+ )
115
+ fusion_bias: bool = Field(
116
+ default=False,
117
+ description=(
118
+ "Whether to add a learnable bias term to the fusion (fully connected) "
119
+ "layer that combines input embeddings with verifier hidden states. "
120
+ "The fusion layer concatenates input embeddings and hidden states, "
121
+ "then projects to hidden_size dimensions. Standard EAGLE1 and EAGLE2 "
122
+ "use False, while HASS uses True."
123
+ ),
124
+ )
125
+
126
+ @model_validator(mode="after")
127
+ def check_add_architectures(self) -> Self:
128
+ """
129
+ Automatically adds the transformer layer architecture to the
130
+ architectures list if it's not already present.
131
+
132
+ :return: The validated configuration instance with updated architectures
133
+ """
134
+ if self.transformer_layer_architecture not in self.architectures:
135
+ self.architectures.append(self.transformer_layer_architecture)
136
+
137
+ return self
138
+
139
+ @field_serializer("transformer_layer_config")
140
+ def serialize_transformer_layer_config(self, value: PretrainedConfig) -> dict:
141
+ """
142
+ Serialize the transformer_layer_config to a dictionary for JSON storage.
143
+
144
+ Converts the PretrainedConfig object to its dictionary representation
145
+ using to_diff_dict() to only include non-default values.
146
+
147
+ :param value: The PretrainedConfig instance to serialize
148
+ :return: Dictionary representation of the transformer layer configuration
149
+ """
150
+ return value.to_diff_dict()
151
+
152
+ @field_validator("transformer_layer_config", mode="before")
153
+ @classmethod
154
+ def validate_transformer_layer_config(cls, value: Any) -> PretrainedConfig:
155
+ """
156
+ Validate and convert transformer_layer_config to a PretrainedConfig instance.
157
+
158
+ Accepts either a dictionary that can be converted to a PretrainedConfig
159
+ or an existing PretrainedConfig instance.
160
+
161
+ :param value: The value to validate (dict or PretrainedConfig)
162
+ :return: A validated PretrainedConfig instance
163
+ :raises ValueError: If the value cannot be converted to a PretrainedConfig
164
+ """
165
+ if isinstance(value, dict):
166
+ return PretrainedConfig.from_dict(value)
167
+
168
+ if isinstance(value, PretrainedConfig):
169
+ return value
170
+
171
+ raise ValueError(
172
+ "transformer_layer_config must be a PretrainedConfig instance or a "
173
+ "dictionary that can be converted to a PretrainedConfig."
174
+ )
175
+
176
+
177
+ @SpeculatorModel.register("eagle")
178
+ class EagleSpeculator(SpeculatorModel):
179
+ """
180
+ A SpeculatorModel implementation for EAGLE and HASS variants for spec decoding:
181
+ - Eagle / Eagle v1: https://arxiv.org/abs/2401.15077
182
+ - Eagle v2: https://arxiv.org/abs/2406.16858
183
+ - HASS: https://arxiv.org/abs/2408.15766
184
+
185
+ Architecture Overview:
186
+ The EAGLE speculator consists of:
187
+ 1. Input embedding layer (shared with verifier)
188
+ 2. Optional embedding layer normalization
189
+ 3. Fusion layer: Concatenates and projects input embeddings + verifier hidden
190
+ states to a latent space of hidden_size
191
+ 4. Single transformer decoder layer for candidate token generation
192
+ 5. Optional pre-LM head layer normalization
193
+ 6. Language model head (shared with verifier)
194
+
195
+ Speculative Decoding Process:
196
+ 1. Verifier model processes input and generates hidden states
197
+ 2. EAGLE speculator uses these hidden states + input embeddings to predict
198
+ next tokens
199
+ 3. Multiple candidate tokens generated in parallel using token proposal methods
200
+ 4. Verifier validates candidates and accepts/rejects based on probability
201
+ thresholds
202
+ 5. Process continues iteratively for multi-token speculation
203
+
204
+ Example:
205
+ ```python
206
+ from speculators import SpeculatorsConfig, VerifierConfig
207
+ from speculators.models import EagleSpeculator, EagleSpeculatorConfig
208
+ from speculators.proposals import GreedyTokenProposalConfig
209
+ from transformers import AutoConfig, AutoTokenizer
210
+
211
+ config = EagleSpeculatorConfig(
212
+ transformer_layer_config=AutoConfig.from_pretrained("meta-llama/Llama-3.1-8B-Instruct"),
213
+ speculators_config=SpeculatorsConfig(
214
+ algorithm="eagle",
215
+ proposal_methods=[
216
+ GreedyTokenProposalConfig(),
217
+ ],
218
+ default_proposal_method="greedy",
219
+ verifier=VerifierConfig(
220
+ name_or_path="meta-llama/Llama-3.1-8B-Instruct",
221
+ architectures=["LlamaForCausalLM"],
222
+ )
223
+ )
224
+ speculator = EagleSpeculator(
225
+ config, verifier=verifier, verifier_attachment_mode="full"
226
+ )
227
+ ```
228
+ """
229
+
230
+ # PreTrainedModel settings
231
+ config_class: ClassVar[type[EagleSpeculatorConfig]] = EagleSpeculatorConfig # type: ignore[misc]
232
+ _keys_to_ignore_on_load_missing: ClassVar[list[str]] = [ # type: ignore[misc]
233
+ "verifier*",
234
+ "embed_tokens*",
235
+ "lm_head*",
236
+ ]
237
+ _keys_to_ignore_on_save: ClassVar[list[str]] = [ # type: ignore[assignment,misc]
238
+ "embed_tokens.weight",
239
+ "lm_head.weight",
240
+ "lm_head.bias",
241
+ ]
242
+
243
+ def __init__(
244
+ self,
245
+ config: EagleSpeculatorConfig,
246
+ verifier: Optional[Union[str, os.PathLike, PreTrainedModel]] = None,
247
+ verifier_attachment_mode: Optional[
248
+ Literal["detached", "full", "train_only"]
249
+ ] = None,
250
+ ):
251
+ """
252
+ Initializes an EAGLE speculator architecture with configurable components based
253
+ on the provided configuration. The model starts with verifier-dependent layers
254
+ (embed_tokens, rotary_emb, lm_head) set to None until a verifier is attached.
255
+
256
+ :param config: Configuration object specifying model architecture, layer
257
+ settings, and speculative decoding parameters. Must be an instance of
258
+ EagleSpeculatorConfig containing transformer layer configuration and
259
+ EAGLE-specific settings.
260
+ :param verifier: Optional verifier model to attach for speculative decoding.
261
+ Can be a path to a model directory, Hugging Face model identifier, or
262
+ PreTrainedModel instance. If None, must be attached later via
263
+ attach_verifier() before using the model.
264
+ :param verifier_attachment_mode: Mode for verifier attachment. "detached"
265
+ prevents attachment even if verifier is provided. "full" enables
266
+ complete integration for both training and generation. "train_only"
267
+ attaches only components needed for training, optimizing memory usage.
268
+ """
269
+ if not isinstance(config, EagleSpeculatorConfig):
270
+ raise ValueError(
271
+ "config must be an instance of EagleSpeculatorConfig, "
272
+ f"got {type(config)} instead."
273
+ )
274
+
275
+ # Initialize model parameters from config
276
+ self.vocab_size = config.transformer_layer_config.vocab_size
277
+ self.hidden_size = config.transformer_layer_config.hidden_size
278
+ self.padding_idx = config.transformer_layer_config.pad_token_id
279
+
280
+ # Set layers pulled from the verifier to None until attach is called
281
+ self.embed_tokens: Optional[nn.Embedding] = None
282
+ self.rotary_emb: Optional[nn.Module] = None
283
+ self.lm_head: Optional[nn.Linear] = None
284
+
285
+ # Delayed initialization to ensure everything needed for attach_verifier is set
286
+ super().__init__(
287
+ config=config,
288
+ verifier=verifier,
289
+ verifier_attachment_mode=verifier_attachment_mode,
290
+ )
291
+
292
+ # Initialize layers based on the configuration
293
+ self.embedding_layernorm: Optional[nn.Module] = self._create_layernorm()
294
+ self.fusion_fc: nn.Linear = nn.Linear(
295
+ 2 * self.hidden_size,
296
+ self.hidden_size,
297
+ bias=config.fusion_bias,
298
+ )
299
+ self.transformer: nn.Module = self._create_transformer_layer()
300
+ self.pre_lm_head_layernorm: Optional[nn.Module] = self._create_layernorm()
301
+
302
+ self.post_init() # type: ignore[attr-defined]
303
+
304
+ def attach_verifier(
305
+ self,
306
+ verifier: Union[str, os.PathLike, PreTrainedModel],
307
+ mode: Optional[Literal["full", "train_only"]] = None,
308
+ ) -> PreTrainedModel:
309
+ """
310
+ Attach a verifier model to the EagleSpeculator for speculative decoding.
311
+ Utilizes the verifier's embed_tokens, rotary_emb, and lm_head layers
312
+ for the speculator's forward pass and generation methods.
313
+ Additionally, for `generate`, it uses the verifier's hidden states
314
+ to generate speculative token predictions.
315
+
316
+ If mode is "full", the verifier is fully integrated for use with
317
+ both `generate` and `forward` methods.
318
+
319
+ If mode is "train_only", only the verifier's layers required for a forward pass
320
+ are attached, allowing for better resource utilization during training.
321
+ `generate` will not be available until a full verifier is attached.
322
+
323
+ Example:
324
+ ```python
325
+ # Load and attach a verifier
326
+ verifier = EagleSpeculator(...)
327
+
328
+ # For generation
329
+ speculator.attach_verifier(verifier)
330
+ outputs = speculator.generate(input_ids)
331
+ speculator.detach_verifier()
332
+
333
+ # For training
334
+ speculator.attach_verifier(verifier, mode="train_only")
335
+ outputs = speculator(input_ids, hidden_states)
336
+ speculator.detach_verifier()
337
+ ```
338
+
339
+ :param verifier: The verifier model to attach. This can be a path to a local
340
+ model directory, a Hugging Face model identifier, or an instance of
341
+ PreTrainedModel. If a path or identifier is provided, the model will be
342
+ loaded automatically. If an instance is provided, it will be used directly.
343
+ :param mode: The mode for attaching the verifier. Can be "full" or "train_only".
344
+ If None, defaults to "full". In "train_only" mode, only the layers
345
+ required for a forward pass are attached, and the speculator cannot
346
+ perform generation until a full verifier is attached.
347
+ :return: The PreTrainedModel instance for the verifier that was attached.
348
+ """
349
+ verifier = super().attach_verifier(
350
+ verifier=verifier,
351
+ mode=mode,
352
+ )
353
+
354
+ # Extract layers from the verifier model
355
+
356
+ if hasattr(verifier, "model"):
357
+ self.embed_tokens = verifier.model.embed_tokens # type: ignore[assignment]
358
+ self.rotary_emb = verifier.model.rotary_emb # type: ignore[assignment]
359
+ else:
360
+ # Bare model structure
361
+ self.embed_tokens = verifier.embed_tokens # type: ignore[assignment]
362
+ self.rotary_emb = verifier.rotary_emb # type: ignore[assignment]
363
+
364
+ # lm_head is always at the top level of the verifier
365
+ self.lm_head = verifier.lm_head
366
+
367
+ return verifier
368
+
369
+ def detach_verifier(self):
370
+ """
371
+ Removes the reference to the attached verifier model and frees up the
372
+ associated memory. After calling this method, the speculator will not
373
+ be able to perform forward passes or generation until a new verifier
374
+ is attached.
375
+ """
376
+ super().detach_verifier()
377
+
378
+ del self.embed_tokens
379
+ self.embed_tokens = None
380
+ del self.rotary_emb
381
+ self.rotary_emb = None
382
+ del self.lm_head
383
+ self.lm_head = None
384
+
385
+ def forward(
386
+ self,
387
+ input_ids: torch.LongTensor,
388
+ hidden_states: torch.FloatTensor,
389
+ attention_mask: Optional[torch.Tensor] = None,
390
+ position_ids: Optional[torch.LongTensor] = None,
391
+ past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
392
+ use_cache: Optional[bool] = None,
393
+ output_attentions: Optional[bool] = None,
394
+ output_hidden_states: Optional[bool] = None, # noqa: ARG002
395
+ return_dict: Optional[bool] = None,
396
+ ) -> Union[torch.FloatTensor, CausalLMOutputWithPast]:
397
+ """
398
+ Execute the forward pass for speculative token generation.
399
+
400
+ Processes input tokens and verifier hidden states through the EAGLE architecture
401
+ to generate candidate tokens for speculative decoding. The method combines input
402
+ embeddings with verifier hidden states via a fusion layer, processes them
403
+ through a transformer decoder layer, and produces logits for next token
404
+ prediction.
405
+
406
+ :param input_ids: Token IDs for the current input sequence. Shape: (batch_size,
407
+ sequence_length). These represent the tokens that will be converted to
408
+ embeddings and combined with verifier hidden states.
409
+ :param hidden_states: Hidden state representations from the verifier model
410
+ corresponding to the input sequence. Shape: (batch_size, sequence_length,
411
+ hidden_size). These capture the verifier's understanding of the context.
412
+ :param attention_mask: Optional attention mask to avoid attending to padding
413
+ tokens. Shape: (batch_size, sequence_length) for 2D or (batch_size, 1,
414
+ sequence_length, sequence_length) for 4D causal mask.
415
+ :param position_ids: Optional position indices for tokens in the sequence.
416
+ Shape: (batch_size, sequence_length). If None, auto-generated based on
417
+ sequence length and past key values.
418
+ :param past_key_values: Optional cached key-value states from previous forward
419
+ passes for efficient generation. Tuple of layer key-value pairs.
420
+ :param use_cache: Whether to return key-value states for caching in subsequent
421
+ forward passes. Useful for autoregressive generation efficiency.
422
+ :param output_attentions: Whether to return attention weights from the
423
+ transformer layer. Used for analysis and visualization.
424
+ :param output_hidden_states: Whether to return hidden states from the
425
+ transformer layer. Currently not implemented in this model.
426
+ :param return_dict: Whether to return structured CausalLMOutputWithPast instead
427
+ of raw logits. If None, uses config.use_return_dict default.
428
+ :return: Either raw logits tensor (batch_size, sequence_length, vocab_size) if
429
+ return_dict=False, or CausalLMOutputWithPast containing logits, past key
430
+ values, and optional attention weights.
431
+ :raises ValueError: If verifier components (embed_tokens, rotary_emb, lm_head)
432
+ are not attached. Call attach_verifier() before using forward().
433
+ """
434
+ if self.embed_tokens is None or self.rotary_emb is None or self.lm_head is None:
435
+ raise ValueError(
436
+ "Verifier model layers not initialized. "
437
+ "Call `attach_verifier` to set up the model before using forward."
438
+ )
439
+
440
+ return_dict = (
441
+ return_dict if return_dict is not None else self.config.use_return_dict
442
+ )
443
+
444
+ inputs_embeds = self.embed_tokens(input_ids)
445
+ if self.embedding_layernorm is not None:
446
+ inputs_embeds = self.embedding_layernorm(inputs_embeds)
447
+
448
+ hidden_states = self.fusion_fc(
449
+ torch.cat([inputs_embeds, hidden_states], dim=-1)
450
+ )
451
+ hidden_states, attention_mask, position_ids = self._prepare_decoder_inputs(
452
+ hidden_states, attention_mask, position_ids, past_key_values
453
+ )
454
+
455
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
456
+ layer_outputs = self.transformer(
457
+ hidden_states,
458
+ attention_mask=attention_mask,
459
+ position_ids=position_ids,
460
+ past_key_value=past_key_values[0] if past_key_values else None,
461
+ output_attentions=output_attentions,
462
+ use_cache=use_cache,
463
+ position_embeddings=(cos, sin),
464
+ )
465
+ hidden_states = layer_outputs[0]
466
+
467
+ if self.pre_lm_head_layernorm is not None:
468
+ hidden_states = self.pre_lm_head_layernorm(hidden_states)
469
+
470
+ logits = self.lm_head(hidden_states)
471
+
472
+ if not return_dict:
473
+ return logits
474
+
475
+ return CausalLMOutputWithPast(
476
+ logits=logits,
477
+ past_key_values=layer_outputs[1] if use_cache else None,
478
+ hidden_states=None,
479
+ attentions=None,
480
+ )
481
+
482
+ def _prepare_decoder_inputs(
483
+ self,
484
+ hidden_states: torch.FloatTensor,
485
+ attention_mask: Optional[torch.Tensor],
486
+ position_ids: Optional[torch.LongTensor],
487
+ past_key_values: Optional[tuple[tuple[torch.FloatTensor]]],
488
+ ) -> tuple[torch.FloatTensor, Optional[torch.Tensor], Optional[torch.LongTensor]]:
489
+ batch_size, seq_length = hidden_states.shape[:2]
490
+
491
+ if position_ids is None:
492
+ device = hidden_states.device
493
+ position_ids = (
494
+ torch.arange(seq_length, dtype=torch.long, device=device) # type: ignore[assignment]
495
+ .unsqueeze(0)
496
+ .expand(batch_size, -1)
497
+ )
498
+
499
+ if attention_mask is not None and attention_mask.dim() == 2: # noqa: PLR2004
500
+ past_key_values_length = (
501
+ past_key_values[0][0].shape[2] if past_key_values else 0
502
+ )
503
+ attention_mask = _prepare_4d_causal_attention_mask(
504
+ attention_mask,
505
+ (batch_size, seq_length),
506
+ hidden_states,
507
+ past_key_values_length,
508
+ sliding_window=getattr(self.config, "sliding_window", None),
509
+ )
510
+
511
+ return hidden_states, attention_mask, position_ids
512
+
513
+ def _create_layernorm(self) -> Optional[nn.Module]:
514
+ if not self.config.layernorms:
515
+ return None
516
+
517
+ return self._layernorm_class()(
518
+ self.hidden_size, eps=self.config.transformer_layer_config.rms_norm_eps
519
+ )
520
+
521
+ def _create_transformer_layer(self) -> nn.Module:
522
+ layer_class = self._transformer_layer_class()
523
+ layer = layer_class(
524
+ self.config.transformer_layer_config,
525
+ layer_idx=0,
526
+ )
527
+
528
+ if not self.config.layernorms:
529
+ # Replace input_layernorm with Identity if layernorms are not used
530
+ layer.input_layernorm = nn.Identity()
531
+
532
+ return layer
533
+
534
+ def _layernorm_class(self) -> type[nn.Module]:
535
+ return LlamaRMSNorm
536
+
537
+ def _transformer_layer_class(self) -> type[nn.Module]:
538
+ return LlamaDecoderLayer
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.52.4"
4
+ }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4151564cb2c9bc7f762af53d68a94e5924f60848992c1814128bafce576a8704
3
- size 1006699768
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25c4dba5bf103b1ca1879b4eea7804871ad8a8b69310f69418b8680bc08f7312
3
+ size 1006699800