[Speculative Decoding] Add ProposerWorkerBase abstract class (#5252)

This commit is contained in:
Nick Hill
2024-06-05 14:53:05 -07:00
committed by GitHub
parent f270a39537
commit faf71bcd4b
9 changed files with 91 additions and 60 deletions

View File

@@ -307,9 +307,10 @@ def test_draft_proposals_full_speculation_len():
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
@@ -344,9 +345,10 @@ def test_draft_proposals_no_speculations():
k,
prompt_len=prompt_len)
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
@@ -415,9 +417,10 @@ def test_draft_proposals_mixed_k():
prev_output_token_len=prev_output_token_len,
)
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)