[BugFix] Support online dense model DP without overhead (#30739)
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: njhill <nickhill123@gmail.com>
This commit is contained in:
@@ -75,7 +75,6 @@ class EngineHandshakeMetadata:
|
||||
|
||||
addresses: EngineZmqAddresses
|
||||
parallel_config: dict[str, int | str | list[int]]
|
||||
parallel_config_hash: str | None = None
|
||||
|
||||
|
||||
class CoreEngineProcManager:
|
||||
@@ -249,12 +248,19 @@ class CoreEngineActorManager:
|
||||
from ray.runtime_env import RuntimeEnv
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
from vllm.v1.engine.core import DPEngineCoreActor
|
||||
from vllm.v1.engine.core import DPMoEEngineCoreActor, EngineCoreActor
|
||||
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
actor_class = (
|
||||
DPMoEEngineCoreActor
|
||||
if dp_size > 1 and vllm_config.model_config.is_moe
|
||||
else EngineCoreActor
|
||||
)
|
||||
|
||||
self.local_engine_actors: list[ray.ActorHandle] = []
|
||||
self.remote_engine_actors: list[ray.ActorHandle] = []
|
||||
|
||||
env_vars_list = get_env_vars_to_copy(destination="DPEngineCoreActor")
|
||||
env_vars_list = get_env_vars_to_copy(destination=actor_class.__name__)
|
||||
self.env_vars_dict = {
|
||||
name: os.environ[name] for name in env_vars_list if name in os.environ
|
||||
}
|
||||
@@ -263,7 +269,6 @@ class CoreEngineActorManager:
|
||||
self.addresses = addresses
|
||||
self.executor_class = executor_class
|
||||
self.log_stats = log_stats
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
local_engine_count = vllm_config.parallel_config.data_parallel_size_local
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
|
||||
@@ -314,7 +319,7 @@ class CoreEngineActorManager:
|
||||
runtime_env = RuntimeEnv(env_vars=actor_env_vars)
|
||||
|
||||
actor = (
|
||||
ray.remote(DPEngineCoreActor)
|
||||
ray.remote(actor_class)
|
||||
.options(
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg,
|
||||
@@ -624,7 +629,13 @@ class CoreEngineActorManager:
|
||||
from ray.runtime_env import RuntimeEnv
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
from vllm.v1.engine.core import DPEngineCoreActor
|
||||
from vllm.v1.engine.core import DPMoEEngineCoreActor, EngineCoreActor
|
||||
|
||||
actor_class = (
|
||||
DPMoEEngineCoreActor
|
||||
if cur_vllm_config.model_config.is_moe
|
||||
else EngineCoreActor
|
||||
)
|
||||
|
||||
cur_data_parallel_size = len(self.local_engine_actors) + len(
|
||||
self.remote_engine_actors
|
||||
@@ -667,7 +678,7 @@ class CoreEngineActorManager:
|
||||
)
|
||||
|
||||
actor = (
|
||||
ray.remote(DPEngineCoreActor)
|
||||
ray.remote(actor_class)
|
||||
.options(
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg,
|
||||
@@ -804,12 +815,19 @@ def launch_core_engines(
|
||||
],
|
||||
)
|
||||
|
||||
# Run the DP Coordinator process with rank 0 when in
|
||||
# online DP mode.
|
||||
run_coordinator = dp_size > 1 and not offline_mode and dp_rank == 0
|
||||
# Run the DP Coordinator process with rank 0 when in online DP mode.
|
||||
# The coordinator is needed for:
|
||||
# 1. Internal/hybrid LB: collecting and publishing queue stats for load balancing
|
||||
# 2. MoE models: wave coordination in addition to stats
|
||||
run_coordinator = (
|
||||
vllm_config.needs_dp_coordinator and not offline_mode and dp_rank == 0
|
||||
)
|
||||
|
||||
if run_coordinator:
|
||||
coordinator = DPCoordinator(parallel_config)
|
||||
coordinator = DPCoordinator(
|
||||
parallel_config,
|
||||
enable_wave_coordination=vllm_config.model_config.is_moe,
|
||||
)
|
||||
|
||||
addresses.coordinator_input, addresses.coordinator_output = (
|
||||
coordinator.get_engine_socket_addresses()
|
||||
@@ -905,6 +923,7 @@ def launch_core_engines(
|
||||
addresses,
|
||||
engines_to_handshake,
|
||||
parallel_config,
|
||||
dp_size > 1 and vllm_config.model_config.is_moe,
|
||||
vllm_config.cache_config,
|
||||
local_engine_manager,
|
||||
coordinator.proc if coordinator else None,
|
||||
@@ -916,6 +935,7 @@ def wait_for_engine_startup(
|
||||
addresses: EngineZmqAddresses,
|
||||
core_engines: list[CoreEngine],
|
||||
parallel_config: ParallelConfig,
|
||||
coordinated_dp: bool,
|
||||
cache_config: CacheConfig,
|
||||
proc_manager: CoreEngineProcManager | None,
|
||||
coord_process: Process | None,
|
||||
@@ -997,8 +1017,7 @@ def wait_for_engine_startup(
|
||||
)
|
||||
|
||||
if status == "HELLO" and engine.state == CoreEngineState.NEW:
|
||||
# Send init message with DP config info and config hash.
|
||||
# The config hash ensures all DP workers have compatible configs.
|
||||
# Send init message with DP config info.
|
||||
init_message = msgspec.msgpack.encode(
|
||||
EngineHandshakeMetadata(
|
||||
addresses=addresses,
|
||||
@@ -1010,10 +1029,9 @@ def wait_for_engine_startup(
|
||||
"_data_parallel_master_port_list",
|
||||
"data_parallel_size",
|
||||
)
|
||||
},
|
||||
parallel_config_hash=parallel_config.compute_hash()
|
||||
if parallel_config.data_parallel_size > 1
|
||||
else None,
|
||||
}
|
||||
if coordinated_dp
|
||||
else {},
|
||||
)
|
||||
)
|
||||
handshake_socket.send_multipart((eng_identity, init_message), copy=False)
|
||||
@@ -1034,8 +1052,8 @@ def wait_for_engine_startup(
|
||||
if addresses.frontend_stats_publish_address is None:
|
||||
addresses.frontend_stats_publish_address = msg.get("dp_stats_address")
|
||||
|
||||
# Validate config hash consistency across DP workers
|
||||
if parallel_config.data_parallel_size > 1:
|
||||
# Validate config hash consistency across DP workers for MoE models.
|
||||
if coordinated_dp:
|
||||
worker_config_hash = msg.get("parallel_config_hash")
|
||||
expected_hash = parallel_config.compute_hash()
|
||||
if worker_config_hash != expected_hash:
|
||||
|
||||
Reference in New Issue
Block a user