[V1] DP scale-out (2/N): Decouple engine process management and comms (#15977)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import signal
|
||||
@@ -23,7 +22,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.utils import resolve_obj_by_qualname, zmq_socket_ctx
|
||||
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname, zmq_socket_ctx
|
||||
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
|
||||
unify_kv_cache_configs)
|
||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||
@@ -43,6 +42,7 @@ from vllm.version import __version__ as VLLM_VERSION
|
||||
logger = init_logger(__name__)
|
||||
|
||||
POLLING_TIMEOUT_S = 2.5
|
||||
HANDSHAKE_TIMEOUT_MINS = 5
|
||||
|
||||
_R = TypeVar('_R') # Return type for collective_rpc
|
||||
|
||||
@@ -348,9 +348,9 @@ class EngineCoreProc(EngineCore):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
vllm_config: VllmConfig,
|
||||
on_head_node: bool,
|
||||
input_address: str,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
engine_index: int = 0,
|
||||
@@ -360,28 +360,91 @@ class EngineCoreProc(EngineCore):
|
||||
executor_fail_callback = lambda: input_queue.put_nowait(
|
||||
(EngineCoreRequestType.EXECUTOR_FAILED, b''))
|
||||
|
||||
super().__init__(vllm_config, executor_class, log_stats,
|
||||
executor_fail_callback)
|
||||
# Create input socket.
|
||||
input_ctx = zmq.Context()
|
||||
identity = engine_index.to_bytes(length=2, byteorder="little")
|
||||
input_socket = make_zmq_socket(input_ctx,
|
||||
input_address,
|
||||
zmq.DEALER,
|
||||
identity=identity,
|
||||
bind=False)
|
||||
try:
|
||||
# Register engine with front-end.
|
||||
output_address = self.startup_handshake(
|
||||
input_socket, on_head_node, vllm_config.parallel_config)
|
||||
|
||||
self.step_fn = (self.step if self.batch_queue is None else
|
||||
self.step_with_batch_queue)
|
||||
self.engines_running = False
|
||||
# Update config which may have changed from the handshake.
|
||||
vllm_config.__post_init__()
|
||||
|
||||
# Background Threads and Queues for IO. These enable us to
|
||||
# overlap ZMQ socket IO with GPU since they release the GIL,
|
||||
# and to overlap some serialization/deserialization with the
|
||||
# model forward pass.
|
||||
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
||||
self.input_queue = input_queue
|
||||
self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]()
|
||||
threading.Thread(target=self.process_input_socket,
|
||||
args=(input_path, engine_index),
|
||||
daemon=True).start()
|
||||
self.output_thread = threading.Thread(
|
||||
target=self.process_output_socket,
|
||||
args=(output_path, engine_index),
|
||||
daemon=True)
|
||||
self.output_thread.start()
|
||||
# Set up data parallel environment.
|
||||
self._init_data_parallel(vllm_config)
|
||||
|
||||
# Initialize engine core and model.
|
||||
super().__init__(vllm_config, executor_class, log_stats,
|
||||
executor_fail_callback)
|
||||
|
||||
self.step_fn = (self.step if self.batch_queue is None else
|
||||
self.step_with_batch_queue)
|
||||
self.engines_running = False
|
||||
|
||||
# Send ready message.
|
||||
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
|
||||
input_socket.send(
|
||||
msgspec.msgpack.encode({
|
||||
"status": "READY",
|
||||
"local": on_head_node,
|
||||
"num_gpu_blocks": num_gpu_blocks,
|
||||
}))
|
||||
|
||||
# Background Threads and Queues for IO. These enable us to
|
||||
# overlap ZMQ socket IO with GPU since they release the GIL,
|
||||
# and to overlap some serialization/deserialization with the
|
||||
# model forward pass.
|
||||
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
||||
self.input_queue = input_queue
|
||||
self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]()
|
||||
threading.Thread(target=self.process_input_socket,
|
||||
args=(input_socket, ),
|
||||
daemon=True).start()
|
||||
input_socket = None
|
||||
self.output_thread = threading.Thread(
|
||||
target=self.process_output_socket,
|
||||
args=(output_address, engine_index),
|
||||
daemon=True)
|
||||
self.output_thread.start()
|
||||
finally:
|
||||
if input_socket is not None:
|
||||
input_socket.close(linger=0)
|
||||
|
||||
@staticmethod
|
||||
def startup_handshake(input_socket: zmq.Socket, on_head_node: bool,
|
||||
parallel_config: ParallelConfig) -> str:
|
||||
|
||||
# Send registration message.
|
||||
input_socket.send(
|
||||
msgspec.msgpack.encode({
|
||||
"status": "HELLO",
|
||||
"local": on_head_node,
|
||||
}))
|
||||
|
||||
# Receive initialization message.
|
||||
logger.info("Waiting for init message from front-end.")
|
||||
if not input_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60 * 1000):
|
||||
raise RuntimeError("Did not receive response from front-end "
|
||||
f"process within {HANDSHAKE_TIMEOUT_MINS} "
|
||||
f"minutes")
|
||||
init_bytes = input_socket.recv()
|
||||
init_message = msgspec.msgpack.decode(init_bytes)
|
||||
logger.debug("Received init message: %s", init_message)
|
||||
|
||||
output_socket_address = init_message["output_socket_address"]
|
||||
#TBD(nick) maybe replace IP with configured head node address
|
||||
|
||||
received_parallel_config = init_message["parallel_config"]
|
||||
for key, value in received_parallel_config.items():
|
||||
setattr(parallel_config, key, value)
|
||||
|
||||
return output_socket_address
|
||||
|
||||
@staticmethod
|
||||
def run_engine_core(*args,
|
||||
@@ -412,7 +475,7 @@ class EngineCoreProc(EngineCore):
|
||||
try:
|
||||
parallel_config: ParallelConfig = kwargs[
|
||||
"vllm_config"].parallel_config
|
||||
if parallel_config.data_parallel_size > 1:
|
||||
if parallel_config.data_parallel_size > 1 or dp_rank > 0:
|
||||
# Set data parallel rank for this engine process.
|
||||
parallel_config.data_parallel_rank = dp_rank
|
||||
parallel_config.data_parallel_rank_local = local_dp_rank
|
||||
@@ -436,6 +499,9 @@ class EngineCoreProc(EngineCore):
|
||||
if engine_core is not None:
|
||||
engine_core.shutdown()
|
||||
|
||||
def _init_data_parallel(self, vllm_config: VllmConfig):
|
||||
pass
|
||||
|
||||
def run_busy_loop(self):
|
||||
"""Core busy loop of the EngineCore."""
|
||||
|
||||
@@ -527,40 +593,25 @@ class EngineCoreProc(EngineCore):
|
||||
logger.fatal("vLLM shutdown signal from EngineCore failed "
|
||||
"to send. Please report this issue.")
|
||||
|
||||
def process_input_socket(self, input_path: str, engine_index: int):
|
||||
def process_input_socket(self, input_socket: zmq.Socket):
|
||||
"""Input socket IO thread."""
|
||||
|
||||
# Msgpack serialization decoding.
|
||||
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
|
||||
generic_decoder = MsgpackDecoder()
|
||||
identity = engine_index.to_bytes(length=2, byteorder="little")
|
||||
|
||||
with zmq_socket_ctx(input_path,
|
||||
zmq.DEALER,
|
||||
identity=identity,
|
||||
bind=False) as socket:
|
||||
while True:
|
||||
# (RequestType, RequestData)
|
||||
type_frame, *data_frames = input_socket.recv_multipart(copy=False)
|
||||
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
|
||||
|
||||
# Send ready message to front-end once input socket is connected.
|
||||
message_dict = {
|
||||
'type': 'READY',
|
||||
'num_gpu_blocks': self.vllm_config.cache_config.num_gpu_blocks,
|
||||
}
|
||||
message = json.dumps(message_dict).encode('utf-8')
|
||||
socket.send(message)
|
||||
# Deserialize the request data.
|
||||
decoder = add_request_decoder if (
|
||||
request_type == EngineCoreRequestType.ADD) else generic_decoder
|
||||
request = decoder.decode(data_frames)
|
||||
|
||||
while True:
|
||||
# (RequestType, RequestData)
|
||||
type_frame, *data_frames = socket.recv_multipart(copy=False)
|
||||
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
|
||||
|
||||
# Deserialize the request data.
|
||||
decoder = add_request_decoder if (
|
||||
request_type
|
||||
== EngineCoreRequestType.ADD) else generic_decoder
|
||||
request = decoder.decode(data_frames)
|
||||
|
||||
# Push to input queue for core busy loop.
|
||||
self.input_queue.put_nowait((request_type, request))
|
||||
# Push to input queue for core busy loop.
|
||||
self.input_queue.put_nowait((request_type, request))
|
||||
|
||||
def process_output_socket(self, output_path: str, engine_index: int):
|
||||
"""Output socket IO thread."""
|
||||
@@ -609,9 +660,9 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
vllm_config: VllmConfig,
|
||||
on_head_node: bool,
|
||||
input_address: str,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
):
|
||||
@@ -623,8 +674,20 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
_add_prefix(sys.stdout, process_name, pid)
|
||||
_add_prefix(sys.stderr, process_name, pid)
|
||||
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
# Counts forward-passes of the model so that we can synchronize
|
||||
# finished with DP peers every N steps.
|
||||
self.counter = 0
|
||||
|
||||
# Initialize the engine.
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
super().__init__(vllm_config, on_head_node, input_address,
|
||||
executor_class, log_stats, dp_rank)
|
||||
|
||||
def _init_data_parallel(self, vllm_config: VllmConfig):
|
||||
|
||||
# Configure GPUs and stateless process group for data parallel.
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
|
||||
|
||||
assert dp_size > 1
|
||||
@@ -632,24 +695,16 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
device_control_env_var = current_platform.device_control_env_var
|
||||
tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
os.environ[device_control_env_var] = ",".join(
|
||||
str(current_platform.device_id_to_physical_device_id(i))
|
||||
for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
|
||||
tp_size))
|
||||
for i in range(local_dp_rank * world_size, (local_dp_rank + 1) *
|
||||
world_size))
|
||||
|
||||
self.local_dp_rank = local_dp_rank
|
||||
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
|
||||
self.current_wave = 0
|
||||
|
||||
# Initialize the engine after setting up environment.
|
||||
super().__init__(input_path, output_path, vllm_config, executor_class,
|
||||
log_stats, dp_rank)
|
||||
|
||||
# Counts forward-passes of the model so that we can synchronize
|
||||
# finished with DP peers every N steps.
|
||||
self.counter = 0
|
||||
|
||||
def shutdown(self):
|
||||
super().shutdown()
|
||||
if dp_group := getattr(self, "dp_group", None):
|
||||
|
||||
Reference in New Issue
Block a user