[1/n][Chunked Prefill] Refactor input query shapes (#3236)
This commit is contained in:
@@ -10,7 +10,7 @@ from .utils import create_dummy_prompt
|
||||
|
||||
def test_scheduler_add_seq_group():
|
||||
block_size = 4
|
||||
scheduler_config = SchedulerConfig(100, 64, 1, 256)
|
||||
scheduler_config = SchedulerConfig(100, 64, 1)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 4
|
||||
cache_config.num_gpu_blocks = 4
|
||||
@@ -26,7 +26,7 @@ def test_scheduler_add_seq_group():
|
||||
|
||||
def test_scheduler_abort_seq_group():
|
||||
block_size = 4
|
||||
scheduler_config = SchedulerConfig(100, 64, 1, 256)
|
||||
scheduler_config = SchedulerConfig(100, 64, 1)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 4
|
||||
cache_config.num_gpu_blocks = 4
|
||||
@@ -50,7 +50,7 @@ def test_scheduler_schedule_simple():
|
||||
block_size = 4
|
||||
num_seq_group = 4
|
||||
max_model_len = 16
|
||||
scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len, 256)
|
||||
scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
@@ -64,10 +64,10 @@ def test_scheduler_schedule_simple():
|
||||
running.append(seq_group)
|
||||
|
||||
# Schedule seq groups prompts.
|
||||
num_tokens = block_size * num_seq_group
|
||||
seq_group_meta, out = scheduler.schedule()
|
||||
assert set(out.scheduled_seq_groups) == set(running)
|
||||
assert out.num_batched_tokens == num_seq_group * seq_group.get_seqs(
|
||||
)[0].get_len()
|
||||
assert out.num_batched_tokens == num_tokens
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
assert len(seq_group_meta) == num_seq_group
|
||||
@@ -84,7 +84,7 @@ def test_scheduler_schedule_simple():
|
||||
def test_scheduler_schedule_preempt_abort():
|
||||
block_size = 4
|
||||
max_model_len = 16
|
||||
scheduler_config = SchedulerConfig(64, 2, max_model_len, 256)
|
||||
scheduler_config = SchedulerConfig(64, 2, max_model_len)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 2
|
||||
cache_config.num_gpu_blocks = 2
|
||||
@@ -99,7 +99,7 @@ def test_scheduler_schedule_preempt_abort():
|
||||
# Schedule seq groups prompts.
|
||||
seq_group_meta, out = scheduler.schedule()
|
||||
assert out.scheduled_seq_groups == [seq_group_a, seq_group_b]
|
||||
assert out.num_batched_tokens == seq_group_a.get_seqs()[0].get_len() * 2
|
||||
assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
assert len(seq_group_meta) == 2
|
||||
@@ -124,7 +124,7 @@ def test_scheduler_schedule_preempt_abort():
|
||||
scheduler.abort_seq_group("1")
|
||||
seq_group_meta, out = scheduler.schedule()
|
||||
assert out.scheduled_seq_groups == [seq_group_b]
|
||||
assert out.num_batched_tokens == seq_group_b.get_seqs()[0].get_len()
|
||||
assert out.num_batched_tokens == 5 # 4 prompt + 1 generation.
|
||||
assert (not out.blocks_to_copy and not out.blocks_to_swap_in
|
||||
and not out.blocks_to_swap_out)
|
||||
assert len(seq_group_meta) == 1
|
||||
@@ -136,7 +136,7 @@ def test_scheduler_max_seqs():
|
||||
num_seq_group = 4
|
||||
max_seq_group = 2
|
||||
max_model_len = 16
|
||||
scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len, 256)
|
||||
scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 8
|
||||
cache_config.num_gpu_blocks = 8
|
||||
|
||||
Reference in New Issue
Block a user