diff --git a/tests/v1/spec_decode/test_eagle_step_kernel.py b/tests/v1/spec_decode/test_eagle_step_kernel.py new file mode 100644 index 000000000..319ab4a33 --- /dev/null +++ b/tests/v1/spec_decode/test_eagle_step_kernel.py @@ -0,0 +1,175 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for the fused EAGLE slot mapping kernel.""" + +import pytest +import torch + +from vllm.v1.spec_decode.utils import ( + PADDING_SLOT_ID, + eagle_step_update_slot_mapping_and_metadata, +) + +# Skip if no CUDA - Triton kernel requires GPU +pytest.importorskip("triton") +if not torch.cuda.is_available(): + pytest.skip("CUDA required for EAGLE kernel tests", allow_module_level=True) + + +def _reference_eagle_step_slot_mapping( + positions_1d: torch.Tensor, + block_table_tensor: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_model_len: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Python reference for eagle_step_update_slot_mapping_and_metadata.""" + new_positions = positions_1d + 1 + exceeds_max = new_positions >= max_model_len + clamped_positions = torch.where( + exceeds_max, torch.zeros_like(positions_1d), new_positions + ) + block_numbers = (clamped_positions // block_size).clamp( + max=block_table_tensor.shape[1] - 1 + ) + block_ids = block_table_tensor[ + torch.arange(positions_1d.shape[0], device=positions_1d.device), + block_numbers.long(), + ].long() + slot_mapping = block_ids * block_size + (clamped_positions % block_size) + slot_mapping = torch.where( + exceeds_max, torch.full_like(slot_mapping, PADDING_SLOT_ID), slot_mapping + ) + new_seq_lens = torch.where(exceeds_max, torch.ones_like(seq_lens), seq_lens + 1) + new_seq_lens = new_seq_lens.clamp(max=max_model_len) + return clamped_positions, slot_mapping, new_seq_lens + + +def test_eagle_step_slot_mapping_kernel(): + """Test fused kernel matches Python reference for slot mapping and metadata.""" + device = torch.device("cuda") + batch_size = 32 + block_size = 16 + max_model_len = 4096 + n_blocks_per_req = (max_model_len + block_size - 1) // block_size + + positions_1d = torch.randint( + 0, max_model_len - 10, (batch_size,), dtype=torch.int64, device=device + ) + block_table_tensor = torch.randint( + 0, 1000, (batch_size, n_blocks_per_req), dtype=torch.int32, device=device + ) + seq_lens = torch.randint(1, 100, (batch_size,), dtype=torch.int32, device=device) + + ref_clamped, ref_slot, ref_seq_lens = _reference_eagle_step_slot_mapping( + positions_1d.clone(), + block_table_tensor, + seq_lens.clone(), + block_size, + max_model_len, + ) + + out_clamped = torch.zeros(batch_size, dtype=torch.int64, device=device) + out_slot = torch.zeros(batch_size, dtype=torch.int64, device=device) + seq_lens_copy = seq_lens.clone() + eagle_step_update_slot_mapping_and_metadata( + positions_1d=positions_1d, + block_table_tensor=block_table_tensor, + seq_lens=seq_lens_copy, + block_size=block_size, + max_model_len=max_model_len, + out_clamped_positions=out_clamped, + out_slot_mapping=out_slot, + ) + + assert torch.equal(out_clamped, ref_clamped), ( + f"clamped: {out_clamped} vs {ref_clamped}" + ) + assert torch.equal(out_slot, ref_slot), f"slot: {out_slot} vs {ref_slot}" + assert torch.equal(seq_lens_copy, ref_seq_lens), ( + f"seq_lens: {seq_lens_copy} vs {ref_seq_lens}" + ) + + +def test_eagle_step_slot_mapping_kernel_exceeds_max(): + """Test fused kernel when position exceeds max_model_len.""" + device = torch.device("cuda") + batch_size = 4 + block_size = 16 + max_model_len = 100 + n_blocks_per_req = (max_model_len + block_size - 1) // block_size + + positions_1d = torch.tensor([50, 98, 99, 100], dtype=torch.int64, device=device) + block_table_tensor = torch.randint( + 0, 100, (batch_size, n_blocks_per_req), dtype=torch.int32, device=device + ) + seq_lens = torch.tensor([51, 99, 100, 101], dtype=torch.int32, device=device) + + out_clamped = torch.zeros(batch_size, dtype=torch.int64, device=device) + out_slot = torch.zeros(batch_size, dtype=torch.int64, device=device) + eagle_step_update_slot_mapping_and_metadata( + positions_1d=positions_1d, + block_table_tensor=block_table_tensor, + seq_lens=seq_lens, + block_size=block_size, + max_model_len=max_model_len, + out_clamped_positions=out_clamped, + out_slot_mapping=out_slot, + ) + + assert out_clamped[0].item() == 51 + assert out_clamped[1].item() == 99 + assert out_clamped[2].item() == 0 + assert out_clamped[3].item() == 0 + assert out_slot[2].item() == PADDING_SLOT_ID + assert out_slot[3].item() == PADDING_SLOT_ID + assert seq_lens[2].item() == 1 + assert seq_lens[3].item() == 1 + + +def test_eagle_step_slot_mapping_kernel_cudagraph_padding(): + """Test that padding threads write PADDING_SLOT_ID when + input_batch_size > batch_size (cudagraph padding).""" + device = torch.device("cuda") + batch_size = 4 + input_batch_size = 8 + block_size = 16 + max_model_len = 4096 + n_blocks_per_req = (max_model_len + block_size - 1) // block_size + + positions_1d = torch.tensor([10, 20, 30, 40], dtype=torch.int64, device=device) + block_table_tensor = torch.randint( + 0, 100, (batch_size, n_blocks_per_req), dtype=torch.int32, device=device + ) + seq_lens = torch.tensor([11, 21, 31, 41], dtype=torch.int32, device=device) + + ref_clamped, ref_slot, ref_seq_lens = _reference_eagle_step_slot_mapping( + positions_1d.clone(), + block_table_tensor, + seq_lens.clone(), + block_size, + max_model_len, + ) + + out_clamped = torch.zeros(batch_size, dtype=torch.int64, device=device) + out_slot = torch.full((input_batch_size,), -999, dtype=torch.int64, device=device) + seq_lens_copy = seq_lens.clone() + eagle_step_update_slot_mapping_and_metadata( + positions_1d=positions_1d, + block_table_tensor=block_table_tensor, + seq_lens=seq_lens_copy, + block_size=block_size, + max_model_len=max_model_len, + out_clamped_positions=out_clamped, + out_slot_mapping=out_slot, + input_batch_size=input_batch_size, + ) + + # Real slots should match the reference + assert torch.equal(out_clamped, ref_clamped) + assert torch.equal(out_slot[:batch_size], ref_slot) + assert torch.equal(seq_lens_copy, ref_seq_lens) + + # Padding slots should be PADDING_SLOT_ID + for i in range(batch_size, input_batch_size): + assert out_slot[i].item() == PADDING_SLOT_ID diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 89c9c80ce..a5554d99f 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -44,6 +44,7 @@ from vllm.v1.spec_decode.utils import ( copy_and_expand_eagle_inputs_kernel, eagle_prepare_inputs_padded_kernel, eagle_prepare_next_token_padded_kernel, + eagle_step_update_slot_mapping_and_metadata, extend_all_queries_by_N, ) from vllm.v1.utils import CpuGpuBuffer @@ -533,41 +534,46 @@ class SpecDecodeBaseProposer: common_attn_metadata._seq_lens_cpu = None common_attn_metadata._num_computed_tokens_cpu = None + block_size = self.block_size + assert block_size > 0, "block_size has not been initialized." for token_index in range(self.num_speculative_tokens - 1): # Update the inputs. # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. input_ids = draft_token_ids_list[-1].int() + # Use fused kernel for slot mapping and metadata updates. + # Write clamped positions directly into the positions buffer to + # avoid an extra D2D copy for the common (non-mrope) case. + positions_1d = positions[0] if self.uses_mrope else positions if self.uses_mrope: - positions += 1 - # NOTE(woosuk): We should handle the case where the draft model - # generates tokens beyond the max model length. - # Since it is complex to remove such requests from the batch, - # we keep them in the batch but adjust the position ids - # and slot mappings to avoid the - # out-of-range access during the model execution. - # The draft tokens generated with this adjustment - # should be ignored. - exceeds_max_model_len = positions[0] >= self.max_model_len - # Mask out the position ids that exceed the max model length. - # Otherwise, we may get out-of-range error in RoPE. - clamped_positions = torch.where( - exceeds_max_model_len.unsqueeze(0), - torch.zeros_like(positions), - positions, - ) + out_pos = self.mrope_positions[0, :batch_size] + elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0: + out_pos = self.xdrope_positions[0, :batch_size] else: - positions += 1 - exceeds_max_model_len = positions >= self.max_model_len - clamped_positions = torch.where(exceeds_max_model_len, 0, positions) - # For data integrity when async scheduling, we shouldn't use in place - # operations in case they are modified in next step's `prepare_input` - # of main model. - # Increment the sequence lengths. - common_attn_metadata.seq_lens += 1 - # For the requests that exceed the max model length, we set the - # sequence length to 1 to minimize their overheads in attention. - common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) + out_pos = self.positions[:batch_size] + eagle_step_update_slot_mapping_and_metadata( + positions_1d=positions_1d, + block_table_tensor=common_attn_metadata.block_table_tensor, + seq_lens=common_attn_metadata.seq_lens, + block_size=block_size, + max_model_len=self.max_model_len, + out_clamped_positions=out_pos, + out_slot_mapping=self._slot_mapping_buffer[:input_batch_size], + input_batch_size=input_batch_size, + ) + common_attn_metadata.slot_mapping = self._slot_mapping_buffer[:batch_size] + if self.uses_mrope: + self.mrope_positions[1:, :batch_size] = self.mrope_positions[ + 0, :batch_size + ] + positions = self.mrope_positions[:, :batch_size] + elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0: + self.xdrope_positions[1:, :batch_size] = self.xdrope_positions[ + 0, :batch_size + ] + positions = self.xdrope_positions[0, :batch_size] + else: + positions = self.positions[:batch_size] # Increment the maximum sequence length. We increment max_seq_len # unconditionally even though some seq_lens may have been capped above, # as max_seq_len serves as an upper bound for sequence lengths. @@ -582,33 +588,6 @@ class SpecDecodeBaseProposer: if common_attn_metadata._num_computed_tokens_cpu is not None: common_attn_metadata._num_computed_tokens_cpu += 1 - # Compute the slot mapping. - block_size = self.block_size - assert block_size > 0, "block_size has not been initialized." - if self.uses_mrope: - # all dimensions of positions are the same - block_numbers = clamped_positions[0] // block_size - else: - block_numbers = clamped_positions // block_size - block_ids = common_attn_metadata.block_table_tensor.gather( - dim=1, index=block_numbers.view(-1, 1) - ) - block_ids = block_ids.view(-1) - if self.uses_mrope: - common_attn_metadata.slot_mapping = ( - block_ids * block_size + clamped_positions[0] % block_size - ) - else: - common_attn_metadata.slot_mapping = ( - block_ids * block_size + clamped_positions % block_size - ) - # Mask out the slot mappings that exceed the max model length. - # Otherwise, the KV cache will be inadvertently updated with the - # padding tokens. - common_attn_metadata.slot_mapping.masked_fill_( - exceeds_max_model_len, PADDING_SLOT_ID - ) - # Rebuild attention metadata for attn_group in self.draft_attn_groups: attn_metadata = attn_group.get_metadata_builder().build_for_drafting( @@ -620,7 +599,6 @@ class SpecDecodeBaseProposer: # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids - self._set_positions(batch_size, clamped_positions) self.hidden_states[:batch_size] = hidden_states if self.supports_mm_inputs: self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids) @@ -646,9 +624,7 @@ class SpecDecodeBaseProposer: num_tokens=input_batch_size, num_tokens_across_dp=batch_size_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, - slot_mapping=self._get_slot_mapping( - input_batch_size, common_attn_metadata.slot_mapping - ), + slot_mapping=self._get_slot_mapping(input_batch_size), ): ret_hidden_states = self.model(**model_kwargs) if not self.model_returns_tuple(): diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index 387c6df9b..cfc30c3e6 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -11,6 +11,114 @@ from vllm.v1.attention.backends.utils import ( PADDING_SLOT_ID = -1 +@triton.jit +def eagle_step_slot_mapping_metadata_kernel( + positions_ptr, # [batch_size] - current positions (1D view for M-RoPE) + block_table_ptr, # [batch_size, n_blocks_per_req] + block_table_stride, # stride for block_table dim 1 + seq_lens_ptr, # [batch_size] - read and write + out_clamped_positions_ptr, # [batch_size] (output) + out_slot_mapping_ptr, # [input_batch_size] (output) + block_size: tl.constexpr, + max_model_len: tl.constexpr, + n_blocks_per_req: tl.constexpr, + PAD_ID: tl.constexpr, + batch_size, +): + """ + Fused kernel for EAGLE autoregressive step: updates positions, slot mapping, + and sequence lengths in a single kernel to reduce launch overhead. + + Launched with input_batch_size threads. Threads with req_idx >= batch_size + are cudagraph padding slots and only write PADDING_SLOT_ID. + + Each real thread handles one request in the batch. Computes: + - new_position = position + 1, clamped if exceeds max_model_len + - slot_mapping from block table lookup + - seq_lens += 1, or 1 if position exceeds max + """ + req_idx = tl.program_id(0) + + if req_idx >= batch_size: + tl.store(out_slot_mapping_ptr + req_idx, PAD_ID) + return + + # Load current position and increment + position = tl.load(positions_ptr + req_idx) + new_position = position + 1 + + # Check bounds and compute clamped position + exceeds_max = new_position >= max_model_len + clamped_position = tl.where(exceeds_max, 0, new_position) + + # Block table lookup: block_number = position // block_size + # Clamp block_number to avoid OOB when position is at max + block_number = clamped_position // block_size + block_number = tl.minimum(block_number, n_blocks_per_req - 1) + + block_id = tl.load(block_table_ptr + req_idx * block_table_stride + block_number) + slot_id = block_id * block_size + (clamped_position % block_size) + slot_id = tl.where(exceeds_max, PAD_ID, slot_id) + + # Update seq_lens: +1 normally, or 1 if exceeded + seq_len = tl.load(seq_lens_ptr + req_idx) + new_seq_len = tl.where(exceeds_max, 1, seq_len + 1) + new_seq_len = tl.minimum(new_seq_len, max_model_len) + + # Store outputs + tl.store(out_clamped_positions_ptr + req_idx, clamped_position) + tl.store(out_slot_mapping_ptr + req_idx, slot_id) + tl.store(seq_lens_ptr + req_idx, new_seq_len) + + +def eagle_step_update_slot_mapping_and_metadata( + positions_1d: torch.Tensor, + block_table_tensor: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_model_len: int, + out_clamped_positions: torch.Tensor, + out_slot_mapping: torch.Tensor, + input_batch_size: int | None = None, +) -> None: + """ + Fused update of slot mapping and metadata for one EAGLE autoregressive step. + Updates seq_lens in place. Writes to out_clamped_positions and out_slot_mapping. + + When input_batch_size > batch_size, threads beyond batch_size write + PADDING_SLOT_ID to out_slot_mapping for cudagraph padding. + + Args: + positions_1d: [batch_size] current positions (use positions[0] for M-RoPE) + block_table_tensor: [batch_size, n_blocks_per_req] + seq_lens: [batch_size] updated in place + block_size: KV cache block size + max_model_len: max model length for clamping + out_clamped_positions: [batch_size] output buffer for clamped positions + out_slot_mapping: [input_batch_size] output buffer for slot mapping + input_batch_size: total batch size including cudagraph padding; + defaults to batch_size (no padding) + """ + batch_size = positions_1d.shape[0] + if input_batch_size is None: + input_batch_size = batch_size + n_blocks_per_req = block_table_tensor.shape[1] + + eagle_step_slot_mapping_metadata_kernel[(input_batch_size,)]( + positions_1d, + block_table_tensor, + block_table_tensor.stride(0), + seq_lens, + out_clamped_positions, + out_slot_mapping, + block_size=block_size, + max_model_len=max_model_len, + n_blocks_per_req=n_blocks_per_req, + PAD_ID=PADDING_SLOT_ID, + batch_size=batch_size, + ) + + @triton.jit def eagle_prepare_inputs_padded_kernel( cu_num_draft_tokens_ptr, # [num_reqs]