[BugFix] Add support for MTP num_speculative_tokens > 1 with sparse MLA (#34552)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -162,7 +162,7 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
|
||||
model_mock.compute_logits.side_effect = logits_returns
|
||||
|
||||
proposer.model = model_mock
|
||||
proposer.attn_layer_names = ["layer.0"]
|
||||
proposer._draft_attn_layer_names = {"layer.0"}
|
||||
|
||||
# Prepare inputs
|
||||
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens)
|
||||
@@ -190,13 +190,17 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
|
||||
|
||||
attn_metadata_builder = attn_metadata_builder_cls(
|
||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||
layer_names=proposer.attn_layer_names,
|
||||
layer_names=list(proposer._draft_attn_layer_names),
|
||||
vllm_config=proposer.vllm_config,
|
||||
device=device,
|
||||
)
|
||||
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.attn_metadata_builder = attn_metadata_builder
|
||||
mock_attn_group = mock.MagicMock()
|
||||
mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder
|
||||
mock_attn_group.layer_names = list(proposer._draft_attn_layer_names)
|
||||
mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec
|
||||
proposer.draft_attn_groups = [mock_attn_group]
|
||||
|
||||
# Run propose
|
||||
result = proposer.propose(
|
||||
|
||||
Reference in New Issue
Block a user