[V1] DP scale-out (2/N): Decouple engine process management and comms (#15977)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-05-13 10:48:21 -07:00
committed by GitHub
parent 0b217da646
commit 55aa7af994
10 changed files with 516 additions and 243 deletions

View File

@@ -1,5 +1,4 @@
# SPDX-License-Identifier: Apache-2.0
import json
import os
import queue
import signal
@@ -23,7 +22,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import resolve_obj_by_qualname, zmq_socket_ctx
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname, zmq_socket_ctx
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
unify_kv_cache_configs)
from vllm.v1.core.sched.interface import SchedulerInterface
@@ -43,6 +42,7 @@ from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
POLLING_TIMEOUT_S = 2.5
HANDSHAKE_TIMEOUT_MINS = 5
_R = TypeVar('_R') # Return type for collective_rpc
@@ -348,9 +348,9 @@ class EngineCoreProc(EngineCore):
def __init__(
self,
input_path: str,
output_path: str,
vllm_config: VllmConfig,
on_head_node: bool,
input_address: str,
executor_class: type[Executor],
log_stats: bool,
engine_index: int = 0,
@@ -360,28 +360,91 @@ class EngineCoreProc(EngineCore):
executor_fail_callback = lambda: input_queue.put_nowait(
(EngineCoreRequestType.EXECUTOR_FAILED, b''))
super().__init__(vllm_config, executor_class, log_stats,
executor_fail_callback)
# Create input socket.
input_ctx = zmq.Context()
identity = engine_index.to_bytes(length=2, byteorder="little")
input_socket = make_zmq_socket(input_ctx,
input_address,
zmq.DEALER,
identity=identity,
bind=False)
try:
# Register engine with front-end.
output_address = self.startup_handshake(
input_socket, on_head_node, vllm_config.parallel_config)
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
self.engines_running = False
# Update config which may have changed from the handshake.
vllm_config.__post_init__()
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue = input_queue
self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]()
threading.Thread(target=self.process_input_socket,
args=(input_path, engine_index),
daemon=True).start()
self.output_thread = threading.Thread(
target=self.process_output_socket,
args=(output_path, engine_index),
daemon=True)
self.output_thread.start()
# Set up data parallel environment.
self._init_data_parallel(vllm_config)
# Initialize engine core and model.
super().__init__(vllm_config, executor_class, log_stats,
executor_fail_callback)
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
self.engines_running = False
# Send ready message.
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
input_socket.send(
msgspec.msgpack.encode({
"status": "READY",
"local": on_head_node,
"num_gpu_blocks": num_gpu_blocks,
}))
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue = input_queue
self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]()
threading.Thread(target=self.process_input_socket,
args=(input_socket, ),
daemon=True).start()
input_socket = None
self.output_thread = threading.Thread(
target=self.process_output_socket,
args=(output_address, engine_index),
daemon=True)
self.output_thread.start()
finally:
if input_socket is not None:
input_socket.close(linger=0)
@staticmethod
def startup_handshake(input_socket: zmq.Socket, on_head_node: bool,
parallel_config: ParallelConfig) -> str:
# Send registration message.
input_socket.send(
msgspec.msgpack.encode({
"status": "HELLO",
"local": on_head_node,
}))
# Receive initialization message.
logger.info("Waiting for init message from front-end.")
if not input_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60 * 1000):
raise RuntimeError("Did not receive response from front-end "
f"process within {HANDSHAKE_TIMEOUT_MINS} "
f"minutes")
init_bytes = input_socket.recv()
init_message = msgspec.msgpack.decode(init_bytes)
logger.debug("Received init message: %s", init_message)
output_socket_address = init_message["output_socket_address"]
#TBD(nick) maybe replace IP with configured head node address
received_parallel_config = init_message["parallel_config"]
for key, value in received_parallel_config.items():
setattr(parallel_config, key, value)
return output_socket_address
@staticmethod
def run_engine_core(*args,
@@ -412,7 +475,7 @@ class EngineCoreProc(EngineCore):
try:
parallel_config: ParallelConfig = kwargs[
"vllm_config"].parallel_config
if parallel_config.data_parallel_size > 1:
if parallel_config.data_parallel_size > 1 or dp_rank > 0:
# Set data parallel rank for this engine process.
parallel_config.data_parallel_rank = dp_rank
parallel_config.data_parallel_rank_local = local_dp_rank
@@ -436,6 +499,9 @@ class EngineCoreProc(EngineCore):
if engine_core is not None:
engine_core.shutdown()
def _init_data_parallel(self, vllm_config: VllmConfig):
pass
def run_busy_loop(self):
"""Core busy loop of the EngineCore."""
@@ -527,40 +593,25 @@ class EngineCoreProc(EngineCore):
logger.fatal("vLLM shutdown signal from EngineCore failed "
"to send. Please report this issue.")
def process_input_socket(self, input_path: str, engine_index: int):
def process_input_socket(self, input_socket: zmq.Socket):
"""Input socket IO thread."""
# Msgpack serialization decoding.
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
generic_decoder = MsgpackDecoder()
identity = engine_index.to_bytes(length=2, byteorder="little")
with zmq_socket_ctx(input_path,
zmq.DEALER,
identity=identity,
bind=False) as socket:
while True:
# (RequestType, RequestData)
type_frame, *data_frames = input_socket.recv_multipart(copy=False)
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
# Send ready message to front-end once input socket is connected.
message_dict = {
'type': 'READY',
'num_gpu_blocks': self.vllm_config.cache_config.num_gpu_blocks,
}
message = json.dumps(message_dict).encode('utf-8')
socket.send(message)
# Deserialize the request data.
decoder = add_request_decoder if (
request_type == EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frames)
while True:
# (RequestType, RequestData)
type_frame, *data_frames = socket.recv_multipart(copy=False)
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
# Deserialize the request data.
decoder = add_request_decoder if (
request_type
== EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frames)
# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
def process_output_socket(self, output_path: str, engine_index: int):
"""Output socket IO thread."""
@@ -609,9 +660,9 @@ class DPEngineCoreProc(EngineCoreProc):
def __init__(
self,
input_path: str,
output_path: str,
vllm_config: VllmConfig,
on_head_node: bool,
input_address: str,
executor_class: type[Executor],
log_stats: bool,
):
@@ -623,8 +674,20 @@ class DPEngineCoreProc(EngineCoreProc):
_add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid)
dp_size = vllm_config.parallel_config.data_parallel_size
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self.counter = 0
# Initialize the engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank
super().__init__(vllm_config, on_head_node, input_address,
executor_class, log_stats, dp_rank)
def _init_data_parallel(self, vllm_config: VllmConfig):
# Configure GPUs and stateless process group for data parallel.
dp_rank = vllm_config.parallel_config.data_parallel_rank
dp_size = vllm_config.parallel_config.data_parallel_size
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
assert dp_size > 1
@@ -632,24 +695,16 @@ class DPEngineCoreProc(EngineCoreProc):
from vllm.platforms import current_platform
device_control_env_var = current_platform.device_control_env_var
tp_size = vllm_config.parallel_config.tensor_parallel_size
world_size = vllm_config.parallel_config.world_size
os.environ[device_control_env_var] = ",".join(
str(current_platform.device_id_to_physical_device_id(i))
for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
tp_size))
for i in range(local_dp_rank * world_size, (local_dp_rank + 1) *
world_size))
self.local_dp_rank = local_dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
self.current_wave = 0
# Initialize the engine after setting up environment.
super().__init__(input_path, output_path, vllm_config, executor_class,
log_stats, dp_rank)
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self.counter = 0
def shutdown(self):
super().shutdown()
if dp_group := getattr(self, "dp_group", None):