[Bugfix][MLA] Add logits size budget to sparse indexer prefill chunking (#36178)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
(cherry picked from commit eb47454987)
This commit is contained in:
@@ -4,6 +4,7 @@ from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
@@ -22,7 +23,6 @@ from vllm.v1.attention.backend import (
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
split_decodes_and_prefills,
|
||||
split_prefill_chunks,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.cp_utils import get_total_cp_world_size
|
||||
@@ -30,6 +30,55 @@ from vllm.v1.worker.cp_utils import get_total_cp_world_size
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def split_indexer_prefill_chunks(
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
query_lens_cpu: torch.Tensor,
|
||||
workspace_size: int,
|
||||
max_logits_bytes: int,
|
||||
request_offset: int = 0,
|
||||
) -> list[tuple[slice, slice]]:
|
||||
"""
|
||||
Split prefill requests into chunks for the sparse indexer, respecting:
|
||||
- N constraint: total_seq_lens <= workspace_size (existing O(N) workspace)
|
||||
- Logits constraint: M * N * 4 <= max_logits_bytes
|
||||
|
||||
When a single request-level chunk still exceeds the logits budget,
|
||||
sub-chunks on the query dimension (M) to bound peak memory.
|
||||
|
||||
Returns list of (req_slice, query_slice) tuples.
|
||||
"""
|
||||
chunks: list[tuple[slice, slice]] = []
|
||||
n = len(seq_lens_cpu)
|
||||
max_logits_elems = max_logits_bytes // 4
|
||||
end = 0
|
||||
|
||||
while end < n:
|
||||
start, chunk_m, chunk_n = end, 0, 0
|
||||
|
||||
while end < n:
|
||||
q, s = query_lens_cpu[end].item(), seq_lens_cpu[end].item()
|
||||
new_m, new_n = chunk_m + q, chunk_n + s
|
||||
if new_n <= workspace_size and new_m * new_n <= max_logits_elems:
|
||||
chunk_m, chunk_n = new_m, new_n
|
||||
end += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# A single request can exceed the budget, requiring sub-chunking
|
||||
# on the query dimension.
|
||||
if end == start:
|
||||
chunk_m, chunk_n = query_lens_cpu[end].item(), seq_lens_cpu[end].item()
|
||||
end += 1
|
||||
|
||||
req_slice = slice(start + request_offset, end + request_offset)
|
||||
max_q = max(1, max_logits_elems // chunk_n) if chunk_n > 0 else chunk_m
|
||||
for q_off in range(0, chunk_m, max_q):
|
||||
sub_m = min(max_q, chunk_m - q_off)
|
||||
chunks.append((req_slice, slice(q_off, q_off + sub_m)))
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
class DeepseekV32IndexerBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@@ -81,6 +130,7 @@ class DeepseekV32IndexerPrefillChunkMetadata:
|
||||
token_start: int
|
||||
token_end: int
|
||||
num_reqs: int
|
||||
skip_kv_gather: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -271,43 +321,51 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
)
|
||||
|
||||
def build_one_prefill_chunk(
|
||||
self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table
|
||||
):
|
||||
self,
|
||||
req_slice: slice,
|
||||
query_slice: slice,
|
||||
query_start_loc_cpu,
|
||||
seq_lens_cpu,
|
||||
block_table,
|
||||
skip_kv_gather: bool = False,
|
||||
) -> DeepseekV32IndexerPrefillChunkMetadata:
|
||||
prefill_query_start_loc = (
|
||||
query_start_loc_cpu[reqs_start : reqs_end + 1]
|
||||
- query_start_loc_cpu[reqs_start]
|
||||
query_start_loc_cpu[req_slice.start : req_slice.stop + 1]
|
||||
- query_start_loc_cpu[req_slice.start]
|
||||
)
|
||||
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
|
||||
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device
|
||||
prefill_query_start_loc, seq_lens_cpu[req_slice], self.device
|
||||
)
|
||||
token_start = query_start_loc_cpu[req_slice.start].item()
|
||||
total_seq_lens = seq_lens_cpu[req_slice].sum()
|
||||
num_reqs = req_slice.stop - req_slice.start
|
||||
seq_idx = torch.arange(0, num_reqs, dtype=torch.int32)
|
||||
token_to_seq = torch.repeat_interleave(seq_idx, seq_lens_cpu[req_slice]).to(
|
||||
self.device
|
||||
)
|
||||
token_start = query_start_loc_cpu[reqs_start].item()
|
||||
token_end = query_start_loc_cpu[reqs_end].item()
|
||||
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
|
||||
seq_idx = torch.arange(0, reqs_end - reqs_start, dtype=torch.int32)
|
||||
token_to_seq = torch.repeat_interleave(
|
||||
seq_idx, seq_lens_cpu[reqs_start:reqs_end]
|
||||
).to(self.device)
|
||||
assert total_seq_lens <= self.max_prefill_buffer_size
|
||||
cu_seq_lens = (
|
||||
torch.cat(
|
||||
[
|
||||
torch.zeros(1, dtype=torch.int32),
|
||||
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0),
|
||||
seq_lens_cpu[req_slice].cumsum(dim=0),
|
||||
]
|
||||
)
|
||||
.to(torch.int32)
|
||||
.to(self.device)
|
||||
)
|
||||
|
||||
return DeepseekV32IndexerPrefillChunkMetadata(
|
||||
cu_seqlen_ks=cu_seqlen_ks,
|
||||
cu_seqlen_ke=cu_seqlen_ke,
|
||||
cu_seqlen_ks=cu_seqlen_ks[query_slice],
|
||||
cu_seqlen_ke=cu_seqlen_ke[query_slice],
|
||||
cu_seq_lens=cu_seq_lens,
|
||||
token_to_seq=token_to_seq,
|
||||
total_seq_lens=total_seq_lens,
|
||||
block_table=block_table[reqs_start:reqs_end],
|
||||
token_start=token_start,
|
||||
token_end=token_end,
|
||||
num_reqs=reqs_end - reqs_start,
|
||||
block_table=block_table[req_slice],
|
||||
token_start=token_start + query_slice.start,
|
||||
token_end=token_start + query_slice.stop,
|
||||
num_reqs=num_reqs,
|
||||
skip_kv_gather=skip_kv_gather,
|
||||
)
|
||||
|
||||
def build(
|
||||
@@ -333,20 +391,27 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
chunk_seq_ids = split_prefill_chunks(
|
||||
prefill_query_lens_cpu = torch.diff(
|
||||
query_start_loc_cpu[num_decodes : num_decodes + num_prefills + 1]
|
||||
)
|
||||
max_logits_bytes = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024
|
||||
chunk_specs = split_indexer_prefill_chunks(
|
||||
common_attn_metadata.seq_lens_cpu[num_decodes:],
|
||||
prefill_query_lens_cpu,
|
||||
self.max_prefill_buffer_size,
|
||||
max_logits_bytes,
|
||||
request_offset=num_decodes,
|
||||
)
|
||||
chunks = [
|
||||
self.build_one_prefill_chunk(
|
||||
reqs_start,
|
||||
reqs_end,
|
||||
req_slice,
|
||||
query_slice,
|
||||
query_start_loc_cpu,
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
common_attn_metadata.block_table_tensor,
|
||||
skip_kv_gather=query_slice.start > 0,
|
||||
)
|
||||
for reqs_start, reqs_end in chunk_seq_ids
|
||||
for req_slice, query_slice in chunk_specs
|
||||
]
|
||||
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
|
||||
chunks=chunks,
|
||||
|
||||
Reference in New Issue
Block a user