[V1] Prompt logprobs + APC compatibility; prompt logprobs reqs cannot fill APC (#13949)
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@@ -16,7 +18,21 @@ def create_scheduler(
|
||||
model: str = "facebook/opt-125m",
|
||||
max_num_seqs: int = 16,
|
||||
max_num_batched_tokens: int = 8192,
|
||||
enable_prefix_caching: Optional[bool] = None,
|
||||
) -> Scheduler:
|
||||
'''Create scheduler under test.
|
||||
|
||||
Args:
|
||||
model: model under test
|
||||
max_num_seqs: max sequences to schedule
|
||||
max_num_batch_tokens: max num tokens to batch
|
||||
enable_prefix_caching: optionally force APC config
|
||||
(True/False) or use default
|
||||
(None)
|
||||
|
||||
Returns:
|
||||
:class:`Scheduler` instance
|
||||
'''
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
@@ -31,11 +47,16 @@ def create_scheduler(
|
||||
dtype="float16",
|
||||
seed=42,
|
||||
)
|
||||
# Cache config, optionally force APC
|
||||
kwargs_cache = ({} if enable_prefix_caching is None else {
|
||||
'enable_prefix_caching': enable_prefix_caching
|
||||
})
|
||||
cache_config = CacheConfig(
|
||||
block_size=16,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
**kwargs_cache,
|
||||
)
|
||||
vllm_config = VllmConfig(
|
||||
scheduler_config=scheduler_config,
|
||||
@@ -54,16 +75,16 @@ def create_scheduler(
|
||||
)
|
||||
|
||||
|
||||
def create_requests(
|
||||
num_requests: int,
|
||||
num_tokens: int = 10,
|
||||
mm_positions: Optional[list[PlaceholderRange]] = None,
|
||||
max_tokens: int = 16,
|
||||
stop_token_ids: Optional[list[int]] = None,
|
||||
):
|
||||
def create_requests(num_requests: int,
|
||||
num_tokens: int = 10,
|
||||
mm_positions: Optional[list[PlaceholderRange]] = None,
|
||||
max_tokens: int = 16,
|
||||
stop_token_ids: Optional[list[int]] = None,
|
||||
prompt_logprobs: Optional[int] = None):
|
||||
sampling_params = SamplingParams(ignore_eos=False,
|
||||
max_tokens=max_tokens,
|
||||
stop_token_ids=stop_token_ids)
|
||||
stop_token_ids=stop_token_ids,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
requests = []
|
||||
for i in range(num_requests):
|
||||
if mm_positions is not None:
|
||||
@@ -122,9 +143,18 @@ def test_get_num_unfinished_requests():
|
||||
assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1
|
||||
|
||||
|
||||
def test_schedule():
|
||||
scheduler = create_scheduler()
|
||||
requests = create_requests(num_requests=10)
|
||||
@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [
|
||||
(None, None),
|
||||
(True, 5),
|
||||
])
|
||||
def test_schedule(enable_prefix_caching: Optional[bool],
|
||||
prompt_logprobs: Optional[int]):
|
||||
'''Test scheduling.
|
||||
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
|
||||
'''
|
||||
scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching)
|
||||
requests = create_requests(num_requests=10,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
@@ -427,14 +457,21 @@ def test_stop_via_update_from_output():
|
||||
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]
|
||||
|
||||
|
||||
def test_schedule_concurrent_batches():
|
||||
@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [
|
||||
(None, None),
|
||||
(True, 5),
|
||||
])
|
||||
def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
|
||||
prompt_logprobs: Optional[int]):
|
||||
scheduler = create_scheduler(
|
||||
max_num_batched_tokens=1024,
|
||||
max_num_seqs=2,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
)
|
||||
requests = create_requests(
|
||||
num_requests=2,
|
||||
num_tokens=512,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
|
||||
# Schedule the first request.
|
||||
|
||||
Reference in New Issue
Block a user