[MTP][Sparse MLA] Take advantage of native MTP support in indexer when possible (#36982)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2026-03-16 13:51:21 -04:00
committed by GitHub
parent 9f9ecff4cd
commit c88ea8338b
2 changed files with 24 additions and 13 deletions

View File

@@ -206,6 +206,8 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig):
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
reorder_batch_threshold: int = 1
natively_supported_next_n: list[int] = [1, 2]
# TODO (matt): integrate kernel with next_n = 4 support
@classmethod
def get_cudagraph_support(
@@ -231,7 +233,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
if self.vllm_config.speculative_config
else 0
)
next_n = self.num_speculative_tokens + 1
self.reorder_batch_threshold += self.num_speculative_tokens
self.use_flattening = next_n not in self.natively_supported_next_n
sm_count = num_compute_units(self.device.index)
self.num_sms = sm_count
@@ -241,10 +245,11 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
dtype=torch.int32,
device=self.device,
)
# Pre-allocated buffers for flattening (spec decode).
self.offsets_buffer = torch.arange(
next_n, device=self.device, dtype=torch.int32
)
self.arange_buffer = torch.arange(
scheduler_config.max_num_seqs * (1 + self.num_speculative_tokens),
scheduler_config.max_num_seqs * next_n,
dtype=torch.int32,
device=self.device,
)
@@ -323,7 +328,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
require_uniform=not self.use_flattening,
)
)
@@ -372,11 +379,21 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
block_table.clamp_(min=0)
max_decode_len = int(decode_lens_cpu.max().item())
if max_decode_len > 1:
next_n = 1 + self.num_speculative_tokens
use_native = not self.use_flattening and max_decode_len == next_n
if use_native and next_n > 1:
offsets = self.offsets_buffer
batch_size = num_decodes
elif max_decode_len > 1:
# Flatten multi-token decode requests into single-token
# batch entries, expanding seq_lens and block tables so
# the kernel always sees next_n=1.
# Also handles the edge case where use_flattening=False
# but max_decode_len != next_n (e.g. a batch containing some
# short prefills (q_len < next_n) and no true decodes).
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
# The context lengths are therefore
@@ -428,13 +445,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
offsets = None
batch_size = num_decode_tokens
else:
next_n = 1 + self.num_speculative_tokens
if next_n > 1:
offsets = torch.arange(
next_n, device=self.device, dtype=torch.int32
)
else:
offsets = None
offsets = None
batch_size = num_decodes
# DeepGEMM is required for the paged MQA logits on CUDA devices