ShortGpt / short_gpt /short_llama.py
deepnet's picture
Upload folder using huggingface_hub
821537b verified
from typing import List, Optional
import numpy as np
import torch
from llama import Llama, Transformer
from metrics import *
def sample_top_p(probs: torch.Tensor, p: float):
"""
Perform top-p (nucleus) sampling on a probability distribution.
Args:
probs (torch.Tensor): Probability distribution tensor.
p (float): Probability threshold for top-p sampling.
Returns:
torch.Tensor: Sampled token indices.
Note:
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
class TransformerWrapper(Transformer):
def __init__(self, model):
self.__dict__ = model.__dict__.copy()
@torch.inference_mode()
def forward(
self,
tokens: torch.Tensor,
start_pos: int,
return_hiddens: Optional[bool] = False):
"""
Perform a forward pass through the Transformer model.
Args:
tokens (torch.Tensor): Input token indices.
start_pos (int): Starting position for attention caching.
(Optional) return_hiddens (bool): Whether to return hidden states. Defaults to False.
Returns:
torch.Tensor: Output logits after applying the Transformer model.
(Optional) List[torch.Tensor]: Hidden states for each transformer block.
"""
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
mask = None
if seqlen > 1:
mask = torch.full(
(seqlen, seqlen), float("-inf"), device=tokens.device
)
mask = torch.triu(mask, diagonal=1)
# When performing key-value caching, we compute the attention scores
# only for the new sequence. Thus, the matrix of scores is of size
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
mask = torch.hstack([
torch.zeros((seqlen, start_pos), device=tokens.device),
mask
]).type_as(h)
hiddens = [h]
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
if return_hiddens:
hiddens.append(h)
h = self.norm(h)
output = self.output(h).float()
if return_hiddens:
return output, hiddens
return output
class ShortLlama():
def __init__(self, llama: Llama, n_prune_layers: Optional[int] = None):
checkpoint = llama.model.state_dict()
llama.model = TransformerWrapper(llama.model) # wrap transformer to collect hidden states
llama.model.load_state_dict(checkpoint, strict=False)
self.llama = llama
self.n_prune_layers = n_prune_layers
self.importances = [0 for _ in self.llama.model.layers] # layer-wise importance scores
def remove_layers(
self,
layers_to_remove: Optional[List[int]] = [],
angular: Optional[bool] = False
):
if angular:
assert self.importances, "Need to compute importances with eval_importance()"
assert self.n_prune_layers, "Need number of layers to prune, set `n_prune_layers`"
start_layer = np.argsort(np.array(self.importances[:-self.n_prune_layers+1]))[0]
layers_to_remove = list(range(start_layer, start_layer + self.n_prune_layers))
elif not layers_to_remove and self.n_prune_layers:
assert self.importances, "Need to compute importances with eval_importance()"
layers_to_remove = np.argsort(np.array(self.importances))[:self.n_prune_layers].tolist()
# remove layers in reverse to avoid indexing errors
for layer_idx in sorted(layers_to_remove, reverse=True):
try:
del self.llama.model.layers[layer_idx]
except IndexError:
print(f"layer {layer_idx} does not exist, function may have already been called")
return []
return layers_to_remove
def compute_bi(self, hiddens: List[torch.Tensor], angular: bool):
n = 1
if angular:
assert self.n_prune_layers is not None, "Set number of layers to prune to use angular importance"
n = self.n_prune_layers
for i in range(len(hiddens) - n):
in_hidden = hiddens[i]
out_hidden = hiddens[i+n]
if angular:
# use only last token for angular distance as described in section 3.2
# https://arxiv.org/pdf/2403.17887.pdf
in_hidden = in_hidden[:,-1:]
out_hidden = out_hidden[:,-1:]
self.importances[i] += block_influence(
in_hidden,
out_hidden,
angular=angular
).sum().cpu().item()
@torch.inference_mode()
def eval_importance(
self,
prompt_tokens: List[List[int]],
max_gen_len: Optional[int] = 0,
temperature: Optional[float] = 0.6,
top_p: Optional[float] = 0.9,
angular: Optional[bool] = False
):
"""
Computes layer-wise importances over input tokens.
NOTE: ShortGPT paper performs no generation during importance computation, which suggests a `max_gen_len`= 0.
Args:
prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers.
(Optional) max_gen_len (int): Maximum length of the generated text sequence.
(Optional) temperature (float): Temperature value for controlling randomness in sampling. Defaults to 0.6.
(Optional) top_p (float): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
(Optional) angular (bool): Whether to ues angular distance. Defaults to False.
Returns:
None
"""
params = self.llama.model.params
bsz = len(prompt_tokens)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
assert max_prompt_len <= params.max_seq_len
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
pad_id = self.llama.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device="cuda")
input_text_mask = tokens != pad_id
for cur_pos in range(min_prompt_len, total_len):
logits = self.llama.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
tokens[:, cur_pos] = next_token
eos_reached |= (~input_text_mask[:, cur_pos]) & (
next_token == self.llama.tokenizer.eos_id
)
prev_pos = cur_pos
if all(eos_reached):
break
# compute block influence over full sequences rather than at each token
_, hiddens = self.llama.model.forward(tokens, 0, return_hiddens=True)
self.compute_bi(hiddens, angular=angular)
return