[DP] Support external DP Load Balancer mode (#19790)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -34,6 +34,7 @@ from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType, UtilityOutput)
|
||||
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
|
||||
from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
@@ -41,7 +42,6 @@ from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
from vllm.v1.utils import EngineHandshakeMetadata, EngineZmqAddresses
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -367,10 +367,11 @@ class EngineCoreProc(EngineCore):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
on_head_node: bool,
|
||||
local_client: bool,
|
||||
handshake_address: str,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
client_handshake_address: Optional[str] = None,
|
||||
engine_index: int = 0,
|
||||
):
|
||||
self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
|
||||
@@ -383,12 +384,21 @@ class EngineCoreProc(EngineCore):
|
||||
identity = self.engine_index.to_bytes(length=2, byteorder="little")
|
||||
self.engines_running = False
|
||||
|
||||
with self._perform_handshake(handshake_address, identity, on_head_node,
|
||||
vllm_config) as addresses:
|
||||
with self._perform_handshakes(handshake_address, identity,
|
||||
local_client, vllm_config,
|
||||
client_handshake_address) as addresses:
|
||||
self.client_count = len(addresses.outputs)
|
||||
|
||||
# Set up data parallel environment.
|
||||
self.has_coordinator = addresses.coordinator_output is not None
|
||||
self.frontend_stats_publish_address = (
|
||||
addresses.frontend_stats_publish_address)
|
||||
# Only publish request queue stats to coordinator for "internal"
|
||||
# LB mode.
|
||||
self.publish_dp_lb_stats = (
|
||||
self.has_coordinator
|
||||
and not vllm_config.parallel_config.data_parallel_external_lb)
|
||||
|
||||
self._init_data_parallel(vllm_config)
|
||||
|
||||
super().__init__(vllm_config, executor_class, log_stats,
|
||||
@@ -414,45 +424,102 @@ class EngineCoreProc(EngineCore):
|
||||
self.output_thread.start()
|
||||
|
||||
@contextmanager
|
||||
def _perform_handshake(
|
||||
self, handshake_address: str, identity: bytes, on_head_node: bool,
|
||||
vllm_config: VllmConfig
|
||||
def _perform_handshakes(
|
||||
self,
|
||||
handshake_address: str,
|
||||
identity: bytes,
|
||||
local_client: bool,
|
||||
vllm_config: VllmConfig,
|
||||
client_handshake_address: Optional[str],
|
||||
) -> Generator[EngineZmqAddresses, None, None]:
|
||||
"""
|
||||
Perform startup handshakes.
|
||||
|
||||
For DP=1 or offline mode, this is with the colocated front-end process.
|
||||
|
||||
For DP>1 with internal loadbalancing this is with the shared front-end
|
||||
process which may reside on a different node.
|
||||
|
||||
For DP>1 with external loadbalancing, two handshakes are performed:
|
||||
- With the rank 0 front-end process which retrieves the
|
||||
DP Coordinator ZMQ addresses and DP process group address.
|
||||
- With the colocated front-end process which retrieves the
|
||||
client input/output socket addresses.
|
||||
with the exception of the rank 0 engine itself which doesn't require
|
||||
the second handshake.
|
||||
|
||||
Here, "front-end" process can mean the process containing the engine
|
||||
core client (which is the API server process in the case the API
|
||||
server is not scaled out), OR the launcher process running the
|
||||
run_multi_api_server() function in serve.py.
|
||||
"""
|
||||
input_ctx = zmq.Context()
|
||||
with make_zmq_socket(input_ctx,
|
||||
is_local = local_client and client_handshake_address is None
|
||||
handshake = self._perform_handshake(input_ctx, handshake_address,
|
||||
identity, is_local, vllm_config,
|
||||
vllm_config.parallel_config)
|
||||
if client_handshake_address is None:
|
||||
with handshake as addresses:
|
||||
yield addresses
|
||||
else:
|
||||
local_handshake = self._perform_handshake(
|
||||
input_ctx, client_handshake_address, identity, local_client,
|
||||
vllm_config)
|
||||
with handshake as addresses, local_handshake as client_addresses:
|
||||
addresses.inputs = client_addresses.inputs
|
||||
addresses.outputs = client_addresses.outputs
|
||||
yield addresses
|
||||
|
||||
# Update config which may have changed from the handshake
|
||||
vllm_config.__post_init__()
|
||||
|
||||
@contextmanager
|
||||
def _perform_handshake(
|
||||
self,
|
||||
ctx: zmq.Context,
|
||||
handshake_address: str,
|
||||
identity: bytes,
|
||||
local_client: bool,
|
||||
vllm_config: VllmConfig,
|
||||
parallel_config_to_update: Optional[ParallelConfig] = None,
|
||||
) -> Generator[EngineZmqAddresses, None, None]:
|
||||
with make_zmq_socket(ctx,
|
||||
handshake_address,
|
||||
zmq.DEALER,
|
||||
identity=identity,
|
||||
linger=5000,
|
||||
bind=False) as handshake_socket:
|
||||
# Register engine with front-end.
|
||||
addresses = self.startup_handshake(handshake_socket, on_head_node,
|
||||
vllm_config.parallel_config)
|
||||
|
||||
# Update config which may have changed from the handshake
|
||||
vllm_config.__post_init__()
|
||||
|
||||
addresses = self.startup_handshake(handshake_socket, local_client,
|
||||
parallel_config_to_update)
|
||||
yield addresses
|
||||
|
||||
# Send ready message.
|
||||
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
|
||||
# We pass back the coordinator stats update address here for the
|
||||
# external LB case for our colocated front-end to use (coordinator
|
||||
# only runs with rank 0).
|
||||
dp_stats_address = self.frontend_stats_publish_address
|
||||
handshake_socket.send(
|
||||
msgspec.msgpack.encode({
|
||||
"status": "READY",
|
||||
"local": on_head_node,
|
||||
"local": local_client,
|
||||
"num_gpu_blocks": num_gpu_blocks,
|
||||
"dp_stats_address": dp_stats_address,
|
||||
}))
|
||||
|
||||
@staticmethod
|
||||
def startup_handshake(
|
||||
handshake_socket: zmq.Socket, on_head_node: bool,
|
||||
parallel_config: ParallelConfig) -> EngineZmqAddresses:
|
||||
handshake_socket: zmq.Socket,
|
||||
local_client: bool,
|
||||
parallel_config: Optional[ParallelConfig] = None,
|
||||
) -> EngineZmqAddresses:
|
||||
|
||||
# Send registration message.
|
||||
handshake_socket.send(
|
||||
msgspec.msgpack.encode({
|
||||
"status": "HELLO",
|
||||
"local": on_head_node,
|
||||
"local": local_client,
|
||||
}))
|
||||
|
||||
# Receive initialization message.
|
||||
@@ -466,9 +533,9 @@ class EngineCoreProc(EngineCore):
|
||||
init_bytes, type=EngineHandshakeMetadata)
|
||||
logger.debug("Received init message: %s", init_message)
|
||||
|
||||
received_parallel_config = init_message.parallel_config
|
||||
for key, value in received_parallel_config.items():
|
||||
setattr(parallel_config, key, value)
|
||||
if parallel_config is not None:
|
||||
for key, value in init_message.parallel_config.items():
|
||||
setattr(parallel_config, key, value)
|
||||
|
||||
return init_message.addresses
|
||||
|
||||
@@ -749,12 +816,12 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
on_head_node: bool,
|
||||
local_client: bool,
|
||||
handshake_address: str,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
client_handshake_address: Optional[str] = None,
|
||||
):
|
||||
|
||||
self._decorate_logs()
|
||||
|
||||
# Counts forward-passes of the model so that we can synchronize
|
||||
@@ -765,8 +832,9 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
|
||||
# Initialize the engine.
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
super().__init__(vllm_config, on_head_node, handshake_address,
|
||||
executor_class, log_stats, dp_rank)
|
||||
super().__init__(vllm_config, local_client, handshake_address,
|
||||
executor_class, log_stats, client_handshake_address,
|
||||
dp_rank)
|
||||
|
||||
def _decorate_logs(self):
|
||||
# Add process-specific prefix to stdout and stderr before
|
||||
@@ -799,10 +867,18 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
from vllm.platforms import current_platform
|
||||
device_control_env_var = current_platform.device_control_env_var
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
os.environ[device_control_env_var] = ",".join(
|
||||
str(current_platform.device_id_to_physical_device_id(i))
|
||||
for i in range(local_dp_rank * world_size, (local_dp_rank + 1) *
|
||||
world_size))
|
||||
# Set CUDA_VISIBLE_DEVICES or equivalent.
|
||||
try:
|
||||
os.environ[device_control_env_var] = ",".join(
|
||||
str(current_platform.device_id_to_physical_device_id(i))
|
||||
for i in range(local_dp_rank *
|
||||
world_size, (local_dp_rank + 1) * world_size))
|
||||
except IndexError as e:
|
||||
raise Exception(
|
||||
f"Error setting {device_control_env_var}: "
|
||||
f"local range: [{local_dp_rank * world_size}, "
|
||||
f"{(local_dp_rank + 1) * world_size}) "
|
||||
f"base value: \"{os.getenv(device_control_env_var)}\"") from e
|
||||
|
||||
self.dp_rank = dp_rank
|
||||
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
|
||||
@@ -839,7 +915,7 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
super()._handle_client_request(request_type, request)
|
||||
|
||||
def _maybe_publish_request_counts(self):
|
||||
if not self.has_coordinator:
|
||||
if not self.publish_dp_lb_stats:
|
||||
return
|
||||
|
||||
# Publish our request counts (if they've changed).
|
||||
@@ -892,9 +968,9 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
|
||||
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
|
||||
|
||||
# Optimization - only perform finish-sync all-reduce every 24 steps.
|
||||
# Optimization - only perform finish-sync all-reduce every 32 steps.
|
||||
self.counter += 1
|
||||
if self.counter != 24:
|
||||
if self.counter != 32:
|
||||
return True
|
||||
self.counter = 0
|
||||
|
||||
@@ -910,7 +986,7 @@ class DPEngineCoreActor(DPEngineCoreProc):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
on_head_node: bool,
|
||||
local_client: bool,
|
||||
addresses: EngineZmqAddresses,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
@@ -927,15 +1003,16 @@ class DPEngineCoreActor(DPEngineCoreProc):
|
||||
# data parallel groups.
|
||||
del os.environ['CUDA_VISIBLE_DEVICES']
|
||||
|
||||
super().__init__(vllm_config, on_head_node, "", executor_class,
|
||||
super().__init__(vllm_config, local_client, "", executor_class,
|
||||
log_stats)
|
||||
|
||||
def _decorate_logs(self):
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def _perform_handshake(self, handshake_address: str, identity: bytes,
|
||||
on_head_node: bool, vllm_config: VllmConfig):
|
||||
def _perform_handshakes(self, handshake_address: str, identity: bytes,
|
||||
local_client: bool, vllm_config: VllmConfig,
|
||||
client_handshake_address: Optional[str]):
|
||||
"""
|
||||
For Ray, we don't need to actually perform handshake.
|
||||
All addresses information is known before the actor creation.
|
||||
|
||||
Reference in New Issue
Block a user