Do0rMaMu's picture
Upload folder using huggingface_hub
e45d058 verified
# Copyright (c) 2023, Tri Dao.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
import gc
import time
from collections import namedtuple
from dataclasses import dataclass, field
from functools import partial
from typing import Callable, Optional, Sequence, Union
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import Tensor
from torch.profiler import ProfilerActivity, profile, record_function
try:
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
except ImportError:
GreedySearchDecoderOnlyOutput = namedtuple("GreedySearchDecoderOnlyOutput", ["sequences", "scores"])
SampleDecoderOnlyOutput = namedtuple("SampleDecoderOnlyOutput", ["sequences", "scores"])
@dataclass
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_seqlen: int
max_batch_size: int
seqlen_offset: int = 0
batch_size_offset: int = 0
key_value_memory_dict: dict = field(default_factory=dict)
lengths_per_sample: Optional[Tensor] = None
def reset(self, max_seqlen, max_batch_size):
self.max_seqlen = max_seqlen
self.max_batch_size = max_batch_size
self.seqlen_offset = 0
if self.lengths_per_sample is not None:
self.lengths_per_sample.zero_()
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
def modify_logits_for_top_k_filtering(logits, top_k):
"""Set the logits for none top-k values to -inf. Done in-place."""
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits.masked_fill_(indices_to_remove, float("-Inf"))
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
def modify_logits_for_top_p_filtering(logits, top_p):
"""Set the logits for none top-p values to -inf. Done in-place."""
if top_p <= 0.0 or top_p >= 1.0:
return
# First sort and calculate cumulative sum of probabilities.
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
logits.masked_fill_(indices_to_remove, float("-inf"))
def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
"""Sample from top-k logits.
Arguments:
logits: Tensor of shape (batch_size, vocab_size)
"""
if top_k == 1: # Short-circuit for greedy decoding
return logits.argmax(dim=-1)
else:
if top_p > 0.0:
assert top_p <= 1.0, "top-p should be in (0, 1]."
if top_k > 0:
top_k = min(top_k, logits.size(-1)) # Safety check
logits_top, indices = torch.topk(logits, top_k, dim=-1)
if temperature != 1.0:
logits_top /= temperature
modify_logits_for_top_p_filtering(logits_top, top_p)
return indices[
torch.arange(indices.shape[0], device=indices.device),
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
]
else:
# Clone so that when we modify for top_p we don't change the original logits
logits_top = logits / temperature if temperature != 1.0 else logits.clone()
modify_logits_for_top_p_filtering(logits_top, top_p)
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
dim=-1
)
@torch.inference_mode()
def decode(
input_ids,
model,
max_length,
top_k=1,
top_p=0.0,
temperature=1.0,
eos_token_id=None,
teacher_outputs=None,
vocab_size=None,
tensor_parallel=1,
cg=False,
enable_timing=False,
):
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
then top-p.
We assume that all sequences in the same batch have the same length.
Arguments:
input_ids: (batch, seq_len)
max_length: int
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
logits, the next token is taken from the teacher_outputs. Useful for testing.
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
sequences: (batch, max_length)
scores: tuples of (batch, vocab_size)
"""
batch_size, seqlen_og = input_ids.shape
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
if cg:
if not hasattr(model, "_decoding_cache"):
model._decoding_cache = None
model._decoding_cache = update_graph_cache(
model,
model._decoding_cache,
batch_size,
seqlen_og,
max_length,
tensor_parallel=tensor_parallel,
)
inference_params = model._decoding_cache.inference_params
inference_params.reset(max_length, batch_size)
else:
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
def get_logits(input_ids, inference_params):
decoding = inference_params.seqlen_offset > 0
if decoding:
position_ids = torch.full(
(batch_size, 1),
inference_params.seqlen_offset,
dtype=torch.long,
device=input_ids.device,
)
else:
position_ids = None
if not cg or not decoding:
logits = model(
input_ids,
position_ids=position_ids,
inference_params=inference_params,
num_last_tokens=1,
).logits.squeeze(dim=1)
else:
logits = model._decoding_cache.run(
input_ids, position_ids, inference_params.seqlen_offset
).squeeze(dim=1)
return logits[..., :vocab_size] if vocab_size is not None else logits
def sample_tokens(logits, inference_params):
if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
else:
token = teacher_outputs[:, inference_params.seqlen_offset]
# return rearrange(token, "b -> b 1")
return token.unsqueeze(1)
def should_stop(current_token, inference_params):
if inference_params.seqlen_offset == 0:
return False
if eos_token_id is not None and (current_token == eos_token_id).all():
return True
if inference_params.seqlen_offset >= max_length - 1:
return True
return False
start = torch.cuda.Event(enable_timing=enable_timing)
end = torch.cuda.Event(enable_timing=enable_timing)
if enable_timing:
if tensor_parallel > 1:
torch.distributed.barrier()
start.record()
scores, sequences = [], [input_ids]
while not should_stop(sequences[-1], inference_params):
scores.append(get_logits(sequences[-1], inference_params))
inference_params.seqlen_offset += sequences[-1].shape[1]
sequences.append(sample_tokens(scores[-1], inference_params))
if enable_timing:
end.record()
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, temperature=1.0):
"""Algorithm 1 from [1]
[1] Fast Inference from Transformers via Speculative Decoding
Yaniv Leviathan, Matan Kalman, Yossi Matias
https://arxiv.org/abs/2211.17192
Arguments:
logits: Tensor of shape (batch_size, seqlen + 1, vocab_size)
logits_draft: Tensor of shape (batch_size, seqlen, vocab_size)
tokens_draft: Tensor of shape (batch_size, seqlen)
Return:
tokens: Tensor of shape (batch_size, seqlen + 1)
num_generated_tokens: Tensor of shape (batch_size), with value in [1, seqlen + 1].
For each sequence in the batch, the number of valid tokens that were sampled by
speculative sampling.
"""
batch, seqlen_p_1, vocab_size = logits.shape
seqlen = seqlen_p_1 - 1
assert logits_draft.shape == (batch, seqlen, vocab_size)
assert tokens_draft.shape == (batch, seqlen)
assert tokens_draft.dtype in [torch.int64, torch.int32]
# TODO: if top_k = 1 we can simplify things and only work with indices
if top_p > 0.0:
assert top_p <= 1.0, "top-p should be in (0, 1]."
# Clone so that when we modify for top_p we don't change the original logits
logits = logits / temperature if temperature != 1.0 else logits.clone()
logits_draft = logits_draft / temperature if temperature != 1.0 else logits_draft.clone()
if top_k > 0:
top_k = min(top_k, logits.size(-1)) # Safety check
modify_logits_for_top_k_filtering(logits, top_k)
modify_logits_for_top_k_filtering(logits_draft, top_k)
modify_logits_for_top_p_filtering(logits, top_p)
modify_logits_for_top_p_filtering(logits_draft, top_p)
probs = torch.softmax(logits, dim=-1)
probs_draft = torch.softmax(logits_draft, dim=-1)
gather = lambda probs, tokens: rearrange(
probs.gather(dim=-1, index=rearrange(tokens, "... -> ... 1")), "... 1 -> ..."
)
# (batch, seqlen)
accepted = torch.rand(batch, seqlen, device=probs.device) * gather(
probs_draft, tokens_draft
) <= gather(probs[:, :-1], tokens_draft)
accepted_all = accepted.all(dim=-1)
# (batch,)
first_rejected_idx = torch.where(accepted_all, seqlen, accepted.int().argmin(dim=-1))
probs_diff = torch.clamp(probs[:, :-1] - probs_draft, min=0.0)
# torch.multinomial can deal with unnormalized probabilities
# probs_diff /= probs_diff.sum(dim=-1, keepdim=True)
resample_probs = torch.cat([probs_diff, probs[:, -1:]], dim=1)
resample_probs = rearrange(
resample_probs.gather(dim=1, index=repeat(first_rejected_idx, "b -> b 1 d", d=vocab_size)),
"b 1 d -> b d",
)
resample = torch.multinomial(resample_probs, num_samples=1).squeeze(dim=-1) # (batch,)
tokens = F.pad(tokens_draft, (0, 1))
tokens[:, first_rejected_idx] = resample
return tokens, first_rejected_idx + 1
@torch.inference_mode()
def decode_speculative(
input_ids,
model,
model_draft,
max_length,
speculative_lookahead=3,
top_k=1,
top_p=0.0,
temperature=1.0,
eos_token_id=None,
vocab_size=None,
tensor_parallel=1,
cg=False,
enable_timing=False,
debug=False,
):
"""
TD: WIP, for my own understanding, lightly tested. Only support batch_size == 1 for now.
Speculative decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
then top-p.
We assume that all sequences in the same batch have the same length.
Arguments:
input_ids: (batch, seq_len)
max_length: int
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
sequences: (batch, max_length)
scores: tuples of (batch, vocab_size)
"""
batch_size, seqlen_og = input_ids.shape
assert batch_size == 1, "Speculative decoding implementation only supports batch_size=1"
assert eos_token_id is None, "Speculative decoding implementation doesn't support eos_token_id"
if cg:
if not hasattr(model_draft, "_decoding_cache"):
model_draft._decoding_cache = None
model_draft._decoding_cache = update_graph_cache(
model_draft,
model_draft._decoding_cache,
batch_size,
seqlen_og,
max_length,
# draft model needs to process either 1 or 2 tokens at a time
decoding_seqlens=(1, 2),
tensor_parallel=tensor_parallel,
)
inference_params_draft = model_draft._decoding_cache.inference_params
inference_params_draft.reset(max_length, batch_size)
if not hasattr(model, "_decoding_cache"):
model._decoding_cache = None
model._decoding_cache = update_graph_cache(
model,
model._decoding_cache,
batch_size,
seqlen_og,
max_length,
decoding_seqlens=range(1, speculative_lookahead + 2),
tensor_parallel=tensor_parallel,
)
inference_params = model._decoding_cache.inference_params
inference_params.reset(max_length, batch_size)
else:
inference_params_draft = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
def get_logits(input_ids, inference_params, model, num_last_tokens=1, cg=False):
decoding = inference_params.seqlen_offset > 0
if decoding:
seqlen = input_ids.shape[1]
# if inference_params.lengths_per_sample is None:
# TODO: in the case of batched decoding where each sequence has a different length,
# we need to compute the position_ids for each sequence using lengths_per_sample
if True:
cache_seqlens = torch.full(
(input_ids.shape[0],),
inference_params.seqlen_offset,
dtype=torch.int32,
device=input_ids.device,
)
else:
cache_seqlens = inference_params.lengths_per_sample
position_ids = cache_seqlens[:, None] + torch.arange(
seqlen, dtype=torch.long, device=input_ids.device
)
else:
position_ids = None
if not cg or not decoding:
logits = model(
input_ids,
position_ids=position_ids,
inference_params=inference_params,
num_last_tokens=num_last_tokens,
).logits
else:
# NOTE: careful, CUDA graph is set to have num_last_tokens=input_ids.shape[1].
# This might not be compatible the num_last_tokens used here.
assert num_last_tokens <= input_ids.shape[1]
logits = model._decoding_cache.run(
input_ids, position_ids, inference_params.seqlen_offset
)[:, -num_last_tokens:]
return logits[..., :vocab_size] if vocab_size is not None else logits
def sample_tokens(input_ids, get_logits_fn, inference_params, sample_fn, num_tokens=1):
"""Sample `num_tokens` tokens from the model, given the previous logits.
Also return the logits of the sampled tokens.
Arguments:
input_ids: (batch, seqlen)
Return:
tokens: (batch, num_tokens)
scores: (batch, num_tokens), which contains @previous_logits and the logits of the next
(num_tokens - 1) tokens. The logits of the last token isn't computed.
"""
assert num_tokens >= 1
sequences, scores = [input_ids], []
for i in range(num_tokens):
scores.append(get_logits_fn(sequences[-1], inference_params)[:, -1])
inference_params.seqlen_offset += sequences[-1].shape[1]
sequences.append(sample_fn(scores[-1]).unsqueeze(1))
return torch.cat(sequences[1:], dim=1), torch.stack(scores, dim=1)
sampling_kwargs = dict(top_k=top_k, top_p=top_p, temperature=temperature)
sample_fn = partial(sample, **sampling_kwargs)
get_logits_main = partial(get_logits, model=model, cg=cg)
get_logits_draft = partial(get_logits, model=model_draft, cg=cg)
sample_tokens_main = partial(
sample_tokens,
get_logits_fn=get_logits_main,
sample_fn=sample_fn,
inference_params=inference_params,
)
sample_tokens_draft = partial(
sample_tokens,
get_logits_fn=get_logits_draft,
sample_fn=sample_fn,
inference_params=inference_params_draft,
)
if debug:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
if enable_timing:
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize()
start = time.time()
sequences, scores = [input_ids], []
num_main_model_calls = 0
num_draft_tokens = 0
num_accepted_tokens_history = []
if seqlen_og >= max_length - 1:
# Don't do speculative sampling, just sample 1 token from the model
tokens, scores_new = sample_tokens_main(input_ids, num_tokens=1)
sequences.append(tokens)
scores.append(scores_new)
else:
# Sample from draft model, which produces @n_spec_tokens, and @model
# will then use to produce between 1 and 1 + @n_spec_tokens tokens.
# We want seqlen_og + 1 + @n_spec_tokens to be <= @max_length.
n_spec_tokens = min(speculative_lookahead, max_length - seqlen_og - 1)
tokens_draft, scores_draft = sample_tokens_draft(input_ids, num_tokens=n_spec_tokens)
num_draft_tokens += n_spec_tokens
if debug:
scores_draft_ref = model_draft(
torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
).logits
print((scores_draft - scores_draft_ref[:, :-1]).abs().max())
# Evaluate the draft tokens with the model
logits = get_logits_main(
torch.cat([input_ids, tokens_draft], dim=1),
inference_params,
num_last_tokens=n_spec_tokens + 1,
)
num_main_model_calls += 1
if debug:
logits_ref = model(
torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
).logits
print((logits - logits_ref).abs().max())
# breakpoint()
tokens, num_generated_tokens = sample_speculative(
logits, scores_draft, tokens_draft, **sampling_kwargs
)
num_accepted_tokens_history.append(num_generated_tokens - 1)
if debug:
print(tokens)
print(num_generated_tokens)
# breakpoint()
# TODO: we're using the fact that batch_size == 1
# TODO: check eos_token_id
sequences.append(tokens[:1, : num_generated_tokens[0]])
scores.append(logits[:1, : num_generated_tokens[0]])
# Note that @model has not evaluated the last sampled token yet, so we'll need to pass
# that in the next time we call @model.
num_generated = num_generated_tokens[0].item()
inference_params.seqlen_offset = seqlen_og + num_generated - 1
inference_params_draft.seqlen_offset = (
inference_params.seqlen_offset - 1
if num_generated > 1
else inference_params.seqlen_offset
)
if debug:
cur_ids = torch.cat([input_ids, sequences[-1]], dim=1)
scores_ref = model(cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1).logits
print((scores[-1] - scores_ref[:, :-1]).abs().max())
# breakpoint()
while True:
# seqlen_offset is total length generated - 1
if inference_params.seqlen_offset >= max_length - 1:
break
if inference_params.seqlen_offset >= max_length - 2:
# Don't do speculative sampling, just sample 1 token from the model
tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1)
sequences.append(tokens)
scores.append(scores_new)
break
# Sample from draft model
n_spec_tokens = min(
speculative_lookahead, max_length - inference_params_draft.seqlen_offset - 2
)
# If the main model accepts all the draft tokens, plus it samples one new token,
# then at the next iteration the draft model need to evaluate the logits of the last draft
# token and the logits of the newly sampled token. So here we pass in the last 2 tokens
# of sequences[-1].
# This exception is when the main model rejects all the draft tokens, in which case we
# will only have 1 token to pass in.
tokens_draft, scores_draft = sample_tokens_draft(
sequences[-1][:, -2:], num_tokens=n_spec_tokens
)
num_draft_tokens += n_spec_tokens
if debug:
scores_draft_ref = model_draft(
torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
).logits
print((scores_draft - scores_draft_ref[:, :-1]).abs().max())
# breakpoint()
# Evaluate the draft tokens with the model
logits = get_logits_main(
torch.cat([sequences[-1][:, -1:], tokens_draft], dim=1),
inference_params,
num_last_tokens=n_spec_tokens + 1,
) # (batch, n_spec_tokens + 1, vocab_size)
num_main_model_calls += 1
if debug:
logits_ref = model(
torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1
).logits
print((logits - logits_ref).abs().max())
# breakpoint()
tokens, num_generated_tokens = sample_speculative(
logits, scores_draft, tokens_draft, **sampling_kwargs
)
num_accepted_tokens_history.append(num_generated_tokens - 1)
if debug:
print(tokens)
print(num_generated_tokens)
# breakpoint()
sequences.append(tokens[:1, : num_generated_tokens[0]])
scores.append(logits[:1, : num_generated_tokens[0]])
# We've evaluated 1 token from sequences[-1][:, -1:] above, plus
# num_generated_tokens[0].item() - 1 tokens from the draft model.
num_generated = num_generated_tokens[0].item()
inference_params.seqlen_offset += num_generated
inference_params_draft.seqlen_offset = (
inference_params.seqlen_offset - 1
if num_generated > 1
else inference_params.seqlen_offset
)
if debug:
cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1)
scores_ref = model(cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1).logits
print((scores[-1] - scores_ref[:, :-1]).abs().max())
# breakpoint()
if enable_timing:
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
print(f"Number of calls to main model: {num_main_model_calls}")
print(
f"Acceptance rate: {torch.cat(num_accepted_tokens_history).sum().item() / num_draft_tokens * 100:.2f}%"
)
sequences = torch.cat(sequences, dim=1)
scores = torch.cat(scores, dim=1)
if debug:
scores_ref = model(sequences).logits
print((scores - scores_ref[:, seqlen_og - 1 : -1]).abs().max())
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
return output_cls(sequences=sequences, scores=scores)
class GenerationMixin:
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
raise NotImplementedError
def generate(
self,
input_ids,
max_length,
top_k=1,
top_p=0.0,
temperature=1.0,
return_dict_in_generate=False,
output_scores=False,
**kwargs,
):
output = decode(
input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs
)
if not output_scores:
output.scores = None
return output if return_dict_in_generate else output.sequences
def allocate_inference_cache(
max_batch_size,
max_seqlen,
nheads,
headdim,
layers: Union[int, Sequence],
device,
dtype=torch.float16,
):
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim)
if isinstance(layers, int):
layers = range(layers)
return {i: torch.empty(kv_cache_shape, device=device, dtype=dtype) for i in layers}
@dataclass
class DecodingCGCache:
max_batch_size: int = 0
max_seqlen: int = 0
device = None
dtype = None
callables: dict = field(default_factory=dict)
mempool = None
inference_params: Optional[InferenceParams] = None
run: Optional[Callable] = None
@torch.inference_mode()
def update_graph_cache(
model,
cache,
batch_size,
seqlen_og,
max_seqlen,
decoding_seqlens=(1,),
tensor_parallel=1,
dtype=None,
n_warmups=2,
):
if cache is None:
cache = DecodingCGCache()
param_example = next(iter(model.parameters()))
device = param_example.device
if dtype is None:
dtype = param_example.dtype
if (
(device, dtype) != (cache.device, cache.dtype)
or batch_size > cache.max_batch_size
or max_seqlen > cache.max_seqlen
): # Invalidate the cache
cache.callables = {}
cache.mempool = None
cache.inference_params = None
gc.collect()
cache.device, cache.dtype = device, dtype
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
if hasattr(model, "allocate_inference_cache"):
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
else:
headdim = getattr(
model.config,
"head_dim",
model.config.hidden_size // model.config.num_attention_heads,
)
inf_cache = allocate_inference_cache(
batch_size,
max_seqlen,
model.config.num_attention_heads // tensor_parallel,
headdim,
model.config.num_hidden_layers,
device,
dtype,
)
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
cache.inference_params = InferenceParams(
max_seqlen=max_seqlen,
max_batch_size=batch_size,
seqlen_offset=seqlen_og,
key_value_memory_dict=inf_cache,
lengths_per_sample=lengths_per_sample,
)
cache.mempool = torch.cuda.graphs.graph_pool_handle()
for decoding_seqlen in decoding_seqlens:
if (batch_size, decoding_seqlen) not in cache.callables:
cache.callables[batch_size, decoding_seqlen] = capture_graph(
model,
cache.inference_params,
batch_size,
max_seqlen,
decoding_seqlen=decoding_seqlen,
mempool=cache.mempool,
n_warmups=n_warmups,
)
def dispatch(input_ids, position_ids, seqlen):
batch_size, decoding_seqlen = input_ids.shape[:2]
return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
cache.run = dispatch
cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
return cache
def capture_graph(
model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
):
device = next(iter(model.parameters())).device
input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
seqlen_offset_og = inference_params.seqlen_offset
inference_params.seqlen_offset = max_seqlen - decoding_seqlen
inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
# Warmup before capture
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(n_warmups):
logits = model(
input_ids,
position_ids=position_ids,
inference_params=inference_params,
num_last_tokens=decoding_seqlen,
).logits
s.synchronize()
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
# which requires that graph launch and non-captured launch to not overlap (I think,
# that's how I interpret the documentation). I'm not sure if this is required.
if torch.distributed.is_initialized():
torch.distributed.barrier()
torch.cuda.current_stream().wait_stream(s)
# Captures the graph
# To allow capture, automatically sets a side stream as the current stream in the context
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=mempool):
logits = model(
input_ids,
position_ids=position_ids,
inference_params=inference_params,
num_last_tokens=decoding_seqlen,
).logits
def run(new_input_ids, new_position_ids, seqlen):
inference_params.lengths_per_sample[:] = seqlen
input_ids.copy_(new_input_ids)
position_ids.copy_(new_position_ids)
graph.replay()
return logits.clone()
inference_params.seqlen_offset = seqlen_offset_og
return run