[Speculative decoding] Add ngram prompt lookup decoding (#4237)
Co-authored-by: Lei Wen <wenlei03@qiyi.com>
This commit is contained in:
@@ -6,8 +6,8 @@ import torch
|
||||
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.spec_decode.multi_step_worker import (DraftModelTop1Proposer,
|
||||
MultiStepWorker)
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
from .utils import (assert_logprobs_dict_allclose, create_batch,
|
||||
@@ -117,8 +117,8 @@ def test_same_output_for_single_step():
|
||||
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
actual_output = multi_step_worker.execute_model_multi_step(
|
||||
**multi_step_execute_model_data.to_dict(), num_steps=num_steps)
|
||||
actual_output, _ = multi_step_worker.sampler_output(
|
||||
**multi_step_execute_model_data.to_dict(), sample_len=num_steps)
|
||||
assert len(actual_output) == num_steps
|
||||
actual_output = actual_output[0]
|
||||
|
||||
@@ -200,8 +200,8 @@ def test_same_output_for_multi_step():
|
||||
# Run multi-step.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
multi_step_output = multi_step_worker.execute_model_multi_step(
|
||||
**execute_model_data.to_dict(), num_steps=num_steps)
|
||||
multi_step_output, _ = multi_step_worker.sampler_output(
|
||||
**execute_model_data.to_dict(), sample_len=num_steps)
|
||||
|
||||
# Run single-step repeatedly.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
@@ -266,7 +266,7 @@ def test_same_output_for_multi_step():
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_draft_proposals_full_speculation_len():
|
||||
"""Verify DraftModelTop1Proposer correctly handles case where all sequences
|
||||
"""Verify Top1Proposer correctly handles case where all sequences
|
||||
can speculate.
|
||||
"""
|
||||
k = 10
|
||||
@@ -275,13 +275,13 @@ def test_draft_proposals_full_speculation_len():
|
||||
device = 'cuda:0'
|
||||
|
||||
draft_worker = MagicMock()
|
||||
proposer = DraftModelTop1Proposer(
|
||||
draft_worker=draft_worker,
|
||||
proposer = Top1Proposer(
|
||||
worker=draft_worker,
|
||||
device=device,
|
||||
max_model_len=2048,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=2048,
|
||||
)
|
||||
draft_worker.execute_model_multi_step.return_value = [
|
||||
draft_worker.sampler_output.return_value = [
|
||||
SamplerOutput(
|
||||
outputs=[],
|
||||
sampled_token_probs=torch.rand(batch_size,
|
||||
@@ -294,13 +294,13 @@ def test_draft_proposals_full_speculation_len():
|
||||
device=device,
|
||||
dtype=torch.long),
|
||||
) for _ in range(k)
|
||||
]
|
||||
], True
|
||||
|
||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**execute_model_data.to_dict(),
|
||||
max_proposal_len=k,
|
||||
proposal_len=k,
|
||||
)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
@@ -315,7 +315,7 @@ def test_draft_proposals_full_speculation_len():
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_draft_proposals_no_speculations():
|
||||
"""Verify DraftModelTop1Proposer correctly handles case where no sequences
|
||||
"""Verify Top1Proposer correctly handles case where no sequences
|
||||
can speculate.
|
||||
"""
|
||||
k = 10
|
||||
@@ -325,11 +325,11 @@ def test_draft_proposals_no_speculations():
|
||||
prompt_len = 10
|
||||
|
||||
draft_worker = MagicMock()
|
||||
proposer = DraftModelTop1Proposer(
|
||||
draft_worker=draft_worker,
|
||||
proposer = Top1Proposer(
|
||||
worker=draft_worker,
|
||||
device=device,
|
||||
max_model_len=prompt_len + k - 1,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=prompt_len + k - 1,
|
||||
)
|
||||
|
||||
execute_model_data, _, _ = create_batch(batch_size,
|
||||
@@ -338,7 +338,7 @@ def test_draft_proposals_no_speculations():
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**execute_model_data.to_dict(),
|
||||
max_proposal_len=k,
|
||||
proposal_len=k,
|
||||
)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
@@ -353,7 +353,7 @@ def test_draft_proposals_no_speculations():
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_draft_proposals_mixed_k():
|
||||
"""Verify DraftModelTop1Proposer correctly handles case some sequences can
|
||||
"""Verify Top1Proposer correctly handles case some sequences can
|
||||
speculate and some can't.
|
||||
"""
|
||||
k = 10
|
||||
@@ -374,14 +374,14 @@ def test_draft_proposals_mixed_k():
|
||||
for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len]
|
||||
|
||||
draft_worker = MagicMock()
|
||||
proposer = DraftModelTop1Proposer(
|
||||
draft_worker=draft_worker,
|
||||
proposer = Top1Proposer(
|
||||
worker=draft_worker,
|
||||
device=device,
|
||||
max_model_len=long_prompt_len + prev_output_token_len + k - 1,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=long_prompt_len + prev_output_token_len + k - 1,
|
||||
)
|
||||
|
||||
draft_worker.execute_model_multi_step.return_value = [
|
||||
draft_worker.sampler_output.return_value = [
|
||||
SamplerOutput(
|
||||
outputs=[],
|
||||
sampled_token_probs=torch.rand(expected_num_proposal_seqs,
|
||||
@@ -395,7 +395,7 @@ def test_draft_proposals_mixed_k():
|
||||
device=device,
|
||||
dtype=torch.long),
|
||||
) for _ in range(k)
|
||||
]
|
||||
], True
|
||||
|
||||
execute_model_data, _, _ = create_batch(
|
||||
batch_size,
|
||||
@@ -406,7 +406,7 @@ def test_draft_proposals_mixed_k():
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**execute_model_data.to_dict(),
|
||||
max_proposal_len=k,
|
||||
proposal_len=k,
|
||||
)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
|
||||
Reference in New Issue
Block a user