[Speculative decoding] Add ngram prompt lookup decoding (#4237)

Co-authored-by: Lei Wen <wenlei03@qiyi.com>
This commit is contained in:
leiwen83
2024-05-02 02:13:03 +08:00
committed by GitHub
parent 8b798eec75
commit b38e42fbca
14 changed files with 1003 additions and 318 deletions

View File

@@ -6,8 +6,8 @@ import torch
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplerOutput
from vllm.spec_decode.multi_step_worker import (DraftModelTop1Proposer,
MultiStepWorker)
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker
from .utils import (assert_logprobs_dict_allclose, create_batch,
@@ -117,8 +117,8 @@ def test_same_output_for_single_step():
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
actual_output = multi_step_worker.execute_model_multi_step(
**multi_step_execute_model_data.to_dict(), num_steps=num_steps)
actual_output, _ = multi_step_worker.sampler_output(
**multi_step_execute_model_data.to_dict(), sample_len=num_steps)
assert len(actual_output) == num_steps
actual_output = actual_output[0]
@@ -200,8 +200,8 @@ def test_same_output_for_multi_step():
# Run multi-step.
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
multi_step_output = multi_step_worker.execute_model_multi_step(
**execute_model_data.to_dict(), num_steps=num_steps)
multi_step_output, _ = multi_step_worker.sampler_output(
**execute_model_data.to_dict(), sample_len=num_steps)
# Run single-step repeatedly.
zero_kv_cache(worker.cache_engine)
@@ -266,7 +266,7 @@ def test_same_output_for_multi_step():
@torch.inference_mode()
def test_draft_proposals_full_speculation_len():
"""Verify DraftModelTop1Proposer correctly handles case where all sequences
"""Verify Top1Proposer correctly handles case where all sequences
can speculate.
"""
k = 10
@@ -275,13 +275,13 @@ def test_draft_proposals_full_speculation_len():
device = 'cuda:0'
draft_worker = MagicMock()
proposer = DraftModelTop1Proposer(
draft_worker=draft_worker,
proposer = Top1Proposer(
worker=draft_worker,
device=device,
max_model_len=2048,
vocab_size=vocab_size,
max_proposal_len=2048,
)
draft_worker.execute_model_multi_step.return_value = [
draft_worker.sampler_output.return_value = [
SamplerOutput(
outputs=[],
sampled_token_probs=torch.rand(batch_size,
@@ -294,13 +294,13 @@ def test_draft_proposals_full_speculation_len():
device=device,
dtype=torch.long),
) for _ in range(k)
]
], True
execute_model_data, _, _ = create_batch(batch_size, k)
proposals = proposer.get_proposals(
**execute_model_data.to_dict(),
max_proposal_len=k,
proposal_len=k,
)
assert torch.is_tensor(proposals.proposal_token_ids)
@@ -315,7 +315,7 @@ def test_draft_proposals_full_speculation_len():
@torch.inference_mode()
def test_draft_proposals_no_speculations():
"""Verify DraftModelTop1Proposer correctly handles case where no sequences
"""Verify Top1Proposer correctly handles case where no sequences
can speculate.
"""
k = 10
@@ -325,11 +325,11 @@ def test_draft_proposals_no_speculations():
prompt_len = 10
draft_worker = MagicMock()
proposer = DraftModelTop1Proposer(
draft_worker=draft_worker,
proposer = Top1Proposer(
worker=draft_worker,
device=device,
max_model_len=prompt_len + k - 1,
vocab_size=vocab_size,
max_proposal_len=prompt_len + k - 1,
)
execute_model_data, _, _ = create_batch(batch_size,
@@ -338,7 +338,7 @@ def test_draft_proposals_no_speculations():
proposals = proposer.get_proposals(
**execute_model_data.to_dict(),
max_proposal_len=k,
proposal_len=k,
)
assert torch.is_tensor(proposals.proposal_token_ids)
@@ -353,7 +353,7 @@ def test_draft_proposals_no_speculations():
@torch.inference_mode()
def test_draft_proposals_mixed_k():
"""Verify DraftModelTop1Proposer correctly handles case some sequences can
"""Verify Top1Proposer correctly handles case some sequences can
speculate and some can't.
"""
k = 10
@@ -374,14 +374,14 @@ def test_draft_proposals_mixed_k():
for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len]
draft_worker = MagicMock()
proposer = DraftModelTop1Proposer(
draft_worker=draft_worker,
proposer = Top1Proposer(
worker=draft_worker,
device=device,
max_model_len=long_prompt_len + prev_output_token_len + k - 1,
vocab_size=vocab_size,
max_proposal_len=long_prompt_len + prev_output_token_len + k - 1,
)
draft_worker.execute_model_multi_step.return_value = [
draft_worker.sampler_output.return_value = [
SamplerOutput(
outputs=[],
sampled_token_probs=torch.rand(expected_num_proposal_seqs,
@@ -395,7 +395,7 @@ def test_draft_proposals_mixed_k():
device=device,
dtype=torch.long),
) for _ in range(k)
]
], True
execute_model_data, _, _ = create_batch(
batch_size,
@@ -406,7 +406,7 @@ def test_draft_proposals_mixed_k():
proposals = proposer.get_proposals(
**execute_model_data.to_dict(),
max_proposal_len=k,
proposal_len=k,
)
assert torch.is_tensor(proposals.proposal_token_ids)