[Bugfix][MLA] Add logits size budget to sparse indexer prefill chunking (#36178)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
(cherry picked from commit eb47454987)
This commit is contained in:
@@ -42,6 +42,7 @@ from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
FlashMLASparseBackend,
|
||||
triton_convert_req_index_to_global_index,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.indexer import split_indexer_prefill_chunks
|
||||
from vllm.v1.attention.backends.utils import split_prefill_chunks
|
||||
from vllm.v1.attention.ops import flashmla
|
||||
|
||||
@@ -716,6 +717,81 @@ def test_split_prefill_chunks(seq_lens, max_buf, expected):
|
||||
assert out == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens,query_lens,workspace_size,max_logits_bytes,expected",
|
||||
[
|
||||
# Logits constraint triggers split (M*N exceeds budget)
|
||||
# req0: M=10, N=100 -> 1000 elems (4000 bytes) - fits in 5000
|
||||
# req1: adding M=10, N=100 -> new_M=20, new_N=200 -> 4000 elems > 1250
|
||||
(
|
||||
torch.tensor([100, 100, 100]),
|
||||
torch.tensor([10, 10, 10]),
|
||||
1000, # workspace allows all
|
||||
5000, # 1250 float32 elems -> forces split
|
||||
[
|
||||
(slice(0, 1), slice(0, 10)),
|
||||
(slice(1, 2), slice(0, 10)),
|
||||
(slice(2, 3), slice(0, 10)),
|
||||
],
|
||||
),
|
||||
# Both constraints satisfied - all fit in one chunk
|
||||
(
|
||||
torch.tensor([10, 10, 10]),
|
||||
torch.tensor([5, 5, 5]),
|
||||
100,
|
||||
10000, # 2500 elems, M*N = 15*30 = 450 < 2500
|
||||
[(slice(0, 3), slice(0, 15))],
|
||||
),
|
||||
# Workspace constraint triggers first
|
||||
(
|
||||
torch.tensor([50, 50, 50]),
|
||||
torch.tensor([1, 1, 1]),
|
||||
50, # workspace only fits one at a time
|
||||
1000000, # logits budget is huge
|
||||
[
|
||||
(slice(0, 1), slice(0, 1)),
|
||||
(slice(1, 2), slice(0, 1)),
|
||||
(slice(2, 3), slice(0, 1)),
|
||||
],
|
||||
),
|
||||
# Greedy filling: first two fit, third doesn't
|
||||
# req0: M=5, N=10 -> 50 elems
|
||||
# req0+1: M=10, N=20 -> 200 elems <= 250
|
||||
# req0+1+2: M=15, N=30 -> 450 elems > 250
|
||||
(
|
||||
torch.tensor([10, 10, 10]),
|
||||
torch.tensor([5, 5, 5]),
|
||||
100,
|
||||
1000, # 250 elems
|
||||
[(slice(0, 2), slice(0, 10)), (slice(2, 3), slice(0, 5))],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_split_indexer_prefill_chunks(
|
||||
seq_lens, query_lens, workspace_size, max_logits_bytes, expected
|
||||
):
|
||||
out = split_indexer_prefill_chunks(
|
||||
seq_lens,
|
||||
query_lens,
|
||||
workspace_size,
|
||||
max_logits_bytes,
|
||||
)
|
||||
assert out == expected
|
||||
|
||||
|
||||
def test_split_indexer_prefill_chunks_single_request_overflow():
|
||||
"""Test that single request exceeding budget is sub-chunked on query dim."""
|
||||
seq_lens = torch.tensor([1000, 50])
|
||||
query_lens = torch.tensor([100, 5])
|
||||
|
||||
out = split_indexer_prefill_chunks(seq_lens, query_lens, 2000, 1000)
|
||||
# max_logits_elems = 250, N=1000 -> max_q = 1 -> 100 query sub-chunks
|
||||
expected = [(slice(0, 1), slice(i, i + 1)) for i in range(100)]
|
||||
# req1: M=5, N=50 -> 250 elems fits budget
|
||||
expected.append((slice(1, 2), slice(0, 5)))
|
||||
assert out == expected
|
||||
|
||||
|
||||
def test_triton_convert_returns_valid_counts():
|
||||
"""Test that return_valid_counts correctly counts non-negative indices."""
|
||||
device = torch.device("cuda")
|
||||
|
||||
Reference in New Issue
Block a user