[Deepseek v3.2] Support indexer prefill chunking (#25999)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
Chen Zhang
2025-10-02 10:29:12 -07:00
committed by simon-mo
parent 9d9a2b77f1
commit c75c2e70d6
3 changed files with 149 additions and 79 deletions

View File

@@ -22,6 +22,7 @@ from vllm.utils import cdiv
from vllm.v1.attention.backends.mla.flashmla_sparse import (
FlashMLASparseBackend, FlashMLASparseDecodeAndContextMetadata,
FlashMLASparseImpl, FlashMLASparseMetadata)
from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks
SPARSE_BACKEND_BATCH_SPECS = {
name: BATCH_SPECS[name]
@@ -424,3 +425,24 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
sdpa_reference,
rtol=0.5,
atol=0.5)
@pytest.mark.parametrize(
"seq_lens,max_buf,start,expected",
[
# Basic split: totals per chunk ≤ max_buf
(torch.tensor([2, 3, 4, 2]), 5, 0, [(0, 2), (2, 3), (3, 4)]),
# Non-zero start index
(torch.tensor([2, 3, 4, 2]), 5, 1, [(1, 2), (2, 3), (3, 4)]),
# Exact fits should split between items when adding the next would
# overflow
(torch.tensor([5, 5, 5]), 5, 0, [(0, 1), (1, 2), (2, 3)]),
# All requests fit in a single chunk
(torch.tensor([1, 1, 1]), 10, 0, [(0, 3)]),
# Large buffer with non-zero start
(torch.tensor([4, 4, 4]), 100, 1, [(1, 3)]),
],
)
def test_split_prefill_chunks(seq_lens, max_buf, start, expected):
out = split_prefill_chunks(seq_lens, max_buf, start)
assert out == expected