Prefix Cache Aware Scheduling [1/n] (#10128)
Signed-off-by: rickyx <rickyx@anyscale.com>
This commit is contained in:
@@ -12,9 +12,9 @@ from vllm.core.scheduler import Scheduler, SchedulingBudget
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SequenceGroup
|
||||
|
||||
from .utils import (append_new_token, append_new_token_seq_group,
|
||||
create_dummy_prompt, get_sequence_groups,
|
||||
schedule_and_update_computed_tokens)
|
||||
from .utils import (append_new_token, append_new_token_seq,
|
||||
append_new_token_seq_group, create_dummy_prompt,
|
||||
get_sequence_groups, schedule_and_update_computed_tokens)
|
||||
|
||||
|
||||
def test_scheduler_add_seq_group():
|
||||
@@ -305,6 +305,8 @@ def initialize_scheduler(
|
||||
block_size=4,
|
||||
num_cpu_blocks=8,
|
||||
num_gpu_blocks=8,
|
||||
enable_prefix_caching=False,
|
||||
enable_chunked_prefill=False,
|
||||
):
|
||||
block_size = block_size
|
||||
scheduler_config = SchedulerConfig(
|
||||
@@ -312,8 +314,15 @@ def initialize_scheduler(
|
||||
max_num_batched_tokens=max_token_budget,
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_model_len=max_model_len,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
)
|
||||
cache_config = CacheConfig(
|
||||
block_size,
|
||||
1.0,
|
||||
1,
|
||||
"auto",
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
scheduler = Scheduler(scheduler_config, cache_config, lora_config)
|
||||
@@ -800,3 +809,165 @@ def test_scheduling_budget():
|
||||
assert budget.num_curr_seqs == 0
|
||||
budget.subtract_num_seqs(seq_group.request_id, 2)
|
||||
assert budget.num_curr_seqs == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_prefix_caching", [True, False])
|
||||
def test_prefix_caching_aware_prefills(enable_prefix_caching):
|
||||
"""
|
||||
Test the below scenario:
|
||||
|
||||
For 3 sequences, seqA, seqB, seqC, share the first block as prefix.
|
||||
|
||||
The test verifies the below scenarios:
|
||||
1. SeqA is first scheduled.
|
||||
2. SeqB and SeqC can be prefilled together in a single schedule round
|
||||
even though there are not enough token budgets to prefill both without
|
||||
considering prefix caching.
|
||||
"""
|
||||
|
||||
block_size = 4
|
||||
max_num_batched_tokens = 12
|
||||
max_seq_group = 3
|
||||
scheduler = initialize_scheduler(
|
||||
block_size=block_size,
|
||||
num_cpu_blocks=16,
|
||||
num_gpu_blocks=16,
|
||||
max_token_budget=max_num_batched_tokens,
|
||||
max_num_seqs=max_seq_group,
|
||||
max_model_len=max_num_batched_tokens,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
)
|
||||
|
||||
seqA_tokens = list(range(8))
|
||||
num_shared_tokens = 4
|
||||
seqB_tokens = seqA_tokens[:num_shared_tokens] + list(range(
|
||||
12, 16)) # Shared prefix first 4.
|
||||
seqC_tokens = seqA_tokens[:num_shared_tokens] + list(range(
|
||||
16, 20)) # Shared prefix first 4.
|
||||
|
||||
seqA, seqA_group = create_dummy_prompt("0",
|
||||
prompt_tokens=seqA_tokens,
|
||||
block_size=block_size)
|
||||
seqB, seqB_group = create_dummy_prompt("1",
|
||||
prompt_tokens=seqB_tokens,
|
||||
block_size=block_size)
|
||||
seqC, seqC_group = create_dummy_prompt("2",
|
||||
prompt_tokens=seqC_tokens,
|
||||
block_size=block_size)
|
||||
|
||||
# Schedule seqA prefill.
|
||||
scheduler.add_seq_group(seqA_group)
|
||||
metas, out, _ = scheduler.schedule()
|
||||
assert (len(out.scheduled_seq_groups) == 1
|
||||
and out.scheduled_seq_groups[0].seq_group == seqA_group)
|
||||
assert out.scheduled_seq_groups[0].token_chunk_size == len(seqA_tokens)
|
||||
|
||||
# Schedule seqA decode.
|
||||
append_new_token_seq_group(len(seqA_tokens), seqA_group, 999)
|
||||
metas, out, _ = scheduler.schedule()
|
||||
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert out.scheduled_seq_groups[0].seq_group == seqA_group
|
||||
assert out.scheduled_seq_groups[0].token_chunk_size == 1
|
||||
|
||||
# Schedule seqB and seqC prefills should work with prefix caching.
|
||||
scheduler.add_seq_group(seqB_group)
|
||||
scheduler.add_seq_group(seqC_group)
|
||||
metas, out, _ = scheduler.schedule()
|
||||
|
||||
if enable_prefix_caching:
|
||||
assert len(out.scheduled_seq_groups) == 2
|
||||
assert set([
|
||||
out.scheduled_seq_groups[0].seq_group,
|
||||
out.scheduled_seq_groups[1].seq_group,
|
||||
]) == set([seqB_group, seqC_group])
|
||||
assert len(metas) == 2
|
||||
for meta in metas:
|
||||
assert meta.token_chunk_size == 8
|
||||
assert (len(meta.computed_block_nums) == num_shared_tokens //
|
||||
block_size) # 1 Block for the 8 tokens.
|
||||
else:
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert len(metas) == 1
|
||||
assert metas[0].token_chunk_size == 8
|
||||
assert len(metas[0].computed_block_nums) == 0 # No blocks computed.
|
||||
|
||||
|
||||
def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching(
|
||||
):
|
||||
"""
|
||||
This test verifies that we don't schedule new prefills if there's already
|
||||
a continuous prefill in progress even though the new prefills with shared
|
||||
prefix can fit in the token budget:
|
||||
|
||||
- SeqA is being chunked prefill.
|
||||
- SeqB with the same prompt shouldn't be scheduled for prefill even though
|
||||
there's enough token budget to prefill the cached tokens.
|
||||
- Neither should seqC be scheduled.
|
||||
|
||||
- When seqA is in decoding phase, seqB and seqC can be scheduled.
|
||||
- Entire seqB should be prefilled since it's a full prefix cache hit.
|
||||
- SeqC would be partially prefilled with the prefix shared, and the
|
||||
remaining unique tokens would be prefilled (rounded down to be
|
||||
block-size aligned).
|
||||
"""
|
||||
|
||||
block_size = 2
|
||||
max_num_batched_tokens = 4
|
||||
max_seq_group = 3
|
||||
scheduler = initialize_scheduler(
|
||||
block_size=block_size,
|
||||
num_cpu_blocks=16,
|
||||
num_gpu_blocks=16,
|
||||
max_token_budget=max_num_batched_tokens,
|
||||
max_num_seqs=max_seq_group,
|
||||
max_model_len=100,
|
||||
enable_prefix_caching=True,
|
||||
enable_chunked_prefill=True,
|
||||
)
|
||||
|
||||
seqA_tokens = list(range(8))
|
||||
seqB_tokens = seqA_tokens
|
||||
seqC_shared_prefix_len = 4
|
||||
seqC_tokens = seqA_tokens[:seqC_shared_prefix_len] + list(range(12, 20))
|
||||
|
||||
seqA, seqA_group = create_dummy_prompt("0",
|
||||
prompt_tokens=seqA_tokens,
|
||||
block_size=block_size)
|
||||
seqB, seqB_group = create_dummy_prompt("1",
|
||||
prompt_tokens=seqB_tokens,
|
||||
block_size=block_size)
|
||||
|
||||
# Chunked prefill seqA.
|
||||
scheduler.add_seq_group(seqA_group)
|
||||
metas, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert out.scheduled_seq_groups[0].seq_group == seqA_group
|
||||
assert out.scheduled_seq_groups[0].token_chunk_size == 4
|
||||
|
||||
# seqB should not be scheduled with ongoing prefills.
|
||||
scheduler.add_seq_group(seqB_group)
|
||||
metas, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert out.scheduled_seq_groups[0].seq_group == seqA_group
|
||||
assert out.scheduled_seq_groups[0].token_chunk_size == 4
|
||||
|
||||
# both seqB and seqC can now be scheduled with seqA is over.
|
||||
# seqA is in decoding phase.
|
||||
append_new_token_seq(seqA, 999)
|
||||
seqC, seqC_group = create_dummy_prompt("2",
|
||||
prompt_tokens=seqC_tokens,
|
||||
block_size=block_size)
|
||||
scheduler.add_seq_group(seqC_group)
|
||||
metas, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 3
|
||||
|
||||
metas = {meta.request_id: meta for meta in metas}
|
||||
assert metas[seqA_group.request_id].token_chunk_size == 1 # Decode
|
||||
assert (metas[seqB_group.request_id].token_chunk_size == 8
|
||||
) # Fully cached prefill
|
||||
assert (
|
||||
metas[seqC_group.request_id].token_chunk_size == 6
|
||||
), "A partial prefix of C (4 tokens) should be prefilled, with the "
|
||||
"remaining tokens fit into 3 token budget (4-1 from the seqA). It will "
|
||||
"then be rounded down to 2 tokens on block size, thus 6 tokens in total."
|
||||
|
||||
Reference in New Issue
Block a user