[Core] Simplify API server handshake (#39364)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -280,36 +280,23 @@ def run_multi_api_server(args: argparse.Namespace):
|
||||
vllm_config, executor_class, log_stats, addresses, num_api_servers
|
||||
) as (local_engine_manager, coordinator, addresses, tensor_queue):
|
||||
# Construct common args for the APIServerProcessManager up-front.
|
||||
api_server_manager_kwargs = dict(
|
||||
stats_update_address = None
|
||||
if coordinator:
|
||||
stats_update_address = coordinator.get_stats_publish_address()
|
||||
|
||||
# Start API servers.
|
||||
api_server_manager = APIServerProcessManager(
|
||||
listen_address=listen_address,
|
||||
sock=sock,
|
||||
args=args,
|
||||
num_servers=num_api_servers,
|
||||
input_addresses=addresses.inputs,
|
||||
output_addresses=addresses.outputs,
|
||||
stats_update_address=coordinator.get_stats_publish_address()
|
||||
if coordinator
|
||||
else None,
|
||||
stats_update_address=stats_update_address,
|
||||
tensor_queue=tensor_queue,
|
||||
)
|
||||
|
||||
# For dp ranks > 0 in external/hybrid DP LB modes, we must delay the
|
||||
# start of the API servers until the local engine is started
|
||||
# (after the launcher context manager exits),
|
||||
# since we get the front-end stats update address from the coordinator
|
||||
# via the handshake with the local engine.
|
||||
if dp_rank == 0 or not parallel_config.local_engines_only:
|
||||
# Start API servers using the manager.
|
||||
api_server_manager = APIServerProcessManager(**api_server_manager_kwargs)
|
||||
|
||||
# Start API servers now if they weren't already started.
|
||||
if api_server_manager is None:
|
||||
api_server_manager_kwargs["stats_update_address"] = (
|
||||
addresses.frontend_stats_publish_address
|
||||
)
|
||||
api_server_manager = APIServerProcessManager(**api_server_manager_kwargs)
|
||||
|
||||
# Wait for API servers
|
||||
# Wait for API servers.
|
||||
try:
|
||||
wait_for_completion_or_failure(
|
||||
api_server_manager=api_server_manager,
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import enum
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
import msgspec
|
||||
@@ -63,17 +64,16 @@ class FinishReason(enum.IntEnum):
|
||||
return FINISH_REASON_STRINGS[self.value]
|
||||
|
||||
|
||||
class EngineCoreReadyResponse(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
):
|
||||
"""Sent from EngineCore to the frontend during the ready handshake.
|
||||
@dataclass
|
||||
class EngineCoreReadyResponse:
|
||||
"""Sent from EngineCore to each frontend at the end of engine startup.
|
||||
|
||||
Contains post-initialization config that may differ from the original
|
||||
values (e.g. max_model_len after KV cache auto-fitting).
|
||||
"""
|
||||
|
||||
num_gpu_blocks: int
|
||||
dp_stats_address: str | None
|
||||
max_model_len: int | None = None
|
||||
|
||||
|
||||
|
||||
@@ -936,14 +936,20 @@ class EngineCoreProc(EngineCore):
|
||||
vllm_config.parallel_config,
|
||||
)
|
||||
if client_handshake_address is None:
|
||||
# We only need to handshake with one party.
|
||||
with handshake as addresses:
|
||||
yield addresses
|
||||
else:
|
||||
# We need to handshake with rank 0 front-end and our colocated frontend.
|
||||
assert local_client
|
||||
local_handshake = self._perform_handshake(
|
||||
input_ctx, client_handshake_address, identity, True, False, vllm_config
|
||||
)
|
||||
with handshake as addresses, local_handshake as client_addresses:
|
||||
# 1. Obtain DP Coordinator zmq address and DP process group address
|
||||
# (addresses).
|
||||
# 2. Add front-end input/output addresses from colocated front-end
|
||||
# (client_addresses).
|
||||
addresses.inputs = client_addresses.inputs
|
||||
addresses.outputs = client_addresses.outputs
|
||||
yield addresses
|
||||
@@ -977,20 +983,12 @@ class EngineCoreProc(EngineCore):
|
||||
yield addresses
|
||||
|
||||
# Send ready message.
|
||||
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
|
||||
# We pass back the coordinator stats update address here for the
|
||||
# external LB case for our colocated front-end to use (coordinator
|
||||
# only runs with rank 0).
|
||||
dp_stats_address = self.frontend_stats_publish_address
|
||||
|
||||
# Include config hash for DP configuration validation
|
||||
ready_msg = {
|
||||
"status": "READY",
|
||||
"local": local_client,
|
||||
"headless": headless,
|
||||
"num_gpu_blocks": num_gpu_blocks,
|
||||
"dp_stats_address": dp_stats_address,
|
||||
}
|
||||
# Include config hash for DP configuration validation
|
||||
if vllm_config.parallel_config.data_parallel_size > 1:
|
||||
ready_msg["parallel_config_hash"] = (
|
||||
vllm_config.parallel_config.compute_hash()
|
||||
@@ -1388,6 +1386,8 @@ class EngineCoreProc(EngineCore):
|
||||
poller = zmq.Poller()
|
||||
ready_response = EngineCoreReadyResponse(
|
||||
max_model_len=self.vllm_config.model_config.max_model_len,
|
||||
num_gpu_blocks=self.vllm_config.cache_config.num_gpu_blocks or 0,
|
||||
dp_stats_address=self.frontend_stats_publish_address,
|
||||
)
|
||||
ready_payload = msgspec.msgpack.encode(ready_response)
|
||||
for input_socket in input_sockets:
|
||||
|
||||
@@ -457,22 +457,6 @@ class ElasticScalingCache:
|
||||
pending_notifications: dict[EEPNotificationType, set[int]]
|
||||
|
||||
|
||||
_ready_response_decoder = msgspec.msgpack.Decoder(EngineCoreReadyResponse)
|
||||
|
||||
|
||||
def _apply_ready_response(payload: bytes, vllm_config: VllmConfig) -> None:
|
||||
"""Decode an EngineCoreReadyResponse and sync any post-initialization
|
||||
config changes (e.g. auto-fitted max_model_len) back to the frontend."""
|
||||
if not payload:
|
||||
return
|
||||
response = _ready_response_decoder.decode(payload)
|
||||
if response.max_model_len is not None:
|
||||
vllm_config.model_config.max_model_len = min(
|
||||
vllm_config.model_config.max_model_len,
|
||||
response.max_model_len,
|
||||
)
|
||||
|
||||
|
||||
class MPClient(EngineCoreClient):
|
||||
"""
|
||||
MPClient: base client for multi-proc EngineCore.
|
||||
@@ -608,7 +592,7 @@ class MPClient(EngineCoreClient):
|
||||
)
|
||||
identity, payload = sync_input_socket.recv_multipart()
|
||||
identities.remove(identity)
|
||||
_apply_ready_response(payload, vllm_config)
|
||||
self._apply_ready_response(payload)
|
||||
|
||||
self.core_engine: EngineIdentity = self.core_engines[0]
|
||||
self.utility_results: dict[int, AnyFuture] = {}
|
||||
@@ -680,6 +664,34 @@ class MPClient(EngineCoreClient):
|
||||
target=monitor_engine_cores, daemon=True, name="MPClientEngineMonitor"
|
||||
).start()
|
||||
|
||||
def _apply_ready_response(self, payload: bytes) -> None:
|
||||
"""Decode an EngineCoreReadyResponse and sync any post-initialization
|
||||
config changes (e.g. auto-fitted max_model_len) back to the frontend."""
|
||||
if not payload:
|
||||
return
|
||||
vllm_config = self.vllm_config
|
||||
response = msgspec.msgpack.decode(payload, type=EngineCoreReadyResponse)
|
||||
if response.max_model_len is not None:
|
||||
vllm_config.model_config.max_model_len = min(
|
||||
vllm_config.model_config.max_model_len,
|
||||
response.max_model_len,
|
||||
)
|
||||
|
||||
# Setup KV cache config with initialization state from
|
||||
# engine core process. Sum values from all engines in DP case.
|
||||
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks or 0
|
||||
num_gpu_blocks += response.num_gpu_blocks
|
||||
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
|
||||
# In external DP LB mode, the coordinator address that the
|
||||
# front-end procs connect to is obtained by each engine via it's
|
||||
# initial handshake with the rank 0 front-end.
|
||||
if response.dp_stats_address is not None:
|
||||
if self.stats_update_address is None:
|
||||
self.stats_update_address = response.dp_stats_address
|
||||
else:
|
||||
assert response.dp_stats_address == self.stats_update_address
|
||||
|
||||
|
||||
def _process_utility_output(
|
||||
output: UtilityOutput, utility_results: dict[int, AnyFuture]
|
||||
@@ -1602,7 +1614,7 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
)
|
||||
identity, payload = sync_input_socket.recv_multipart()
|
||||
new_engine_identities.discard(identity)
|
||||
_apply_ready_response(payload, self.vllm_config)
|
||||
self._apply_ready_response(payload)
|
||||
|
||||
# NOTE(yongji): Before we schedule any requests on the new workers,
|
||||
# we should wait for them to switch to the new setup.
|
||||
|
||||
@@ -1212,19 +1212,6 @@ def wait_for_engine_startup(
|
||||
start_pending[0 if local else 1] += 1
|
||||
engine.state = CoreEngineState.CONNECTED
|
||||
elif status == "READY" and engine.state == CoreEngineState.CONNECTED:
|
||||
# Setup KV cache config with initialization state from
|
||||
# engine core process. Sum values from all engines in DP case.
|
||||
num_gpu_blocks = cache_config.num_gpu_blocks or 0
|
||||
num_gpu_blocks += msg["num_gpu_blocks"]
|
||||
cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
|
||||
# In external DP LB mode, the coordinator address that the
|
||||
# front-end procs connect to is obtained from rank 0 via
|
||||
# one of the engine handshakes, and passed to the local
|
||||
# front-end process in the response from the other.
|
||||
if addresses.frontend_stats_publish_address is None:
|
||||
addresses.frontend_stats_publish_address = msg.get("dp_stats_address")
|
||||
|
||||
# Validate config hash consistency across DP workers for MoE models.
|
||||
if coordinated_dp:
|
||||
worker_config_hash = msg.get("parallel_config_hash")
|
||||
|
||||
Reference in New Issue
Block a user