[Feature] [Spec decode]: Enable MLPSpeculator/Medusa and prompt_logprobs with ChunkedPrefill (#10132)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: wallashss <wallashss@ibm.com>
Co-authored-by: wallashss <wallashss@ibm.com>
This commit is contained in:
Nicolò Lucchesi
2025-01-27 22:38:35 +01:00
committed by GitHub
parent 2bc3fbba0c
commit 6116ca8cd7
16 changed files with 468 additions and 165 deletions

View File

@@ -2,6 +2,7 @@ from itertools import cycle
from typing import List, Optional, Sequence, Tuple, Union
import pytest
import torch
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
@@ -154,6 +155,8 @@ def _check_logprobs_when_output_disabled(
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
assert spec_pos_logprob.rank == -1
assert spec_pos_logprob.logprob == 0.0
if isinstance(spec_pos_logprob_token_id, torch.Tensor):
spec_pos_logprob_token_id = spec_pos_logprob_token_id.item()
assert spec_pos_logprob_token_id in baseline_pos_logprobs
@@ -244,7 +247,8 @@ def run_equality_correctness_test_tp(model,
batch_size: int,
max_output_len: int,
seed: int = 0,
temperature: float = 0.0):
temperature: float = 0.0,
logprobs: Optional[int] = None):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
@@ -257,7 +261,6 @@ def run_equality_correctness_test_tp(model,
results = []
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
for args, env in ((arg1, env1), (arg2, env2)):
with RemoteOpenAIServer(model,
args,
@@ -269,12 +272,14 @@ def run_equality_correctness_test_tp(model,
prompt=prompts,
max_tokens=max_output_len,
seed=seed,
temperature=temperature)
temperature=temperature,
logprobs=logprobs)
results.append({
"test":
"seeded_sampling",
"text": [choice.text for choice in completion.choices],
"logprobs": [choice.logprobs for choice in completion.choices],
"finish_reason":
[choice.finish_reason for choice in completion.choices],
"usage":
@@ -284,7 +289,15 @@ def run_equality_correctness_test_tp(model,
n = len(results) // 2
arg1_results = results[:n]
arg2_results = results[n:]
# Separate logprobs to avoid asserting exact equality.
arg1_logprobs = [r.pop("logprobs") for r in arg1_results]
arg2_logprobs = [r.pop("logprobs") for r in arg2_results]
for arg1_result, arg2_result in zip(arg1_results, arg2_results):
assert arg1_result == arg2_result, (
f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
f"{arg1_result=} != {arg2_result=}")
if logprobs:
for logs1, logs2 in zip(arg1_logprobs, arg2_logprobs):
for l1, l2 in zip(logs1, logs2):
assert l1.tokens == l2.tokens

View File

@@ -2,6 +2,8 @@
tensor parallelism.
"""
from typing import Optional
import pytest
import torch
@@ -154,15 +156,20 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
"--speculative-draft-tensor-parallel-size",
"1",
])])
@pytest.mark.parametrize("logprobs", [None, 2])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
logprobs: Optional[int],
batch_size: int, seed: int):
"""Verify spec decode works well with same and different TP size for
the draft model with chunked prefill.
"""
if logprobs:
test_llm_kwargs.extend(
["--disable_logprobs_during_spec_decoding", "False"])
run_equality_correctness_test_tp(model,
common_llm_kwargs,
per_test_common_llm_kwargs,
@@ -171,4 +178,5 @@ def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
batch_size,
max_output_len=32,
seed=seed,
temperature=0.0)
temperature=0.0,
logprobs=logprobs)

View File

@@ -4,26 +4,27 @@ import pytest
from vllm import SamplingParams
from ..utils import maybe_enable_chunked_prefill
from .conftest import run_equality_correctness_test
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-68m",
"model_name": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
"enforce_eager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}, {
"speculative_model": "JackFram/llama-160m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": True,
}])
@@ -36,12 +37,15 @@ from .conftest import run_equality_correctness_test
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4, 12])
def test_logprobs_equality(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int, logprobs: int):
"""Verify output logprobs are equal with and without speculative decoding.
seed: int, logprobs: int, prefill_chunk_size: int):
"""Verify output logprobs are equal with and without speculative decoding,
as well as with and without chunked prefill.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,

View File

@@ -21,6 +21,7 @@ correctess for the target model outputs.
import pytest
from ..utils import maybe_enable_chunked_prefill
from .conftest import run_equality_correctness_test
# main model
@@ -67,12 +68,14 @@ PRECISION = "float32"
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):
seed: int, prefill_chunk_size: int):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
@@ -119,12 +122,15 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int, logprobs: int):
seed: int, logprobs: int,
prefill_chunk_size: int):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
@@ -167,12 +173,14 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_medusa_e2e_greedy_correctness_cuda_graph(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
seed: int, prefill_chunk_size: int):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
@@ -217,13 +225,15 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_medusa_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
seed: int, prefill_chunk_size: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
@@ -267,13 +277,15 @@ def test_medusa_e2e_greedy_correctness_with_preemption(
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_medusa_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
seed: int, prefill_chunk_size: int):
"""Verify that medusa speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
@@ -313,14 +325,17 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs,
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
output_len: int, seed: int,
prefill_chunk_size: int):
"""Verify that medusa speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
@@ -361,12 +376,14 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
output_len: int, seed: int, prefill_chunk_size: int):
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,

View File

@@ -25,6 +25,7 @@ import pytest
from vllm.model_executor.layers.vocab_parallel_embedding import pad_vocab_size
from ..utils import maybe_enable_chunked_prefill
from .conftest import run_equality_correctness_test
# main model
@@ -66,14 +67,16 @@ PRECISION = "float32"
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("batch_size", [4, 32])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):
seed: int, prefill_chunk_size: int):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
@@ -116,12 +119,19 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int):
logprobs: int, prefill_chunk_size: int):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
# NOTE Test is sensitive enough st if we don't enable chunked prefill
# scheduling on baseline too, we get slightly different logprobs, ending
# up sampling different tokens at the tail (ie top tokens don't change).
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
@@ -162,12 +172,15 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("output_len", [2048])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int):
batch_size: int, output_len: int,
prefill_chunk_size: int, seed: int):
"""Verify acceptance rate with different batch size and large output
length."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
@@ -204,13 +217,17 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("output_len", [64])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("temperature", [1.0])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
temperature: float, seed: int):
temperature: float,
prefill_chunk_size: int, seed: int):
"""Verify seeded runs produce the same output."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
@@ -266,14 +283,16 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
prefill_chunk_size: int, seed: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
@@ -317,12 +336,14 @@ def test_mlp_e2e_greedy_correctness_with_preemption(
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
def test_mlp_e2e_greedy_correctness_with_padding(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
prefill_chunk_size: int, seed: int):
"""Verify greedy equality when the vocab dimension is padded
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
# Default pad_to is 64, test model has vocab_size of 32000
def patched_pad_vocab_size(vocab_size, pad_to=None):
@@ -373,14 +394,16 @@ def test_mlp_e2e_greedy_correctness_with_padding(
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
@pytest.mark.parametrize("seed", [1])
def test_mlp_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, seed: int,
output_len: int):
test_llm_kwargs, batch_size: int,
prefill_chunk_size: int, seed: int, output_len: int):
"""Verify that mlp speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
@@ -418,15 +441,21 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs,
# Use smaller output len for fast test.
32,
])
# Speculative decoding is disabled when sequences reach decoding and the batch
# consists of single-token requests. Hence we set `max_num_seqs`
# >= `speculative_disable_by_batch_size` to test feature interaction.
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
@pytest.mark.parametrize("seed", [1])
def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, seed: int,
test_llm_kwargs, batch_size: int,
prefill_chunk_size: int, seed: int,
output_len: int):
"""Verify that mlp speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
@@ -460,13 +489,15 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
@pytest.mark.parametrize("seed", [1])
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
output_len: int, prefill_chunk_size: int, seed: int):
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,

