[BugFix] Improve internal DP load balancing (#21617)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-08-02 03:45:27 +01:00
committed by GitHub
parent 9f9c38c392
commit 8d524ce79f
7 changed files with 122 additions and 59 deletions

View File

@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import multiprocessing
import time
import weakref
@@ -65,18 +66,14 @@ class DPCoordinator:
# Assume coordinator is colocated with front-end procs when not in
# either external or hybrid DP LB mode.
local_only = not (external_lb or hybrid_lb)
front_publish_address = get_engine_client_zmq_addr(
local_only=not external_lb and not hybrid_lb, host=host)
local_only=local_only, host=host)
local_only_eng = dp_size == parallel_config.data_parallel_size_local
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
# When in external LB mode, load stats aren't published, only changes
# to request wave / running state, so we don't need to rate-limit the
# updates to the front-end proc(s).
min_stats_update_interval_ms = 0 if external_lb else 100
context = get_mp_context()
self.proc: multiprocessing.Process = context.Process(
target=DPCoordinatorProc.run_coordinator,
@@ -86,7 +83,6 @@ class DPCoordinator:
"front_publish_address": front_publish_address,
"back_output_address": back_output_address,
"back_publish_address": back_publish_address,
"min_stats_update_interval_ms": min_stats_update_interval_ms,
},
daemon=True)
self.proc.start()
@@ -125,10 +121,6 @@ class DPCoordinatorProc:
self.stats_update_interval_ms = min_stats_update_interval_ms
self.current_wave = 0
self.engines_running = False
self.stats_changed = False
@staticmethod
def run_coordinator(
engine_count: int,
@@ -155,6 +147,16 @@ class DPCoordinatorProc:
decoder = MsgpackDecoder(EngineCoreOutputs)
# For tracking request wave progression.
current_wave = 0
engines_running = False
# For tracking request counts for internal load-balancing.
stats_changed = False
last_stats_step = -1
last_stats_wave = -1
last_step_counts: Optional[list[list[int]]] = None
with make_zmq_socket(
path=front_publish_address, # IPC
ctx=self.ctx,
@@ -191,21 +193,33 @@ class DPCoordinatorProc:
while True:
elapsed = int(time.time() * 1000) - last_publish_time
# Send at stats_update_interval_ms interval if the stats have
# changed, or otherwise every 4 seconds.
# changed, or otherwise every 5 seconds.
wait_for = (self.stats_update_interval_ms
if self.stats_changed else 4000)
events = poller.poll(timeout=max(0, wait_for - elapsed))
if stats_changed else 5000)
# Wait at least 50ms to ensure we've received all stats for
# the current step.
min_timeout = 50 if last_step_counts is None else 0
events = poller.poll(timeout=max(min_timeout, wait_for -
elapsed))
if not events:
# Poller timeout - publish current stats to front-ends.
engine_req_counts_list = self._get_engine_counts()
to_publish = (engine_req_counts_list, self.current_wave,
self.engines_running)
if last_step_counts is not None:
engine_req_counts_list = last_step_counts
last_step_counts = None
else:
engine_req_counts_list = self._get_engine_counts()
stats_changed = False
to_publish = (engine_req_counts_list, current_wave,
engines_running)
publish_front.send(msgspec.msgpack.encode(to_publish))
last_publish_time = int(time.time() * 1000)
self.stats_changed = False
continue
events = dict(events)
wave_state_changed = False
if publish_front in events:
buffer = publish_front.recv()
@@ -232,7 +246,7 @@ class DPCoordinatorProc:
# current_wave
# we note that 0 is the wave number for the new
# engine
self.engines_running = False
engines_running = False
logger.info(
"DPCoordinator scaled up from %s to %s "
"engines", current_count, new_engine_count)
@@ -248,15 +262,15 @@ class DPCoordinatorProc:
# engines are paused, so that we can wake the other
# engines.
engine_to_exclude, wave = decoded
if not self.engines_running:
if wave < self.current_wave:
if not engines_running:
if wave < current_wave:
# If the wave number is stale, ensure the message
# is handled by all the engines.
engine_to_exclude = None
self.engines_running = True
self.stats_changed = True
self._send_start_wave(publish_back, self.current_wave,
engines_running = True
wave_state_changed = True
self._send_start_wave(publish_back, current_wave,
engine_to_exclude)
if output_back in events:
@@ -274,36 +288,56 @@ class DPCoordinatorProc:
# 1. Updated request load stats - update our local
# state with these.
stats = self.engines[eng_index].request_counts
stats_step = scheduler_stats.step_counter
stats_wave = scheduler_stats.current_wave
if (stats_wave > last_stats_wave
or stats_wave == last_stats_wave
and stats_step > last_stats_step):
if stats_changed:
last_step_counts = self._get_engine_counts(
do_copy=True)
last_stats_step = stats_step
last_stats_wave = stats_wave
elif stats_wave != last_stats_wave or (
stats_step != last_stats_step):
logger.warning(
"Received stats for out-of-order "
"step (%d, %d) from engine %d (expected "
"> (%d, %d))", stats_wave, stats_step,
eng_index, last_stats_wave, last_stats_step)
stats[0] = scheduler_stats.num_waiting_reqs
stats[1] = scheduler_stats.num_running_reqs
self.stats_changed = True
stats_changed = True
if (wave := outputs.wave_complete) is not None:
# 2. Notification from rank 0 engine that we've
# moved into the global paused state
# (engines_running==False).
if self.current_wave <= wave:
if current_wave <= wave:
new_wave = wave + 1
logger.debug("Moving DP wave from %d to %d.",
self.current_wave, new_wave)
self.current_wave = new_wave
self.engines_running = False
self.stats_changed = True
current_wave, new_wave)
current_wave = new_wave
engines_running = False
wave_state_changed = True
elif (wave := outputs.start_wave) is not None and (
wave > self.current_wave or
(wave == self.current_wave
and not self.engines_running)):
wave > current_wave or
(wave == current_wave and not engines_running)):
# 3. The engine received request for a non-current wave
# so we must ensure that other engines progress to the
# next wave (race condition handling).
logger.debug(
"Starting wave %d after notification of "
"stale wave request from engine.", wave)
self.current_wave = wave
self.engines_running = True
self.stats_changed = True
current_wave = wave
engines_running = True
wave_state_changed = True
self._send_start_wave(publish_back, wave, eng_index)
if wave_state_changed:
message = (None, current_wave, engines_running)
publish_front.send(msgspec.msgpack.encode(message))
@staticmethod
def _send_start_wave(socket: zmq.Socket, wave: int,
exclude_engine_index: Optional[int]):
@@ -316,6 +350,8 @@ class DPCoordinatorProc:
socket.send_multipart(
(EngineCoreRequestType.START_DP_WAVE.value, wave_encoded))
def _get_engine_counts(self) -> list[list[int]]:
def _get_engine_counts(self, do_copy=False) -> list[list[int]]:
"""Return list of [waiting, running] count lists for each engine."""
if do_copy:
return [copy.copy(e.request_counts) for e in self.engines]
return [e.request_counts for e in self.engines]