[Bugfix] Fix block_size for hybrid model MTP (#36036)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Benjamin Chislett
2026-03-05 01:10:54 -05:00
committed by GitHub
parent d106bf39f5
commit 57c629e9c1
2 changed files with 25 additions and 17 deletions

View File

@@ -37,6 +37,8 @@ eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
ar_draft_model_dir = "amd/PARD-Llama-3.2-1B" # Compatible with parallel and AR drafting
BLOCK_SIZE = 16
def _create_proposer(
method: str,
@@ -91,9 +93,11 @@ def _create_proposer(
)
if "eagle" in method:
return EagleProposer(vllm_config=vllm_config, device=device)
proposer = EagleProposer(vllm_config=vllm_config, device=device)
else:
return DraftModelProposer(vllm_config=vllm_config, device=device)
proposer = DraftModelProposer(vllm_config=vllm_config, device=device)
proposer.block_size = BLOCK_SIZE
return proposer
def test_prepare_next_token_ids():
@@ -163,7 +167,7 @@ def test_prepare_next_token_ids():
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
block_size=BLOCK_SIZE,
device=device,
)
@@ -207,7 +211,7 @@ def test_prepare_inputs():
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
block_size=BLOCK_SIZE,
device=device,
)
@@ -302,7 +306,7 @@ def test_prepare_inputs_padded():
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
block_size=BLOCK_SIZE,
device=device,
)
@@ -371,7 +375,7 @@ def test_set_inputs_first_pass_default_eagle():
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
block_size=BLOCK_SIZE,
device=device,
)
@@ -462,7 +466,7 @@ def test_set_inputs_first_pass_draft_model():
device = torch.device(current_platform.device_type)
num_speculative_tokens = 2
block_size = 16
block_size = BLOCK_SIZE
# Create a proposer configured as a draft model (pass_hidden_states=False)
# We need to mock this since _create_proposer defaults to EAGLE
@@ -600,7 +604,7 @@ def test_set_inputs_first_pass_parallel_drafting():
device = torch.device(current_platform.device_type)
num_speculative_tokens = 3
block_size = 16
block_size = BLOCK_SIZE
proposer = _create_proposer("eagle", num_speculative_tokens, parallel_drafting=True)
@@ -926,7 +930,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
block_size=BLOCK_SIZE,
device=device,
)
@@ -1123,7 +1127,7 @@ def test_propose_tree(spec_token_tree):
)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
block_size=BLOCK_SIZE,
device=device,
)
sampling_metadata = mock.MagicMock()