[V1][Usage] Refactor speculative decoding configuration and tests (#14434)

Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
shangmingc
2025-03-23 13:28:10 +08:00
committed by GitHub
parent 0661cfef7a
commit 50c9636d87
20 changed files with 1055 additions and 802 deletions

View File

@@ -56,7 +56,7 @@ def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
def maybe_assert_ngram_worker(llm):
# Verify the proposer worker is ngram if ngram is specified.
if (llm.llm_engine.speculative_config is not None
and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0):
and llm.llm_engine.speculative_config.method == "ngram"):
from vllm.spec_decode.ngram_worker import NGramWorker
assert isinstance(
llm.llm_engine.model_executor.driver_worker.proposer_worker,

View File

@@ -7,28 +7,39 @@ from vllm import SamplingParams
from .conftest import get_output_from_llm_generator
@pytest.mark.parametrize("common_llm_kwargs", [{
"model": "meta-llama/Llama-3.2-1B-Instruct",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
}])
@pytest.mark.parametrize("common_llm_kwargs",
[{
"model": "meta-llama/Llama-3.2-1B-Instruct",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Speculative max model len > overridden max model len should raise.
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"max_model_len": 129,
},
"max_model_len": 128,
"speculative_max_model_len": 129,
},
{
# Speculative max model len > draft max model len should raise.
# https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
"speculative_max_model_len": 2048 + 1,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"max_model_len": 2048 + 1,
},
},
{
# Speculative max model len > target max model len should raise.
# https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18
"speculative_max_model_len": 131072 + 1,
# https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"max_model_len": 131072 + 1,
},
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])

View File

