[Bugfix][MTP][Sparse MLA] Allow sparse MLA with MTP to run with FULL cudagraphs (#34457)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -182,6 +182,7 @@ The following table lists backends that support full CUDA Graphs at the time of
|
||||
| FlashInfer | `UNIFORM_SINGLE_TOKEN_DECODE` | Will be set to `UNIFORM_BATCH` when using TRTLLM attention on Blackwell |
|
||||
| FlashMLA | `UNIFORM_BATCH` | |
|
||||
| FlashInferMLA | `UNIFORM_BATCH` | |
|
||||
| FlashInferMLASparse | `UNIFORM_BATCH` | |
|
||||
| AITER MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | |
|
||||
| CUTLASS MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | |
|
||||
| Mamba attention| `UNIFORM_SINGLE_TOKEN_DECODE` | |
|
||||
|
||||
@@ -196,9 +196,7 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig):
|
||||
|
||||
|
||||
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
)
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
@@ -212,8 +210,14 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
if self.vllm_config.speculative_config
|
||||
else 0
|
||||
)
|
||||
# Now deepgemm fp8_paged_mqa_logits does not support next_n > 2
|
||||
self.reorder_batch_threshold += min(self.num_speculative_tokens, 1)
|
||||
if self.num_speculative_tokens > 1:
|
||||
raise ValueError(
|
||||
"Sparse MLA only supports "
|
||||
"num_speculative_tokens <= 1 because the DeepGEMM "
|
||||
"fp8_paged_mqa_logits kernel does not support next_n > 2. "
|
||||
f"Got num_speculative_tokens={self.num_speculative_tokens}."
|
||||
)
|
||||
self.reorder_batch_threshold += self.num_speculative_tokens
|
||||
|
||||
props = torch.cuda.get_device_properties(self.device)
|
||||
sm_count = props.multi_processor_count
|
||||
@@ -342,8 +346,14 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
||||
seq_lens, self.kv_cache_spec.block_size, self.num_sms
|
||||
)
|
||||
block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...]
|
||||
# Padded CUDA graph requests have block_table entries of -1.
|
||||
# Clamp to 0 to prevent OOB access in the DeepGEMM kernel.
|
||||
# This is safe because padded requests have seq_lens=0, so the
|
||||
# kernel produces no meaningful output for those rows.
|
||||
block_table.clamp_(min=0)
|
||||
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||
block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...],
|
||||
block_table=block_table,
|
||||
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
||||
decode_lens=decode_lens,
|
||||
requires_padding=requires_padding,
|
||||
|
||||
Reference in New Issue
Block a user