[BugFix] Add support for MTP num_speculative_tokens > 1 with sparse MLA (#34552)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -476,12 +476,12 @@ def test_set_inputs_first_pass_draft_model():
|
||||
proposer.max_num_tokens, dtype=torch.bool, device=device
|
||||
)
|
||||
|
||||
# Mock the attn_metadata_builder to avoid needing the full model setup
|
||||
# Mock draft_attn_groups to avoid needing the full model setup
|
||||
mock_kv_cache_spec = mock.MagicMock()
|
||||
mock_kv_cache_spec.block_size = block_size
|
||||
mock_builder = mock.MagicMock()
|
||||
mock_builder.kv_cache_spec = mock_kv_cache_spec
|
||||
proposer.attn_metadata_builder = mock_builder
|
||||
mock_attn_group = mock.MagicMock()
|
||||
mock_attn_group.kv_cache_spec = mock_kv_cache_spec
|
||||
proposer.draft_attn_groups = [mock_attn_group]
|
||||
|
||||
# Request 0: query_len=3 (but 1 rejected), Request 1: query_len=2
|
||||
batch_spec = BatchSpec(
|
||||
@@ -616,12 +616,12 @@ def test_set_inputs_first_pass_parallel_drafting():
|
||||
proposer.max_num_tokens, dtype=torch.bool, device=device
|
||||
)
|
||||
|
||||
# Mock the attn_metadata_builder
|
||||
# Mock draft_attn_groups
|
||||
mock_kv_cache_spec = mock.MagicMock()
|
||||
mock_kv_cache_spec.block_size = block_size
|
||||
mock_builder = mock.MagicMock()
|
||||
mock_builder.kv_cache_spec = mock_kv_cache_spec
|
||||
proposer.attn_metadata_builder = mock_builder
|
||||
mock_attn_group = mock.MagicMock()
|
||||
mock_attn_group.kv_cache_spec = mock_kv_cache_spec
|
||||
proposer.draft_attn_groups = [mock_attn_group]
|
||||
|
||||
# Request 0: query_len=4 (1 rejected), Request 1: query_len=4 (all valid)
|
||||
batch_spec = BatchSpec(
|
||||
@@ -916,7 +916,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
proposer.model = model_mock
|
||||
|
||||
# Assign draft attn_layer_names since load_model is not invoked
|
||||
proposer.attn_layer_names = ["layer.0"]
|
||||
proposer._draft_attn_layer_names = {"layer.0"}
|
||||
|
||||
# Create input tensors
|
||||
batch_spec = BatchSpec(
|
||||
@@ -961,20 +961,18 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
|
||||
attn_metadata_builder = attn_metadata_builder_cls(
|
||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||
layer_names=proposer.attn_layer_names,
|
||||
layer_names=proposer._draft_attn_layer_names,
|
||||
vllm_config=proposer.vllm_config,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Mock runner for attention metadata building
|
||||
# Mock runner and draft_attn_groups for attention metadata building
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||
proposer.runner.attn_groups[0][
|
||||
0
|
||||
].get_metadata_builder.return_value = attn_metadata_builder
|
||||
proposer._get_attention_metadata_builder = mock.MagicMock(
|
||||
return_value=attn_metadata_builder
|
||||
)
|
||||
mock_attn_group = mock.MagicMock()
|
||||
mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder
|
||||
mock_attn_group.layer_names = list(proposer._draft_attn_layer_names)
|
||||
mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec
|
||||
proposer.draft_attn_groups = [mock_attn_group]
|
||||
|
||||
result = proposer.propose(
|
||||
target_token_ids=target_token_ids,
|
||||
@@ -1089,7 +1087,7 @@ def test_propose_tree(spec_token_tree):
|
||||
proposer.model = model_mock
|
||||
|
||||
# Assign draft attn_layer_names since load_model is not invoked
|
||||
proposer.attn_layer_names = ["layer.0"]
|
||||
proposer._draft_attn_layer_names = {"layer.0"}
|
||||
|
||||
# Get the tree attention metadata builder.
|
||||
attn_metadata_builder_cls, _ = try_get_attention_backend(
|
||||
@@ -1097,21 +1095,18 @@ def test_propose_tree(spec_token_tree):
|
||||
)
|
||||
attn_metadata_builder = attn_metadata_builder_cls(
|
||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||
layer_names=proposer.attn_layer_names,
|
||||
layer_names=proposer._draft_attn_layer_names,
|
||||
vllm_config=proposer.vllm_config,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Mock runner for attention metadata building.
|
||||
# Mock runner and draft_attn_groups for attention metadata building.
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||
proposer.runner.attn_groups[0][0].metadata_builders = [attn_metadata_builder]
|
||||
proposer.runner.attn_groups[0][
|
||||
0
|
||||
].get_metadata_builder.return_value = attn_metadata_builder
|
||||
proposer._get_attention_metadata_builder = mock.MagicMock(
|
||||
return_value=attn_metadata_builder
|
||||
)
|
||||
mock_attn_group = mock.MagicMock()
|
||||
mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder
|
||||
mock_attn_group.layer_names = list(proposer._draft_attn_layer_names)
|
||||
mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec
|
||||
proposer.draft_attn_groups = [mock_attn_group]
|
||||
|
||||
# Setup inputs for the proposer.
|
||||
target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device)
|
||||
|
||||
Reference in New Issue
Block a user