[MTP][Sparse MLA] Take advantage of native MTP support in indexer when possible (#36982)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -206,6 +206,8 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig):
|
||||
|
||||
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
reorder_batch_threshold: int = 1
|
||||
natively_supported_next_n: list[int] = [1, 2]
|
||||
# TODO (matt): integrate kernel with next_n = 4 support
|
||||
|
||||
@classmethod
|
||||
def get_cudagraph_support(
|
||||
@@ -231,7 +233,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
if self.vllm_config.speculative_config
|
||||
else 0
|
||||
)
|
||||
next_n = self.num_speculative_tokens + 1
|
||||
self.reorder_batch_threshold += self.num_speculative_tokens
|
||||
self.use_flattening = next_n not in self.natively_supported_next_n
|
||||
|
||||
sm_count = num_compute_units(self.device.index)
|
||||
self.num_sms = sm_count
|
||||
@@ -241,10 +245,11 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Pre-allocated buffers for flattening (spec decode).
|
||||
self.offsets_buffer = torch.arange(
|
||||
next_n, device=self.device, dtype=torch.int32
|
||||
)
|
||||
self.arange_buffer = torch.arange(
|
||||
scheduler_config.max_num_seqs * (1 + self.num_speculative_tokens),
|
||||
scheduler_config.max_num_seqs * next_n,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
@@ -323,7 +328,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold,
|
||||
require_uniform=not self.use_flattening,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -372,11 +379,21 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
block_table.clamp_(min=0)
|
||||
|
||||
max_decode_len = int(decode_lens_cpu.max().item())
|
||||
if max_decode_len > 1:
|
||||
next_n = 1 + self.num_speculative_tokens
|
||||
use_native = not self.use_flattening and max_decode_len == next_n
|
||||
|
||||
if use_native and next_n > 1:
|
||||
offsets = self.offsets_buffer
|
||||
batch_size = num_decodes
|
||||
elif max_decode_len > 1:
|
||||
# Flatten multi-token decode requests into single-token
|
||||
# batch entries, expanding seq_lens and block tables so
|
||||
# the kernel always sees next_n=1.
|
||||
|
||||
# Also handles the edge case where use_flattening=False
|
||||
# but max_decode_len != next_n (e.g. a batch containing some
|
||||
# short prefills (q_len < next_n) and no true decodes).
|
||||
|
||||
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
|
||||
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
|
||||
# The context lengths are therefore
|
||||
@@ -428,13 +445,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
offsets = None
|
||||
batch_size = num_decode_tokens
|
||||
else:
|
||||
next_n = 1 + self.num_speculative_tokens
|
||||
if next_n > 1:
|
||||
offsets = torch.arange(
|
||||
next_n, device=self.device, dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
offsets = None
|
||||
offsets = None
|
||||
batch_size = num_decodes
|
||||
|
||||
# DeepGEMM is required for the paged MQA logits on CUDA devices
|
||||
|
||||
Reference in New Issue
Block a user