[Speculative decoding] Add ngram prompt lookup decoding (#4237)

Co-authored-by: Lei Wen <wenlei03@qiyi.com>
This commit is contained in:
leiwen83
2024-05-02 02:13:03 +08:00
committed by GitHub
parent 8b798eec75
commit b38e42fbca
14 changed files with 1003 additions and 318 deletions

View File

@@ -1,4 +1,5 @@
import asyncio
from itertools import cycle
from typing import List, Optional, Tuple, Union
import pytest
@@ -185,3 +186,60 @@ def get_output_from_llm_generator(
del llm
return tokens, token_ids
def run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
print_tokens: bool = False):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
"""
temperature = 0.0
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"San Francisco is know for its",
"Facebook was created in 2004 by",
"Curious George is a",
"Python 3.11 brings improvements to its",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos = force_output_len
sampling_params = SamplingParams(
max_tokens=max_output_len,
ignore_eos=ignore_eos,
temperature=temperature,
)
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
test_llm_generator, prompts, sampling_params)
(baseline_batch_tokens,
baseline_batch_token_ids) = get_output_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
assert len(baseline_batch_token_ids) == len(prompts)
assert len(spec_batch_token_ids) == len(prompts)
for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
spec_tokens) in enumerate(
zip(baseline_batch_token_ids, baseline_batch_tokens,
spec_batch_token_ids, spec_batch_tokens)):
if print_tokens:
print(f'{i=} {baseline_tokens=}')
print(f'{i=} {spec_tokens=}')
print(f'{i=} {baseline_token_ids=}')
print(f'{i=} {spec_token_ids=}')
assert baseline_token_ids == spec_token_ids

View File

@@ -35,7 +35,8 @@ from transformers import AutoTokenizer
from vllm import SamplingParams
from .conftest import get_output_from_llm_generator
from .conftest import (get_output_from_llm_generator,
run_greedy_equality_correctness_test)
@pytest.mark.parametrize(
@@ -545,60 +546,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
batch_size,
max_output_len=output_len,
force_output_len=True)
def run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
print_tokens: bool = False):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
"""
temperature = 0.0
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"San Francisco is know for its",
"Facebook was created in 2004 by",
"Curious George is a",
"Python 3.11 brings improvements to its",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos = force_output_len
sampling_params = SamplingParams(
max_tokens=max_output_len,
ignore_eos=ignore_eos,
temperature=temperature,
)
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
test_llm_generator, prompts, sampling_params)
(baseline_batch_tokens,
baseline_batch_token_ids) = get_output_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
assert len(baseline_batch_token_ids) == len(prompts)
assert len(spec_batch_token_ids) == len(prompts)
for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
spec_tokens) in enumerate(
zip(baseline_batch_token_ids, baseline_batch_tokens,
spec_batch_token_ids, spec_batch_tokens)):
if print_tokens:
print(f'{i=} {baseline_tokens=}')
print(f'{i=} {spec_tokens=}')
print(f'{i=} {baseline_token_ids=}')
print(f'{i=} {spec_token_ids=}')
assert baseline_token_ids == spec_token_ids

View File

