[Speculative Decoding] Add ProposerWorkerBase abstract class (#5252)
This commit is contained in:
@@ -307,9 +307,10 @@ def test_draft_proposals_full_speculation_len():
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
@@ -344,9 +345,10 @@ def test_draft_proposals_no_speculations():
|
||||
k,
|
||||
prompt_len=prompt_len)
|
||||
|
||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
@@ -415,9 +417,10 @@ def test_draft_proposals_mixed_k():
|
||||
prev_output_token_len=prev_output_token_len,
|
||||
)
|
||||
|
||||
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k), )
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
Reference in New Issue
Block a user