[V1][Usage] Refactor speculative decoding configuration and tests (#14434)

Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
shangmingc
2025-03-23 13:28:10 +08:00
committed by GitHub
parent 0661cfef7a
commit 50c9636d87
20 changed files with 1055 additions and 802 deletions

View File

@@ -20,16 +20,19 @@ from .conftest import run_equality_correctness_test
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}, {
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": True,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs": False,
},
}, {
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs": True,
},
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
@@ -48,19 +51,20 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
as well as with and without chunked prefill.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
@@ -73,16 +77,19 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}, {
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 6,
"disable_logprobs_during_spec_decoding": False,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs": False,
},
}, {
"speculative_config": {
"model": "JackFram/llama-160m",
"num_speculative_tokens": 6,
"disable_logprobs": False,
},
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
@@ -98,18 +105,19 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
output_len: int, seed: int, logprobs: int):
"""Veriy logprob greedy equality with different speculation lens.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
@@ -125,13 +133,15 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize(
"test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len": 32,
"speculative_config": {
"model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs": False,
# Artificially limit the draft model max model len; this forces
# vLLM to skip speculation once the sequences grow beyond 32-k
# tokens.
"max_model_len": 32,
},
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
@@ -149,18 +159,19 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
seed: int, logprobs: int):
"""Verify logprobs greedy equality when some sequences skip speculation.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
@@ -173,12 +184,13 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs": False,
},
}])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize(
"output_len",
@@ -248,12 +260,13 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": True,
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs": True,
},
}])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize(
@@ -270,15 +283,16 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
"""Check the behavior when logprobs are disabled.
Token choices should match with the base model.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])