[Core] Reduce TTFT with concurrent partial prefills (#10235)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com> Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com> Co-authored-by: Prashant Gupta <prashantgupta@us.ibm.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
@@ -7,6 +7,9 @@ import pytest # noqa
|
||||
|
||||
from vllm.config import CacheConfig, SchedulerConfig
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import Logprob, SequenceGroup
|
||||
|
||||
from .utils import create_dummy_prompt
|
||||
@@ -16,7 +19,7 @@ def get_sequence_groups(scheduler_output):
|
||||
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
|
||||
|
||||
|
||||
def append_new_token(seq_group, token_id: int):
|
||||
def append_new_token(seq_group: SequenceGroup, token_id: int):
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
||||
|
||||
@@ -123,6 +126,232 @@ def test_chunk():
|
||||
assert out.num_batched_tokens == 57
|
||||
|
||||
|
||||
def test_concurrent_chunking():
|
||||
"""Verify prefills are chunked properly when
|
||||
--max-num-partial-prefills is > 1"""
|
||||
block_size = 4
|
||||
max_seqs = 60
|
||||
max_model_len = 2000
|
||||
max_num_batched_tokens = 64
|
||||
scheduler_config = SchedulerConfig(
|
||||
"generate",
|
||||
max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_partial_prefills=2, # Up to 2 partial prefills at a time
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 32
|
||||
cache_config.num_gpu_blocks = 32
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
running: List[SequenceGroup] = []
|
||||
|
||||
# Add seq groups to scheduler.
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i),
|
||||
prompt_length=60,
|
||||
block_size=block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
|
||||
# Verify both requests are chunked with half of max_num_batched_tokens each
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
assert seq_group_meta[0].token_chunk_size == 32
|
||||
assert seq_group_meta[1].token_chunk_size == 32
|
||||
assert out.num_prefill_groups == 2
|
||||
assert out.num_batched_tokens == 64
|
||||
|
||||
# After one iteration, both should have 60 - 32 = 28 tokens left to prefill
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
assert seq_group_meta[0].token_chunk_size == 28
|
||||
assert seq_group_meta[1].token_chunk_size == 28
|
||||
assert out.num_prefill_groups == 2
|
||||
assert out.num_batched_tokens == 56
|
||||
|
||||
|
||||
def test_concurrent_chunking_large_requests():
|
||||
"""Verify large prefill requests are run one at a time"""
|
||||
block_size = 4
|
||||
max_seqs = 60
|
||||
max_model_len = 2000
|
||||
max_num_batched_tokens = 64
|
||||
scheduler_config = SchedulerConfig(
|
||||
"generate",
|
||||
max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_partial_prefills=2, # Up to 2 partial prefills at a time
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests
|
||||
cache_config.num_gpu_blocks = 3200
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
|
||||
# Add seq groups to scheduler.
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(
|
||||
str(i),
|
||||
prompt_length=1200, # Very large prompt
|
||||
block_size=block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
|
||||
# Verify only a single request is chunked, and it gets all 64 tokens
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(get_sequence_groups(out)) == 1
|
||||
assert seq_group_meta[0].token_chunk_size == 64
|
||||
assert out.num_prefill_groups == 1
|
||||
assert out.num_batched_tokens == 64
|
||||
|
||||
|
||||
def test_short_prompts_jump_long_prompts_in_queue():
|
||||
"""Verify large prefill requests are punted behind smaller ones if
|
||||
another large prefill request is already running"""
|
||||
block_size = 4
|
||||
max_seqs = 60
|
||||
max_model_len = 2000
|
||||
max_num_batched_tokens = 64
|
||||
scheduler_config = SchedulerConfig(
|
||||
"generate",
|
||||
max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_partial_prefills=2, # Up to 2 partial prefills at a time
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests
|
||||
cache_config.num_gpu_blocks = 3200
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
long_seqs: List[SequenceGroup] = []
|
||||
short_seqs: List[SequenceGroup] = []
|
||||
|
||||
# Add 2 large seq groups to scheduler.
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(
|
||||
str(i),
|
||||
prompt_length=1200, # Very large prompt
|
||||
block_size=block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
long_seqs.append(seq_group)
|
||||
assert seq_group.is_prefill()
|
||||
|
||||
# Add 2 small seq groups behind them
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(
|
||||
str(i + 2),
|
||||
prompt_length=40, # Very small prompt
|
||||
block_size=block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
short_seqs.append(seq_group)
|
||||
assert seq_group.is_prefill()
|
||||
|
||||
# Verify one large req and 1 small req chunked
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert seq_group_meta[0].token_chunk_size == 32 # large req gets 32 tokens
|
||||
assert seq_group_meta[1].token_chunk_size == 32 # small req gets 32 tokens
|
||||
|
||||
# all 4 are prefilling
|
||||
assert long_seqs[0].is_prefill()
|
||||
assert long_seqs[1].is_prefill()
|
||||
assert short_seqs[0].is_prefill()
|
||||
assert short_seqs[1].is_prefill()
|
||||
# First short and first long sequences have been scheduled
|
||||
assert long_seqs[0].first_seq.get_num_computed_tokens() == 32
|
||||
assert long_seqs[1].first_seq.get_num_computed_tokens() == 0
|
||||
assert short_seqs[0].first_seq.get_num_computed_tokens() == 32
|
||||
assert short_seqs[1].first_seq.get_num_computed_tokens() == 0
|
||||
|
||||
assert out.num_prefill_groups == 2
|
||||
assert out.num_batched_tokens == 64
|
||||
|
||||
# in the second iteration,
|
||||
# the first small request had only 8 tokens left
|
||||
# so it went to decode
|
||||
# The other small req is scheduled
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
# the new small req got 64 - (32+8) tokens
|
||||
assert seq_group_meta[0].token_chunk_size == 24
|
||||
assert seq_group_meta[1].token_chunk_size == 32 # large req still got 32
|
||||
# the other small request had only 8 tokens left
|
||||
assert seq_group_meta[2].token_chunk_size == 8 # 40-32
|
||||
|
||||
# The first small request got to decode now
|
||||
assert long_seqs[0].is_prefill()
|
||||
assert long_seqs[1].is_prefill()
|
||||
assert not short_seqs[0].is_prefill()
|
||||
assert short_seqs[1].is_prefill()
|
||||
# Both small requests have started in front of the second long request
|
||||
assert long_seqs[0].first_seq.get_num_computed_tokens() == 64
|
||||
assert long_seqs[1].first_seq.get_num_computed_tokens() == 0
|
||||
assert short_seqs[0].first_seq.get_num_computed_tokens() == 40
|
||||
assert short_seqs[1].first_seq.get_num_computed_tokens() == 24
|
||||
|
||||
assert out.num_prefill_groups == 3
|
||||
assert out.num_batched_tokens == 64
|
||||
# the first small seq group has a new token appended.
|
||||
append_new_token(short_seqs[0], 1)
|
||||
|
||||
# in the third iteration,
|
||||
# the first small request is already decoding
|
||||
# the second small request only has 16 tokens left and will enter decoding
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert seq_group_meta[0].token_chunk_size == 32 # large still got 32
|
||||
# small req finished prefilling 40-24=16 tokens
|
||||
assert seq_group_meta[1].token_chunk_size == 16
|
||||
assert seq_group_meta[2].token_chunk_size == 1 # decode
|
||||
assert out.num_prefill_groups == 2
|
||||
assert out.num_batched_tokens == 49 # (32+16+1 decode)
|
||||
|
||||
# both small requests have now reached decode
|
||||
assert long_seqs[0].is_prefill()
|
||||
assert long_seqs[1].is_prefill()
|
||||
assert not short_seqs[0].is_prefill()
|
||||
assert not short_seqs[1].is_prefill()
|
||||
assert long_seqs[0].first_seq.get_num_computed_tokens() == 96
|
||||
assert long_seqs[1].first_seq.get_num_computed_tokens() == 0
|
||||
assert short_seqs[0].first_seq.get_num_computed_tokens() == 41
|
||||
assert short_seqs[1].first_seq.get_num_computed_tokens() == 40
|
||||
|
||||
# both the small seq groups have a new token appended
|
||||
append_new_token(short_seqs[0], 1)
|
||||
append_new_token(short_seqs[1], 1)
|
||||
|
||||
# in the fourth iteration, both small requests are decoding
|
||||
# so large request gets all the budget
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
|
||||
# large req gets 62 tokens (minus 2 for decode)
|
||||
assert seq_group_meta[0].token_chunk_size == 62
|
||||
assert seq_group_meta[1].token_chunk_size == 1 # decode
|
||||
assert seq_group_meta[2].token_chunk_size == 1 # decode
|
||||
assert out.num_prefill_groups == 1
|
||||
assert out.num_batched_tokens == 64
|
||||
|
||||
assert long_seqs[0].first_seq.get_num_computed_tokens() == 158
|
||||
|
||||
# assert long_seqs[0].is_prefill()
|
||||
# assert long_seqs[1].is_prefill()
|
||||
# assert not short_seqs[0].is_prefill()
|
||||
# assert not short_seqs[1].is_prefill()
|
||||
|
||||
# # both the small seq groups have a new token appended
|
||||
# append_new_token(short_seqs[0], 1)
|
||||
# append_new_token(short_seqs[1], 1)
|
||||
|
||||
# # in the fifth iteration, large request gets all the budget
|
||||
# # while both small requests are decoding
|
||||
# seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
# assert seq_group_meta[0].token_chunk_size == 62
|
||||
# assert seq_group_meta[1].token_chunk_size == 1 # decode
|
||||
# assert seq_group_meta[2].token_chunk_size == 1 # decode
|
||||
# assert out.num_prefill_groups == 1
|
||||
# assert out.num_batched_tokens == 64
|
||||
|
||||
|
||||
def test_complex():
|
||||
block_size = 4
|
||||
max_seqs = 60
|
||||
@@ -508,7 +737,7 @@ def test_chunked_prefill_max_seqs():
|
||||
assert not running[1].is_prefill()
|
||||
|
||||
|
||||
def test_perfix_caching():
|
||||
def test_prefix_caching():
|
||||
"""Verify allocating full blocks when prefix caching is enabled."""
|
||||
block_size = 4
|
||||
max_seqs = 10
|
||||
@@ -548,3 +777,86 @@ def test_perfix_caching():
|
||||
assert seq_group_meta[1].token_chunk_size == 12
|
||||
assert out.num_prefill_groups == 2
|
||||
assert out.num_batched_tokens == 62
|
||||
|
||||
|
||||
def test_prefix_caching_with_concurrent_partial_prefills():
|
||||
"""Verify allocating full blocks when prefix caching is enabled with
|
||||
--max-num-partial-prefills > 1."""
|
||||
block_size = 4
|
||||
max_seqs = 10
|
||||
max_model_len = 8000
|
||||
max_num_batched_tokens = 60 # With two slots, each slot will get 30 tokens
|
||||
scheduler_config = SchedulerConfig("generate",
|
||||
max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_partial_prefills=2)
|
||||
cache_config = CacheConfig(block_size,
|
||||
1.0,
|
||||
1,
|
||||
"auto",
|
||||
enable_prefix_caching=True)
|
||||
cache_config.num_cpu_blocks = 0
|
||||
cache_config.num_gpu_blocks = 32
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
running: List[SequenceGroup] = []
|
||||
|
||||
# Add seq groups to scheduler.
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i),
|
||||
block_size=block_size,
|
||||
prompt_length=50)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
# To partially prefill both sequences, both can chunk up to 30 tokens
|
||||
# But the next lowest multiple of the block size (4) is 28
|
||||
assert seq_group_meta[0].token_chunk_size == 28
|
||||
assert seq_group_meta[1].token_chunk_size == 28
|
||||
assert out.num_prefill_groups == 2
|
||||
assert out.num_batched_tokens == 56
|
||||
|
||||
# On the next iteration, both sequences should finish prefill
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
# Both sequences have 50 - 28 = 22 tokens left to prefill.
|
||||
# This is not a multiple of the block size, but we don't care since we don't
|
||||
# cache the final partial block of prefix sequences
|
||||
assert seq_group_meta[0].token_chunk_size == 22
|
||||
assert seq_group_meta[1].token_chunk_size == 22
|
||||
assert out.num_prefill_groups == 2
|
||||
assert out.num_batched_tokens == 44
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||
@pytest.mark.parametrize("max_num_partial_prefills", [2, 4, 8])
|
||||
def test_chunked_prefill_with_actual_engine(model: str,
|
||||
max_num_partial_prefills: int):
|
||||
"""Make sure the model can actually sample with concurrent
|
||||
partial prefills
|
||||
"""
|
||||
|
||||
prompt = "hello" * 40
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
max_num_partial_prefills=max_num_partial_prefills,
|
||||
max_num_batched_tokens=40,
|
||||
max_num_seqs=8,
|
||||
enable_chunked_prefill=True,
|
||||
gpu_memory_utilization=0.8,
|
||||
)
|
||||
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
|
||||
for req_num in range(max_num_partial_prefills):
|
||||
engine.add_request(f"{req_num}", prompt, sampling_params)
|
||||
# first step
|
||||
request_outputs = engine.step()
|
||||
# means all are prefilling
|
||||
assert len(request_outputs) == 0
|
||||
assert len(engine.scheduler[0].running) == max_num_partial_prefills
|
||||
|
||||
Reference in New Issue
Block a user