Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -56,7 +56,6 @@ class DPCoordinator:
|
||||
"""
|
||||
|
||||
def __init__(self, parallel_config: ParallelConfig):
|
||||
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
assert dp_size > 1, "Coordinator only used for data parallel"
|
||||
|
||||
@@ -68,7 +67,8 @@ class DPCoordinator:
|
||||
# either external or hybrid DP LB mode.
|
||||
local_only = not (external_lb or hybrid_lb)
|
||||
front_publish_address = get_engine_client_zmq_addr(
|
||||
local_only=local_only, host=host)
|
||||
local_only=local_only, host=host
|
||||
)
|
||||
|
||||
local_only_eng = dp_size == parallel_config.data_parallel_size_local
|
||||
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
|
||||
@@ -84,7 +84,8 @@ class DPCoordinator:
|
||||
"back_output_address": back_output_address,
|
||||
"back_publish_address": back_publish_address,
|
||||
},
|
||||
daemon=True)
|
||||
daemon=True,
|
||||
)
|
||||
self.proc.start()
|
||||
|
||||
self.stats_publish_address = front_publish_address
|
||||
@@ -104,16 +105,12 @@ class DPCoordinator:
|
||||
|
||||
|
||||
class EngineState:
|
||||
|
||||
def __init__(self):
|
||||
self.request_counts = [0, 0] # [waiting, running]
|
||||
|
||||
|
||||
class DPCoordinatorProc:
|
||||
|
||||
def __init__(self,
|
||||
engine_count: int,
|
||||
min_stats_update_interval_ms: int = 100):
|
||||
def __init__(self, engine_count: int, min_stats_update_interval_ms: int = 100):
|
||||
set_process_title("DPCoordinator")
|
||||
self.ctx = zmq.Context()
|
||||
|
||||
@@ -131,7 +128,8 @@ class DPCoordinatorProc:
|
||||
):
|
||||
coordinator = DPCoordinatorProc(
|
||||
engine_count=engine_count,
|
||||
min_stats_update_interval_ms=min_stats_update_interval_ms)
|
||||
min_stats_update_interval_ms=min_stats_update_interval_ms,
|
||||
)
|
||||
try:
|
||||
coordinator.process_input_socket(
|
||||
front_publish_address,
|
||||
@@ -141,10 +139,12 @@ class DPCoordinatorProc:
|
||||
except KeyboardInterrupt:
|
||||
logger.info("DP Coordinator process exiting")
|
||||
|
||||
def process_input_socket(self, front_publish_address: str,
|
||||
back_output_address: str,
|
||||
back_publish_address: str):
|
||||
|
||||
def process_input_socket(
|
||||
self,
|
||||
front_publish_address: str,
|
||||
back_output_address: str,
|
||||
back_publish_address: str,
|
||||
):
|
||||
decoder = MsgpackDecoder(EngineCoreOutputs)
|
||||
|
||||
# For tracking request wave progression.
|
||||
@@ -157,29 +157,33 @@ class DPCoordinatorProc:
|
||||
last_stats_wave = -1
|
||||
last_step_counts: Optional[list[list[int]]] = None
|
||||
|
||||
with make_zmq_socket(
|
||||
with (
|
||||
make_zmq_socket(
|
||||
path=front_publish_address, # IPC
|
||||
ctx=self.ctx,
|
||||
socket_type=zmq.XPUB,
|
||||
bind=True,
|
||||
) as publish_front, make_zmq_socket(
|
||||
) as publish_front,
|
||||
make_zmq_socket(
|
||||
path=back_output_address, # IPC or TCP
|
||||
ctx=self.ctx,
|
||||
socket_type=zmq.PULL,
|
||||
bind=True,
|
||||
) as output_back, make_zmq_socket(
|
||||
) as output_back,
|
||||
make_zmq_socket(
|
||||
path=back_publish_address, # IPC or TCP
|
||||
ctx=self.ctx,
|
||||
socket_type=zmq.XPUB,
|
||||
bind=True,
|
||||
) as publish_back:
|
||||
|
||||
) as publish_back,
|
||||
):
|
||||
# Wait until all engines subscribe.
|
||||
for _ in self.engines:
|
||||
if publish_back.recv() != b'\x01':
|
||||
if publish_back.recv() != b"\x01":
|
||||
logger.error(
|
||||
"DP Coordinator received unexpected message while "
|
||||
"waiting for engines to subscribe")
|
||||
"waiting for engines to subscribe"
|
||||
)
|
||||
return
|
||||
# Send ready message to engines.
|
||||
publish_back.send(b"READY")
|
||||
@@ -194,15 +198,13 @@ class DPCoordinatorProc:
|
||||
elapsed = int(time.time() * 1000) - last_publish_time
|
||||
# Send at stats_update_interval_ms interval if the stats have
|
||||
# changed, or otherwise every 5 seconds.
|
||||
wait_for = (self.stats_update_interval_ms
|
||||
if stats_changed else 5000)
|
||||
wait_for = self.stats_update_interval_ms if stats_changed else 5000
|
||||
|
||||
# Wait at least 50ms to ensure we've received all stats for
|
||||
# the current step.
|
||||
min_timeout = 50 if last_step_counts is None else 0
|
||||
|
||||
events = poller.poll(timeout=max(min_timeout, wait_for -
|
||||
elapsed))
|
||||
events = poller.poll(timeout=max(min_timeout, wait_for - elapsed))
|
||||
if not events:
|
||||
# Poller timeout - publish current stats to front-ends.
|
||||
if last_step_counts is not None:
|
||||
@@ -212,8 +214,7 @@ class DPCoordinatorProc:
|
||||
engine_req_counts_list = self._get_engine_counts()
|
||||
stats_changed = False
|
||||
|
||||
to_publish = (engine_req_counts_list, current_wave,
|
||||
engines_running)
|
||||
to_publish = (engine_req_counts_list, current_wave, engines_running)
|
||||
publish_front.send(msgspec.msgpack.encode(to_publish))
|
||||
last_publish_time = int(time.time() * 1000)
|
||||
continue
|
||||
@@ -223,13 +224,16 @@ class DPCoordinatorProc:
|
||||
|
||||
if publish_front in events:
|
||||
buffer = publish_front.recv()
|
||||
if buffer in (b'\x01', b'\x00'):
|
||||
if buffer in (b"\x01", b"\x00"):
|
||||
# Ignore subscription messages.
|
||||
continue
|
||||
|
||||
decoded = msgspec.msgpack.decode(buffer)
|
||||
if isinstance(decoded, (list, tuple)) and len(
|
||||
decoded) == 2 and decoded[0] == "SCALE_ELASTIC_EP":
|
||||
if (
|
||||
isinstance(decoded, (list, tuple))
|
||||
and len(decoded) == 2
|
||||
and decoded[0] == "SCALE_ELASTIC_EP"
|
||||
):
|
||||
# Handle scale up notification
|
||||
new_engine_count = decoded[1]
|
||||
current_count = len(self.engines)
|
||||
@@ -248,13 +252,17 @@ class DPCoordinatorProc:
|
||||
# engine
|
||||
engines_running = False
|
||||
logger.info(
|
||||
"DPCoordinator scaled up from %s to %s "
|
||||
"engines", current_count, new_engine_count)
|
||||
"DPCoordinator scaled up from %s to %s engines",
|
||||
current_count,
|
||||
new_engine_count,
|
||||
)
|
||||
else:
|
||||
self.engines = self.engines[:new_engine_count]
|
||||
logger.info(
|
||||
"DPCoordinator scaled down from %s to %s "
|
||||
"engines", current_count, new_engine_count)
|
||||
"DPCoordinator scaled down from %s to %s engines",
|
||||
current_count,
|
||||
new_engine_count,
|
||||
)
|
||||
continue # Skip normal engine notification processing
|
||||
|
||||
# We received a message on the front-end XPUB socket,
|
||||
@@ -270,8 +278,9 @@ class DPCoordinatorProc:
|
||||
|
||||
engines_running = True
|
||||
wave_state_changed = True
|
||||
self._send_start_wave(publish_back, current_wave,
|
||||
engine_to_exclude)
|
||||
self._send_start_wave(
|
||||
publish_back, current_wave, engine_to_exclude
|
||||
)
|
||||
|
||||
if output_back in events:
|
||||
# We received a message from one of the engines.
|
||||
@@ -290,21 +299,28 @@ class DPCoordinatorProc:
|
||||
stats = self.engines[eng_index].request_counts
|
||||
stats_step = scheduler_stats.step_counter
|
||||
stats_wave = scheduler_stats.current_wave
|
||||
if (stats_wave > last_stats_wave
|
||||
or stats_wave == last_stats_wave
|
||||
and stats_step > last_stats_step):
|
||||
if (
|
||||
stats_wave > last_stats_wave
|
||||
or stats_wave == last_stats_wave
|
||||
and stats_step > last_stats_step
|
||||
):
|
||||
if stats_changed:
|
||||
last_step_counts = self._get_engine_counts(
|
||||
do_copy=True)
|
||||
last_step_counts = self._get_engine_counts(do_copy=True)
|
||||
last_stats_step = stats_step
|
||||
last_stats_wave = stats_wave
|
||||
elif stats_wave != last_stats_wave or (
|
||||
stats_step != last_stats_step):
|
||||
stats_step != last_stats_step
|
||||
):
|
||||
logger.warning(
|
||||
"Received stats for out-of-order "
|
||||
"step (%d, %d) from engine %d (expected "
|
||||
"> (%d, %d))", stats_wave, stats_step,
|
||||
eng_index, last_stats_wave, last_stats_step)
|
||||
"> (%d, %d))",
|
||||
stats_wave,
|
||||
stats_step,
|
||||
eng_index,
|
||||
last_stats_wave,
|
||||
last_stats_step,
|
||||
)
|
||||
stats[0] = scheduler_stats.num_waiting_reqs
|
||||
stats[1] = scheduler_stats.num_running_reqs
|
||||
stats_changed = True
|
||||
@@ -315,20 +331,24 @@ class DPCoordinatorProc:
|
||||
# (engines_running==False).
|
||||
if current_wave <= wave:
|
||||
new_wave = wave + 1
|
||||
logger.debug("Moving DP wave from %d to %d.",
|
||||
current_wave, new_wave)
|
||||
logger.debug(
|
||||
"Moving DP wave from %d to %d.", current_wave, new_wave
|
||||
)
|
||||
current_wave = new_wave
|
||||
engines_running = False
|
||||
wave_state_changed = True
|
||||
elif (wave := outputs.start_wave) is not None and (
|
||||
wave > current_wave or
|
||||
(wave == current_wave and not engines_running)):
|
||||
wave > current_wave
|
||||
or (wave == current_wave and not engines_running)
|
||||
):
|
||||
# 3. The engine received request for a non-current wave
|
||||
# so we must ensure that other engines progress to the
|
||||
# next wave (race condition handling).
|
||||
logger.debug(
|
||||
"Starting wave %d after notification of "
|
||||
"stale wave request from engine.", wave)
|
||||
"stale wave request from engine.",
|
||||
wave,
|
||||
)
|
||||
current_wave = wave
|
||||
engines_running = True
|
||||
wave_state_changed = True
|
||||
@@ -339,16 +359,16 @@ class DPCoordinatorProc:
|
||||
publish_front.send(msgspec.msgpack.encode(message))
|
||||
|
||||
@staticmethod
|
||||
def _send_start_wave(socket: zmq.Socket, wave: int,
|
||||
exclude_engine_index: Optional[int]):
|
||||
def _send_start_wave(
|
||||
socket: zmq.Socket, wave: int, exclude_engine_index: Optional[int]
|
||||
):
|
||||
"""Broadcast the START_DP_WAVE message to all the engines.
|
||||
It includes the current wave number and index of engine which
|
||||
has already received a request with this wave number and so doesn't
|
||||
require additional notification.
|
||||
"""
|
||||
wave_encoded = msgspec.msgpack.encode((wave, exclude_engine_index))
|
||||
socket.send_multipart(
|
||||
(EngineCoreRequestType.START_DP_WAVE.value, wave_encoded))
|
||||
socket.send_multipart((EngineCoreRequestType.START_DP_WAVE.value, wave_encoded))
|
||||
|
||||
def _get_engine_counts(self, do_copy=False) -> list[list[int]]:
|
||||
"""Return list of [waiting, running] count lists for each engine."""
|
||||
|
||||
Reference in New Issue
Block a user