Elastic Expert Parallel Initial Support (#20775)

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
Rui Qiao
2025-07-18 17:46:09 -07:00
committed by GitHub
parent 5782581acf
commit 217937221b
24 changed files with 1659 additions and 68 deletions

View File

@@ -174,16 +174,21 @@ class CoreEngineActorManager:
self.local_engine_actors: list[ray.ActorHandle] = []
self.remote_engine_actors: list[ray.ActorHandle] = []
env_vars_list = get_env_vars_to_copy(destination="DPEngineCoreActor")
self.env_vars_dict = {
name: os.environ[name]
for name in env_vars_list if name in os.environ
}
runtime_env = RuntimeEnv(env_vars=self.env_vars_dict)
self.addresses = addresses
self.executor_class = executor_class
self.log_stats = log_stats
dp_size = vllm_config.parallel_config.data_parallel_size
local_engine_count = \
vllm_config.parallel_config.data_parallel_size_local
world_size = vllm_config.parallel_config.world_size
env_vars_set = get_env_vars_to_copy(destination="DPEngineCoreActor")
env_vars_dict = {
name: os.environ[name]
for name in env_vars_set if name in os.environ
}
runtime_env = RuntimeEnv(env_vars=env_vars_dict)
if ray.is_initialized():
logger.info(
@@ -208,6 +213,7 @@ class CoreEngineActorManager:
assert len(placement_groups) == dp_size, (
"Number of placement groups must match data parallel size")
self.placement_group_is_local = []
refs = []
for index in range(dp_size):
local_index = local_dp_ranks[index]
@@ -231,6 +237,7 @@ class CoreEngineActorManager:
self.local_engine_actors.append(actor)
else:
self.remote_engine_actors.append(actor)
self.placement_group_is_local.append(local_client)
refs.append(actor.wait_for_init.remote())
ray.get(refs)
@@ -242,6 +249,9 @@ class CoreEngineActorManager:
def create_dp_placement_groups(
vllm_config: VllmConfig
) -> tuple[list["PlacementGroup"], list[int]]:
"""
Create placement groups for data parallel.
"""
import ray
from ray._private.state import available_resources_per_node
@@ -250,10 +260,11 @@ class CoreEngineActorManager:
logger.info("Creating placement groups for data parallel")
dp_master_ip = \
vllm_config.parallel_config.data_parallel_master_ip
dp_size = vllm_config.parallel_config.data_parallel_size
num_pg_to_create = vllm_config.parallel_config.data_parallel_size
local_engine_count = \
vllm_config.parallel_config.data_parallel_size_local
nodes = list_nodes()
nodes = sorted(list_nodes(),
key=lambda node: node.node_ip != dp_master_ip)
assert nodes[0].node_ip == dp_master_ip, (
@@ -293,7 +304,7 @@ class CoreEngineActorManager:
local_dp_ranks.append(i)
else:
for i in range(available_engine_count):
if len(placement_groups) == dp_size:
if len(placement_groups) == num_pg_to_create:
break
bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}]
pg = ray.util.placement_group(
@@ -305,6 +316,204 @@ class CoreEngineActorManager:
local_dp_ranks.append(i)
return placement_groups, local_dp_ranks
@staticmethod
def add_dp_placement_groups(
old_vllm_config: VllmConfig, new_data_parallel_size: int
) -> tuple[list["PlacementGroup"], list[int]]:
"""
Add placement groups for new data parallel size.
"""
import ray
from ray._private.state import (available_resources_per_node,
total_resources_per_node)
from ray.util.state import list_nodes
old_dp_size = old_vllm_config.parallel_config.data_parallel_size
num_pg_to_create = new_data_parallel_size - old_dp_size
if num_pg_to_create <= 0:
return [], []
dp_master_ip = old_vllm_config.parallel_config.data_parallel_master_ip
world_size = old_vllm_config.parallel_config.world_size
nodes = list_nodes()
nodes = sorted(nodes, key=lambda node: node.node_ip != dp_master_ip)
assert nodes[0].node_ip == dp_master_ip, (
"The first node must be the head node")
assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, (
"There can only be one head node")
available_resources = available_resources_per_node()
total_resources = total_resources_per_node()
placement_groups = []
local_dp_ranks = []
num_pg_created = 0
for node in nodes:
if num_pg_created >= num_pg_to_create:
break
node_ip = node.node_ip
node_id = node.node_id
available_gpus = int(available_resources[node_id]["GPU"])
# Get total GPUs on this node from the node's resources
# Ray stores node resources with node ID as key
total_gpus = int(total_resources[node_id]["GPU"])
# Calculate used GPUs and used engines on this node
used_gpus = max(0, total_gpus - available_gpus)
used_engines_on_node = used_gpus // world_size
# Calculate how many new engines this node can accommodate
available_engine_count = available_gpus // world_size
# Create placement groups for new engines on this node
for i in range(available_engine_count):
if num_pg_created >= num_pg_to_create:
break
rank = old_dp_size + num_pg_created
# Create bundles with node constraint for master node
if node_ip == dp_master_ip:
bundles = [{
"GPU": 1.0,
"node:" + dp_master_ip: 0.001
}] * world_size + [{
"CPU": 1.0
}]
else:
bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}]
pg = ray.util.placement_group(
name=f"dp_rank_{rank}",
strategy="STRICT_PACK",
bundles=bundles,
)
placement_groups.append(pg)
# Local rank starts from the number of engines already used
# on this node
local_rank = used_engines_on_node + i
local_dp_ranks.append(local_rank)
num_pg_created += 1
return placement_groups, local_dp_ranks
def scale_up_elastic_ep(self, cur_vllm_config: VllmConfig,
new_data_parallel_size: int) -> None:
import copy
import ray
from ray.runtime_env import RuntimeEnv
from ray.util.scheduling_strategies import (
PlacementGroupSchedulingStrategy)
from vllm.v1.engine.core import DPEngineCoreActor
cur_data_parallel_size = len(self.local_engine_actors) + \
len(self.remote_engine_actors)
assert new_data_parallel_size > cur_data_parallel_size, (
f"New data parallel size {new_data_parallel_size} must be greater "
f"than current data parallel size {cur_data_parallel_size} "
"for scale up")
placement_groups, local_dp_ranks = \
self.add_dp_placement_groups(
cur_vllm_config, new_data_parallel_size)
world_size = cur_vllm_config.parallel_config.world_size
dp_master_ip = cur_vllm_config.parallel_config.data_parallel_master_ip
new_local_engines = 0
runtime_env = RuntimeEnv(env_vars=self.env_vars_dict
| {"VLLM_ELASTIC_EP_SCALE_UP_LAUNCH": "1"})
for i, (pg,
local_rank) in enumerate(zip(placement_groups,
local_dp_ranks)):
rank = cur_data_parallel_size + i
dp_vllm_config = copy.deepcopy(cur_vllm_config)
dp_vllm_config.parallel_config.data_parallel_size = \
new_data_parallel_size
dp_vllm_config.parallel_config.placement_group = pg
# Check if this placement group is on the head node
local_client = any(
bundle.get("node:" + dp_master_ip, 0) > 0
for bundle in pg.bundle_specs)
if local_client:
new_local_engines += 1
# Update data_parallel_size_local
dp_vllm_config.parallel_config.data_parallel_size_local = (
cur_vllm_config.parallel_config.data_parallel_size_local +
new_local_engines)
actor = ray.remote(DPEngineCoreActor).options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_bundle_index=world_size,
),
runtime_env=runtime_env).remote(
vllm_config=dp_vllm_config,
executor_class=self.executor_class,
log_stats=self.log_stats,
local_client=local_client,
addresses=self.addresses,
dp_rank=rank,
local_dp_rank=local_rank)
if local_client:
self.local_engine_actors.append(actor)
else:
self.remote_engine_actors.append(actor)
self.created_placement_groups.append(pg)
self.placement_group_is_local.append(local_client)
ray.get([
actor.wait_for_init.remote()
for actor in (self.local_engine_actors[-new_local_engines:]
if new_local_engines > 0 else []) +
self.remote_engine_actors[-(len(placement_groups) -
new_local_engines):]
])
actors = (self.local_engine_actors[-new_local_engines:]
if new_local_engines > 0 else []) + \
self.remote_engine_actors[-(len(placement_groups) -
new_local_engines):]
for actor in actors:
self.run_refs.append(actor.run.remote())
cur_vllm_config.parallel_config.data_parallel_size = \
new_data_parallel_size
# Update old_vllm_config with new data_parallel_size_local if any new
# local engines were added
if new_local_engines > 0:
cur_vllm_config.parallel_config.data_parallel_size_local += \
new_local_engines
def scale_down_elastic_ep(self, cur_data_parallel_size: int,
new_data_parallel_size: int) -> None:
import ray
assert cur_data_parallel_size > new_data_parallel_size, (
f"cur_data_parallel_size {cur_data_parallel_size} must be greater "
f"than new_data_parallel_size {new_data_parallel_size} "
"for scale down")
for _ in range(cur_data_parallel_size - new_data_parallel_size):
pg = self.created_placement_groups.pop()
is_local = self.placement_group_is_local.pop()
if is_local:
self.local_engine_actors.pop()
else:
self.remote_engine_actors.pop()
ray.util.remove_placement_group(pg)
def get_run_refs(self):
return self.run_refs