[BugFix] Add support for MTP num_speculative_tokens > 1 with sparse MLA (#34552)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-03-03 10:21:57 -05:00
committed by GitHub
parent fb7fdc49c4
commit 28ef9ba399
7 changed files with 260 additions and 197 deletions

View File

@@ -79,6 +79,12 @@ def sparse_attn_indexer(
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
# During speculative decoding, k may be padded to the CUDA graph batch
# size while slot_mapping only covers actual tokens. Truncate k to avoid
# out-of-bounds reads in the kernel.
num_tokens = slot_mapping.shape[0]
k = k[:num_tokens]
ops.indexer_k_quant_and_cache(
k,
kv_cache,