[Bugfix][MLA] Add logits size budget to sparse indexer prefill chunking (#36178)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-04-01 00:15:53 -04:00
committed by GitHub
parent 116f4be405
commit eb47454987
4 changed files with 191 additions and 31 deletions

View File

@@ -42,6 +42,7 @@ from vllm.v1.attention.backends.mla.flashmla_sparse import (
FlashMLASparseBackend,
triton_convert_req_index_to_global_index,
)
from vllm.v1.attention.backends.mla.indexer import split_indexer_prefill_chunks
from vllm.v1.attention.backends.utils import split_prefill_chunks
from vllm.v1.attention.ops import flashmla
@@ -716,6 +717,81 @@ def test_split_prefill_chunks(seq_lens, max_buf, expected):
assert out == expected
@pytest.mark.parametrize(
"seq_lens,query_lens,workspace_size,max_logits_bytes,expected",
[
# Logits constraint triggers split (M*N exceeds budget)
# req0: M=10, N=100 -> 1000 elems (4000 bytes) - fits in 5000
# req1: adding M=10, N=100 -> new_M=20, new_N=200 -> 4000 elems > 1250
(
torch.tensor([100, 100, 100]),
torch.tensor([10, 10, 10]),
1000, # workspace allows all
5000, # 1250 float32 elems -> forces split
[
(slice(0, 1), slice(0, 10)),
(slice(1, 2), slice(0, 10)),
(slice(2, 3), slice(0, 10)),
],
),
# Both constraints satisfied - all fit in one chunk
(
torch.tensor([10, 10, 10]),
torch.tensor([5, 5, 5]),
100,
10000, # 2500 elems, M*N = 15*30 = 450 < 2500
[(slice(0, 3), slice(0, 15))],
),
# Workspace constraint triggers first
(
torch.tensor([50, 50, 50]),
torch.tensor([1, 1, 1]),
50, # workspace only fits one at a time
1000000, # logits budget is huge
[
(slice(0, 1), slice(0, 1)),
(slice(1, 2), slice(0, 1)),
(slice(2, 3), slice(0, 1)),
],
),
# Greedy filling: first two fit, third doesn't
# req0: M=5, N=10 -> 50 elems
# req0+1: M=10, N=20 -> 200 elems <= 250
# req0+1+2: M=15, N=30 -> 450 elems > 250
(
torch.tensor([10, 10, 10]),
torch.tensor([5, 5, 5]),
100,
1000, # 250 elems
[(slice(0, 2), slice(0, 10)), (slice(2, 3), slice(0, 5))],
),
],
)
def test_split_indexer_prefill_chunks(
seq_lens, query_lens, workspace_size, max_logits_bytes, expected
):
out = split_indexer_prefill_chunks(
seq_lens,
query_lens,
workspace_size,
max_logits_bytes,
)
assert out == expected
def test_split_indexer_prefill_chunks_single_request_overflow():
"""Test that single request exceeding budget is sub-chunked on query dim."""
seq_lens = torch.tensor([1000, 50])
query_lens = torch.tensor([100, 5])
out = split_indexer_prefill_chunks(seq_lens, query_lens, 2000, 1000)
# max_logits_elems = 250, N=1000 -> max_q = 1 -> 100 query sub-chunks
expected = [(slice(0, 1), slice(i, i + 1)) for i in range(100)]
# req1: M=5, N=50 -> 250 elems fits budget
expected.append((slice(1, 2), slice(0, 5)))
assert out == expected
def test_triton_convert_returns_valid_counts():
"""Test that return_valid_counts correctly counts non-negative indices."""
device = torch.device("cuda")

View File

@@ -55,6 +55,7 @@ if TYPE_CHECKING:
VLLM_CPU_INT4_W4A8: bool = True
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
VLLM_XLA_CHECK_RECOMPILATION: bool = False
VLLM_SPARSE_INDEXER_MAX_LOGITS_MB: int = 512
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto"
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True
@@ -861,6 +862,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
),
# Enable SPMD mode for TPU backend.
"VLLM_XLA_USE_SPMD": lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))),
# Maximum size (in MB) for logits tensor in sparse MLA indexer prefill chunks.
# Bounds the [M, N] float32 logits tensor to prevent CUDA OOM.
# Default: 512 MB
"VLLM_SPARSE_INDEXER_MAX_LOGITS_MB": lambda: int(
os.getenv("VLLM_SPARSE_INDEXER_MAX_LOGITS_MB", "512")
),
# If set, the OpenAI API server will stay alive even after the underlying
# AsyncLLMEngine errors and stops serving requests
"VLLM_KEEP_ALIVE_ON_ENGINE_DEATH": lambda: bool(

View File

@@ -4,6 +4,7 @@
import torch
import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
@@ -51,6 +52,14 @@ def sparse_attn_indexer(
((total_seq_lens, head_dim), torch.float8_e4m3fn),
((total_seq_lens, 4), torch.uint8),
)
# Dummy allocation to simulate for peak logits tensor memory during inference.
# FP8 elements so elements == bytes
max_logits_elems = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024
_ = torch.empty(
max_logits_elems, dtype=torch.uint8, device=hidden_states.device
)
return sparse_attn_indexer_fake(
hidden_states,
k_cache_prefix,
@@ -101,13 +110,16 @@ def sparse_attn_indexer(
for chunk in prefill_metadata.chunks:
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
ops.cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
chunk.block_table,
chunk.cu_seq_lens,
)
if not chunk.skip_kv_gather:
ops.cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
chunk.block_table,
chunk.cu_seq_lens,
)
logits = fp8_mqa_logits(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32).flatten()),

View File

