[Bugfix][MLA] Add logits size budget to sparse indexer prefill chunking (#36178)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -42,6 +42,7 @@ from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
FlashMLASparseBackend,
|
||||
triton_convert_req_index_to_global_index,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.indexer import split_indexer_prefill_chunks
|
||||
from vllm.v1.attention.backends.utils import split_prefill_chunks
|
||||
from vllm.v1.attention.ops import flashmla
|
||||
|
||||
@@ -716,6 +717,81 @@ def test_split_prefill_chunks(seq_lens, max_buf, expected):
|
||||
assert out == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens,query_lens,workspace_size,max_logits_bytes,expected",
|
||||
[
|
||||
# Logits constraint triggers split (M*N exceeds budget)
|
||||
# req0: M=10, N=100 -> 1000 elems (4000 bytes) - fits in 5000
|
||||
# req1: adding M=10, N=100 -> new_M=20, new_N=200 -> 4000 elems > 1250
|
||||
(
|
||||
torch.tensor([100, 100, 100]),
|
||||
torch.tensor([10, 10, 10]),
|
||||
1000, # workspace allows all
|
||||
5000, # 1250 float32 elems -> forces split
|
||||
[
|
||||
(slice(0, 1), slice(0, 10)),
|
||||
(slice(1, 2), slice(0, 10)),
|
||||
(slice(2, 3), slice(0, 10)),
|
||||
],
|
||||
),
|
||||
# Both constraints satisfied - all fit in one chunk
|
||||
(
|
||||
torch.tensor([10, 10, 10]),
|
||||
torch.tensor([5, 5, 5]),
|
||||
100,
|
||||
10000, # 2500 elems, M*N = 15*30 = 450 < 2500
|
||||
[(slice(0, 3), slice(0, 15))],
|
||||
),
|
||||
# Workspace constraint triggers first
|
||||
(
|
||||
torch.tensor([50, 50, 50]),
|
||||
torch.tensor([1, 1, 1]),
|
||||
50, # workspace only fits one at a time
|
||||
1000000, # logits budget is huge
|
||||
[
|
||||
(slice(0, 1), slice(0, 1)),
|
||||
(slice(1, 2), slice(0, 1)),
|
||||
(slice(2, 3), slice(0, 1)),
|
||||
],
|
||||
),
|
||||
# Greedy filling: first two fit, third doesn't
|
||||
# req0: M=5, N=10 -> 50 elems
|
||||
# req0+1: M=10, N=20 -> 200 elems <= 250
|
||||
# req0+1+2: M=15, N=30 -> 450 elems > 250
|
||||
(
|
||||
torch.tensor([10, 10, 10]),
|
||||
torch.tensor([5, 5, 5]),
|
||||
100,
|
||||
1000, # 250 elems
|
||||
[(slice(0, 2), slice(0, 10)), (slice(2, 3), slice(0, 5))],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_split_indexer_prefill_chunks(
|
||||
seq_lens, query_lens, workspace_size, max_logits_bytes, expected
|
||||
):
|
||||
out = split_indexer_prefill_chunks(
|
||||
seq_lens,
|
||||
query_lens,
|
||||
workspace_size,
|
||||
max_logits_bytes,
|
||||
)
|
||||
assert out == expected
|
||||
|
||||
|
||||
def test_split_indexer_prefill_chunks_single_request_overflow():
|
||||
"""Test that single request exceeding budget is sub-chunked on query dim."""
|
||||
seq_lens = torch.tensor([1000, 50])
|
||||
query_lens = torch.tensor([100, 5])
|
||||
|
||||
out = split_indexer_prefill_chunks(seq_lens, query_lens, 2000, 1000)
|
||||
# max_logits_elems = 250, N=1000 -> max_q = 1 -> 100 query sub-chunks
|
||||
expected = [(slice(0, 1), slice(i, i + 1)) for i in range(100)]
|
||||
# req1: M=5, N=50 -> 250 elems fits budget
|
||||
expected.append((slice(1, 2), slice(0, 5)))
|
||||
assert out == expected
|
||||
|
||||
|
||||
def test_triton_convert_returns_valid_counts():
|
||||
"""Test that return_valid_counts correctly counts non-negative indices."""
|
||||
device = torch.device("cuda")
|
||||
|
||||
@@ -55,6 +55,7 @@ if TYPE_CHECKING:
|
||||
VLLM_CPU_INT4_W4A8: bool = True
|
||||
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
|
||||
VLLM_XLA_CHECK_RECOMPILATION: bool = False
|
||||
VLLM_SPARSE_INDEXER_MAX_LOGITS_MB: int = 512
|
||||
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto"
|
||||
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False
|
||||
VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True
|
||||
@@ -861,6 +862,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
),
|
||||
# Enable SPMD mode for TPU backend.
|
||||
"VLLM_XLA_USE_SPMD": lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))),
|
||||
# Maximum size (in MB) for logits tensor in sparse MLA indexer prefill chunks.
|
||||
# Bounds the [M, N] float32 logits tensor to prevent CUDA OOM.
|
||||
# Default: 512 MB
|
||||
"VLLM_SPARSE_INDEXER_MAX_LOGITS_MB": lambda: int(
|
||||
os.getenv("VLLM_SPARSE_INDEXER_MAX_LOGITS_MB", "512")
|
||||
),
|
||||
# If set, the OpenAI API server will stay alive even after the underlying
|
||||
# AsyncLLMEngine errors and stops serving requests
|
||||
"VLLM_KEEP_ALIVE_ON_ENGINE_DEATH": lambda: bool(
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
@@ -51,6 +52,14 @@ def sparse_attn_indexer(
|
||||
((total_seq_lens, head_dim), torch.float8_e4m3fn),
|
||||
((total_seq_lens, 4), torch.uint8),
|
||||
)
|
||||
|
||||
# Dummy allocation to simulate for peak logits tensor memory during inference.
|
||||
# FP8 elements so elements == bytes
|
||||
max_logits_elems = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024
|
||||
_ = torch.empty(
|
||||
max_logits_elems, dtype=torch.uint8, device=hidden_states.device
|
||||
)
|
||||
|
||||
return sparse_attn_indexer_fake(
|
||||
hidden_states,
|
||||
k_cache_prefix,
|
||||
@@ -101,13 +110,16 @@ def sparse_attn_indexer(
|
||||
for chunk in prefill_metadata.chunks:
|
||||
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
|
||||
k_scale = k_scale_full[: chunk.total_seq_lens]
|
||||
ops.cp_gather_indexer_k_quant_cache(
|
||||
kv_cache,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
chunk.block_table,
|
||||
chunk.cu_seq_lens,
|
||||
)
|
||||
|
||||
if not chunk.skip_kv_gather:
|
||||
ops.cp_gather_indexer_k_quant_cache(
|
||||
kv_cache,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
chunk.block_table,
|
||||
chunk.cu_seq_lens,
|
||||
)
|
||||
|
||||
logits = fp8_mqa_logits(
|
||||
q_fp8[chunk.token_start : chunk.token_end],
|
||||
(k_fp8, k_scale.view(torch.float32).flatten()),
|
||||
|
||||
@@ -4,6 +4,7 @@ from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
@@ -22,7 +23,6 @@ from vllm.v1.attention.backend import (
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
split_decodes_and_prefills,
|
||||
split_prefill_chunks,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.cp_utils import get_total_cp_world_size
|
||||
@@ -30,6 +30,55 @@ from vllm.v1.worker.cp_utils import get_total_cp_world_size
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def split_indexer_prefill_chunks(
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
query_lens_cpu: torch.Tensor,
|
||||
workspace_size: int,
|
||||
max_logits_bytes: int,
|
||||
request_offset: int = 0,
|
||||
) -> list[tuple[slice, slice]]:
|
||||
"""
|
||||
Split prefill requests into chunks for the sparse indexer, respecting:
|
||||
- N constraint: total_seq_lens <= workspace_size (existing O(N) workspace)
|
||||
- Logits constraint: M * N * 4 <= max_logits_bytes
|
||||
|
||||
When a single request-level chunk still exceeds the logits budget,
|
||||
sub-chunks on the query dimension (M) to bound peak memory.
|
||||
|
||||
Returns list of (req_slice, query_slice) tuples.
|
||||
"""
|
||||
chunks: list[tuple[slice, slice]] = []
|
||||
n = len(seq_lens_cpu)
|
||||
max_logits_elems = max_logits_bytes // 4
|
||||
end = 0
|
||||
|
||||
while end < n:
|
||||
start, chunk_m, chunk_n = end, 0, 0
|
||||
|
||||
while end < n:
|
||||
q, s = query_lens_cpu[end].item(), seq_lens_cpu[end].item()
|
||||
new_m, new_n = chunk_m + q, chunk_n + s
|
||||
if new_n <= workspace_size and new_m * new_n <= max_logits_elems:
|
||||
chunk_m, chunk_n = new_m, new_n
|
||||
end += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# A single request can exceed the budget, requiring sub-chunking
|
||||
# on the query dimension.
|
||||
if end == start:
|
||||
chunk_m, chunk_n = query_lens_cpu[end].item(), seq_lens_cpu[end].item()
|
||||
end += 1
|
||||
|
||||
req_slice = slice(start + request_offset, end + request_offset)
|
||||
max_q = max(1, max_logits_elems // chunk_n) if chunk_n > 0 else chunk_m
|
||||
for q_off in range(0, chunk_m, max_q):
|
||||
sub_m = min(max_q, chunk_m - q_off)
|
||||
chunks.append((req_slice, slice(q_off, q_off + sub_m)))
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
class DeepseekV32IndexerBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@@ -81,6 +130,7 @@ class DeepseekV32IndexerPrefillChunkMetadata:
|
||||
token_start: int
|
||||
token_end: int
|
||||
num_reqs: int
|
||||
skip_kv_gather: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -271,43 +321,51 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
)
|
||||
|
||||
def build_one_prefill_chunk(
|
||||
self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table
|
||||
):
|
||||
self,
|
||||
req_slice: slice,
|
||||
query_slice: slice,
|
||||
query_start_loc_cpu,
|
||||
seq_lens_cpu,
|
||||
block_table,
|
||||
skip_kv_gather: bool = False,
|
||||
) -> DeepseekV32IndexerPrefillChunkMetadata:
|
||||
prefill_query_start_loc = (
|
||||
query_start_loc_cpu[reqs_start : reqs_end + 1]
|
||||
- query_start_loc_cpu[reqs_start]
|
||||
query_start_loc_cpu[req_slice.start : req_slice.stop + 1]
|
||||
- query_start_loc_cpu[req_slice.start]
|
||||
)
|
||||
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
|
||||
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device
|
||||
prefill_query_start_loc, seq_lens_cpu[req_slice], self.device
|
||||
)
|
||||
token_start = query_start_loc_cpu[req_slice.start].item()
|
||||
total_seq_lens = seq_lens_cpu[req_slice].sum()
|
||||
num_reqs = req_slice.stop - req_slice.start
|
||||
seq_idx = torch.arange(0, num_reqs, dtype=torch.int32)
|
||||
token_to_seq = torch.repeat_interleave(seq_idx, seq_lens_cpu[req_slice]).to(
|
||||
self.device
|
||||
)
|
||||
token_start = query_start_loc_cpu[reqs_start].item()
|
||||
token_end = query_start_loc_cpu[reqs_end].item()
|
||||
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
|
||||
seq_idx = torch.arange(0, reqs_end - reqs_start, dtype=torch.int32)
|
||||
token_to_seq = torch.repeat_interleave(
|
||||
seq_idx, seq_lens_cpu[reqs_start:reqs_end]
|
||||
).to(self.device)
|
||||
assert total_seq_lens <= self.max_prefill_buffer_size
|
||||
cu_seq_lens = (
|
||||
torch.cat(
|
||||
[
|
||||
torch.zeros(1, dtype=torch.int32),
|
||||
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0),
|
||||
seq_lens_cpu[req_slice].cumsum(dim=0),
|
||||
]
|
||||
)
|
||||
.to(torch.int32)
|
||||
.to(self.device)
|
||||
)
|
||||
|
||||
return DeepseekV32IndexerPrefillChunkMetadata(
|
||||
cu_seqlen_ks=cu_seqlen_ks,
|
||||
cu_seqlen_ke=cu_seqlen_ke,
|
||||
cu_seqlen_ks=cu_seqlen_ks[query_slice],
|
||||
cu_seqlen_ke=cu_seqlen_ke[query_slice],
|
||||
cu_seq_lens=cu_seq_lens,
|
||||
token_to_seq=token_to_seq,
|
||||
total_seq_lens=total_seq_lens,
|
||||
block_table=block_table[reqs_start:reqs_end],
|
||||
token_start=token_start,
|
||||
token_end=token_end,
|
||||
num_reqs=reqs_end - reqs_start,
|
||||
block_table=block_table[req_slice],
|
||||
token_start=token_start + query_slice.start,
|
||||
token_end=token_start + query_slice.stop,
|
||||
num_reqs=num_reqs,
|
||||
skip_kv_gather=skip_kv_gather,
|
||||
)
|
||||
|
||||
def build(
|
||||
@@ -333,20 +391,27 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
chunk_seq_ids = split_prefill_chunks(
|
||||
prefill_query_lens_cpu = torch.diff(
|
||||
query_start_loc_cpu[num_decodes : num_decodes + num_prefills + 1]
|
||||
)
|
||||
max_logits_bytes = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024
|
||||
chunk_specs = split_indexer_prefill_chunks(
|
||||
common_attn_metadata.seq_lens_cpu[num_decodes:],
|
||||
prefill_query_lens_cpu,
|
||||
self.max_prefill_buffer_size,
|
||||
max_logits_bytes,
|
||||
request_offset=num_decodes,
|
||||
)
|
||||
chunks = [
|
||||
self.build_one_prefill_chunk(
|
||||
reqs_start,
|
||||
reqs_end,
|
||||
req_slice,
|
||||
query_slice,
|
||||
query_start_loc_cpu,
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
common_attn_metadata.block_table_tensor,
|
||||
skip_kv_gather=query_slice.start > 0,
|
||||
)
|
||||
for reqs_start, reqs_end in chunk_seq_ids
|
||||
for req_slice, query_slice in chunk_specs
|
||||
]
|
||||
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
|
||||
chunks=chunks,
|
||||
|
||||
Reference in New Issue
Block a user