[BugFix] Add support for MTP num_speculative_tokens > 1 with sparse MLA (#34552)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -79,6 +79,12 @@ def sparse_attn_indexer(
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
# During speculative decoding, k may be padded to the CUDA graph batch
|
||||
# size while slot_mapping only covers actual tokens. Truncate k to avoid
|
||||
# out-of-bounds reads in the kernel.
|
||||
num_tokens = slot_mapping.shape[0]
|
||||
k = k[:num_tokens]
|
||||
|
||||
ops.indexer_k_quant_and_cache(
|
||||
k,
|
||||
kv_cache,
|
||||
|
||||
Reference in New Issue
Block a user