[DP] Internal Load Balancing Per Node [one-pod-per-node] (#21238)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Robert Shaw
2025-07-23 23:57:32 -04:00
committed by GitHub
parent eec6942014
commit d5b981f8b1
12 changed files with 486 additions and 45 deletions

View File

@@ -544,7 +544,8 @@ def launch_core_engines(
local_start_index = parallel_config.data_parallel_rank_local
dp_rank = parallel_config.data_parallel_rank
host = parallel_config.data_parallel_master_ip
external_dp_lb = parallel_config.data_parallel_external_lb
local_engines_only = (parallel_config.data_parallel_hybrid_lb
or parallel_config.data_parallel_external_lb)
# In offline mode there is an LLM instance per DP rank and
# one core engine per LLM, see
@@ -553,8 +554,8 @@ def launch_core_engines(
# client_local_only = True for cases where this front-end
# sends requests only to colocated engines.
client_local_only = offline_mode or external_dp_lb or (local_engine_count
== dp_size)
client_local_only = (offline_mode or local_engines_only
or (local_engine_count == dp_size))
# Set up input and output addresses.
addresses = EngineZmqAddresses(
@@ -598,14 +599,27 @@ def launch_core_engines(
yield engine_actor_manager, coordinator, addresses
return
if offline_mode or (external_dp_lb and dp_rank > 0):
if offline_mode:
assert local_engine_count == 1
engines_to_handshake = [CoreEngine(index=dp_rank, local=True)]
else:
elif dp_rank == 0:
# Rank 0 holds Coordinator, so it handshakes with all Cores
# in both external dplb and internal dplb mode.
# Note this also covers the case where we have zero local engines
# and rank 0 is headless.
engines_to_handshake = [
CoreEngine(index=i, local=(i < local_engine_count))
for i in range(dp_size)
]
else:
# Rank > 0 handshakes with just the local cores it is managing.
assert local_engines_only, (
"Attempting to launch core_engines from dp_rank > 0, but "
"found internal DPLB, which is incompatible.")
engines_to_handshake = [
CoreEngine(index=i, local=True)
for i in range(dp_rank, dp_rank + local_engine_count)
]
# Whether the started engines will handshake only with co-located
# front-end processes. In external_dp_lb mode, ranks > 0 handshake with
@@ -616,7 +630,7 @@ def launch_core_engines(
handshake_address = get_engine_client_zmq_addr(
handshake_local_only, host, parallel_config.data_parallel_rpc_port)
if external_dp_lb and dp_rank > 0:
if local_engines_only and dp_rank > 0:
assert not handshake_local_only
local_handshake_address = get_open_zmq_ipc_path()
client_handshake_address = local_handshake_address
@@ -631,8 +645,6 @@ def launch_core_engines(
# Start local engines.
if local_engine_count:
# In server mode, start_index and local_start_index will
# both be 0.
local_engine_manager = CoreEngineProcManager(
EngineCoreProc.run_engine_core,
vllm_config=vllm_config,
@@ -678,6 +690,9 @@ def wait_for_engine_startup(
poller = zmq.Poller()
poller.register(handshake_socket, zmq.POLLIN)
remote_should_be_headless = not parallel_config.data_parallel_hybrid_lb \
and not parallel_config.data_parallel_external_lb
if proc_manager is not None:
for sentinel in proc_manager.sentinels():
poller.register(sentinel, zmq.POLLIN)
@@ -713,13 +728,24 @@ def wait_for_engine_startup(
raise RuntimeError(f"Message from engine with unexpected data "
f"parallel rank: {eng_index}")
msg = msgspec.msgpack.decode(ready_msg_bytes)
status, local = msg["status"], msg["local"]
status, local, headless = msg["status"], msg["local"], msg["headless"]
if local != engine.local:
raise RuntimeError(f"{status} message from "
f"{'local' if local else 'remote'} "
f"engine {eng_index}, expected it to be "
f"{'local' if engine.local else 'remote'}")
# Remote engines must be headless iff we aren't in hybrid dp lb mode.
if not local and headless != remote_should_be_headless:
if headless:
raise RuntimeError(f"Remote engine {eng_index} must not use "
f"--headless in external or hybrid dp lb "
f"mode")
else:
raise RuntimeError(f"Remote engine {eng_index} must use "
f"--headless unless in external or hybrid "
f"dp lb mode")
if status == "HELLO" and engine.state == CoreEngineState.NEW:
# Send init message with DP config info.