[V1][Usage] Refactor speculative decoding configuration and tests (#14434)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
@@ -24,12 +24,7 @@ SPEC_MODEL = "JackFram/llama-68m"
|
||||
"4",
|
||||
]])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
[
|
||||
"--speculative-model",
|
||||
f"{SPEC_MODEL}",
|
||||
"--num-speculative-tokens",
|
||||
"5",
|
||||
],
|
||||
[],
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize(
|
||||
@@ -37,8 +32,12 @@ SPEC_MODEL = "JackFram/llama-68m"
|
||||
[
|
||||
#TODO(wooyeon): add spec_draft_dp=2 case
|
||||
[
|
||||
"--speculative-draft-tensor-parallel-size",
|
||||
"1",
|
||||
"--speculative_config",
|
||||
str({
|
||||
"model": f"{SPEC_MODEL}",
|
||||
"num_speculative_tokens": 5,
|
||||
"draft_tensor_parallel_size": 1,
|
||||
}),
|
||||
],
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@@ -78,15 +77,14 @@ def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs,
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
[
|
||||
"--speculative-model",
|
||||
f"{SPEC_MODEL}",
|
||||
"--num-speculative-tokens",
|
||||
"5",
|
||||
|
||||
# Artificially limit the draft model max model len; this forces vLLM
|
||||
# to skip speculation once the sequences grow beyond 32-k tokens.
|
||||
"--speculative-max-model-len",
|
||||
"32",
|
||||
"--speculative_config",
|
||||
str({
|
||||
"model": f"{SPEC_MODEL}",
|
||||
"num_speculative_tokens": 5,
|
||||
"max_model_len": 32,
|
||||
}),
|
||||
],
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
|
||||
Reference in New Issue
Block a user