[V1][Usage] Refactor speculative decoding configuration and tests (#14434)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
@@ -56,7 +56,7 @@ def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
def maybe_assert_ngram_worker(llm):
|
||||
# Verify the proposer worker is ngram if ngram is specified.
|
||||
if (llm.llm_engine.speculative_config is not None
|
||||
and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0):
|
||||
and llm.llm_engine.speculative_config.method == "ngram"):
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
assert isinstance(
|
||||
llm.llm_engine.model_executor.driver_worker.proposer_worker,
|
||||
|
||||
@@ -7,28 +7,39 @@ from vllm import SamplingParams
|
||||
from .conftest import get_output_from_llm_generator
|
||||
|
||||
|
||||
@pytest.mark.parametrize("common_llm_kwargs", [{
|
||||
"model": "meta-llama/Llama-3.2-1B-Instruct",
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
}])
|
||||
@pytest.mark.parametrize("common_llm_kwargs",
|
||||
[{
|
||||
"model": "meta-llama/Llama-3.2-1B-Instruct",
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
{
|
||||
# Speculative max model len > overridden max model len should raise.
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"max_model_len": 129,
|
||||
},
|
||||
"max_model_len": 128,
|
||||
"speculative_max_model_len": 129,
|
||||
},
|
||||
{
|
||||
# Speculative max model len > draft max model len should raise.
|
||||
# https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
|
||||
"speculative_max_model_len": 2048 + 1,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"max_model_len": 2048 + 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
# Speculative max model len > target max model len should raise.
|
||||
# https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18
|
||||
"speculative_max_model_len": 131072 + 1,
|
||||
# https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"max_model_len": 131072 + 1,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
|
||||
@@ -57,8 +57,10 @@ PRECISION = "float32"
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
@@ -95,18 +97,19 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
}, {
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs_during_spec_decoding": True,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
])
|
||||
}])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@@ -119,18 +122,19 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int,
|
||||
logprobs: int):
|
||||
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -151,8 +155,10 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
@@ -193,8 +199,10 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
@@ -236,8 +244,10 @@ def test_eagle_e2e_greedy_correctness_with_preemption(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": k,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
}
|
||||
# Try a range of num. speculative tokens
|
||||
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||
@@ -277,12 +287,13 @@ def test_eagle_different_k(vllm_runner, common_llm_kwargs,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_disable_by_batch_size": 4
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_by_batch_size": 4,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
@@ -324,8 +335,10 @@ def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "yuhuili/EAGLE-llama2-chat-7B",
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_config": {
|
||||
"model": "yuhuili/EAGLE-llama2-chat-7B",
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
@@ -372,8 +385,10 @@ def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_config": {
|
||||
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
@@ -420,8 +435,10 @@ def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "yuhuili/EAGLE-Qwen2-7B-Instruct",
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_config": {
|
||||
"model": "yuhuili/EAGLE-Qwen2-7B-Instruct",
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -23,8 +23,10 @@ MAIN_MODEL = "JackFram/llama-68m"
|
||||
[
|
||||
{
|
||||
# Identical models.
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@@ -57,26 +59,33 @@ def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs,
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
# Explicitly specify draft model quantization
|
||||
{
|
||||
"speculative_model_quantization": "gptq",
|
||||
"speculative_config": {
|
||||
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
|
||||
"num_speculative_tokens": 5,
|
||||
"quantization": "gptq",
|
||||
},
|
||||
},
|
||||
# Explicitly specify GPTQ-based draft model to use marlin quantization
|
||||
{
|
||||
"speculative_model_quantization": "marlin",
|
||||
"speculative_config": {
|
||||
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
|
||||
"num_speculative_tokens": 5,
|
||||
"quantization": "marlin",
|
||||
},
|
||||
},
|
||||
# Not explicitly specify draft model quantization
|
||||
{
|
||||
"speculative_model_quantization": None,
|
||||
"speculative_config": {
|
||||
"model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
|
||||
"num_speculative_tokens": 5,
|
||||
"quantization": None,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@@ -107,15 +116,16 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
@@ -127,7 +137,7 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
|
||||
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
"""Verify that ngram speculative decoding generates the same output
|
||||
"""Verify that speculative decoding generates the same output
|
||||
with batch expansion scorer and mqa scorer.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
|
||||
@@ -27,18 +27,19 @@ from .conftest import run_equality_correctness_test_tp
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
[
|
||||
"--speculative-model",
|
||||
"JackFram/llama-68m",
|
||||
"--num-speculative-tokens",
|
||||
"3",
|
||||
"--speculative_config",
|
||||
str({
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
}),
|
||||
],
|
||||
[
|
||||
"--speculative-model",
|
||||
"[ngram]",
|
||||
"--num-speculative-tokens",
|
||||
"5",
|
||||
"--ngram-prompt-lookup-max",
|
||||
"3",
|
||||
"--speculative_config",
|
||||
str({
|
||||
"model": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
}),
|
||||
],
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@@ -83,23 +84,24 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
]])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize("model, test_llm_kwargs",
|
||||
[("JackFram/llama-68m", [
|
||||
"--speculative-model",
|
||||
"JackFram/llama-68m",
|
||||
"--num_speculative-tokens",
|
||||
"5",
|
||||
"--speculative-draft-tensor-parallel-size",
|
||||
"1",
|
||||
]),
|
||||
("ibm-granite/granite-3b-code-instruct", [
|
||||
"--speculative-model",
|
||||
"ibm-granite/granite-3b-code-instruct",
|
||||
"--num_speculative-tokens",
|
||||
"5",
|
||||
"--speculative-draft-tensor-parallel-size",
|
||||
"1",
|
||||
])])
|
||||
@pytest.mark.parametrize(
|
||||
"model, test_llm_kwargs",
|
||||
[("JackFram/llama-68m", [
|
||||
"--speculative_config",
|
||||
str({
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"draft_tensor_parallel_size": 1,
|
||||
}),
|
||||
]),
|
||||
("ibm-granite/granite-3b-code-instruct", [
|
||||
"--speculative_config",
|
||||
str({
|
||||
"model": "ibm-granite/granite-3b-code-instruct",
|
||||
"num_speculative_tokens": 5,
|
||||
"draft_tensor_parallel_size": 1,
|
||||
}),
|
||||
])])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
|
||||
@@ -144,18 +146,19 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize("model, test_llm_kwargs",
|
||||
[("JackFram/llama-68m", [
|
||||
"--speculative-model",
|
||||
"JackFram/llama-68m",
|
||||
"--num_speculative-tokens",
|
||||
"3",
|
||||
"--speculative_config",
|
||||
str({
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
}),
|
||||
]),
|
||||
("JackFram/llama-68m", [
|
||||
"--speculative-model",
|
||||
"JackFram/llama-68m",
|
||||
"--num_speculative-tokens",
|
||||
"3",
|
||||
"--speculative-draft-tensor-parallel-size",
|
||||
"1",
|
||||
"--speculative_config",
|
||||
str({
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"draft_tensor_parallel_size": 1,
|
||||
}),
|
||||
])])
|
||||
@pytest.mark.parametrize("logprobs", [None, 2])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
|
||||
@@ -24,12 +24,7 @@ SPEC_MODEL = "JackFram/llama-68m"
|
||||
"4",
|
||||
]])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
[
|
||||
"--speculative-model",
|
||||
f"{SPEC_MODEL}",
|
||||
"--num-speculative-tokens",
|
||||
"5",
|
||||
],
|
||||
[],
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize(
|
||||
@@ -37,8 +32,12 @@ SPEC_MODEL = "JackFram/llama-68m"
|
||||
[
|
||||
#TODO(wooyeon): add spec_draft_dp=2 case
|
||||
[
|
||||
"--speculative-draft-tensor-parallel-size",
|
||||
"1",
|
||||
"--speculative_config",
|
||||
str({
|
||||
"model": f"{SPEC_MODEL}",
|
||||
"num_speculative_tokens": 5,
|
||||
"draft_tensor_parallel_size": 1,
|
||||
}),
|
||||
],
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@@ -78,15 +77,14 @@ def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs,
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
[
|
||||
"--speculative-model",
|
||||
f"{SPEC_MODEL}",
|
||||
"--num-speculative-tokens",
|
||||
"5",
|
||||
|
||||
# Artificially limit the draft model max model len; this forces vLLM
|
||||
# to skip speculation once the sequences grow beyond 32-k tokens.
|
||||
"--speculative-max-model-len",
|
||||
"32",
|
||||
"--speculative_config",
|
||||
str({
|
||||
"model": f"{SPEC_MODEL}",
|
||||
"num_speculative_tokens": 5,
|
||||
"max_model_len": 32,
|
||||
}),
|
||||
],
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
|
||||
@@ -20,16 +20,19 @@ from .conftest import run_equality_correctness_test
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
}, {
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
}, {
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
@@ -48,19 +51,20 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
|
||||
as well as with and without chunked prefill.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -73,16 +77,19 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
}, {
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 6,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
}, {
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 6,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
@@ -98,18 +105,19 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
|
||||
output_len: int, seed: int, logprobs: int):
|
||||
"""Veriy logprob greedy equality with different speculation lens.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -125,13 +133,15 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
|
||||
# Artificially limit the draft model max model len; this forces vLLM
|
||||
# to skip speculation once the sequences grow beyond 32-k tokens.
|
||||
"speculative_max_model_len": 32,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs": False,
|
||||
# Artificially limit the draft model max model len; this forces
|
||||
# vLLM to skip speculation once the sequences grow beyond 32-k
|
||||
# tokens.
|
||||
"max_model_len": 32,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
@@ -149,18 +159,19 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
|
||||
seed: int, logprobs: int):
|
||||
"""Verify logprobs greedy equality when some sequences skip speculation.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -173,12 +184,13 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
@@ -248,12 +260,13 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize(
|
||||
@@ -270,15 +283,16 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
|
||||
"""Check the behavior when logprobs are disabled.
|
||||
Token choices should match with the base model.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
@@ -60,8 +60,10 @@ PRECISION = "float32"
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
@@ -107,14 +109,18 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs_during_spec_decoding": True,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
@@ -132,19 +138,20 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
prefill_chunk_size: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -165,8 +172,10 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
@@ -214,8 +223,10 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
@@ -264,8 +275,10 @@ def test_medusa_e2e_greedy_correctness_with_preemption(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": k,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
}
|
||||
# Try a range of num. speculative tokens
|
||||
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||
@@ -312,12 +325,13 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_disable_by_batch_size": 4
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_by_batch_size": 4,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
@@ -359,16 +373,17 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_disable_by_batch_size": 4
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_by_batch_size": 4,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
|
||||
@@ -62,7 +62,9 @@ PRECISION = "float32"
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
@@ -108,12 +110,16 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"disable_logprobs_during_spec_decoding": True,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [8])
|
||||
@@ -133,19 +139,20 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
# up sampling different tokens at the tail (ie top tokens don't change).
|
||||
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -167,7 +174,9 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [2048])
|
||||
@@ -209,8 +218,10 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
|
||||
# Speculative model
|
||||
"speculative_model": SPEC_MODEL,
|
||||
# Speculative config
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
|
||||
@@ -274,7 +285,9 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
@@ -326,7 +339,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption(
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
@@ -382,8 +397,10 @@ def test_mlp_e2e_greedy_correctness_with_padding(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"num_speculative_tokens": k,
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
}
|
||||
# Try a range of num. speculative tokens
|
||||
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||
@@ -430,11 +447,12 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": SPEC_MODEL,
|
||||
"speculative_disable_by_batch_size": 4
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"disable_by_batch_size": 4,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
@@ -475,14 +493,15 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
"speculative_model": SPEC_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
|
||||
@@ -57,7 +57,9 @@ PRECISION = "bfloat16"
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
@@ -99,12 +101,16 @@ def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs_during_spec_decoding": True,
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
@@ -119,18 +125,19 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int,
|
||||
logprobs: int):
|
||||
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -152,7 +159,9 @@ def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
@@ -198,7 +207,9 @@ def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
@@ -243,7 +254,9 @@ def test_mtp_e2e_greedy_correctness_with_preemption(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"num_speculative_tokens": k,
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
}
|
||||
# Try a range of num. speculative tokens
|
||||
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||
@@ -286,11 +299,12 @@ def test_mtp_different_k(vllm_runner, common_llm_kwargs,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"speculative_disable_by_batch_size": 4
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_by_batch_size": 4
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
|
||||
@@ -61,15 +61,19 @@ from .conftest import (get_output_from_llm_generator,
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
# Chunked prefill enabled with small value
|
||||
# to make sure we get mixed batches.
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
@@ -148,20 +152,23 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": False,
|
||||
"disable_logprobs_during_spec_decoding": False
|
||||
}, {
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
"disable_logprobs_during_spec_decoding": False
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
}, {
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
@@ -184,7 +191,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
||||
whether all speculative tokens are accepted.
|
||||
"""
|
||||
ensure_all_accepted = per_test_common_llm_kwargs.get(
|
||||
"model_name") == test_llm_kwargs.get("speculative_model")
|
||||
"model_name") == test_llm_kwargs.get("speculative_config")["model"]
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@@ -224,13 +231,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
@@ -283,13 +294,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
@@ -336,13 +351,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
@@ -391,13 +410,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
@@ -449,13 +472,17 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
@@ -514,13 +541,17 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
@@ -567,21 +598,25 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
|
||||
# Artificially limit the draft model max model len; this forces vLLM
|
||||
# to skip speculation once the sequences grow beyond 32-k tokens.
|
||||
"speculative_max_model_len": 32,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"max_model_len": 32,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"max_model_len": 32,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
"speculative_max_model_len": 32,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@@ -627,15 +662,19 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_disable_by_batch_size": 2,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"disable_by_batch_size": 2,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_disable_by_batch_size": 2,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"disable_by_batch_size": 2,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
@@ -676,15 +715,19 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs,
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
}
|
||||
# Try a range of common k, as well as large speculation.
|
||||
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
|
||||
] + [{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
@@ -729,17 +772,21 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
"acceptance_method": "typical_acceptance_sampler",
|
||||
},
|
||||
"enable_chunked_prefill": False
|
||||
}
|
||||
# Try a range of common k.
|
||||
for k in [1, 2, 3]
|
||||
] + [{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
"acceptance_method": "typical_acceptance_sampler",
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
|
||||
@@ -48,16 +48,20 @@ from .conftest import run_equality_correctness_test
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"speculative_disable_mqa_scorer": False,
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_mqa_scorer": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
@@ -101,16 +105,20 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"disable_logprobs_during_spec_decoding": True,
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
@@ -125,19 +133,20 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int,
|
||||
logprobs: int):
|
||||
"""Verify greedy equality on a tiny model with different batch size."""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs[
|
||||
'disable_logprobs_during_spec_decoding'])
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -159,17 +168,21 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
@@ -214,17 +227,21 @@ def test_ngram_e2e_greedy_correctness_with_preemption(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": k,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": k,
|
||||
"prompt_lookup_max": 3,
|
||||
},
|
||||
}
|
||||
# Try a range of common k, as well as large speculation.
|
||||
for k in [1, 3, 5]
|
||||
] + [
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": k,
|
||||
"ngram_prompt_lookup_max": 1,
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": k,
|
||||
"prompt_lookup_max": 1,
|
||||
},
|
||||
}
|
||||
# Try a range of common k, as well as large speculation.
|
||||
for k in [1, 3, 5]
|
||||
@@ -243,7 +260,7 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
|
||||
seed: int):
|
||||
"""Verify that ngram speculative decoding produces exact equality
|
||||
to without spec decode with many different values of k and
|
||||
different ngram_prompt_lookup_max.
|
||||
different ngram prompt_lookup_max.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
@@ -266,22 +283,25 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"speculative_disable_by_batch_size": 4
|
||||
}, {
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"speculative_disable_by_batch_size": 4,
|
||||
"enable_chunked_prefill": True,
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_by_batch_size": 4
|
||||
},
|
||||
}, {
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_by_batch_size": 4,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
@@ -296,7 +316,7 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
seed: int):
|
||||
"""Verify that ngram speculative decoding produces exact equality
|
||||
to without spec decode with many different values of k and
|
||||
different ngram_prompt_lookup_max.
|
||||
different ngram prompt_lookup_max.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
@@ -316,18 +336,17 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
|
||||
@@ -19,11 +19,11 @@ SPEC_MODEL = "JackFram/llama-160m"
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# speculative model
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
|
||||
# num speculative tokens
|
||||
"num_speculative_tokens": 3,
|
||||
# speculative config
|
||||
"speculative_config": {
|
||||
"model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
|
||||
|
||||
Reference in New Issue
Block a user