@@ -57,8 +57,10 @@ PRECISION = "float32"
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
@@ -95,18 +97,19 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": False,
"disable_logprobs": False,
},
{
"speculative_model": SPEC_MODEL,
}, {
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": True,
"disable_logprobs": True,
},
])
}])
@pytest.mark.parametrize("output_len", [
128,
])
@@ -119,18 +122,19 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int):
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
@@ -151,8 +155,10 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
@@ -193,8 +199,10 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(
@@ -236,8 +244,10 @@ def test_eagle_e2e_greedy_correctness_with_preemption(
"test_llm_kwargs",
[
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": k,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": k,
},
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
@@ -277,12 +287,13 @@ def test_eagle_different_k(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_by_batch_size": 4,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
@@ -324,8 +335,10 @@ def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "yuhuili/EAGLE-llama2-chat-7B",
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": "yuhuili/EAGLE-llama2-chat-7B",
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(
@@ -372,8 +385,10 @@ def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(
@@ -420,8 +435,10 @@ def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "yuhuili/EAGLE-Qwen2-7B-Instruct",
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": "yuhuili/EAGLE-Qwen2-7B-Instruct",
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(

View File

@@ -23,8 +23,10 @@ MAIN_MODEL = "JackFram/llama-68m"
[
{
# Identical models.
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@@ -57,26 +59,33 @@ def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"speculative_model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
# Explicitly specify draft model quantization
{
"speculative_model_quantization": "gptq",
"speculative_config": {
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
"num_speculative_tokens": 5,
"quantization": "gptq",
},
},
# Explicitly specify GPTQ-based draft model to use marlin quantization
{
"speculative_model_quantization": "marlin",
"speculative_config": {
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
"num_speculative_tokens": 5,
"quantization": "marlin",
},
},
# Not explicitly specify draft model quantization
{
"speculative_model_quantization": None,
"speculative_config": {
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
"num_speculative_tokens": 5,
"quantization": None,
},
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@@ -107,15 +116,16 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
@@ -127,7 +137,7 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
"""Verify that ngram speculative decoding generates the same output
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,

View File

@@ -27,18 +27,19 @@ from .conftest import run_equality_correctness_test_tp
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize("test_llm_kwargs", [
[
"--speculative-model",
"JackFram/llama-68m",
"--num-speculative-tokens",
"3",
"--speculative_config",
str({
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
}),
],
[
"--speculative-model",
"[ngram]",
"--num-speculative-tokens",
"5",
"--ngram-prompt-lookup-max",
"3",
"--speculative_config",
str({
"model": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
}),
],
])
@pytest.mark.parametrize("batch_size", [2])
@@ -83,23 +84,24 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
]])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize("model, test_llm_kwargs",
[("JackFram/llama-68m", [
"--speculative-model",
"JackFram/llama-68m",
"--num_speculative-tokens",
"5",
"--speculative-draft-tensor-parallel-size",
"1",
]),
("ibm-granite/granite-3b-code-instruct", [
"--speculative-model",
"ibm-granite/granite-3b-code-instruct",
"--num_speculative-tokens",
"5",
"--speculative-draft-tensor-parallel-size",
"1",
])])
@pytest.mark.parametrize(
"model, test_llm_kwargs",
[("JackFram/llama-68m", [
"--speculative_config",
str({
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"draft_tensor_parallel_size": 1,
}),
]),
("ibm-granite/granite-3b-code-instruct", [
"--speculative_config",
str({
"model": "ibm-granite/granite-3b-code-instruct",
"num_speculative_tokens": 5,
"draft_tensor_parallel_size": 1,
}),
])])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
@@ -144,18 +146,19 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize("model, test_llm_kwargs",
[("JackFram/llama-68m", [
"--speculative-model",
"JackFram/llama-68m",
"--num_speculative-tokens",
"3",
"--speculative_config",
str({
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
}),
]),
("JackFram/llama-68m", [
"--speculative-model",
"JackFram/llama-68m",
"--num_speculative-tokens",
"3",
"--speculative-draft-tensor-parallel-size",
"1",
"--speculative_config",
str({
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"draft_tensor_parallel_size": 1,
}),
])])
@pytest.mark.parametrize("logprobs", [None, 2])
@pytest.mark.parametrize("batch_size", [2])

View File

@@ -24,12 +24,7 @@ SPEC_MODEL = "JackFram/llama-68m"
"4",
]])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
[
"--speculative-model",
f"{SPEC_MODEL}",
"--num-speculative-tokens",
"5",
],
[],
])
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize(
@@ -37,8 +32,12 @@ SPEC_MODEL = "JackFram/llama-68m"
[
#TODO(wooyeon): add spec_draft_dp=2 case
[
"--speculative-draft-tensor-parallel-size",
"1",
"--speculative_config",
str({
"model": f"{SPEC_MODEL}",
"num_speculative_tokens": 5,
"draft_tensor_parallel_size": 1,
}),
],
])
@pytest.mark.parametrize("batch_size", [2])
@@ -78,15 +77,14 @@ def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs,
"test_llm_kwargs",
[
[
"--speculative-model",
f"{SPEC_MODEL}",
"--num-speculative-tokens",
"5",
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"--speculative-max-model-len",
"32",
"--speculative_config",
str({
"model": f"{SPEC_MODEL}",
"num_speculative_tokens": 5,
"max_model_len": 32,
}),
],
])
@pytest.mark.parametrize("batch_size", [8])

View File

@@ -20,16 +20,19 @@ from .conftest import run_equality_correctness_test
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}, {
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": True,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs": False,
},
}, {
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs": True,
},
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
@@ -48,19 +51,20 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
as well as with and without chunked prefill.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
@@ -73,16 +77,19 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}, {
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 6,
"disable_logprobs_during_spec_decoding": False,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs": False,
},
}, {
"speculative_config": {
"model": "JackFram/llama-160m",
"num_speculative_tokens": 6,
"disable_logprobs": False,
},
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
@@ -98,18 +105,19 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
output_len: int, seed: int, logprobs: int):
"""Veriy logprob greedy equality with different speculation lens.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
@@ -125,13 +133,15 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize(
"test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len": 32,
"speculative_config": {
"model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs": False,
# Artificially limit the draft model max model len; this forces
# vLLM to skip speculation once the sequences grow beyond 32-k
# tokens.
"max_model_len": 32,
},
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
@@ -149,18 +159,19 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
seed: int, logprobs: int):
"""Verify logprobs greedy equality when some sequences skip speculation.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
@@ -173,12 +184,13 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs": False,
},
}])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize(
"output_len",
@@ -248,12 +260,13 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": True,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs": True,
},
}])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize(
@@ -270,15 +283,16 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
"""Check the behavior when logprobs are disabled.
Token choices should match with the base model.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])

View File

@@ -60,8 +60,10 @@ PRECISION = "float32"
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
@@ -107,14 +109,18 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": False,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": False,
},
},
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": True,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": True,
},
},
])
@pytest.mark.parametrize("output_len", [
@@ -132,19 +138,20 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
prefill_chunk_size: int):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
@@ -165,8 +172,10 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
@@ -214,8 +223,10 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(
@@ -264,8 +275,10 @@ def test_medusa_e2e_greedy_correctness_with_preemption(
"test_llm_kwargs",
[
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": k,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": k,
},
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
@@ -312,12 +325,13 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_by_batch_size": 4,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
@@ -359,16 +373,17 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
# Main model
"model_name": MAIN_MODEL,
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_by_batch_size": 4,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",

View File

@@ -62,7 +62,9 @@ PRECISION = "float32"
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"speculative_config": {
"model": SPEC_MODEL,
},
},
])
@pytest.mark.parametrize("output_len", [
@@ -108,12 +110,16 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"disable_logprobs_during_spec_decoding": False,
"speculative_config": {
"model": SPEC_MODEL,
"disable_logprobs": False,
},
},
{
"speculative_model": SPEC_MODEL,
"disable_logprobs_during_spec_decoding": True,
"speculative_config": {
"model": SPEC_MODEL,
"disable_logprobs": True,
},
},
])
@pytest.mark.parametrize("output_len", [8])
@@ -133,19 +139,20 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
# up sampling different tokens at the tail (ie top tokens don't change).
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
@@ -167,7 +174,9 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"speculative_config": {
"model": SPEC_MODEL,
},
},
])
@pytest.mark.parametrize("output_len", [2048])
@@ -209,8 +218,10 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
# Main model
"model_name": MAIN_MODEL,
# Speculative model
"speculative_model": SPEC_MODEL,
# Speculative config
"speculative_config": {
"model": SPEC_MODEL,
},
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
@@ -274,7 +285,9 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"speculative_config": {
"model": SPEC_MODEL,
},
},
])
@pytest.mark.parametrize(
@@ -326,7 +339,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"speculative_config": {
"model": SPEC_MODEL,
},
},
])
@pytest.mark.parametrize(
@@ -382,8 +397,10 @@ def test_mlp_e2e_greedy_correctness_with_padding(
"test_llm_kwargs",
[
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": k,
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": k,
},
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
@@ -430,11 +447,12 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": SPEC_MODEL,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"disable_by_batch_size": 4,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
@@ -475,14 +493,15 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
"speculative_model": SPEC_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",

View File

@@ -57,7 +57,9 @@ PRECISION = "bfloat16"
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
@@ -99,12 +101,16 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": False,
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": False,
},
},
{
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": True,
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": True,
},
},
])
@pytest.mark.parametrize("output_len", [
@@ -119,18 +125,19 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int):
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
@@ -152,7 +159,9 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
@@ -198,7 +207,9 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(
@@ -243,7 +254,9 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
"test_llm_kwargs",
[
{
"num_speculative_tokens": k,
"speculative_config": {
"num_speculative_tokens": k,
},
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
@@ -286,11 +299,12 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_by_batch_size": 4
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",

View File

@@ -61,15 +61,19 @@ from .conftest import (get_output_from_llm_generator,
"per_test_common_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
# Chunked prefill enabled with small value
# to make sure we get mixed batches.
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
@@ -148,20 +152,23 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
"disable_logprobs_during_spec_decoding": False
}, {
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
"disable_logprobs_during_spec_decoding": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"disable_logprobs": False,
},
"enable_chunked_prefill": False,
}, {
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs": False,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
}])
@pytest.mark.parametrize(
"output_len",
[
@@ -184,7 +191,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
whether all speculative tokens are accepted.
"""
ensure_all_accepted = per_test_common_llm_kwargs.get(
"model_name") == test_llm_kwargs.get("speculative_model")
"model_name") == test_llm_kwargs.get("speculative_config")["model"]
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
@@ -224,13 +231,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
@@ -283,13 +294,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
@@ -336,13 +351,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
@@ -391,13 +410,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
@@ -449,13 +472,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
@@ -514,13 +541,17 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
@@ -567,21 +598,25 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
"test_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len": 32,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"max_model_len": 32,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"max_model_len": 32,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
"speculative_max_model_len": 32,
},
])
@pytest.mark.parametrize("batch_size", [8])
@@ -627,15 +662,19 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_disable_by_batch_size": 2,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"disable_by_batch_size": 2,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_disable_by_batch_size": 2,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"disable_by_batch_size": 2,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
@@ -676,15 +715,19 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs,
"test_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": k,
},
"enable_chunked_prefill": False,
}
# Try a range of common k, as well as large speculation.
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
] + [{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": k,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
@@ -729,17 +772,21 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
"test_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"acceptance_method": "typical_acceptance_sampler",
},
"enable_chunked_prefill": False
}
# Try a range of common k.
for k in [1, 2, 3]
] + [{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"acceptance_method": "typical_acceptance_sampler",
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4

View File

@@ -48,16 +48,20 @@ from .conftest import run_equality_correctness_test
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_mqa_scorer": False,
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": False,
},
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_mqa_scorer": True,
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": True,
},
},
])
@pytest.mark.parametrize("output_len", [
@@ -101,16 +105,20 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"disable_logprobs_during_spec_decoding": False,
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_logprobs": False,
},
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"disable_logprobs_during_spec_decoding": True,
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_logprobs": True,
},
},
])
@pytest.mark.parametrize("output_len", [
@@ -125,19 +133,20 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int):
"""Verify greedy equality on a tiny model with different batch size."""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
@@ -159,17 +168,21 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
},
"enable_chunked_prefill": False,
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": True,
},
"enable_chunked_prefill": True,
"speculative_disable_mqa_scorer": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
@@ -214,17 +227,21 @@ def test_ngram_e2e_greedy_correctness_with_preemption(
"test_llm_kwargs",
[
{
"speculative_model": "[ngram]",
"num_speculative_tokens": k,
"ngram_prompt_lookup_max": 3,
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": k,
"prompt_lookup_max": 3,
},
}
# Try a range of common k, as well as large speculation.
for k in [1, 3, 5]
] + [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": k,
"ngram_prompt_lookup_max": 1,
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": k,
"prompt_lookup_max": 1,
},
}
# Try a range of common k, as well as large speculation.
for k in [1, 3, 5]
@@ -243,7 +260,7 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
seed: int):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
different ngram prompt_lookup_max.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
@@ -266,22 +283,25 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_by_batch_size": 4
}, {
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_by_batch_size": 4,
"enable_chunked_prefill": True,
"speculative_disable_mqa_scorer": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_by_batch_size": 4
},
}, {
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_by_batch_size": 4,
"disable_mqa_scorer": True,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
@@ -296,7 +316,7 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
seed: int):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
different ngram prompt_lookup_max.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
@@ -316,18 +336,17 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",

View File

@@ -19,11 +19,11 @@ SPEC_MODEL = "JackFram/llama-160m"
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# speculative model
"speculative_model": "JackFram/llama-160m",
# num speculative tokens
"num_speculative_tokens": 3,
# speculative config
"speculative_config": {
"model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
},
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])