[Attention] Use sparse prefill kernel for fp8 kv-cache in DeepSeek-v3.2 (#27532)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2025-12-12 08:57:47 -05:00
committed by GitHub
parent 91401c7a26
commit 3e41992fec
30 changed files with 1372 additions and 256 deletions

View File

@@ -22,10 +22,14 @@ from tests.v1.attention.utils import (
)
from vllm import _custom_ops as ops
from vllm.attention.ops import flashmla
from vllm.config import set_current_vllm_config
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mla.flashmla_sparse import FlashMLASparseBackend
from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks
from vllm.v1.attention.backends.mla.flashmla_sparse import (
FlashMLASparseBackend,
triton_convert_req_index_to_global_index,
)
from vllm.v1.attention.backends.utils import split_prefill_chunks
SPARSE_BACKEND_BATCH_SPECS = {
name: BATCH_SPECS[name]
@@ -114,8 +118,12 @@ def _quantize_dequantize_fp8_ds_mla(
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
@pytest.mark.skipif(
torch.cuda.get_device_capability() < (9, 0),
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
)
def test_sparse_backend_decode_correctness(
dist_init, batch_name, kv_cache_dtype, tensor_parallel_size
dist_init, batch_name, kv_cache_dtype, tensor_parallel_size, workspace_init
):
if not torch.cuda.is_available():
pytest.skip("CUDA is required for sparse MLA decode test")
@@ -320,28 +328,29 @@ def test_sparse_backend_decode_correctness(
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous())
impl_cls = FlashMLASparseBackend.get_impl_cls()
impl = impl_cls(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=1,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
logits_soft_cap=None,
attn_type="decoder",
kv_sharing_target_layer_name=None,
q_lora_rank=None,
kv_lora_rank=kv_lora_rank,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
v_head_dim=v_head_dim,
kv_b_proj=mock_kv_b_proj,
indexer=mock_indexer,
)
with set_current_vllm_config(vllm_config):
impl = impl_cls(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=1,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
logits_soft_cap=None,
attn_type="decoder",
kv_sharing_target_layer_name=None,
q_lora_rank=None,
kv_lora_rank=kv_lora_rank,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
v_head_dim=v_head_dim,
kv_b_proj=mock_kv_b_proj,
indexer=mock_indexer,
)
impl.process_weights_after_loading(dtype)
impl.process_weights_after_loading(dtype)
layer = MockAttentionLayer(device)
out_buffer = torch.empty(
@@ -366,22 +375,192 @@ def test_sparse_backend_decode_correctness(
torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.5, atol=0.5)
def _triton_convert_reference_impl(
req_ids: torch.Tensor,
block_table: torch.Tensor,
token_indices: torch.Tensor,
block_size: int,
num_topk_tokens: int,
HAS_PREFILL_WORKSPACE: bool = False,
prefill_workspace_request_ids: torch.Tensor | None = None,
prefill_workspace_starts: torch.Tensor | None = None,
) -> torch.Tensor:
"""Reference implementation for triton_convert_req_index_to_global_index."""
num_tokens = req_ids.shape[0]
max_blocks_per_req = block_table.shape[1]
result = torch.empty(
num_tokens, num_topk_tokens, dtype=torch.int32, device=req_ids.device
)
for token_id in range(num_tokens):
req_id = req_ids[token_id].item()
# Determine if this token uses workspace or paged cache
use_prefill_workspace = False
workspace_start = 0
if HAS_PREFILL_WORKSPACE and prefill_workspace_request_ids is not None:
assert prefill_workspace_starts is not None
prefill_req_id = prefill_workspace_request_ids[token_id].item()
if prefill_req_id >= 0:
use_prefill_workspace = True
workspace_start = prefill_workspace_starts[prefill_req_id].item()
for idx_id in range(num_topk_tokens):
token_idx = token_indices[token_id, idx_id].item()
if token_idx == -1:
result[token_id, idx_id] = -1
elif use_prefill_workspace:
# Prefill + using prefill workspace: map to workspace offset
result[token_id, idx_id] = workspace_start + token_idx
else:
# Decode: map to paged cache
block_id = token_idx // block_size
if block_id >= max_blocks_per_req:
result[token_id, idx_id] = -1
else:
block_num = block_table[req_id, block_id].item()
offset = token_idx % block_size
result[token_id, idx_id] = block_num * block_size + offset
return result
@pytest.mark.parametrize("block_size", [16, 64, 128])
@pytest.mark.parametrize("num_topk_tokens", [128, 256, 512])
@pytest.mark.skipif(
torch.cuda.get_device_capability() < (9, 0),
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
)
def test_triton_convert_req_index_to_global_index_decode_only(
block_size, num_topk_tokens
):
device = torch.device("cuda")
num_tokens = 8
num_requests = 4
max_blocks_per_req = 10
req_id = torch.randint(
0, num_requests, (num_tokens,), dtype=torch.int32, device=device
)
block_table = torch.randint(
0, 100, (num_requests, max_blocks_per_req), dtype=torch.int32, device=device
)
token_indices = torch.randint(
0,
block_size * max_blocks_per_req,
(num_tokens, num_topk_tokens),
dtype=torch.int32,
device=device,
)
# Set some to -1 to test masking
token_indices[0, :10] = -1
token_indices[3, 50:60] = -1
# Set some to out of bounds
token_indices[2, 100:110] = max_blocks_per_req * block_size
token_indices[6, 150:160] = max_blocks_per_req * block_size
result = triton_convert_req_index_to_global_index(
req_id,
block_table,
token_indices,
BLOCK_SIZE=block_size,
NUM_TOPK_TOKENS=num_topk_tokens,
)
reference_result = _triton_convert_reference_impl(
req_id,
block_table,
token_indices,
block_size,
num_topk_tokens,
)
torch.testing.assert_close(result, reference_result, rtol=0, atol=0)
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.skipif(
torch.cuda.get_device_capability() < (9, 0),
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
)
def test_triton_convert_req_index_to_global_index_with_prefill_workspace(block_size):
device = torch.device("cuda")
num_requests = 4
max_blocks_per_req = 8
num_topk_tokens = 128
# First 6 tokens are decode (reqs 0, 1), last 6 are prefill (reqs 2, 3)
req_id = torch.tensor(
[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], dtype=torch.int32, device=device
)
prefill_workspace_request_ids = torch.tensor(
[-1, -1, -1, -1, -1, -1, 0, 0, 0, 1, 1, 1], dtype=torch.int32, device=device
)
# Workspace starts for the 2 prefill reqs: req 2 starts at 0, req 3 starts at 100
prefill_workspace_starts = torch.tensor([0, 100], dtype=torch.int32, device=device)
block_table = torch.randint(
0, 50, (num_requests, max_blocks_per_req), dtype=torch.int32, device=device
)
token_indices = torch.randint(
0,
block_size * max_blocks_per_req,
(req_id.shape[0], num_topk_tokens),
dtype=torch.int32,
device=device,
)
# Set some to -1 to test masking
token_indices[0, :10] = -1
token_indices[3, 50:60] = -1
# Set some to out of bounds
token_indices[2, 100:110] = max_blocks_per_req * block_size
token_indices[6, 150:160] = max_blocks_per_req * block_size
result = triton_convert_req_index_to_global_index(
req_id,
block_table,
token_indices,
BLOCK_SIZE=block_size,
NUM_TOPK_TOKENS=num_topk_tokens,
HAS_PREFILL_WORKSPACE=True,
prefill_workspace_request_ids=prefill_workspace_request_ids,
prefill_workspace_starts=prefill_workspace_starts,
)
reference_result = _triton_convert_reference_impl(
req_id,
block_table,
token_indices,
block_size,
num_topk_tokens,
HAS_PREFILL_WORKSPACE=True,
prefill_workspace_request_ids=prefill_workspace_request_ids,
prefill_workspace_starts=prefill_workspace_starts,
)
torch.testing.assert_close(result, reference_result, rtol=0, atol=0)
@pytest.mark.parametrize(
"seq_lens,max_buf,start,expected",
"seq_lens,max_buf,expected",
[
# Basic split: totals per chunk ≤ max_buf
(torch.tensor([2, 3, 4, 2]), 5, 0, [(0, 2), (2, 3), (3, 4)]),
# Non-zero start index
(torch.tensor([2, 3, 4, 2]), 5, 1, [(1, 2), (2, 3), (3, 4)]),
# Exact fits should split between items when adding the next would
# overflow
(torch.tensor([5, 5, 5]), 5, 0, [(0, 1), (1, 2), (2, 3)]),
(torch.tensor([2, 3, 4, 2]), 5, [(0, 2), (2, 3), (3, 4)]),
# Exact fits should split between items when adding the next would overflow
(torch.tensor([5, 5, 5]), 5, [(0, 1), (1, 2), (2, 3)]),
# All requests fit in a single chunk
(torch.tensor([1, 1, 1]), 10, 0, [(0, 3)]),
# Large buffer with non-zero start
(torch.tensor([4, 4, 4]), 100, 1, [(1, 3)]),
(torch.tensor([1, 1, 1]), 10, [(0, 3)]),
# Large buffer
(torch.tensor([4, 4, 4]), 100, [(0, 3)]),
],
)
def test_split_prefill_chunks(seq_lens, max_buf, start, expected):
out = split_prefill_chunks(seq_lens, max_buf, start)
def test_split_prefill_chunks(seq_lens, max_buf, expected):
out = split_prefill_chunks(seq_lens, max_buf)
assert out == expected