[Feature] [Spec decode]: Combine chunked prefill with speculative decoding (#9291)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -10,6 +10,7 @@ import torch
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest, SequenceOutput
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
|
||||
SpecDecodeWorkerMetrics)
|
||||
@@ -819,3 +820,84 @@ def test_handle_finished_requests():
|
||||
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
|
||||
assert worker._seq_with_bonus_token_in_last_step == \
|
||||
{4,5,10}
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [3])
|
||||
@pytest.mark.parametrize('batch_size', [2, 32])
|
||||
@pytest.mark.parametrize("batch_composition",
|
||||
["prefill_only", "decode_only", "mixed"])
|
||||
@torch.inference_mode()
|
||||
def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
|
||||
"""
|
||||
Verify SpecDecodeWorker calls match the expected flow.
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
mock_spec_decode_sampler("rejection_sampler"),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
exception_secret = 'artificial stop'
|
||||
worker.scorer = mock_worker(BatchExpansionTop1Scorer)
|
||||
worker.scorer.score_proposals.side_effect = ValueError(exception_secret)
|
||||
|
||||
# Create batch with combination of terminal/non-terminal prefill chunks
|
||||
# and decodes (different seq_ids).
|
||||
decodes, _, _ = create_batch(batch_size, k)
|
||||
# Pre-chunking here, get 'batch_size' chunks.
|
||||
prefill, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
prefill_chunk_size=4,
|
||||
seq_ids=list(range(batch_size,
|
||||
batch_size * 2)))
|
||||
|
||||
if batch_composition == "prefill_only":
|
||||
n_prefills = batch_size
|
||||
elif batch_composition == "decode_only":
|
||||
n_prefills = 0
|
||||
else:
|
||||
n_prefills = random.randint(1, batch_size - 1)
|
||||
n_decodes = batch_size - n_prefills
|
||||
|
||||
prefill = random.sample(prefill, n_prefills)
|
||||
decodes = random.sample(decodes, n_decodes)
|
||||
target_group_metadata_list = prefill + decodes
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=target_group_metadata_list,
|
||||
num_lookahead_slots=k)
|
||||
|
||||
target_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(1, batch_size * (k + 1)),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
target_token_probs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_token_logprobs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs,
|
||||
target_token_logprobs)
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
if not len(decodes):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
# no spec run (prefill only)
|
||||
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
else:
|
||||
# Decode-only run OR mixed batch, scorer call fails (it's mocked)
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
# but first draft still counted
|
||||
assert draft_worker.get_spec_proposals.call_count == 1
|
||||
|
||||
Reference in New Issue
Block a user