[Speculative Decoding] Enabling bonus token in speculative decoding for KV cache based models (#5765)
This commit is contained in:
@@ -53,7 +53,8 @@ def test_ngram_algo_correctness_for_single_no_match():
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=proposal_len), )
|
||||
num_lookahead_slots=proposal_len),
|
||||
seq_ids_with_bonus_token_in_last_step=None)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
@@ -121,7 +122,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=proposal_len), )
|
||||
num_lookahead_slots=proposal_len),
|
||||
seq_ids_with_bonus_token_in_last_step=None)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
@@ -193,7 +195,8 @@ def test_ngram_algo_correctness_for_batches_match_all():
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=proposal_len), )
|
||||
num_lookahead_slots=proposal_len),
|
||||
seq_ids_with_bonus_token_in_last_step=None)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
Reference in New Issue
Block a user