[V1] AsyncLLM data parallel (#13923)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
import queue
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import Future
|
||||
from inspect import isclass, signature
|
||||
from multiprocessing.connection import Connection
|
||||
from logging import DEBUG
|
||||
from typing import Any, Optional
|
||||
|
||||
import msgspec
|
||||
@@ -14,7 +15,9 @@ import psutil
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
||||
from vllm.executor.multiproc_worker_utils import _add_prefix
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.config import (
|
||||
@@ -91,6 +94,8 @@ class EngineCore:
|
||||
cache_config=vllm_config.cache_config,
|
||||
lora_config=vllm_config.lora_config,
|
||||
speculative_config=vllm_config.speculative_config,
|
||||
include_finished_set=vllm_config.parallel_config.data_parallel_size
|
||||
> 1,
|
||||
log_stats=self.log_stats,
|
||||
structured_output_manager=self.structured_output_manager,
|
||||
)
|
||||
@@ -283,10 +288,10 @@ class EngineCoreProc(EngineCore):
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
ready_pipe: Connection,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
engine_index: int = 0,
|
||||
):
|
||||
super().__init__(vllm_config, executor_class, log_stats)
|
||||
|
||||
@@ -302,14 +307,20 @@ class EngineCoreProc(EngineCore):
|
||||
args=(input_path, ),
|
||||
daemon=True).start()
|
||||
threading.Thread(target=self.process_output_socket,
|
||||
args=(output_path, ),
|
||||
args=(output_path, engine_index),
|
||||
daemon=True).start()
|
||||
|
||||
# Send Readiness signal to EngineClient.
|
||||
ready_pipe.send({"status": "READY"})
|
||||
self.global_unfinished_reqs = False
|
||||
|
||||
self.step_fn = (self.step if self.batch_queue is None else
|
||||
self.step_with_batch_queue)
|
||||
|
||||
@staticmethod
|
||||
def run_engine_core(*args, **kwargs):
|
||||
def run_engine_core(*args,
|
||||
dp_rank: int = 0,
|
||||
local_dp_rank: int = 0,
|
||||
ready_pipe,
|
||||
**kwargs):
|
||||
"""Launch EngineCore busy loop in background process."""
|
||||
|
||||
# Signal handler used for graceful termination.
|
||||
@@ -331,9 +342,21 @@ class EngineCoreProc(EngineCore):
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
parent_process = psutil.Process().parent()
|
||||
engine_core = None
|
||||
engine_core: Optional[EngineCoreProc] = None
|
||||
try:
|
||||
engine_core = EngineCoreProc(*args, **kwargs)
|
||||
parallel_config: ParallelConfig = kwargs[
|
||||
"vllm_config"].parallel_config
|
||||
if parallel_config.data_parallel_size > 1:
|
||||
# Set data parallel rank for this engine process.
|
||||
parallel_config.data_parallel_rank = dp_rank
|
||||
parallel_config.data_parallel_rank_local = local_dp_rank
|
||||
engine_core = DPEngineCoreProc(*args, **kwargs)
|
||||
else:
|
||||
engine_core = EngineCoreProc(*args, **kwargs)
|
||||
|
||||
# Send Readiness signal to EngineClient.
|
||||
ready_pipe.send({"status": "READY"})
|
||||
|
||||
engine_core.run_busy_loop()
|
||||
|
||||
except SystemExit:
|
||||
@@ -351,28 +374,44 @@ class EngineCoreProc(EngineCore):
|
||||
def run_busy_loop(self):
|
||||
"""Core busy loop of the EngineCore."""
|
||||
|
||||
step_fn = (self.step
|
||||
if self.batch_queue is None else self.step_with_batch_queue)
|
||||
|
||||
# Loop until process is sent a SIGINT or SIGTERM
|
||||
while True:
|
||||
# 1) Poll the input queue until there is work to do.
|
||||
while not self.scheduler.has_requests():
|
||||
logger.debug("EngineCore busy loop waiting.")
|
||||
req = self.input_queue.get()
|
||||
self._handle_client_request(*req)
|
||||
self._process_input_queue()
|
||||
# 2) Step the engine core and return the outputs.
|
||||
self._process_engine_step()
|
||||
|
||||
# 2) Handle any new client requests.
|
||||
while not self.input_queue.empty():
|
||||
req = self.input_queue.get_nowait()
|
||||
self._handle_client_request(*req)
|
||||
def _process_input_queue(self):
|
||||
"""Exits when an engine step needs to be performed."""
|
||||
|
||||
# 3) Step the engine core.
|
||||
outputs = step_fn()
|
||||
waited = False
|
||||
while not self.global_unfinished_reqs and not (
|
||||
self.scheduler.has_requests()):
|
||||
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
|
||||
logger.debug("EngineCore waiting for work.")
|
||||
waited = True
|
||||
req = self.input_queue.get()
|
||||
self._handle_client_request(*req)
|
||||
|
||||
# 4) Put EngineCoreOutputs into the output queue.
|
||||
if outputs is not None:
|
||||
self.output_queue.put_nowait(outputs)
|
||||
if waited:
|
||||
logger.debug(
|
||||
"EngineCore loop active - local unfinished: %s, finished: %s.",
|
||||
self.scheduler.has_unfinished_requests(),
|
||||
self.scheduler.has_finished_requests())
|
||||
|
||||
# Handle any more client requests.
|
||||
while not self.input_queue.empty():
|
||||
req = self.input_queue.get_nowait()
|
||||
self._handle_client_request(*req)
|
||||
|
||||
def _process_engine_step(self):
|
||||
"""Called only when there are unfinished local requests."""
|
||||
|
||||
# Step the engine core.
|
||||
outputs = self.step_fn()
|
||||
# Put EngineCoreOutputs into the output queue.
|
||||
if outputs is not None:
|
||||
self.output_queue.put_nowait(outputs)
|
||||
|
||||
def _handle_client_request(self, request_type: EngineCoreRequestType,
|
||||
request: Any) -> None:
|
||||
@@ -382,6 +421,10 @@ class EngineCoreProc(EngineCore):
|
||||
self.add_request(request)
|
||||
elif request_type == EngineCoreRequestType.ABORT:
|
||||
self.abort_requests(request)
|
||||
elif request_type == EngineCoreRequestType.START_DP:
|
||||
if not self.global_unfinished_reqs:
|
||||
logger.debug("EngineCore starting idle loop.")
|
||||
self.global_unfinished_reqs = True
|
||||
elif request_type == EngineCoreRequestType.UTILITY:
|
||||
call_id, method_name, args = request
|
||||
output = UtilityOutput(call_id)
|
||||
@@ -432,7 +475,7 @@ class EngineCoreProc(EngineCore):
|
||||
# Push to input queue for core busy loop.
|
||||
self.input_queue.put_nowait((request_type, request))
|
||||
|
||||
def process_output_socket(self, output_path: str):
|
||||
def process_output_socket(self, output_path: str, engine_index: int):
|
||||
"""Output socket IO thread."""
|
||||
|
||||
# Msgpack serialization encoding.
|
||||
@@ -443,5 +486,114 @@ class EngineCoreProc(EngineCore):
|
||||
with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
|
||||
while True:
|
||||
outputs = self.output_queue.get()
|
||||
outputs.engine_index = engine_index
|
||||
encoder.encode_into(outputs, buffer)
|
||||
socket.send_multipart((buffer, ), copy=False)
|
||||
socket.send(buffer, copy=False)
|
||||
|
||||
|
||||
ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True)
|
||||
|
||||
|
||||
class DPEngineCoreProc(EngineCoreProc):
|
||||
"""ZMQ-wrapper for running EngineCore in background process
|
||||
in a data parallel context."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
):
|
||||
# Add process-specific prefix to stdout and stderr before
|
||||
# we initialize the engine.
|
||||
from multiprocessing import current_process
|
||||
process_name = current_process().name
|
||||
pid = os.getpid()
|
||||
_add_prefix(sys.stdout, process_name, pid)
|
||||
_add_prefix(sys.stderr, process_name, pid)
|
||||
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
|
||||
|
||||
assert dp_size > 1
|
||||
assert 0 <= local_dp_rank <= dp_rank < dp_size
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm.platforms.cuda import device_id_to_physical_device_id
|
||||
tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
|
||||
str(device_id_to_physical_device_id(i))
|
||||
for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
|
||||
tp_size))
|
||||
|
||||
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
|
||||
|
||||
# Initialize the engine after setting up environment.
|
||||
super().__init__(input_path, output_path, vllm_config, executor_class,
|
||||
log_stats, dp_rank)
|
||||
|
||||
# Counts forward-passes of the model so that we can synchronize
|
||||
# finished with DP peers every N steps.
|
||||
self.counter = 0
|
||||
|
||||
def shutdown(self):
|
||||
super().shutdown()
|
||||
if dp_group := getattr(self, "dp_group", None):
|
||||
stateless_destroy_torch_distributed_process_group(dp_group)
|
||||
|
||||
def run_busy_loop(self):
|
||||
"""Core busy loop of the EngineCore for data parallel case."""
|
||||
|
||||
# Loop until process is sent a SIGINT or SIGTERM
|
||||
while True:
|
||||
# 1) Poll the input queue until there is work to do.
|
||||
self._process_input_queue()
|
||||
|
||||
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
|
||||
|
||||
if local_unfinished_reqs:
|
||||
# 2) Step the engine core.
|
||||
self._process_engine_step()
|
||||
|
||||
# Check if we have now finished all requests.
|
||||
local_unfinished_reqs = (
|
||||
self.scheduler.has_unfinished_requests())
|
||||
else:
|
||||
if self.scheduler.has_finished_requests():
|
||||
# There are no unfinished requests, but there are some
|
||||
# finished requests remaining to be removed from the
|
||||
# batch state. This engine step won't perform a forward
|
||||
# pass but will flush the finished requests to ensure
|
||||
# up-to-date state is returned in the engine outputs.
|
||||
self._process_engine_step()
|
||||
|
||||
if not self.global_unfinished_reqs:
|
||||
# All engines are idle.
|
||||
continue
|
||||
|
||||
# There must be unfinished requests in DP peers, run a
|
||||
# dummy forward pass.
|
||||
self.execute_dummy_batch()
|
||||
|
||||
# 3) All-reduce operation to determine global unfinished reqs.
|
||||
self.global_unfinished_reqs = self._has_global_unfinished_reqs(
|
||||
local_unfinished_reqs)
|
||||
|
||||
if not self.global_unfinished_reqs:
|
||||
# Notify client that we are pausing the loop.
|
||||
self.output_queue.put_nowait(ENGINE_PAUSED_OUTPUTS)
|
||||
|
||||
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
|
||||
|
||||
# Optimization - only perform finish-sync all-reduce every 16 steps.
|
||||
self.counter += 1
|
||||
if self.counter != 16:
|
||||
return True
|
||||
self.counter = 0
|
||||
|
||||
return ParallelConfig.has_unfinished_dp(self.dp_group,
|
||||
local_unfinished)
|
||||
|
||||
Reference in New Issue
Block a user