|
import torch |
|
import torch.nn as nn |
|
from typing import Dict |
|
from transformers import LlamaForCausalLM, LlamaConfig |
|
from transformers.generation.utils import GenerationConfig |
|
import os |
|
import pdb |
|
import copy |
|
import math |
|
import numpy as np |
|
from dataclasses import dataclass |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
import gc |
|
|
|
import traceback |
|
import torch |
|
from torch import nn |
|
import torch.utils.checkpoint |
|
import torch.nn.functional as F |
|
from torch.cuda.amp import autocast |
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
|
from transformers.models.llama.configuration_llama import LlamaConfig |
|
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, LlamaAttention, apply_rotary_pos_emb |
|
|
|
from transformers.cache_utils import DynamicCache |
|
|
|
class PredictorDynamicCache(DynamicCache): |
|
def __init__(self): |
|
super().__init__() |
|
self.predictor_primary_key: List[Optional[torch.Tensor]] = [] |
|
self.predictor_primary_value: List[Optional[torch.Tensor]] = [] |
|
self.predictor_importance_key: List[Optional[torch.Tensor]] = [] |
|
|
|
def update_predictor_primary( |
|
self, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
layer_idx: int, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Append or create the predictor's "primary" K/V states for `layer_idx`. |
|
|
|
shape for key_states, value_states is typically [batch_size, num_heads, seq_len, head_dim]. |
|
""" |
|
|
|
|
|
self._ensure_list_capacity( |
|
self.predictor_primary_key, layer_idx, fill=None |
|
) |
|
self._ensure_list_capacity( |
|
self.predictor_primary_value, layer_idx, fill=None |
|
) |
|
|
|
|
|
if self.predictor_primary_key[layer_idx] is None: |
|
self.predictor_primary_key[layer_idx] = key_states |
|
self.predictor_primary_value[layer_idx] = value_states |
|
else: |
|
|
|
self.predictor_primary_key[layer_idx] = torch.cat( |
|
[self.predictor_primary_key[layer_idx], key_states], dim=2 |
|
) |
|
self.predictor_primary_value[layer_idx] = torch.cat( |
|
[self.predictor_primary_value[layer_idx], value_states], dim=2 |
|
) |
|
|
|
return ( |
|
self.predictor_primary_key[layer_idx], |
|
self.predictor_primary_value[layer_idx], |
|
) |
|
|
|
def update_predictor_importance( |
|
self, |
|
key_states: torch.Tensor, |
|
layer_idx: int, |
|
) -> torch.Tensor: |
|
""" |
|
Append or create the predictor's "importance" key for `layer_idx`. |
|
""" |
|
self._ensure_list_capacity( |
|
self.predictor_importance_key, layer_idx, fill=None |
|
) |
|
|
|
if self.predictor_importance_key[layer_idx] is None: |
|
self.predictor_importance_key[layer_idx] = key_states |
|
else: |
|
self.predictor_importance_key[layer_idx] = torch.cat( |
|
[self.predictor_importance_key[layer_idx], key_states], dim=2 |
|
) |
|
return self.predictor_importance_key[layer_idx] |
|
|
|
def crop(self, max_length: int): |
|
super().crop(max_length) |
|
|
|
for idx in range(len(self.predictor_primary_key)): |
|
if self.predictor_primary_key[idx] is not None: |
|
self.predictor_primary_key[idx] = self.predictor_primary_key[idx][..., :max_length, :] |
|
self.predictor_primary_value[idx] = self.predictor_primary_value[idx][..., :max_length, :] |
|
|
|
for idx in range(len(self.predictor_importance_key)): |
|
if self.predictor_importance_key[idx] is not None: |
|
self.predictor_importance_key[idx] = self.predictor_importance_key[idx][..., :max_length, :] |
|
|
|
|
|
self._seen_tokens = min(self._seen_tokens, max_length) |
|
|
|
def batch_split( |
|
self, full_batch_size: int, split_size: int, num_hidden_layers: int = None |
|
) -> List["PredictorDynamicCache"]: |
|
|
|
base_splits = super().batch_split(full_batch_size, split_size, num_hidden_layers) |
|
|
|
|
|
|
|
out: List[PredictorDynamicCache] = [] |
|
|
|
for split_i, base_split in enumerate(base_splits): |
|
|
|
new_cache = PredictorDynamicCache() |
|
|
|
new_cache.key_cache = base_split.key_cache |
|
new_cache.value_cache = base_split.value_cache |
|
new_cache._seen_tokens = base_split._seen_tokens |
|
|
|
|
|
|
|
b_start = split_i * split_size |
|
b_end = min(full_batch_size, b_start + split_size) |
|
|
|
new_cache.predictor_primary_key = self._slice_list_tensors( |
|
self.predictor_primary_key, b_start, b_end |
|
) |
|
new_cache.predictor_primary_value = self._slice_list_tensors( |
|
self.predictor_primary_value, b_start, b_end |
|
) |
|
new_cache.predictor_importance_key = self._slice_list_tensors( |
|
self.predictor_importance_key, b_start, b_end |
|
) |
|
|
|
out.append(new_cache) |
|
|
|
return out |
|
|
|
@classmethod |
|
def from_batch_splits(cls, splits: List["PredictorDynamicCache"], num_hidden_layers: int = None) -> "PredictorDynamicCache": |
|
|
|
base_merged = DynamicCache.from_batch_splits(splits, num_hidden_layers=num_hidden_layers) |
|
merged = cls() |
|
merged.key_cache = base_merged.key_cache |
|
merged.value_cache = base_merged.value_cache |
|
merged._seen_tokens = base_merged._seen_tokens |
|
|
|
|
|
merged.predictor_primary_key = cls._merge_list_tensors( |
|
[split.predictor_primary_key for split in splits] |
|
) |
|
merged.predictor_primary_value = cls._merge_list_tensors( |
|
[split.predictor_primary_value for split in splits] |
|
) |
|
merged.predictor_importance_key = cls._merge_list_tensors( |
|
[split.predictor_importance_key for split in splits] |
|
) |
|
|
|
return merged |
|
|
|
def batch_repeat_interleave(self, repeats: int): |
|
super().batch_repeat_interleave(repeats) |
|
self.predictor_primary_key = self._repeat_list_tensors( |
|
self.predictor_primary_key, repeats |
|
) |
|
self.predictor_primary_value = self._repeat_list_tensors( |
|
self.predictor_primary_value, repeats |
|
) |
|
self.predictor_importance_key = self._repeat_list_tensors( |
|
self.predictor_importance_key, repeats |
|
) |
|
|
|
def batch_select_indices(self, indices: torch.Tensor): |
|
super().batch_select_indices(indices) |
|
self.predictor_primary_key = self._select_list_tensors( |
|
self.predictor_primary_key, indices |
|
) |
|
self.predictor_primary_value = self._select_list_tensors( |
|
self.predictor_primary_value, indices |
|
) |
|
self.predictor_importance_key = self._select_list_tensors( |
|
self.predictor_importance_key, indices |
|
) |
|
|
|
@staticmethod |
|
def _ensure_list_capacity(lst: list, idx: int, fill=None): |
|
if len(lst) <= idx: |
|
lst.extend([fill] * (idx + 1 - len(lst))) |
|
|
|
@staticmethod |
|
def _slice_list_tensors( |
|
tensor_list: List[Optional[torch.Tensor]], start: int, end: int |
|
) -> List[Optional[torch.Tensor]]: |
|
out = [] |
|
for t in tensor_list: |
|
if t is None: |
|
out.append(None) |
|
else: |
|
out.append(t[start:end, ...]) |
|
return out |
|
|
|
@classmethod |
|
def _merge_list_tensors( |
|
cls, list_of_lists: List[List[Optional[torch.Tensor]]] |
|
) -> List[Optional[torch.Tensor]]: |
|
|
|
if not list_of_lists: |
|
return [] |
|
|
|
|
|
max_len = len(list_of_lists[0]) |
|
merged = [None] * max_len |
|
|
|
for layer_idx in range(max_len): |
|
|
|
chunk_tensors = [] |
|
for split in list_of_lists: |
|
t = split[layer_idx] if layer_idx < len(split) else None |
|
if t is not None: |
|
chunk_tensors.append(t) |
|
if len(chunk_tensors) == 0: |
|
merged[layer_idx] = None |
|
else: |
|
merged[layer_idx] = torch.cat(chunk_tensors, dim=0) |
|
return merged |
|
|
|
@staticmethod |
|
def _repeat_list_tensors( |
|
tensor_list: List[Optional[torch.Tensor]], repeats: int |
|
) -> List[Optional[torch.Tensor]]: |
|
out = [] |
|
for t in tensor_list: |
|
if t is None: |
|
out.append(None) |
|
else: |
|
out.append(t.repeat_interleave(repeats, dim=0)) |
|
return out |
|
|
|
@staticmethod |
|
def _select_list_tensors( |
|
tensor_list: List[Optional[torch.Tensor]], indices: torch.Tensor |
|
) -> List[Optional[torch.Tensor]]: |
|
out = [] |
|
for t in tensor_list: |
|
if t is None: |
|
out.append(None) |
|
else: |
|
out.append(t.index_select(0, indices)) |
|
return out |
|
|
|
|
|
class TokenImportancePredictorAttentive(nn.Module): |
|
def __init__(self, config, pred_hid_size, num_heads, num_hidden_layers, dDash, intdim, \ |
|
attn_reduce_factor, dropout=0.1): |
|
""" |
|
Optimized Token Importance Predictor with parallel Q-K projections and simplified mapping. |
|
|
|
Args: |
|
config: Configuration object containing model parameters. |
|
pred_hid_size (int): Hidden size for the predictor's attention layer. |
|
num_heads (int): Number of attention heads. |
|
num_hidden_layers (int): Number of transformer layers to predict. |
|
dropout (float): Dropout probability. |
|
q_downscale (int): Factor to downscale the Q dimension for efficiency. |
|
intermediate_dim (int): Intermediate dimension for non-linear transformations in projections. |
|
""" |
|
super().__init__() |
|
self.config = config |
|
self.hidden_size = pred_hid_size |
|
self.num_heads = num_heads |
|
self.num_hidden_layers = num_hidden_layers |
|
self.dropout = dropout |
|
self.head_dim = pred_hid_size // (num_heads * 4) |
|
self.rope_theta = config.rope_theta |
|
self.dDash = dDash |
|
self.intermediate_dim = intdim |
|
self.attn_reduce_factor = attn_reduce_factor |
|
self.max_position_embeddings = config.max_position_embeddings |
|
self.flash_attn = False |
|
assert pred_hid_size % (num_heads * 4) == 0, "pred_hid_size must be divisible by num_heads * 4." |
|
|
|
|
|
self.hidden_size_reduced = self.hidden_size // self.attn_reduce_factor |
|
assert self.hidden_size_reduced % self.num_heads == 0, "Reduced hidden size must be divisible by num_heads" |
|
self.attn_head_dim = self.hidden_size_reduced // self.num_heads |
|
|
|
|
|
self.input_proj = nn.Linear(self.hidden_size, self.hidden_size_reduced, bias=False) |
|
|
|
|
|
self.q_proj_attn = nn.Linear(self.hidden_size_reduced, self.hidden_size_reduced, bias=False) |
|
self.k_proj_attn = nn.Linear(self.hidden_size_reduced, self.hidden_size_reduced, bias=False) |
|
self.v_proj_attn = nn.Linear(self.hidden_size_reduced, self.hidden_size_reduced, bias=False) |
|
|
|
|
|
self.attn_dropout = nn.Dropout(self.dropout) |
|
|
|
|
|
self.norm1 = nn.LayerNorm(self.hidden_size_reduced) |
|
self.norm2 = nn.LayerNorm(self.hidden_size) |
|
|
|
self.ffn_hidden_size = 2 * self.hidden_size_reduced |
|
self.ffn = nn.Sequential( |
|
nn.Linear(self.hidden_size_reduced, self.ffn_hidden_size), |
|
nn.GELU(), |
|
nn.Linear(self.ffn_hidden_size, self.hidden_size), |
|
nn.Dropout(self.dropout) |
|
) |
|
|
|
self.norm_importance = nn.LayerNorm(self.hidden_size) |
|
|
|
|
|
|
|
self.q_proj_importance = nn.Sequential( |
|
nn.Linear(pred_hid_size, self.intermediate_dim, bias=False), |
|
nn.SiLU(), |
|
nn.Linear(self.intermediate_dim, num_hidden_layers * num_heads * self.dDash, bias=False) |
|
) |
|
self.k_proj_importance = nn.Sequential( |
|
nn.Linear(pred_hid_size, self.intermediate_dim, bias=False), |
|
nn.SiLU(), |
|
nn.Linear(self.intermediate_dim, num_hidden_layers * num_heads * self.dDash, bias=False) |
|
) |
|
|
|
|
|
self._init_rope() |
|
self._initialize_weights() |
|
self.device = None |
|
|
|
def _initialize_weights(self): |
|
for name, module in self.named_modules(): |
|
if isinstance(module, nn.Linear): |
|
nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.constant_(module.bias, 0) |
|
elif isinstance(module, nn.LayerNorm): |
|
nn.init.constant_(module.weight, 1.0) |
|
nn.init.constant_(module.bias, 0.0) |
|
elif isinstance(module, nn.MultiheadAttention): |
|
|
|
nn.init.xavier_uniform_(module.in_proj_weight) |
|
if module.in_proj_bias is not None: |
|
nn.init.constant_(module.in_proj_bias, 0) |
|
|
|
|
|
nn.init.xavier_uniform_(module.out_proj.weight) |
|
if module.out_proj.bias is not None: |
|
nn.init.constant_(module.out_proj.bias, 0) |
|
|
|
def _init_rope(self): |
|
|
|
|
|
config_copy = copy.deepcopy(self.config) |
|
config_copy.rope_scaling = { |
|
"factor": 32.0, |
|
"high_freq_factor": 4.0, |
|
"low_freq_factor": 1.0, |
|
"original_max_position_embeddings": 8192, |
|
"rope_type": "llama3" |
|
} |
|
config_copy.head_dim = self.attn_head_dim |
|
|
|
|
|
self.rotary_emb_attn = LlamaRotaryEmbedding( |
|
config_copy |
|
) |
|
|
|
config_copy.head_dim = self.dDash |
|
|
|
self.rotary_emb_importance = LlamaRotaryEmbedding( |
|
config_copy |
|
) |
|
|
|
def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False, layer_idx=None): |
|
""" |
|
Forward pass for the Optimized Token Importance Predictor. |
|
|
|
Args: |
|
hidden_states (torch.Tensor): Input tensor of shape [B, L, HQ]. |
|
attention_mask (torch.Tensor, optional): Attention mask of shape [B, 1, 1, L] or [B, 1, L, L]. |
|
position_ids (torch.Tensor, optional): Position IDs. |
|
past_key_value (tuple, optional): Past key and value states. |
|
use_cache (bool, optional): Whether to use cache. |
|
|
|
Returns: |
|
torch.Tensor: Importance scores of shape [B, N, H, L, L]. |
|
""" |
|
layer_idx = 0 |
|
|
|
|
|
if self.device != hidden_states.device: |
|
self.device = hidden_states.device |
|
self.to(self.device) |
|
|
|
B, L, E = hidden_states.size() |
|
|
|
|
|
hidden_states = hidden_states.to(self.input_proj.weight.dtype) |
|
hidden_states_reduced = self.input_proj(hidden_states) |
|
|
|
q = self.q_proj_attn(hidden_states_reduced) |
|
k = self.k_proj_attn(hidden_states_reduced) |
|
v = self.v_proj_attn(hidden_states_reduced) |
|
|
|
q = q.view(B, L, self.num_heads, self.attn_head_dim).transpose(1, 2) |
|
k = k.view(B, L, self.num_heads, self.attn_head_dim).transpose(1, 2) |
|
v = v.view(B, L, self.num_heads, self.attn_head_dim).transpose(1, 2) |
|
if (past_key_value is not None |
|
and layer_idx < len(past_key_value.predictor_primary_key) |
|
and past_key_value.predictor_primary_key[layer_idx] is not None): |
|
offset = past_key_value.predictor_primary_key[layer_idx].shape[2] |
|
else: |
|
offset = 0 |
|
|
|
|
|
kv_seq_len = offset + L |
|
|
|
|
|
if position_ids is None: |
|
|
|
position_ids = torch.arange(offset, offset + L, dtype=torch.long, device=self.device) |
|
position_ids = position_ids.unsqueeze(0).expand(B, L) |
|
|
|
|
|
cos, sin = self.rotary_emb_attn(v, position_ids) |
|
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) |
|
|
|
|
|
if use_cache and past_key_value is not None: |
|
k, v = past_key_value.update_predictor_primary(k.detach(), v.detach(), layer_idx) |
|
kv_seq_len = k.size(2) |
|
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True) |
|
attn_output = attn_output.to(q.dtype) |
|
attn_output = attn_output.transpose(1, 2).contiguous().view(B, L, self.hidden_size_reduced) |
|
attn_output = self.norm1(attn_output) |
|
ffn_output = self.ffn(attn_output) |
|
|
|
hidden_states = self.norm2(hidden_states + ffn_output) |
|
|
|
B, L, E = hidden_states.size() |
|
|
|
H = self.num_heads |
|
N = self.num_hidden_layers |
|
|
|
hidden_states_for_importance = self.norm_importance(hidden_states) |
|
q_importance = self.q_proj_importance(hidden_states_for_importance) |
|
k_importance = self.k_proj_importance(hidden_states_for_importance) |
|
|
|
|
|
q_importance = q_importance.view(B, L, N, H, self.dDash).permute(0, 2, 3, 1, 4).contiguous() |
|
k_importance = k_importance.view(B, L, N, H, self.dDash).permute(0, 2, 3, 1, 4).contiguous() |
|
|
|
|
|
q_importance = q_importance.view(B * N * H, L, self.dDash) |
|
k_importance = k_importance.view(B * N * H, L, self.dDash) |
|
|
|
|
|
cos, sin = self.rotary_emb_importance(k_importance, position_ids) |
|
q_importance, k_importance = apply_rotary_pos_emb(q_importance, k_importance, cos, sin, position_ids) |
|
|
|
if use_cache and past_key_value is not None: |
|
k_importance = past_key_value.update_predictor_importance(k_importance.detach(), layer_idx) |
|
|
|
k_importance = k_importance.view(B * H, N, -1, self.dDash) |
|
q_importance = q_importance.view(B * H, N, -1, self.dDash) |
|
return q_importance, k_importance |
|
|
|
|
|
|
|
class HeadImportancePredictor(nn.Module): |
|
def __init__(self, config, pred_hid_size, num_heads, num_hidden_layers, dDash, intdim, \ |
|
attn_reduce_factor, dropout=0.1): |
|
""" |
|
Optimized Token Importance Predictor with parallel Q-K projections and simplified mapping. |
|
|
|
Args: |
|
config: Configuration object containing model parameters. |
|
pred_hid_size (int): Hidden size for the predictor's attention layer. |
|
num_heads (int): Number of attention heads. |
|
num_hidden_layers (int): Number of transformer layers to predict. |
|
dropout (float): Dropout probability. |
|
q_downscale (int): Factor to downscale the Q dimension for efficiency. |
|
intermediate_dim (int): Intermediate dimension for non-linear transformations in projections. |
|
""" |
|
super().__init__() |
|
self.is_head_predictor = None |
|
self.config = config |
|
self.hidden_size = pred_hid_size |
|
self.num_heads = num_heads |
|
self.num_hidden_layers = num_hidden_layers |
|
self.dropout = dropout |
|
self.head_dim = pred_hid_size // (num_heads * 4) |
|
self.rope_theta = config.rope_theta |
|
self.dDash = dDash |
|
self.intermediate_dim = intdim |
|
self.attn_reduce_factor = attn_reduce_factor |
|
self.max_position_embeddings = config.max_position_embeddings |
|
self.flash_attn = False |
|
|
|
|
|
self.hidden_size_reduced = self.hidden_size // self.attn_reduce_factor |
|
assert self.hidden_size_reduced % self.num_heads == 0, "Reduced hidden size must be divisible by num_heads" |
|
self.attn_head_dim = self.hidden_size_reduced // self.num_heads |
|
|
|
|
|
self.input_proj = nn.Linear(self.hidden_size, self.hidden_size_reduced, bias=False) |
|
|
|
|
|
self.q_proj_attn = nn.Linear(self.hidden_size_reduced, self.hidden_size_reduced, bias=False) |
|
self.k_proj_attn = nn.Linear(self.hidden_size_reduced, self.hidden_size_reduced, bias=False) |
|
self.v_proj_attn = nn.Linear(self.hidden_size_reduced, self.hidden_size_reduced, bias=False) |
|
|
|
|
|
self.attn_dropout = nn.Dropout(self.dropout) |
|
|
|
|
|
self.norm1 = nn.LayerNorm(self.hidden_size_reduced) |
|
self.norm2 = nn.LayerNorm(self.hidden_size) |
|
|
|
self.ffn_hidden_size = 4 * self.hidden_size_reduced |
|
self.ffn = nn.Sequential( |
|
nn.Linear(self.hidden_size_reduced, self.ffn_hidden_size), |
|
nn.GELU(), |
|
nn.Linear(self.ffn_hidden_size, self.num_heads * self.num_hidden_layers), |
|
) |
|
|
|
|
|
self._init_rope() |
|
self._initialize_weights() |
|
self.device = None |
|
|
|
def _initialize_weights(self): |
|
for name, module in self.named_modules(): |
|
if isinstance(module, nn.Linear): |
|
nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.constant_(module.bias, 0) |
|
elif isinstance(module, nn.LayerNorm): |
|
nn.init.constant_(module.weight, 1.0) |
|
nn.init.constant_(module.bias, 0.0) |
|
elif isinstance(module, nn.MultiheadAttention): |
|
|
|
nn.init.xavier_uniform_(module.in_proj_weight) |
|
if module.in_proj_bias is not None: |
|
nn.init.constant_(module.in_proj_bias, 0) |
|
|
|
|
|
nn.init.xavier_uniform_(module.out_proj.weight) |
|
if module.out_proj.bias is not None: |
|
nn.init.constant_(module.out_proj.bias, 0) |
|
|
|
def _init_rope(self): |
|
config_copy = copy.deepcopy(self.config) |
|
config_copy.head_dim = self.attn_head_dim |
|
|
|
self.rotary_emb_attn = LlamaRotaryEmbedding( |
|
config_copy |
|
) |
|
|
|
self.rotary_emb_importance = LlamaRotaryEmbedding( |
|
config_copy |
|
) |
|
|
|
def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False): |
|
""" |
|
Forward pass for the Optimized Token Importance Predictor. |
|
|
|
Args: |
|
hidden_states (torch.Tensor): Input tensor of shape [B, L, HQ]. |
|
attention_mask (torch.Tensor, optional): Attention mask of shape [B, 1, 1, L] or [B, 1, L, L]. |
|
position_ids (torch.Tensor, optional): Position IDs. |
|
past_key_value (tuple, optional): Past key and value states. |
|
use_cache (bool, optional): Whether to use cache. |
|
|
|
Returns: |
|
torch.Tensor: Importance scores of shape [B, N, H, L, L]. |
|
""" |
|
|
|
if self.device != hidden_states.device: |
|
self.device = hidden_states.device |
|
self.to(self.device) |
|
|
|
B, L, E = hidden_states.size() |
|
if past_key_value is None: |
|
past_key_value = {} |
|
|
|
|
|
past_primary = past_key_value.get('primary', None) |
|
|
|
hidden_states = hidden_states.to(self.input_proj.weight.dtype) |
|
hidden_states_reduced = self.input_proj(hidden_states) |
|
|
|
q = self.q_proj_attn(hidden_states_reduced) |
|
k = self.k_proj_attn(hidden_states_reduced) |
|
v = self.v_proj_attn(hidden_states_reduced) |
|
|
|
q = q.view(B, L, self.num_heads, self.attn_head_dim).transpose(1, 2) |
|
k = k.view(B, L, self.num_heads, self.attn_head_dim).transpose(1, 2) |
|
v = v.view(B, L, self.num_heads, self.attn_head_dim).transpose(1, 2) |
|
|
|
if past_primary is not None: |
|
past_L = past_primary[0].shape[2] |
|
kv_seq_len = past_L + L |
|
else: |
|
kv_seq_len = L |
|
|
|
|
|
cos, sin = self.rotary_emb_attn(v, position_ids) |
|
if position_ids is None: |
|
position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=self.device) |
|
position_ids = position_ids.unsqueeze(0).expand(B, kv_seq_len) |
|
|
|
if past_primary is not None: |
|
|
|
k = torch.cat([past_primary[0], k], dim=2) |
|
v = torch.cat([past_primary[1], v], dim=2) |
|
|
|
|
|
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) |
|
|
|
|
|
if use_cache: |
|
past_key_value['primary'] = (k.detach(), v.detach()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True) |
|
attn_output = attn_output.to(q.dtype) |
|
attn_output = attn_output.transpose(1, 2).contiguous().view(B, L, self.hidden_size_reduced) |
|
attn_output = self.norm1(attn_output) |
|
head_importances = self.ffn(attn_output) |
|
return head_importances, past_key_value |
|
|
|
def calculate_hit_metrics(estimated_importance: torch.Tensor, |
|
true_importance: torch.Tensor, |
|
top_k_ratio: float = 0.5) -> Tuple[float, float, float]: |
|
""" |
|
Calculate hit accuracy, mean, and max rank correlation between estimated and true importance tensors. |
|
We compute metrics along the last dimension of the input tensors. |
|
|
|
Shapes: |
|
- 4D token-importance: [B, H, L, L]. We slice the last query (index -1) => [B, H, L]. |
|
- 3D head-importance: [B, L, H]. We use all of it as-is => [B, L, H]. |
|
|
|
Args: |
|
estimated_importance (torch.Tensor): [B, H, L, L] or [B, L, H] |
|
true_importance (torch.Tensor): [B, H, L, L] or [B, L, H] |
|
top_k_ratio (float): Fraction of top-k elements to consider for hit accuracy (default=0.5). |
|
|
|
Returns: |
|
(hit_accuracy, mean_corr, max_corr): |
|
hit_accuracy (float): Intersection ratio of top-k sets (0..1). |
|
mean_corr (float): Average Spearman rank correlation over all [B, ...]. |
|
max_corr (float): Maximum Spearman rank correlation among all [B, ...]. |
|
""" |
|
|
|
|
|
if estimated_importance.dim() == 4: |
|
|
|
estimated_importance = estimated_importance[:, :, -1, :] |
|
true_importance = true_importance[:, :, -1, :] |
|
|
|
|
|
denom_for_hits = estimated_importance.size(0) * estimated_importance.size(1) |
|
elif estimated_importance.dim() == 3: |
|
|
|
|
|
denom_for_hits = estimated_importance.size(0) * estimated_importance.size(1) |
|
else: |
|
raise ValueError("Tensors must be either 4D [B,H,L,L] or 3D [B,L,H].") |
|
|
|
|
|
|
|
_, sorted_esti = torch.sort(estimated_importance, dim=-1, descending=True) |
|
_, sorted_true = torch.sort(true_importance, dim=-1, descending=True) |
|
|
|
|
|
n = sorted_esti.shape[-1] |
|
d = sorted_esti.float() - sorted_true.float() |
|
d_squared = d ** 2 |
|
sum_d_squared = d_squared.sum(dim=-1) |
|
rank_corr = 1 - (6 * sum_d_squared) / (n * (n**2 - 1)) |
|
|
|
mean_corr = rank_corr.mean().item() |
|
max_corr = rank_corr.max().item() |
|
|
|
|
|
top_k = max(1, int(n * top_k_ratio)) |
|
_, top_esti_indices = torch.topk(estimated_importance, top_k, dim=-1) |
|
_, top_true_indices = torch.topk(true_importance, top_k, dim=-1) |
|
|
|
|
|
|
|
|
|
matches = (top_esti_indices.unsqueeze(-1) == top_true_indices.unsqueeze(-2)) |
|
intersection = matches.any(dim=-1).sum(dim=-1) |
|
|
|
|
|
total_possible = top_k * denom_for_hits |
|
hit_accuracy = intersection.sum().item() / total_possible |
|
|
|
return hit_accuracy, mean_corr, max_corr |
|
|
|
|
|
def threshold_to_mask(unadj_importance_mask, perhead_thresholds, min_sparse_index, bsz, q_len, key_len): |
|
""" |
|
Create a mask tensor based on per-head thresholds, setting values below the threshold to -inf. |
|
|
|
Args: |
|
- unadj_importance_mask: torch.Tensor of shape [B, H, Lq, Lk]. |
|
- perhead_thresholds: torch.Tensor of shape [H], per-head thresholds. |
|
- min_sparse_index: Minimum index for sparsity; values below this index will not be masked. |
|
- bsz: Batch size. |
|
- q_len: Query length (Lq). |
|
- key_len: Key length (Lk). |
|
|
|
Returns: |
|
- mask_tensor: torch.Tensor of shape [B, H, Lq, Lk], with values below threshold as -inf. |
|
""" |
|
|
|
thresholds_broadcast = perhead_thresholds.view(1, -1, 1, 1) |
|
|
|
|
|
mask_tensor = torch.where( |
|
unadj_importance_mask >= thresholds_broadcast, |
|
torch.zeros_like(unadj_importance_mask), |
|
torch.full_like(unadj_importance_mask, float('-inf')) |
|
) |
|
|
|
|
|
mask_tensor[:, :, :, :min_sparse_index] = 0.0 |
|
|
|
return mask_tensor |
|
|
|
class SlidingWindowCache: |
|
def __init__(self, max_seq_len, sliding_window, device): |
|
self.sliding_window = sliding_window |
|
self.device = device |
|
if sliding_window is None: |
|
self.max_seq_len = 0 |
|
self.window = None |
|
else: |
|
self.max_seq_len = max_seq_len |
|
self.window = self._create_window(self.max_seq_len) |
|
|
|
def _create_window(self, seq_len): |
|
idx = torch.arange(seq_len, device=self.device) |
|
query = idx.unsqueeze(1) |
|
key = idx.unsqueeze(0) |
|
win = (key >= (query - self.sliding_window + 1)) & (key <= query) |
|
return win.unsqueeze(0).unsqueeze(0) |
|
|
|
def get_window(self, q_len, key_len): |
|
if self.sliding_window is None: |
|
return None |
|
req = max(q_len, key_len) |
|
if req > self.max_seq_len: |
|
self.max_seq_len = req |
|
self.window = self._create_window(self.max_seq_len) |
|
return self.window[:, :, :q_len, :key_len] |
|
|
|
def enforce_sliding_window(mask_tensor, window): |
|
if window is None: |
|
return mask_tensor |
|
return mask_tensor.masked_fill(window, 0.0) |
|
|
|
|
|
def sorted_index_to_mask( |
|
sorted_indices, |
|
attention_mask, |
|
min_sparse_index, |
|
bsz, |
|
q_len, |
|
key_len, |
|
sparse_aggression, |
|
sliding_window=None |
|
): |
|
""" |
|
sorted_indices: [B, H, q_len, key_len] |
|
attention_mask: [1, 1, q_len, key_len] (True = keep, False = mask out, or vice versa) |
|
min_sparse_index: guaranteed front region to keep |
|
sliding_window: guaranteed trailing region (for each query) to keep |
|
sparse_aggression: float in [0,1], fraction of keys to drop or keep |
|
""" |
|
device = sorted_indices.device |
|
dtype = sorted_indices.dtype |
|
|
|
|
|
if q_len == 1: |
|
query_positions = torch.arange(q_len, device=device).view(1, 1, q_len, 1).float() |
|
query_positions[0] = key_len + 1 |
|
else: |
|
query_positions = torch.arange(q_len, device=device).view(1, 1, q_len, 1).float() + 1.0 |
|
K_original = torch.ceil(query_positions * sparse_aggression).long() |
|
K_original = torch.clamp(K_original, max=key_len) |
|
|
|
|
|
guaranteed = min_sparse_index |
|
if sliding_window is not None: |
|
guaranteed += sliding_window |
|
|
|
K_adjusted = K_original - guaranteed |
|
|
|
K_adjusted = torch.clamp(K_adjusted, min=0, max=key_len) |
|
|
|
|
|
attention_mask_expanded = attention_mask.expand(bsz, -1, -1, -1) |
|
attention_mask_expanded = attention_mask_expanded.expand(-1, sorted_indices.size(1), -1, -1) |
|
|
|
attention_mask_expanded = (~attention_mask_expanded.bool()).int() |
|
|
|
|
|
gathered_mask = torch.gather(attention_mask_expanded, dim=-1, index=sorted_indices) |
|
|
|
|
|
gathered_mask_float = gathered_mask.float() |
|
cum_sum = torch.cumsum(gathered_mask_float, dim=-1) |
|
|
|
|
|
|
|
K_broadcast = K_adjusted.view(1, 1, q_len, 1).expand_as(cum_sum) |
|
selected_mask = (cum_sum <= K_broadcast) |
|
|
|
|
|
mask_tensor = torch.full_like(attention_mask_expanded.float(), float('-inf')) |
|
|
|
|
|
scatter_values = torch.zeros_like(gathered_mask_float) |
|
scatter_values = scatter_values.masked_fill(~selected_mask, float('-inf')) |
|
mask_tensor.scatter_(-1, sorted_indices, scatter_values) |
|
|
|
|
|
mask_tensor[:, :, :, :min_sparse_index] = 0.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return mask_tensor |
|
|
|
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): |
|
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" |
|
|
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, config=None): |
|
self.scaling_factor = scaling_factor |
|
super().__init__(config) |
|
|
|
def _set_cos_sin_cache(self, seq_len, device, dtype): |
|
self.max_seq_len_cached = seq_len |
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) |
|
t = t / self.scaling_factor |
|
|
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) |
|
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) |
|
|
|
|
|
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): |
|
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" |
|
|
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, config=None): |
|
self.scaling_factor = scaling_factor |
|
super().__init__(config) |
|
|
|
def _set_cos_sin_cache(self, seq_len, device, dtype): |
|
self.max_seq_len_cached = seq_len |
|
|
|
if seq_len > self.max_position_embeddings: |
|
base = self.base * ( |
|
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) |
|
) ** (self.dim / (self.dim - 2)) |
|
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) |
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) |
|
|
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) |
|
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) |
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
""" |
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
|
""" |
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
if n_rep == 1: |
|
return hidden_states |
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
class LlamaAttentionExperimental(nn.Module): |
|
def __init__(self, config: LlamaConfig, producer=None, layer_idx=0): |
|
super().__init__() |
|
self.config = config |
|
self.hidden_size = config.hidden_size |
|
self.num_hidden_layers = config.num_hidden_layers |
|
self.num_heads = config.num_attention_heads |
|
self.head_dim = self.hidden_size // self.num_heads |
|
self.num_key_value_heads = config.num_key_value_heads |
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
|
self.max_position_embeddings = config.max_position_embeddings |
|
self.rope_theta = config.rope_theta |
|
self.inference_mode = False |
|
self.producer = producer |
|
self.layer_idx = layer_idx |
|
self.token_sparse_method = None |
|
self.sparse_aggression = None |
|
self.stream_llm_start_size = None |
|
self.dDash = None |
|
self.intdim = None |
|
self.attn_reduce_factor = None |
|
self.head_attn_reduce_factor = None |
|
self.effective_sparsity = None |
|
self.min_sparse_index = None |
|
self.pred_hid_size = self.hidden_size |
|
self.num_tok_per_page = None |
|
self.calc_hitrates = False |
|
self.flash_attn = False |
|
self.train_headpredictor = False |
|
self.calibrate_thresholds = False |
|
self.test_with_thresholds = False |
|
self.old_predictor = None |
|
|
|
if self.layer_idx > 0: |
|
self.mseloss = MSELoss(reduction='none') |
|
self.msemagn_loss = None |
|
self.headmseloss = MSELoss(reduction='none') |
|
self.headmsemagn_loss = None |
|
|
|
if self.producer is None: |
|
self.q_importance = None |
|
self.k_importance = None |
|
self.head_importances = None |
|
self.actmagn_masklist = {} |
|
self.available_tokens = {} |
|
|
|
|
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) |
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) |
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) |
|
self._init_rope() |
|
|
|
def update_predictor(self): |
|
self.sparse_token_predictor = TokenImportancePredictorAttentive( |
|
self.config, self.pred_hid_size, self.num_heads, self.num_layers_pred, dropout=0.1, dDash = self.dDash, \ |
|
intdim = self.intdim, attn_reduce_factor=self.attn_reduce_factor |
|
).to('cuda:0') |
|
self.sparse_token_predictor.flash_attn = self.flash_attn |
|
if self.train_headpredictor: |
|
self.sparse_head_predictor = HeadImportancePredictor( |
|
self.config, self.pred_hid_size, self.num_heads, self.num_layers_pred, dropout=0.1, dDash = self.dDash, \ |
|
intdim = self.intdim, attn_reduce_factor=self.head_attn_reduce_factor |
|
).to('cuda:0') |
|
self.sparse_head_predictor.flash_attn = self.flash_attn |
|
|
|
def set_token_sparsity(self): |
|
assert self.token_sparse_method is not None, "Set token sparse method first!" |
|
if self.token_sparse_method is not None: |
|
try: |
|
mname = self.config._name_or_path.split("/")[-1] |
|
read_path = f"threshold_calibs/{mname}/{self.token_sparse_method}.pkl" |
|
threshold_model_dictionary = torch.load(read_path) |
|
self.tok_calibration_set = threshold_model_dictionary |
|
except: |
|
pass |
|
if self.token_sparse_method == "LazyLLM": |
|
if self.layer_idx <= 9: |
|
self.sparse_aggression = 1 |
|
elif self.layer_idx <= 19: |
|
self.sparse_aggression = 0.7 |
|
elif self.layer_idx <= 28: |
|
self.sparse_aggression = 0.4 |
|
else: |
|
self.sparse_aggression = 0.1 |
|
elif "fixed" in self.token_sparse_method: |
|
if self.layer_idx == 0: |
|
self.sparse_aggression = 1 |
|
else: |
|
self.sparse_aggression = 1 - float(self.token_sparse_method.split("_")[1].split("pc")[0])/100. |
|
elif "progressive" in self.token_sparse_method: |
|
pc_drop = float(self.token_sparse_method.split("_")[1].split("pc")[0])/100. |
|
self.sparse_aggression = (1 - pc_drop) ** (self.layer_idx) |
|
else: |
|
raise ValueError(f"Unknown token sparsity method {self.token_sparse_method}") |
|
|
|
|
|
def _init_rope(self): |
|
if self.config.rope_scaling is None: |
|
self.rotary_emb = LlamaRotaryEmbedding( |
|
self.config |
|
) |
|
else: |
|
scaling_type = self.config.rope_scaling.get("type") or self.config.rope_scaling.get("rope_type") |
|
scaling_factor = self.config.rope_scaling["factor"] |
|
if scaling_type == "linear" or scaling_type == 'llama3': |
|
self.rotary_emb = LlamaLinearScalingRotaryEmbedding( |
|
self.head_dim, |
|
max_position_embeddings=self.max_position_embeddings, |
|
scaling_factor=scaling_factor, |
|
base=self.rope_theta, |
|
config=self.config |
|
) |
|
elif scaling_type == "dynamic": |
|
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( |
|
self.head_dim, |
|
max_position_embeddings=self.max_position_embeddings, |
|
scaling_factor=scaling_factor, |
|
base=self.rope_theta, |
|
config=self.config |
|
) |
|
else: |
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}") |
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Union[DynamicCache, PredictorDynamicCache]] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
padding_mask: Optional[torch.LongTensor] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
**kwargs, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[PredictorDynamicCache]]: |
|
bsz, q_len, _ = hidden_states.size() |
|
Ltrack = hidden_states.size(1) |
|
|
|
if self.config.pretraining_tp > 1: |
|
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp |
|
query_slices = self.q_proj.weight.split( |
|
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 |
|
) |
|
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) |
|
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) |
|
|
|
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] |
|
query_states = torch.cat(query_states, dim=-1) |
|
|
|
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] |
|
key_states = torch.cat(key_states, dim=-1) |
|
|
|
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] |
|
value_states = torch.cat(value_states, dim=-1) |
|
else: |
|
query_states = self.q_proj(hidden_states) |
|
key_states = self.k_proj(hidden_states) |
|
value_states = self.v_proj(hidden_states) |
|
|
|
evalmode = self.eval_llm_mode |
|
num_tokens_to_keep = int(q_len * self.sparse_aggression) |
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
cos, sin = position_embeddings |
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) |
|
|
|
if use_cache: |
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) |
|
|
|
kv_seq_len = key_states.shape[-2] |
|
final_mask = None |
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups) |
|
value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
|
|
key_len = key_states.size(2) |
|
bsz, q_len = query_states.size(0), query_states.size(2) |
|
|
|
if attention_mask is None: |
|
|
|
causal_mask_2d = torch.ones(q_len, kv_seq_len, |
|
device=hidden_states.device, |
|
dtype=torch.bool).triu(diagonal=1) |
|
|
|
causal_mask_4d = causal_mask_2d.unsqueeze(0).expand(bsz, 1, q_len, kv_seq_len) |
|
|
|
attention_mask = torch.full_like(causal_mask_4d, 0, dtype=hidden_states.dtype) |
|
if q_len != 1: |
|
attention_mask = attention_mask.masked_fill(causal_mask_4d, float("-inf")) |
|
|
|
if self.inference_mode: |
|
min_sparse_index = self.min_sparse_index |
|
with torch.no_grad(): |
|
if evalmode == "ExpPred": |
|
if self.layer_idx > 0: |
|
q_importance_tensor = self.producer.q_importance[:, self.layer_idx % self.producer_frequency, :, :].float().to(query_states.device) |
|
k_importance_tensor = self.producer.k_importance[:, self.layer_idx % self.producer_frequency, :, :].float().to(key_states.device) |
|
importance_mask = torch.bmm(q_importance_tensor, k_importance_tensor.transpose(-2, -1)) / math.sqrt(self.dDash) |
|
importance_mask = importance_mask.view(bsz, self.num_heads, q_len, key_len) |
|
attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim) |
|
if self.calc_hitrates: |
|
self.tok_hit_acc, self.tok_mean_rank_corr, self.tok_max_rank_corr = calculate_hit_metrics( |
|
estimated_importance=importance_mask, |
|
true_importance=attn_weights, |
|
top_k_ratio=0.5 |
|
) |
|
if self.calibrate_thresholds: |
|
|
|
unadj_importance_mask = importance_mask.clone() |
|
importance_mask = torch.softmax(importance_mask + attention_mask, dim=-1) |
|
sorted_indices = torch.argsort(importance_mask, dim=-1, descending=True) |
|
sorted_indices = sorted_indices[:, :, -q_len:, :] |
|
sorted_values, sorted_ix = torch.sort(importance_mask, dim=-1) |
|
sorted_true_values, _ = torch.sort(torch.gather(unadj_importance_mask, dim=-1, index=sorted_ix), dim=-1) |
|
true_thresholds = sorted_true_values[:, :, :, int(importance_mask.size(-1) * self.sparse_aggression)] |
|
thresholds = sorted_values[:, :, :, int(importance_mask.size(-1) * self.sparse_aggression)] |
|
self.true_threshmean = true_thresholds |
|
self.threshmean = thresholds |
|
if self.test_with_thresholds: |
|
unadj_importance_mask = importance_mask.clone() |
|
perhead_thresholds = self.tok_calibration_set[self.layer_idx - 1].to(unadj_importance_mask.device) |
|
mask_tensor = threshold_to_mask(unadj_importance_mask, perhead_thresholds, min_sparse_index, bsz, q_len, key_len) |
|
else: |
|
importance_mask = torch.softmax(importance_mask + attention_mask, dim=-1) |
|
sorted_indices = torch.argsort(importance_mask, dim=-1, descending=True) |
|
sorted_indices = sorted_indices[:, :, -q_len:, :] |
|
mask_tensor = sorted_index_to_mask(sorted_indices, attention_mask, min_sparse_index, bsz, q_len, key_len, self.sparse_aggression, self.sliding_window) |
|
|
|
if self.sliding_window is not None: |
|
if not hasattr(self, "window_cache"): |
|
self.window_cache = SlidingWindowCache(max_seq_len=1024, |
|
sliding_window=self.sliding_window, |
|
device=mask_tensor.device) |
|
window = self.window_cache.get_window(q_len, key_len) |
|
mask_tensor = enforce_sliding_window(mask_tensor, window) |
|
final_mask = mask_tensor |
|
|
|
self.final_mask_investigate = final_mask |
|
attn_weights = attn_weights + mask_tensor + attention_mask |
|
else: |
|
attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim) |
|
attn_weights = attn_weights + attention_mask |
|
else: |
|
raise ValueError(f"Unknown eval mode {evalmode}") |
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) |
|
attn_output = torch.matmul(attn_weights, value_states) |
|
|
|
else: |
|
attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim) |
|
if self.layer_idx > 0: |
|
q_importance_tensor = self.producer.q_importance[:, self.layer_idx % self.producer_frequency, :, :].float().to(query_states.device) |
|
k_importance_tensor = self.producer.k_importance[:, self.layer_idx % self.producer_frequency, :, :].float().to(key_states.device) |
|
importance_mask = torch.bmm(q_importance_tensor, k_importance_tensor.transpose(-2, -1)) / math.sqrt(self.dDash) |
|
importance_mask = importance_mask.view(bsz, self.num_heads, q_len, key_len) |
|
|
|
if self.lookahead == 0: |
|
self.msemagn_loss = self.mseloss(attn_weights, importance_mask) |
|
else: |
|
self.msemagn_loss = self.mseloss(attn_weights[:, :, self.lookahead:, :], importance_mask[:, :, :-self.lookahead, :]) |
|
self.msemagn_loss = (self.msemagn_loss).mean(dim=(-1, -2)) |
|
self.msemagn_loss = self.msemagn_loss.mean() |
|
|
|
if self.calc_hitrates: |
|
self.tok_hit_acc, self.tok_mean_rank_corr, self.tok_max_rank_corr = calculate_hit_metrics( |
|
estimated_importance=importance_mask, |
|
true_importance=attn_weights, |
|
top_k_ratio=0.5 |
|
) |
|
|
|
if attention_mask is not None: |
|
attn_weights = attn_weights + attention_mask |
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) |
|
attn_output = torch.matmul(attn_weights, value_states) |
|
|
|
if self.layer_idx > 0 and self.train_headpredictor: |
|
head_importance_tensor = self.producer.head_importances[:, :, :, self.layer_idx % self.producer_frequency].float().to(attn_output.device) |
|
attn_head_weights = attn_output.mean(dim=-1).permute(0, 2, 1) |
|
self.headmsemagn_loss = self.headmseloss(attn_head_weights, head_importance_tensor).mean() |
|
|
|
if self.calc_hitrates: |
|
self.head_hit_acc, self.head_mean_rank_corr, self.head_max_rank_corr = calculate_hit_metrics( |
|
estimated_importance=head_importance_tensor, |
|
true_importance=attn_head_weights, |
|
top_k_ratio=0.5 |
|
) |
|
else: |
|
self.headmsemagn_loss = 0 |
|
if self.calc_hitrates: |
|
self.head_hit_acc, self.head_mean_rank_corr, self.head_max_rank_corr = 0, 0, 0 |
|
|
|
|
|
checkeverytime = hasattr(self, 'test_with_thresholds') |
|
if checkeverytime: |
|
checkeverytime = self.test_with_thresholds |
|
if final_mask is not None: |
|
if self.effective_sparsity is None or checkeverytime: |
|
true_mask = final_mask + attention_mask |
|
num_deact = true_mask.bool().sum(dim=-1) |
|
causally_deact = (attention_mask.bool()).sum(dim=-1).expand_as(num_deact) |
|
additional_deact = (num_deact - causally_deact) |
|
num_active = (~attention_mask.bool()).sum(dim=-1).expand_as(num_deact) |
|
effective_sparsity = 100 * (additional_deact.float() / num_active.float()).mean().item() |
|
self.effective_sparsity = effective_sparsity |
|
print("Effective Sparsity:", effective_sparsity, "%\t Sequence Length:", q_len) |
|
if self.layer_idx == 0: |
|
if self.effective_sparsity is None: |
|
self.effective_sparsity = 0.0 |
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
attn_output = attn_output.view(bsz, -1, self.hidden_size) |
|
|
|
if self.config.pretraining_tp > 1: |
|
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) |
|
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) |
|
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) |
|
else: |
|
attn_output = self.o_proj(attn_output) |
|
|
|
if self.producer is None: |
|
try: |
|
q_importance, k_importance = self.sparse_token_predictor( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
use_cache=use_cache, |
|
layer_idx=self.layer_idx, |
|
) |
|
if self.train_headpredictor: |
|
head_importances, past_key_value_hp = self.sparse_head_predictor( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value_hp, |
|
use_cache=use_cache |
|
) |
|
head_importances = head_importances.view(bsz, q_len, self.num_heads, self.num_hidden_layers) |
|
q_len = attn_output.size(1) |
|
k_len = k_importance.size(-1) |
|
except: |
|
print(traceback.format_exc()) |
|
import pdb; pdb.set_trace() |
|
|
|
self.q_importance = q_importance |
|
self.k_importance = k_importance |
|
|
|
if self.train_headpredictor: |
|
if self.head_importances is None: |
|
self.head_importances = head_importances |
|
else: |
|
self.head_importances = torch.cat([self.head_importances, head_importances], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not output_attentions: |
|
attn_weights = None |
|
return attn_output, attn_weights |
|
|
|
def convert_kvcache_experimental(model, config, producer_frequency): |
|
producer_layer = None |
|
producer_layer_device = None |
|
layer_counter = {'idx': 0} |
|
|
|
def recurse_convert(parent_module): |
|
nonlocal producer_layer |
|
nonlocal producer_layer_device |
|
for name, module in parent_module._modules.items(): |
|
if len(list(module.children())) > 0: |
|
recurse_convert(module) |
|
if isinstance(module, LlamaAttention): |
|
device = next(module.parameters()).device |
|
dtype = next(module.parameters()).dtype |
|
if layer_counter['idx'] % producer_frequency == 0: |
|
new_module = LlamaAttentionExperimental(config).to(dtype).to(device) |
|
producer_layer = new_module |
|
producer_layer_device = device |
|
else: |
|
new_module = LlamaAttentionExperimental( |
|
config, |
|
producer=producer_layer, |
|
layer_idx=layer_counter['idx'] |
|
).to(dtype).to(device) |
|
new_module.load_state_dict(module.state_dict(), strict=False) |
|
is_producer = layer_counter['idx'] % producer_frequency == 0 |
|
if is_producer: |
|
print(f"Converted Producer layer '{name}' to LlamaAttentionExperimental at layer index {layer_counter['idx']}") |
|
else: |
|
print(f"Converted layer '{name}' to LlamaAttentionExperimental at layer index {layer_counter['idx']}") |
|
parent_module._modules[name] = new_module |
|
layer_counter['idx'] += 1 |
|
recurse_convert(model) |
|
producer_layer = producer_layer.to(producer_layer_device) |
|
return model |
|
|
|
|
|
|
|
|
|
|
|
class LlamaButlerConfig(LlamaConfig): |
|
""" |
|
Extends HF's LlamaConfig to hold optional extra parameters for the "Butler" logic. |
|
You can store your custom attributes here, so they can be serialized in config.json. |
|
""" |
|
|
|
model_type = "llama_butler" |
|
|
|
def __init__( |
|
self, |
|
eval_llm_mode="ExpPred", |
|
token_sparse_method="fixed_50pc", |
|
producer_frequency=8, |
|
dDash=16, |
|
attn_reduce_factor=4, |
|
head_attn_reduce_factor=4, |
|
intdim=256, |
|
flash_attn=False, |
|
train_headpredictor=False, |
|
min_sparse_index=5, |
|
lookahead=0, |
|
sliding_window=None, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.eval_llm_mode = eval_llm_mode |
|
self.token_sparse_method = token_sparse_method |
|
self.producer_frequency = producer_frequency |
|
self.dDash = dDash |
|
self.attn_reduce_factor = attn_reduce_factor |
|
self.head_attn_reduce_factor = head_attn_reduce_factor |
|
self.intdim = intdim |
|
self.flash_attn = flash_attn |
|
self.train_headpredictor = train_headpredictor |
|
self.min_sparse_index = min_sparse_index |
|
self.lookahead = lookahead |
|
self.sliding_window = sliding_window |
|
|
|
|
|
|
|
|
|
|
|
class LlamaButlerForCausalLM(LlamaForCausalLM): |
|
""" |
|
A subclass of HF's LlamaForCausalLM that: |
|
- Patches each LlamaAttention to your LlamaAttentionExperimental |
|
- Sets specialized attributes (eval_llm_mode, etc.) |
|
- Overrides _prepare_cache_for_generation to inject PredictorDynamicCache |
|
""" |
|
|
|
|
|
config_class = LlamaButlerConfig |
|
|
|
def __init__(self, config: LlamaButlerConfig): |
|
super().__init__(config) |
|
""" |
|
HF's LlamaForCausalLM initializes: |
|
self.model = LlamaModel(config) |
|
self.lm_head = nn.Linear(...) |
|
""" |
|
|
|
|
|
self.model = convert_kvcache_experimental( |
|
self.model, |
|
config, |
|
config.producer_frequency |
|
) |
|
|
|
|
|
for module in self.model.modules(): |
|
if module.__class__.__name__.endswith("AttentionExperimental"): |
|
|
|
module.eval_llm_mode = config.eval_llm_mode |
|
module.token_sparse_method = config.token_sparse_method |
|
module.set_token_sparsity() |
|
|
|
module.producer_frequency = config.producer_frequency |
|
module.dDash = config.dDash |
|
module.attn_reduce_factor = config.attn_reduce_factor |
|
module.head_attn_reduce_factor = config.head_attn_reduce_factor |
|
module.intdim = config.intdim |
|
module.flash_attn = config.flash_attn |
|
module.train_headpredictor = config.train_headpredictor |
|
module.min_sparse_index = config.min_sparse_index |
|
module.lookahead = config.lookahead |
|
module.sliding_window = config.sliding_window |
|
module.num_layers_pred = config.producer_frequency |
|
|
|
|
|
if hasattr(module, "layer_idx") and (module.layer_idx % config.producer_frequency == 0): |
|
module.update_predictor() |
|
|
|
|
|
if config.eval_llm_mode in ["ExpPred", "ReplAttn"]: |
|
self._prepare_cache_for_generation = self._patched_prepare_cache_for_generation.__get__( |
|
self, self.__class__ |
|
) |
|
|
|
|
|
|
|
|
|
def _patched_prepare_cache_for_generation( |
|
self, |
|
generation_config: GenerationConfig, |
|
model_kwargs: Dict, |
|
*args, |
|
**kwargs |
|
): |
|
""" |
|
This override injects a PredictorDynamicCache |
|
in place of the standard 'past_key_values'. |
|
""" |
|
if "past_key_values" not in model_kwargs or model_kwargs["past_key_values"] is None: |
|
model_kwargs["past_key_values"] = PredictorDynamicCache() |
|
return model_kwargs |