[Feature] [Spec decode]: Combine chunked prefill with speculative decoding (#9291)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2024-11-07 17:15:14 +01:00
committed by GitHub
parent ae62fd17c0
commit 9d43afcc53
17 changed files with 476 additions and 146 deletions

View File

@@ -5,40 +5,6 @@ from vllm import SamplingParams
from .conftest import get_output_from_llm_generator
@pytest.mark.parametrize("common_llm_kwargs", [{
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"enable_chunked_prefill": True,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
"""Verify that speculative decoding with chunked prefill fails.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
with pytest.raises(ValueError,
match="Speculative decoding and chunked prefill"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
@pytest.mark.parametrize("common_llm_kwargs", [{
"model": "meta-llama/Llama-2-7b-chat-hf",
"speculative_model": "JackFram/llama-68m",

View File

@@ -62,6 +62,16 @@ from .conftest import (get_output_from_llm_generator,
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
# Chunked prefill enabled with small value
# to make sure we get mixed batches.
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
{
# Verify the detokenizer assertions in the test work when spec
@@ -141,6 +151,14 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
},
])
@pytest.mark.parametrize(
@@ -204,6 +222,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize(
@@ -255,6 +281,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize("max_output_len", [
@@ -300,6 +334,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize("batch_size", [1])
@@ -347,6 +389,14 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize("batch_size", [32])
@@ -397,6 +447,14 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize(
@@ -454,6 +512,14 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize("batch_size", [2])
@@ -503,6 +569,15 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len": 32,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
"speculative_max_model_len": 32,
},
])
@pytest.mark.parametrize("batch_size", [8])
@@ -551,6 +626,15 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs,
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_disable_by_batch_size": 2,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_disable_by_batch_size": 2,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
},
])
@pytest.mark.parametrize("batch_size", [8])
@@ -590,10 +674,17 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs,
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"enable_chunked_prefill": False,
}
# Try a range of common k, as well as large speculation.
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
])
] + [{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
} for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
@@ -636,11 +727,19 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"spec_decoding_acceptance_method": "typical_acceptance_sampler"
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
"enable_chunked_prefill": False
}
# Try a range of common k.
for k in [1, 2, 3]
])
] + [{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
} for k in [1, 2, 3]])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize(
"output_len",

View File

@@ -50,18 +50,33 @@ from .conftest import run_equality_correctness_test
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
])
@pytest.mark.parametrize("output_len", [
256,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
@pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):
prefill_chunk_size: int, seed: int):
"""Verify greedy equality on a tiny model with different batch size."""
if prefill_chunk_size > 0:
common_llm_kwargs.update(
**{
"enable_chunked_prefill": True,
"max_num_batched_tokens": prefill_chunk_size,
"max_num_seqs": prefill_chunk_size
})
else:
common_llm_kwargs["enable_chunked_prefill"] = False
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
@@ -151,6 +166,16 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"enable_chunked_prefill": False,
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"enable_chunked_prefill": True,
"speculative_disable_mqa_scorer": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize(
@@ -251,6 +276,15 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_by_batch_size": 4
}, {
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_by_batch_size": 4,
"enable_chunked_prefill": True,
"speculative_disable_mqa_scorer": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(