feat(spec_decode): fuse EAGLE step slot mapping and metadata updates (#33503)

Signed-off-by: sladynnunes <snunes@usc.edu>
This commit is contained in:
Sladyn
2026-03-10 21:35:33 -07:00
committed by GitHub
parent 4bf533623b
commit 4aaaf8c8ce
3 changed files with 318 additions and 59 deletions

View File

@@ -44,6 +44,7 @@ from vllm.v1.spec_decode.utils import (
copy_and_expand_eagle_inputs_kernel,
eagle_prepare_inputs_padded_kernel,
eagle_prepare_next_token_padded_kernel,
eagle_step_update_slot_mapping_and_metadata,
extend_all_queries_by_N,
)
from vllm.v1.utils import CpuGpuBuffer
@@ -533,41 +534,46 @@ class SpecDecodeBaseProposer:
common_attn_metadata._seq_lens_cpu = None
common_attn_metadata._num_computed_tokens_cpu = None
block_size = self.block_size
assert block_size > 0, "block_size has not been initialized."
for token_index in range(self.num_speculative_tokens - 1):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
input_ids = draft_token_ids_list[-1].int()
# Use fused kernel for slot mapping and metadata updates.
# Write clamped positions directly into the positions buffer to
# avoid an extra D2D copy for the common (non-mrope) case.
positions_1d = positions[0] if self.uses_mrope else positions
if self.uses_mrope:
positions += 1
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length.
# Since it is complex to remove such requests from the batch,
# we keep them in the batch but adjust the position ids
# and slot mappings to avoid the
# out-of-range access during the model execution.
# The draft tokens generated with this adjustment
# should be ignored.
exceeds_max_model_len = positions[0] >= self.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(
exceeds_max_model_len.unsqueeze(0),
torch.zeros_like(positions),
positions,
)
out_pos = self.mrope_positions[0, :batch_size]
elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
out_pos = self.xdrope_positions[0, :batch_size]
else:
positions += 1
exceeds_max_model_len = positions >= self.max_model_len
clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
# For data integrity when async scheduling, we shouldn't use in place
# operations in case they are modified in next step's `prepare_input`
# of main model.
# Increment the sequence lengths.
common_attn_metadata.seq_lens += 1
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
out_pos = self.positions[:batch_size]
eagle_step_update_slot_mapping_and_metadata(
positions_1d=positions_1d,
block_table_tensor=common_attn_metadata.block_table_tensor,
seq_lens=common_attn_metadata.seq_lens,
block_size=block_size,
max_model_len=self.max_model_len,
out_clamped_positions=out_pos,
out_slot_mapping=self._slot_mapping_buffer[:input_batch_size],
input_batch_size=input_batch_size,
)
common_attn_metadata.slot_mapping = self._slot_mapping_buffer[:batch_size]
if self.uses_mrope:
self.mrope_positions[1:, :batch_size] = self.mrope_positions[
0, :batch_size
]
positions = self.mrope_positions[:, :batch_size]
elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
self.xdrope_positions[1:, :batch_size] = self.xdrope_positions[
0, :batch_size
]
positions = self.xdrope_positions[0, :batch_size]
else:
positions = self.positions[:batch_size]
# Increment the maximum sequence length. We increment max_seq_len
# unconditionally even though some seq_lens may have been capped above,
# as max_seq_len serves as an upper bound for sequence lengths.
@@ -582,33 +588,6 @@ class SpecDecodeBaseProposer:
if common_attn_metadata._num_computed_tokens_cpu is not None:
common_attn_metadata._num_computed_tokens_cpu += 1
# Compute the slot mapping.
block_size = self.block_size
assert block_size > 0, "block_size has not been initialized."
if self.uses_mrope:
# all dimensions of positions are the same
block_numbers = clamped_positions[0] // block_size
else:
block_numbers = clamped_positions // block_size
block_ids = common_attn_metadata.block_table_tensor.gather(
dim=1, index=block_numbers.view(-1, 1)
)
block_ids = block_ids.view(-1)
if self.uses_mrope:
common_attn_metadata.slot_mapping = (
block_ids * block_size + clamped_positions[0] % block_size
)
else:
common_attn_metadata.slot_mapping = (
block_ids * block_size + clamped_positions % block_size
)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
common_attn_metadata.slot_mapping.masked_fill_(
exceeds_max_model_len, PADDING_SLOT_ID
)
# Rebuild attention metadata
for attn_group in self.draft_attn_groups:
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
@@ -620,7 +599,6 @@ class SpecDecodeBaseProposer:
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self._set_positions(batch_size, clamped_positions)
self.hidden_states[:batch_size] = hidden_states
if self.supports_mm_inputs:
self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids)
@@ -646,9 +624,7 @@ class SpecDecodeBaseProposer:
num_tokens=input_batch_size,
num_tokens_across_dp=batch_size_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
input_batch_size, common_attn_metadata.slot_mapping
),
slot_mapping=self._get_slot_mapping(input_batch_size),
):
ret_hidden_states = self.model(**model_kwargs)
if not self.model_returns_tuple():

