[V1][PP] Run engine busy loop with batch queue (#13064)
This commit is contained in:
@@ -4,8 +4,9 @@ import queue
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import Future
|
||||
from multiprocessing.connection import Connection
|
||||
from typing import Any, List, Tuple, Type
|
||||
from typing import Any, List, Optional, Tuple, Type
|
||||
|
||||
import psutil
|
||||
import zmq
|
||||
@@ -18,11 +19,12 @@ from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.utils import get_exception_traceback, zmq_socket_ctx
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
|
||||
from vllm.v1.core.scheduler import Scheduler
|
||||
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType)
|
||||
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
@@ -66,9 +68,22 @@ class EngineCore:
|
||||
log_stats=self.log_stats,
|
||||
)
|
||||
|
||||
# Setup MM Input Mapper.
|
||||
self.mm_input_cache_server = MMInputCacheServer(
|
||||
vllm_config.model_config)
|
||||
|
||||
# Setup batch queue for pipeline parallelism.
|
||||
# Batch queue for scheduled batches. This enables us to asynchronously
|
||||
# schedule and execute batches, and is required by pipeline parallelism
|
||||
# to eliminate pipeline bubbles.
|
||||
self.batch_queue_size = self.model_executor.max_concurrent_batches
|
||||
self.batch_queue: Optional[queue.Queue[Tuple[Future[ModelRunnerOutput],
|
||||
SchedulerOutput]]] = None
|
||||
if self.batch_queue_size > 1:
|
||||
logger.info("Batch queue is enabled with size %d",
|
||||
self.batch_queue_size)
|
||||
self.batch_queue = queue.Queue(self.batch_queue_size)
|
||||
|
||||
def _initialize_kv_caches(self,
|
||||
vllm_config: VllmConfig) -> Tuple[int, int]:
|
||||
start = time.time()
|
||||
@@ -135,7 +150,55 @@ class EngineCore:
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
output = self.model_executor.execute_model(scheduler_output)
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, output)
|
||||
scheduler_output, output) # type: ignore
|
||||
return engine_core_outputs
|
||||
|
||||
def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]:
|
||||
"""Schedule and execute batches with the batch queue.
|
||||
Note that if nothing to output in this step, None is returned.
|
||||
|
||||
The execution flow is as follows:
|
||||
1. Try to schedule a new batch if there are unscheduled requests
|
||||
and the job queue is not full. If a new batch is scheduled, directly
|
||||
return an empty engine core output. In other words, we won't check
|
||||
and return model outputs before the batch queue is full.
|
||||
2. If there is no new scheduled batch, meaning that the batch queue
|
||||
is full or no other requests can be scheduled, we block until the first
|
||||
batch in the job queue is finished.
|
||||
3. Update the scheduler from the output.
|
||||
"""
|
||||
assert self.batch_queue is not None
|
||||
|
||||
engine_core_outputs = None
|
||||
scheduler_output = None
|
||||
# If there are unscheduled requests and the job queue
|
||||
# is not full, schedule a new batch. Note that this is not blocking.
|
||||
if (self.scheduler.get_num_unscheduled_requests() > 0
|
||||
and not self.batch_queue.full()):
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
if scheduler_output.total_num_scheduled_tokens > 0:
|
||||
future = self.model_executor.execute_model(scheduler_output)
|
||||
self.batch_queue.put_nowait(
|
||||
(future, scheduler_output)) # type: ignore
|
||||
|
||||
# If all requests are scheduled or the job queue is full,
|
||||
# block until the first batch in the job queue is finished.
|
||||
if (scheduler_output is None
|
||||
or scheduler_output.total_num_scheduled_tokens == 0):
|
||||
try:
|
||||
future, scheduler_output = self.batch_queue.get(
|
||||
timeout=POLLING_TIMEOUT_S)
|
||||
# Blocking until the first result is available.
|
||||
model_output = future.result()
|
||||
self.batch_queue.task_done()
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output)
|
||||
except queue.Empty:
|
||||
# If the queue is empty (timeout at .get), return
|
||||
# an empty EngineCoreOutputs for logging.
|
||||
engine_core_outputs = EngineCoreOutputs(
|
||||
outputs=[], scheduler_stats=self.scheduler.make_stats())
|
||||
|
||||
return engine_core_outputs
|
||||
|
||||
def shutdown(self):
|
||||
@@ -226,6 +289,9 @@ class EngineCoreProc(EngineCore):
|
||||
def run_busy_loop(self):
|
||||
"""Core busy loop of the EngineCore."""
|
||||
|
||||
step_fn = (self.step
|
||||
if self.batch_queue is None else self.step_with_batch_queue)
|
||||
|
||||
# Loop until process is sent a SIGINT or SIGTERM
|
||||
while True:
|
||||
# 1) Poll the input queue until there is work to do.
|
||||
@@ -249,10 +315,11 @@ class EngineCoreProc(EngineCore):
|
||||
self._handle_client_request(*req)
|
||||
|
||||
# 3) Step the engine core.
|
||||
outputs = self.step()
|
||||
outputs = step_fn()
|
||||
|
||||
# 5) Put EngineCoreOutputs into the output queue.
|
||||
self.output_queue.put_nowait(outputs)
|
||||
# 4) Put EngineCoreOutputs into the output queue.
|
||||
if outputs is not None:
|
||||
self.output_queue.put_nowait(outputs)
|
||||
|
||||
def _handle_client_request(self, request_type: EngineCoreRequestType,
|
||||
request: Any) -> None:
|
||||
|
||||
Reference in New Issue
Block a user