[BugFix] Support online dense model DP without overhead (#30739)
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: njhill <nickhill123@gmail.com>
This commit is contained in:
@@ -55,7 +55,9 @@ class DPCoordinator:
|
||||
request wave / running state changes.
|
||||
"""
|
||||
|
||||
def __init__(self, parallel_config: ParallelConfig):
|
||||
def __init__(
|
||||
self, parallel_config: ParallelConfig, enable_wave_coordination: bool = True
|
||||
):
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
assert dp_size > 1, "Coordinator only used for data parallel"
|
||||
|
||||
@@ -83,6 +85,7 @@ class DPCoordinator:
|
||||
"front_publish_address": front_publish_address,
|
||||
"back_output_address": back_output_address,
|
||||
"back_publish_address": back_publish_address,
|
||||
"enable_wave_coordination": enable_wave_coordination,
|
||||
},
|
||||
daemon=True,
|
||||
)
|
||||
@@ -110,13 +113,19 @@ class EngineState:
|
||||
|
||||
|
||||
class DPCoordinatorProc:
|
||||
def __init__(self, engine_count: int, min_stats_update_interval_ms: int = 100):
|
||||
def __init__(
|
||||
self,
|
||||
engine_count: int,
|
||||
min_stats_update_interval_ms: int = 100,
|
||||
enable_wave_coordination: bool = True,
|
||||
):
|
||||
set_process_title("DPCoordinator")
|
||||
self.ctx = zmq.Context()
|
||||
|
||||
self.engines = [EngineState() for _ in range(engine_count)]
|
||||
|
||||
self.stats_update_interval_ms = min_stats_update_interval_ms
|
||||
self.enable_wave_coordination = enable_wave_coordination
|
||||
|
||||
@staticmethod
|
||||
def run_coordinator(
|
||||
@@ -125,10 +134,12 @@ class DPCoordinatorProc:
|
||||
back_output_address: str,
|
||||
back_publish_address: str,
|
||||
min_stats_update_interval_ms: int = 100,
|
||||
enable_wave_coordination: bool = True,
|
||||
):
|
||||
coordinator = DPCoordinatorProc(
|
||||
engine_count=engine_count,
|
||||
min_stats_update_interval_ms=min_stats_update_interval_ms,
|
||||
enable_wave_coordination=enable_wave_coordination,
|
||||
)
|
||||
try:
|
||||
coordinator.process_input_socket(
|
||||
@@ -265,22 +276,25 @@ class DPCoordinatorProc:
|
||||
)
|
||||
continue # Skip normal engine notification processing
|
||||
|
||||
# We received a message on the front-end XPUB socket,
|
||||
# from an API server sending a new request while the
|
||||
# engines are paused, so that we can wake the other
|
||||
# engines.
|
||||
engine_to_exclude, wave = decoded
|
||||
if not engines_running:
|
||||
if wave < current_wave:
|
||||
# If the wave number is stale, ensure the message
|
||||
# is handled by all the engines.
|
||||
engine_to_exclude = None
|
||||
# Wave coordination: handle new-request messages from front-end.
|
||||
# Only process these when wave coordination is enabled
|
||||
if self.enable_wave_coordination:
|
||||
# We received a message on the front-end XPUB socket,
|
||||
# from an API server sending a new request while the
|
||||
# engines are paused, so that we can wake the other
|
||||
# engines.
|
||||
engine_to_exclude, wave = decoded
|
||||
if not engines_running:
|
||||
if wave < current_wave:
|
||||
# If the wave number is stale, ensure the message
|
||||
# is handled by all the engines.
|
||||
engine_to_exclude = None
|
||||
|
||||
engines_running = True
|
||||
wave_state_changed = True
|
||||
self._send_start_wave(
|
||||
publish_back, current_wave, engine_to_exclude
|
||||
)
|
||||
engines_running = True
|
||||
wave_state_changed = True
|
||||
self._send_start_wave(
|
||||
publish_back, current_wave, engine_to_exclude
|
||||
)
|
||||
|
||||
if output_back in events:
|
||||
# We received a message from one of the engines.
|
||||
@@ -325,34 +339,39 @@ class DPCoordinatorProc:
|
||||
stats[1] = scheduler_stats.num_running_reqs
|
||||
stats_changed = True
|
||||
|
||||
if (wave := outputs.wave_complete) is not None:
|
||||
# 2. Notification from rank 0 engine that we've
|
||||
# moved into the global paused state
|
||||
# (engines_running==False).
|
||||
if current_wave <= wave:
|
||||
new_wave = wave + 1
|
||||
# Wave coordination: handle wave completion and start notifications
|
||||
# Only process these when wave coordination is enabled
|
||||
if self.enable_wave_coordination:
|
||||
if (wave := outputs.wave_complete) is not None:
|
||||
# 2. Notification from rank 0 engine that we've
|
||||
# moved into the global paused state
|
||||
# (engines_running==False).
|
||||
if current_wave <= wave:
|
||||
new_wave = wave + 1
|
||||
logger.debug(
|
||||
"Moving DP wave from %d to %d.",
|
||||
current_wave,
|
||||
new_wave,
|
||||
)
|
||||
current_wave = new_wave
|
||||
engines_running = False
|
||||
wave_state_changed = True
|
||||
elif (wave := outputs.start_wave) is not None and (
|
||||
wave > current_wave
|
||||
or (wave == current_wave and not engines_running)
|
||||
):
|
||||
# 3. The engine received request for a non-current wave
|
||||
# so we must ensure that other engines progress to the
|
||||
# next wave (race condition handling).
|
||||
logger.debug(
|
||||
"Moving DP wave from %d to %d.", current_wave, new_wave
|
||||
"Starting wave %d after notification of "
|
||||
"stale wave request from engine.",
|
||||
wave,
|
||||
)
|
||||
current_wave = new_wave
|
||||
engines_running = False
|
||||
current_wave = wave
|
||||
engines_running = True
|
||||
wave_state_changed = True
|
||||
elif (wave := outputs.start_wave) is not None and (
|
||||
wave > current_wave
|
||||
or (wave == current_wave and not engines_running)
|
||||
):
|
||||
# 3. The engine received request for a non-current wave
|
||||
# so we must ensure that other engines progress to the
|
||||
# next wave (race condition handling).
|
||||
logger.debug(
|
||||
"Starting wave %d after notification of "
|
||||
"stale wave request from engine.",
|
||||
wave,
|
||||
)
|
||||
current_wave = wave
|
||||
engines_running = True
|
||||
wave_state_changed = True
|
||||
self._send_start_wave(publish_back, wave, eng_index)
|
||||
self._send_start_wave(publish_back, wave, eng_index)
|
||||
|
||||
if wave_state_changed:
|
||||
message = (None, current_wave, engines_running)
|
||||
|
||||
Reference in New Issue
Block a user