View File

@@ -147,20 +147,20 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
},
])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
"disable_logprobs_during_spec_decoding": False
}, {
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
"disable_logprobs_during_spec_decoding": False
}])
@pytest.mark.parametrize(
"output_len",
[
@@ -192,6 +192,9 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
batch_size,
max_output_len=output_len,
seed=seed,
prompt_logprobs=2,
logprobs=2,
disable_logprobs=False,
temperature=0.0,
ensure_all_accepted=ensure_all_accepted)

View File

@@ -26,6 +26,7 @@ for the target model outputs.
import pytest
from ..utils import maybe_enable_chunked_prefill
from .conftest import run_equality_correctness_test
@@ -49,11 +50,13 @@ from .conftest import run_equality_correctness_test
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_mqa_scorer": False,
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_mqa_scorer": True,
},
])
@pytest.mark.parametrize("output_len", [
@@ -68,15 +71,7 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
batch_size: int, output_len: int,
prefill_chunk_size: int, seed: int):
"""Verify greedy equality on a tiny model with different batch size."""
if prefill_chunk_size > 0:
common_llm_kwargs.update(
**{
"enable_chunked_prefill": True,
"max_num_batched_tokens": prefill_chunk_size,
"max_num_seqs": prefill_chunk_size
})
else:
common_llm_kwargs["enable_chunked_prefill"] = False
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,