[BugFix] Add support for MTP num_speculative_tokens > 1 with sparse MLA (#34552)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-03-03 10:21:57 -05:00
committed by GitHub
parent fb7fdc49c4
commit 28ef9ba399
7 changed files with 260 additions and 197 deletions

View File

@@ -162,7 +162,7 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
model_mock.compute_logits.side_effect = logits_returns
proposer.model = model_mock
proposer.attn_layer_names = ["layer.0"]
proposer._draft_attn_layer_names = {"layer.0"}
# Prepare inputs
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens)
@@ -190,13 +190,17 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,
layer_names=list(proposer._draft_attn_layer_names),
vllm_config=proposer.vllm_config,
device=device,
)
proposer.runner = mock.MagicMock()
proposer.attn_metadata_builder = attn_metadata_builder
mock_attn_group = mock.MagicMock()
mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder
mock_attn_group.layer_names = list(proposer._draft_attn_layer_names)
mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec
proposer.draft_attn_groups = [mock_attn_group]
# Run propose
result = proposer.propose(