@@ -4,6 +4,7 @@ from dataclasses import dataclass
import torch
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
@@ -22,7 +23,6 @@ from vllm.v1.attention.backend import (
)
from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills,
split_prefill_chunks,
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.cp_utils import get_total_cp_world_size
@@ -30,6 +30,55 @@ from vllm.v1.worker.cp_utils import get_total_cp_world_size
logger = init_logger(__name__)
def split_indexer_prefill_chunks(
seq_lens_cpu: torch.Tensor,
query_lens_cpu: torch.Tensor,
workspace_size: int,
max_logits_bytes: int,
request_offset: int = 0,
) -> list[tuple[slice, slice]]:
"""
Split prefill requests into chunks for the sparse indexer, respecting:
- N constraint: total_seq_lens <= workspace_size (existing O(N) workspace)
- Logits constraint: M * N * 4 <= max_logits_bytes
When a single request-level chunk still exceeds the logits budget,
sub-chunks on the query dimension (M) to bound peak memory.
Returns list of (req_slice, query_slice) tuples.
"""
chunks: list[tuple[slice, slice]] = []
n = len(seq_lens_cpu)
max_logits_elems = max_logits_bytes // 4
end = 0
while end < n:
start, chunk_m, chunk_n = end, 0, 0
while end < n:
q, s = query_lens_cpu[end].item(), seq_lens_cpu[end].item()
new_m, new_n = chunk_m + q, chunk_n + s
if new_n <= workspace_size and new_m * new_n <= max_logits_elems:
chunk_m, chunk_n = new_m, new_n
end += 1
else:
break
# A single request can exceed the budget, requiring sub-chunking
# on the query dimension.
if end == start:
chunk_m, chunk_n = query_lens_cpu[end].item(), seq_lens_cpu[end].item()
end += 1
req_slice = slice(start + request_offset, end + request_offset)
max_q = max(1, max_logits_elems // chunk_n) if chunk_n > 0 else chunk_m
for q_off in range(0, chunk_m, max_q):
sub_m = min(max_q, chunk_m - q_off)
chunks.append((req_slice, slice(q_off, q_off + sub_m)))
return chunks
class DeepseekV32IndexerBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
@@ -81,6 +130,7 @@ class DeepseekV32IndexerPrefillChunkMetadata:
token_start: int
token_end: int
num_reqs: int
skip_kv_gather: bool = False
@dataclass
@@ -271,43 +321,51 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
)
def build_one_prefill_chunk(
self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table
):
self,
req_slice: slice,
query_slice: slice,
query_start_loc_cpu,
seq_lens_cpu,
block_table,
skip_kv_gather: bool = False,
) -> DeepseekV32IndexerPrefillChunkMetadata:
prefill_query_start_loc = (
query_start_loc_cpu[reqs_start : reqs_end + 1]
- query_start_loc_cpu[reqs_start]
query_start_loc_cpu[req_slice.start : req_slice.stop + 1]
- query_start_loc_cpu[req_slice.start]
)
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device
prefill_query_start_loc, seq_lens_cpu[req_slice], self.device
)
token_start = query_start_loc_cpu[req_slice.start].item()
total_seq_lens = seq_lens_cpu[req_slice].sum()
num_reqs = req_slice.stop - req_slice.start
seq_idx = torch.arange(0, num_reqs, dtype=torch.int32)
token_to_seq = torch.repeat_interleave(seq_idx, seq_lens_cpu[req_slice]).to(
self.device
)
token_start = query_start_loc_cpu[reqs_start].item()
token_end = query_start_loc_cpu[reqs_end].item()
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
seq_idx = torch.arange(0, reqs_end - reqs_start, dtype=torch.int32)
token_to_seq = torch.repeat_interleave(
seq_idx, seq_lens_cpu[reqs_start:reqs_end]
).to(self.device)
assert total_seq_lens <= self.max_prefill_buffer_size
cu_seq_lens = (
torch.cat(
[
torch.zeros(1, dtype=torch.int32),
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0),
seq_lens_cpu[req_slice].cumsum(dim=0),
]
)
.to(torch.int32)
.to(self.device)
)
return DeepseekV32IndexerPrefillChunkMetadata(
cu_seqlen_ks=cu_seqlen_ks,
cu_seqlen_ke=cu_seqlen_ke,
cu_seqlen_ks=cu_seqlen_ks[query_slice],
cu_seqlen_ke=cu_seqlen_ke[query_slice],
cu_seq_lens=cu_seq_lens,
token_to_seq=token_to_seq,
total_seq_lens=total_seq_lens,
block_table=block_table[reqs_start:reqs_end],
token_start=token_start,
token_end=token_end,
num_reqs=reqs_end - reqs_start,
block_table=block_table[req_slice],
token_start=token_start + query_slice.start,
token_end=token_start + query_slice.stop,
num_reqs=num_reqs,
skip_kv_gather=skip_kv_gather,
)
def build(
@@ -333,20 +391,27 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
prefill_metadata = None
if num_prefills > 0:
chunk_seq_ids = split_prefill_chunks(
prefill_query_lens_cpu = torch.diff(
query_start_loc_cpu[num_decodes : num_decodes + num_prefills + 1]
)
max_logits_bytes = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024
chunk_specs = split_indexer_prefill_chunks(
common_attn_metadata.seq_lens_cpu[num_decodes:],
prefill_query_lens_cpu,
self.max_prefill_buffer_size,
max_logits_bytes,
request_offset=num_decodes,
)
chunks = [
self.build_one_prefill_chunk(
reqs_start,
reqs_end,
req_slice,
query_slice,
query_start_loc_cpu,
common_attn_metadata.seq_lens_cpu,
common_attn_metadata.block_table_tensor,
skip_kv_gather=query_slice.start > 0,
)
for reqs_start, reqs_end in chunk_seq_ids
for req_slice, query_slice in chunk_specs
]
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
chunks=chunks,