[Feature] [Spec decode]: Combine chunked prefill with speculative decoding (#9291)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -5,40 +5,6 @@ from vllm import SamplingParams
|
||||
from .conftest import get_output_from_llm_generator
|
||||
|
||||
|
||||
@pytest.mark.parametrize("common_llm_kwargs", [{
|
||||
"model": "JackFram/llama-68m",
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"enable_chunked_prefill": True,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
|
||||
"""Verify that speculative decoding with chunked prefill fails.
|
||||
"""
|
||||
output_len = 128
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match="Speculative decoding and chunked prefill"):
|
||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("common_llm_kwargs", [{
|
||||
"model": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
|
||||
@@ -62,6 +62,16 @@ from .conftest import (get_output_from_llm_generator,
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
# Chunked prefill enabled with small value
|
||||
# to make sure we get mixed batches.
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
{
|
||||
# Verify the detokenizer assertions in the test work when spec
|
||||
@@ -141,6 +151,14 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
@@ -204,6 +222,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
@@ -255,6 +281,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("max_output_len", [
|
||||
@@ -300,6 +334,14 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
@@ -347,6 +389,14 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [32])
|
||||
@@ -397,6 +447,14 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
@@ -454,6 +512,14 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@@ -503,6 +569,15 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
|
||||
# Artificially limit the draft model max model len; this forces vLLM
|
||||
# to skip speculation once the sequences grow beyond 32-k tokens.
|
||||
"speculative_max_model_len": 32,
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
"speculative_max_model_len": 32,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@@ -551,6 +626,15 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs,
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_disable_by_batch_size": 2,
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_disable_by_batch_size": 2,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@@ -590,10 +674,17 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs,
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
"enable_chunked_prefill": False,
|
||||
}
|
||||
# Try a range of common k, as well as large speculation.
|
||||
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
|
||||
])
|
||||
] + [{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4,
|
||||
} for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
@@ -636,11 +727,19 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
"spec_decoding_acceptance_method": "typical_acceptance_sampler"
|
||||
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
|
||||
"enable_chunked_prefill": False
|
||||
}
|
||||
# Try a range of common k.
|
||||
for k in [1, 2, 3]
|
||||
])
|
||||
] + [{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": k,
|
||||
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
} for k in [1, 2, 3]])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
|
||||
@@ -50,18 +50,33 @@ from .conftest import run_equality_correctness_test
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
},
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify greedy equality on a tiny model with different batch size."""
|
||||
if prefill_chunk_size > 0:
|
||||
common_llm_kwargs.update(
|
||||
**{
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": prefill_chunk_size,
|
||||
"max_num_seqs": prefill_chunk_size
|
||||
})
|
||||
else:
|
||||
common_llm_kwargs["enable_chunked_prefill"] = False
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
@@ -151,6 +166,16 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"enable_chunked_prefill": True,
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
@@ -251,6 +276,15 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"speculative_disable_by_batch_size": 4
|
||||
}, {
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
"speculative_disable_by_batch_size": 4,
|
||||
"enable_chunked_prefill": True,
|
||||
"speculative_disable_mqa_scorer": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -118,7 +118,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
for sg in seq_group_metadata_list:
|
||||
sg.is_prompt = False
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
@@ -147,7 +148,7 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
|
||||
def test_ngram_algo_correctness_for_batches_match_all():
|
||||
"""Verify our ngram algo find the right candidate in the prompt
|
||||
|
||||
For the scenario find candidate in all batchs
|
||||
For the scenario find candidate in all batches
|
||||
"""
|
||||
|
||||
block_size = 32
|
||||
@@ -192,6 +193,10 @@ def test_ngram_algo_correctness_for_batches_match_all():
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Normally drafter is run on decode requests only; here we check the output
|
||||
# of the ngram worker as it is the sole proposer that has no forward.
|
||||
for sg in seq_group_metadata_list:
|
||||
sg.is_prompt = False
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
|
||||
@@ -46,12 +46,14 @@ def assert_score_equal(score1: SpeculativeScores,
|
||||
@pytest.mark.parametrize('max_propose_len', [1, 3, 5])
|
||||
@pytest.mark.parametrize('mixed_propose_len', [True])
|
||||
@pytest.mark.parametrize('device', ['cuda'])
|
||||
@pytest.mark.parametrize('prefill_chunking', [False, True])
|
||||
def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
|
||||
mixed_propose_len: bool, device: str) -> None:
|
||||
mixed_propose_len: bool, device: str,
|
||||
prefill_chunking: bool) -> None:
|
||||
"""
|
||||
Compare the batch expansion scorer and mqa scorer return the same score.
|
||||
We test for both queries with the same propose length and different
|
||||
propose length.
|
||||
propose length, as well as mixed prefill-decode batches.
|
||||
"""
|
||||
seed = 0
|
||||
block_size = 32
|
||||
@@ -67,16 +69,37 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
|
||||
if not mixed_propose_len:
|
||||
propose_lens = [max_propose_len] * batch_size
|
||||
else:
|
||||
non_zero_cnt = random.randint(0, batch_size)
|
||||
# There must be at least 1 decode request, otherwise
|
||||
# we have nothing to score (`_run_no_spec`).
|
||||
non_zero_cnt = random.randint(1, batch_size)
|
||||
propose_lens = [max_propose_len
|
||||
] * non_zero_cnt + [0] * (batch_size - non_zero_cnt)
|
||||
random.shuffle(propose_lens)
|
||||
|
||||
proposals = create_proposal(propose_lens, vocab_size, device)
|
||||
seq_group_metadatalist, _, _ = create_batch(batch_size,
|
||||
max_propose_len,
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
if mixed_propose_len and prefill_chunking and (n_prefills :=
|
||||
batch_size - non_zero_cnt):
|
||||
prefill, _, _ = create_batch(n_prefills,
|
||||
None,
|
||||
prefill_chunk_size=4,
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
seq_ids=list(
|
||||
range(batch_size,
|
||||
batch_size + n_prefills)))
|
||||
# re-order to guarantee prefill|decode order
|
||||
target_group_metadatalist = [
|
||||
seq_group_metadatalist[i] for i, p in enumerate(propose_lens)
|
||||
if p > 0
|
||||
]
|
||||
seq_group_metadatalist = prefill + target_group_metadatalist
|
||||
propose_lens = [0] * n_prefills + [p for p in propose_lens if p > 0]
|
||||
|
||||
proposals = create_proposal(propose_lens, vocab_size, device)
|
||||
requests = ExecuteModelRequest(seq_group_metadatalist,
|
||||
num_lookahead_slots=max_propose_len)
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import torch
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest, SequenceOutput
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
|
||||
SpecDecodeWorkerMetrics)
|
||||
@@ -819,3 +820,84 @@ def test_handle_finished_requests():
|
||||
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
|
||||
assert worker._seq_with_bonus_token_in_last_step == \
|
||||
{4,5,10}
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [3])
|
||||
@pytest.mark.parametrize('batch_size', [2, 32])
|
||||
@pytest.mark.parametrize("batch_composition",
|
||||
["prefill_only", "decode_only", "mixed"])
|
||||
@torch.inference_mode()
|
||||
def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
|
||||
"""
|
||||
Verify SpecDecodeWorker calls match the expected flow.
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
mock_spec_decode_sampler("rejection_sampler"),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
exception_secret = 'artificial stop'
|
||||
worker.scorer = mock_worker(BatchExpansionTop1Scorer)
|
||||
worker.scorer.score_proposals.side_effect = ValueError(exception_secret)
|
||||
|
||||
# Create batch with combination of terminal/non-terminal prefill chunks
|
||||
# and decodes (different seq_ids).
|
||||
decodes, _, _ = create_batch(batch_size, k)
|
||||
# Pre-chunking here, get 'batch_size' chunks.
|
||||
prefill, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
prefill_chunk_size=4,
|
||||
seq_ids=list(range(batch_size,
|
||||
batch_size * 2)))
|
||||
|
||||
if batch_composition == "prefill_only":
|
||||
n_prefills = batch_size
|
||||
elif batch_composition == "decode_only":
|
||||
n_prefills = 0
|
||||
else:
|
||||
n_prefills = random.randint(1, batch_size - 1)
|
||||
n_decodes = batch_size - n_prefills
|
||||
|
||||
prefill = random.sample(prefill, n_prefills)
|
||||
decodes = random.sample(decodes, n_decodes)
|
||||
target_group_metadata_list = prefill + decodes
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=target_group_metadata_list,
|
||||
num_lookahead_slots=k)
|
||||
|
||||
target_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(1, batch_size * (k + 1)),
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
target_token_probs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_token_logprobs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs,
|
||||
target_token_logprobs)
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
if not len(decodes):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
# no spec run (prefill only)
|
||||
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
else:
|
||||
# Decode-only run OR mixed batch, scorer call fails (it's mocked)
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
# but first draft still counted
|
||||
assert draft_worker.get_spec_proposals.call_count == 1
|
||||
|
||||
@@ -146,6 +146,41 @@ def create_seq_group_metadata_from_prompts(
|
||||
return seq_grou_metadata_list
|
||||
|
||||
|
||||
def create_chunked_seq_group_metadata_from_prompt(
|
||||
prompt: List[int],
|
||||
num_gpu_blocks: int,
|
||||
chunk_size: int,
|
||||
block_size: int,
|
||||
seq_id: Optional[int] = None) -> List[SequenceGroupMetadata]:
|
||||
|
||||
if seq_id is None:
|
||||
seq_id = 0
|
||||
|
||||
free_gpu_blocks = list(range(num_gpu_blocks))
|
||||
|
||||
block_allocations = [
|
||||
free_gpu_blocks.pop()
|
||||
for _ in range(round_up_to_next_block(len(prompt), block_size))
|
||||
]
|
||||
|
||||
seq_group_metadata_list = []
|
||||
for i, idx in enumerate(range(0, len(prompt), chunk_size)):
|
||||
chunk_ids = prompt[idx:idx + chunk_size]
|
||||
data = SequenceData.from_seqs(prompt)
|
||||
data.update_num_computed_tokens(idx)
|
||||
seq_data = {i: data}
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=str(seq_id),
|
||||
is_prompt=True,
|
||||
do_sample=idx + chunk_size >= len(prompt), # terminal chunk
|
||||
seq_data=seq_data,
|
||||
sampling_params=SamplingParams(temperature=0.0),
|
||||
block_tables={i: block_allocations},
|
||||
token_chunk_size=len(chunk_ids)))
|
||||
return seq_group_metadata_list
|
||||
|
||||
|
||||
def assert_logprobs_dict_allclose(
|
||||
actual_logprobs: List[Dict[int, Logprob]],
|
||||
expected_logprobs: List[Dict[int, Logprob]]) -> None:
|
||||
@@ -198,7 +233,8 @@ def create_batch(batch_size,
|
||||
prev_output_token_len: int = 10,
|
||||
seq_ids: Optional[List[int]] = None,
|
||||
num_gpu_blocks: Optional[int] = None,
|
||||
block_size: Optional[int] = None):
|
||||
block_size: Optional[int] = None,
|
||||
prefill_chunk_size: Optional[int] = None):
|
||||
if block_size is None:
|
||||
block_size = 8
|
||||
|
||||
@@ -213,15 +249,28 @@ def create_batch(batch_size,
|
||||
prompt_lens = prompt_len
|
||||
|
||||
prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens]
|
||||
prev_output_tokens = [[
|
||||
next(iterator) for _ in range(prev_output_token_len)
|
||||
] for _ in range(batch_size)]
|
||||
final_prompt_lens = [
|
||||
len(prompt) + len(prev_output_token) + k + 1
|
||||
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
|
||||
]
|
||||
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size, final_prompt_lens,
|
||||
prev_output_tokens, seq_ids)
|
||||
if prefill_chunk_size:
|
||||
# Create a batch of chunked prompts.
|
||||
if not seq_ids:
|
||||
seq_ids = list(range(len(prompts)))
|
||||
seq_group_metadata_list = []
|
||||
for p, sid in zip(prompts, seq_ids):
|
||||
seq_group_metadata_list += \
|
||||
create_chunked_seq_group_metadata_from_prompt(
|
||||
p, num_gpu_blocks, prefill_chunk_size, block_size, sid)
|
||||
seq_group_metadata_list = seq_group_metadata_list[:batch_size]
|
||||
prev_output_tokens = []
|
||||
else:
|
||||
prev_output_tokens = [[
|
||||
next(iterator) for _ in range(prev_output_token_len)
|
||||
] for _ in range(batch_size)]
|
||||
final_prompt_lens = [
|
||||
len(prompt) + len(prev_output_token) + k + 1
|
||||
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
|
||||
]
|
||||
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size, final_prompt_lens,
|
||||
prev_output_tokens, seq_ids)
|
||||
return seq_group_metadata_list, prompts, prev_output_tokens
|
||||
|
||||
Reference in New Issue
Block a user