[Feature] [Spec decode]: Combine chunked prefill with speculative decoding (#9291)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2024-11-07 17:15:14 +01:00
committed by GitHub
parent ae62fd17c0
commit 9d43afcc53
17 changed files with 476 additions and 146 deletions

View File

@@ -5,40 +5,6 @@ from vllm import SamplingParams
from .conftest import get_output_from_llm_generator
@pytest.mark.parametrize("common_llm_kwargs", [{
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"enable_chunked_prefill": True,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
"""Verify that speculative decoding with chunked prefill fails.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
with pytest.raises(ValueError,
match="Speculative decoding and chunked prefill"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
@pytest.mark.parametrize("common_llm_kwargs", [{
"model": "meta-llama/Llama-2-7b-chat-hf",
"speculative_model": "JackFram/llama-68m",

View File

@@ -62,6 +62,16 @@ from .conftest import (get_output_from_llm_generator,
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
# Chunked prefill enabled with small value
# to make sure we get mixed batches.
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
{
# Verify the detokenizer assertions in the test work when spec
@@ -141,6 +151,14 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
},
])
@pytest.mark.parametrize(
@@ -204,6 +222,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize(
@@ -255,6 +281,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize("max_output_len", [
@@ -300,6 +334,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize("batch_size", [1])
@@ -347,6 +389,14 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize("batch_size", [32])
@@ -397,6 +447,14 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize(
@@ -454,6 +512,14 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize("batch_size", [2])
@@ -503,6 +569,15 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len": 32,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
"speculative_max_model_len": 32,
},
])
@pytest.mark.parametrize("batch_size", [8])
@@ -551,6 +626,15 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs,
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_disable_by_batch_size": 2,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_disable_by_batch_size": 2,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
},
])
@pytest.mark.parametrize("batch_size", [8])
@@ -590,10 +674,17 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs,
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"enable_chunked_prefill": False,
}
# Try a range of common k, as well as large speculation.
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
])
] + [{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
} for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
@@ -636,11 +727,19 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"spec_decoding_acceptance_method": "typical_acceptance_sampler"
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
"enable_chunked_prefill": False
}
# Try a range of common k.
for k in [1, 2, 3]
])
] + [{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
} for k in [1, 2, 3]])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize(
"output_len",

View File

@@ -50,18 +50,33 @@ from .conftest import run_equality_correctness_test
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
])
@pytest.mark.parametrize("output_len", [
256,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
@pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):
prefill_chunk_size: int, seed: int):
"""Verify greedy equality on a tiny model with different batch size."""
if prefill_chunk_size > 0:
common_llm_kwargs.update(
**{
"enable_chunked_prefill": True,
"max_num_batched_tokens": prefill_chunk_size,
"max_num_seqs": prefill_chunk_size
})
else:
common_llm_kwargs["enable_chunked_prefill"] = False
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
@@ -151,6 +166,16 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"enable_chunked_prefill": False,
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"enable_chunked_prefill": True,
"speculative_disable_mqa_scorer": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize(
@@ -251,6 +276,15 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_by_batch_size": 4
}, {
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_by_batch_size": 4,
"enable_chunked_prefill": True,
"speculative_disable_mqa_scorer": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(

View File

@@ -118,7 +118,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)
for sg in seq_group_metadata_list:
sg.is_prompt = False
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
@@ -147,7 +148,7 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
def test_ngram_algo_correctness_for_batches_match_all():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario find candidate in all batchs
For the scenario find candidate in all batches
"""
block_size = 32
@@ -192,6 +193,10 @@ def test_ngram_algo_correctness_for_batches_match_all():
block_size,
final_prompt_lens=final_prompt_lens)
# Normally drafter is run on decode requests only; here we check the output
# of the ngram worker as it is the sole proposer that has no forward.
for sg in seq_group_metadata_list:
sg.is_prompt = False
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,

View File

@@ -46,12 +46,14 @@ def assert_score_equal(score1: SpeculativeScores,
@pytest.mark.parametrize('max_propose_len', [1, 3, 5])
@pytest.mark.parametrize('mixed_propose_len', [True])
@pytest.mark.parametrize('device', ['cuda'])
@pytest.mark.parametrize('prefill_chunking', [False, True])
def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
mixed_propose_len: bool, device: str) -> None:
mixed_propose_len: bool, device: str,
prefill_chunking: bool) -> None:
"""
Compare the batch expansion scorer and mqa scorer return the same score.
We test for both queries with the same propose length and different
propose length.
propose length, as well as mixed prefill-decode batches.
"""
seed = 0
block_size = 32
@@ -67,16 +69,37 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
if not mixed_propose_len:
propose_lens = [max_propose_len] * batch_size
else:
non_zero_cnt = random.randint(0, batch_size)
# There must be at least 1 decode request, otherwise
# we have nothing to score (`_run_no_spec`).
non_zero_cnt = random.randint(1, batch_size)
propose_lens = [max_propose_len
] * non_zero_cnt + [0] * (batch_size - non_zero_cnt)
random.shuffle(propose_lens)
proposals = create_proposal(propose_lens, vocab_size, device)
seq_group_metadatalist, _, _ = create_batch(batch_size,
max_propose_len,
block_size=block_size,
num_gpu_blocks=num_gpu_blocks)
if mixed_propose_len and prefill_chunking and (n_prefills :=
batch_size - non_zero_cnt):
prefill, _, _ = create_batch(n_prefills,
None,
prefill_chunk_size=4,
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
seq_ids=list(
range(batch_size,
batch_size + n_prefills)))
# re-order to guarantee prefill|decode order
target_group_metadatalist = [
seq_group_metadatalist[i] for i, p in enumerate(propose_lens)
if p > 0
]
seq_group_metadatalist = prefill + target_group_metadatalist
propose_lens = [0] * n_prefills + [p for p in propose_lens if p > 0]
proposals = create_proposal(propose_lens, vocab_size, device)
requests = ExecuteModelRequest(seq_group_metadatalist,
num_lookahead_slots=max_propose_len)

View File

@@ -10,6 +10,7 @@ import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SequenceOutput
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics)
@@ -819,3 +820,84 @@ def test_handle_finished_requests():
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
assert worker._seq_with_bonus_token_in_last_step == \
{4,5,10}
@pytest.mark.parametrize('k', [3])
@pytest.mark.parametrize('batch_size', [2, 32])
@pytest.mark.parametrize("batch_composition",
["prefill_only", "decode_only", "mixed"])
@torch.inference_mode()
def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
"""
Verify SpecDecodeWorker calls match the expected flow.
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker,
target_worker,
mock_spec_decode_sampler("rejection_sampler"),
disable_logprobs=False,
metrics_collector=metrics_collector)
exception_secret = 'artificial stop'
worker.scorer = mock_worker(BatchExpansionTop1Scorer)
worker.scorer.score_proposals.side_effect = ValueError(exception_secret)
# Create batch with combination of terminal/non-terminal prefill chunks
# and decodes (different seq_ids).
decodes, _, _ = create_batch(batch_size, k)
# Pre-chunking here, get 'batch_size' chunks.
prefill, _, _ = create_batch(batch_size,
k,
prefill_chunk_size=4,
seq_ids=list(range(batch_size,
batch_size * 2)))
if batch_composition == "prefill_only":
n_prefills = batch_size
elif batch_composition == "decode_only":
n_prefills = 0
else:
n_prefills = random.randint(1, batch_size - 1)
n_decodes = batch_size - n_prefills
prefill = random.sample(prefill, n_prefills)
decodes = random.sample(decodes, n_decodes)
target_group_metadata_list = prefill + decodes
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=target_group_metadata_list,
num_lookahead_slots=k)
target_token_ids = torch.randint(low=0,
high=vocab_size,
size=(1, batch_size * (k + 1)),
dtype=torch.int64,
device='cuda')
target_token_probs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_token_logprobs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs,
target_token_logprobs)
target_worker.execute_model.return_value = [target_output[0]]
if not len(decodes):
worker.execute_model(execute_model_req=execute_model_req)
# no spec run (prefill only)
draft_worker.execute_model.assert_called_once_with(execute_model_req)
target_worker.execute_model.assert_called_once_with(execute_model_req)
else:
# Decode-only run OR mixed batch, scorer call fails (it's mocked)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
# but first draft still counted
assert draft_worker.get_spec_proposals.call_count == 1

View File

@@ -146,6 +146,41 @@ def create_seq_group_metadata_from_prompts(
return seq_grou_metadata_list
def create_chunked_seq_group_metadata_from_prompt(
prompt: List[int],
num_gpu_blocks: int,
chunk_size: int,
block_size: int,
seq_id: Optional[int] = None) -> List[SequenceGroupMetadata]:
if seq_id is None:
seq_id = 0
free_gpu_blocks = list(range(num_gpu_blocks))
block_allocations = [
free_gpu_blocks.pop()
for _ in range(round_up_to_next_block(len(prompt), block_size))
]
seq_group_metadata_list = []
for i, idx in enumerate(range(0, len(prompt), chunk_size)):
chunk_ids = prompt[idx:idx + chunk_size]
data = SequenceData.from_seqs(prompt)
data.update_num_computed_tokens(idx)
seq_data = {i: data}
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=str(seq_id),
is_prompt=True,
do_sample=idx + chunk_size >= len(prompt), # terminal chunk
seq_data=seq_data,
sampling_params=SamplingParams(temperature=0.0),
block_tables={i: block_allocations},
token_chunk_size=len(chunk_ids)))
return seq_group_metadata_list
def assert_logprobs_dict_allclose(
actual_logprobs: List[Dict[int, Logprob]],
expected_logprobs: List[Dict[int, Logprob]]) -> None:
@@ -198,7 +233,8 @@ def create_batch(batch_size,
prev_output_token_len: int = 10,
seq_ids: Optional[List[int]] = None,
num_gpu_blocks: Optional[int] = None,
block_size: Optional[int] = None):
block_size: Optional[int] = None,
prefill_chunk_size: Optional[int] = None):
if block_size is None:
block_size = 8
@@ -213,15 +249,28 @@ def create_batch(batch_size,
prompt_lens = prompt_len
prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens]
prev_output_tokens = [[
next(iterator) for _ in range(prev_output_token_len)
] for _ in range(batch_size)]
final_prompt_lens = [
len(prompt) + len(prev_output_token) + k + 1
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size, final_prompt_lens,
prev_output_tokens, seq_ids)
if prefill_chunk_size:
# Create a batch of chunked prompts.
if not seq_ids:
seq_ids = list(range(len(prompts)))
seq_group_metadata_list = []
for p, sid in zip(prompts, seq_ids):
seq_group_metadata_list += \
create_chunked_seq_group_metadata_from_prompt(
p, num_gpu_blocks, prefill_chunk_size, block_size, sid)
seq_group_metadata_list = seq_group_metadata_list[:batch_size]
prev_output_tokens = []
else:
prev_output_tokens = [[
next(iterator) for _ in range(prev_output_token_len)
] for _ in range(batch_size)]
final_prompt_lens = [
len(prompt) + len(prev_output_token) + k + 1
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size, final_prompt_lens,
prev_output_tokens, seq_ids)
return seq_group_metadata_list, prompts, prev_output_tokens