[Bugfix][MLA] Add logits size budget to sparse indexer prefill chunking (#36178)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
(cherry picked from commit eb47454987)
This commit is contained in:
Lucas Wilkinson
2026-04-01 00:15:53 -04:00
committed by khluu
parent 268bed9cf3
commit 0ee3b7fc3d
4 changed files with 191 additions and 31 deletions

View File

@@ -42,6 +42,7 @@ from vllm.v1.attention.backends.mla.flashmla_sparse import (
FlashMLASparseBackend,
triton_convert_req_index_to_global_index,
)
from vllm.v1.attention.backends.mla.indexer import split_indexer_prefill_chunks
from vllm.v1.attention.backends.utils import split_prefill_chunks
from vllm.v1.attention.ops import flashmla
@@ -716,6 +717,81 @@ def test_split_prefill_chunks(seq_lens, max_buf, expected):
assert out == expected
@pytest.mark.parametrize(
"seq_lens,query_lens,workspace_size,max_logits_bytes,expected",
[
# Logits constraint triggers split (M*N exceeds budget)
# req0: M=10, N=100 -> 1000 elems (4000 bytes) - fits in 5000
# req1: adding M=10, N=100 -> new_M=20, new_N=200 -> 4000 elems > 1250
(
torch.tensor([100, 100, 100]),
torch.tensor([10, 10, 10]),
1000, # workspace allows all
5000, # 1250 float32 elems -> forces split
[
(slice(0, 1), slice(0, 10)),
(slice(1, 2), slice(0, 10)),
(slice(2, 3), slice(0, 10)),
],
),
# Both constraints satisfied - all fit in one chunk
(
torch.tensor([10, 10, 10]),
torch.tensor([5, 5, 5]),
100,
10000, # 2500 elems, M*N = 15*30 = 450 < 2500
[(slice(0, 3), slice(0, 15))],
),
# Workspace constraint triggers first
(
torch.tensor([50, 50, 50]),
torch.tensor([1, 1, 1]),
50, # workspace only fits one at a time
1000000, # logits budget is huge
[
(slice(0, 1), slice(0, 1)),
(slice(1, 2), slice(0, 1)),
(slice(2, 3), slice(0, 1)),
],
),
# Greedy filling: first two fit, third doesn't
# req0: M=5, N=10 -> 50 elems
# req0+1: M=10, N=20 -> 200 elems <= 250
# req0+1+2: M=15, N=30 -> 450 elems > 250
(
torch.tensor([10, 10, 10]),
torch.tensor([5, 5, 5]),
100,
1000, # 250 elems
[(slice(0, 2), slice(0, 10)), (slice(2, 3), slice(0, 5))],
),
],
)
def test_split_indexer_prefill_chunks(
seq_lens, query_lens, workspace_size, max_logits_bytes, expected
):
out = split_indexer_prefill_chunks(
seq_lens,
query_lens,
workspace_size,
max_logits_bytes,
)
assert out == expected
def test_split_indexer_prefill_chunks_single_request_overflow():
"""Test that single request exceeding budget is sub-chunked on query dim."""
seq_lens = torch.tensor([1000, 50])
query_lens = torch.tensor([100, 5])
out = split_indexer_prefill_chunks(seq_lens, query_lens, 2000, 1000)
# max_logits_elems = 250, N=1000 -> max_q = 1 -> 100 query sub-chunks
expected = [(slice(0, 1), slice(i, i + 1)) for i in range(100)]
# req1: M=5, N=50 -> 250 elems fits budget
expected.append((slice(1, 2), slice(0, 5)))
assert out == expected
def test_triton_convert_returns_valid_counts():
"""Test that return_valid_counts correctly counts non-negative indices."""
device = torch.device("cuda")