[Perf] API-server scaleout with many-to-many server-engine comms (#17546)

This commit is contained in:
Nick Hill
2025-05-30 08:17:00 -07:00
committed by GitHub
parent 84ec470fca
commit 2dbe8c0774
26 changed files with 1828 additions and 436 deletions

View File

@@ -7,6 +7,7 @@ import threading
import time
from collections import deque
from concurrent.futures import Future
from contextlib import ExitStack
from inspect import isclass, signature
from logging import DEBUG
from typing import Any, Callable, Optional, TypeVar, Union
@@ -22,7 +23,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname, zmq_socket_ctx
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
unify_kv_cache_configs)
from vllm.v1.core.sched.interface import SchedulerInterface
@@ -33,10 +34,12 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.utils import EngineHandshakeMetadata, EngineZmqAddresses
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
@@ -211,7 +214,7 @@ class EngineCore:
# Re-raise exception
raise err
def step(self) -> tuple[EngineCoreOutputs, bool]:
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
"""Schedule, execute, and make output.
Returns tuple of outputs and a flag indicating whether the model
@@ -221,10 +224,7 @@ class EngineCore:
# Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch.
if not self.scheduler.has_requests():
return EngineCoreOutputs(
outputs=[],
scheduler_stats=self.scheduler.make_stats(),
), False
return {}, False
scheduler_output = self.scheduler.schedule()
model_output = self.execute_model(scheduler_output)
engine_core_outputs = self.scheduler.update_from_output(
@@ -234,7 +234,7 @@ class EngineCore:
scheduler_output.total_num_scheduled_tokens > 0)
def step_with_batch_queue(
self) -> tuple[Optional[EngineCoreOutputs], bool]:
self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
"""Schedule and execute batches with the batch queue.
Note that if nothing to output in this step, None is returned.
@@ -276,8 +276,8 @@ class EngineCore:
# Blocking until the first result is available.
model_output = future.result()
self.batch_queue.task_done()
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output)
engine_core_outputs = (self.scheduler.update_from_output(
scheduler_output, model_output))
return engine_core_outputs, scheduled_batch
@@ -362,7 +362,7 @@ class EngineCoreProc(EngineCore):
self,
vllm_config: VllmConfig,
on_head_node: bool,
input_address: str,
handshake_address: str,
executor_class: type[Executor],
log_stats: bool,
engine_index: int = 0,
@@ -375,65 +375,70 @@ class EngineCoreProc(EngineCore):
# Create input socket.
input_ctx = zmq.Context()
identity = engine_index.to_bytes(length=2, byteorder="little")
input_socket = make_zmq_socket(input_ctx,
input_address,
zmq.DEALER,
identity=identity,
bind=False)
try:
with make_zmq_socket(input_ctx,
handshake_address,
zmq.DEALER,
identity=identity,
linger=5000,
bind=False) as handshake_socket:
# Register engine with front-end.
output_address = self.startup_handshake(
input_socket, on_head_node, vllm_config.parallel_config)
addresses = self.startup_handshake(handshake_socket, on_head_node,
vllm_config.parallel_config)
self.client_count = len(addresses.outputs)
# Update config which may have changed from the handshake.
vllm_config.__post_init__()
# Set up data parallel environment.
self.has_coordinator = addresses.coordinator_output is not None
self._init_data_parallel(vllm_config)
# Initialize engine core and model.
super().__init__(vllm_config, executor_class, log_stats,
executor_fail_callback)
self.engine_index = engine_index
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
self.engines_running = False
self.last_counts = (0, 0)
# Send ready message.
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
input_socket.send(
handshake_socket.send(
msgspec.msgpack.encode({
"status": "READY",
"local": on_head_node,
"num_gpu_blocks": num_gpu_blocks,
}))
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue = input_queue
self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]()
threading.Thread(target=self.process_input_socket,
args=(input_socket, ),
daemon=True).start()
input_socket = None
self.output_thread = threading.Thread(
target=self.process_output_socket,
args=(output_address, engine_index),
daemon=True)
self.output_thread.start()
finally:
if input_socket is not None:
input_socket.close(linger=0)
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue = input_queue
self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
bytes]]()
threading.Thread(target=self.process_input_sockets,
args=(addresses.inputs, addresses.coordinator_input,
identity),
daemon=True).start()
self.output_thread = threading.Thread(
target=self.process_output_sockets,
args=(addresses.outputs, addresses.coordinator_output,
engine_index),
daemon=True)
self.output_thread.start()
@staticmethod
def startup_handshake(input_socket: zmq.Socket, on_head_node: bool,
parallel_config: ParallelConfig) -> str:
def startup_handshake(
handshake_socket: zmq.Socket, on_head_node: bool,
parallel_config: ParallelConfig) -> EngineZmqAddresses:
# Send registration message.
input_socket.send(
handshake_socket.send(
msgspec.msgpack.encode({
"status": "HELLO",
"local": on_head_node,
@@ -441,22 +446,20 @@ class EngineCoreProc(EngineCore):
# Receive initialization message.
logger.info("Waiting for init message from front-end.")
if not input_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60 * 1000):
if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
raise RuntimeError("Did not receive response from front-end "
f"process within {HANDSHAKE_TIMEOUT_MINS} "
f"minutes")
init_bytes = input_socket.recv()
init_message = msgspec.msgpack.decode(init_bytes)
init_bytes = handshake_socket.recv()
init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
init_bytes, type=EngineHandshakeMetadata)
logger.debug("Received init message: %s", init_message)
output_socket_address = init_message["output_socket_address"]
#TBD(nick) maybe replace IP with configured head node address
received_parallel_config = init_message["parallel_config"]
received_parallel_config = init_message.parallel_config
for key, value in received_parallel_config.items():
setattr(parallel_config, key, value)
return output_socket_address
return init_message.addresses
@staticmethod
def run_engine_core(*args,
@@ -528,7 +531,7 @@ class EngineCoreProc(EngineCore):
"""Exits when an engine step needs to be performed."""
waited = False
while not self.engines_running and not (self.scheduler.has_requests()):
while not self.engines_running and not self.scheduler.has_requests():
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
logger.debug("EngineCore waiting for work.")
waited = True
@@ -549,8 +552,8 @@ class EngineCoreProc(EngineCore):
# Step the engine core.
outputs, model_executed = self.step_fn()
# Put EngineCoreOutputs into the output queue.
if outputs is not None:
self.output_queue.put_nowait(outputs)
for output in (outputs.items() if outputs else ()):
self.output_queue.put_nowait(output)
return model_executed
@@ -563,7 +566,7 @@ class EngineCoreProc(EngineCore):
elif request_type == EngineCoreRequestType.ABORT:
self.abort_requests(request)
elif request_type == EngineCoreRequestType.UTILITY:
call_id, method_name, args = request
client_idx, call_id, method_name, args = request
output = UtilityOutput(call_id)
try:
method = getattr(self, method_name)
@@ -574,7 +577,7 @@ class EngineCoreProc(EngineCore):
output.failure_message = (f"Call to {method_name} method"
f" failed: {str(e)}")
self.output_queue.put_nowait(
EngineCoreOutputs(utility_output=output))
(client_idx, EngineCoreOutputs(utility_output=output)))
elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
raise RuntimeError("Executor failed.")
else:
@@ -607,27 +610,68 @@ class EngineCoreProc(EngineCore):
logger.fatal("vLLM shutdown signal from EngineCore failed "
"to send. Please report this issue.")
def process_input_socket(self, input_socket: zmq.Socket):
def process_input_sockets(self, input_addresses: list[str],
coord_input_address: Optional[str],
identity: bytes):
"""Input socket IO thread."""
# Msgpack serialization decoding.
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
generic_decoder = MsgpackDecoder()
while True:
# (RequestType, RequestData)
type_frame, *data_frames = input_socket.recv_multipart(copy=False)
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
with ExitStack() as stack, zmq.Context() as ctx:
input_sockets = [
stack.enter_context(
make_zmq_socket(ctx,
input_address,
zmq.DEALER,
identity=identity,
bind=False))
for input_address in input_addresses
]
if coord_input_address is None:
coord_socket = None
else:
coord_socket = stack.enter_context(
make_zmq_socket(ctx,
coord_input_address,
zmq.XSUB,
identity=identity,
bind=False))
# Send subscription message to coordinator.
coord_socket.send(b'\x01')
# Deserialize the request data.
decoder = add_request_decoder if (
request_type == EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frames)
# Register sockets with poller.
poller = zmq.Poller()
for input_socket in input_sockets:
# Send initial message to each input socket - this is required
# before the front-end ROUTER socket can send input messages
# back to us.
input_socket.send(b'')
poller.register(input_socket, zmq.POLLIN)
if coord_socket is not None:
poller.register(coord_socket, zmq.POLLIN)
# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
while True:
for input_socket, _ in poller.poll():
# (RequestType, RequestData)
type_frame, *data_frames = input_socket.recv_multipart(
copy=False)
request_type = EngineCoreRequestType(
bytes(type_frame.buffer))
def process_output_socket(self, output_path: str, engine_index: int):
# Deserialize the request data.
decoder = add_request_decoder if (
request_type
== EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frames)
# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
def process_output_sockets(self, output_paths: list[str],
coord_output_path: Optional[str],
engine_index: int):
"""Output socket IO thread."""
# Msgpack serialization encoding.
@@ -641,30 +685,49 @@ class EngineCoreProc(EngineCore):
# We must set linger to ensure the ENGINE_CORE_DEAD
# message is sent prior to closing the socket.
with zmq_socket_ctx(output_path, zmq.constants.PUSH,
linger=4000) as socket:
with ExitStack() as stack, zmq.Context() as ctx:
sockets = [
stack.enter_context(
make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000))
for output_path in output_paths
]
coord_socket = stack.enter_context(
make_zmq_socket(
ctx, coord_output_path, zmq.PUSH, bind=False,
linger=4000)) if coord_output_path is not None else None
max_reuse_bufs = len(sockets) + 1
while True:
outputs = self.output_queue.get()
if outputs == EngineCoreProc.ENGINE_CORE_DEAD:
socket.send(outputs, copy=False)
output = self.output_queue.get()
if output == EngineCoreProc.ENGINE_CORE_DEAD:
for socket in sockets:
socket.send(output)
break
assert not isinstance(outputs, bytes)
assert not isinstance(output, bytes)
client_index, outputs = output
outputs.engine_index = engine_index
if client_index == -1:
# Don't reuse buffer for coordinator message
# which will be very small.
assert coord_socket is not None
coord_socket.send_multipart(encoder.encode(outputs))
continue
# Reclaim buffers that zmq is finished with.
while pending and pending[-1][0].done:
reuse_buffers.append(pending.pop()[2])
buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
buffers = encoder.encode_into(outputs, buffer)
tracker = socket.send_multipart(buffers,
copy=False,
track=True)
tracker = sockets[client_index].send_multipart(buffers,
copy=False,
track=True)
if not tracker.done:
ref = outputs if len(buffers) > 1 else None
pending.appendleft((tracker, ref, buffer))
elif len(reuse_buffers) < 2:
# Keep at most 2 buffers to reuse.
elif len(reuse_buffers) < max_reuse_bufs:
# Limit the number of buffers to reuse.
reuse_buffers.append(buffer)
@@ -676,7 +739,7 @@ class DPEngineCoreProc(EngineCoreProc):
self,
vllm_config: VllmConfig,
on_head_node: bool,
input_address: str,
handshake_address: str,
executor_class: type[Executor],
log_stats: bool,
):
@@ -691,10 +754,11 @@ class DPEngineCoreProc(EngineCoreProc):
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self.counter = 0
self.current_wave = 0
# Initialize the engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank
super().__init__(vllm_config, on_head_node, input_address,
super().__init__(vllm_config, on_head_node, handshake_address,
executor_class, log_stats, dp_rank)
def _init_data_parallel(self, vllm_config: VllmConfig):
@@ -726,7 +790,6 @@ class DPEngineCoreProc(EngineCoreProc):
self.dp_rank = dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
self.current_wave = 0
def shutdown(self):
super().shutdown()
@@ -734,22 +797,23 @@ class DPEngineCoreProc(EngineCoreProc):
stateless_destroy_torch_distributed_process_group(dp_group)
def add_request(self, request: EngineCoreRequest):
if request.current_wave != self.current_wave:
if self.has_coordinator and request.current_wave != self.current_wave:
if request.current_wave > self.current_wave:
self.current_wave = request.current_wave
elif not self.engines_running:
# Request received for an already-completed wave, notify
# front-end that we need to start the next one.
self.output_queue.put_nowait(
EngineCoreOutputs(start_wave=self.current_wave))
(-1, EngineCoreOutputs(start_wave=self.current_wave)))
super().add_request(request)
def _handle_client_request(self, request_type: EngineCoreRequestType,
request: Any) -> None:
if request_type == EngineCoreRequestType.START_DP_WAVE:
new_wave: int = request
if new_wave >= self.current_wave:
new_wave, exclude_eng_index = request
if exclude_eng_index != self.engine_index and (
new_wave >= self.current_wave):
self.current_wave = new_wave
if not self.engines_running:
logger.debug("EngineCore starting idle loop for wave %d.",
@@ -758,6 +822,18 @@ class DPEngineCoreProc(EngineCoreProc):
else:
super()._handle_client_request(request_type, request)
def _maybe_publish_request_counts(self):
if not self.has_coordinator:
return
# Publish our request counts (if they've changed).
counts = self.scheduler.get_request_counts()
if counts != self.last_counts:
self.last_counts = counts
stats = SchedulerStats(*counts)
self.output_queue.put_nowait(
(-1, EngineCoreOutputs(scheduler_stats=stats)))
def run_busy_loop(self):
"""Core busy loop of the EngineCore for data parallel case."""
@@ -768,6 +844,8 @@ class DPEngineCoreProc(EngineCoreProc):
# 2) Step the engine core.
executed = self._process_engine_step()
self._maybe_publish_request_counts()
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
if not executed:
if not local_unfinished_reqs and not self.engines_running:
@@ -788,7 +866,8 @@ class DPEngineCoreProc(EngineCoreProc):
logger.debug("Wave %d finished, pausing engine loop.",
self.current_wave)
self.output_queue.put_nowait(
EngineCoreOutputs(wave_complete=self.current_wave))
(-1,
EngineCoreOutputs(wave_complete=self.current_wave)))
self.current_wave += 1
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: