[BugFix] Add support for MTP num_speculative_tokens > 1 with sparse MLA (#34552)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -12,6 +12,7 @@ from vllm.utils.deep_gemm import (
|
||||
get_paged_mqa_logits_metadata,
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
@@ -24,6 +25,7 @@ from vllm.v1.attention.backends.utils import (
|
||||
split_decodes_and_prefills,
|
||||
split_prefill_chunks,
|
||||
)
|
||||
from vllm.v1.worker.cp_utils import get_total_cp_world_size
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -214,20 +216,39 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
if self.vllm_config.speculative_config
|
||||
else 0
|
||||
)
|
||||
if self.num_speculative_tokens > 1:
|
||||
raise ValueError(
|
||||
"Sparse MLA only supports "
|
||||
"num_speculative_tokens <= 1 because the DeepGEMM "
|
||||
"fp8_paged_mqa_logits kernel does not support next_n > 2. "
|
||||
f"Got num_speculative_tokens={self.num_speculative_tokens}."
|
||||
)
|
||||
self.reorder_batch_threshold += self.num_speculative_tokens
|
||||
|
||||
sm_count = num_compute_units(self.device.index)
|
||||
self.num_sms = sm_count
|
||||
|
||||
self.decode_lens_buffer = torch.empty(
|
||||
(scheduler_config.max_num_seqs,), dtype=torch.int32, device=self.device
|
||||
(scheduler_config.max_num_batched_tokens,),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Pre-allocated buffers for flattening (spec decode).
|
||||
self.arange_buffer = torch.arange(
|
||||
scheduler_config.max_num_seqs * (1 + self.num_speculative_tokens),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.expanded_seq_lens_buffer = torch.zeros(
|
||||
(scheduler_config.max_num_batched_tokens,),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
max_num_blocks_per_req = cdiv(
|
||||
self.vllm_config.model_config.max_model_len,
|
||||
self.kv_cache_spec.block_size * get_total_cp_world_size(),
|
||||
)
|
||||
self.expanded_block_table_buffer = torch.zeros(
|
||||
(
|
||||
scheduler_config.max_num_batched_tokens,
|
||||
max_num_blocks_per_req,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# See: DeepGMM/csrc/apis/attention.hpp
|
||||
@@ -326,42 +347,97 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
|
||||
)
|
||||
|
||||
# Use CPU to avoid GPU sync; breaking async scheduling
|
||||
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
|
||||
|
||||
# Decide which top-k kernel to use based on batch size and sequence length
|
||||
batch_size = num_decodes
|
||||
_is_large_context = common_attn_metadata.max_seq_len > 8192
|
||||
|
||||
# Decision logic based on micro-benchmark results:
|
||||
# - large_context_topk wins for batch <= 128 and seq_len > 8K
|
||||
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
|
||||
use_large_context_topk = batch_size <= 128 and _is_large_context
|
||||
|
||||
next_n = 1 + self.num_speculative_tokens
|
||||
if next_n > 1:
|
||||
offsets = torch.arange(next_n, device=self.device, dtype=torch.int32)
|
||||
else:
|
||||
offsets = None
|
||||
|
||||
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
|
||||
|
||||
# DeepGEMM is required for the paged MQA logits on CUDA devices
|
||||
if current_platform.is_cuda() and is_deep_gemm_supported():
|
||||
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
||||
seq_lens, self.kv_cache_spec.block_size, self.num_sms
|
||||
)
|
||||
block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...]
|
||||
|
||||
# Padded CUDA graph requests have block_table entries of -1.
|
||||
# Clamp to 0 to prevent OOB access in the DeepGEMM kernel.
|
||||
# This is safe because padded requests have seq_lens=0, so the
|
||||
# kernel produces no meaningful output for those rows.
|
||||
block_table.clamp_(min=0)
|
||||
|
||||
max_decode_len = int(decode_lens_cpu.max().item())
|
||||
if max_decode_len > 1:
|
||||
# Flatten multi-token decode requests into single-token
|
||||
# batch entries, expanding seq_lens and block tables so
|
||||
# the kernel always sees next_n=1.
|
||||
|
||||
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
|
||||
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
|
||||
# The context lengths are therefore
|
||||
# [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].
|
||||
|
||||
# 3 + 1 + 4 + 0 = 8
|
||||
actual_expanded = int(decode_lens_cpu.sum().item())
|
||||
|
||||
# [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8]
|
||||
expanded_base = torch.repeat_interleave(
|
||||
seq_lens - decode_lens, decode_lens
|
||||
)
|
||||
|
||||
# [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4]
|
||||
expanded_starts = torch.repeat_interleave(
|
||||
common_attn_metadata.query_start_loc[:num_decodes], decode_lens
|
||||
)
|
||||
|
||||
# [0, 1, 2, 0, 0, 1, 2, 3]
|
||||
positions_within = (
|
||||
self.arange_buffer[:actual_expanded] - expanded_starts
|
||||
)
|
||||
|
||||
# [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
|
||||
self.expanded_seq_lens_buffer[:actual_expanded] = (
|
||||
expanded_base + positions_within + 1
|
||||
)
|
||||
self.expanded_seq_lens_buffer[actual_expanded:] = 0
|
||||
seq_lens = self.expanded_seq_lens_buffer[:num_decode_tokens]
|
||||
|
||||
# Give each of the flattened entries the same block table row as the
|
||||
# original request.
|
||||
self.expanded_block_table_buffer[:actual_expanded] = (
|
||||
torch.repeat_interleave(block_table, decode_lens, dim=0)
|
||||
)
|
||||
if actual_expanded < num_decode_tokens:
|
||||
self.expanded_block_table_buffer[
|
||||
actual_expanded:num_decode_tokens, 0
|
||||
] = 0
|
||||
block_table = self.expanded_block_table_buffer[:num_decode_tokens]
|
||||
|
||||
# All reqs now have decode_len=1
|
||||
self.decode_lens_buffer[:num_decode_tokens] = 1
|
||||
decode_lens = self.decode_lens_buffer[:num_decode_tokens]
|
||||
offsets = None
|
||||
batch_size = num_decode_tokens
|
||||
else:
|
||||
next_n = 1 + self.num_speculative_tokens
|
||||
if next_n > 1:
|
||||
offsets = torch.arange(
|
||||
next_n, device=self.device, dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
offsets = None
|
||||
batch_size = num_decodes
|
||||
|
||||
# DeepGEMM is required for the paged MQA logits on CUDA devices
|
||||
if current_platform.is_cuda() and is_deep_gemm_supported():
|
||||
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
||||
seq_lens,
|
||||
self.kv_cache_spec.block_size,
|
||||
self.num_sms,
|
||||
)
|
||||
|
||||
# Decide which top-k kernel to use based on batch size and sequence length
|
||||
# Decision logic based on micro-benchmark results:
|
||||
# - large_context_topk wins for batch <= 128 and seq_len > 8K
|
||||
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
|
||||
_is_large_context = common_attn_metadata.max_seq_len > 8192
|
||||
use_large_context_topk = batch_size <= 128 and _is_large_context
|
||||
|
||||
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||
block_table=block_table,
|
||||
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
||||
seq_lens=seq_lens,
|
||||
decode_lens=decode_lens,
|
||||
requires_padding=requires_padding,
|
||||
requires_padding=False,
|
||||
schedule_metadata=self.scheduler_metadata_buffer,
|
||||
use_large_context_topk=use_large_context_topk,
|
||||
offsets=offsets,
|
||||
|
||||
Reference in New Issue
Block a user