from typing import List import torch def get_redrafter_specific_tensor_names() -> List[str]: return [ # inputs 'device_request_types', 'draft_tokens', 'draft_indices', 'draft_probs', 'redrafter_inverted_temperature', 'rand_data_sample', 'rand_data_validation', 'position_ids_base', # outputs 'next_spec_decoding_generation_lengths', 'next_spec_decoding_position_offsets', 'spec_decoding_mask', 'next_draft_tokens', 'next_draft_indices', 'next_draft_probs', 'next_flat_tokens', 'num_accepted_tokens', 'accepted_beam_index', 'max_gen_token', 'total_gen_token', 'packed_position_ids', ] def get_redrafter_tensor_names() -> List[str]: return [ # inputs 'spec_decoding_generation_lengths', 'spec_decoding_position_offsets', 'spec_decoding_packed_mask', ] + get_redrafter_specific_tensor_names() def init_allocate_redrafter_tensors(session, batch_size): # define the buffers for ReDrafter session.flat_tokens = torch.zeros( [batch_size * (session.max_draft_tokens + 1)], dtype=torch.int32, device=session.device) session.next_flat_tokens = torch.zeros( [batch_size * (session.max_draft_tokens + 1)], dtype=torch.int32, device=session.device) session.position_ids_base = torch.zeros([batch_size], dtype=torch.int32, device=session.device) session.packed_position_ids = torch.zeros( [batch_size * (session.max_draft_tokens + 1)], dtype=torch.int32, device=session.device) session.accept_lengths = torch.ones([batch_size], dtype=torch.int32, device=session.device) session.draft_tokens = torch.zeros([ batch_size, session._model_config.redrafter_num_beams, session._model_config.redrafter_draft_len_per_beam + 1 ], dtype=torch.int32, device=session.device) session.draft_indices = torch.zeros([ batch_size, session._model_config.redrafter_num_beams, session._model_config.redrafter_draft_len_per_beam + 1 ], dtype=torch.int32, device=session.device) session.draft_probs = torch.zeros([ batch_size, session._model_config.redrafter_num_beams, session._model_config.redrafter_draft_len_per_beam, session.vocab_size ], dtype=session.dtype, device=session.device) session.next_draft_tokens = torch.zeros([ batch_size, session._model_config.redrafter_num_beams, session._model_config.redrafter_draft_len_per_beam + 1 ], dtype=torch.int32, device=session.device) session.next_draft_indices = torch.zeros([ batch_size, session._model_config.redrafter_num_beams, session._model_config.redrafter_draft_len_per_beam + 1 ], dtype=torch.int32, device=session.device) session.next_draft_probs = torch.zeros([ batch_size, session._model_config.redrafter_num_beams, session._model_config.redrafter_draft_len_per_beam, session.vocab_size ], dtype=session.dtype, device=session.device) session.next_spec_decoding_position_offsets = torch.zeros( [batch_size, session.max_draft_tokens + 1], dtype=torch.int32, device=session.device) session.next_spec_decoding_generation_lengths = torch.zeros( [batch_size], dtype=torch.int32, device=session.device) session.spec_decoding_generation_lengths = torch.zeros( [batch_size], dtype=torch.int32, device=session.device) session.spec_decoding_mask = torch.zeros([ batch_size, session.max_draft_tokens + 1, session.max_draft_tokens + 1 ], dtype=torch.bool, device=session.device) session.spec_decoding_packed_mask = torch.zeros([ batch_size * session.max_draft_tokens + 1, (session.max_draft_tokens + 1 + 31) // 32 ], dtype=torch.int32, device=session.device) session.spec_decoding_position_offsets = torch.zeros( [batch_size, session.max_draft_tokens + 1], dtype=torch.int32, device=session.device) session.accepted_beam_index = torch.zeros([batch_size], dtype=torch.int32, device=session.device) session.max_gen_token = torch.zeros(1, dtype=torch.int32, device=session.device) session.total_gen_token = torch.zeros(1, dtype=torch.int32, device=session.device) torch.manual_seed(0) # use seed=0 for context session.rand_data_sample = torch.rand([batch_size], dtype=session.dtype, device=session.device) # print(session.rand_data_sample) session.rand_data_validation = torch.rand([ batch_size, session._model_config.redrafter_num_beams, session._model_config.redrafter_draft_len_per_beam ], dtype=session.dtype, device=session.device) # print(session.rand_data_validation) session.buffer['flat_tokens'] = session.flat_tokens session.buffer['next_flat_tokens'] = session.next_flat_tokens session.buffer['num_accepted_tokens'] = session.accept_lengths session.buffer['draft_tokens'] = session.draft_tokens session.buffer['draft_indices'] = session.draft_indices session.buffer['draft_probs'] = session.draft_probs session.buffer['accepted_beam_index'] = session.accepted_beam_index session.buffer[ 'spec_decoding_generation_lengths'] = session.spec_decoding_generation_lengths session.buffer['spec_decoding_mask'] = session.spec_decoding_mask session.buffer[ 'spec_decoding_position_offsets'] = session.spec_decoding_position_offsets session.buffer[ 'spec_decoding_packed_mask'] = session.spec_decoding_packed_mask session.buffer['rand_data_sample'] = session.rand_data_sample session.buffer['rand_data_validation'] = session.rand_data_validation session.buffer[ 'next_spec_decoding_generation_lengths'] = session.next_spec_decoding_generation_lengths session.buffer['next_draft_tokens'] = session.next_draft_tokens session.buffer['next_draft_indices'] = session.next_draft_indices session.buffer['next_draft_probs'] = session.next_draft_probs session.buffer[ 'next_spec_decoding_position_offsets'] = session.next_spec_decoding_position_offsets session.buffer['max_gen_token'] = session.max_gen_token session.buffer['total_gen_token'] = session.total_gen_token session.buffer['position_ids_base'] = session.position_ids_base session.buffer['packed_position_ids'] = session.packed_position_ids # NOTE: device_request_types is created with host_request_types return def set_redrafter_ctx_tensors(session, add_tensor, add_tensor_with_bs): # Add all output tensors add_tensor(session.buffer['next_spec_decoding_generation_lengths'], 'next_spec_decoding_generation_lengths') add_tensor(session.buffer['next_spec_decoding_position_offsets'], 'next_spec_decoding_position_offsets') add_tensor(session.buffer['spec_decoding_mask'], 'spec_decoding_mask') add_tensor(session.buffer['next_flat_tokens'], 'next_flat_tokens') add_tensor(session.buffer['next_draft_tokens'], 'next_draft_tokens') add_tensor(session.buffer['next_draft_indices'], 'next_draft_indices') add_tensor(session.buffer['next_draft_probs'], 'next_draft_probs') add_tensor(session.buffer['num_accepted_tokens'], 'num_accepted_tokens') add_tensor(session.buffer['accepted_beam_index'], 'accepted_beam_index') add_tensor(session.buffer['packed_position_ids'], 'packed_position_ids') # add all input tensors add_tensor_with_bs(session.buffer['spec_decoding_generation_lengths'], 'spec_decoding_generation_lengths', 0) add_tensor_with_bs(session.buffer['spec_decoding_position_offsets'], 'spec_decoding_position_offsets', 0) add_tensor_with_bs(session.buffer['spec_decoding_packed_mask'], 'spec_decoding_packed_mask', 0) add_tensor_with_bs(session.buffer['draft_tokens'], 'draft_tokens', 0) add_tensor_with_bs(session.buffer['draft_indices'], 'draft_indices', 0) add_tensor_with_bs(session.buffer['draft_probs'], 'draft_probs', 0) add_tensor_with_bs(session.buffer['rand_data_validation'], 'rand_data_validation', 0) add_tensor(session.buffer['rand_data_sample'], 'rand_data_sample') add_tensor(session.buffer['redrafter_inverted_temperature'], 'redrafter_inverted_temperature') add_tensor(session.buffer['max_gen_token'], 'max_gen_token') add_tensor(session.buffer['total_gen_token'], 'total_gen_token') add_tensor(session.buffer['position_ids_base'], 'position_ids_base') return def set_redrafter_gen_tensors(session, batch_size, add_tensor, add_tensor_with_shape): torch.cuda.nvtx.range_push("set_redrafter_gen_tensors") # add output tensors add_tensor(session.buffer['next_spec_decoding_generation_lengths'], 'next_spec_decoding_generation_lengths') add_tensor(session.buffer['next_flat_tokens'], 'next_flat_tokens') add_tensor(session.buffer['next_draft_tokens'], 'next_draft_tokens') add_tensor(session.buffer['next_draft_indices'], 'next_draft_indices') add_tensor(session.buffer['next_draft_probs'], 'next_draft_probs') add_tensor(session.buffer['next_spec_decoding_position_offsets'], 'next_spec_decoding_position_offsets') add_tensor(session.buffer['spec_decoding_mask'], 'spec_decoding_mask') add_tensor(session.buffer['num_accepted_tokens'], 'num_accepted_tokens') add_tensor(session.buffer['accepted_beam_index'], 'accepted_beam_index') add_tensor(session.buffer['packed_position_ids'], 'packed_position_ids') # add all input tensors add_tensor(session.buffer['spec_decoding_generation_lengths'], 'spec_decoding_generation_lengths') # position offsets vary for ReDrafter and should already be updated at this point. # Just need to provide the updated shape for ReDrafter. max_gen_len = session.host_max_gen_token position_offsets = session.buffer['spec_decoding_position_offsets'].view( -1)[:batch_size * max_gen_len] add_tensor_with_shape(position_offsets.view(batch_size, max_gen_len), 'spec_decoding_position_offsets', (batch_size, max_gen_len)) add_tensor(session.buffer['spec_decoding_packed_mask'], 'spec_decoding_packed_mask') add_tensor(session.buffer['draft_tokens'], 'draft_tokens') add_tensor(session.buffer['draft_indices'], 'draft_indices') add_tensor(session.buffer['draft_probs'], 'draft_probs') add_tensor(session.buffer['rand_data_validation'], 'rand_data_validation') add_tensor(session.buffer['rand_data_sample'], 'rand_data_sample') add_tensor(session.buffer['redrafter_inverted_temperature'], 'redrafter_inverted_temperature') add_tensor(session.buffer['max_gen_token'], 'max_gen_token') add_tensor(session.buffer['total_gen_token'], 'total_gen_token') add_tensor(session.buffer['position_ids_base'], 'position_ids_base') torch.cuda.nvtx.range_pop() return def redrafter_convert_spec_decoding_mask_to_packed_mask( session, spec_decoding_generation_lengths): torch.cuda.nvtx.range_push("mask_conversion") torch.ops.tensorrt_llm.convert_spec_decoding_mask_to_packed_mask( spec_decoding_generation_lengths, session.spec_decoding_mask, session.max_draft_tokens, session.spec_decoding_packed_mask, None) torch.cuda.nvtx.range_pop() return def exchange_redrafter_buffers(session): # NOTE: shouldn't incur any copies def swap_buffers(name: str): next_name = "next_" + name session.buffer[name], session.buffer[next_name] = session.buffer[ next_name], session.buffer[name] torch.cuda.nvtx.range_push("exchange_redrafter_buffers") session.host_max_gen_token = session.buffer['max_gen_token'].cpu().item() session.host_total_gen_token = session.buffer['total_gen_token'].cpu().item( ) swap_buffers('spec_decoding_generation_lengths') swap_buffers('spec_decoding_position_offsets') swap_buffers('draft_probs') swap_buffers('draft_indices') swap_buffers('draft_tokens') swap_buffers("flat_tokens") torch.cuda.nvtx.range_pop() return def process_redrafter_outputs(session, step, batch_size, last_draft_tokens, new_draft_tokens): torch.cuda.nvtx.range_push("process_redrafter_outputs") best_path = session.buffer["accepted_beam_index"] session.accept_lengths = best_path_lengths = session.buffer[ "num_accepted_tokens"] accepted_tokens = [None] * batch_size # print(best_path, best_path_lengths) for b in range(batch_size): torch.cuda.nvtx.range_push(f"accept_tokens_{b}") # use new beam0 to get latest true token accepted_tokens[b] = new_draft_tokens[b, 0, :1] if step > 0: verified_tokens = last_draft_tokens[b, best_path[b], 1:best_path_lengths[b]] accepted_tokens[b] = torch.concat( [verified_tokens, accepted_tokens[b]]) torch.cuda.nvtx.range_pop() # print("Accept", accepted_tokens) session.new_tokens = torch.nested.to_padded_tensor( torch.nested.nested_tensor(accepted_tokens, dtype=torch.int32), session.end_ids[0]) #FIXME end id padding. torch.cuda.nvtx.range_pop() return best_path, best_path_lengths