[DP] Support external DP Load Balancer mode (#19790)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -10,7 +10,7 @@ import zmq
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_mp_context, get_open_zmq_ipc_path, make_zmq_socket
|
||||
from vllm.utils import get_mp_context, make_zmq_socket
|
||||
from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType
|
||||
from vllm.v1.serial_utils import MsgpackDecoder
|
||||
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
|
||||
@@ -48,20 +48,33 @@ class DPCoordinator:
|
||||
|
||||
Engines will move into running state when receiving a new request or
|
||||
START_DP_WAVE message.
|
||||
|
||||
Note that when deployed in External LB mode, no stats will be published by
|
||||
the engines and thus updates will only be sent to front-ends when the
|
||||
request wave / running state changes.
|
||||
"""
|
||||
|
||||
def __init__(self, parallel_config: ParallelConfig):
|
||||
|
||||
# Assume coordinator is colocated with front-end procs.
|
||||
front_publish_address = get_open_zmq_ipc_path()
|
||||
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
assert dp_size > 1, "Coordinator only used for data parallel"
|
||||
|
||||
local_only = dp_size == parallel_config.data_parallel_size_local
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
back_publish_address = get_engine_client_zmq_addr(local_only, host)
|
||||
back_output_address = get_engine_client_zmq_addr(local_only, host)
|
||||
external_lb = parallel_config.data_parallel_external_lb
|
||||
|
||||
# Assume coordinator is colocated with front-end procs when not in
|
||||
# external DP LB mode.
|
||||
front_publish_address = get_engine_client_zmq_addr(
|
||||
local_only=not external_lb, host=host)
|
||||
|
||||
local_only_eng = dp_size == parallel_config.data_parallel_size_local
|
||||
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
|
||||
back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
|
||||
|
||||
# When in external LB mode, load stats aren't published, only changes
|
||||
# to request wave / running state, so we don't need to rate-limit the
|
||||
# updates to the front-end proc(s).
|
||||
min_stats_update_interval_ms = 0 if external_lb else 100
|
||||
|
||||
context = get_mp_context()
|
||||
self.proc: multiprocessing.Process = context.Process(
|
||||
@@ -72,6 +85,7 @@ class DPCoordinator:
|
||||
"front_publish_address": front_publish_address,
|
||||
"back_output_address": back_output_address,
|
||||
"back_publish_address": back_publish_address,
|
||||
"min_stats_update_interval_ms": min_stats_update_interval_ms,
|
||||
},
|
||||
daemon=True)
|
||||
self.proc.start()
|
||||
@@ -100,12 +114,16 @@ class EngineState:
|
||||
|
||||
class CoordinatorProc:
|
||||
|
||||
def __init__(self, engine_count: int):
|
||||
def __init__(self,
|
||||
engine_count: int,
|
||||
min_stats_update_interval_ms: int = 100):
|
||||
|
||||
self.ctx = zmq.Context()
|
||||
|
||||
self.engines = [EngineState() for _ in range(engine_count)]
|
||||
|
||||
self.stats_update_interval_ms = min_stats_update_interval_ms
|
||||
|
||||
self.current_wave = 0
|
||||
self.engines_running = False
|
||||
self.stats_changed = False
|
||||
@@ -116,8 +134,11 @@ class CoordinatorProc:
|
||||
front_publish_address: str,
|
||||
back_output_address: str,
|
||||
back_publish_address: str,
|
||||
min_stats_update_interval_ms: int = 100,
|
||||
):
|
||||
coordinator = CoordinatorProc(engine_count=engine_count)
|
||||
coordinator = CoordinatorProc(
|
||||
engine_count=engine_count,
|
||||
min_stats_update_interval_ms=min_stats_update_interval_ms)
|
||||
try:
|
||||
coordinator.process_input_socket(
|
||||
front_publish_address,
|
||||
@@ -156,9 +177,10 @@ class CoordinatorProc:
|
||||
last_publish_time = 0
|
||||
while True:
|
||||
elapsed = int(time.time() * 1000) - last_publish_time
|
||||
# Send at 100 ms interval if the stats have changed,
|
||||
# or otherwise every 3 seconds.
|
||||
wait_for = 100 if self.stats_changed else 3000
|
||||
# Send at stats_update_interval_ms interval if the stats have
|
||||
# changed, or otherwise every 4 seconds.
|
||||
wait_for = (self.stats_update_interval_ms
|
||||
if self.stats_changed else 4000)
|
||||
events = poller.poll(timeout=max(0, wait_for - elapsed))
|
||||
if not events:
|
||||
# Poller timeout - publish current stats to front-ends.
|
||||
@@ -174,7 +196,7 @@ class CoordinatorProc:
|
||||
|
||||
if publish_front in events:
|
||||
buffer = publish_front.recv()
|
||||
if buffer == b'\x01':
|
||||
if buffer in (b'\x01', b'\x00'):
|
||||
# Ignore subscription messages.
|
||||
continue
|
||||
|
||||
|
||||
Reference in New Issue
Block a user