[Core] Simplify API server handshake (#39364)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-04-09 03:56:15 -07:00
committed by GitHub
parent d87fb264df
commit c8d98f81f6
5 changed files with 53 additions and 67 deletions

View File

@@ -280,36 +280,23 @@ def run_multi_api_server(args: argparse.Namespace):
vllm_config, executor_class, log_stats, addresses, num_api_servers
) as (local_engine_manager, coordinator, addresses, tensor_queue):
# Construct common args for the APIServerProcessManager up-front.
api_server_manager_kwargs = dict(
stats_update_address = None
if coordinator:
stats_update_address = coordinator.get_stats_publish_address()
# Start API servers.
api_server_manager = APIServerProcessManager(
listen_address=listen_address,
sock=sock,
args=args,
num_servers=num_api_servers,
input_addresses=addresses.inputs,
output_addresses=addresses.outputs,
stats_update_address=coordinator.get_stats_publish_address()
if coordinator
else None,
stats_update_address=stats_update_address,
tensor_queue=tensor_queue,
)
# For dp ranks > 0 in external/hybrid DP LB modes, we must delay the
# start of the API servers until the local engine is started
# (after the launcher context manager exits),
# since we get the front-end stats update address from the coordinator
# via the handshake with the local engine.
if dp_rank == 0 or not parallel_config.local_engines_only:
# Start API servers using the manager.
api_server_manager = APIServerProcessManager(**api_server_manager_kwargs)
# Start API servers now if they weren't already started.
if api_server_manager is None:
api_server_manager_kwargs["stats_update_address"] = (
addresses.frontend_stats_publish_address
)
api_server_manager = APIServerProcessManager(**api_server_manager_kwargs)
# Wait for API servers
# Wait for API servers.
try:
wait_for_completion_or_failure(
api_server_manager=api_server_manager,

View File

@@ -4,6 +4,7 @@
import enum
import time
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, Literal
import msgspec
@@ -63,17 +64,16 @@ class FinishReason(enum.IntEnum):
return FINISH_REASON_STRINGS[self.value]
class EngineCoreReadyResponse(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
):
"""Sent from EngineCore to the frontend during the ready handshake.
@dataclass
class EngineCoreReadyResponse:
"""Sent from EngineCore to each frontend at the end of engine startup.
Contains post-initialization config that may differ from the original
values (e.g. max_model_len after KV cache auto-fitting).
"""
num_gpu_blocks: int
dp_stats_address: str | None
max_model_len: int | None = None

View File

@@ -936,14 +936,20 @@ class EngineCoreProc(EngineCore):
vllm_config.parallel_config,
)
if client_handshake_address is None:
# We only need to handshake with one party.
with handshake as addresses:
yield addresses
else:
# We need to handshake with rank 0 front-end and our colocated frontend.
assert local_client
local_handshake = self._perform_handshake(
input_ctx, client_handshake_address, identity, True, False, vllm_config
)
with handshake as addresses, local_handshake as client_addresses:
# 1. Obtain DP Coordinator zmq address and DP process group address
# (addresses).
# 2. Add front-end input/output addresses from colocated front-end
# (client_addresses).
addresses.inputs = client_addresses.inputs
addresses.outputs = client_addresses.outputs
yield addresses
@@ -977,20 +983,12 @@ class EngineCoreProc(EngineCore):
yield addresses
# Send ready message.
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
# We pass back the coordinator stats update address here for the
# external LB case for our colocated front-end to use (coordinator
# only runs with rank 0).
dp_stats_address = self.frontend_stats_publish_address
# Include config hash for DP configuration validation
ready_msg = {
"status": "READY",
"local": local_client,
"headless": headless,
"num_gpu_blocks": num_gpu_blocks,
"dp_stats_address": dp_stats_address,
}
# Include config hash for DP configuration validation
if vllm_config.parallel_config.data_parallel_size > 1:
ready_msg["parallel_config_hash"] = (
vllm_config.parallel_config.compute_hash()
@@ -1388,6 +1386,8 @@ class EngineCoreProc(EngineCore):
poller = zmq.Poller()
ready_response = EngineCoreReadyResponse(
max_model_len=self.vllm_config.model_config.max_model_len,
num_gpu_blocks=self.vllm_config.cache_config.num_gpu_blocks or 0,
dp_stats_address=self.frontend_stats_publish_address,
)
ready_payload = msgspec.msgpack.encode(ready_response)
for input_socket in input_sockets:

View File

@@ -457,22 +457,6 @@ class ElasticScalingCache:
pending_notifications: dict[EEPNotificationType, set[int]]
_ready_response_decoder = msgspec.msgpack.Decoder(EngineCoreReadyResponse)
def _apply_ready_response(payload: bytes, vllm_config: VllmConfig) -> None:
"""Decode an EngineCoreReadyResponse and sync any post-initialization
config changes (e.g. auto-fitted max_model_len) back to the frontend."""
if not payload:
return
response = _ready_response_decoder.decode(payload)
if response.max_model_len is not None:
vllm_config.model_config.max_model_len = min(
vllm_config.model_config.max_model_len,
response.max_model_len,
)
class MPClient(EngineCoreClient):
"""
MPClient: base client for multi-proc EngineCore.
@@ -608,7 +592,7 @@ class MPClient(EngineCoreClient):
)
identity, payload = sync_input_socket.recv_multipart()
identities.remove(identity)
_apply_ready_response(payload, vllm_config)
self._apply_ready_response(payload)
self.core_engine: EngineIdentity = self.core_engines[0]
self.utility_results: dict[int, AnyFuture] = {}
@@ -680,6 +664,34 @@ class MPClient(EngineCoreClient):
target=monitor_engine_cores, daemon=True, name="MPClientEngineMonitor"
).start()
def _apply_ready_response(self, payload: bytes) -> None:
"""Decode an EngineCoreReadyResponse and sync any post-initialization
config changes (e.g. auto-fitted max_model_len) back to the frontend."""
if not payload:
return
vllm_config = self.vllm_config
response = msgspec.msgpack.decode(payload, type=EngineCoreReadyResponse)
if response.max_model_len is not None:
vllm_config.model_config.max_model_len = min(
vllm_config.model_config.max_model_len,
response.max_model_len,
)
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks or 0
num_gpu_blocks += response.num_gpu_blocks
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
# In external DP LB mode, the coordinator address that the
# front-end procs connect to is obtained by each engine via it's
# initial handshake with the rank 0 front-end.
if response.dp_stats_address is not None:
if self.stats_update_address is None:
self.stats_update_address = response.dp_stats_address
else:
assert response.dp_stats_address == self.stats_update_address
def _process_utility_output(
output: UtilityOutput, utility_results: dict[int, AnyFuture]
@@ -1602,7 +1614,7 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
)
identity, payload = sync_input_socket.recv_multipart()
new_engine_identities.discard(identity)
_apply_ready_response(payload, self.vllm_config)
self._apply_ready_response(payload)
# NOTE(yongji): Before we schedule any requests on the new workers,
# we should wait for them to switch to the new setup.

View File

@@ -1212,19 +1212,6 @@ def wait_for_engine_startup(
start_pending[0 if local else 1] += 1
engine.state = CoreEngineState.CONNECTED
elif status == "READY" and engine.state == CoreEngineState.CONNECTED:
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
num_gpu_blocks = cache_config.num_gpu_blocks or 0
num_gpu_blocks += msg["num_gpu_blocks"]
cache_config.num_gpu_blocks = num_gpu_blocks
# In external DP LB mode, the coordinator address that the
# front-end procs connect to is obtained from rank 0 via
# one of the engine handshakes, and passed to the local
# front-end process in the response from the other.
if addresses.frontend_stats_publish_address is None:
addresses.frontend_stats_publish_address = msg.get("dp_stats_address")
# Validate config hash consistency across DP workers for MoE models.
if coordinated_dp:
worker_config_hash = msg.get("parallel_config_hash")