feat(spec_decode): fuse EAGLE step slot mapping and metadata updates (#33503)
Signed-off-by: sladynnunes <snunes@usc.edu>
This commit is contained in:
@@ -44,6 +44,7 @@ from vllm.v1.spec_decode.utils import (
|
||||
copy_and_expand_eagle_inputs_kernel,
|
||||
eagle_prepare_inputs_padded_kernel,
|
||||
eagle_prepare_next_token_padded_kernel,
|
||||
eagle_step_update_slot_mapping_and_metadata,
|
||||
extend_all_queries_by_N,
|
||||
)
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
@@ -533,41 +534,46 @@ class SpecDecodeBaseProposer:
|
||||
common_attn_metadata._seq_lens_cpu = None
|
||||
common_attn_metadata._num_computed_tokens_cpu = None
|
||||
|
||||
block_size = self.block_size
|
||||
assert block_size > 0, "block_size has not been initialized."
|
||||
for token_index in range(self.num_speculative_tokens - 1):
|
||||
# Update the inputs.
|
||||
# cast to int32 is crucial when eagle model is compiled.
|
||||
# tensor.argmax() returns int64 by default.
|
||||
input_ids = draft_token_ids_list[-1].int()
|
||||
# Use fused kernel for slot mapping and metadata updates.
|
||||
# Write clamped positions directly into the positions buffer to
|
||||
# avoid an extra D2D copy for the common (non-mrope) case.
|
||||
positions_1d = positions[0] if self.uses_mrope else positions
|
||||
if self.uses_mrope:
|
||||
positions += 1
|
||||
# NOTE(woosuk): We should handle the case where the draft model
|
||||
# generates tokens beyond the max model length.
|
||||
# Since it is complex to remove such requests from the batch,
|
||||
# we keep them in the batch but adjust the position ids
|
||||
# and slot mappings to avoid the
|
||||
# out-of-range access during the model execution.
|
||||
# The draft tokens generated with this adjustment
|
||||
# should be ignored.
|
||||
exceeds_max_model_len = positions[0] >= self.max_model_len
|
||||
# Mask out the position ids that exceed the max model length.
|
||||
# Otherwise, we may get out-of-range error in RoPE.
|
||||
clamped_positions = torch.where(
|
||||
exceeds_max_model_len.unsqueeze(0),
|
||||
torch.zeros_like(positions),
|
||||
positions,
|
||||
)
|
||||
out_pos = self.mrope_positions[0, :batch_size]
|
||||
elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
|
||||
out_pos = self.xdrope_positions[0, :batch_size]
|
||||
else:
|
||||
positions += 1
|
||||
exceeds_max_model_len = positions >= self.max_model_len
|
||||
clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
|
||||
# For data integrity when async scheduling, we shouldn't use in place
|
||||
# operations in case they are modified in next step's `prepare_input`
|
||||
# of main model.
|
||||
# Increment the sequence lengths.
|
||||
common_attn_metadata.seq_lens += 1
|
||||
# For the requests that exceed the max model length, we set the
|
||||
# sequence length to 1 to minimize their overheads in attention.
|
||||
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
|
||||
out_pos = self.positions[:batch_size]
|
||||
eagle_step_update_slot_mapping_and_metadata(
|
||||
positions_1d=positions_1d,
|
||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
block_size=block_size,
|
||||
max_model_len=self.max_model_len,
|
||||
out_clamped_positions=out_pos,
|
||||
out_slot_mapping=self._slot_mapping_buffer[:input_batch_size],
|
||||
input_batch_size=input_batch_size,
|
||||
)
|
||||
common_attn_metadata.slot_mapping = self._slot_mapping_buffer[:batch_size]
|
||||
if self.uses_mrope:
|
||||
self.mrope_positions[1:, :batch_size] = self.mrope_positions[
|
||||
0, :batch_size
|
||||
]
|
||||
positions = self.mrope_positions[:, :batch_size]
|
||||
elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
|
||||
self.xdrope_positions[1:, :batch_size] = self.xdrope_positions[
|
||||
0, :batch_size
|
||||
]
|
||||
positions = self.xdrope_positions[0, :batch_size]
|
||||
else:
|
||||
positions = self.positions[:batch_size]
|
||||
# Increment the maximum sequence length. We increment max_seq_len
|
||||
# unconditionally even though some seq_lens may have been capped above,
|
||||
# as max_seq_len serves as an upper bound for sequence lengths.
|
||||
@@ -582,33 +588,6 @@ class SpecDecodeBaseProposer:
|
||||
if common_attn_metadata._num_computed_tokens_cpu is not None:
|
||||
common_attn_metadata._num_computed_tokens_cpu += 1
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_size = self.block_size
|
||||
assert block_size > 0, "block_size has not been initialized."
|
||||
if self.uses_mrope:
|
||||
# all dimensions of positions are the same
|
||||
block_numbers = clamped_positions[0] // block_size
|
||||
else:
|
||||
block_numbers = clamped_positions // block_size
|
||||
block_ids = common_attn_metadata.block_table_tensor.gather(
|
||||
dim=1, index=block_numbers.view(-1, 1)
|
||||
)
|
||||
block_ids = block_ids.view(-1)
|
||||
if self.uses_mrope:
|
||||
common_attn_metadata.slot_mapping = (
|
||||
block_ids * block_size + clamped_positions[0] % block_size
|
||||
)
|
||||
else:
|
||||
common_attn_metadata.slot_mapping = (
|
||||
block_ids * block_size + clamped_positions % block_size
|
||||
)
|
||||
# Mask out the slot mappings that exceed the max model length.
|
||||
# Otherwise, the KV cache will be inadvertently updated with the
|
||||
# padding tokens.
|
||||
common_attn_metadata.slot_mapping.masked_fill_(
|
||||
exceeds_max_model_len, PADDING_SLOT_ID
|
||||
)
|
||||
|
||||
# Rebuild attention metadata
|
||||
for attn_group in self.draft_attn_groups:
|
||||
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
|
||||
@@ -620,7 +599,6 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
self._set_positions(batch_size, clamped_positions)
|
||||
self.hidden_states[:batch_size] = hidden_states
|
||||
if self.supports_mm_inputs:
|
||||
self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids)
|
||||
@@ -646,9 +624,7 @@ class SpecDecodeBaseProposer:
|
||||
num_tokens=input_batch_size,
|
||||
num_tokens_across_dp=batch_size_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
slot_mapping=self._get_slot_mapping(
|
||||
input_batch_size, common_attn_metadata.slot_mapping
|
||||
),
|
||||
slot_mapping=self._get_slot_mapping(input_batch_size),
|
||||
):
|
||||
ret_hidden_states = self.model(**model_kwargs)
|
||||
if not self.model_returns_tuple():
|
||||
|
||||
@@ -11,6 +11,114 @@ from vllm.v1.attention.backends.utils import (
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def eagle_step_slot_mapping_metadata_kernel(
|
||||
positions_ptr, # [batch_size] - current positions (1D view for M-RoPE)
|
||||
block_table_ptr, # [batch_size, n_blocks_per_req]
|
||||
block_table_stride, # stride for block_table dim 1
|
||||
seq_lens_ptr, # [batch_size] - read and write
|
||||
out_clamped_positions_ptr, # [batch_size] (output)
|
||||
out_slot_mapping_ptr, # [input_batch_size] (output)
|
||||
block_size: tl.constexpr,
|
||||
max_model_len: tl.constexpr,
|
||||
n_blocks_per_req: tl.constexpr,
|
||||
PAD_ID: tl.constexpr,
|
||||
batch_size,
|
||||
):
|
||||
"""
|
||||
Fused kernel for EAGLE autoregressive step: updates positions, slot mapping,
|
||||
and sequence lengths in a single kernel to reduce launch overhead.
|
||||
|
||||
Launched with input_batch_size threads. Threads with req_idx >= batch_size
|
||||
are cudagraph padding slots and only write PADDING_SLOT_ID.
|
||||
|
||||
Each real thread handles one request in the batch. Computes:
|
||||
- new_position = position + 1, clamped if exceeds max_model_len
|
||||
- slot_mapping from block table lookup
|
||||
- seq_lens += 1, or 1 if position exceeds max
|
||||
"""
|
||||
req_idx = tl.program_id(0)
|
||||
|
||||
if req_idx >= batch_size:
|
||||
tl.store(out_slot_mapping_ptr + req_idx, PAD_ID)
|
||||
return
|
||||
|
||||
# Load current position and increment
|
||||
position = tl.load(positions_ptr + req_idx)
|
||||
new_position = position + 1
|
||||
|
||||
# Check bounds and compute clamped position
|
||||
exceeds_max = new_position >= max_model_len
|
||||
clamped_position = tl.where(exceeds_max, 0, new_position)
|
||||
|
||||
# Block table lookup: block_number = position // block_size
|
||||
# Clamp block_number to avoid OOB when position is at max
|
||||
block_number = clamped_position // block_size
|
||||
block_number = tl.minimum(block_number, n_blocks_per_req - 1)
|
||||
|
||||
block_id = tl.load(block_table_ptr + req_idx * block_table_stride + block_number)
|
||||
slot_id = block_id * block_size + (clamped_position % block_size)
|
||||
slot_id = tl.where(exceeds_max, PAD_ID, slot_id)
|
||||
|
||||
# Update seq_lens: +1 normally, or 1 if exceeded
|
||||
seq_len = tl.load(seq_lens_ptr + req_idx)
|
||||
new_seq_len = tl.where(exceeds_max, 1, seq_len + 1)
|
||||
new_seq_len = tl.minimum(new_seq_len, max_model_len)
|
||||
|
||||
# Store outputs
|
||||
tl.store(out_clamped_positions_ptr + req_idx, clamped_position)
|
||||
tl.store(out_slot_mapping_ptr + req_idx, slot_id)
|
||||
tl.store(seq_lens_ptr + req_idx, new_seq_len)
|
||||
|
||||
|
||||
def eagle_step_update_slot_mapping_and_metadata(
|
||||
positions_1d: torch.Tensor,
|
||||
block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
block_size: int,
|
||||
max_model_len: int,
|
||||
out_clamped_positions: torch.Tensor,
|
||||
out_slot_mapping: torch.Tensor,
|
||||
input_batch_size: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Fused update of slot mapping and metadata for one EAGLE autoregressive step.
|
||||
Updates seq_lens in place. Writes to out_clamped_positions and out_slot_mapping.
|
||||
|
||||
When input_batch_size > batch_size, threads beyond batch_size write
|
||||
PADDING_SLOT_ID to out_slot_mapping for cudagraph padding.
|
||||
|
||||
Args:
|
||||
positions_1d: [batch_size] current positions (use positions[0] for M-RoPE)
|
||||
block_table_tensor: [batch_size, n_blocks_per_req]
|
||||
seq_lens: [batch_size] updated in place
|
||||
block_size: KV cache block size
|
||||
max_model_len: max model length for clamping
|
||||
out_clamped_positions: [batch_size] output buffer for clamped positions
|
||||
out_slot_mapping: [input_batch_size] output buffer for slot mapping
|
||||
input_batch_size: total batch size including cudagraph padding;
|
||||
defaults to batch_size (no padding)
|
||||
"""
|
||||
batch_size = positions_1d.shape[0]
|
||||
if input_batch_size is None:
|
||||
input_batch_size = batch_size
|
||||
n_blocks_per_req = block_table_tensor.shape[1]
|
||||
|
||||
eagle_step_slot_mapping_metadata_kernel[(input_batch_size,)](
|
||||
positions_1d,
|
||||
block_table_tensor,
|
||||
block_table_tensor.stride(0),
|
||||
seq_lens,
|
||||
out_clamped_positions,
|
||||
out_slot_mapping,
|
||||
block_size=block_size,
|
||||
max_model_len=max_model_len,
|
||||
n_blocks_per_req=n_blocks_per_req,
|
||||
PAD_ID=PADDING_SLOT_ID,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def eagle_prepare_inputs_padded_kernel(
|
||||
cu_num_draft_tokens_ptr, # [num_reqs]
|
||||
|
||||
Reference in New Issue
Block a user