[Speculative decoding] Add ngram prompt lookup decoding (#4237)
Co-authored-by: Lei Wen <wenlei03@qiyi.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from itertools import cycle
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
@@ -185,3 +186,60 @@ def get_output_from_llm_generator(
|
||||
del llm
|
||||
|
||||
return tokens, token_ids
|
||||
|
||||
|
||||
def run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len,
|
||||
force_output_len: bool,
|
||||
print_tokens: bool = False):
|
||||
"""Helper method that compares the outputs of both the baseline LLM and
|
||||
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
||||
the same when temperature is zero.
|
||||
"""
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
"San Francisco is know for its",
|
||||
"Facebook was created in 2004 by",
|
||||
"Curious George is a",
|
||||
"Python 3.11 brings improvements to its",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
# If the test requires that we generated max_output_len tokens, then set the
|
||||
# sampling params to ignore eos token.
|
||||
ignore_eos = force_output_len
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_len,
|
||||
ignore_eos=ignore_eos,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
|
||||
test_llm_generator, prompts, sampling_params)
|
||||
|
||||
(baseline_batch_tokens,
|
||||
baseline_batch_token_ids) = get_output_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
|
||||
assert len(baseline_batch_token_ids) == len(prompts)
|
||||
assert len(spec_batch_token_ids) == len(prompts)
|
||||
|
||||
for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
|
||||
spec_tokens) in enumerate(
|
||||
zip(baseline_batch_token_ids, baseline_batch_tokens,
|
||||
spec_batch_token_ids, spec_batch_tokens)):
|
||||
if print_tokens:
|
||||
print(f'{i=} {baseline_tokens=}')
|
||||
print(f'{i=} {spec_tokens=}')
|
||||
print(f'{i=} {baseline_token_ids=}')
|
||||
print(f'{i=} {spec_token_ids=}')
|
||||
assert baseline_token_ids == spec_token_ids
|
||||
|
||||
@@ -35,7 +35,8 @@ from transformers import AutoTokenizer
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
from .conftest import get_output_from_llm_generator
|
||||
from .conftest import (get_output_from_llm_generator,
|
||||
run_greedy_equality_correctness_test)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -545,60 +546,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
def run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len,
|
||||
force_output_len: bool,
|
||||
print_tokens: bool = False):
|
||||
"""Helper method that compares the outputs of both the baseline LLM and
|
||||
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
||||
the same when temperature is zero.
|
||||
"""
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
"San Francisco is know for its",
|
||||
"Facebook was created in 2004 by",
|
||||
"Curious George is a",
|
||||
"Python 3.11 brings improvements to its",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
# If the test requires that we generated max_output_len tokens, then set the
|
||||
# sampling params to ignore eos token.
|
||||
ignore_eos = force_output_len
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_len,
|
||||
ignore_eos=ignore_eos,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
|
||||
test_llm_generator, prompts, sampling_params)
|
||||
|
||||
(baseline_batch_tokens,
|
||||
baseline_batch_token_ids) = get_output_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
|
||||
assert len(baseline_batch_token_ids) == len(prompts)
|
||||
assert len(spec_batch_token_ids) == len(prompts)
|
||||
|
||||
for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
|
||||
spec_tokens) in enumerate(
|
||||
zip(baseline_batch_token_ids, baseline_batch_tokens,
|
||||
spec_batch_token_ids, spec_batch_tokens)):
|
||||
if print_tokens:
|
||||
print(f'{i=} {baseline_tokens=}')
|
||||
print(f'{i=} {spec_tokens=}')
|
||||
print(f'{i=} {baseline_token_ids=}')
|
||||
print(f'{i=} {spec_token_ids=}')
|
||||
assert baseline_token_ids == spec_token_ids
|
||||
172
tests/spec_decode/e2e/test_ngram_correctness.py
Normal file
172
tests/spec_decode/e2e/test_ngram_correctness.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""This docstring details important information on the testing methodology.
|
||||
|
||||
Most of the tests rely on "greedy equality", where we expect the output of
|
||||
speculative decoding on a sequence to exactly match the output of normal non-
|
||||
speculative decoding.
|
||||
|
||||
Since speculative decoding with rejection sampling guarantees that the output
|
||||
distribution matches the target model's output distribution (up to hardware
|
||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||
equality.
|
||||
|
||||
For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding,
|
||||
and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775.
|
||||
Since there is no model is needed for generate the proposal, we could make
|
||||
the testcase much simpler than drafter multi-step one.
|
||||
|
||||
However, we still need to verify below scenario could be passed:
|
||||
* Batch size 1 greedy equality
|
||||
* Batch size >1 greedy equality
|
||||
* Test greedy equality under preemption
|
||||
* Test greedy equality under various ngram sizes / speculative sizes
|
||||
|
||||
With those tests, we can say at least, ngram spec would not break the correctess
|
||||
for the target model outputs.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import run_greedy_equality_correctness_test
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"model": "JackFram/llama-68m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
|
||||
test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality on a tiny model with different batch size."""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 8,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"model": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": k,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
}
|
||||
# Try a range of common k, as well as large speculation.
|
||||
for k in [1, 3, 5]
|
||||
] + [
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": k,
|
||||
"ngram_prompt_lookup_max": 1,
|
||||
}
|
||||
# Try a range of common k, as well as large speculation.
|
||||
for k in [1, 3, 5]
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_different_k(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
"""Verify that ngram speculative decoding produces exact equality
|
||||
to without spec decode with many different values of k and
|
||||
different ngram_prompt_lookup_max.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
@@ -6,8 +6,8 @@ import torch
|
||||
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.spec_decode.multi_step_worker import (DraftModelTop1Proposer,
|
||||
MultiStepWorker)
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
from .utils import (assert_logprobs_dict_allclose, create_batch,
|
||||
@@ -117,8 +117,8 @@ def test_same_output_for_single_step():
|
||||
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
actual_output = multi_step_worker.execute_model_multi_step(
|
||||
**multi_step_execute_model_data.to_dict(), num_steps=num_steps)
|
||||
actual_output, _ = multi_step_worker.sampler_output(
|
||||
**multi_step_execute_model_data.to_dict(), sample_len=num_steps)
|
||||
assert len(actual_output) == num_steps
|
||||
actual_output = actual_output[0]
|
||||
|
||||
@@ -200,8 +200,8 @@ def test_same_output_for_multi_step():
|
||||
# Run multi-step.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
multi_step_output = multi_step_worker.execute_model_multi_step(
|
||||
**execute_model_data.to_dict(), num_steps=num_steps)
|
||||
multi_step_output, _ = multi_step_worker.sampler_output(
|
||||
**execute_model_data.to_dict(), sample_len=num_steps)
|
||||
|
||||
# Run single-step repeatedly.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
@@ -266,7 +266,7 @@ def test_same_output_for_multi_step():
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_draft_proposals_full_speculation_len():
|
||||
"""Verify DraftModelTop1Proposer correctly handles case where all sequences
|
||||
"""Verify Top1Proposer correctly handles case where all sequences
|
||||
can speculate.
|
||||
"""
|
||||
k = 10
|
||||
@@ -275,13 +275,13 @@ def test_draft_proposals_full_speculation_len():
|
||||
device = 'cuda:0'
|
||||
|
||||
draft_worker = MagicMock()
|
||||
proposer = DraftModelTop1Proposer(
|
||||
draft_worker=draft_worker,
|
||||
proposer = Top1Proposer(
|
||||
worker=draft_worker,
|
||||
device=device,
|
||||
max_model_len=2048,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=2048,
|
||||
)
|
||||
draft_worker.execute_model_multi_step.return_value = [
|
||||
draft_worker.sampler_output.return_value = [
|
||||
SamplerOutput(
|
||||
outputs=[],
|
||||
sampled_token_probs=torch.rand(batch_size,
|
||||
@@ -294,13 +294,13 @@ def test_draft_proposals_full_speculation_len():
|
||||
device=device,
|
||||
dtype=torch.long),
|
||||
) for _ in range(k)
|
||||
]
|
||||
], True
|
||||
|
||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**execute_model_data.to_dict(),
|
||||
max_proposal_len=k,
|
||||
proposal_len=k,
|
||||
)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
@@ -315,7 +315,7 @@ def test_draft_proposals_full_speculation_len():
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_draft_proposals_no_speculations():
|
||||
"""Verify DraftModelTop1Proposer correctly handles case where no sequences
|
||||
"""Verify Top1Proposer correctly handles case where no sequences
|
||||
can speculate.
|
||||
"""
|
||||
k = 10
|
||||
@@ -325,11 +325,11 @@ def test_draft_proposals_no_speculations():
|
||||
prompt_len = 10
|
||||
|
||||
draft_worker = MagicMock()
|
||||
proposer = DraftModelTop1Proposer(
|
||||
draft_worker=draft_worker,
|
||||
proposer = Top1Proposer(
|
||||
worker=draft_worker,
|
||||
device=device,
|
||||
max_model_len=prompt_len + k - 1,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=prompt_len + k - 1,
|
||||
)
|
||||
|
||||
execute_model_data, _, _ = create_batch(batch_size,
|
||||
@@ -338,7 +338,7 @@ def test_draft_proposals_no_speculations():
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**execute_model_data.to_dict(),
|
||||
max_proposal_len=k,
|
||||
proposal_len=k,
|
||||
)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
@@ -353,7 +353,7 @@ def test_draft_proposals_no_speculations():
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_draft_proposals_mixed_k():
|
||||
"""Verify DraftModelTop1Proposer correctly handles case some sequences can
|
||||
"""Verify Top1Proposer correctly handles case some sequences can
|
||||
speculate and some can't.
|
||||
"""
|
||||
k = 10
|
||||
@@ -374,14 +374,14 @@ def test_draft_proposals_mixed_k():
|
||||
for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len]
|
||||
|
||||
draft_worker = MagicMock()
|
||||
proposer = DraftModelTop1Proposer(
|
||||
draft_worker=draft_worker,
|
||||
proposer = Top1Proposer(
|
||||
worker=draft_worker,
|
||||
device=device,
|
||||
max_model_len=long_prompt_len + prev_output_token_len + k - 1,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=long_prompt_len + prev_output_token_len + k - 1,
|
||||
)
|
||||
|
||||
draft_worker.execute_model_multi_step.return_value = [
|
||||
draft_worker.sampler_output.return_value = [
|
||||
SamplerOutput(
|
||||
outputs=[],
|
||||
sampled_token_probs=torch.rand(expected_num_proposal_seqs,
|
||||
@@ -395,7 +395,7 @@ def test_draft_proposals_mixed_k():
|
||||
device=device,
|
||||
dtype=torch.long),
|
||||
) for _ in range(k)
|
||||
]
|
||||
], True
|
||||
|
||||
execute_model_data, _, _ = create_batch(
|
||||
batch_size,
|
||||
@@ -406,7 +406,7 @@ def test_draft_proposals_mixed_k():
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**execute_model_data.to_dict(),
|
||||
max_proposal_len=k,
|
||||
proposal_len=k,
|
||||
)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
|
||||
206
tests/spec_decode/test_ngram_worker.py
Normal file
206
tests/spec_decode/test_ngram_worker.py
Normal file
@@ -0,0 +1,206 @@
|
||||
import torch
|
||||
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
|
||||
from .utils import (create_execute_model_data,
|
||||
create_seq_group_metadata_from_prompts, create_worker)
|
||||
|
||||
|
||||
def test_ngram_algo_correctness_for_single_no_match():
|
||||
"""Verify our ngram algo find the right candidate in the prompt
|
||||
|
||||
For the scenario cannot find any candidate in one single batch
|
||||
"""
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
vocab_size = 32_000
|
||||
device = 'cuda:0'
|
||||
|
||||
ngram_worker = create_worker(
|
||||
NGramWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
|
||||
proposer = Top1Proposer(
|
||||
worker=ngram_worker,
|
||||
device=device,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=20,
|
||||
)
|
||||
|
||||
# set ngram window (0, 3], which is window=1/2/3
|
||||
ngram_worker.set_ngram_window_size(0, 3)
|
||||
|
||||
prompts = [
|
||||
# shall find no candidate
|
||||
[1, 2, 3, 4, 5, 6, 7],
|
||||
]
|
||||
|
||||
proposal_len = 5
|
||||
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
ngram_sampler_output_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size,
|
||||
final_seq_lens=final_seq_lens))
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**ngram_sampler_output_data.to_dict(),
|
||||
proposal_len=proposal_len,
|
||||
)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([1, proposal_len])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([1, proposal_len])
|
||||
assert proposals.proposal_lens.shape == torch.Size([1])
|
||||
assert proposals.proposal_lens.tolist() == [0]
|
||||
|
||||
|
||||
def test_ngram_algo_correctness_for_batches_not_match_all():
|
||||
"""Verify our ngram algo find the right candidate in the prompt
|
||||
|
||||
For the scenario find some candidate not full in batchs
|
||||
"""
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
vocab_size = 32_000
|
||||
device = 'cuda:0'
|
||||
|
||||
ngram_worker = create_worker(
|
||||
NGramWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
|
||||
proposer = Top1Proposer(
|
||||
worker=ngram_worker,
|
||||
device=device,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=20,
|
||||
)
|
||||
|
||||
# set ngram window (0, 3], which is window=1/2/3
|
||||
ngram_worker.set_ngram_window_size(0, 3)
|
||||
|
||||
prompts = [
|
||||
# shall find no candidate
|
||||
[1, 2, 3, 4, 5, 6, 7],
|
||||
# shall find candidate 12,13,14,15,16
|
||||
[11, 12, 13, 14, 15, 16, 11],
|
||||
# shall find candidate 23,24,25,26,21
|
||||
[21, 21, 22, 23, 24, 25, 26, 21, 22],
|
||||
# shall find candidate 34,35,36,37,38
|
||||
[31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
|
||||
# shall find no candidate as exceed max_proposal_len
|
||||
[
|
||||
31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 33, 34, 35, 36, 37,
|
||||
38, 31, 32, 33
|
||||
],
|
||||
]
|
||||
|
||||
proposal_len = 5
|
||||
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
ngram_sampler_output_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size,
|
||||
final_seq_lens=final_seq_lens))
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**ngram_sampler_output_data.to_dict(),
|
||||
proposal_len=proposal_len,
|
||||
)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([5, proposal_len])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([5, proposal_len])
|
||||
assert proposals.proposal_lens.shape == torch.Size([5])
|
||||
|
||||
assert proposals.proposal_lens.tolist(
|
||||
) == [proposal_len for _ in range(4)] + [0]
|
||||
|
||||
for i in range(proposal_len):
|
||||
assert proposals.proposal_token_ids[0][i] == 0
|
||||
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 1]
|
||||
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 3]
|
||||
assert proposals.proposal_token_ids[3][i] == prompts[3][i + 5]
|
||||
assert proposals.proposal_token_ids[4][i] == -1
|
||||
|
||||
|
||||
def test_ngram_algo_correctness_for_batches_match_all():
|
||||
"""Verify our ngram algo find the right candidate in the prompt
|
||||
|
||||
For the scenario find candidate in all batchs
|
||||
"""
|
||||
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
vocab_size = 32_000
|
||||
device = 'cuda:0'
|
||||
|
||||
ngram_worker = create_worker(
|
||||
NGramWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
|
||||
proposer = Top1Proposer(
|
||||
worker=ngram_worker,
|
||||
device=device,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=20,
|
||||
)
|
||||
|
||||
# set ngram window (0, 3], which is window=1/2/3
|
||||
ngram_worker.set_ngram_window_size(0, 3)
|
||||
|
||||
prompts = [
|
||||
# shall find candidate 12,13,14,15,16
|
||||
[11, 12, 13, 14, 15, 16, 11],
|
||||
# shall find candidate 23,24,25,26,21
|
||||
[21, 21, 22, 23, 24, 25, 26, 21, 22],
|
||||
# shall find candidate 34,35,36,37,38
|
||||
[31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
|
||||
]
|
||||
|
||||
proposal_len = 5
|
||||
final_seq_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
ngram_sampler_output_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size,
|
||||
final_seq_lens=final_seq_lens))
|
||||
|
||||
proposals = proposer.get_proposals(
|
||||
**ngram_sampler_output_data.to_dict(),
|
||||
proposal_len=proposal_len,
|
||||
)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([3, proposal_len])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([3, proposal_len])
|
||||
assert proposals.proposal_lens.shape == torch.Size([3])
|
||||
|
||||
assert proposals.proposal_lens.tolist() == [proposal_len for _ in range(3)]
|
||||
|
||||
for i in range(proposal_len):
|
||||
assert proposals.proposal_token_ids[0][i] == prompts[0][i + 1]
|
||||
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 3]
|
||||
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 5]
|
||||
Reference in New Issue
Block a user