[Update] Use FlashInfer fast_decode_plan directly instead of replication (#34687)
Signed-off-by: Andrii <askliar@nvidia.com> Co-authored-by: Andrii <askliar@nvidia.com>
This commit is contained in:
@@ -84,6 +84,209 @@ def ref_paged_attn(
|
||||
return torch.cat(outputs, dim=0)
|
||||
|
||||
|
||||
def _make_paged_kv_metadata(
|
||||
kv_lens: list[int],
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Build paged-KV metadata tensors for fast_plan_decode tests.
|
||||
|
||||
Returns:
|
||||
kv_indptr – CPU int32, shape [num_seqs + 1]
|
||||
kv_indices – CUDA int32, shape [total_blocks]
|
||||
kv_last_page_lens – CPU int32, shape [num_seqs]
|
||||
block_tables – CUDA int32, shape [num_seqs, max_blocks_per_seq]
|
||||
"""
|
||||
num_seqs = len(kv_lens)
|
||||
max_blocks = (max(kv_lens) + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, num_blocks, (num_seqs, max_blocks), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
indptr_list = [0]
|
||||
indices_list: list[int] = []
|
||||
last_lens_list: list[int] = []
|
||||
for i, seq_len in enumerate(kv_lens):
|
||||
n = (seq_len + block_size - 1) // block_size
|
||||
indices_list.extend(block_tables[i, :n].cpu().tolist())
|
||||
indptr_list.append(indptr_list[-1] + n)
|
||||
last_lens_list.append(seq_len % block_size or block_size)
|
||||
|
||||
return (
|
||||
torch.tensor(indptr_list, dtype=torch.int32, device="cpu"),
|
||||
torch.tensor(indices_list, dtype=torch.int32, device="cuda"),
|
||||
torch.tensor(last_lens_list, dtype=torch.int32, device="cpu"),
|
||||
block_tables,
|
||||
)
|
||||
|
||||
|
||||
def _make_cg_decode_wrapper(
|
||||
num_seqs: int,
|
||||
kv_indices_buffer: torch.Tensor,
|
||||
workspace_buffer: torch.Tensor,
|
||||
use_tensor_cores: bool = True,
|
||||
) -> "flashinfer.BatchDecodeWithPagedKVCacheWrapper":
|
||||
"""Create a cudagraph-enabled BatchDecodeWithPagedKVCacheWrapper.
|
||||
|
||||
*kv_indices_buffer* is shared with the caller so that fast_plan_decode
|
||||
can avoid the device-to-device index copy on subsequent (cudagraph) calls.
|
||||
"""
|
||||
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
"NHD",
|
||||
use_cuda_graph=True,
|
||||
paged_kv_indptr_buffer=torch.zeros(
|
||||
num_seqs + 1, dtype=torch.int32, device="cuda"
|
||||
),
|
||||
paged_kv_indices_buffer=kv_indices_buffer,
|
||||
paged_kv_last_page_len_buffer=torch.zeros(
|
||||
num_seqs, dtype=torch.int32, device="cuda"
|
||||
),
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
)
|
||||
|
||||
|
||||
def test_fast_decode_plan_importable() -> None:
|
||||
"""fast_decode_plan must be importable from flashinfer.decode.
|
||||
|
||||
This is a forward-compatibility smoke test: if FlashInfer reorganises its
|
||||
public API the import will fail before any other test does.
|
||||
"""
|
||||
from flashinfer.decode import fast_decode_plan # noqa: F401
|
||||
|
||||
assert callable(fast_decode_plan)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@torch.inference_mode
|
||||
def test_fast_plan_decode_warmup_uses_full_plan(dtype: torch.dtype) -> None:
|
||||
"""On the first call fast_plan_decode must route through self.plan() and
|
||||
flip vllm_first_call to False on the wrapper object."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from vllm.v1.attention.backends.flashinfer import fast_plan_decode
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
set_random_seed(0)
|
||||
|
||||
kv_lens = [128, 64]
|
||||
block_size = 16
|
||||
num_seqs = len(kv_lens)
|
||||
num_query_heads, num_kv_heads = 8, 2
|
||||
head_size = 128
|
||||
|
||||
kv_indptr, kv_indices, kv_last_page_lens, _ = _make_paged_kv_metadata(
|
||||
kv_lens, block_size, NUM_BLOCKS
|
||||
)
|
||||
|
||||
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||
wrapper = _make_cg_decode_wrapper(num_seqs, kv_indices.clone(), workspace)
|
||||
|
||||
assert getattr(wrapper, "vllm_first_call", True) is True
|
||||
|
||||
with patch.object(wrapper, "plan", wraps=wrapper.plan) as mock_plan:
|
||||
fast_plan_decode(
|
||||
wrapper,
|
||||
indptr_cpu=kv_indptr,
|
||||
indices=kv_indices,
|
||||
last_page_len_cpu=kv_last_page_lens,
|
||||
num_qo_heads=num_query_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
page_size=block_size,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
)
|
||||
mock_plan.assert_called_once()
|
||||
|
||||
assert wrapper.vllm_first_call is False, (
|
||||
"vllm_first_call should be False after the first fast_plan_decode call"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@torch.inference_mode
|
||||
def test_fast_plan_decode_matches_full_plan(
|
||||
kv_lens: list[int],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
"""fast_plan_decode's cudagraph path (delegating to FlashInfer's
|
||||
fast_decode_plan) must produce attention output numerically identical to
|
||||
a standard plan() call.
|
||||
|
||||
Both the warmup call (self.plan) and the subsequent fast call
|
||||
(fast_decode_plan) are verified against the same reference.
|
||||
"""
|
||||
from vllm.v1.attention.backends.flashinfer import fast_plan_decode
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
set_random_seed(0)
|
||||
num_seqs = len(kv_lens)
|
||||
num_query_heads, num_kv_heads = num_heads
|
||||
|
||||
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||
key_value_cache = torch.randn(
|
||||
NUM_BLOCKS, 2, block_size, num_kv_heads, head_size, dtype=dtype
|
||||
)
|
||||
|
||||
kv_indptr, kv_indices, kv_last_page_lens, _ = _make_paged_kv_metadata(
|
||||
kv_lens, block_size, NUM_BLOCKS
|
||||
)
|
||||
|
||||
# Reference output via the standard plan()
|
||||
workspace_ref = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||
ref_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_ref, "NHD", use_tensor_cores=True
|
||||
)
|
||||
ref_wrapper.plan(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_query_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
"NONE",
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
)
|
||||
ref_output = ref_wrapper.run(query, key_value_cache)
|
||||
|
||||
# CUDAGraph wrapper exercised through fast_plan_decode
|
||||
kv_indices_buf = kv_indices.clone()
|
||||
workspace_cg = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||
cg_wrapper = _make_cg_decode_wrapper(num_seqs, kv_indices_buf, workspace_cg)
|
||||
|
||||
plan_kwargs: dict = dict(
|
||||
indptr_cpu=kv_indptr,
|
||||
indices=kv_indices_buf,
|
||||
last_page_len_cpu=kv_last_page_lens,
|
||||
num_qo_heads=num_query_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
page_size=block_size,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
)
|
||||
|
||||
# First call – warmup path (routes through self.plan)
|
||||
fast_plan_decode(cg_wrapper, **plan_kwargs)
|
||||
warmup_output = cg_wrapper.run(query, key_value_cache)
|
||||
torch.testing.assert_close(warmup_output, ref_output, atol=1e-2, rtol=1e-2)
|
||||
|
||||
# Second call – fast path (routes through fast_decode_plan from FlashInfer)
|
||||
fast_plan_decode(cg_wrapper, **plan_kwargs)
|
||||
fast_output = cg_wrapper.run(query, key_value_cache)
|
||||
torch.testing.assert_close(fast_output, ref_output, atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
|
||||
Reference in New Issue
Block a user