[Feature] [Spec decode]: Combine chunked prefill with speculative decoding (#9291)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2024-11-07 17:15:14 +01:00
committed by GitHub
parent ae62fd17c0
commit 9d43afcc53
17 changed files with 476 additions and 146 deletions

View File

@@ -46,12 +46,14 @@ def assert_score_equal(score1: SpeculativeScores,
@pytest.mark.parametrize('max_propose_len', [1, 3, 5])
@pytest.mark.parametrize('mixed_propose_len', [True])
@pytest.mark.parametrize('device', ['cuda'])
@pytest.mark.parametrize('prefill_chunking', [False, True])
def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
mixed_propose_len: bool, device: str) -> None:
mixed_propose_len: bool, device: str,
prefill_chunking: bool) -> None:
"""
Compare the batch expansion scorer and mqa scorer return the same score.
We test for both queries with the same propose length and different
propose length.
propose length, as well as mixed prefill-decode batches.
"""
seed = 0
block_size = 32
@@ -67,16 +69,37 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
if not mixed_propose_len:
propose_lens = [max_propose_len] * batch_size
else:
non_zero_cnt = random.randint(0, batch_size)
# There must be at least 1 decode request, otherwise
# we have nothing to score (`_run_no_spec`).
non_zero_cnt = random.randint(1, batch_size)
propose_lens = [max_propose_len
] * non_zero_cnt + [0] * (batch_size - non_zero_cnt)
random.shuffle(propose_lens)
proposals = create_proposal(propose_lens, vocab_size, device)
seq_group_metadatalist, _, _ = create_batch(batch_size,
max_propose_len,
block_size=block_size,
num_gpu_blocks=num_gpu_blocks)
if mixed_propose_len and prefill_chunking and (n_prefills :=
batch_size - non_zero_cnt):
prefill, _, _ = create_batch(n_prefills,
None,
prefill_chunk_size=4,
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
seq_ids=list(
range(batch_size,
batch_size + n_prefills)))
# re-order to guarantee prefill|decode order
target_group_metadatalist = [
seq_group_metadatalist[i] for i, p in enumerate(propose_lens)
if p > 0
]
seq_group_metadatalist = prefill + target_group_metadatalist
propose_lens = [0] * n_prefills + [p for p in propose_lens if p > 0]
proposals = create_proposal(propose_lens, vocab_size, device)
requests = ExecuteModelRequest(seq_group_metadatalist,
num_lookahead_slots=max_propose_len)