Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -56,7 +56,6 @@ class DPCoordinator:
"""
def __init__(self, parallel_config: ParallelConfig):
dp_size = parallel_config.data_parallel_size
assert dp_size > 1, "Coordinator only used for data parallel"
@@ -68,7 +67,8 @@ class DPCoordinator:
# either external or hybrid DP LB mode.
local_only = not (external_lb or hybrid_lb)
front_publish_address = get_engine_client_zmq_addr(
local_only=local_only, host=host)
local_only=local_only, host=host
)
local_only_eng = dp_size == parallel_config.data_parallel_size_local
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
@@ -84,7 +84,8 @@ class DPCoordinator:
"back_output_address": back_output_address,
"back_publish_address": back_publish_address,
},
daemon=True)
daemon=True,
)
self.proc.start()
self.stats_publish_address = front_publish_address
@@ -104,16 +105,12 @@ class DPCoordinator:
class EngineState:
def __init__(self):
self.request_counts = [0, 0] # [waiting, running]
class DPCoordinatorProc:
def __init__(self,
engine_count: int,
min_stats_update_interval_ms: int = 100):
def __init__(self, engine_count: int, min_stats_update_interval_ms: int = 100):
set_process_title("DPCoordinator")
self.ctx = zmq.Context()
@@ -131,7 +128,8 @@ class DPCoordinatorProc:
):
coordinator = DPCoordinatorProc(
engine_count=engine_count,
min_stats_update_interval_ms=min_stats_update_interval_ms)
min_stats_update_interval_ms=min_stats_update_interval_ms,
)
try:
coordinator.process_input_socket(
front_publish_address,
@@ -141,10 +139,12 @@ class DPCoordinatorProc:
except KeyboardInterrupt:
logger.info("DP Coordinator process exiting")
def process_input_socket(self, front_publish_address: str,
back_output_address: str,
back_publish_address: str):
def process_input_socket(
self,
front_publish_address: str,
back_output_address: str,
back_publish_address: str,
):
decoder = MsgpackDecoder(EngineCoreOutputs)
# For tracking request wave progression.
@@ -157,29 +157,33 @@ class DPCoordinatorProc:
last_stats_wave = -1
last_step_counts: Optional[list[list[int]]] = None
with make_zmq_socket(
with (
make_zmq_socket(
path=front_publish_address, # IPC
ctx=self.ctx,
socket_type=zmq.XPUB,
bind=True,
) as publish_front, make_zmq_socket(
) as publish_front,
make_zmq_socket(
path=back_output_address, # IPC or TCP
ctx=self.ctx,
socket_type=zmq.PULL,
bind=True,
) as output_back, make_zmq_socket(
) as output_back,
make_zmq_socket(
path=back_publish_address, # IPC or TCP
ctx=self.ctx,
socket_type=zmq.XPUB,
bind=True,
) as publish_back:
) as publish_back,
):
# Wait until all engines subscribe.
for _ in self.engines:
if publish_back.recv() != b'\x01':
if publish_back.recv() != b"\x01":
logger.error(
"DP Coordinator received unexpected message while "
"waiting for engines to subscribe")
"waiting for engines to subscribe"
)
return
# Send ready message to engines.
publish_back.send(b"READY")
@@ -194,15 +198,13 @@ class DPCoordinatorProc:
elapsed = int(time.time() * 1000) - last_publish_time
# Send at stats_update_interval_ms interval if the stats have
# changed, or otherwise every 5 seconds.
wait_for = (self.stats_update_interval_ms
if stats_changed else 5000)
wait_for = self.stats_update_interval_ms if stats_changed else 5000
# Wait at least 50ms to ensure we've received all stats for
# the current step.
min_timeout = 50 if last_step_counts is None else 0
events = poller.poll(timeout=max(min_timeout, wait_for -
elapsed))
events = poller.poll(timeout=max(min_timeout, wait_for - elapsed))
if not events:
# Poller timeout - publish current stats to front-ends.
if last_step_counts is not None:
@@ -212,8 +214,7 @@ class DPCoordinatorProc:
engine_req_counts_list = self._get_engine_counts()
stats_changed = False
to_publish = (engine_req_counts_list, current_wave,
engines_running)
to_publish = (engine_req_counts_list, current_wave, engines_running)
publish_front.send(msgspec.msgpack.encode(to_publish))
last_publish_time = int(time.time() * 1000)
continue
@@ -223,13 +224,16 @@ class DPCoordinatorProc:
if publish_front in events:
buffer = publish_front.recv()
if buffer in (b'\x01', b'\x00'):
if buffer in (b"\x01", b"\x00"):
# Ignore subscription messages.
continue
decoded = msgspec.msgpack.decode(buffer)
if isinstance(decoded, (list, tuple)) and len(
decoded) == 2 and decoded[0] == "SCALE_ELASTIC_EP":
if (
isinstance(decoded, (list, tuple))
and len(decoded) == 2
and decoded[0] == "SCALE_ELASTIC_EP"
):
# Handle scale up notification
new_engine_count = decoded[1]
current_count = len(self.engines)
@@ -248,13 +252,17 @@ class DPCoordinatorProc:
# engine
engines_running = False
logger.info(
"DPCoordinator scaled up from %s to %s "
"engines", current_count, new_engine_count)
"DPCoordinator scaled up from %s to %s engines",
current_count,
new_engine_count,
)
else:
self.engines = self.engines[:new_engine_count]
logger.info(
"DPCoordinator scaled down from %s to %s "
"engines", current_count, new_engine_count)
"DPCoordinator scaled down from %s to %s engines",
current_count,
new_engine_count,
)
continue # Skip normal engine notification processing
# We received a message on the front-end XPUB socket,
@@ -270,8 +278,9 @@ class DPCoordinatorProc:
engines_running = True
wave_state_changed = True
self._send_start_wave(publish_back, current_wave,
engine_to_exclude)
self._send_start_wave(
publish_back, current_wave, engine_to_exclude
)
if output_back in events:
# We received a message from one of the engines.
@@ -290,21 +299,28 @@ class DPCoordinatorProc:
stats = self.engines[eng_index].request_counts
stats_step = scheduler_stats.step_counter
stats_wave = scheduler_stats.current_wave
if (stats_wave > last_stats_wave
or stats_wave == last_stats_wave
and stats_step > last_stats_step):
if (
stats_wave > last_stats_wave
or stats_wave == last_stats_wave
and stats_step > last_stats_step
):
if stats_changed:
last_step_counts = self._get_engine_counts(
do_copy=True)
last_step_counts = self._get_engine_counts(do_copy=True)
last_stats_step = stats_step
last_stats_wave = stats_wave
elif stats_wave != last_stats_wave or (
stats_step != last_stats_step):
stats_step != last_stats_step
):
logger.warning(
"Received stats for out-of-order "
"step (%d, %d) from engine %d (expected "
"> (%d, %d))", stats_wave, stats_step,
eng_index, last_stats_wave, last_stats_step)
"> (%d, %d))",
stats_wave,
stats_step,
eng_index,
last_stats_wave,
last_stats_step,
)
stats[0] = scheduler_stats.num_waiting_reqs
stats[1] = scheduler_stats.num_running_reqs
stats_changed = True
@@ -315,20 +331,24 @@ class DPCoordinatorProc:
# (engines_running==False).
if current_wave <= wave:
new_wave = wave + 1
logger.debug("Moving DP wave from %d to %d.",
current_wave, new_wave)
logger.debug(
"Moving DP wave from %d to %d.", current_wave, new_wave
)
current_wave = new_wave
engines_running = False
wave_state_changed = True
elif (wave := outputs.start_wave) is not None and (
wave > current_wave or
(wave == current_wave and not engines_running)):
wave > current_wave
or (wave == current_wave and not engines_running)
):
# 3. The engine received request for a non-current wave
# so we must ensure that other engines progress to the
# next wave (race condition handling).
logger.debug(
"Starting wave %d after notification of "
"stale wave request from engine.", wave)
"stale wave request from engine.",
wave,
)
current_wave = wave
engines_running = True
wave_state_changed = True
@@ -339,16 +359,16 @@ class DPCoordinatorProc:
publish_front.send(msgspec.msgpack.encode(message))
@staticmethod
def _send_start_wave(socket: zmq.Socket, wave: int,
exclude_engine_index: Optional[int]):
def _send_start_wave(
socket: zmq.Socket, wave: int, exclude_engine_index: Optional[int]
):
"""Broadcast the START_DP_WAVE message to all the engines.
It includes the current wave number and index of engine which
has already received a request with this wave number and so doesn't
require additional notification.
"""
wave_encoded = msgspec.msgpack.encode((wave, exclude_engine_index))
socket.send_multipart(
(EngineCoreRequestType.START_DP_WAVE.value, wave_encoded))
socket.send_multipart((EngineCoreRequestType.START_DP_WAVE.value, wave_encoded))
def _get_engine_counts(self, do_copy=False) -> list[list[int]]:
"""Return list of [waiting, running] count lists for each engine."""