[Attention] Use sparse prefill kernel for fp8 kv-cache in DeepSeek-v3.2 (#27532)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -22,10 +22,14 @@ from tests.v1.attention.utils import (
|
||||
)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.ops import flashmla
|
||||
from vllm.config import set_current_vllm_config
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import FlashMLASparseBackend
|
||||
from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
FlashMLASparseBackend,
|
||||
triton_convert_req_index_to_global_index,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import split_prefill_chunks
|
||||
|
||||
SPARSE_BACKEND_BATCH_SPECS = {
|
||||
name: BATCH_SPECS[name]
|
||||
@@ -114,8 +118,12 @@ def _quantize_dequantize_fp8_ds_mla(
|
||||
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.get_device_capability() < (9, 0),
|
||||
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
|
||||
)
|
||||
def test_sparse_backend_decode_correctness(
|
||||
dist_init, batch_name, kv_cache_dtype, tensor_parallel_size
|
||||
dist_init, batch_name, kv_cache_dtype, tensor_parallel_size, workspace_init
|
||||
):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is required for sparse MLA decode test")
|
||||
@@ -320,28 +328,29 @@ def test_sparse_backend_decode_correctness(
|
||||
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous())
|
||||
|
||||
impl_cls = FlashMLASparseBackend.get_impl_cls()
|
||||
impl = impl_cls(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=1,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
||||
logits_soft_cap=None,
|
||||
attn_type="decoder",
|
||||
kv_sharing_target_layer_name=None,
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
kv_b_proj=mock_kv_b_proj,
|
||||
indexer=mock_indexer,
|
||||
)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
impl = impl_cls(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=1,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
||||
logits_soft_cap=None,
|
||||
attn_type="decoder",
|
||||
kv_sharing_target_layer_name=None,
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
kv_b_proj=mock_kv_b_proj,
|
||||
indexer=mock_indexer,
|
||||
)
|
||||
|
||||
impl.process_weights_after_loading(dtype)
|
||||
impl.process_weights_after_loading(dtype)
|
||||
|
||||
layer = MockAttentionLayer(device)
|
||||
out_buffer = torch.empty(
|
||||
@@ -366,22 +375,192 @@ def test_sparse_backend_decode_correctness(
|
||||
torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.5, atol=0.5)
|
||||
|
||||
|
||||
def _triton_convert_reference_impl(
|
||||
req_ids: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
token_indices: torch.Tensor,
|
||||
block_size: int,
|
||||
num_topk_tokens: int,
|
||||
HAS_PREFILL_WORKSPACE: bool = False,
|
||||
prefill_workspace_request_ids: torch.Tensor | None = None,
|
||||
prefill_workspace_starts: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Reference implementation for triton_convert_req_index_to_global_index."""
|
||||
num_tokens = req_ids.shape[0]
|
||||
max_blocks_per_req = block_table.shape[1]
|
||||
result = torch.empty(
|
||||
num_tokens, num_topk_tokens, dtype=torch.int32, device=req_ids.device
|
||||
)
|
||||
|
||||
for token_id in range(num_tokens):
|
||||
req_id = req_ids[token_id].item()
|
||||
|
||||
# Determine if this token uses workspace or paged cache
|
||||
use_prefill_workspace = False
|
||||
workspace_start = 0
|
||||
if HAS_PREFILL_WORKSPACE and prefill_workspace_request_ids is not None:
|
||||
assert prefill_workspace_starts is not None
|
||||
prefill_req_id = prefill_workspace_request_ids[token_id].item()
|
||||
if prefill_req_id >= 0:
|
||||
use_prefill_workspace = True
|
||||
workspace_start = prefill_workspace_starts[prefill_req_id].item()
|
||||
|
||||
for idx_id in range(num_topk_tokens):
|
||||
token_idx = token_indices[token_id, idx_id].item()
|
||||
|
||||
if token_idx == -1:
|
||||
result[token_id, idx_id] = -1
|
||||
elif use_prefill_workspace:
|
||||
# Prefill + using prefill workspace: map to workspace offset
|
||||
result[token_id, idx_id] = workspace_start + token_idx
|
||||
else:
|
||||
# Decode: map to paged cache
|
||||
block_id = token_idx // block_size
|
||||
if block_id >= max_blocks_per_req:
|
||||
result[token_id, idx_id] = -1
|
||||
else:
|
||||
block_num = block_table[req_id, block_id].item()
|
||||
offset = token_idx % block_size
|
||||
result[token_id, idx_id] = block_num * block_size + offset
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [16, 64, 128])
|
||||
@pytest.mark.parametrize("num_topk_tokens", [128, 256, 512])
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.get_device_capability() < (9, 0),
|
||||
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
|
||||
)
|
||||
def test_triton_convert_req_index_to_global_index_decode_only(
|
||||
block_size, num_topk_tokens
|
||||
):
|
||||
device = torch.device("cuda")
|
||||
num_tokens = 8
|
||||
num_requests = 4
|
||||
max_blocks_per_req = 10
|
||||
|
||||
req_id = torch.randint(
|
||||
0, num_requests, (num_tokens,), dtype=torch.int32, device=device
|
||||
)
|
||||
block_table = torch.randint(
|
||||
0, 100, (num_requests, max_blocks_per_req), dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
token_indices = torch.randint(
|
||||
0,
|
||||
block_size * max_blocks_per_req,
|
||||
(num_tokens, num_topk_tokens),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Set some to -1 to test masking
|
||||
token_indices[0, :10] = -1
|
||||
token_indices[3, 50:60] = -1
|
||||
|
||||
# Set some to out of bounds
|
||||
token_indices[2, 100:110] = max_blocks_per_req * block_size
|
||||
token_indices[6, 150:160] = max_blocks_per_req * block_size
|
||||
|
||||
result = triton_convert_req_index_to_global_index(
|
||||
req_id,
|
||||
block_table,
|
||||
token_indices,
|
||||
BLOCK_SIZE=block_size,
|
||||
NUM_TOPK_TOKENS=num_topk_tokens,
|
||||
)
|
||||
|
||||
reference_result = _triton_convert_reference_impl(
|
||||
req_id,
|
||||
block_table,
|
||||
token_indices,
|
||||
block_size,
|
||||
num_topk_tokens,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(result, reference_result, rtol=0, atol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.get_device_capability() < (9, 0),
|
||||
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
|
||||
)
|
||||
def test_triton_convert_req_index_to_global_index_with_prefill_workspace(block_size):
|
||||
device = torch.device("cuda")
|
||||
num_requests = 4
|
||||
max_blocks_per_req = 8
|
||||
num_topk_tokens = 128
|
||||
|
||||
# First 6 tokens are decode (reqs 0, 1), last 6 are prefill (reqs 2, 3)
|
||||
req_id = torch.tensor(
|
||||
[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], dtype=torch.int32, device=device
|
||||
)
|
||||
prefill_workspace_request_ids = torch.tensor(
|
||||
[-1, -1, -1, -1, -1, -1, 0, 0, 0, 1, 1, 1], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# Workspace starts for the 2 prefill reqs: req 2 starts at 0, req 3 starts at 100
|
||||
prefill_workspace_starts = torch.tensor([0, 100], dtype=torch.int32, device=device)
|
||||
|
||||
block_table = torch.randint(
|
||||
0, 50, (num_requests, max_blocks_per_req), dtype=torch.int32, device=device
|
||||
)
|
||||
token_indices = torch.randint(
|
||||
0,
|
||||
block_size * max_blocks_per_req,
|
||||
(req_id.shape[0], num_topk_tokens),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Set some to -1 to test masking
|
||||
token_indices[0, :10] = -1
|
||||
token_indices[3, 50:60] = -1
|
||||
|
||||
# Set some to out of bounds
|
||||
token_indices[2, 100:110] = max_blocks_per_req * block_size
|
||||
token_indices[6, 150:160] = max_blocks_per_req * block_size
|
||||
|
||||
result = triton_convert_req_index_to_global_index(
|
||||
req_id,
|
||||
block_table,
|
||||
token_indices,
|
||||
BLOCK_SIZE=block_size,
|
||||
NUM_TOPK_TOKENS=num_topk_tokens,
|
||||
HAS_PREFILL_WORKSPACE=True,
|
||||
prefill_workspace_request_ids=prefill_workspace_request_ids,
|
||||
prefill_workspace_starts=prefill_workspace_starts,
|
||||
)
|
||||
|
||||
reference_result = _triton_convert_reference_impl(
|
||||
req_id,
|
||||
block_table,
|
||||
token_indices,
|
||||
block_size,
|
||||
num_topk_tokens,
|
||||
HAS_PREFILL_WORKSPACE=True,
|
||||
prefill_workspace_request_ids=prefill_workspace_request_ids,
|
||||
prefill_workspace_starts=prefill_workspace_starts,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(result, reference_result, rtol=0, atol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens,max_buf,start,expected",
|
||||
"seq_lens,max_buf,expected",
|
||||
[
|
||||
# Basic split: totals per chunk ≤ max_buf
|
||||
(torch.tensor([2, 3, 4, 2]), 5, 0, [(0, 2), (2, 3), (3, 4)]),
|
||||
# Non-zero start index
|
||||
(torch.tensor([2, 3, 4, 2]), 5, 1, [(1, 2), (2, 3), (3, 4)]),
|
||||
# Exact fits should split between items when adding the next would
|
||||
# overflow
|
||||
(torch.tensor([5, 5, 5]), 5, 0, [(0, 1), (1, 2), (2, 3)]),
|
||||
(torch.tensor([2, 3, 4, 2]), 5, [(0, 2), (2, 3), (3, 4)]),
|
||||
# Exact fits should split between items when adding the next would overflow
|
||||
(torch.tensor([5, 5, 5]), 5, [(0, 1), (1, 2), (2, 3)]),
|
||||
# All requests fit in a single chunk
|
||||
(torch.tensor([1, 1, 1]), 10, 0, [(0, 3)]),
|
||||
# Large buffer with non-zero start
|
||||
(torch.tensor([4, 4, 4]), 100, 1, [(1, 3)]),
|
||||
(torch.tensor([1, 1, 1]), 10, [(0, 3)]),
|
||||
# Large buffer
|
||||
(torch.tensor([4, 4, 4]), 100, [(0, 3)]),
|
||||
],
|
||||
)
|
||||
def test_split_prefill_chunks(seq_lens, max_buf, start, expected):
|
||||
out = split_prefill_chunks(seq_lens, max_buf, start)
|
||||
def test_split_prefill_chunks(seq_lens, max_buf, expected):
|
||||
out = split_prefill_chunks(seq_lens, max_buf)
|
||||
assert out == expected
|
||||
|
||||
Reference in New Issue
Block a user