akhauriyash commited on
Commit
d00164e
·
1 Parent(s): 6896657

base files

Browse files
conversion.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LlamaForCausalLM, LlamaConfig, AutoTokenizer
2
+ import torch
3
+ import os
4
+
5
+ # huggingface-cli download deepseek-ai/DeepSeek-R1-Distill-Llama-8B tokenizer_config.json --local-dir ./
6
+ # huggingface-cli download deepseek-ai/DeepSeek-R1-Distill-Llama-8B tokenizer.json --local-dir ./
7
+
8
+ question = "A $y$-intercept is a point on the graph that lies on the $y$-axis, so $x = 0$. Hence, the number $y$-intercepts corresponds to the number of real solutions of the quadratic equation $y^2 - 4y - 1 = 0$. The discriminant of this quadratic equation is $(-4)^2 + 4 \cdot 1 \cdot (-1) = 20$, which is positive, so the quadratic has two distinct real roots. Therefore, the number of $y$-intercepts is $\boxed{2}$. \n \n [asy] \n size(150); \n real ticklen=3; \n real tickspace=2; \n \n real ticklength=0.1cm; \n real axisarrowsize=0.14cm; \n pen axispen=black+1.3bp; \n real vectorarrowsize=0.2cm; \n real tickdown=-0.5; \n real tickdownlength=-0.15inch; \n real tickdownbase=0.3; \n real wholetickdown=tickdown; \n void rr_cartesian_axes(real xleft, real xright, real ybottom, real ytop, real xstep=1, real ystep=1, bool \n \n useticks=false, bool complexplane=false, bool usegrid=true) { \n \n import graph; \n \n real i; \n \n if(complexplane) { \n \n label('$\textnormal{Re}$',(xright,0),SE); \n \n label('$\textnormal{Im}$',(0,ytop),NW); \n \n } else { \n \n label('$x$',(xright+0.4,-0.5)); \n \n label('$y$',(-0.5,ytop+0.2)); \n \n } \n \n ylimits(ybottom,ytop); \n \n xlimits( xleft, xright); \n \n real[] TicksArrx,TicksArry; \n \n for(i=xleft+xstep; i<xright; i+=xstep) { \n \n if(abs(i) >0.1) { \n \n TicksArrx.push(i); \n \n } \n \n } \n \n for(i=ybottom+ystep; i<ytop; i+=ystep) { \n \n if(abs(i) >0.1) { \n \n TicksArry.push(i); \n \n } \n \n } \n \n if(usegrid) {"
9
+ predictor_load_path = "/home/ya255/projects/TokenButler/expt_model/TrainTokenButler_42_finetune_None_None_500_llama_deepseek-ai_DeepSeek-R1-Distill-Llama-8B_L3_8B_R1_1K.csv_L3_8B_R1_1K_False_False_2000_False_redpajama_1024_1_1_20_0.001_1024/16_False_4_1000_ExpPred_fixed_40pc_True_False_0_None_False_False_4_8_2_32_1024_False_False_True_32_0.3875000000000002__best.pt"
10
+ base_model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
11
+
12
+ def get_producer_layers(model):
13
+ """
14
+ Traverses the model to find the producer layer (layer_idx=0).cc
15
+ """
16
+ producer_modules = []
17
+ for module in model.modules():
18
+ if module.__class__.__name__.endswith("AttentionExperimental") and module.layer_idx == 0:
19
+ producer_modules.append(module)
20
+ return producer_modules
21
+
22
+ # 1) Load the base model from HF
23
+ base_model = LlamaForCausalLM.from_pretrained(base_model_name, device_map="auto")
24
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name)
25
+ inputs = tokenizer(question, return_tensors="pt")
26
+ inputs = {k: v.to(base_model.device) for k, v in inputs.items()}
27
+ question_length = inputs['attention_mask'].shape[1]
28
+
29
+ with torch.no_grad():
30
+ base_output_ids = base_model.generate(
31
+ **inputs,
32
+ max_new_tokens=200,
33
+ do_sample=True,
34
+ top_p=0.95,
35
+ temperature=0.7,
36
+ )
37
+ base_output_text = tokenizer.decode(base_output_ids[0][question_length:], skip_special_tokens=True)
38
+
39
+ # Remove base model from GPU
40
+ base_model_device = base_model.device
41
+ base_model.to("cpu")
42
+ base_state_dict = base_model.state_dict()
43
+ del base_model
44
+ torch.cuda.empty_cache()
45
+
46
+ from modeling_llama_butler import LlamaButlerConfig, LlamaButlerForCausalLM
47
+ butler_config = LlamaButlerConfig.from_pretrained('config.json')
48
+
49
+ butler_model = LlamaButlerForCausalLM(butler_config)
50
+ butler_model.load_state_dict(base_state_dict, strict=False)
51
+
52
+ model_producer_layers = get_producer_layers(butler_model)
53
+ producer_layer_weights = torch.load(predictor_load_path)
54
+ for idx, producer_layer_weight in enumerate(producer_layer_weights):
55
+ try:
56
+ model_producer_layers[idx].load_state_dict(producer_layer_weight, strict=False)
57
+ except Exception as e:
58
+ print(f"Error loading producer layer {idx}: {e}")
59
+ print("\n\nContinuing... !! Bad Perf If Unintentional !!\n\n")
60
+
61
+
62
+ butler_model.to(base_model_device)
63
+ butler_model.eval()
64
+
65
+ with torch.no_grad():
66
+ butler_output_ids = butler_model.generate(
67
+ **inputs,
68
+ max_new_tokens=200,
69
+ do_sample=True,
70
+ top_p=0.95,
71
+ temperature=0.7,
72
+ )
73
+
74
+ butler_output_text = tokenizer.decode(butler_output_ids[0][question_length:], skip_special_tokens=True)
75
+
76
+ print("\n=== Base Model Output (Newlines Removed For Brevity) ===\n")
77
+ print(base_output_text.replace("\n", ""))
78
+ print("\n")
79
+ print("=== Butler Model Output (Newlines Removed For Brevity) ===\n")
80
+ print(butler_output_text.replace("\n", ""))
81
+ print("\n")
82
+
83
+ OUTPUT_DIR = "."
84
+ print(f"\nSaving final merged model to: {OUTPUT_DIR}")
85
+ butler_model.save_pretrained(OUTPUT_DIR, safe_serialization=False)
86
+
87
+ # tokenizer.save_pretrained(OUTPUT_DIR)
88
+ print("\nAll done! The folder should now have `pytorch_model.bin` and the updated `config.json`.\n")
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 128000,
4
+ "eos_token_id": 128001,
5
+ "transformers_version": "4.48.3"
6
+ }
modeling_llama_butler.py ADDED
@@ -0,0 +1,1434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Dict
4
+ from transformers import LlamaForCausalLM, LlamaConfig
5
+ from transformers.generation.utils import GenerationConfig
6
+ import os
7
+ import pdb
8
+ import copy
9
+ import math
10
+ import numpy as np
11
+ from dataclasses import dataclass
12
+ from typing import Any, Dict, List, Optional, Tuple, Union
13
+ import gc
14
+
15
+ import traceback
16
+ import torch
17
+ from torch import nn
18
+ import torch.utils.checkpoint
19
+ import torch.nn.functional as F
20
+ from torch.cuda.amp import autocast
21
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
22
+
23
+ from transformers.models.llama.configuration_llama import LlamaConfig
24
+ from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, LlamaAttention, apply_rotary_pos_emb
25
+
26
+ from transformers.cache_utils import DynamicCache
27
+
28
+ class PredictorDynamicCache(DynamicCache):
29
+ def __init__(self):
30
+ super().__init__()
31
+ self.predictor_primary_key: List[Optional[torch.Tensor]] = []
32
+ self.predictor_primary_value: List[Optional[torch.Tensor]] = []
33
+ self.predictor_importance_key: List[Optional[torch.Tensor]] = []
34
+
35
+ def update_predictor_primary(
36
+ self,
37
+ key_states: torch.Tensor,
38
+ value_states: torch.Tensor,
39
+ layer_idx: int,
40
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
41
+ """
42
+ Append or create the predictor's "primary" K/V states for `layer_idx`.
43
+
44
+ shape for key_states, value_states is typically [batch_size, num_heads, seq_len, head_dim].
45
+ """
46
+ # Extend the lists so that `predictor_primary_key[layer_idx]` and
47
+ # `predictor_primary_value[layer_idx]` exist.
48
+ self._ensure_list_capacity(
49
+ self.predictor_primary_key, layer_idx, fill=None
50
+ )
51
+ self._ensure_list_capacity(
52
+ self.predictor_primary_value, layer_idx, fill=None
53
+ )
54
+
55
+ # If this is the very first time we are updating that layer's predictor cache, just assign
56
+ if self.predictor_primary_key[layer_idx] is None:
57
+ self.predictor_primary_key[layer_idx] = key_states
58
+ self.predictor_primary_value[layer_idx] = value_states
59
+ else:
60
+ # Otherwise, concatenate along the seq_len dimension (=-2 or =2 depending on your shape).
61
+ self.predictor_primary_key[layer_idx] = torch.cat(
62
+ [self.predictor_primary_key[layer_idx], key_states], dim=2
63
+ )
64
+ self.predictor_primary_value[layer_idx] = torch.cat(
65
+ [self.predictor_primary_value[layer_idx], value_states], dim=2
66
+ )
67
+
68
+ return (
69
+ self.predictor_primary_key[layer_idx],
70
+ self.predictor_primary_value[layer_idx],
71
+ )
72
+
73
+ def update_predictor_importance(
74
+ self,
75
+ key_states: torch.Tensor,
76
+ layer_idx: int,
77
+ ) -> torch.Tensor:
78
+ """
79
+ Append or create the predictor's "importance" key for `layer_idx`.
80
+ """
81
+ self._ensure_list_capacity(
82
+ self.predictor_importance_key, layer_idx, fill=None
83
+ )
84
+
85
+ if self.predictor_importance_key[layer_idx] is None:
86
+ self.predictor_importance_key[layer_idx] = key_states
87
+ else:
88
+ self.predictor_importance_key[layer_idx] = torch.cat(
89
+ [self.predictor_importance_key[layer_idx], key_states], dim=2
90
+ )
91
+ return self.predictor_importance_key[layer_idx]
92
+
93
+ def crop(self, max_length: int):
94
+ super().crop(max_length)
95
+ # Now also crop predictor caches
96
+ for idx in range(len(self.predictor_primary_key)):
97
+ if self.predictor_primary_key[idx] is not None:
98
+ self.predictor_primary_key[idx] = self.predictor_primary_key[idx][..., :max_length, :]
99
+ self.predictor_primary_value[idx] = self.predictor_primary_value[idx][..., :max_length, :]
100
+
101
+ for idx in range(len(self.predictor_importance_key)):
102
+ if self.predictor_importance_key[idx] is not None:
103
+ self.predictor_importance_key[idx] = self.predictor_importance_key[idx][..., :max_length, :]
104
+
105
+ # Remember to adjust self._seen_tokens accordingly
106
+ self._seen_tokens = min(self._seen_tokens, max_length)
107
+
108
+ def batch_split(
109
+ self, full_batch_size: int, split_size: int, num_hidden_layers: int = None
110
+ ) -> List["PredictorDynamicCache"]:
111
+ # Use the base split logic for the standard K/V
112
+ base_splits = super().batch_split(full_batch_size, split_size, num_hidden_layers)
113
+ # `base_splits` is now a list of new DynamicCache objects. But we *actually*
114
+ # want them to be PredictorDynamicCache so we can store the predictor states.
115
+ # Easiest: we can cast and fill them.
116
+ out: List[PredictorDynamicCache] = []
117
+
118
+ for split_i, base_split in enumerate(base_splits):
119
+ # Construct an empty PredictorDynamicCache
120
+ new_cache = PredictorDynamicCache()
121
+ # Copy over the underlying fields from base_split
122
+ new_cache.key_cache = base_split.key_cache
123
+ new_cache.value_cache = base_split.value_cache
124
+ new_cache._seen_tokens = base_split._seen_tokens
125
+
126
+ # Now also slice our predictor fields
127
+ # The slice in batch dim is [i:i+split_size].
128
+ b_start = split_i * split_size
129
+ b_end = min(full_batch_size, b_start + split_size)
130
+
131
+ new_cache.predictor_primary_key = self._slice_list_tensors(
132
+ self.predictor_primary_key, b_start, b_end
133
+ )
134
+ new_cache.predictor_primary_value = self._slice_list_tensors(
135
+ self.predictor_primary_value, b_start, b_end
136
+ )
137
+ new_cache.predictor_importance_key = self._slice_list_tensors(
138
+ self.predictor_importance_key, b_start, b_end
139
+ )
140
+
141
+ out.append(new_cache)
142
+
143
+ return out
144
+
145
+ @classmethod
146
+ def from_batch_splits(cls, splits: List["PredictorDynamicCache"], num_hidden_layers: int = None) -> "PredictorDynamicCache":
147
+ # Let the base class handle the normal K/V merges
148
+ base_merged = DynamicCache.from_batch_splits(splits, num_hidden_layers=num_hidden_layers)
149
+ merged = cls()
150
+ merged.key_cache = base_merged.key_cache
151
+ merged.value_cache = base_merged.value_cache
152
+ merged._seen_tokens = base_merged._seen_tokens
153
+
154
+ # Now unify predictor states by concatenating along batch dim=0
155
+ merged.predictor_primary_key = cls._merge_list_tensors(
156
+ [split.predictor_primary_key for split in splits]
157
+ )
158
+ merged.predictor_primary_value = cls._merge_list_tensors(
159
+ [split.predictor_primary_value for split in splits]
160
+ )
161
+ merged.predictor_importance_key = cls._merge_list_tensors(
162
+ [split.predictor_importance_key for split in splits]
163
+ )
164
+
165
+ return merged
166
+
167
+ def batch_repeat_interleave(self, repeats: int):
168
+ super().batch_repeat_interleave(repeats)
169
+ self.predictor_primary_key = self._repeat_list_tensors(
170
+ self.predictor_primary_key, repeats
171
+ )
172
+ self.predictor_primary_value = self._repeat_list_tensors(
173
+ self.predictor_primary_value, repeats
174
+ )
175
+ self.predictor_importance_key = self._repeat_list_tensors(
176
+ self.predictor_importance_key, repeats
177
+ )
178
+
179
+ def batch_select_indices(self, indices: torch.Tensor):
180
+ super().batch_select_indices(indices)
181
+ self.predictor_primary_key = self._select_list_tensors(
182
+ self.predictor_primary_key, indices
183
+ )
184
+ self.predictor_primary_value = self._select_list_tensors(
185
+ self.predictor_primary_value, indices
186
+ )
187
+ self.predictor_importance_key = self._select_list_tensors(
188
+ self.predictor_importance_key, indices
189
+ )
190
+
191
+ @staticmethod
192
+ def _ensure_list_capacity(lst: list, idx: int, fill=None):
193
+ if len(lst) <= idx:
194
+ lst.extend([fill] * (idx + 1 - len(lst)))
195
+
196
+ @staticmethod
197
+ def _slice_list_tensors(
198
+ tensor_list: List[Optional[torch.Tensor]], start: int, end: int
199
+ ) -> List[Optional[torch.Tensor]]:
200
+ out = []
201
+ for t in tensor_list:
202
+ if t is None:
203
+ out.append(None)
204
+ else:
205
+ out.append(t[start:end, ...])
206
+ return out
207
+
208
+ @classmethod
209
+ def _merge_list_tensors(
210
+ cls, list_of_lists: List[List[Optional[torch.Tensor]]]
211
+ ) -> List[Optional[torch.Tensor]]:
212
+ # If no splits, return empty
213
+ if not list_of_lists:
214
+ return []
215
+
216
+ # Number of layers is length of the sub-list from the first split
217
+ max_len = len(list_of_lists[0])
218
+ merged = [None] * max_len
219
+
220
+ for layer_idx in range(max_len):
221
+ # collect that layer_idx from each split
222
+ chunk_tensors = []
223
+ for split in list_of_lists:
224
+ t = split[layer_idx] if layer_idx < len(split) else None
225
+ if t is not None:
226
+ chunk_tensors.append(t)
227
+ if len(chunk_tensors) == 0:
228
+ merged[layer_idx] = None
229
+ else:
230
+ merged[layer_idx] = torch.cat(chunk_tensors, dim=0)
231
+ return merged
232
+
233
+ @staticmethod
234
+ def _repeat_list_tensors(
235
+ tensor_list: List[Optional[torch.Tensor]], repeats: int
236
+ ) -> List[Optional[torch.Tensor]]:
237
+ out = []
238
+ for t in tensor_list:
239
+ if t is None:
240
+ out.append(None)
241
+ else:
242
+ out.append(t.repeat_interleave(repeats, dim=0))
243
+ return out
244
+
245
+ @staticmethod
246
+ def _select_list_tensors(
247
+ tensor_list: List[Optional[torch.Tensor]], indices: torch.Tensor
248
+ ) -> List[Optional[torch.Tensor]]:
249
+ out = []
250
+ for t in tensor_list:
251
+ if t is None:
252
+ out.append(None)
253
+ else:
254
+ out.append(t.index_select(0, indices))
255
+ return out
256
+
257
+
258
+ class TokenImportancePredictorAttentive(nn.Module):
259
+ def __init__(self, config, pred_hid_size, num_heads, num_hidden_layers, dDash, intdim, \
260
+ attn_reduce_factor, dropout=0.1):
261
+ """
262
+ Optimized Token Importance Predictor with parallel Q-K projections and simplified mapping.
263
+
264
+ Args:
265
+ config: Configuration object containing model parameters.
266
+ pred_hid_size (int): Hidden size for the predictor's attention layer.
267
+ num_heads (int): Number of attention heads.
268
+ num_hidden_layers (int): Number of transformer layers to predict.
269
+ dropout (float): Dropout probability.
270
+ q_downscale (int): Factor to downscale the Q dimension for efficiency.
271
+ intermediate_dim (int): Intermediate dimension for non-linear transformations in projections.
272
+ """
273
+ super().__init__()
274
+ self.config = config
275
+ self.hidden_size = pred_hid_size
276
+ self.num_heads = num_heads
277
+ self.num_hidden_layers = num_hidden_layers
278
+ self.dropout = dropout
279
+ self.head_dim = pred_hid_size // (num_heads * 4) # Predictor head dimension is not the same as the model head dimension.
280
+ self.rope_theta = config.rope_theta
281
+ self.dDash = dDash
282
+ self.intermediate_dim = intdim
283
+ self.attn_reduce_factor = attn_reduce_factor
284
+ self.max_position_embeddings = config.max_position_embeddings
285
+ self.flash_attn = False
286
+ assert pred_hid_size % (num_heads * 4) == 0, "pred_hid_size must be divisible by num_heads * 4."
287
+
288
+ # Reduce the hidden size for attention computations
289
+ self.hidden_size_reduced = self.hidden_size // self.attn_reduce_factor # For example, reduce to 1/4th
290
+ assert self.hidden_size_reduced % self.num_heads == 0, "Reduced hidden size must be divisible by num_heads"
291
+ self.attn_head_dim = self.hidden_size_reduced // self.num_heads
292
+
293
+ # Input projection to reduce hidden size
294
+ self.input_proj = nn.Linear(self.hidden_size, self.hidden_size_reduced, bias=False)
295
+
296
+ # Query, Key, Value projections for attention
297
+ self.q_proj_attn = nn.Linear(self.hidden_size_reduced, self.hidden_size_reduced, bias=False)
298
+ self.k_proj_attn = nn.Linear(self.hidden_size_reduced, self.hidden_size_reduced, bias=False)
299
+ self.v_proj_attn = nn.Linear(self.hidden_size_reduced, self.hidden_size_reduced, bias=False)
300
+ # Output projection to restore hidden size
301
+ # self.o_proj_attn = nn.Linear(self.hidden_size_reduced, self.hidden_size_reduced, bias=False)
302
+ self.attn_dropout = nn.Dropout(self.dropout)
303
+
304
+ # LayerNorm and Feed-forward network
305
+ self.norm1 = nn.LayerNorm(self.hidden_size_reduced)
306
+ self.norm2 = nn.LayerNorm(self.hidden_size)
307
+
308
+ self.ffn_hidden_size = 2 * self.hidden_size_reduced # Typical FFN hidden size
309
+ self.ffn = nn.Sequential(
310
+ nn.Linear(self.hidden_size_reduced, self.ffn_hidden_size),
311
+ nn.GELU(),
312
+ nn.Linear(self.ffn_hidden_size, self.hidden_size),
313
+ nn.Dropout(self.dropout)
314
+ )
315
+ # Add extra LayerNorm for the importance branch when not using the old design.
316
+ self.norm_importance = nn.LayerNorm(self.hidden_size)
317
+
318
+ # Define Q and K projection layers for all layers in parallel with non-linearity[]
319
+ # Output shape: [B, L, N * H * D']
320
+ self.q_proj_importance = nn.Sequential(
321
+ nn.Linear(pred_hid_size, self.intermediate_dim, bias=False),
322
+ nn.SiLU(),
323
+ nn.Linear(self.intermediate_dim, num_hidden_layers * num_heads * self.dDash, bias=False)
324
+ )
325
+ self.k_proj_importance = nn.Sequential(
326
+ nn.Linear(pred_hid_size, self.intermediate_dim, bias=False),
327
+ nn.SiLU(),
328
+ nn.Linear(self.intermediate_dim, num_hidden_layers * num_heads * self.dDash, bias=False)
329
+ )
330
+
331
+ # Initialize rotary positional embeddings
332
+ self._init_rope()
333
+ self._initialize_weights()
334
+ self.device = None
335
+
336
+ def _initialize_weights(self):
337
+ for name, module in self.named_modules():
338
+ if isinstance(module, nn.Linear):
339
+ nn.init.xavier_uniform_(module.weight) # Xavier initialization for linear layers
340
+ if module.bias is not None:
341
+ nn.init.constant_(module.bias, 0)
342
+ elif isinstance(module, nn.LayerNorm):
343
+ nn.init.constant_(module.weight, 1.0)
344
+ nn.init.constant_(module.bias, 0.0)
345
+ elif isinstance(module, nn.MultiheadAttention):
346
+ # Initialize in_proj_weight
347
+ nn.init.xavier_uniform_(module.in_proj_weight)
348
+ if module.in_proj_bias is not None:
349
+ nn.init.constant_(module.in_proj_bias, 0)
350
+
351
+ # Initialize out_proj
352
+ nn.init.xavier_uniform_(module.out_proj.weight)
353
+ if module.out_proj.bias is not None:
354
+ nn.init.constant_(module.out_proj.bias, 0)
355
+
356
+ def _init_rope(self):
357
+
358
+ # send self.config but after modifying head_dim to be self.head_dim just in the function call
359
+ config_copy = copy.deepcopy(self.config)
360
+ config_copy.rope_scaling = {
361
+ "factor": 32.0,
362
+ "high_freq_factor": 4.0,
363
+ "low_freq_factor": 1.0,
364
+ "original_max_position_embeddings": 8192,
365
+ "rope_type": "llama3"
366
+ }
367
+ config_copy.head_dim = self.attn_head_dim
368
+
369
+ # Rotary embedding for attention layer
370
+ self.rotary_emb_attn = LlamaRotaryEmbedding(
371
+ config_copy
372
+ )
373
+
374
+ config_copy.head_dim = self.dDash
375
+ # Rotary embedding for importance projection
376
+ self.rotary_emb_importance = LlamaRotaryEmbedding(
377
+ config_copy
378
+ )
379
+
380
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False, layer_idx=None):
381
+ """
382
+ Forward pass for the Optimized Token Importance Predictor.
383
+
384
+ Args:
385
+ hidden_states (torch.Tensor): Input tensor of shape [B, L, HQ].
386
+ attention_mask (torch.Tensor, optional): Attention mask of shape [B, 1, 1, L] or [B, 1, L, L].
387
+ position_ids (torch.Tensor, optional): Position IDs.
388
+ past_key_value (tuple, optional): Past key and value states.
389
+ use_cache (bool, optional): Whether to use cache.
390
+
391
+ Returns:
392
+ torch.Tensor: Importance scores of shape [B, N, H, L, L].
393
+ """
394
+ layer_idx = 0 # Guaranteed to be 0, as we only have one predictor!
395
+
396
+ # Set device if not already set
397
+ if self.device != hidden_states.device:
398
+ self.device = hidden_states.device
399
+ self.to(self.device)
400
+
401
+ B, L, E = hidden_states.size()
402
+
403
+ # Reduce hidden size
404
+ hidden_states = hidden_states.to(self.input_proj.weight.dtype)
405
+ hidden_states_reduced = self.input_proj(hidden_states) # [B, L, hidden_size_reduced]
406
+ # Compute q, k, v for attention
407
+ q = self.q_proj_attn(hidden_states_reduced) # [B, L, hidden_size_reduced]
408
+ k = self.k_proj_attn(hidden_states_reduced) # [B, L, hidden_size_reduced]
409
+ v = self.v_proj_attn(hidden_states_reduced) # [B, L, hidden_size_reduced]
410
+ # Reshape q, k, v to [B, num_heads, L, attn_head_dim]
411
+ q = q.view(B, L, self.num_heads, self.attn_head_dim).transpose(1, 2) # [B, num_heads, L, attn_head_dim]
412
+ k = k.view(B, L, self.num_heads, self.attn_head_dim).transpose(1, 2) # [B, num_heads, L, attn_head_dim]
413
+ v = v.view(B, L, self.num_heads, self.attn_head_dim).transpose(1, 2) # [B, num_heads, L, attn_head_dim]
414
+ if (past_key_value is not None
415
+ and layer_idx < len(past_key_value.predictor_primary_key)
416
+ and past_key_value.predictor_primary_key[layer_idx] is not None):
417
+ offset = past_key_value.predictor_primary_key[layer_idx].shape[2] # old_k.shape[2]
418
+ else:
419
+ offset = 0
420
+
421
+ # total seq length for new + old
422
+ kv_seq_len = offset + L
423
+
424
+ # Step 2: build position_ids for just the new chunk [offset..offset+L-1]
425
+ if position_ids is None:
426
+ # shape [B, L], e.g. [0..(offset+L-1)]
427
+ position_ids = torch.arange(offset, offset + L, dtype=torch.long, device=self.device)
428
+ position_ids = position_ids.unsqueeze(0).expand(B, L)
429
+
430
+ # Step 3: apply rotary to just the new chunk k,v with the correct offset
431
+ cos, sin = self.rotary_emb_attn(v, position_ids)
432
+ q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
433
+
434
+ # Step 4: ask the cache to append them. Then re‐assign k, v to the full cat
435
+ if use_cache and past_key_value is not None:
436
+ k, v = past_key_value.update_predictor_primary(k.detach(), v.detach(), layer_idx)
437
+ kv_seq_len = k.size(2) # now includes old + new
438
+
439
+ attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True)
440
+ attn_output = attn_output.to(q.dtype)
441
+ attn_output = attn_output.transpose(1, 2).contiguous().view(B, L, self.hidden_size_reduced)
442
+ attn_output = self.norm1(attn_output)
443
+ ffn_output = self.ffn(attn_output)
444
+ # Temporary measure, till old predictor fully deprecated
445
+ hidden_states = self.norm2(hidden_states + ffn_output)
446
+
447
+ B, L, E = hidden_states.size()
448
+ # Importance projections
449
+ H = self.num_heads
450
+ N = self.num_hidden_layers
451
+
452
+ hidden_states_for_importance = self.norm_importance(hidden_states)
453
+ q_importance = self.q_proj_importance(hidden_states_for_importance)
454
+ k_importance = self.k_proj_importance(hidden_states_for_importance)
455
+
456
+ # Reshape and permute to [B, N, H, L, D']
457
+ q_importance = q_importance.view(B, L, N, H, self.dDash).permute(0, 2, 3, 1, 4).contiguous() # [B, N, H, L, D']
458
+ k_importance = k_importance.view(B, L, N, H, self.dDash).permute(0, 2, 3, 1, 4).contiguous() # [B, N, H, L, D']
459
+
460
+ # Flatten N and H for efficient computation
461
+ q_importance = q_importance.view(B * N * H, L, self.dDash) # [BNH, L, D']
462
+ k_importance = k_importance.view(B * N * H, L, self.dDash) # [BNH, L, D']
463
+
464
+ # Apply rotary positional embeddings
465
+ cos, sin = self.rotary_emb_importance(k_importance, position_ids)
466
+ q_importance, k_importance = apply_rotary_pos_emb(q_importance, k_importance, cos, sin, position_ids)
467
+
468
+ if use_cache and past_key_value is not None:
469
+ k_importance = past_key_value.update_predictor_importance(k_importance.detach(), layer_idx)
470
+
471
+ k_importance = k_importance.view(B * H, N, -1, self.dDash) # [BNH, L, D']
472
+ q_importance = q_importance.view(B * H, N, -1, self.dDash) # [BH, N, L, D']
473
+ return q_importance, k_importance
474
+
475
+
476
+
477
+ class HeadImportancePredictor(nn.Module):
478
+ def __init__(self, config, pred_hid_size, num_heads, num_hidden_layers, dDash, intdim, \
479
+ attn_reduce_factor, dropout=0.1):
480
+ """
481
+ Optimized Token Importance Predictor with parallel Q-K projections and simplified mapping.
482
+
483
+ Args:
484
+ config: Configuration object containing model parameters.
485
+ pred_hid_size (int): Hidden size for the predictor's attention layer.
486
+ num_heads (int): Number of attention heads.
487
+ num_hidden_layers (int): Number of transformer layers to predict.
488
+ dropout (float): Dropout probability.
489
+ q_downscale (int): Factor to downscale the Q dimension for efficiency.
490
+ intermediate_dim (int): Intermediate dimension for non-linear transformations in projections.
491
+ """
492
+ super().__init__()
493
+ self.is_head_predictor = None
494
+ self.config = config
495
+ self.hidden_size = pred_hid_size
496
+ self.num_heads = num_heads
497
+ self.num_hidden_layers = num_hidden_layers
498
+ self.dropout = dropout
499
+ self.head_dim = pred_hid_size // (num_heads * 4)
500
+ self.rope_theta = config.rope_theta
501
+ self.dDash = dDash
502
+ self.intermediate_dim = intdim
503
+ self.attn_reduce_factor = attn_reduce_factor
504
+ self.max_position_embeddings = config.max_position_embeddings
505
+ self.flash_attn = False
506
+
507
+ # Reduce the hidden size for attention computations
508
+ self.hidden_size_reduced = self.hidden_size // self.attn_reduce_factor # For example, reduce to 1/4th
509
+ assert self.hidden_size_reduced % self.num_heads == 0, "Reduced hidden size must be divisible by num_heads"
510
+ self.attn_head_dim = self.hidden_size_reduced // self.num_heads
511
+
512
+ # Input projection to reduce hidden size
513
+ self.input_proj = nn.Linear(self.hidden_size, self.hidden_size_reduced, bias=False)
514
+
515
+ # Query, Key, Value projections for attention
516
+ self.q_proj_attn = nn.Linear(self.hidden_size_reduced, self.hidden_size_reduced, bias=False)
517
+ self.k_proj_attn = nn.Linear(self.hidden_size_reduced, self.hidden_size_reduced, bias=False)
518
+ self.v_proj_attn = nn.Linear(self.hidden_size_reduced, self.hidden_size_reduced, bias=False)
519
+ # Output projection to restore hidden size
520
+ # self.o_proj_attn = nn.Linear(self.hidden_size_reduced, self.hidden_size_reduced, bias=False)
521
+ self.attn_dropout = nn.Dropout(self.dropout)
522
+
523
+ # LayerNorm and Feed-forward network
524
+ self.norm1 = nn.LayerNorm(self.hidden_size_reduced)
525
+ self.norm2 = nn.LayerNorm(self.hidden_size)
526
+
527
+ self.ffn_hidden_size = 4 * self.hidden_size_reduced # Typical FFN hidden size
528
+ self.ffn = nn.Sequential(
529
+ nn.Linear(self.hidden_size_reduced, self.ffn_hidden_size),
530
+ nn.GELU(),
531
+ nn.Linear(self.ffn_hidden_size, self.num_heads * self.num_hidden_layers),
532
+ )
533
+
534
+ # Initialize rotary positional embeddings
535
+ self._init_rope()
536
+ self._initialize_weights()
537
+ self.device = None
538
+
539
+ def _initialize_weights(self):
540
+ for name, module in self.named_modules():
541
+ if isinstance(module, nn.Linear):
542
+ nn.init.xavier_uniform_(module.weight) # Xavier initialization for linear layers
543
+ if module.bias is not None:
544
+ nn.init.constant_(module.bias, 0)
545
+ elif isinstance(module, nn.LayerNorm):
546
+ nn.init.constant_(module.weight, 1.0)
547
+ nn.init.constant_(module.bias, 0.0)
548
+ elif isinstance(module, nn.MultiheadAttention):
549
+ # Initialize in_proj_weight
550
+ nn.init.xavier_uniform_(module.in_proj_weight)
551
+ if module.in_proj_bias is not None:
552
+ nn.init.constant_(module.in_proj_bias, 0)
553
+
554
+ # Initialize out_proj
555
+ nn.init.xavier_uniform_(module.out_proj.weight)
556
+ if module.out_proj.bias is not None:
557
+ nn.init.constant_(module.out_proj.bias, 0)
558
+
559
+ def _init_rope(self):
560
+ config_copy = copy.deepcopy(self.config)
561
+ config_copy.head_dim = self.attn_head_dim
562
+ # Rotary embedding for attention layer
563
+ self.rotary_emb_attn = LlamaRotaryEmbedding(
564
+ config_copy
565
+ )
566
+ # Rotary embedding for importance projection
567
+ self.rotary_emb_importance = LlamaRotaryEmbedding(
568
+ config_copy
569
+ )
570
+
571
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
572
+ """
573
+ Forward pass for the Optimized Token Importance Predictor.
574
+
575
+ Args:
576
+ hidden_states (torch.Tensor): Input tensor of shape [B, L, HQ].
577
+ attention_mask (torch.Tensor, optional): Attention mask of shape [B, 1, 1, L] or [B, 1, L, L].
578
+ position_ids (torch.Tensor, optional): Position IDs.
579
+ past_key_value (tuple, optional): Past key and value states.
580
+ use_cache (bool, optional): Whether to use cache.
581
+
582
+ Returns:
583
+ torch.Tensor: Importance scores of shape [B, N, H, L, L].
584
+ """
585
+ # Set device if not already set
586
+ if self.device != hidden_states.device:
587
+ self.device = hidden_states.device
588
+ self.to(self.device)
589
+
590
+ B, L, E = hidden_states.size()
591
+ if past_key_value is None:
592
+ past_key_value = {}
593
+ # if L == 1:
594
+ # import pdb; pdb.set_trace()
595
+ past_primary = past_key_value.get('primary', None)
596
+ # Reduce hidden size
597
+ hidden_states = hidden_states.to(self.input_proj.weight.dtype)
598
+ hidden_states_reduced = self.input_proj(hidden_states) # [B, L, hidden_size_reduced]
599
+ # Compute q, k, v for attention
600
+ q = self.q_proj_attn(hidden_states_reduced) # [B, L, hidden_size_reduced]
601
+ k = self.k_proj_attn(hidden_states_reduced) # [B, L, hidden_size_reduced]
602
+ v = self.v_proj_attn(hidden_states_reduced) # [B, L, hidden_size_reduced]
603
+ # Reshape q, k, v to [B, num_heads, L, attn_head_dim]
604
+ q = q.view(B, L, self.num_heads, self.attn_head_dim).transpose(1, 2) # [B, num_heads, L, attn_head_dim]
605
+ k = k.view(B, L, self.num_heads, self.attn_head_dim).transpose(1, 2) # [B, num_heads, L, attn_head_dim]
606
+ v = v.view(B, L, self.num_heads, self.attn_head_dim).transpose(1, 2) # [B, num_heads, L, attn_head_dim]
607
+ # Compute kv_seq_len before concatenation
608
+ if past_primary is not None:
609
+ past_L = past_primary[0].shape[2]
610
+ kv_seq_len = past_L + L
611
+ else:
612
+ kv_seq_len = L
613
+
614
+ # Apply rotary positional embeddings based on kv_seq_len
615
+ cos, sin = self.rotary_emb_attn(v, position_ids)
616
+ if position_ids is None:
617
+ position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=self.device)
618
+ position_ids = position_ids.unsqueeze(0).expand(B, kv_seq_len)
619
+
620
+ if past_primary is not None:
621
+ # Concatenate past k and v
622
+ k = torch.cat([past_primary[0], k], dim=2) # [B, num_heads, past_L + L, attn_head_dim]
623
+ v = torch.cat([past_primary[1], v], dim=2) # [B, num_heads, past_L + L, attn_head_dim]
624
+
625
+ # Apply rotary embeddings after concatenation
626
+ q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
627
+
628
+ # Update cache if use_cache is True
629
+ if use_cache:
630
+ past_key_value['primary'] = (k.detach(), v.detach())
631
+
632
+ # if self.flash_attn:
633
+ # sm_scale = 1.0 / math.sqrt(self.attn_head_dim)
634
+ # attn_output = attention(q.contiguous().to(torch.float16), k.contiguous().to(torch.float16), v.contiguous().to(torch.float16), True, sm_scale).to(q.dtype)
635
+ # else:
636
+ # attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True)
637
+ attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True)
638
+ attn_output = attn_output.to(q.dtype)
639
+ attn_output = attn_output.transpose(1, 2).contiguous().view(B, L, self.hidden_size_reduced)
640
+ attn_output = self.norm1(attn_output)
641
+ head_importances = self.ffn(attn_output)
642
+ return head_importances, past_key_value
643
+
644
+ def calculate_hit_metrics(estimated_importance: torch.Tensor,
645
+ true_importance: torch.Tensor,
646
+ top_k_ratio: float = 0.5) -> Tuple[float, float, float]:
647
+ """
648
+ Calculate hit accuracy, mean, and max rank correlation between estimated and true importance tensors.
649
+ We compute metrics along the last dimension of the input tensors.
650
+
651
+ Shapes:
652
+ - 4D token-importance: [B, H, L, L]. We slice the last query (index -1) => [B, H, L].
653
+ - 3D head-importance: [B, L, H]. We use all of it as-is => [B, L, H].
654
+
655
+ Args:
656
+ estimated_importance (torch.Tensor): [B, H, L, L] or [B, L, H]
657
+ true_importance (torch.Tensor): [B, H, L, L] or [B, L, H]
658
+ top_k_ratio (float): Fraction of top-k elements to consider for hit accuracy (default=0.5).
659
+
660
+ Returns:
661
+ (hit_accuracy, mean_corr, max_corr):
662
+ hit_accuracy (float): Intersection ratio of top-k sets (0..1).
663
+ mean_corr (float): Average Spearman rank correlation over all [B, ...].
664
+ max_corr (float): Maximum Spearman rank correlation among all [B, ...].
665
+ """
666
+
667
+ # 1) Standardize shapes so the last dimension is what we rank over.
668
+ if estimated_importance.dim() == 4:
669
+ # Shape is [B, H, L, L] => slice to keep only the last query => [B, H, L]
670
+ estimated_importance = estimated_importance[:, :, -1, :]
671
+ true_importance = true_importance[:, :, -1, :]
672
+ # after slicing: [B, H, L]
673
+ # For intersection denominator => top_k * B * H
674
+ denom_for_hits = estimated_importance.size(0) * estimated_importance.size(1)
675
+ elif estimated_importance.dim() == 3:
676
+ # Shape is [B, L, H], the last dimension is H
677
+ # For intersection denominator => top_k * B * L
678
+ denom_for_hits = estimated_importance.size(0) * estimated_importance.size(1)
679
+ else:
680
+ raise ValueError("Tensors must be either 4D [B,H,L,L] or 3D [B,L,H].")
681
+
682
+ # 2) Compute Spearman rank correlation along the last dimension.
683
+ # Sort indices in descending order => get 'ranks' for correlation.
684
+ _, sorted_esti = torch.sort(estimated_importance, dim=-1, descending=True)
685
+ _, sorted_true = torch.sort(true_importance, dim=-1, descending=True)
686
+
687
+ # Spearman's rho = 1 - 6 * sum(d^2) / [n*(n^2 - 1)]
688
+ n = sorted_esti.shape[-1]
689
+ d = sorted_esti.float() - sorted_true.float()
690
+ d_squared = d ** 2
691
+ sum_d_squared = d_squared.sum(dim=-1)
692
+ rank_corr = 1 - (6 * sum_d_squared) / (n * (n**2 - 1)) # shape: [B,H] or [B,L]
693
+
694
+ mean_corr = rank_corr.mean().item()
695
+ max_corr = rank_corr.max().item()
696
+
697
+ # 3) Compute top-k hit accuracy along the last dimension.
698
+ top_k = max(1, int(n * top_k_ratio))
699
+ _, top_esti_indices = torch.topk(estimated_importance, top_k, dim=-1)
700
+ _, top_true_indices = torch.topk(true_importance, top_k, dim=-1)
701
+
702
+ # top_esti_indices => [B,H,top_k] or [B,L,top_k]
703
+ # top_true_indices => [B,H,top_k] or [B,L,top_k]
704
+ # matches => [B,H,top_k,top_k] or [B,L,top_k,top_k]
705
+ matches = (top_esti_indices.unsqueeze(-1) == top_true_indices.unsqueeze(-2))
706
+ intersection = matches.any(dim=-1).sum(dim=-1) # => [B,H] or [B,L]
707
+
708
+ # Each [B,H] or [B,L] element can have at most 'top_k' matches, so total is top_k * denom_for_hits.
709
+ total_possible = top_k * denom_for_hits
710
+ hit_accuracy = intersection.sum().item() / total_possible # => 0..1
711
+
712
+ return hit_accuracy, mean_corr, max_corr
713
+
714
+
715
+ def threshold_to_mask(unadj_importance_mask, perhead_thresholds, min_sparse_index, bsz, q_len, key_len):
716
+ """
717
+ Create a mask tensor based on per-head thresholds, setting values below the threshold to -inf.
718
+
719
+ Args:
720
+ - unadj_importance_mask: torch.Tensor of shape [B, H, Lq, Lk].
721
+ - perhead_thresholds: torch.Tensor of shape [H], per-head thresholds.
722
+ - min_sparse_index: Minimum index for sparsity; values below this index will not be masked.
723
+ - bsz: Batch size.
724
+ - q_len: Query length (Lq).
725
+ - key_len: Key length (Lk).
726
+
727
+ Returns:
728
+ - mask_tensor: torch.Tensor of shape [B, H, Lq, Lk], with values below threshold as -inf.
729
+ """
730
+ # Ensure perhead_thresholds is in the correct shape for broadcasting
731
+ thresholds_broadcast = perhead_thresholds.view(1, -1, 1, 1) # [1, H, 1, 1]
732
+
733
+ # Compare unadj_importance_mask with thresholds to create a mask
734
+ mask_tensor = torch.where(
735
+ unadj_importance_mask >= thresholds_broadcast,
736
+ torch.zeros_like(unadj_importance_mask),
737
+ torch.full_like(unadj_importance_mask, float('-inf'))
738
+ ) # [B, H, Lq, Lk]
739
+
740
+ # Ensure mask_tensor has mask_tensor[:, :, :, :min_sparse_index] = 0
741
+ mask_tensor[:, :, :, :min_sparse_index] = 0.0
742
+
743
+ return mask_tensor
744
+
745
+ class SlidingWindowCache:
746
+ def __init__(self, max_seq_len, sliding_window, device):
747
+ self.sliding_window = sliding_window
748
+ self.device = device
749
+ if sliding_window is None:
750
+ self.max_seq_len = 0
751
+ self.window = None
752
+ else:
753
+ self.max_seq_len = max_seq_len
754
+ self.window = self._create_window(self.max_seq_len)
755
+
756
+ def _create_window(self, seq_len):
757
+ idx = torch.arange(seq_len, device=self.device)
758
+ query = idx.unsqueeze(1) # [seq_len, 1]
759
+ key = idx.unsqueeze(0) # [1, seq_len]
760
+ win = (key >= (query - self.sliding_window + 1)) & (key <= query)
761
+ return win.unsqueeze(0).unsqueeze(0) # [1,1,seq_len,seq_len]
762
+
763
+ def get_window(self, q_len, key_len):
764
+ if self.sliding_window is None:
765
+ return None
766
+ req = max(q_len, key_len)
767
+ if req > self.max_seq_len:
768
+ self.max_seq_len = req
769
+ self.window = self._create_window(self.max_seq_len)
770
+ return self.window[:, :, :q_len, :key_len]
771
+
772
+ def enforce_sliding_window(mask_tensor, window):
773
+ if window is None:
774
+ return mask_tensor
775
+ return mask_tensor.masked_fill(window, 0.0)
776
+
777
+
778
+ def sorted_index_to_mask(
779
+ sorted_indices,
780
+ attention_mask,
781
+ min_sparse_index,
782
+ bsz,
783
+ q_len,
784
+ key_len,
785
+ sparse_aggression,
786
+ sliding_window=None
787
+ ):
788
+ """
789
+ sorted_indices: [B, H, q_len, key_len]
790
+ attention_mask: [1, 1, q_len, key_len] (True = keep, False = mask out, or vice versa)
791
+ min_sparse_index: guaranteed front region to keep
792
+ sliding_window: guaranteed trailing region (for each query) to keep
793
+ sparse_aggression: float in [0,1], fraction of keys to drop or keep
794
+ """
795
+ device = sorted_indices.device
796
+ dtype = sorted_indices.dtype
797
+
798
+ # Step 1: Compute base K
799
+ if q_len == 1:
800
+ query_positions = torch.arange(q_len, device=device).view(1, 1, q_len, 1).float()
801
+ query_positions[0] = key_len + 1
802
+ else:
803
+ query_positions = torch.arange(q_len, device=device).view(1, 1, q_len, 1).float() + 1.0
804
+ K_original = torch.ceil(query_positions * sparse_aggression).long() # [1,1,q_len,1]
805
+ K_original = torch.clamp(K_original, max=key_len)
806
+
807
+ # Step 1b: Incorporate guaranteed region
808
+ guaranteed = min_sparse_index
809
+ if sliding_window is not None:
810
+ guaranteed += sliding_window
811
+ # Subtract guaranteed from the original K
812
+ K_adjusted = K_original - guaranteed
813
+ # Ensure K_adjusted is at least 0
814
+ K_adjusted = torch.clamp(K_adjusted, min=0, max=key_len)
815
+
816
+ # Step 2: Expand attention_mask to [B,H,q_len,key_len]
817
+ attention_mask_expanded = attention_mask.expand(bsz, -1, -1, -1)
818
+ attention_mask_expanded = attention_mask_expanded.expand(-1, sorted_indices.size(1), -1, -1)
819
+ # Convert True -> 1, False -> 0
820
+ attention_mask_expanded = (~attention_mask_expanded.bool()).int()
821
+
822
+ # Step 3: Gather (reorder) mask by sorted_indices
823
+ gathered_mask = torch.gather(attention_mask_expanded, dim=-1, index=sorted_indices)
824
+
825
+ # Step 4: cumsum along sorted dimension
826
+ gathered_mask_float = gathered_mask.float()
827
+ cum_sum = torch.cumsum(gathered_mask_float, dim=-1) # [B,H,q_len,key_len]
828
+
829
+ # Step 5: Compare cumsum <= K_adjusted
830
+ # Expand K_adjusted to [B,H,q_len,key_len] for broadcast
831
+ K_broadcast = K_adjusted.view(1, 1, q_len, 1).expand_as(cum_sum)
832
+ selected_mask = (cum_sum <= K_broadcast)
833
+
834
+ # Step 6: Prepare final mask_tensor with -inf by default
835
+ mask_tensor = torch.full_like(attention_mask_expanded.float(), float('-inf'))
836
+
837
+ # Step 7: Scatter 0 where selected, -inf otherwise
838
+ scatter_values = torch.zeros_like(gathered_mask_float)
839
+ scatter_values = scatter_values.masked_fill(~selected_mask, float('-inf'))
840
+ mask_tensor.scatter_(-1, sorted_indices, scatter_values)
841
+
842
+ # Step 8: Force the guaranteed front region unmasked
843
+ mask_tensor[:, :, :, :min_sparse_index] = 0.0
844
+
845
+ # We do NOT forcibly unmask the trailing `sliding_window` here,
846
+ # because we typically do it with a separate function that
847
+ # ensures the last `sliding_window` positions are unmasked for each query.
848
+ # Replace with self.sliding_window where referenced
849
+ # Where not referenced, reduce budget in calculation.
850
+
851
+ return mask_tensor
852
+
853
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
854
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
855
+
856
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, config=None):
857
+ self.scaling_factor = scaling_factor
858
+ super().__init__(config)
859
+
860
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
861
+ self.max_seq_len_cached = seq_len
862
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
863
+ t = t / self.scaling_factor
864
+
865
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
866
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
867
+ emb = torch.cat((freqs, freqs), dim=-1)
868
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
869
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
870
+
871
+
872
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
873
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
874
+
875
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, config=None):
876
+ self.scaling_factor = scaling_factor
877
+ super().__init__(config)
878
+
879
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
880
+ self.max_seq_len_cached = seq_len
881
+
882
+ if seq_len > self.max_position_embeddings:
883
+ base = self.base * (
884
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
885
+ ) ** (self.dim / (self.dim - 2))
886
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
887
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
888
+
889
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
890
+
891
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
892
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
893
+ emb = torch.cat((freqs, freqs), dim=-1)
894
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
895
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
896
+
897
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
898
+ """
899
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
900
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
901
+ """
902
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
903
+ if n_rep == 1:
904
+ return hidden_states
905
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
906
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
907
+
908
+
909
+ class LlamaAttentionExperimental(nn.Module):
910
+ def __init__(self, config: LlamaConfig, producer=None, layer_idx=0):
911
+ super().__init__()
912
+ self.config = config
913
+ self.hidden_size = config.hidden_size
914
+ self.num_hidden_layers = config.num_hidden_layers
915
+ self.num_heads = config.num_attention_heads
916
+ self.head_dim = self.hidden_size // self.num_heads
917
+ self.num_key_value_heads = config.num_key_value_heads
918
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
919
+ self.max_position_embeddings = config.max_position_embeddings
920
+ self.rope_theta = config.rope_theta
921
+ self.inference_mode = False
922
+ self.producer = producer
923
+ self.layer_idx = layer_idx
924
+ self.token_sparse_method = None
925
+ self.sparse_aggression = None
926
+ self.stream_llm_start_size = None
927
+ self.dDash = None
928
+ self.intdim = None
929
+ self.attn_reduce_factor = None
930
+ self.head_attn_reduce_factor = None
931
+ self.effective_sparsity = None
932
+ self.min_sparse_index = None
933
+ self.pred_hid_size = self.hidden_size
934
+ self.num_tok_per_page = None
935
+ self.calc_hitrates = False
936
+ self.flash_attn = False
937
+ self.train_headpredictor = False
938
+ self.calibrate_thresholds = False
939
+ self.test_with_thresholds = False
940
+ self.old_predictor = None
941
+
942
+ if self.layer_idx > 0:
943
+ self.mseloss = MSELoss(reduction='none')
944
+ self.msemagn_loss = None
945
+ self.headmseloss = MSELoss(reduction='none')
946
+ self.headmsemagn_loss = None
947
+
948
+ if self.producer is None: # This is the producer layer
949
+ self.q_importance = None # Shared mask across layers during inference
950
+ self.k_importance = None
951
+ self.head_importances = None
952
+ self.actmagn_masklist = {}
953
+ self.available_tokens = {}
954
+
955
+ # Attention setup
956
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
957
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
958
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
959
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
960
+ self._init_rope()
961
+
962
+ def update_predictor(self):
963
+ self.sparse_token_predictor = TokenImportancePredictorAttentive(
964
+ self.config, self.pred_hid_size, self.num_heads, self.num_layers_pred, dropout=0.1, dDash = self.dDash, \
965
+ intdim = self.intdim, attn_reduce_factor=self.attn_reduce_factor
966
+ ).to('cuda:0')
967
+ self.sparse_token_predictor.flash_attn = self.flash_attn
968
+ if self.train_headpredictor:
969
+ self.sparse_head_predictor = HeadImportancePredictor(
970
+ self.config, self.pred_hid_size, self.num_heads, self.num_layers_pred, dropout=0.1, dDash = self.dDash, \
971
+ intdim = self.intdim, attn_reduce_factor=self.head_attn_reduce_factor
972
+ ).to('cuda:0')
973
+ self.sparse_head_predictor.flash_attn = self.flash_attn
974
+
975
+ def set_token_sparsity(self):
976
+ assert self.token_sparse_method is not None, "Set token sparse method first!"
977
+ if self.token_sparse_method is not None:
978
+ try:
979
+ mname = self.config._name_or_path.split("/")[-1]
980
+ read_path = f"threshold_calibs/{mname}/{self.token_sparse_method}.pkl"
981
+ threshold_model_dictionary = torch.load(read_path)
982
+ self.tok_calibration_set = threshold_model_dictionary
983
+ except:
984
+ pass
985
+ if self.token_sparse_method == "LazyLLM":
986
+ if self.layer_idx <= 9:
987
+ self.sparse_aggression = 1
988
+ elif self.layer_idx <= 19:
989
+ self.sparse_aggression = 0.7
990
+ elif self.layer_idx <= 28:
991
+ self.sparse_aggression = 0.4
992
+ else:
993
+ self.sparse_aggression = 0.1
994
+ elif "fixed" in self.token_sparse_method:
995
+ if self.layer_idx == 0:
996
+ self.sparse_aggression = 1
997
+ else:
998
+ self.sparse_aggression = 1 - float(self.token_sparse_method.split("_")[1].split("pc")[0])/100.
999
+ elif "progressive" in self.token_sparse_method:
1000
+ pc_drop = float(self.token_sparse_method.split("_")[1].split("pc")[0])/100.
1001
+ self.sparse_aggression = (1 - pc_drop) ** (self.layer_idx) # (x% per layer, progressive_xpc style)
1002
+ else:
1003
+ raise ValueError(f"Unknown token sparsity method {self.token_sparse_method}")
1004
+
1005
+
1006
+ def _init_rope(self):
1007
+ if self.config.rope_scaling is None:
1008
+ self.rotary_emb = LlamaRotaryEmbedding(
1009
+ self.config
1010
+ )
1011
+ else:
1012
+ scaling_type = self.config.rope_scaling.get("type") or self.config.rope_scaling.get("rope_type")
1013
+ scaling_factor = self.config.rope_scaling["factor"]
1014
+ if scaling_type == "linear" or scaling_type == 'llama3':
1015
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
1016
+ self.head_dim,
1017
+ max_position_embeddings=self.max_position_embeddings,
1018
+ scaling_factor=scaling_factor,
1019
+ base=self.rope_theta,
1020
+ config=self.config
1021
+ )
1022
+ elif scaling_type == "dynamic":
1023
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
1024
+ self.head_dim,
1025
+ max_position_embeddings=self.max_position_embeddings,
1026
+ scaling_factor=scaling_factor,
1027
+ base=self.rope_theta,
1028
+ config=self.config
1029
+ )
1030
+ else:
1031
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
1032
+
1033
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
1034
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
1035
+
1036
+ def forward(
1037
+ self,
1038
+ hidden_states: torch.Tensor,
1039
+ attention_mask: Optional[torch.Tensor] = None,
1040
+ position_ids: Optional[torch.LongTensor] = None,
1041
+ past_key_value: Optional[Union[DynamicCache, PredictorDynamicCache]] = None,
1042
+ output_attentions: bool = False,
1043
+ use_cache: bool = False,
1044
+ padding_mask: Optional[torch.LongTensor] = None,
1045
+ cache_position: Optional[torch.LongTensor] = None,
1046
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
1047
+ **kwargs,
1048
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[PredictorDynamicCache]]:
1049
+ bsz, q_len, _ = hidden_states.size()
1050
+ Ltrack = hidden_states.size(1)
1051
+
1052
+ if self.config.pretraining_tp > 1:
1053
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
1054
+ query_slices = self.q_proj.weight.split(
1055
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
1056
+ )
1057
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
1058
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
1059
+
1060
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
1061
+ query_states = torch.cat(query_states, dim=-1)
1062
+
1063
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
1064
+ key_states = torch.cat(key_states, dim=-1)
1065
+
1066
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
1067
+ value_states = torch.cat(value_states, dim=-1)
1068
+ else:
1069
+ query_states = self.q_proj(hidden_states)
1070
+ key_states = self.k_proj(hidden_states)
1071
+ value_states = self.v_proj(hidden_states)
1072
+
1073
+ evalmode = self.eval_llm_mode
1074
+ num_tokens_to_keep = int(q_len * self.sparse_aggression)
1075
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
1076
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1077
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1078
+
1079
+ # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) # AHMED: Modified this to use the newer version.
1080
+ cos, sin = position_embeddings
1081
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
1082
+
1083
+ if use_cache:
1084
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
1085
+
1086
+ kv_seq_len = key_states.shape[-2]
1087
+ final_mask = None
1088
+
1089
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
1090
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
1091
+
1092
+ key_len = key_states.size(2)
1093
+ bsz, q_len = query_states.size(0), query_states.size(2)
1094
+
1095
+ if attention_mask is None:
1096
+ # We want a [q_len, kv_seq_len] boolean upper-triangular mask
1097
+ causal_mask_2d = torch.ones(q_len, kv_seq_len,
1098
+ device=hidden_states.device,
1099
+ dtype=torch.bool).triu(diagonal=1)
1100
+ # Then shape it to [bsz, 1, q_len, kv_seq_len]
1101
+ causal_mask_4d = causal_mask_2d.unsqueeze(0).expand(bsz, 1, q_len, kv_seq_len)
1102
+ # Now fill -inf where the mask is True
1103
+ attention_mask = torch.full_like(causal_mask_4d, 0, dtype=hidden_states.dtype)
1104
+ if q_len != 1:
1105
+ attention_mask = attention_mask.masked_fill(causal_mask_4d, float("-inf"))
1106
+
1107
+ if self.inference_mode:
1108
+ min_sparse_index = self.min_sparse_index
1109
+ with torch.no_grad():
1110
+ if evalmode == "ExpPred":
1111
+ if self.layer_idx > 0:
1112
+ q_importance_tensor = self.producer.q_importance[:, self.layer_idx % self.producer_frequency, :, :].float().to(query_states.device) # [BH, Lq, D']
1113
+ k_importance_tensor = self.producer.k_importance[:, self.layer_idx % self.producer_frequency, :, :].float().to(key_states.device) # [BH, Lk, D']
1114
+ importance_mask = torch.bmm(q_importance_tensor, k_importance_tensor.transpose(-2, -1)) / math.sqrt(self.dDash) # [BH, Lq, Lk]
1115
+ importance_mask = importance_mask.view(bsz, self.num_heads, q_len, key_len) # [B, H, Lq, Lk]
1116
+ attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim)
1117
+ if self.calc_hitrates:
1118
+ self.tok_hit_acc, self.tok_mean_rank_corr, self.tok_max_rank_corr = calculate_hit_metrics(
1119
+ estimated_importance=importance_mask,
1120
+ true_importance=attn_weights,
1121
+ top_k_ratio=0.5
1122
+ )
1123
+ if self.calibrate_thresholds:
1124
+ ### Threshold variance investigation
1125
+ unadj_importance_mask = importance_mask.clone()
1126
+ importance_mask = torch.softmax(importance_mask + attention_mask, dim=-1)
1127
+ sorted_indices = torch.argsort(importance_mask, dim=-1, descending=True)
1128
+ sorted_indices = sorted_indices[:, :, -q_len:, :]
1129
+ sorted_values, sorted_ix = torch.sort(importance_mask, dim=-1)
1130
+ sorted_true_values, _ = torch.sort(torch.gather(unadj_importance_mask, dim=-1, index=sorted_ix), dim=-1)
1131
+ true_thresholds = sorted_true_values[:, :, :, int(importance_mask.size(-1) * self.sparse_aggression)]
1132
+ thresholds = sorted_values[:, :, :, int(importance_mask.size(-1) * self.sparse_aggression)]
1133
+ self.true_threshmean = true_thresholds
1134
+ self.threshmean = thresholds
1135
+ if self.test_with_thresholds:
1136
+ unadj_importance_mask = importance_mask.clone()
1137
+ perhead_thresholds = self.tok_calibration_set[self.layer_idx - 1].to(unadj_importance_mask.device) # 0 does not have calibration data.
1138
+ mask_tensor = threshold_to_mask(unadj_importance_mask, perhead_thresholds, min_sparse_index, bsz, q_len, key_len)
1139
+ else:
1140
+ importance_mask = torch.softmax(importance_mask + attention_mask, dim=-1)
1141
+ sorted_indices = torch.argsort(importance_mask, dim=-1, descending=True)
1142
+ sorted_indices = sorted_indices[:, :, -q_len:, :]
1143
+ mask_tensor = sorted_index_to_mask(sorted_indices, attention_mask, min_sparse_index, bsz, q_len, key_len, self.sparse_aggression, self.sliding_window)
1144
+ ### Threshold variance investigation
1145
+ if self.sliding_window is not None:
1146
+ if not hasattr(self, "window_cache"):
1147
+ self.window_cache = SlidingWindowCache(max_seq_len=1024,
1148
+ sliding_window=self.sliding_window,
1149
+ device=mask_tensor.device)
1150
+ window = self.window_cache.get_window(q_len, key_len)
1151
+ mask_tensor = enforce_sliding_window(mask_tensor, window)
1152
+ final_mask = mask_tensor
1153
+
1154
+ self.final_mask_investigate = final_mask
1155
+ attn_weights = attn_weights + mask_tensor + attention_mask
1156
+ else:
1157
+ attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim)
1158
+ attn_weights = attn_weights + attention_mask
1159
+ else:
1160
+ raise ValueError(f"Unknown eval mode {evalmode}")
1161
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
1162
+ attn_output = torch.matmul(attn_weights, value_states)
1163
+
1164
+ else:
1165
+ attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim)
1166
+ if self.layer_idx > 0:
1167
+ q_importance_tensor = self.producer.q_importance[:, self.layer_idx % self.producer_frequency, :, :].float().to(query_states.device) # [BH, Lq, D']
1168
+ k_importance_tensor = self.producer.k_importance[:, self.layer_idx % self.producer_frequency, :, :].float().to(key_states.device) # [BH, Lk, D']
1169
+ importance_mask = torch.bmm(q_importance_tensor, k_importance_tensor.transpose(-2, -1)) / math.sqrt(self.dDash) # [BH, Lq, Lk]
1170
+ importance_mask = importance_mask.view(bsz, self.num_heads, q_len, key_len) # [B, H, Lq, Lk]
1171
+
1172
+ if self.lookahead == 0:
1173
+ self.msemagn_loss = self.mseloss(attn_weights, importance_mask)
1174
+ else:
1175
+ self.msemagn_loss = self.mseloss(attn_weights[:, :, self.lookahead:, :], importance_mask[:, :, :-self.lookahead, :])
1176
+ self.msemagn_loss = (self.msemagn_loss).mean(dim=(-1, -2))
1177
+ self.msemagn_loss = self.msemagn_loss.mean()
1178
+
1179
+ if self.calc_hitrates:
1180
+ self.tok_hit_acc, self.tok_mean_rank_corr, self.tok_max_rank_corr = calculate_hit_metrics(
1181
+ estimated_importance=importance_mask,
1182
+ true_importance=attn_weights,
1183
+ top_k_ratio=0.5
1184
+ )
1185
+
1186
+ if attention_mask is not None:
1187
+ attn_weights = attn_weights + attention_mask
1188
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
1189
+ attn_output = torch.matmul(attn_weights, value_states)
1190
+
1191
+ if self.layer_idx > 0 and self.train_headpredictor:
1192
+ head_importance_tensor = self.producer.head_importances[:, :, :, self.layer_idx % self.producer_frequency].float().to(attn_output.device)
1193
+ attn_head_weights = attn_output.mean(dim=-1).permute(0, 2, 1)
1194
+ self.headmsemagn_loss = self.headmseloss(attn_head_weights, head_importance_tensor).mean()
1195
+
1196
+ if self.calc_hitrates:
1197
+ self.head_hit_acc, self.head_mean_rank_corr, self.head_max_rank_corr = calculate_hit_metrics(
1198
+ estimated_importance=head_importance_tensor,
1199
+ true_importance=attn_head_weights,
1200
+ top_k_ratio=0.5
1201
+ )
1202
+ else:
1203
+ self.headmsemagn_loss = 0
1204
+ if self.calc_hitrates:
1205
+ self.head_hit_acc, self.head_mean_rank_corr, self.head_max_rank_corr = 0, 0, 0
1206
+
1207
+
1208
+ checkeverytime = hasattr(self, 'test_with_thresholds')
1209
+ if checkeverytime:
1210
+ checkeverytime = self.test_with_thresholds
1211
+ if final_mask is not None:
1212
+ if self.effective_sparsity is None or checkeverytime:
1213
+ true_mask = final_mask + attention_mask
1214
+ num_deact = true_mask.bool().sum(dim=-1) # Number of tokens disabled.
1215
+ causally_deact = (attention_mask.bool()).sum(dim=-1).expand_as(num_deact) # Number of tokens disabled causally anyway
1216
+ additional_deact = (num_deact - causally_deact)
1217
+ num_active = (~attention_mask.bool()).sum(dim=-1).expand_as(num_deact) # Number of tokens active at this position if zero-sparsity
1218
+ effective_sparsity = 100 * (additional_deact.float() / num_active.float()).mean().item()
1219
+ self.effective_sparsity = effective_sparsity
1220
+ print("Effective Sparsity:", effective_sparsity, "%\t Sequence Length:", q_len)
1221
+ if self.layer_idx == 0:
1222
+ if self.effective_sparsity is None:
1223
+ self.effective_sparsity = 0.0
1224
+
1225
+ attn_output = attn_output.transpose(1, 2).contiguous()
1226
+ attn_output = attn_output.view(bsz, -1, self.hidden_size)
1227
+
1228
+ if self.config.pretraining_tp > 1:
1229
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
1230
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
1231
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
1232
+ else:
1233
+ attn_output = self.o_proj(attn_output)
1234
+
1235
+ if self.producer is None:
1236
+ try:
1237
+ q_importance, k_importance = self.sparse_token_predictor(
1238
+ hidden_states,
1239
+ attention_mask=attention_mask,
1240
+ position_ids=position_ids,
1241
+ past_key_value=past_key_value, # the same single cache
1242
+ use_cache=use_cache,
1243
+ layer_idx=self.layer_idx, # or pass 0
1244
+ )
1245
+ if self.train_headpredictor:
1246
+ head_importances, past_key_value_hp = self.sparse_head_predictor(
1247
+ hidden_states,
1248
+ attention_mask=attention_mask,
1249
+ position_ids=position_ids,
1250
+ past_key_value=past_key_value_hp,
1251
+ use_cache=use_cache
1252
+ )
1253
+ head_importances = head_importances.view(bsz, q_len, self.num_heads, self.num_hidden_layers) # [B L H N]
1254
+ q_len = attn_output.size(1)
1255
+ k_len = k_importance.size(-1)
1256
+ except:
1257
+ print(traceback.format_exc())
1258
+ import pdb; pdb.set_trace()
1259
+
1260
+ self.q_importance = q_importance
1261
+ self.k_importance = k_importance
1262
+
1263
+ if self.train_headpredictor:
1264
+ if self.head_importances is None:
1265
+ self.head_importances = head_importances
1266
+ else:
1267
+ self.head_importances = torch.cat([self.head_importances, head_importances], dim=1)
1268
+
1269
+ # if self.layer_idx == 31:
1270
+ # if q_len == 1:
1271
+ # self.dtok += 1
1272
+ # print(f"Primary Key-Value Shape: {past_key_value.predictor_primary_key[0].shape}, Importance: {past_key_value.predictor_importance_key[0].shape}, Tok-Decoded: {self.dtok}")
1273
+ # else:
1274
+ # self.dtok = 0
1275
+
1276
+ if not output_attentions:
1277
+ attn_weights = None
1278
+ return attn_output, attn_weights
1279
+
1280
+ def convert_kvcache_experimental(model, config, producer_frequency):
1281
+ producer_layer = None
1282
+ producer_layer_device = None
1283
+ layer_counter = {'idx': 0}
1284
+
1285
+ def recurse_convert(parent_module):
1286
+ nonlocal producer_layer
1287
+ nonlocal producer_layer_device
1288
+ for name, module in parent_module._modules.items():
1289
+ if len(list(module.children())) > 0:
1290
+ recurse_convert(module)
1291
+ if isinstance(module, LlamaAttention):
1292
+ device = next(module.parameters()).device
1293
+ dtype = next(module.parameters()).dtype
1294
+ if layer_counter['idx'] % producer_frequency == 0:
1295
+ new_module = LlamaAttentionExperimental(config).to(dtype).to(device)
1296
+ producer_layer = new_module
1297
+ producer_layer_device = device
1298
+ else:
1299
+ new_module = LlamaAttentionExperimental(
1300
+ config,
1301
+ producer=producer_layer,
1302
+ layer_idx=layer_counter['idx']
1303
+ ).to(dtype).to(device)
1304
+ new_module.load_state_dict(module.state_dict(), strict=False)
1305
+ is_producer = layer_counter['idx'] % producer_frequency == 0
1306
+ if is_producer:
1307
+ print(f"Converted Producer layer '{name}' to LlamaAttentionExperimental at layer index {layer_counter['idx']}")
1308
+ else:
1309
+ print(f"Converted layer '{name}' to LlamaAttentionExperimental at layer index {layer_counter['idx']}")
1310
+ parent_module._modules[name] = new_module
1311
+ layer_counter['idx'] += 1
1312
+ recurse_convert(model)
1313
+ producer_layer = producer_layer.to(producer_layer_device)
1314
+ return model
1315
+
1316
+
1317
+ # ---------------------------------------------------------------------
1318
+ # 1) Custom Config subclass
1319
+ # ---------------------------------------------------------------------
1320
+ class LlamaButlerConfig(LlamaConfig):
1321
+ """
1322
+ Extends HF's LlamaConfig to hold optional extra parameters for the "Butler" logic.
1323
+ You can store your custom attributes here, so they can be serialized in config.json.
1324
+ """
1325
+
1326
+ model_type = "llama_butler"
1327
+
1328
+ def __init__(
1329
+ self,
1330
+ eval_llm_mode="ExpPred",
1331
+ token_sparse_method="fixed_50pc",
1332
+ producer_frequency=8,
1333
+ dDash=16,
1334
+ attn_reduce_factor=4,
1335
+ head_attn_reduce_factor=4,
1336
+ intdim=256,
1337
+ flash_attn=False,
1338
+ train_headpredictor=False,
1339
+ min_sparse_index=5,
1340
+ lookahead=0,
1341
+ sliding_window=None,
1342
+ **kwargs
1343
+ ):
1344
+ super().__init__(**kwargs)
1345
+ self.eval_llm_mode = eval_llm_mode
1346
+ self.token_sparse_method = token_sparse_method
1347
+ self.producer_frequency = producer_frequency
1348
+ self.dDash = dDash
1349
+ self.attn_reduce_factor = attn_reduce_factor
1350
+ self.head_attn_reduce_factor = head_attn_reduce_factor
1351
+ self.intdim = intdim
1352
+ self.flash_attn = flash_attn
1353
+ self.train_headpredictor = train_headpredictor
1354
+ self.min_sparse_index = min_sparse_index
1355
+ self.lookahead = lookahead
1356
+ self.sliding_window = sliding_window
1357
+
1358
+
1359
+ # ---------------------------------------------------------------------
1360
+ # 2) The main Butler model class
1361
+ # ---------------------------------------------------------------------
1362
+ class LlamaButlerForCausalLM(LlamaForCausalLM):
1363
+ """
1364
+ A subclass of HF's LlamaForCausalLM that:
1365
+ - Patches each LlamaAttention to your LlamaAttentionExperimental
1366
+ - Sets specialized attributes (eval_llm_mode, etc.)
1367
+ - Overrides _prepare_cache_for_generation to inject PredictorDynamicCache
1368
+ """
1369
+
1370
+ # Let HF auto-detect this config class from config.json:
1371
+ config_class = LlamaButlerConfig
1372
+
1373
+ def __init__(self, config: LlamaButlerConfig):
1374
+ super().__init__(config)
1375
+ """
1376
+ HF's LlamaForCausalLM initializes:
1377
+ self.model = LlamaModel(config)
1378
+ self.lm_head = nn.Linear(...)
1379
+ """
1380
+
1381
+ # 1) Patch the underlying LlamaModel to replace LlamaAttention with LlamaAttentionExperimental
1382
+ self.model = convert_kvcache_experimental(
1383
+ self.model,
1384
+ config,
1385
+ config.producer_frequency
1386
+ )
1387
+
1388
+ # 2) Optionally, set per-module attributes so each LlamaAttentionExperimental knows about them:
1389
+ for module in self.model.modules():
1390
+ if module.__class__.__name__.endswith("AttentionExperimental"):
1391
+ # Set these from your config. Or you can hardcode them if you prefer.
1392
+ module.eval_llm_mode = config.eval_llm_mode
1393
+ module.token_sparse_method = config.token_sparse_method
1394
+ module.set_token_sparsity() # e.g. sets module.sparse_aggression
1395
+
1396
+ module.producer_frequency = config.producer_frequency
1397
+ module.dDash = config.dDash
1398
+ module.attn_reduce_factor = config.attn_reduce_factor
1399
+ module.head_attn_reduce_factor = config.head_attn_reduce_factor
1400
+ module.intdim = config.intdim
1401
+ module.flash_attn = config.flash_attn
1402
+ module.train_headpredictor = config.train_headpredictor
1403
+ module.min_sparse_index = config.min_sparse_index
1404
+ module.lookahead = config.lookahead
1405
+ module.sliding_window = config.sliding_window
1406
+ module.num_layers_pred = config.producer_frequency # example usage
1407
+
1408
+ # If this is a "producer layer" (mod.layer_idx % freq == 0), run update_predictor():
1409
+ if hasattr(module, "layer_idx") and (module.layer_idx % config.producer_frequency == 0):
1410
+ module.update_predictor()
1411
+
1412
+ # 3) Patch the dynamic cache (past_key_values) creation. For your evaluation modes:
1413
+ if config.eval_llm_mode in ["ExpPred", "ReplAttn"]:
1414
+ self._prepare_cache_for_generation = self._patched_prepare_cache_for_generation.__get__(
1415
+ self, self.__class__
1416
+ )
1417
+
1418
+ # -----------------------------------------------------------------
1419
+ # 3) The custom `_prepare_cache_for_generation` override
1420
+ # -----------------------------------------------------------------
1421
+ def _patched_prepare_cache_for_generation(
1422
+ self,
1423
+ generation_config: GenerationConfig,
1424
+ model_kwargs: Dict,
1425
+ *args,
1426
+ **kwargs
1427
+ ):
1428
+ """
1429
+ This override injects a PredictorDynamicCache
1430
+ in place of the standard 'past_key_values'.
1431
+ """
1432
+ if "past_key_values" not in model_kwargs or model_kwargs["past_key_values"] is None:
1433
+ model_kwargs["past_key_values"] = PredictorDynamicCache()
1434
+ return model_kwargs
pytorch_model.bin.index.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<|begin▁of▁sentence|>",
7
+ "lstrip": false,
8
+ "normalized": true,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "clean_up_tokenization_spaces": false,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "<|end▁of▁sentence|>",
16
+ "lstrip": false,
17
+ "normalized": true,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "legacy": true,
22
+ "model_max_length": 16384,
23
+ "pad_token": {
24
+ "__type": "AddedToken",
25
+ "content": "<|end▁of▁sentence|>",
26
+ "lstrip": false,
27
+ "normalized": true,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ },
31
+ "sp_model_kwargs": {},
32
+ "unk_token": null,
33
+ "tokenizer_class": "LlamaTokenizerFast",
34
+ "chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|><think>\\n'}}{% endif %}"
35
+ }