[Core] Simplify API server handshake (#39364)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-04-09 03:56:15 -07:00
committed by GitHub
parent d87fb264df
commit c8d98f81f6
5 changed files with 53 additions and 67 deletions

View File

@@ -280,36 +280,23 @@ def run_multi_api_server(args: argparse.Namespace):
vllm_config, executor_class, log_stats, addresses, num_api_servers vllm_config, executor_class, log_stats, addresses, num_api_servers
) as (local_engine_manager, coordinator, addresses, tensor_queue): ) as (local_engine_manager, coordinator, addresses, tensor_queue):
# Construct common args for the APIServerProcessManager up-front. # Construct common args for the APIServerProcessManager up-front.
api_server_manager_kwargs = dict( stats_update_address = None
if coordinator:
stats_update_address = coordinator.get_stats_publish_address()
# Start API servers.
api_server_manager = APIServerProcessManager(
listen_address=listen_address, listen_address=listen_address,
sock=sock, sock=sock,
args=args, args=args,
num_servers=num_api_servers, num_servers=num_api_servers,
input_addresses=addresses.inputs, input_addresses=addresses.inputs,
output_addresses=addresses.outputs, output_addresses=addresses.outputs,
stats_update_address=coordinator.get_stats_publish_address() stats_update_address=stats_update_address,
if coordinator
else None,
tensor_queue=tensor_queue, tensor_queue=tensor_queue,
) )
# For dp ranks > 0 in external/hybrid DP LB modes, we must delay the # Wait for API servers.
# start of the API servers until the local engine is started
# (after the launcher context manager exits),
# since we get the front-end stats update address from the coordinator
# via the handshake with the local engine.
if dp_rank == 0 or not parallel_config.local_engines_only:
# Start API servers using the manager.
api_server_manager = APIServerProcessManager(**api_server_manager_kwargs)
# Start API servers now if they weren't already started.
if api_server_manager is None:
api_server_manager_kwargs["stats_update_address"] = (
addresses.frontend_stats_publish_address
)
api_server_manager = APIServerProcessManager(**api_server_manager_kwargs)
# Wait for API servers
try: try:
wait_for_completion_or_failure( wait_for_completion_or_failure(
api_server_manager=api_server_manager, api_server_manager=api_server_manager,

View File

@@ -4,6 +4,7 @@
import enum import enum
import time import time
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, Literal from typing import Any, Literal
import msgspec import msgspec
@@ -63,17 +64,16 @@ class FinishReason(enum.IntEnum):
return FINISH_REASON_STRINGS[self.value] return FINISH_REASON_STRINGS[self.value]
class EngineCoreReadyResponse( @dataclass
msgspec.Struct, class EngineCoreReadyResponse:
array_like=True, # type: ignore[call-arg] """Sent from EngineCore to each frontend at the end of engine startup.
omit_defaults=True, # type: ignore[call-arg]
):
"""Sent from EngineCore to the frontend during the ready handshake.
Contains post-initialization config that may differ from the original Contains post-initialization config that may differ from the original
values (e.g. max_model_len after KV cache auto-fitting). values (e.g. max_model_len after KV cache auto-fitting).
""" """
num_gpu_blocks: int
dp_stats_address: str | None
max_model_len: int | None = None max_model_len: int | None = None

View File

@@ -936,14 +936,20 @@ class EngineCoreProc(EngineCore):
vllm_config.parallel_config, vllm_config.parallel_config,
) )
if client_handshake_address is None: if client_handshake_address is None:
# We only need to handshake with one party.
with handshake as addresses: with handshake as addresses:
yield addresses yield addresses
else: else:
# We need to handshake with rank 0 front-end and our colocated frontend.
assert local_client assert local_client
local_handshake = self._perform_handshake( local_handshake = self._perform_handshake(
input_ctx, client_handshake_address, identity, True, False, vllm_config input_ctx, client_handshake_address, identity, True, False, vllm_config
) )
with handshake as addresses, local_handshake as client_addresses: with handshake as addresses, local_handshake as client_addresses:
# 1. Obtain DP Coordinator zmq address and DP process group address
# (addresses).
# 2. Add front-end input/output addresses from colocated front-end
# (client_addresses).
addresses.inputs = client_addresses.inputs addresses.inputs = client_addresses.inputs
addresses.outputs = client_addresses.outputs addresses.outputs = client_addresses.outputs
yield addresses yield addresses
@@ -977,20 +983,12 @@ class EngineCoreProc(EngineCore):
yield addresses yield addresses
# Send ready message. # Send ready message.
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
# We pass back the coordinator stats update address here for the
# external LB case for our colocated front-end to use (coordinator
# only runs with rank 0).
dp_stats_address = self.frontend_stats_publish_address
# Include config hash for DP configuration validation
ready_msg = { ready_msg = {
"status": "READY", "status": "READY",
"local": local_client, "local": local_client,
"headless": headless, "headless": headless,
"num_gpu_blocks": num_gpu_blocks,
"dp_stats_address": dp_stats_address,
} }
# Include config hash for DP configuration validation
if vllm_config.parallel_config.data_parallel_size > 1: if vllm_config.parallel_config.data_parallel_size > 1:
ready_msg["parallel_config_hash"] = ( ready_msg["parallel_config_hash"] = (
vllm_config.parallel_config.compute_hash() vllm_config.parallel_config.compute_hash()
@@ -1388,6 +1386,8 @@ class EngineCoreProc(EngineCore):
poller = zmq.Poller() poller = zmq.Poller()
ready_response = EngineCoreReadyResponse( ready_response = EngineCoreReadyResponse(
max_model_len=self.vllm_config.model_config.max_model_len, max_model_len=self.vllm_config.model_config.max_model_len,
num_gpu_blocks=self.vllm_config.cache_config.num_gpu_blocks or 0,
dp_stats_address=self.frontend_stats_publish_address,
) )
ready_payload = msgspec.msgpack.encode(ready_response) ready_payload = msgspec.msgpack.encode(ready_response)
for input_socket in input_sockets: for input_socket in input_sockets:

View File

@@ -457,22 +457,6 @@ class ElasticScalingCache:
pending_notifications: dict[EEPNotificationType, set[int]] pending_notifications: dict[EEPNotificationType, set[int]]
_ready_response_decoder = msgspec.msgpack.Decoder(EngineCoreReadyResponse)
def _apply_ready_response(payload: bytes, vllm_config: VllmConfig) -> None:
"""Decode an EngineCoreReadyResponse and sync any post-initialization
config changes (e.g. auto-fitted max_model_len) back to the frontend."""
if not payload:
return
response = _ready_response_decoder.decode(payload)
if response.max_model_len is not None:
vllm_config.model_config.max_model_len = min(
vllm_config.model_config.max_model_len,
response.max_model_len,
)
class MPClient(EngineCoreClient): class MPClient(EngineCoreClient):
""" """
MPClient: base client for multi-proc EngineCore. MPClient: base client for multi-proc EngineCore.
@@ -608,7 +592,7 @@ class MPClient(EngineCoreClient):
) )
identity, payload = sync_input_socket.recv_multipart() identity, payload = sync_input_socket.recv_multipart()
identities.remove(identity) identities.remove(identity)
_apply_ready_response(payload, vllm_config) self._apply_ready_response(payload)
self.core_engine: EngineIdentity = self.core_engines[0] self.core_engine: EngineIdentity = self.core_engines[0]
self.utility_results: dict[int, AnyFuture] = {} self.utility_results: dict[int, AnyFuture] = {}
@@ -680,6 +664,34 @@ class MPClient(EngineCoreClient):
target=monitor_engine_cores, daemon=True, name="MPClientEngineMonitor" target=monitor_engine_cores, daemon=True, name="MPClientEngineMonitor"
).start() ).start()
def _apply_ready_response(self, payload: bytes) -> None:
"""Decode an EngineCoreReadyResponse and sync any post-initialization
config changes (e.g. auto-fitted max_model_len) back to the frontend."""
if not payload:
return
vllm_config = self.vllm_config
response = msgspec.msgpack.decode(payload, type=EngineCoreReadyResponse)
if response.max_model_len is not None:
vllm_config.model_config.max_model_len = min(
vllm_config.model_config.max_model_len,
response.max_model_len,
)
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks or 0
num_gpu_blocks += response.num_gpu_blocks
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
# In external DP LB mode, the coordinator address that the
# front-end procs connect to is obtained by each engine via it's
# initial handshake with the rank 0 front-end.
if response.dp_stats_address is not None:
if self.stats_update_address is None:
self.stats_update_address = response.dp_stats_address
else:
assert response.dp_stats_address == self.stats_update_address
def _process_utility_output( def _process_utility_output(
output: UtilityOutput, utility_results: dict[int, AnyFuture] output: UtilityOutput, utility_results: dict[int, AnyFuture]
@@ -1602,7 +1614,7 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
) )
identity, payload = sync_input_socket.recv_multipart() identity, payload = sync_input_socket.recv_multipart()
new_engine_identities.discard(identity) new_engine_identities.discard(identity)
_apply_ready_response(payload, self.vllm_config) self._apply_ready_response(payload)
# NOTE(yongji): Before we schedule any requests on the new workers, # NOTE(yongji): Before we schedule any requests on the new workers,
# we should wait for them to switch to the new setup. # we should wait for them to switch to the new setup.

View File

@@ -1212,19 +1212,6 @@ def wait_for_engine_startup(
start_pending[0 if local else 1] += 1 start_pending[0 if local else 1] += 1
engine.state = CoreEngineState.CONNECTED engine.state = CoreEngineState.CONNECTED
elif status == "READY" and engine.state == CoreEngineState.CONNECTED: elif status == "READY" and engine.state == CoreEngineState.CONNECTED:
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
num_gpu_blocks = cache_config.num_gpu_blocks or 0
num_gpu_blocks += msg["num_gpu_blocks"]
cache_config.num_gpu_blocks = num_gpu_blocks
# In external DP LB mode, the coordinator address that the
# front-end procs connect to is obtained from rank 0 via
# one of the engine handshakes, and passed to the local
# front-end process in the response from the other.
if addresses.frontend_stats_publish_address is None:
addresses.frontend_stats_publish_address = msg.get("dp_stats_address")
# Validate config hash consistency across DP workers for MoE models. # Validate config hash consistency across DP workers for MoE models.
if coordinated_dp: if coordinated_dp:
worker_config_hash = msg.get("parallel_config_hash") worker_config_hash = msg.get("parallel_config_hash")