|
from typing import List |
|
|
|
import torch |
|
|
|
|
|
def get_redrafter_specific_tensor_names() -> List[str]: |
|
return [ |
|
|
|
'device_request_types', |
|
'draft_tokens', |
|
'draft_indices', |
|
'draft_probs', |
|
'redrafter_inverted_temperature', |
|
'rand_data_sample', |
|
'rand_data_validation', |
|
'position_ids_base', |
|
|
|
'next_spec_decoding_generation_lengths', |
|
'next_spec_decoding_position_offsets', |
|
'spec_decoding_mask', |
|
'next_draft_tokens', |
|
'next_draft_indices', |
|
'next_draft_probs', |
|
'next_flat_tokens', |
|
'num_accepted_tokens', |
|
'accepted_beam_index', |
|
'max_gen_token', |
|
'total_gen_token', |
|
'packed_position_ids', |
|
] |
|
|
|
|
|
def get_redrafter_tensor_names() -> List[str]: |
|
return [ |
|
|
|
'spec_decoding_generation_lengths', |
|
'spec_decoding_position_offsets', |
|
'spec_decoding_packed_mask', |
|
] + get_redrafter_specific_tensor_names() |
|
|
|
|
|
def init_allocate_redrafter_tensors(session, batch_size): |
|
|
|
session.flat_tokens = torch.zeros( |
|
[batch_size * (session.max_draft_tokens + 1)], |
|
dtype=torch.int32, |
|
device=session.device) |
|
session.next_flat_tokens = torch.zeros( |
|
[batch_size * (session.max_draft_tokens + 1)], |
|
dtype=torch.int32, |
|
device=session.device) |
|
session.position_ids_base = torch.zeros([batch_size], |
|
dtype=torch.int32, |
|
device=session.device) |
|
session.packed_position_ids = torch.zeros( |
|
[batch_size * (session.max_draft_tokens + 1)], |
|
dtype=torch.int32, |
|
device=session.device) |
|
session.accept_lengths = torch.ones([batch_size], |
|
dtype=torch.int32, |
|
device=session.device) |
|
session.draft_tokens = torch.zeros([ |
|
batch_size, session._model_config.redrafter_num_beams, |
|
session._model_config.redrafter_draft_len_per_beam + 1 |
|
], |
|
dtype=torch.int32, |
|
device=session.device) |
|
session.draft_indices = torch.zeros([ |
|
batch_size, session._model_config.redrafter_num_beams, |
|
session._model_config.redrafter_draft_len_per_beam + 1 |
|
], |
|
dtype=torch.int32, |
|
device=session.device) |
|
session.draft_probs = torch.zeros([ |
|
batch_size, session._model_config.redrafter_num_beams, |
|
session._model_config.redrafter_draft_len_per_beam, session.vocab_size |
|
], |
|
dtype=session.dtype, |
|
device=session.device) |
|
session.next_draft_tokens = torch.zeros([ |
|
batch_size, session._model_config.redrafter_num_beams, |
|
session._model_config.redrafter_draft_len_per_beam + 1 |
|
], |
|
dtype=torch.int32, |
|
device=session.device) |
|
session.next_draft_indices = torch.zeros([ |
|
batch_size, session._model_config.redrafter_num_beams, |
|
session._model_config.redrafter_draft_len_per_beam + 1 |
|
], |
|
dtype=torch.int32, |
|
device=session.device) |
|
session.next_draft_probs = torch.zeros([ |
|
batch_size, session._model_config.redrafter_num_beams, |
|
session._model_config.redrafter_draft_len_per_beam, session.vocab_size |
|
], |
|
dtype=session.dtype, |
|
device=session.device) |
|
session.next_spec_decoding_position_offsets = torch.zeros( |
|
[batch_size, session.max_draft_tokens + 1], |
|
dtype=torch.int32, |
|
device=session.device) |
|
session.next_spec_decoding_generation_lengths = torch.zeros( |
|
[batch_size], dtype=torch.int32, device=session.device) |
|
|
|
session.spec_decoding_generation_lengths = torch.zeros( |
|
[batch_size], dtype=torch.int32, device=session.device) |
|
session.spec_decoding_mask = torch.zeros([ |
|
batch_size, session.max_draft_tokens + 1, session.max_draft_tokens + 1 |
|
], |
|
dtype=torch.bool, |
|
device=session.device) |
|
session.spec_decoding_packed_mask = torch.zeros([ |
|
batch_size * session.max_draft_tokens + 1, |
|
(session.max_draft_tokens + 1 + 31) // 32 |
|
], |
|
dtype=torch.int32, |
|
device=session.device) |
|
session.spec_decoding_position_offsets = torch.zeros( |
|
[batch_size, session.max_draft_tokens + 1], |
|
dtype=torch.int32, |
|
device=session.device) |
|
session.accepted_beam_index = torch.zeros([batch_size], |
|
dtype=torch.int32, |
|
device=session.device) |
|
session.max_gen_token = torch.zeros(1, |
|
dtype=torch.int32, |
|
device=session.device) |
|
session.total_gen_token = torch.zeros(1, |
|
dtype=torch.int32, |
|
device=session.device) |
|
torch.manual_seed(0) |
|
session.rand_data_sample = torch.rand([batch_size], |
|
dtype=session.dtype, |
|
device=session.device) |
|
|
|
session.rand_data_validation = torch.rand([ |
|
batch_size, session._model_config.redrafter_num_beams, |
|
session._model_config.redrafter_draft_len_per_beam |
|
], |
|
dtype=session.dtype, |
|
device=session.device) |
|
|
|
session.buffer['flat_tokens'] = session.flat_tokens |
|
session.buffer['next_flat_tokens'] = session.next_flat_tokens |
|
session.buffer['num_accepted_tokens'] = session.accept_lengths |
|
session.buffer['draft_tokens'] = session.draft_tokens |
|
session.buffer['draft_indices'] = session.draft_indices |
|
session.buffer['draft_probs'] = session.draft_probs |
|
session.buffer['accepted_beam_index'] = session.accepted_beam_index |
|
session.buffer[ |
|
'spec_decoding_generation_lengths'] = session.spec_decoding_generation_lengths |
|
session.buffer['spec_decoding_mask'] = session.spec_decoding_mask |
|
session.buffer[ |
|
'spec_decoding_position_offsets'] = session.spec_decoding_position_offsets |
|
session.buffer[ |
|
'spec_decoding_packed_mask'] = session.spec_decoding_packed_mask |
|
session.buffer['rand_data_sample'] = session.rand_data_sample |
|
session.buffer['rand_data_validation'] = session.rand_data_validation |
|
session.buffer[ |
|
'next_spec_decoding_generation_lengths'] = session.next_spec_decoding_generation_lengths |
|
session.buffer['next_draft_tokens'] = session.next_draft_tokens |
|
session.buffer['next_draft_indices'] = session.next_draft_indices |
|
session.buffer['next_draft_probs'] = session.next_draft_probs |
|
session.buffer[ |
|
'next_spec_decoding_position_offsets'] = session.next_spec_decoding_position_offsets |
|
session.buffer['max_gen_token'] = session.max_gen_token |
|
session.buffer['total_gen_token'] = session.total_gen_token |
|
session.buffer['position_ids_base'] = session.position_ids_base |
|
session.buffer['packed_position_ids'] = session.packed_position_ids |
|
|
|
return |
|
|
|
|
|
def set_redrafter_ctx_tensors(session, add_tensor, add_tensor_with_bs): |
|
|
|
add_tensor(session.buffer['next_spec_decoding_generation_lengths'], |
|
'next_spec_decoding_generation_lengths') |
|
add_tensor(session.buffer['next_spec_decoding_position_offsets'], |
|
'next_spec_decoding_position_offsets') |
|
add_tensor(session.buffer['spec_decoding_mask'], 'spec_decoding_mask') |
|
add_tensor(session.buffer['next_flat_tokens'], 'next_flat_tokens') |
|
add_tensor(session.buffer['next_draft_tokens'], 'next_draft_tokens') |
|
add_tensor(session.buffer['next_draft_indices'], 'next_draft_indices') |
|
add_tensor(session.buffer['next_draft_probs'], 'next_draft_probs') |
|
add_tensor(session.buffer['num_accepted_tokens'], 'num_accepted_tokens') |
|
add_tensor(session.buffer['accepted_beam_index'], 'accepted_beam_index') |
|
add_tensor(session.buffer['packed_position_ids'], 'packed_position_ids') |
|
|
|
add_tensor_with_bs(session.buffer['spec_decoding_generation_lengths'], |
|
'spec_decoding_generation_lengths', 0) |
|
add_tensor_with_bs(session.buffer['spec_decoding_position_offsets'], |
|
'spec_decoding_position_offsets', 0) |
|
add_tensor_with_bs(session.buffer['spec_decoding_packed_mask'], |
|
'spec_decoding_packed_mask', 0) |
|
add_tensor_with_bs(session.buffer['draft_tokens'], 'draft_tokens', 0) |
|
add_tensor_with_bs(session.buffer['draft_indices'], 'draft_indices', 0) |
|
add_tensor_with_bs(session.buffer['draft_probs'], 'draft_probs', 0) |
|
add_tensor_with_bs(session.buffer['rand_data_validation'], |
|
'rand_data_validation', 0) |
|
add_tensor(session.buffer['rand_data_sample'], 'rand_data_sample') |
|
add_tensor(session.buffer['redrafter_inverted_temperature'], |
|
'redrafter_inverted_temperature') |
|
add_tensor(session.buffer['max_gen_token'], 'max_gen_token') |
|
add_tensor(session.buffer['total_gen_token'], 'total_gen_token') |
|
add_tensor(session.buffer['position_ids_base'], 'position_ids_base') |
|
return |
|
|
|
|
|
def set_redrafter_gen_tensors(session, batch_size, add_tensor, |
|
add_tensor_with_shape): |
|
torch.cuda.nvtx.range_push("set_redrafter_gen_tensors") |
|
|
|
add_tensor(session.buffer['next_spec_decoding_generation_lengths'], |
|
'next_spec_decoding_generation_lengths') |
|
add_tensor(session.buffer['next_flat_tokens'], 'next_flat_tokens') |
|
add_tensor(session.buffer['next_draft_tokens'], 'next_draft_tokens') |
|
add_tensor(session.buffer['next_draft_indices'], 'next_draft_indices') |
|
add_tensor(session.buffer['next_draft_probs'], 'next_draft_probs') |
|
add_tensor(session.buffer['next_spec_decoding_position_offsets'], |
|
'next_spec_decoding_position_offsets') |
|
add_tensor(session.buffer['spec_decoding_mask'], 'spec_decoding_mask') |
|
add_tensor(session.buffer['num_accepted_tokens'], 'num_accepted_tokens') |
|
add_tensor(session.buffer['accepted_beam_index'], 'accepted_beam_index') |
|
add_tensor(session.buffer['packed_position_ids'], 'packed_position_ids') |
|
|
|
add_tensor(session.buffer['spec_decoding_generation_lengths'], |
|
'spec_decoding_generation_lengths') |
|
|
|
|
|
max_gen_len = session.host_max_gen_token |
|
position_offsets = session.buffer['spec_decoding_position_offsets'].view( |
|
-1)[:batch_size * max_gen_len] |
|
add_tensor_with_shape(position_offsets.view(batch_size, max_gen_len), |
|
'spec_decoding_position_offsets', |
|
(batch_size, max_gen_len)) |
|
add_tensor(session.buffer['spec_decoding_packed_mask'], |
|
'spec_decoding_packed_mask') |
|
add_tensor(session.buffer['draft_tokens'], 'draft_tokens') |
|
add_tensor(session.buffer['draft_indices'], 'draft_indices') |
|
add_tensor(session.buffer['draft_probs'], 'draft_probs') |
|
add_tensor(session.buffer['rand_data_validation'], 'rand_data_validation') |
|
add_tensor(session.buffer['rand_data_sample'], 'rand_data_sample') |
|
add_tensor(session.buffer['redrafter_inverted_temperature'], |
|
'redrafter_inverted_temperature') |
|
add_tensor(session.buffer['max_gen_token'], 'max_gen_token') |
|
add_tensor(session.buffer['total_gen_token'], 'total_gen_token') |
|
add_tensor(session.buffer['position_ids_base'], 'position_ids_base') |
|
torch.cuda.nvtx.range_pop() |
|
return |
|
|
|
|
|
def redrafter_convert_spec_decoding_mask_to_packed_mask( |
|
session, spec_decoding_generation_lengths): |
|
torch.cuda.nvtx.range_push("mask_conversion") |
|
torch.ops.tensorrt_llm.convert_spec_decoding_mask_to_packed_mask( |
|
spec_decoding_generation_lengths, session.spec_decoding_mask, |
|
session.max_draft_tokens, session.spec_decoding_packed_mask, None) |
|
torch.cuda.nvtx.range_pop() |
|
return |
|
|
|
|
|
def exchange_redrafter_buffers(session): |
|
|
|
def swap_buffers(name: str): |
|
next_name = "next_" + name |
|
session.buffer[name], session.buffer[next_name] = session.buffer[ |
|
next_name], session.buffer[name] |
|
|
|
torch.cuda.nvtx.range_push("exchange_redrafter_buffers") |
|
session.host_max_gen_token = session.buffer['max_gen_token'].cpu().item() |
|
session.host_total_gen_token = session.buffer['total_gen_token'].cpu().item( |
|
) |
|
swap_buffers('spec_decoding_generation_lengths') |
|
swap_buffers('spec_decoding_position_offsets') |
|
swap_buffers('draft_probs') |
|
swap_buffers('draft_indices') |
|
swap_buffers('draft_tokens') |
|
swap_buffers("flat_tokens") |
|
torch.cuda.nvtx.range_pop() |
|
return |
|
|
|
|
|
def process_redrafter_outputs(session, step, batch_size, last_draft_tokens, |
|
new_draft_tokens): |
|
torch.cuda.nvtx.range_push("process_redrafter_outputs") |
|
best_path = session.buffer["accepted_beam_index"] |
|
session.accept_lengths = best_path_lengths = session.buffer[ |
|
"num_accepted_tokens"] |
|
accepted_tokens = [None] * batch_size |
|
|
|
for b in range(batch_size): |
|
torch.cuda.nvtx.range_push(f"accept_tokens_{b}") |
|
|
|
accepted_tokens[b] = new_draft_tokens[b, 0, :1] |
|
if step > 0: |
|
verified_tokens = last_draft_tokens[b, best_path[b], |
|
1:best_path_lengths[b]] |
|
accepted_tokens[b] = torch.concat( |
|
[verified_tokens, accepted_tokens[b]]) |
|
torch.cuda.nvtx.range_pop() |
|
|
|
session.new_tokens = torch.nested.to_padded_tensor( |
|
torch.nested.nested_tensor(accepted_tokens, dtype=torch.int32), |
|
session.end_ids[0]) |
|
torch.cuda.nvtx.range_pop() |
|
return best_path, best_path_lengths |
|
|