diff --git a/tests/kernels/attention/test_flashinfer.py b/tests/kernels/attention/test_flashinfer.py index 570bf7fc8..9a0847697 100644 --- a/tests/kernels/attention/test_flashinfer.py +++ b/tests/kernels/attention/test_flashinfer.py @@ -84,6 +84,209 @@ def ref_paged_attn( return torch.cat(outputs, dim=0) +def _make_paged_kv_metadata( + kv_lens: list[int], + block_size: int, + num_blocks: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Build paged-KV metadata tensors for fast_plan_decode tests. + + Returns: + kv_indptr – CPU int32, shape [num_seqs + 1] + kv_indices – CUDA int32, shape [total_blocks] + kv_last_page_lens – CPU int32, shape [num_seqs] + block_tables – CUDA int32, shape [num_seqs, max_blocks_per_seq] + """ + num_seqs = len(kv_lens) + max_blocks = (max(kv_lens) + block_size - 1) // block_size + block_tables = torch.randint( + 0, num_blocks, (num_seqs, max_blocks), dtype=torch.int32, device="cuda" + ) + + indptr_list = [0] + indices_list: list[int] = [] + last_lens_list: list[int] = [] + for i, seq_len in enumerate(kv_lens): + n = (seq_len + block_size - 1) // block_size + indices_list.extend(block_tables[i, :n].cpu().tolist()) + indptr_list.append(indptr_list[-1] + n) + last_lens_list.append(seq_len % block_size or block_size) + + return ( + torch.tensor(indptr_list, dtype=torch.int32, device="cpu"), + torch.tensor(indices_list, dtype=torch.int32, device="cuda"), + torch.tensor(last_lens_list, dtype=torch.int32, device="cpu"), + block_tables, + ) + + +def _make_cg_decode_wrapper( + num_seqs: int, + kv_indices_buffer: torch.Tensor, + workspace_buffer: torch.Tensor, + use_tensor_cores: bool = True, +) -> "flashinfer.BatchDecodeWithPagedKVCacheWrapper": + """Create a cudagraph-enabled BatchDecodeWithPagedKVCacheWrapper. + + *kv_indices_buffer* is shared with the caller so that fast_plan_decode + can avoid the device-to-device index copy on subsequent (cudagraph) calls. + """ + return flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + "NHD", + use_cuda_graph=True, + paged_kv_indptr_buffer=torch.zeros( + num_seqs + 1, dtype=torch.int32, device="cuda" + ), + paged_kv_indices_buffer=kv_indices_buffer, + paged_kv_last_page_len_buffer=torch.zeros( + num_seqs, dtype=torch.int32, device="cuda" + ), + use_tensor_cores=use_tensor_cores, + ) + + +def test_fast_decode_plan_importable() -> None: + """fast_decode_plan must be importable from flashinfer.decode. + + This is a forward-compatibility smoke test: if FlashInfer reorganises its + public API the import will fail before any other test does. + """ + from flashinfer.decode import fast_decode_plan # noqa: F401 + + assert callable(fast_decode_plan) + + +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode +def test_fast_plan_decode_warmup_uses_full_plan(dtype: torch.dtype) -> None: + """On the first call fast_plan_decode must route through self.plan() and + flip vllm_first_call to False on the wrapper object.""" + from unittest.mock import patch + + from vllm.v1.attention.backends.flashinfer import fast_plan_decode + + torch.set_default_device("cuda") + set_random_seed(0) + + kv_lens = [128, 64] + block_size = 16 + num_seqs = len(kv_lens) + num_query_heads, num_kv_heads = 8, 2 + head_size = 128 + + kv_indptr, kv_indices, kv_last_page_lens, _ = _make_paged_kv_metadata( + kv_lens, block_size, NUM_BLOCKS + ) + + workspace = torch.empty(128 * 1024 * 1024, dtype=torch.int8) + wrapper = _make_cg_decode_wrapper(num_seqs, kv_indices.clone(), workspace) + + assert getattr(wrapper, "vllm_first_call", True) is True + + with patch.object(wrapper, "plan", wraps=wrapper.plan) as mock_plan: + fast_plan_decode( + wrapper, + indptr_cpu=kv_indptr, + indices=kv_indices, + last_page_len_cpu=kv_last_page_lens, + num_qo_heads=num_query_heads, + num_kv_heads=num_kv_heads, + head_dim=head_size, + page_size=block_size, + q_data_type=dtype, + kv_data_type=dtype, + ) + mock_plan.assert_called_once() + + assert wrapper.vllm_first_call is False, ( + "vllm_first_call should be False after the first fast_plan_decode call" + ) + + +@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode +def test_fast_plan_decode_matches_full_plan( + kv_lens: list[int], + num_heads: tuple[int, int], + head_size: int, + block_size: int, + dtype: torch.dtype, +) -> None: + """fast_plan_decode's cudagraph path (delegating to FlashInfer's + fast_decode_plan) must produce attention output numerically identical to + a standard plan() call. + + Both the warmup call (self.plan) and the subsequent fast call + (fast_decode_plan) are verified against the same reference. + """ + from vllm.v1.attention.backends.flashinfer import fast_plan_decode + + torch.set_default_device("cuda") + set_random_seed(0) + num_seqs = len(kv_lens) + num_query_heads, num_kv_heads = num_heads + + query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) + key_value_cache = torch.randn( + NUM_BLOCKS, 2, block_size, num_kv_heads, head_size, dtype=dtype + ) + + kv_indptr, kv_indices, kv_last_page_lens, _ = _make_paged_kv_metadata( + kv_lens, block_size, NUM_BLOCKS + ) + + # Reference output via the standard plan() + workspace_ref = torch.empty(128 * 1024 * 1024, dtype=torch.int8) + ref_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_ref, "NHD", use_tensor_cores=True + ) + ref_wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_lens, + num_query_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + q_data_type=dtype, + kv_data_type=dtype, + ) + ref_output = ref_wrapper.run(query, key_value_cache) + + # CUDAGraph wrapper exercised through fast_plan_decode + kv_indices_buf = kv_indices.clone() + workspace_cg = torch.empty(128 * 1024 * 1024, dtype=torch.int8) + cg_wrapper = _make_cg_decode_wrapper(num_seqs, kv_indices_buf, workspace_cg) + + plan_kwargs: dict = dict( + indptr_cpu=kv_indptr, + indices=kv_indices_buf, + last_page_len_cpu=kv_last_page_lens, + num_qo_heads=num_query_heads, + num_kv_heads=num_kv_heads, + head_dim=head_size, + page_size=block_size, + q_data_type=dtype, + kv_data_type=dtype, + ) + + # First call – warmup path (routes through self.plan) + fast_plan_decode(cg_wrapper, **plan_kwargs) + warmup_output = cg_wrapper.run(query, key_value_cache) + torch.testing.assert_close(warmup_output, ref_output, atol=1e-2, rtol=1e-2) + + # Second call – fast path (routes through fast_decode_plan from FlashInfer) + fast_plan_decode(cg_wrapper, **plan_kwargs) + fast_output = cg_wrapper.run(query, key_value_cache) + torch.testing.assert_close(fast_output, ref_output, atol=1e-2, rtol=1e-2) + + @pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 80297720d..5300cf56c 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -13,7 +13,7 @@ from flashinfer import ( BatchPrefillWithRaggedKVCacheWrapper, MultiLevelCascadeAttentionWrapper, ) -from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache +from flashinfer.decode import fast_decode_plan, trtllm_batch_decode_with_kv_cache from flashinfer.prefill import trtllm_batch_context_with_kv_cache from flashinfer.utils import FP4Tensor from typing_extensions import override @@ -199,14 +199,14 @@ class BatchDCPPrefillWrapper: ): """Plan the prefill operation with given parameters.""" self._context.plan( - qo_indptr_cpu, - paged_kv_indptr_cpu, - paged_kv_indices, - paged_kv_last_page_len_cpu, - num_qo_heads * dcp_world_size, - num_kv_heads, - head_dim, - page_size, + qo_indptr=qo_indptr_cpu, + paged_kv_indptr=paged_kv_indptr_cpu, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len=paged_kv_last_page_len_cpu, + num_qo_heads=num_qo_heads * dcp_world_size, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim, + page_size=page_size, causal=False, # This is context run sm_scale=sm_scale, window_left=window_left, @@ -818,6 +818,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): page_size, paged_kv_last_page_len_np, ) + self.paged_kv_last_page_len.gpu[:num_reqs].copy_( + self.paged_kv_last_page_len.cpu[:num_reqs], non_blocking=True + ) return paged_kv_indices def build( @@ -999,14 +1002,17 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): attn_metadata.cascade_wrapper = self._get_cascade_wrapper() attn_metadata.cascade_wrapper.plan( - [shared_qo_indptr_cpu, qo_indptr_cpu], - [shared_kv_page_indptr_cpu, paged_kv_indptr_cpu], - [shared_kv_page_indices_cpu, paged_kv_indices], - [shared_kv_last_page_len_cpu, paged_kv_last_page_len_cpu], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - self.page_size, + qo_indptr_arr=[shared_qo_indptr_cpu, qo_indptr_cpu], + paged_kv_indptr_arr=[shared_kv_page_indptr_cpu, paged_kv_indptr_cpu], + paged_kv_indices_arr=[shared_kv_page_indices_cpu, paged_kv_indices], + paged_kv_last_page_len=[ + shared_kv_last_page_len_cpu, + paged_kv_last_page_len_cpu, + ], + num_qo_heads=self.num_qo_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + page_size=self.page_size, causal=True, sm_scale=self.sm_scale, window_left=self.window_left, @@ -1084,14 +1090,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): BatchPrefillWithPagedKVCacheWrapper, ) prefill_wrapper.plan( - qo_indptr_prefill_cpu, - paged_kv_indptr_prefill_cpu, - paged_kv_indices, - paged_kv_last_page_len_prefill_cpu, - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - self.page_size, + qo_indptr=qo_indptr_prefill_cpu, + paged_kv_indptr=paged_kv_indptr_prefill_cpu, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len=paged_kv_last_page_len_prefill_cpu, + num_qo_heads=self.num_qo_heads, + num_kv_heads=self.num_kv_heads, + head_dim_qk=self.head_dim, + page_size=self.page_size, causal=True, sm_scale=self.sm_scale, window_left=self.window_left, @@ -1132,14 +1138,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): # in atten_metadata when using cudagraph. fast_plan_decode( decode_wrapper, - self.paged_kv_indptr.cpu[: num_input_tokens + 1], - paged_kv_indices, - self.paged_kv_last_page_len.cpu[:num_input_tokens], - seq_lens_cpu[:num_input_tokens], - self.num_qo_heads * self.dcp_world_size, - self.num_kv_heads, - self.head_dim, - self.page_size, + indptr_cpu=self.paged_kv_indptr.cpu[: num_input_tokens + 1], + indices=paged_kv_indices, + last_page_len_cpu=self.paged_kv_last_page_len.cpu[ + :num_input_tokens + ], + num_qo_heads=self.num_qo_heads * self.dcp_world_size, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + page_size=self.page_size, # Disable flashinfer's pos encoding and use vllm's rope. pos_encoding_mode="NONE", sm_scale=self.sm_scale, @@ -1617,7 +1624,6 @@ def fast_plan_decode( indptr_cpu: torch.Tensor, indices: torch.Tensor, last_page_len_cpu: torch.Tensor, - seq_lens_cpu: torch.Tensor, num_qo_heads: int, num_kv_heads: int, head_dim: int, @@ -1654,110 +1660,56 @@ def fast_plan_decode( # this warm up is to generate the _cached_module for the decode wrapper. if not self.is_cuda_graph_enabled or getattr(self, "vllm_first_call", True): self.plan( - indptr_cpu, - indices, - last_page_len_cpu, - num_qo_heads, - num_kv_heads, - head_dim, - page_size, - pos_encoding_mode, - window_left, - logits_soft_cap, - q_data_type, - kv_data_type, - o_data_type, - data_type, - sm_scale, - rope_scale, - rope_theta, - non_blocking, - None, # block_tables - None, # seq_lens - fixed_split_size, - disable_split_kv, + indptr=indptr_cpu, + indices=indices, + last_page_len=last_page_len_cpu, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + page_size=page_size, + pos_encoding_mode=pos_encoding_mode, + window_left=window_left, + logits_soft_cap=logits_soft_cap, + q_data_type=q_data_type, + kv_data_type=kv_data_type, + o_data_type=o_data_type, + data_type=data_type, + sm_scale=sm_scale, + rope_scale=rope_scale, + rope_theta=rope_theta, + non_blocking=non_blocking, + block_tables=None, + seq_lens=None, + fixed_split_size=fixed_split_size, + disable_split_kv=disable_split_kv, ) self.vllm_first_call = False return assert self.is_cuda_graph_enabled, "Should be cudagraph only here" - batch_size = len(last_page_len_cpu) - if logits_soft_cap is None: - logits_soft_cap = 0.0 - - # Handle data types consistently - if data_type is not None: - if q_data_type is None: - q_data_type = data_type - if kv_data_type is None: - kv_data_type = data_type - elif q_data_type is None: - q_data_type = "float16" - - if kv_data_type is None: - kv_data_type = q_data_type - q_data_type = ( - getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type + fast_decode_plan( + self, + indptr=indptr_cpu, + indices=indices, + last_page_len=last_page_len_cpu, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + page_size=page_size, + pos_encoding_mode=pos_encoding_mode, + window_left=window_left, + logits_soft_cap=logits_soft_cap, + q_data_type=q_data_type, + kv_data_type=kv_data_type, + data_type=data_type, + sm_scale=sm_scale, + rope_scale=rope_scale, + rope_theta=rope_theta, + non_blocking=non_blocking, + fixed_split_size=fixed_split_size, + disable_split_kv=disable_split_kv, ) - kv_data_type = ( - getattr(torch, kv_data_type) if isinstance(kv_data_type, str) else kv_data_type - ) - - if batch_size != self._fixed_batch_size: - raise ValueError( - "The batch size should be fixed in cudagraph mode, the runtime " - "batch size {} mismatches the batch size set during " - "initialization {}".format(batch_size, self._fixed_batch_size) - ) - if len(indices) > len(self._paged_kv_indices_buf): - raise ValueError( - "The size of indices should be less than or equal to the allocated buffer" - ) - - # host-to-device copy for the indptr buffer - self._paged_kv_indptr_buf.copy_(indptr_cpu, non_blocking=True) - # host-to-device copy for the last_page_len buffer - self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, non_blocking=True) - - qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") - - try: - # Make sure we pass exactly 19 arguments for tensor core version - args = [ - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - qo_indptr_host, - indptr_cpu, - seq_lens_cpu, - batch_size, # total_num_rows - batch_size, - num_qo_heads, - num_kv_heads, - page_size, - self.is_cuda_graph_enabled, - head_dim, - head_dim, - False, # causal - window_left, - ] - if self._backend == "fa2": - args.append(fixed_split_size) - args.append(disable_split_kv) - args.append(0) # num_colocated_ctas - self._plan_info = self._cached_module.plan( - *args, - ) - except Exception as e: - raise RuntimeError(f"Error in tensor core plan: {e}") from e - - self._pos_encoding_mode = pos_encoding_mode - self._window_left = window_left - self._logits_soft_cap = logits_soft_cap - self._sm_scale = sm_scale - self._rope_scale = rope_scale - self._rope_theta = rope_theta @triton.jit