[V1] Prompt logprobs + APC compatibility; prompt logprobs reqs cannot fill APC (#13949)

This commit is contained in:
afeldman-nm
2025-03-07 20:48:12 -05:00
committed by GitHub
parent 66e16a038e
commit ef64044079
9 changed files with 291 additions and 161 deletions

View File

@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import pytest
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
@@ -16,7 +18,21 @@ def create_scheduler(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192,
enable_prefix_caching: Optional[bool] = None,
) -> Scheduler:
'''Create scheduler under test.
Args:
model: model under test
max_num_seqs: max sequences to schedule
max_num_batch_tokens: max num tokens to batch
enable_prefix_caching: optionally force APC config
(True/False) or use default
(None)
Returns:
:class:`Scheduler` instance
'''
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
@@ -31,11 +47,16 @@ def create_scheduler(
dtype="float16",
seed=42,
)
# Cache config, optionally force APC
kwargs_cache = ({} if enable_prefix_caching is None else {
'enable_prefix_caching': enable_prefix_caching
})
cache_config = CacheConfig(
block_size=16,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
**kwargs_cache,
)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
@@ -54,16 +75,16 @@ def create_scheduler(
)
def create_requests(
num_requests: int,
num_tokens: int = 10,
mm_positions: Optional[list[PlaceholderRange]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None,
):
def create_requests(num_requests: int,
num_tokens: int = 10,
mm_positions: Optional[list[PlaceholderRange]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None):
sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids)
stop_token_ids=stop_token_ids,
prompt_logprobs=prompt_logprobs)
requests = []
for i in range(num_requests):
if mm_positions is not None:
@@ -122,9 +143,18 @@ def test_get_num_unfinished_requests():
assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1
def test_schedule():
scheduler = create_scheduler()
requests = create_requests(num_requests=10)
@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [
(None, None),
(True, 5),
])
def test_schedule(enable_prefix_caching: Optional[bool],
prompt_logprobs: Optional[int]):
'''Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
'''
scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching)
requests = create_requests(num_requests=10,
prompt_logprobs=prompt_logprobs)
for request in requests:
scheduler.add_request(request)
@@ -427,14 +457,21 @@ def test_stop_via_update_from_output():
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]
def test_schedule_concurrent_batches():
@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [
(None, None),
(True, 5),
])
def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
prompt_logprobs: Optional[int]):
scheduler = create_scheduler(
max_num_batched_tokens=1024,
max_num_seqs=2,
enable_prefix_caching=enable_prefix_caching,
)
requests = create_requests(
num_requests=2,
num_tokens=512,
prompt_logprobs=prompt_logprobs,
)
# Schedule the first request.