[BugFix] Add support for MTP num_speculative_tokens > 1 with sparse MLA (#34552)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-03-03 10:21:57 -05:00
committed by GitHub
parent fb7fdc49c4
commit 28ef9ba399
7 changed files with 260 additions and 197 deletions

View File

@@ -476,12 +476,12 @@ def test_set_inputs_first_pass_draft_model():
proposer.max_num_tokens, dtype=torch.bool, device=device
)
# Mock the attn_metadata_builder to avoid needing the full model setup
# Mock draft_attn_groups to avoid needing the full model setup
mock_kv_cache_spec = mock.MagicMock()
mock_kv_cache_spec.block_size = block_size
mock_builder = mock.MagicMock()
mock_builder.kv_cache_spec = mock_kv_cache_spec
proposer.attn_metadata_builder = mock_builder
mock_attn_group = mock.MagicMock()
mock_attn_group.kv_cache_spec = mock_kv_cache_spec
proposer.draft_attn_groups = [mock_attn_group]
# Request 0: query_len=3 (but 1 rejected), Request 1: query_len=2
batch_spec = BatchSpec(
@@ -616,12 +616,12 @@ def test_set_inputs_first_pass_parallel_drafting():
proposer.max_num_tokens, dtype=torch.bool, device=device
)
# Mock the attn_metadata_builder
# Mock draft_attn_groups
mock_kv_cache_spec = mock.MagicMock()
mock_kv_cache_spec.block_size = block_size
mock_builder = mock.MagicMock()
mock_builder.kv_cache_spec = mock_kv_cache_spec
proposer.attn_metadata_builder = mock_builder
mock_attn_group = mock.MagicMock()
mock_attn_group.kv_cache_spec = mock_kv_cache_spec
proposer.draft_attn_groups = [mock_attn_group]
# Request 0: query_len=4 (1 rejected), Request 1: query_len=4 (all valid)
batch_spec = BatchSpec(
@@ -916,7 +916,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
proposer.model = model_mock
# Assign draft attn_layer_names since load_model is not invoked
proposer.attn_layer_names = ["layer.0"]
proposer._draft_attn_layer_names = {"layer.0"}
# Create input tensors
batch_spec = BatchSpec(
@@ -961,20 +961,18 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,
layer_names=proposer._draft_attn_layer_names,
vllm_config=proposer.vllm_config,
device=device,
)
# Mock runner for attention metadata building
# Mock runner and draft_attn_groups for attention metadata building
proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][
0
].get_metadata_builder.return_value = attn_metadata_builder
proposer._get_attention_metadata_builder = mock.MagicMock(
return_value=attn_metadata_builder
)
mock_attn_group = mock.MagicMock()
mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder
mock_attn_group.layer_names = list(proposer._draft_attn_layer_names)
mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec
proposer.draft_attn_groups = [mock_attn_group]
result = proposer.propose(
target_token_ids=target_token_ids,
@@ -1089,7 +1087,7 @@ def test_propose_tree(spec_token_tree):
proposer.model = model_mock
# Assign draft attn_layer_names since load_model is not invoked
proposer.attn_layer_names = ["layer.0"]
proposer._draft_attn_layer_names = {"layer.0"}
# Get the tree attention metadata builder.
attn_metadata_builder_cls, _ = try_get_attention_backend(
@@ -1097,21 +1095,18 @@ def test_propose_tree(spec_token_tree):
)
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,
layer_names=proposer._draft_attn_layer_names,
vllm_config=proposer.vllm_config,
device=device,
)
# Mock runner for attention metadata building.
# Mock runner and draft_attn_groups for attention metadata building.
proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].metadata_builders = [attn_metadata_builder]
proposer.runner.attn_groups[0][
0
].get_metadata_builder.return_value = attn_metadata_builder
proposer._get_attention_metadata_builder = mock.MagicMock(
return_value=attn_metadata_builder
)
mock_attn_group = mock.MagicMock()
mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder
mock_attn_group.layer_names = list(proposer._draft_attn_layer_names)
mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec
proposer.draft_attn_groups = [mock_attn_group]
# Setup inputs for the proposer.
target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device)