[Core] Simplify API server handshake (#39364)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -280,36 +280,23 @@ def run_multi_api_server(args: argparse.Namespace):
|
|||||||
vllm_config, executor_class, log_stats, addresses, num_api_servers
|
vllm_config, executor_class, log_stats, addresses, num_api_servers
|
||||||
) as (local_engine_manager, coordinator, addresses, tensor_queue):
|
) as (local_engine_manager, coordinator, addresses, tensor_queue):
|
||||||
# Construct common args for the APIServerProcessManager up-front.
|
# Construct common args for the APIServerProcessManager up-front.
|
||||||
api_server_manager_kwargs = dict(
|
stats_update_address = None
|
||||||
|
if coordinator:
|
||||||
|
stats_update_address = coordinator.get_stats_publish_address()
|
||||||
|
|
||||||
|
# Start API servers.
|
||||||
|
api_server_manager = APIServerProcessManager(
|
||||||
listen_address=listen_address,
|
listen_address=listen_address,
|
||||||
sock=sock,
|
sock=sock,
|
||||||
args=args,
|
args=args,
|
||||||
num_servers=num_api_servers,
|
num_servers=num_api_servers,
|
||||||
input_addresses=addresses.inputs,
|
input_addresses=addresses.inputs,
|
||||||
output_addresses=addresses.outputs,
|
output_addresses=addresses.outputs,
|
||||||
stats_update_address=coordinator.get_stats_publish_address()
|
stats_update_address=stats_update_address,
|
||||||
if coordinator
|
|
||||||
else None,
|
|
||||||
tensor_queue=tensor_queue,
|
tensor_queue=tensor_queue,
|
||||||
)
|
)
|
||||||
|
|
||||||
# For dp ranks > 0 in external/hybrid DP LB modes, we must delay the
|
# Wait for API servers.
|
||||||
# start of the API servers until the local engine is started
|
|
||||||
# (after the launcher context manager exits),
|
|
||||||
# since we get the front-end stats update address from the coordinator
|
|
||||||
# via the handshake with the local engine.
|
|
||||||
if dp_rank == 0 or not parallel_config.local_engines_only:
|
|
||||||
# Start API servers using the manager.
|
|
||||||
api_server_manager = APIServerProcessManager(**api_server_manager_kwargs)
|
|
||||||
|
|
||||||
# Start API servers now if they weren't already started.
|
|
||||||
if api_server_manager is None:
|
|
||||||
api_server_manager_kwargs["stats_update_address"] = (
|
|
||||||
addresses.frontend_stats_publish_address
|
|
||||||
)
|
|
||||||
api_server_manager = APIServerProcessManager(**api_server_manager_kwargs)
|
|
||||||
|
|
||||||
# Wait for API servers
|
|
||||||
try:
|
try:
|
||||||
wait_for_completion_or_failure(
|
wait_for_completion_or_failure(
|
||||||
api_server_manager=api_server_manager,
|
api_server_manager=api_server_manager,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
import enum
|
import enum
|
||||||
import time
|
import time
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
@@ -63,17 +64,16 @@ class FinishReason(enum.IntEnum):
|
|||||||
return FINISH_REASON_STRINGS[self.value]
|
return FINISH_REASON_STRINGS[self.value]
|
||||||
|
|
||||||
|
|
||||||
class EngineCoreReadyResponse(
|
@dataclass
|
||||||
msgspec.Struct,
|
class EngineCoreReadyResponse:
|
||||||
array_like=True, # type: ignore[call-arg]
|
"""Sent from EngineCore to each frontend at the end of engine startup.
|
||||||
omit_defaults=True, # type: ignore[call-arg]
|
|
||||||
):
|
|
||||||
"""Sent from EngineCore to the frontend during the ready handshake.
|
|
||||||
|
|
||||||
Contains post-initialization config that may differ from the original
|
Contains post-initialization config that may differ from the original
|
||||||
values (e.g. max_model_len after KV cache auto-fitting).
|
values (e.g. max_model_len after KV cache auto-fitting).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
num_gpu_blocks: int
|
||||||
|
dp_stats_address: str | None
|
||||||
max_model_len: int | None = None
|
max_model_len: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -936,14 +936,20 @@ class EngineCoreProc(EngineCore):
|
|||||||
vllm_config.parallel_config,
|
vllm_config.parallel_config,
|
||||||
)
|
)
|
||||||
if client_handshake_address is None:
|
if client_handshake_address is None:
|
||||||
|
# We only need to handshake with one party.
|
||||||
with handshake as addresses:
|
with handshake as addresses:
|
||||||
yield addresses
|
yield addresses
|
||||||
else:
|
else:
|
||||||
|
# We need to handshake with rank 0 front-end and our colocated frontend.
|
||||||
assert local_client
|
assert local_client
|
||||||
local_handshake = self._perform_handshake(
|
local_handshake = self._perform_handshake(
|
||||||
input_ctx, client_handshake_address, identity, True, False, vllm_config
|
input_ctx, client_handshake_address, identity, True, False, vllm_config
|
||||||
)
|
)
|
||||||
with handshake as addresses, local_handshake as client_addresses:
|
with handshake as addresses, local_handshake as client_addresses:
|
||||||
|
# 1. Obtain DP Coordinator zmq address and DP process group address
|
||||||
|
# (addresses).
|
||||||
|
# 2. Add front-end input/output addresses from colocated front-end
|
||||||
|
# (client_addresses).
|
||||||
addresses.inputs = client_addresses.inputs
|
addresses.inputs = client_addresses.inputs
|
||||||
addresses.outputs = client_addresses.outputs
|
addresses.outputs = client_addresses.outputs
|
||||||
yield addresses
|
yield addresses
|
||||||
@@ -977,20 +983,12 @@ class EngineCoreProc(EngineCore):
|
|||||||
yield addresses
|
yield addresses
|
||||||
|
|
||||||
# Send ready message.
|
# Send ready message.
|
||||||
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
|
|
||||||
# We pass back the coordinator stats update address here for the
|
|
||||||
# external LB case for our colocated front-end to use (coordinator
|
|
||||||
# only runs with rank 0).
|
|
||||||
dp_stats_address = self.frontend_stats_publish_address
|
|
||||||
|
|
||||||
# Include config hash for DP configuration validation
|
|
||||||
ready_msg = {
|
ready_msg = {
|
||||||
"status": "READY",
|
"status": "READY",
|
||||||
"local": local_client,
|
"local": local_client,
|
||||||
"headless": headless,
|
"headless": headless,
|
||||||
"num_gpu_blocks": num_gpu_blocks,
|
|
||||||
"dp_stats_address": dp_stats_address,
|
|
||||||
}
|
}
|
||||||
|
# Include config hash for DP configuration validation
|
||||||
if vllm_config.parallel_config.data_parallel_size > 1:
|
if vllm_config.parallel_config.data_parallel_size > 1:
|
||||||
ready_msg["parallel_config_hash"] = (
|
ready_msg["parallel_config_hash"] = (
|
||||||
vllm_config.parallel_config.compute_hash()
|
vllm_config.parallel_config.compute_hash()
|
||||||
@@ -1388,6 +1386,8 @@ class EngineCoreProc(EngineCore):
|
|||||||
poller = zmq.Poller()
|
poller = zmq.Poller()
|
||||||
ready_response = EngineCoreReadyResponse(
|
ready_response = EngineCoreReadyResponse(
|
||||||
max_model_len=self.vllm_config.model_config.max_model_len,
|
max_model_len=self.vllm_config.model_config.max_model_len,
|
||||||
|
num_gpu_blocks=self.vllm_config.cache_config.num_gpu_blocks or 0,
|
||||||
|
dp_stats_address=self.frontend_stats_publish_address,
|
||||||
)
|
)
|
||||||
ready_payload = msgspec.msgpack.encode(ready_response)
|
ready_payload = msgspec.msgpack.encode(ready_response)
|
||||||
for input_socket in input_sockets:
|
for input_socket in input_sockets:
|
||||||
|
|||||||
@@ -457,22 +457,6 @@ class ElasticScalingCache:
|
|||||||
pending_notifications: dict[EEPNotificationType, set[int]]
|
pending_notifications: dict[EEPNotificationType, set[int]]
|
||||||
|
|
||||||
|
|
||||||
_ready_response_decoder = msgspec.msgpack.Decoder(EngineCoreReadyResponse)
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_ready_response(payload: bytes, vllm_config: VllmConfig) -> None:
|
|
||||||
"""Decode an EngineCoreReadyResponse and sync any post-initialization
|
|
||||||
config changes (e.g. auto-fitted max_model_len) back to the frontend."""
|
|
||||||
if not payload:
|
|
||||||
return
|
|
||||||
response = _ready_response_decoder.decode(payload)
|
|
||||||
if response.max_model_len is not None:
|
|
||||||
vllm_config.model_config.max_model_len = min(
|
|
||||||
vllm_config.model_config.max_model_len,
|
|
||||||
response.max_model_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MPClient(EngineCoreClient):
|
class MPClient(EngineCoreClient):
|
||||||
"""
|
"""
|
||||||
MPClient: base client for multi-proc EngineCore.
|
MPClient: base client for multi-proc EngineCore.
|
||||||
@@ -608,7 +592,7 @@ class MPClient(EngineCoreClient):
|
|||||||
)
|
)
|
||||||
identity, payload = sync_input_socket.recv_multipart()
|
identity, payload = sync_input_socket.recv_multipart()
|
||||||
identities.remove(identity)
|
identities.remove(identity)
|
||||||
_apply_ready_response(payload, vllm_config)
|
self._apply_ready_response(payload)
|
||||||
|
|
||||||
self.core_engine: EngineIdentity = self.core_engines[0]
|
self.core_engine: EngineIdentity = self.core_engines[0]
|
||||||
self.utility_results: dict[int, AnyFuture] = {}
|
self.utility_results: dict[int, AnyFuture] = {}
|
||||||
@@ -680,6 +664,34 @@ class MPClient(EngineCoreClient):
|
|||||||
target=monitor_engine_cores, daemon=True, name="MPClientEngineMonitor"
|
target=monitor_engine_cores, daemon=True, name="MPClientEngineMonitor"
|
||||||
).start()
|
).start()
|
||||||
|
|
||||||
|
def _apply_ready_response(self, payload: bytes) -> None:
|
||||||
|
"""Decode an EngineCoreReadyResponse and sync any post-initialization
|
||||||
|
config changes (e.g. auto-fitted max_model_len) back to the frontend."""
|
||||||
|
if not payload:
|
||||||
|
return
|
||||||
|
vllm_config = self.vllm_config
|
||||||
|
response = msgspec.msgpack.decode(payload, type=EngineCoreReadyResponse)
|
||||||
|
if response.max_model_len is not None:
|
||||||
|
vllm_config.model_config.max_model_len = min(
|
||||||
|
vllm_config.model_config.max_model_len,
|
||||||
|
response.max_model_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup KV cache config with initialization state from
|
||||||
|
# engine core process. Sum values from all engines in DP case.
|
||||||
|
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks or 0
|
||||||
|
num_gpu_blocks += response.num_gpu_blocks
|
||||||
|
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||||
|
|
||||||
|
# In external DP LB mode, the coordinator address that the
|
||||||
|
# front-end procs connect to is obtained by each engine via it's
|
||||||
|
# initial handshake with the rank 0 front-end.
|
||||||
|
if response.dp_stats_address is not None:
|
||||||
|
if self.stats_update_address is None:
|
||||||
|
self.stats_update_address = response.dp_stats_address
|
||||||
|
else:
|
||||||
|
assert response.dp_stats_address == self.stats_update_address
|
||||||
|
|
||||||
|
|
||||||
def _process_utility_output(
|
def _process_utility_output(
|
||||||
output: UtilityOutput, utility_results: dict[int, AnyFuture]
|
output: UtilityOutput, utility_results: dict[int, AnyFuture]
|
||||||
@@ -1602,7 +1614,7 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
|||||||
)
|
)
|
||||||
identity, payload = sync_input_socket.recv_multipart()
|
identity, payload = sync_input_socket.recv_multipart()
|
||||||
new_engine_identities.discard(identity)
|
new_engine_identities.discard(identity)
|
||||||
_apply_ready_response(payload, self.vllm_config)
|
self._apply_ready_response(payload)
|
||||||
|
|
||||||
# NOTE(yongji): Before we schedule any requests on the new workers,
|
# NOTE(yongji): Before we schedule any requests on the new workers,
|
||||||
# we should wait for them to switch to the new setup.
|
# we should wait for them to switch to the new setup.
|
||||||
|
|||||||
@@ -1212,19 +1212,6 @@ def wait_for_engine_startup(
|
|||||||
start_pending[0 if local else 1] += 1
|
start_pending[0 if local else 1] += 1
|
||||||
engine.state = CoreEngineState.CONNECTED
|
engine.state = CoreEngineState.CONNECTED
|
||||||
elif status == "READY" and engine.state == CoreEngineState.CONNECTED:
|
elif status == "READY" and engine.state == CoreEngineState.CONNECTED:
|
||||||
# Setup KV cache config with initialization state from
|
|
||||||
# engine core process. Sum values from all engines in DP case.
|
|
||||||
num_gpu_blocks = cache_config.num_gpu_blocks or 0
|
|
||||||
num_gpu_blocks += msg["num_gpu_blocks"]
|
|
||||||
cache_config.num_gpu_blocks = num_gpu_blocks
|
|
||||||
|
|
||||||
# In external DP LB mode, the coordinator address that the
|
|
||||||
# front-end procs connect to is obtained from rank 0 via
|
|
||||||
# one of the engine handshakes, and passed to the local
|
|
||||||
# front-end process in the response from the other.
|
|
||||||
if addresses.frontend_stats_publish_address is None:
|
|
||||||
addresses.frontend_stats_publish_address = msg.get("dp_stats_address")
|
|
||||||
|
|
||||||
# Validate config hash consistency across DP workers for MoE models.
|
# Validate config hash consistency across DP workers for MoE models.
|
||||||
if coordinated_dp:
|
if coordinated_dp:
|
||||||
worker_config_hash = msg.get("parallel_config_hash")
|
worker_config_hash = msg.get("parallel_config_hash")
|
||||||
|
|||||||
Reference in New Issue
Block a user