Upload folder using huggingface_hub
Browse files- config.json +38 -96
- eagle.py +538 -0
- generation_config.json +4 -0
- model.safetensors +2 -2
config.json
CHANGED
@@ -1,125 +1,67 @@
|
|
1 |
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
"has_no_defaults_at_init": false,
|
3 |
-
"
|
4 |
-
"speculators_version": "0.1.0.dev18",
|
5 |
"speculators_config": {
|
6 |
"algorithm": "eagle",
|
|
|
7 |
"proposal_methods": [
|
8 |
{
|
|
|
9 |
"proposal_type": "greedy",
|
10 |
"speculative_tokens": 5,
|
11 |
-
"verifier_accept_k": 1
|
12 |
-
"accept_tolerance": 0.0
|
13 |
}
|
14 |
],
|
15 |
-
"default_proposal_method": "greedy",
|
16 |
"verifier": {
|
17 |
-
"name_or_path": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
18 |
"architectures": [
|
19 |
"LlamaForCausalLM"
|
20 |
-
]
|
|
|
21 |
}
|
22 |
},
|
23 |
-
"
|
24 |
-
|
25 |
-
|
26 |
-
],
|
27 |
"transformer_layer_architecture": "LlamaDecoderLayer",
|
28 |
"transformer_layer_config": {
|
29 |
-
"
|
30 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
"hidden_size": 4096,
|
|
|
32 |
"intermediate_size": 14336,
|
33 |
-
"
|
|
|
|
|
34 |
"num_attention_heads": 32,
|
|
|
35 |
"num_key_value_heads": 8,
|
36 |
-
"hidden_act": "silu",
|
37 |
-
"initializer_range": 0.02,
|
38 |
-
"rms_norm_eps": 1e-05,
|
39 |
"pretraining_tp": 1,
|
40 |
-
"
|
41 |
-
"rope_theta": 500000.0,
|
42 |
"rope_scaling": {
|
43 |
"factor": 8.0,
|
44 |
-
"low_freq_factor": 1.0,
|
45 |
"high_freq_factor": 4.0,
|
|
|
46 |
"original_max_position_embeddings": 8192,
|
47 |
"rope_type": "llama3"
|
48 |
},
|
49 |
-
"
|
50 |
-
"
|
51 |
-
"
|
52 |
-
"head_dim": 128,
|
53 |
-
"tie_word_embeddings": false,
|
54 |
-
"bos_token_id": 128000,
|
55 |
-
"eos_token_id": [
|
56 |
-
128001,
|
57 |
-
128008,
|
58 |
-
128009
|
59 |
-
],
|
60 |
-
"transformers_version": "4.52.4",
|
61 |
-
"model_type": "llama"
|
62 |
-
},
|
63 |
-
"layernorms": true,
|
64 |
-
"fusion_bias": false,
|
65 |
-
"_name_or_path": "",
|
66 |
-
"transformers_version": "4.52.4",
|
67 |
-
"return_dict": true,
|
68 |
-
"output_hidden_states": false,
|
69 |
-
"output_attentions": false,
|
70 |
-
"torchscript": false,
|
71 |
-
"torch_dtype": null,
|
72 |
-
"use_bfloat16": false,
|
73 |
-
"tf_legacy_loss": false,
|
74 |
-
"pruned_heads": {},
|
75 |
-
"tie_word_embeddings": true,
|
76 |
-
"chunk_size_feed_forward": 0,
|
77 |
-
"is_encoder_decoder": false,
|
78 |
-
"is_decoder": false,
|
79 |
-
"cross_attention_hidden_size": null,
|
80 |
-
"add_cross_attention": false,
|
81 |
-
"tie_encoder_decoder": false,
|
82 |
-
"max_length": 20,
|
83 |
-
"min_length": 0,
|
84 |
-
"do_sample": false,
|
85 |
-
"early_stopping": false,
|
86 |
-
"num_beams": 1,
|
87 |
-
"num_beam_groups": 1,
|
88 |
-
"diversity_penalty": 0.0,
|
89 |
-
"temperature": 1.0,
|
90 |
-
"top_k": 50,
|
91 |
-
"top_p": 1.0,
|
92 |
-
"typical_p": 1.0,
|
93 |
-
"repetition_penalty": 1.0,
|
94 |
-
"length_penalty": 1.0,
|
95 |
-
"no_repeat_ngram_size": 0,
|
96 |
-
"encoder_no_repeat_ngram_size": 0,
|
97 |
-
"bad_words_ids": null,
|
98 |
-
"num_return_sequences": 1,
|
99 |
-
"output_scores": false,
|
100 |
-
"return_dict_in_generate": false,
|
101 |
-
"forced_bos_token_id": null,
|
102 |
-
"forced_eos_token_id": null,
|
103 |
-
"remove_invalid_values": false,
|
104 |
-
"exponential_decay_length_penalty": null,
|
105 |
-
"suppress_tokens": null,
|
106 |
-
"begin_suppress_tokens": null,
|
107 |
-
"finetuning_task": null,
|
108 |
-
"id2label": {
|
109 |
-
"0": "LABEL_0",
|
110 |
-
"1": "LABEL_1"
|
111 |
-
},
|
112 |
-
"label2id": {
|
113 |
-
"LABEL_0": 0,
|
114 |
-
"LABEL_1": 1
|
115 |
},
|
116 |
-
"
|
117 |
-
|
118 |
-
"bos_token_id": null,
|
119 |
-
"pad_token_id": null,
|
120 |
-
"eos_token_id": null,
|
121 |
-
"sep_token_id": null,
|
122 |
-
"decoder_start_token_id": null,
|
123 |
-
"task_specific_params": null,
|
124 |
-
"problem_type": null
|
125 |
-
}
|
|
|
1 |
{
|
2 |
+
"architectures": [
|
3 |
+
"EagleSpeculator"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"": "eagle.EagleSpeculatorConfig"
|
7 |
+
},
|
8 |
+
"fusion_bias": false,
|
9 |
"has_no_defaults_at_init": false,
|
10 |
+
"layernorms": true,
|
|
|
11 |
"speculators_config": {
|
12 |
"algorithm": "eagle",
|
13 |
+
"default_proposal_method": "greedy",
|
14 |
"proposal_methods": [
|
15 |
{
|
16 |
+
"accept_tolerance": 0.0,
|
17 |
"proposal_type": "greedy",
|
18 |
"speculative_tokens": 5,
|
19 |
+
"verifier_accept_k": 1
|
|
|
20 |
}
|
21 |
],
|
|
|
22 |
"verifier": {
|
|
|
23 |
"architectures": [
|
24 |
"LlamaForCausalLM"
|
25 |
+
],
|
26 |
+
"name_or_path": "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
27 |
}
|
28 |
},
|
29 |
+
"speculators_model_type": "eagle",
|
30 |
+
"speculators_version": "0.1.0.dev13",
|
31 |
+
"torch_dtype": "float32",
|
|
|
32 |
"transformer_layer_architecture": "LlamaDecoderLayer",
|
33 |
"transformer_layer_config": {
|
34 |
+
"attention_bias": false,
|
35 |
+
"attention_dropout": 0.0,
|
36 |
+
"bos_token_id": 128000,
|
37 |
+
"eos_token_id": [
|
38 |
+
128001,
|
39 |
+
128008,
|
40 |
+
128009
|
41 |
+
],
|
42 |
+
"head_dim": 128,
|
43 |
+
"hidden_act": "silu",
|
44 |
"hidden_size": 4096,
|
45 |
+
"initializer_range": 0.02,
|
46 |
"intermediate_size": 14336,
|
47 |
+
"max_position_embeddings": 131072,
|
48 |
+
"mlp_bias": false,
|
49 |
+
"model_type": "llama",
|
50 |
"num_attention_heads": 32,
|
51 |
+
"num_hidden_layers": 1,
|
52 |
"num_key_value_heads": 8,
|
|
|
|
|
|
|
53 |
"pretraining_tp": 1,
|
54 |
+
"rms_norm_eps": 1e-05,
|
|
|
55 |
"rope_scaling": {
|
56 |
"factor": 8.0,
|
|
|
57 |
"high_freq_factor": 4.0,
|
58 |
+
"low_freq_factor": 1.0,
|
59 |
"original_max_position_embeddings": 8192,
|
60 |
"rope_type": "llama3"
|
61 |
},
|
62 |
+
"rope_theta": 500000.0,
|
63 |
+
"use_cache": true,
|
64 |
+
"vocab_size": 128256
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
},
|
66 |
+
"transformers_version": "4.52.4"
|
67 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eagle.py
ADDED
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Speculators implementations providing a unified implementation
|
3 |
+
for EAGLE v1, EAGLE v2, and HASS variants for spec decoding:
|
4 |
+
- Eagle / Eagle v1: https://arxiv.org/abs/2401.15077
|
5 |
+
- Eagle v2: https://arxiv.org/abs/2406.16858
|
6 |
+
- HASS: https://arxiv.org/abs/2408.15766
|
7 |
+
|
8 |
+
Classes:
|
9 |
+
EagleSpeculatorConfig: Configuration class for EAGLE/HASS model variants
|
10 |
+
EagleSpeculator: Main model implementation for EAGLE/HASS speculators
|
11 |
+
"""
|
12 |
+
|
13 |
+
import os
|
14 |
+
from typing import Any, ClassVar, Literal, Optional, Union
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from pydantic import Field, field_serializer, field_validator, model_validator
|
18 |
+
from torch import nn
|
19 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
20 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
21 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
22 |
+
from transformers.models.llama.configuration_llama import LlamaConfig
|
23 |
+
from transformers.models.llama.modeling_llama import (
|
24 |
+
LlamaDecoderLayer,
|
25 |
+
LlamaRMSNorm,
|
26 |
+
)
|
27 |
+
from typing_extensions import Self
|
28 |
+
|
29 |
+
from speculators import SpeculatorModel, SpeculatorModelConfig
|
30 |
+
|
31 |
+
__all__ = [
|
32 |
+
"EagleSpeculator",
|
33 |
+
"EagleSpeculatorConfig",
|
34 |
+
]
|
35 |
+
|
36 |
+
|
37 |
+
@SpeculatorModelConfig.register("eagle")
|
38 |
+
class EagleSpeculatorConfig(SpeculatorModelConfig):
|
39 |
+
"""
|
40 |
+
A SpeculatorModelConfig implementation to be used with the EagleSpeculator
|
41 |
+
for EAGLE and HASS variants for spec decoding:
|
42 |
+
- Eagle / Eagle v1: https://arxiv.org/abs/2401.15077
|
43 |
+
- Eagle v2: https://arxiv.org/abs/2406.16858
|
44 |
+
- HASS: https://arxiv.org/abs/2408.15766
|
45 |
+
|
46 |
+
Model Configurations:
|
47 |
+
- EAGLE1: layernorms=False, fusion_bias=False
|
48 |
+
- EAGLE2: layernorms=False, fusion_bias=False
|
49 |
+
- HASS: layernorms=False, fusion_bias=True
|
50 |
+
|
51 |
+
Example:
|
52 |
+
```python
|
53 |
+
from speculators import SpeculatorsConfig, VerifierConfig
|
54 |
+
from speculators.models import EagleSpeculatorConfig
|
55 |
+
from speculators.proposals import GreedyTokenProposalConfig
|
56 |
+
from transformers import AutoConfig
|
57 |
+
|
58 |
+
config = EagleSpeculatorConfig(
|
59 |
+
transformer_layer_config=AutoConfig.from_pretrained("meta-llama/Llama-3.1-8B-Instruct"),
|
60 |
+
speculators_config=SpeculatorsConfig(
|
61 |
+
algorithm="eagle",
|
62 |
+
proposal_methods=[
|
63 |
+
GreedyTokenProposalConfig(),
|
64 |
+
],
|
65 |
+
default_proposal_method="greedy",
|
66 |
+
verifier=VerifierConfig(
|
67 |
+
name_or_path="meta-llama/Llama-3.1-8B-Instruct",
|
68 |
+
architectures=["LlamaForCausalLM"],
|
69 |
+
)
|
70 |
+
)
|
71 |
+
```
|
72 |
+
"""
|
73 |
+
|
74 |
+
speculators_model_type: Literal["eagle"] = "eagle"
|
75 |
+
architectures: list[str] = Field(
|
76 |
+
default_factory=lambda: ["EagleSpeculator"],
|
77 |
+
description=(
|
78 |
+
"List of model architectures that can be used with the model "
|
79 |
+
"pretrained weights. Automatically includes the transformer layer "
|
80 |
+
"architecture to ensure compatibility during model loading and "
|
81 |
+
"validation."
|
82 |
+
),
|
83 |
+
)
|
84 |
+
|
85 |
+
transformer_layer_architecture: str = Field(
|
86 |
+
default="LlamaDecoderLayer",
|
87 |
+
description=(
|
88 |
+
"The architecture class name of the transformer layer to use for "
|
89 |
+
"the speculator's decoder layer. Must correspond to a valid "
|
90 |
+
"transformer decoder layer class (e.g., 'LlamaDecoderLayer')."
|
91 |
+
),
|
92 |
+
)
|
93 |
+
transformer_layer_config: PretrainedConfig = Field(
|
94 |
+
default_factory=LlamaConfig,
|
95 |
+
description=(
|
96 |
+
"Configuration object for the transformer layer architecture. "
|
97 |
+
"Must be a PretrainedConfig instance that matches the requirements "
|
98 |
+
"of the transformer_layer_architecture. Contains parameters such as "
|
99 |
+
"hidden_size, num_attention_heads, intermediate_size, vocab_size, "
|
100 |
+
"and other architecture-specific settings."
|
101 |
+
),
|
102 |
+
)
|
103 |
+
layernorms: bool = Field(
|
104 |
+
default=False,
|
105 |
+
description=(
|
106 |
+
"Whether to include additional layer normalization layers in the "
|
107 |
+
"model architecture. When True, adds RMSNorm layers after the "
|
108 |
+
"verifier's hidden state (embedding_layernorm), after the fusion "
|
109 |
+
"layer output, and before the language model head (pre_lm_head_layernorm). "
|
110 |
+
"When False, these layers are not included and the output layernorm "
|
111 |
+
"within the transformer architecture is removed as well. "
|
112 |
+
"Standard EAGLE1, EAGLE2, and HASS implementations use False."
|
113 |
+
),
|
114 |
+
)
|
115 |
+
fusion_bias: bool = Field(
|
116 |
+
default=False,
|
117 |
+
description=(
|
118 |
+
"Whether to add a learnable bias term to the fusion (fully connected) "
|
119 |
+
"layer that combines input embeddings with verifier hidden states. "
|
120 |
+
"The fusion layer concatenates input embeddings and hidden states, "
|
121 |
+
"then projects to hidden_size dimensions. Standard EAGLE1 and EAGLE2 "
|
122 |
+
"use False, while HASS uses True."
|
123 |
+
),
|
124 |
+
)
|
125 |
+
|
126 |
+
@model_validator(mode="after")
|
127 |
+
def check_add_architectures(self) -> Self:
|
128 |
+
"""
|
129 |
+
Automatically adds the transformer layer architecture to the
|
130 |
+
architectures list if it's not already present.
|
131 |
+
|
132 |
+
:return: The validated configuration instance with updated architectures
|
133 |
+
"""
|
134 |
+
if self.transformer_layer_architecture not in self.architectures:
|
135 |
+
self.architectures.append(self.transformer_layer_architecture)
|
136 |
+
|
137 |
+
return self
|
138 |
+
|
139 |
+
@field_serializer("transformer_layer_config")
|
140 |
+
def serialize_transformer_layer_config(self, value: PretrainedConfig) -> dict:
|
141 |
+
"""
|
142 |
+
Serialize the transformer_layer_config to a dictionary for JSON storage.
|
143 |
+
|
144 |
+
Converts the PretrainedConfig object to its dictionary representation
|
145 |
+
using to_diff_dict() to only include non-default values.
|
146 |
+
|
147 |
+
:param value: The PretrainedConfig instance to serialize
|
148 |
+
:return: Dictionary representation of the transformer layer configuration
|
149 |
+
"""
|
150 |
+
return value.to_diff_dict()
|
151 |
+
|
152 |
+
@field_validator("transformer_layer_config", mode="before")
|
153 |
+
@classmethod
|
154 |
+
def validate_transformer_layer_config(cls, value: Any) -> PretrainedConfig:
|
155 |
+
"""
|
156 |
+
Validate and convert transformer_layer_config to a PretrainedConfig instance.
|
157 |
+
|
158 |
+
Accepts either a dictionary that can be converted to a PretrainedConfig
|
159 |
+
or an existing PretrainedConfig instance.
|
160 |
+
|
161 |
+
:param value: The value to validate (dict or PretrainedConfig)
|
162 |
+
:return: A validated PretrainedConfig instance
|
163 |
+
:raises ValueError: If the value cannot be converted to a PretrainedConfig
|
164 |
+
"""
|
165 |
+
if isinstance(value, dict):
|
166 |
+
return PretrainedConfig.from_dict(value)
|
167 |
+
|
168 |
+
if isinstance(value, PretrainedConfig):
|
169 |
+
return value
|
170 |
+
|
171 |
+
raise ValueError(
|
172 |
+
"transformer_layer_config must be a PretrainedConfig instance or a "
|
173 |
+
"dictionary that can be converted to a PretrainedConfig."
|
174 |
+
)
|
175 |
+
|
176 |
+
|
177 |
+
@SpeculatorModel.register("eagle")
|
178 |
+
class EagleSpeculator(SpeculatorModel):
|
179 |
+
"""
|
180 |
+
A SpeculatorModel implementation for EAGLE and HASS variants for spec decoding:
|
181 |
+
- Eagle / Eagle v1: https://arxiv.org/abs/2401.15077
|
182 |
+
- Eagle v2: https://arxiv.org/abs/2406.16858
|
183 |
+
- HASS: https://arxiv.org/abs/2408.15766
|
184 |
+
|
185 |
+
Architecture Overview:
|
186 |
+
The EAGLE speculator consists of:
|
187 |
+
1. Input embedding layer (shared with verifier)
|
188 |
+
2. Optional embedding layer normalization
|
189 |
+
3. Fusion layer: Concatenates and projects input embeddings + verifier hidden
|
190 |
+
states to a latent space of hidden_size
|
191 |
+
4. Single transformer decoder layer for candidate token generation
|
192 |
+
5. Optional pre-LM head layer normalization
|
193 |
+
6. Language model head (shared with verifier)
|
194 |
+
|
195 |
+
Speculative Decoding Process:
|
196 |
+
1. Verifier model processes input and generates hidden states
|
197 |
+
2. EAGLE speculator uses these hidden states + input embeddings to predict
|
198 |
+
next tokens
|
199 |
+
3. Multiple candidate tokens generated in parallel using token proposal methods
|
200 |
+
4. Verifier validates candidates and accepts/rejects based on probability
|
201 |
+
thresholds
|
202 |
+
5. Process continues iteratively for multi-token speculation
|
203 |
+
|
204 |
+
Example:
|
205 |
+
```python
|
206 |
+
from speculators import SpeculatorsConfig, VerifierConfig
|
207 |
+
from speculators.models import EagleSpeculator, EagleSpeculatorConfig
|
208 |
+
from speculators.proposals import GreedyTokenProposalConfig
|
209 |
+
from transformers import AutoConfig, AutoTokenizer
|
210 |
+
|
211 |
+
config = EagleSpeculatorConfig(
|
212 |
+
transformer_layer_config=AutoConfig.from_pretrained("meta-llama/Llama-3.1-8B-Instruct"),
|
213 |
+
speculators_config=SpeculatorsConfig(
|
214 |
+
algorithm="eagle",
|
215 |
+
proposal_methods=[
|
216 |
+
GreedyTokenProposalConfig(),
|
217 |
+
],
|
218 |
+
default_proposal_method="greedy",
|
219 |
+
verifier=VerifierConfig(
|
220 |
+
name_or_path="meta-llama/Llama-3.1-8B-Instruct",
|
221 |
+
architectures=["LlamaForCausalLM"],
|
222 |
+
)
|
223 |
+
)
|
224 |
+
speculator = EagleSpeculator(
|
225 |
+
config, verifier=verifier, verifier_attachment_mode="full"
|
226 |
+
)
|
227 |
+
```
|
228 |
+
"""
|
229 |
+
|
230 |
+
# PreTrainedModel settings
|
231 |
+
config_class: ClassVar[type[EagleSpeculatorConfig]] = EagleSpeculatorConfig # type: ignore[misc]
|
232 |
+
_keys_to_ignore_on_load_missing: ClassVar[list[str]] = [ # type: ignore[misc]
|
233 |
+
"verifier*",
|
234 |
+
"embed_tokens*",
|
235 |
+
"lm_head*",
|
236 |
+
]
|
237 |
+
_keys_to_ignore_on_save: ClassVar[list[str]] = [ # type: ignore[assignment,misc]
|
238 |
+
"embed_tokens.weight",
|
239 |
+
"lm_head.weight",
|
240 |
+
"lm_head.bias",
|
241 |
+
]
|
242 |
+
|
243 |
+
def __init__(
|
244 |
+
self,
|
245 |
+
config: EagleSpeculatorConfig,
|
246 |
+
verifier: Optional[Union[str, os.PathLike, PreTrainedModel]] = None,
|
247 |
+
verifier_attachment_mode: Optional[
|
248 |
+
Literal["detached", "full", "train_only"]
|
249 |
+
] = None,
|
250 |
+
):
|
251 |
+
"""
|
252 |
+
Initializes an EAGLE speculator architecture with configurable components based
|
253 |
+
on the provided configuration. The model starts with verifier-dependent layers
|
254 |
+
(embed_tokens, rotary_emb, lm_head) set to None until a verifier is attached.
|
255 |
+
|
256 |
+
:param config: Configuration object specifying model architecture, layer
|
257 |
+
settings, and speculative decoding parameters. Must be an instance of
|
258 |
+
EagleSpeculatorConfig containing transformer layer configuration and
|
259 |
+
EAGLE-specific settings.
|
260 |
+
:param verifier: Optional verifier model to attach for speculative decoding.
|
261 |
+
Can be a path to a model directory, Hugging Face model identifier, or
|
262 |
+
PreTrainedModel instance. If None, must be attached later via
|
263 |
+
attach_verifier() before using the model.
|
264 |
+
:param verifier_attachment_mode: Mode for verifier attachment. "detached"
|
265 |
+
prevents attachment even if verifier is provided. "full" enables
|
266 |
+
complete integration for both training and generation. "train_only"
|
267 |
+
attaches only components needed for training, optimizing memory usage.
|
268 |
+
"""
|
269 |
+
if not isinstance(config, EagleSpeculatorConfig):
|
270 |
+
raise ValueError(
|
271 |
+
"config must be an instance of EagleSpeculatorConfig, "
|
272 |
+
f"got {type(config)} instead."
|
273 |
+
)
|
274 |
+
|
275 |
+
# Initialize model parameters from config
|
276 |
+
self.vocab_size = config.transformer_layer_config.vocab_size
|
277 |
+
self.hidden_size = config.transformer_layer_config.hidden_size
|
278 |
+
self.padding_idx = config.transformer_layer_config.pad_token_id
|
279 |
+
|
280 |
+
# Set layers pulled from the verifier to None until attach is called
|
281 |
+
self.embed_tokens: Optional[nn.Embedding] = None
|
282 |
+
self.rotary_emb: Optional[nn.Module] = None
|
283 |
+
self.lm_head: Optional[nn.Linear] = None
|
284 |
+
|
285 |
+
# Delayed initialization to ensure everything needed for attach_verifier is set
|
286 |
+
super().__init__(
|
287 |
+
config=config,
|
288 |
+
verifier=verifier,
|
289 |
+
verifier_attachment_mode=verifier_attachment_mode,
|
290 |
+
)
|
291 |
+
|
292 |
+
# Initialize layers based on the configuration
|
293 |
+
self.embedding_layernorm: Optional[nn.Module] = self._create_layernorm()
|
294 |
+
self.fusion_fc: nn.Linear = nn.Linear(
|
295 |
+
2 * self.hidden_size,
|
296 |
+
self.hidden_size,
|
297 |
+
bias=config.fusion_bias,
|
298 |
+
)
|
299 |
+
self.transformer: nn.Module = self._create_transformer_layer()
|
300 |
+
self.pre_lm_head_layernorm: Optional[nn.Module] = self._create_layernorm()
|
301 |
+
|
302 |
+
self.post_init() # type: ignore[attr-defined]
|
303 |
+
|
304 |
+
def attach_verifier(
|
305 |
+
self,
|
306 |
+
verifier: Union[str, os.PathLike, PreTrainedModel],
|
307 |
+
mode: Optional[Literal["full", "train_only"]] = None,
|
308 |
+
) -> PreTrainedModel:
|
309 |
+
"""
|
310 |
+
Attach a verifier model to the EagleSpeculator for speculative decoding.
|
311 |
+
Utilizes the verifier's embed_tokens, rotary_emb, and lm_head layers
|
312 |
+
for the speculator's forward pass and generation methods.
|
313 |
+
Additionally, for `generate`, it uses the verifier's hidden states
|
314 |
+
to generate speculative token predictions.
|
315 |
+
|
316 |
+
If mode is "full", the verifier is fully integrated for use with
|
317 |
+
both `generate` and `forward` methods.
|
318 |
+
|
319 |
+
If mode is "train_only", only the verifier's layers required for a forward pass
|
320 |
+
are attached, allowing for better resource utilization during training.
|
321 |
+
`generate` will not be available until a full verifier is attached.
|
322 |
+
|
323 |
+
Example:
|
324 |
+
```python
|
325 |
+
# Load and attach a verifier
|
326 |
+
verifier = EagleSpeculator(...)
|
327 |
+
|
328 |
+
# For generation
|
329 |
+
speculator.attach_verifier(verifier)
|
330 |
+
outputs = speculator.generate(input_ids)
|
331 |
+
speculator.detach_verifier()
|
332 |
+
|
333 |
+
# For training
|
334 |
+
speculator.attach_verifier(verifier, mode="train_only")
|
335 |
+
outputs = speculator(input_ids, hidden_states)
|
336 |
+
speculator.detach_verifier()
|
337 |
+
```
|
338 |
+
|
339 |
+
:param verifier: The verifier model to attach. This can be a path to a local
|
340 |
+
model directory, a Hugging Face model identifier, or an instance of
|
341 |
+
PreTrainedModel. If a path or identifier is provided, the model will be
|
342 |
+
loaded automatically. If an instance is provided, it will be used directly.
|
343 |
+
:param mode: The mode for attaching the verifier. Can be "full" or "train_only".
|
344 |
+
If None, defaults to "full". In "train_only" mode, only the layers
|
345 |
+
required for a forward pass are attached, and the speculator cannot
|
346 |
+
perform generation until a full verifier is attached.
|
347 |
+
:return: The PreTrainedModel instance for the verifier that was attached.
|
348 |
+
"""
|
349 |
+
verifier = super().attach_verifier(
|
350 |
+
verifier=verifier,
|
351 |
+
mode=mode,
|
352 |
+
)
|
353 |
+
|
354 |
+
# Extract layers from the verifier model
|
355 |
+
|
356 |
+
if hasattr(verifier, "model"):
|
357 |
+
self.embed_tokens = verifier.model.embed_tokens # type: ignore[assignment]
|
358 |
+
self.rotary_emb = verifier.model.rotary_emb # type: ignore[assignment]
|
359 |
+
else:
|
360 |
+
# Bare model structure
|
361 |
+
self.embed_tokens = verifier.embed_tokens # type: ignore[assignment]
|
362 |
+
self.rotary_emb = verifier.rotary_emb # type: ignore[assignment]
|
363 |
+
|
364 |
+
# lm_head is always at the top level of the verifier
|
365 |
+
self.lm_head = verifier.lm_head
|
366 |
+
|
367 |
+
return verifier
|
368 |
+
|
369 |
+
def detach_verifier(self):
|
370 |
+
"""
|
371 |
+
Removes the reference to the attached verifier model and frees up the
|
372 |
+
associated memory. After calling this method, the speculator will not
|
373 |
+
be able to perform forward passes or generation until a new verifier
|
374 |
+
is attached.
|
375 |
+
"""
|
376 |
+
super().detach_verifier()
|
377 |
+
|
378 |
+
del self.embed_tokens
|
379 |
+
self.embed_tokens = None
|
380 |
+
del self.rotary_emb
|
381 |
+
self.rotary_emb = None
|
382 |
+
del self.lm_head
|
383 |
+
self.lm_head = None
|
384 |
+
|
385 |
+
def forward(
|
386 |
+
self,
|
387 |
+
input_ids: torch.LongTensor,
|
388 |
+
hidden_states: torch.FloatTensor,
|
389 |
+
attention_mask: Optional[torch.Tensor] = None,
|
390 |
+
position_ids: Optional[torch.LongTensor] = None,
|
391 |
+
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
392 |
+
use_cache: Optional[bool] = None,
|
393 |
+
output_attentions: Optional[bool] = None,
|
394 |
+
output_hidden_states: Optional[bool] = None, # noqa: ARG002
|
395 |
+
return_dict: Optional[bool] = None,
|
396 |
+
) -> Union[torch.FloatTensor, CausalLMOutputWithPast]:
|
397 |
+
"""
|
398 |
+
Execute the forward pass for speculative token generation.
|
399 |
+
|
400 |
+
Processes input tokens and verifier hidden states through the EAGLE architecture
|
401 |
+
to generate candidate tokens for speculative decoding. The method combines input
|
402 |
+
embeddings with verifier hidden states via a fusion layer, processes them
|
403 |
+
through a transformer decoder layer, and produces logits for next token
|
404 |
+
prediction.
|
405 |
+
|
406 |
+
:param input_ids: Token IDs for the current input sequence. Shape: (batch_size,
|
407 |
+
sequence_length). These represent the tokens that will be converted to
|
408 |
+
embeddings and combined with verifier hidden states.
|
409 |
+
:param hidden_states: Hidden state representations from the verifier model
|
410 |
+
corresponding to the input sequence. Shape: (batch_size, sequence_length,
|
411 |
+
hidden_size). These capture the verifier's understanding of the context.
|
412 |
+
:param attention_mask: Optional attention mask to avoid attending to padding
|
413 |
+
tokens. Shape: (batch_size, sequence_length) for 2D or (batch_size, 1,
|
414 |
+
sequence_length, sequence_length) for 4D causal mask.
|
415 |
+
:param position_ids: Optional position indices for tokens in the sequence.
|
416 |
+
Shape: (batch_size, sequence_length). If None, auto-generated based on
|
417 |
+
sequence length and past key values.
|
418 |
+
:param past_key_values: Optional cached key-value states from previous forward
|
419 |
+
passes for efficient generation. Tuple of layer key-value pairs.
|
420 |
+
:param use_cache: Whether to return key-value states for caching in subsequent
|
421 |
+
forward passes. Useful for autoregressive generation efficiency.
|
422 |
+
:param output_attentions: Whether to return attention weights from the
|
423 |
+
transformer layer. Used for analysis and visualization.
|
424 |
+
:param output_hidden_states: Whether to return hidden states from the
|
425 |
+
transformer layer. Currently not implemented in this model.
|
426 |
+
:param return_dict: Whether to return structured CausalLMOutputWithPast instead
|
427 |
+
of raw logits. If None, uses config.use_return_dict default.
|
428 |
+
:return: Either raw logits tensor (batch_size, sequence_length, vocab_size) if
|
429 |
+
return_dict=False, or CausalLMOutputWithPast containing logits, past key
|
430 |
+
values, and optional attention weights.
|
431 |
+
:raises ValueError: If verifier components (embed_tokens, rotary_emb, lm_head)
|
432 |
+
are not attached. Call attach_verifier() before using forward().
|
433 |
+
"""
|
434 |
+
if self.embed_tokens is None or self.rotary_emb is None or self.lm_head is None:
|
435 |
+
raise ValueError(
|
436 |
+
"Verifier model layers not initialized. "
|
437 |
+
"Call `attach_verifier` to set up the model before using forward."
|
438 |
+
)
|
439 |
+
|
440 |
+
return_dict = (
|
441 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
442 |
+
)
|
443 |
+
|
444 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
445 |
+
if self.embedding_layernorm is not None:
|
446 |
+
inputs_embeds = self.embedding_layernorm(inputs_embeds)
|
447 |
+
|
448 |
+
hidden_states = self.fusion_fc(
|
449 |
+
torch.cat([inputs_embeds, hidden_states], dim=-1)
|
450 |
+
)
|
451 |
+
hidden_states, attention_mask, position_ids = self._prepare_decoder_inputs(
|
452 |
+
hidden_states, attention_mask, position_ids, past_key_values
|
453 |
+
)
|
454 |
+
|
455 |
+
cos, sin = self.rotary_emb(hidden_states, position_ids)
|
456 |
+
layer_outputs = self.transformer(
|
457 |
+
hidden_states,
|
458 |
+
attention_mask=attention_mask,
|
459 |
+
position_ids=position_ids,
|
460 |
+
past_key_value=past_key_values[0] if past_key_values else None,
|
461 |
+
output_attentions=output_attentions,
|
462 |
+
use_cache=use_cache,
|
463 |
+
position_embeddings=(cos, sin),
|
464 |
+
)
|
465 |
+
hidden_states = layer_outputs[0]
|
466 |
+
|
467 |
+
if self.pre_lm_head_layernorm is not None:
|
468 |
+
hidden_states = self.pre_lm_head_layernorm(hidden_states)
|
469 |
+
|
470 |
+
logits = self.lm_head(hidden_states)
|
471 |
+
|
472 |
+
if not return_dict:
|
473 |
+
return logits
|
474 |
+
|
475 |
+
return CausalLMOutputWithPast(
|
476 |
+
logits=logits,
|
477 |
+
past_key_values=layer_outputs[1] if use_cache else None,
|
478 |
+
hidden_states=None,
|
479 |
+
attentions=None,
|
480 |
+
)
|
481 |
+
|
482 |
+
def _prepare_decoder_inputs(
|
483 |
+
self,
|
484 |
+
hidden_states: torch.FloatTensor,
|
485 |
+
attention_mask: Optional[torch.Tensor],
|
486 |
+
position_ids: Optional[torch.LongTensor],
|
487 |
+
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]],
|
488 |
+
) -> tuple[torch.FloatTensor, Optional[torch.Tensor], Optional[torch.LongTensor]]:
|
489 |
+
batch_size, seq_length = hidden_states.shape[:2]
|
490 |
+
|
491 |
+
if position_ids is None:
|
492 |
+
device = hidden_states.device
|
493 |
+
position_ids = (
|
494 |
+
torch.arange(seq_length, dtype=torch.long, device=device) # type: ignore[assignment]
|
495 |
+
.unsqueeze(0)
|
496 |
+
.expand(batch_size, -1)
|
497 |
+
)
|
498 |
+
|
499 |
+
if attention_mask is not None and attention_mask.dim() == 2: # noqa: PLR2004
|
500 |
+
past_key_values_length = (
|
501 |
+
past_key_values[0][0].shape[2] if past_key_values else 0
|
502 |
+
)
|
503 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
504 |
+
attention_mask,
|
505 |
+
(batch_size, seq_length),
|
506 |
+
hidden_states,
|
507 |
+
past_key_values_length,
|
508 |
+
sliding_window=getattr(self.config, "sliding_window", None),
|
509 |
+
)
|
510 |
+
|
511 |
+
return hidden_states, attention_mask, position_ids
|
512 |
+
|
513 |
+
def _create_layernorm(self) -> Optional[nn.Module]:
|
514 |
+
if not self.config.layernorms:
|
515 |
+
return None
|
516 |
+
|
517 |
+
return self._layernorm_class()(
|
518 |
+
self.hidden_size, eps=self.config.transformer_layer_config.rms_norm_eps
|
519 |
+
)
|
520 |
+
|
521 |
+
def _create_transformer_layer(self) -> nn.Module:
|
522 |
+
layer_class = self._transformer_layer_class()
|
523 |
+
layer = layer_class(
|
524 |
+
self.config.transformer_layer_config,
|
525 |
+
layer_idx=0,
|
526 |
+
)
|
527 |
+
|
528 |
+
if not self.config.layernorms:
|
529 |
+
# Replace input_layernorm with Identity if layernorms are not used
|
530 |
+
layer.input_layernorm = nn.Identity()
|
531 |
+
|
532 |
+
return layer
|
533 |
+
|
534 |
+
def _layernorm_class(self) -> type[nn.Module]:
|
535 |
+
return LlamaRMSNorm
|
536 |
+
|
537 |
+
def _transformer_layer_class(self) -> type[nn.Module]:
|
538 |
+
return LlamaDecoderLayer
|
generation_config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"transformers_version": "4.52.4"
|
4 |
+
}
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:25c4dba5bf103b1ca1879b4eea7804871ad8a8b69310f69418b8680bc08f7312
|
3 |
+
size 1006699800
|