[Bugfix][MTP][Sparse MLA] Allow sparse MLA with MTP to run with FULL cudagraphs (#34457)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Matthew Bonanni
2026-02-17 14:01:27 -05:00
committed by GitHub
parent 1e4a084c8e
commit dc5fa77a4e
2 changed files with 17 additions and 6 deletions

View File

@@ -182,6 +182,7 @@ The following table lists backends that support full CUDA Graphs at the time of
| FlashInfer | `UNIFORM_SINGLE_TOKEN_DECODE` | Will be set to `UNIFORM_BATCH` when using TRTLLM attention on Blackwell |
| FlashMLA | `UNIFORM_BATCH` | |
| FlashInferMLA | `UNIFORM_BATCH` | |
| FlashInferMLASparse | `UNIFORM_BATCH` | |
| AITER MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | |
| CUTLASS MLA | `UNIFORM_SINGLE_TOKEN_DECODE` | |
| Mamba attention| `UNIFORM_SINGLE_TOKEN_DECODE` | |

View File

@@ -196,9 +196,7 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig):
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
_cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
reorder_batch_threshold: int = 1
@@ -212,8 +210,14 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
if self.vllm_config.speculative_config
else 0
)
# Now deepgemm fp8_paged_mqa_logits does not support next_n > 2
self.reorder_batch_threshold += min(self.num_speculative_tokens, 1)
if self.num_speculative_tokens > 1:
raise ValueError(
"Sparse MLA only supports "
"num_speculative_tokens <= 1 because the DeepGEMM "
"fp8_paged_mqa_logits kernel does not support next_n > 2. "
f"Got num_speculative_tokens={self.num_speculative_tokens}."
)
self.reorder_batch_threshold += self.num_speculative_tokens
props = torch.cuda.get_device_properties(self.device)
sm_count = props.multi_processor_count
@@ -342,8 +346,14 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms
)
block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...]
# Padded CUDA graph requests have block_table entries of -1.
# Clamp to 0 to prevent OOB access in the DeepGEMM kernel.
# This is safe because padded requests have seq_lens=0, so the
# kernel produces no meaningful output for those rows.
block_table.clamp_(min=0)
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...],
block_table=block_table,
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
decode_lens=decode_lens,
requires_padding=requires_padding,