[V1][Usage] Refactor speculative decoding configuration and tests (#14434)

Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
shangmingc
2025-03-23 13:28:10 +08:00
committed by GitHub
parent 0661cfef7a
commit 50c9636d87
20 changed files with 1055 additions and 802 deletions

View File

@@ -27,18 +27,19 @@ from .conftest import run_equality_correctness_test_tp
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize("test_llm_kwargs", [
[
"--speculative-model",
"JackFram/llama-68m",
"--num-speculative-tokens",
"3",
"--speculative_config",
str({
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
}),
],
[
"--speculative-model",
"[ngram]",
"--num-speculative-tokens",
"5",
"--ngram-prompt-lookup-max",
"3",
"--speculative_config",
str({
"model": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
}),
],
])
@pytest.mark.parametrize("batch_size", [2])
@@ -83,23 +84,24 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
]])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize("model, test_llm_kwargs",
[("JackFram/llama-68m", [
"--speculative-model",
"JackFram/llama-68m",
"--num_speculative-tokens",
"5",
"--speculative-draft-tensor-parallel-size",
"1",
]),
("ibm-granite/granite-3b-code-instruct", [
"--speculative-model",
"ibm-granite/granite-3b-code-instruct",
"--num_speculative-tokens",
"5",
"--speculative-draft-tensor-parallel-size",
"1",
])])
@pytest.mark.parametrize(
"model, test_llm_kwargs",
[("JackFram/llama-68m", [
"--speculative_config",
str({
"model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"draft_tensor_parallel_size": 1,
}),
]),
("ibm-granite/granite-3b-code-instruct", [
"--speculative_config",
str({
"model": "ibm-granite/granite-3b-code-instruct",
"num_speculative_tokens": 5,
"draft_tensor_parallel_size": 1,
}),
])])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
@@ -144,18 +146,19 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize("model, test_llm_kwargs",
[("JackFram/llama-68m", [
"--speculative-model",
"JackFram/llama-68m",
"--num_speculative-tokens",
"3",
"--speculative_config",
str({
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
}),
]),
("JackFram/llama-68m", [
"--speculative-model",
"JackFram/llama-68m",
"--num_speculative-tokens",
"3",
"--speculative-draft-tensor-parallel-size",
"1",
"--speculative_config",
str({
"model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"draft_tensor_parallel_size": 1,
}),
])])
@pytest.mark.parametrize("logprobs", [None, 2])
@pytest.mark.parametrize("batch_size", [2])