[2/N] Chunked prefill data update (#3538)
This commit is contained in:
@@ -10,6 +10,10 @@ from vllm.sequence import Logprob, SequenceGroup
|
||||
from .utils import create_dummy_prompt
|
||||
|
||||
|
||||
def get_sequence_groups(scheduler_output):
|
||||
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
|
||||
|
||||
|
||||
def test_scheduler_add_seq_group():
|
||||
block_size = 4
|
||||
scheduler_config = SchedulerConfig(100, 64, 1)
|
||||
@@ -57,9 +61,9 @@ def test_scheduler_schedule_simple():
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
running: List[SequenceGroup] = []
|
||||
|
||||
# Add seq groups to scheduler.
|
||||
running: List[SequenceGroup] = []
|
||||
for i in range(num_seq_group):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
@@ -68,7 +72,7 @@ def test_scheduler_schedule_simple():
|
||||
# Schedule seq groups prompts.
|
||||
num_tokens = block_size * num_seq_group
|
||||
seq_group_meta, out = scheduler.schedule()
|
||||
assert set(out.scheduled_seq_groups) == set(running)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
assert out.num_batched_tokens == num_tokens
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
@@ -76,7 +80,7 @@ def test_scheduler_schedule_simple():
|
||||
|
||||
# Schedule seq groups generation.
|
||||
seq_group_meta, out = scheduler.schedule()
|
||||
assert set(out.scheduled_seq_groups) == set(running)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
assert out.num_batched_tokens == num_seq_group
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
@@ -100,7 +104,7 @@ def test_scheduler_schedule_preempt_abort():
|
||||
|
||||
# Schedule seq groups prompts.
|
||||
seq_group_meta, out = scheduler.schedule()
|
||||
assert out.scheduled_seq_groups == [seq_group_a, seq_group_b]
|
||||
assert get_sequence_groups(out) == [seq_group_a, seq_group_b]
|
||||
assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
@@ -115,7 +119,7 @@ def test_scheduler_schedule_preempt_abort():
|
||||
|
||||
# Schedule seq groups generation and preempt seq group b.
|
||||
seq_group_meta, out = scheduler.schedule()
|
||||
assert out.scheduled_seq_groups == [seq_group_a]
|
||||
assert get_sequence_groups(out) == [seq_group_a]
|
||||
assert out.num_batched_tokens == 1
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
@@ -125,7 +129,7 @@ def test_scheduler_schedule_preempt_abort():
|
||||
# Abort seq group a. Re-schedule seq group b prompt with recomputation.
|
||||
scheduler.abort_seq_group("1")
|
||||
seq_group_meta, out = scheduler.schedule()
|
||||
assert out.scheduled_seq_groups == [seq_group_b]
|
||||
assert get_sequence_groups(out) == [seq_group_b]
|
||||
assert out.num_batched_tokens == 5 # 4 prompt + 1 generation.
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
@@ -155,11 +159,11 @@ def test_scheduler_max_seqs():
|
||||
|
||||
# Schedule seq groups prompts.
|
||||
_, out = scheduler.schedule()
|
||||
assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]])
|
||||
assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
|
||||
|
||||
# Schedule seq groups generation.
|
||||
_, out = scheduler.schedule()
|
||||
assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]])
|
||||
assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
|
||||
|
||||
# Append 2 more seq group
|
||||
scheduler.add_seq_group(all_seq_groups[1])
|
||||
@@ -169,7 +173,7 @@ def test_scheduler_max_seqs():
|
||||
# Only 1 seq group should be scheduled since max_seq_group is 2
|
||||
# and one is prompting.
|
||||
_, out = scheduler.schedule()
|
||||
assert set(out.scheduled_seq_groups) == set([all_seq_groups[1]])
|
||||
assert set(get_sequence_groups(out)) == set([all_seq_groups[1]])
|
||||
|
||||
|
||||
def test_scheduler_delay_factor():
|
||||
|
||||
Reference in New Issue
Block a user