View File

@@ -11,6 +11,114 @@ from vllm.v1.attention.backends.utils import (
PADDING_SLOT_ID = -1
@triton.jit
def eagle_step_slot_mapping_metadata_kernel(
positions_ptr, # [batch_size] - current positions (1D view for M-RoPE)
block_table_ptr, # [batch_size, n_blocks_per_req]
block_table_stride, # stride for block_table dim 1
seq_lens_ptr, # [batch_size] - read and write
out_clamped_positions_ptr, # [batch_size] (output)
out_slot_mapping_ptr, # [input_batch_size] (output)
block_size: tl.constexpr,
max_model_len: tl.constexpr,
n_blocks_per_req: tl.constexpr,
PAD_ID: tl.constexpr,
batch_size,
):
"""
Fused kernel for EAGLE autoregressive step: updates positions, slot mapping,
and sequence lengths in a single kernel to reduce launch overhead.
Launched with input_batch_size threads. Threads with req_idx >= batch_size
are cudagraph padding slots and only write PADDING_SLOT_ID.
Each real thread handles one request in the batch. Computes:
- new_position = position + 1, clamped if exceeds max_model_len
- slot_mapping from block table lookup
- seq_lens += 1, or 1 if position exceeds max
"""
req_idx = tl.program_id(0)
if req_idx >= batch_size:
tl.store(out_slot_mapping_ptr + req_idx, PAD_ID)
return
# Load current position and increment
position = tl.load(positions_ptr + req_idx)
new_position = position + 1
# Check bounds and compute clamped position
exceeds_max = new_position >= max_model_len
clamped_position = tl.where(exceeds_max, 0, new_position)
# Block table lookup: block_number = position // block_size
# Clamp block_number to avoid OOB when position is at max
block_number = clamped_position // block_size
block_number = tl.minimum(block_number, n_blocks_per_req - 1)
block_id = tl.load(block_table_ptr + req_idx * block_table_stride + block_number)
slot_id = block_id * block_size + (clamped_position % block_size)
slot_id = tl.where(exceeds_max, PAD_ID, slot_id)
# Update seq_lens: +1 normally, or 1 if exceeded
seq_len = tl.load(seq_lens_ptr + req_idx)
new_seq_len = tl.where(exceeds_max, 1, seq_len + 1)
new_seq_len = tl.minimum(new_seq_len, max_model_len)
# Store outputs
tl.store(out_clamped_positions_ptr + req_idx, clamped_position)
tl.store(out_slot_mapping_ptr + req_idx, slot_id)
tl.store(seq_lens_ptr + req_idx, new_seq_len)
def eagle_step_update_slot_mapping_and_metadata(
positions_1d: torch.Tensor,
block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_model_len: int,
out_clamped_positions: torch.Tensor,
out_slot_mapping: torch.Tensor,
input_batch_size: int | None = None,
) -> None:
"""
Fused update of slot mapping and metadata for one EAGLE autoregressive step.
Updates seq_lens in place. Writes to out_clamped_positions and out_slot_mapping.
When input_batch_size > batch_size, threads beyond batch_size write
PADDING_SLOT_ID to out_slot_mapping for cudagraph padding.
Args:
positions_1d: [batch_size] current positions (use positions[0] for M-RoPE)
block_table_tensor: [batch_size, n_blocks_per_req]
seq_lens: [batch_size] updated in place
block_size: KV cache block size
max_model_len: max model length for clamping
out_clamped_positions: [batch_size] output buffer for clamped positions
out_slot_mapping: [input_batch_size] output buffer for slot mapping
input_batch_size: total batch size including cudagraph padding;
defaults to batch_size (no padding)
"""
batch_size = positions_1d.shape[0]
if input_batch_size is None:
input_batch_size = batch_size
n_blocks_per_req = block_table_tensor.shape[1]
eagle_step_slot_mapping_metadata_kernel[(input_batch_size,)](
positions_1d,
block_table_tensor,
block_table_tensor.stride(0),
seq_lens,
out_clamped_positions,
out_slot_mapping,
block_size=block_size,
max_model_len=max_model_len,
n_blocks_per_req=n_blocks_per_req,
PAD_ID=PADDING_SLOT_ID,
batch_size=batch_size,
)
@triton.jit
def eagle_prepare_inputs_padded_kernel(
cu_num_draft_tokens_ptr, # [num_reqs]