[Feature] [Spec decode]: Combine chunked prefill with speculative decoding (#9291)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -46,12 +46,14 @@ def assert_score_equal(score1: SpeculativeScores,
|
||||
@pytest.mark.parametrize('max_propose_len', [1, 3, 5])
|
||||
@pytest.mark.parametrize('mixed_propose_len', [True])
|
||||
@pytest.mark.parametrize('device', ['cuda'])
|
||||
@pytest.mark.parametrize('prefill_chunking', [False, True])
|
||||
def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
|
||||
mixed_propose_len: bool, device: str) -> None:
|
||||
mixed_propose_len: bool, device: str,
|
||||
prefill_chunking: bool) -> None:
|
||||
"""
|
||||
Compare the batch expansion scorer and mqa scorer return the same score.
|
||||
We test for both queries with the same propose length and different
|
||||
propose length.
|
||||
propose length, as well as mixed prefill-decode batches.
|
||||
"""
|
||||
seed = 0
|
||||
block_size = 32
|
||||
@@ -67,16 +69,37 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
|
||||
if not mixed_propose_len:
|
||||
propose_lens = [max_propose_len] * batch_size
|
||||
else:
|
||||
non_zero_cnt = random.randint(0, batch_size)
|
||||
# There must be at least 1 decode request, otherwise
|
||||
# we have nothing to score (`_run_no_spec`).
|
||||
non_zero_cnt = random.randint(1, batch_size)
|
||||
propose_lens = [max_propose_len
|
||||
] * non_zero_cnt + [0] * (batch_size - non_zero_cnt)
|
||||
random.shuffle(propose_lens)
|
||||
|
||||
proposals = create_proposal(propose_lens, vocab_size, device)
|
||||
seq_group_metadatalist, _, _ = create_batch(batch_size,
|
||||
max_propose_len,
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
if mixed_propose_len and prefill_chunking and (n_prefills :=
|
||||
batch_size - non_zero_cnt):
|
||||
prefill, _, _ = create_batch(n_prefills,
|
||||
None,
|
||||
prefill_chunk_size=4,
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
seq_ids=list(
|
||||
range(batch_size,
|
||||
batch_size + n_prefills)))
|
||||
# re-order to guarantee prefill|decode order
|
||||
target_group_metadatalist = [
|
||||
seq_group_metadatalist[i] for i, p in enumerate(propose_lens)
|
||||
if p > 0
|
||||
]
|
||||
seq_group_metadatalist = prefill + target_group_metadatalist
|
||||
propose_lens = [0] * n_prefills + [p for p in propose_lens if p > 0]
|
||||
|
||||
proposals = create_proposal(propose_lens, vocab_size, device)
|
||||
requests = ExecuteModelRequest(seq_group_metadatalist,
|
||||
num_lookahead_slots=max_propose_len)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user