[Perf][V1] Fully overlap model execution (#23569)
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
This commit is contained in:
committed by
GitHub
parent
c954c6629c
commit
cee182b297
@@ -3,6 +3,7 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
import queue
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
@@ -33,7 +34,8 @@ from vllm.utils import (decorate_logs, get_distributed_init_method,
|
||||
get_loopback_ip, get_mp_context, get_open_port,
|
||||
set_process_title)
|
||||
from vllm.v1.executor.abstract import Executor, FailureCallback
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds,
|
||||
ModelRunnerOutput)
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -414,6 +416,16 @@ class WorkerProc:
|
||||
# Initializes a message queue for sending the model output
|
||||
self.worker_response_mq = MessageQueue(1, 1)
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.use_async_scheduling = scheduler_config.async_scheduling
|
||||
if self.use_async_scheduling:
|
||||
self.async_output_queue: queue.Queue = queue.Queue()
|
||||
self.async_output_copy_thread = Thread(
|
||||
target=self.async_output_busy_loop,
|
||||
daemon=True,
|
||||
name="WorkerAsyncOutputCopy")
|
||||
self.async_output_copy_thread.start()
|
||||
|
||||
# Initialize device and loads weights
|
||||
self.worker.init_device()
|
||||
self.worker.load_model()
|
||||
@@ -595,6 +607,36 @@ class WorkerProc:
|
||||
SUCCESS = auto()
|
||||
FAILURE = auto()
|
||||
|
||||
def enqueue_output(self, output: Any):
|
||||
"""Prepares output from the worker and enqueues it to the
|
||||
worker_response_mq. If the output is an Exception, it is
|
||||
converted to a FAILURE response.
|
||||
"""
|
||||
if isinstance(output, AsyncModelRunnerOutput):
|
||||
output = output.get_output()
|
||||
|
||||
if isinstance(output, Exception):
|
||||
result = (WorkerProc.ResponseStatus.FAILURE, str(output))
|
||||
else:
|
||||
result = (WorkerProc.ResponseStatus.SUCCESS, output)
|
||||
self.worker_response_mq.enqueue(result)
|
||||
|
||||
def handle_output(self, output: Any):
|
||||
"""Handles output from the worker. If async scheduling is enabled,
|
||||
it is passed to the async_output_busy_loop thread. Otherwise, it is
|
||||
enqueued directly to the worker_response_mq.
|
||||
"""
|
||||
if self.use_async_scheduling:
|
||||
self.async_output_queue.put(output)
|
||||
else:
|
||||
self.enqueue_output(output)
|
||||
|
||||
def async_output_busy_loop(self):
|
||||
"""Entrypoint for the thread which handles outputs asynchronously."""
|
||||
while True:
|
||||
output = self.async_output_queue.get()
|
||||
self.enqueue_output(output)
|
||||
|
||||
def worker_busy_loop(self):
|
||||
"""Main busy loop for Multiprocessing Workers"""
|
||||
while True:
|
||||
@@ -614,10 +656,8 @@ class WorkerProc:
|
||||
# exception might not be serializable, so we convert it to
|
||||
# string, only for logging purpose.
|
||||
if output_rank is None or self.rank == output_rank:
|
||||
self.worker_response_mq.enqueue(
|
||||
(WorkerProc.ResponseStatus.FAILURE, str(e)))
|
||||
self.handle_output(e)
|
||||
continue
|
||||
|
||||
if output_rank is None or self.rank == output_rank:
|
||||
self.worker_response_mq.enqueue(
|
||||
(WorkerProc.ResponseStatus.SUCCESS, output))
|
||||
self.handle_output(output)
|
||||
|
||||
Reference in New Issue
Block a user