[Speculative Decoding] Enabling bonus token in speculative decoding for KV cache based models (#5765)

This commit is contained in:
sroy745
2024-07-10 16:02:47 -07:00
committed by GitHub
parent 44cc76610d
commit ae151d73be
14 changed files with 645 additions and 80 deletions

View File

@@ -70,14 +70,17 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
if queue_size < disable_by_batch_size:
# Should raise exception when executing the mocked draft model.
with pytest.raises(ValueError, match=exception_secret):
proposer.get_spec_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
else:
# Should not execute the draft model because spec decode is disabled
# for all requests. Accordingly, the proposal length should be 0.
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert proposals.proposal_lens.tolist() == [0] * batch_size