@@ -0,0 +1,172 @@
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding,
and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775.
Since there is no model is needed for generate the proposal, we could make
the testcase much simpler than drafter multi-step one.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various ngram sizes / speculative sizes
With those tests, we can say at least, ngram spec would not break the correctess
for the target model outputs.
"""
import pytest
from .conftest import run_greedy_equality_correctness_test
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"model": "JackFram/llama-68m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
])
@pytest.mark.parametrize("output_len", [
256,
])
@pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
test_llm_generator, batch_size: int,
output_len: int):
"""Verify greedy equality on a tiny model with different batch size."""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 8,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"model": "JackFram/llama-160m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
256,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
test_llm_generator,
batch_size: int,
output_len: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": "[ngram]",
"num_speculative_tokens": k,
"ngram_prompt_lookup_max": 3,
}
# Try a range of common k, as well as large speculation.
for k in [1, 3, 5]
] + [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": k,
"ngram_prompt_lookup_max": 1,
}
# Try a range of common k, as well as large speculation.
for k in [1, 3, 5]
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_different_k(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)

View File

@@ -6,8 +6,8 @@ import torch
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplerOutput
from vllm.spec_decode.multi_step_worker import (DraftModelTop1Proposer,
MultiStepWorker)
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker
from .utils import (assert_logprobs_dict_allclose, create_batch,
@@ -117,8 +117,8 @@ def test_same_output_for_single_step():
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
actual_output = multi_step_worker.execute_model_multi_step(
**multi_step_execute_model_data.to_dict(), num_steps=num_steps)
actual_output, _ = multi_step_worker.sampler_output(
**multi_step_execute_model_data.to_dict(), sample_len=num_steps)
assert len(actual_output) == num_steps
actual_output = actual_output[0]
@@ -200,8 +200,8 @@ def test_same_output_for_multi_step():
# Run multi-step.
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
multi_step_output = multi_step_worker.execute_model_multi_step(
**execute_model_data.to_dict(), num_steps=num_steps)
multi_step_output, _ = multi_step_worker.sampler_output(
**execute_model_data.to_dict(), sample_len=num_steps)
# Run single-step repeatedly.
zero_kv_cache(worker.cache_engine)
@@ -266,7 +266,7 @@ def test_same_output_for_multi_step():
@torch.inference_mode()
def test_draft_proposals_full_speculation_len():
"""Verify DraftModelTop1Proposer correctly handles case where all sequences
"""Verify Top1Proposer correctly handles case where all sequences
can speculate.
"""
k = 10
@@ -275,13 +275,13 @@ def test_draft_proposals_full_speculation_len():
device = 'cuda:0'
draft_worker = MagicMock()
proposer = DraftModelTop1Proposer(
draft_worker=draft_worker,
proposer = Top1Proposer(
worker=draft_worker,
device=device,
max_model_len=2048,
vocab_size=vocab_size,
max_proposal_len=2048,
)
draft_worker.execute_model_multi_step.return_value = [
draft_worker.sampler_output.return_value = [
SamplerOutput(
outputs=[],
sampled_token_probs=torch.rand(batch_size,
@@ -294,13 +294,13 @@ def test_draft_proposals_full_speculation_len():
device=device,
dtype=torch.long),
) for _ in range(k)
]
], True
execute_model_data, _, _ = create_batch(batch_size, k)
proposals = proposer.get_proposals(
**execute_model_data.to_dict(),
max_proposal_len=k,
proposal_len=k,
)
assert torch.is_tensor(proposals.proposal_token_ids)
@@ -315,7 +315,7 @@ def test_draft_proposals_full_speculation_len():
@torch.inference_mode()
def test_draft_proposals_no_speculations():
"""Verify DraftModelTop1Proposer correctly handles case where no sequences
"""Verify Top1Proposer correctly handles case where no sequences
can speculate.
"""
k = 10
@@ -325,11 +325,11 @@ def test_draft_proposals_no_speculations():
prompt_len = 10
draft_worker = MagicMock()
proposer = DraftModelTop1Proposer(
draft_worker=draft_worker,
proposer = Top1Proposer(
worker=draft_worker,
device=device,
max_model_len=prompt_len + k - 1,
vocab_size=vocab_size,
max_proposal_len=prompt_len + k - 1,
)
execute_model_data, _, _ = create_batch(batch_size,
@@ -338,7 +338,7 @@ def test_draft_proposals_no_speculations():
proposals = proposer.get_proposals(
**execute_model_data.to_dict(),
max_proposal_len=k,
proposal_len=k,
)
assert torch.is_tensor(proposals.proposal_token_ids)
@@ -353,7 +353,7 @@ def test_draft_proposals_no_speculations():
@torch.inference_mode()
def test_draft_proposals_mixed_k():
"""Verify DraftModelTop1Proposer correctly handles case some sequences can
"""Verify Top1Proposer correctly handles case some sequences can
speculate and some can't.
"""
k = 10
@@ -374,14 +374,14 @@ def test_draft_proposals_mixed_k():
for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len]
draft_worker = MagicMock()
proposer = DraftModelTop1Proposer(
draft_worker=draft_worker,
proposer = Top1Proposer(
worker=draft_worker,
device=device,
max_model_len=long_prompt_len + prev_output_token_len + k - 1,
vocab_size=vocab_size,
max_proposal_len=long_prompt_len + prev_output_token_len + k - 1,
)
draft_worker.execute_model_multi_step.return_value = [
draft_worker.sampler_output.return_value = [
SamplerOutput(
outputs=[],
sampled_token_probs=torch.rand(expected_num_proposal_seqs,
@@ -395,7 +395,7 @@ def test_draft_proposals_mixed_k():
device=device,
dtype=torch.long),
) for _ in range(k)
]
], True
execute_model_data, _, _ = create_batch(
batch_size,
@@ -406,7 +406,7 @@ def test_draft_proposals_mixed_k():
proposals = proposer.get_proposals(
**execute_model_data.to_dict(),
max_proposal_len=k,
proposal_len=k,
)
assert torch.is_tensor(proposals.proposal_token_ids)

View File

@@ -0,0 +1,206 @@
import torch
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from .utils import (create_execute_model_data,
create_seq_group_metadata_from_prompts, create_worker)
def test_ngram_algo_correctness_for_single_no_match():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario cannot find any candidate in one single batch
"""
block_size = 32
num_gpu_blocks = 2048 // block_size
seed = 100
model_name = 'JackFram/llama-68m'
vocab_size = 32_000
device = 'cuda:0'
ngram_worker = create_worker(
NGramWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
proposer = Top1Proposer(
worker=ngram_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=20,
)
# set ngram window (0, 3], which is window=1/2/3
ngram_worker.set_ngram_window_size(0, 3)
prompts = [
# shall find no candidate
[1, 2, 3, 4, 5, 6, 7],
]
proposal_len = 5
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts]
ngram_sampler_output_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size,
final_seq_lens=final_seq_lens))
proposals = proposer.get_proposals(
**ngram_sampler_output_data.to_dict(),
proposal_len=proposal_len,
)
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([1, proposal_len])
assert proposals.proposal_probs.shape[:-1] == torch.Size([1, proposal_len])
assert proposals.proposal_lens.shape == torch.Size([1])
assert proposals.proposal_lens.tolist() == [0]
def test_ngram_algo_correctness_for_batches_not_match_all():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario find some candidate not full in batchs
"""
block_size = 32
num_gpu_blocks = 2048 // block_size
seed = 100
model_name = 'JackFram/llama-68m'
vocab_size = 32_000
device = 'cuda:0'
ngram_worker = create_worker(
NGramWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
proposer = Top1Proposer(
worker=ngram_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=20,
)
# set ngram window (0, 3], which is window=1/2/3
ngram_worker.set_ngram_window_size(0, 3)
prompts = [
# shall find no candidate
[1, 2, 3, 4, 5, 6, 7],
# shall find candidate 12,13,14,15,16
[11, 12, 13, 14, 15, 16, 11],
# shall find candidate 23,24,25,26,21
[21, 21, 22, 23, 24, 25, 26, 21, 22],
# shall find candidate 34,35,36,37,38
[31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
# shall find no candidate as exceed max_proposal_len
[
31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 33, 34, 35, 36, 37,
38, 31, 32, 33
],
]
proposal_len = 5
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts]
ngram_sampler_output_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size,
final_seq_lens=final_seq_lens))
proposals = proposer.get_proposals(
**ngram_sampler_output_data.to_dict(),
proposal_len=proposal_len,
)
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([5, proposal_len])
assert proposals.proposal_probs.shape[:-1] == torch.Size([5, proposal_len])
assert proposals.proposal_lens.shape == torch.Size([5])
assert proposals.proposal_lens.tolist(
) == [proposal_len for _ in range(4)] + [0]
for i in range(proposal_len):
assert proposals.proposal_token_ids[0][i] == 0
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 1]
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 3]
assert proposals.proposal_token_ids[3][i] == prompts[3][i + 5]
assert proposals.proposal_token_ids[4][i] == -1
def test_ngram_algo_correctness_for_batches_match_all():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario find candidate in all batchs
"""
block_size = 32
num_gpu_blocks = 2048 // block_size
seed = 100
model_name = 'JackFram/llama-68m'
vocab_size = 32_000
device = 'cuda:0'
ngram_worker = create_worker(
NGramWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
proposer = Top1Proposer(
worker=ngram_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=20,
)
# set ngram window (0, 3], which is window=1/2/3
ngram_worker.set_ngram_window_size(0, 3)
prompts = [
# shall find candidate 12,13,14,15,16
[11, 12, 13, 14, 15, 16, 11],
# shall find candidate 23,24,25,26,21
[21, 21, 22, 23, 24, 25, 26, 21, 22],
# shall find candidate 34,35,36,37,38
[31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
]
proposal_len = 5
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts]
ngram_sampler_output_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size,
final_seq_lens=final_seq_lens))
proposals = proposer.get_proposals(
**ngram_sampler_output_data.to_dict(),
proposal_len=proposal_len,
)
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([3, proposal_len])
assert proposals.proposal_probs.shape[:-1] == torch.Size([3, proposal_len])
assert proposals.proposal_lens.shape == torch.Size([3])
assert proposals.proposal_lens.tolist() == [proposal_len for _ in range(3)]
for i in range(proposal_len):
assert proposals.proposal_token_ids[0][i] == prompts[0][i + 1]
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 3]
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 5]