[V1][PP] Run engine busy loop with batch queue (#13064)
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import copy
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import Future
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
@@ -12,7 +15,9 @@ from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core import EngineCore
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.executor.abstract import Executor, UniProcExecutor
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip(reason="V1 currently only supported on CUDA.",
|
||||
@@ -191,3 +196,85 @@ def test_engine_core_advanced_sampling(monkeypatch):
|
||||
)
|
||||
engine_core.add_request(request2)
|
||||
_check_engine_state()
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
def test_engine_core_concurrent_batches(monkeypatch):
|
||||
"""
|
||||
Test that the engine can handle multiple concurrent batches.
|
||||
"""
|
||||
|
||||
def make_request_with_max_tokens(max_tokens: int) -> EngineCoreRequest:
|
||||
request = make_request()
|
||||
request.sampling_params.max_tokens = max_tokens
|
||||
return request
|
||||
|
||||
class DummyExecutor(UniProcExecutor):
|
||||
|
||||
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
super().initialize(kv_cache_config)
|
||||
|
||||
# This executor actually can only run 1 batch at a time
|
||||
self.semaphore = threading.Semaphore(1)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output,
|
||||
) -> Future[ModelRunnerOutput]:
|
||||
"""Make execute_model non-blocking."""
|
||||
future: Future[ModelRunnerOutput] = Future()
|
||||
|
||||
def _thread_wrapper(scheduler_output, future):
|
||||
with self.semaphore:
|
||||
output = self.collective_rpc("execute_model",
|
||||
args=(scheduler_output, ))
|
||||
# Make a copy because output[0] may be reused
|
||||
# by the next batch.
|
||||
output = copy.deepcopy(output[0])
|
||||
future.set_result(output)
|
||||
|
||||
threading.Thread(target=_thread_wrapper,
|
||||
args=(scheduler_output, future)).start()
|
||||
return future
|
||||
|
||||
@property
|
||||
def max_concurrent_batches(self) -> int:
|
||||
return 2
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=MODEL_NAME,
|
||||
# To test concurrent batches.
|
||||
max_num_seqs=2,
|
||||
# Avoid all requests being scheduled once.
|
||||
enable_prefix_caching=False,
|
||||
max_num_batched_tokens=10,
|
||||
)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
log_stats=False,
|
||||
executor_class=DummyExecutor)
|
||||
assert engine_core.batch_queue is not None
|
||||
|
||||
# Add two requests in a row.
|
||||
req = make_request_with_max_tokens(5)
|
||||
engine_core.add_request(req)
|
||||
req = make_request_with_max_tokens(5)
|
||||
engine_core.add_request(req)
|
||||
|
||||
# First saturate the batch queue.
|
||||
assert engine_core.step_with_batch_queue() is None
|
||||
assert engine_core.batch_queue.qsize() == 1
|
||||
assert engine_core.step_with_batch_queue() is None
|
||||
assert engine_core.batch_queue.qsize() == 2
|
||||
assert engine_core.scheduler.get_num_unfinished_requests() == 2
|
||||
|
||||
# Loop through both requests.
|
||||
while engine_core.scheduler.get_num_unfinished_requests() == 2:
|
||||
engine_core.step_with_batch_queue()
|
||||
|
||||
# Reaching here when got the result of the first request.
|
||||
while engine_core.scheduler.get_num_unfinished_requests() == 1:
|
||||
engine_core.step_with_batch_queue()
|
||||
|
||||
Reference in New Issue
Block a user