[Feat][EPLB] A novel static EPLB placement strategy for MoE models. (#23745)

Signed-off-by: bruceszchen <bruceszchen@tencent.com>
Signed-off-by: Chen Bruce <bruceszchen@tencent.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Chen Bruce <cszwwdz@vip.qq.com>
Co-authored-by: lemon412 <lemon412@foxmail.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Chen Bruce
2025-09-16 18:55:16 +08:00
committed by GitHub
parent 27fcfe7bcf
commit 7ea5c73ad7
4 changed files with 265 additions and 12 deletions

View File

@@ -0,0 +1,194 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
def verify_round_robin_pattern(expert_map, ep_rank, ep_size,
global_num_experts):
"""Verify that the expert map follows the round_robin pattern."""
# Calculate expected local experts (supporting non-divisible cases)
base_experts = global_num_experts // ep_size
remainder = global_num_experts % ep_size
if ep_rank < remainder:
local_num_experts = base_experts + 1
else:
local_num_experts = base_experts
# Expected expert IDs for this rank in round_robin pattern
# For non-divisible cases, ranks with extra experts start earlier
expected_expert_ids = []
for expert_idx in range(local_num_experts):
global_expert_id = ep_rank + expert_idx * ep_size
expected_expert_ids.append(global_expert_id)
# Check that only expected experts are mapped to this rank
for global_expert_id in range(global_num_experts):
if global_expert_id in expected_expert_ids:
local_expert_id = expert_map[global_expert_id]
expected_local_id = expected_expert_ids.index(global_expert_id)
assert (
local_expert_id == expected_local_id
), f"Global expert {global_expert_id} should map to local expert " \
f"{expected_local_id}, got {local_expert_id}"
else:
assert (
expert_map[global_expert_id] == -1
), f"Global expert {global_expert_id} should not be mapped to " \
f"this rank"
# Verify that all local expert IDs are consecutive starting from 0
local_expert_ids = [
expert_map[global_id] for global_id in expected_expert_ids
]
expected_local_ids = list(range(local_num_experts))
assert (
local_expert_ids == expected_local_ids
), f"Expected local expert IDs {expected_local_ids}, got {local_expert_ids}"
@pytest.mark.parametrize("expert_placement_strategy", ["round_robin"])
@pytest.mark.parametrize("world_size", [2, 4])
def test_expert_placement_various_sizes(expert_placement_strategy, world_size):
"""Test round_robin expert placement with various expert counts."""
# Test with different global_num_experts values
# Include both divisible and non-divisible cases
if world_size == 2:
test_cases = [
(4, 2), # 4 experts (divisible)
(8, 2), # 8 experts (divisible)
(9, 2), # 9 experts (non-divisible)
(16, 2), # 16 experts (divisible)
(17, 2), # 17 experts (non-divisible)
]
elif world_size == 4:
test_cases = [
(8, 4), # 8 experts (divisible)
(16, 4), # 16 experts (divisible)
(18, 4), # 18 experts (non-divisible)
(32, 4), # 32 experts (divisible)
(33, 4), # 33 experts (non-divisible)
]
else:
test_cases = []
for test_global_experts, test_ep_size in test_cases:
# Ensure ep_size matches world_size
assert (test_ep_size == world_size
), f"ep_size {test_ep_size} must equal world_size {world_size}"
# Test each rank
for ep_rank in range(world_size):
# Calculate expected local experts
base_experts = test_global_experts // test_ep_size
remainder = test_global_experts % test_ep_size
if ep_rank < remainder:
expected_test_local = base_experts + 1
else:
expected_test_local = base_experts
test_local_experts, test_expert_map = determine_expert_map(
ep_size=test_ep_size,
ep_rank=ep_rank,
global_num_experts=test_global_experts,
expert_placement_strategy=expert_placement_strategy,
)
assert (
test_local_experts == expected_test_local
), f"For {test_global_experts} experts on {test_ep_size} ranks, " \
f"rank {ep_rank}: expected {expected_test_local} local" \
f"experts, got {test_local_experts}"
if test_expert_map is not None:
assert test_expert_map.shape == (
test_global_experts,
), f"Expected expert map shape ({test_global_experts},), " \
f"got {test_expert_map.shape}"
# Verify round_robin pattern for this test case
verify_round_robin_pattern(test_expert_map, ep_rank,
test_ep_size, test_global_experts)
@pytest.mark.parametrize("expert_placement_strategy", ["round_robin"])
@pytest.mark.parametrize("world_size", [2, 4])
def test_expert_placement_edge_cases(expert_placement_strategy, world_size):
"""Test edge cases for round_robin expert placement."""
# Test case 1: ep_size = 1 (should return None for expert_map)
local_num_experts, expert_map = determine_expert_map(
ep_size=1,
ep_rank=0,
global_num_experts=8,
expert_placement_strategy=expert_placement_strategy,
)
assert local_num_experts == 8, "For ep_size=1, should get all experts"
assert expert_map is None, "For ep_size=1, expert_map should be None"
# Test case 2: ep_size = 0 (should raise assertion)
with pytest.raises(AssertionError):
determine_expert_map(
ep_size=0,
ep_rank=0,
global_num_experts=8,
expert_placement_strategy=expert_placement_strategy,
)
def test_determine_expert_map_comprehensive():
"""Test of determine_expert_map function with various configurations."""
# Test cases: (ep_size, ep_rank, global_num_experts,
# expert_placement_strategy, expected_local, expected_map_pattern)
test_cases = [
# Round robin placement tests
(2, 0, 8, "round_robin", 4, [0, -1, 1, -1, 2, -1, 3,
-1]), # rank 0 gets even experts
(2, 1, 8, "round_robin", 4, [-1, 0, -1, 1, -1, 2, -1,
3]), # rank 1 gets odd experts
(2, 0, 9, "round_robin", 5, [0, -1, 1, -1, 2, -1, 3, -1, 4
]), # rank 0 gets 5 experts (even + last)
(2, 1, 9, "round_robin", 4, [-1, 0, -1, 1, -1, 2, -1, 3,
-1]), # rank 1 gets 4 experts (odd)
# 4-rank tests
(4, 0, 8, "round_robin", 2, [0, -1, -1, -1, 1, -1, -1,
-1]), # rank 0 gets experts 0, 4
(4, 1, 8, "round_robin", 2, [-1, 0, -1, -1, -1, 1, -1,
-1]), # rank 1 gets experts 1, 5
(4, 2, 8, "round_robin", 2, [-1, -1, 0, -1, -1, -1, 1,
-1]), # rank 2 gets experts 2, 6
(4, 3, 8, "round_robin", 2, [-1, -1, -1, 0, -1, -1, -1,
1]), # rank 3 gets experts 3, 7
]
for ep_size, ep_rank, global_num_experts, expert_placement_strategy, \
expected_local, expected_map_pattern in test_cases:
local_num_experts, expert_map = determine_expert_map(
ep_size=ep_size,
ep_rank=ep_rank,
global_num_experts=global_num_experts,
expert_placement_strategy=expert_placement_strategy,
)
assert local_num_experts == expected_local, \
f"ep_size={ep_size}, ep_rank={ep_rank}, " \
f"global_num_experts={global_num_experts}, " \
f"expert_placement_strategy={expert_placement_strategy}: " \
f"expected {expected_local} local experts, got {local_num_experts}"
if expected_map_pattern is None:
assert expert_map is None, "Expected expert_map to be None"
else:
assert expert_map is not None, "Expected expert_map to not be None"
actual_map = expert_map.tolist()
assert actual_map == expected_map_pattern, \
f"ep_size={ep_size}, ep_rank={ep_rank}, " \
f"global_num_experts={global_num_experts}, " \
f"expert_placement_strategy={expert_placement_strategy}: " \
f"expected map {expected_map_pattern}, got {actual_map}"

View File

@@ -29,6 +29,7 @@ else:
logger = init_logger(__name__) logger = init_logger(__name__)
ExpertPlacementStrategy = Literal["linear", "round_robin"]
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
@@ -102,6 +103,15 @@ class ParallelConfig:
"""Enable expert parallelism load balancing for MoE layers.""" """Enable expert parallelism load balancing for MoE layers."""
eplb_config: EPLBConfig = field(default_factory=EPLBConfig) eplb_config: EPLBConfig = field(default_factory=EPLBConfig)
"""Expert parallelism configuration.""" """Expert parallelism configuration."""
expert_placement_strategy: ExpertPlacementStrategy = "linear"
"""The expert placement strategy for MoE layers:\n
- "linear": Experts are placed in a contiguous manner. For example, with 4
experts and 2 ranks, rank 0 will have experts [0, 1] and rank 1 will have
experts [2, 3].\n
- "round_robin": Experts are placed in a round-robin manner. For example,
with 4 experts and 2 ranks, rank 0 will have experts [0, 2] and rank 1
will have experts [1, 3]. This strategy can help improve load balancing
for grouped expert models with no redundant experts."""
num_redundant_experts: Optional[int] = None num_redundant_experts: Optional[int] = None
"""`num_redundant_experts` is deprecated and has been replaced with """`num_redundant_experts` is deprecated and has been replaced with
`eplb_config.num_redundant_experts`. This will be removed in v0.12.0. `eplb_config.num_redundant_experts`. This will be removed in v0.12.0.

View File

@@ -34,6 +34,7 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
SpeculativeConfig, TaskOption, TokenizerMode, SpeculativeConfig, TaskOption, TokenizerMode,
VllmConfig, get_attr_docs) VllmConfig, get_attr_docs)
from vllm.config.multimodal import MMCacheType, MultiModalConfig from vllm.config.multimodal import MMCacheType, MultiModalConfig
from vllm.config.parallel import ExpertPlacementStrategy
from vllm.config.utils import get_field from vllm.config.utils import get_field
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform from vllm.platforms import CpuArchEnum, current_platform
@@ -328,6 +329,8 @@ class EngineArgs:
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config") eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
enable_eplb: bool = ParallelConfig.enable_eplb enable_eplb: bool = ParallelConfig.enable_eplb
expert_placement_strategy: ExpertPlacementStrategy = \
ParallelConfig.expert_placement_strategy
num_redundant_experts: int = EPLBConfig.num_redundant_experts num_redundant_experts: int = EPLBConfig.num_redundant_experts
eplb_window_size: int = EPLBConfig.window_size eplb_window_size: int = EPLBConfig.window_size
eplb_step_interval: int = EPLBConfig.step_interval eplb_step_interval: int = EPLBConfig.step_interval
@@ -696,6 +699,9 @@ class EngineArgs:
**parallel_kwargs["enable_eplb"]) **parallel_kwargs["enable_eplb"])
parallel_group.add_argument("--eplb-config", parallel_group.add_argument("--eplb-config",
**parallel_kwargs["eplb_config"]) **parallel_kwargs["eplb_config"])
parallel_group.add_argument(
"--expert-placement-strategy",
**parallel_kwargs["expert_placement_strategy"])
parallel_group.add_argument( parallel_group.add_argument(
"--num-redundant-experts", "--num-redundant-experts",
type=int, type=int,
@@ -1335,6 +1341,7 @@ class EngineArgs:
enable_expert_parallel=self.enable_expert_parallel, enable_expert_parallel=self.enable_expert_parallel,
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
eplb_config=self.eplb_config, eplb_config=self.eplb_config,
expert_placement_strategy=self.expert_placement_strategy,
max_parallel_loading_workers=self.max_parallel_loading_workers, max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce, disable_custom_all_reduce=self.disable_custom_all_reduce,
ray_workers_use_nsight=self.ray_workers_use_nsight, ray_workers_use_nsight=self.ray_workers_use_nsight,

View File

@@ -4,7 +4,7 @@
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from enum import Enum from enum import Enum
from typing import Callable, Literal, Optional, Union, overload from typing import Callable, Literal, Optional, Union, get_args, overload
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -12,6 +12,7 @@ from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs import vllm.envs as envs
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.config.parallel import ExpertPlacementStrategy
from vllm.distributed import (get_dp_group, get_ep_group, from vllm.distributed import (get_dp_group, get_ep_group,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
@@ -675,8 +676,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def determine_expert_map( def determine_expert_map(
ep_size: int, ep_rank: int, ep_size: int,
global_num_experts: int) -> tuple[int, Optional[torch.Tensor]]: ep_rank: int,
global_num_experts: int,
expert_placement_strategy: ExpertPlacementStrategy = "linear",
) -> tuple[int, Optional[torch.Tensor]]:
""" """
Calculates how many experts should be assigned to each rank for EP and Calculates how many experts should be assigned to each rank for EP and
creates a mapping from global to local expert index. Experts are creates a mapping from global to local expert index. Experts are
@@ -684,8 +688,11 @@ def determine_expert_map(
last rank. last rank.
Args: Args:
ep_size (int): The size of the expert parallel group ep_size: The size of the expert parallel group
global_num_experts (int): The total number of experts in the model. ep_rank: The rank of the current process in the expert parallel
group
global_num_experts: The total number of experts in the model.
expert_placement_strategy: The expert placement strategy.
Returns: Returns:
tuple[int, Optional[torch.Tensor]]: A tuple containing: tuple[int, Optional[torch.Tensor]]: A tuple containing:
@@ -711,9 +718,23 @@ def determine_expert_map(
# Create a tensor of size num_experts filled with -1 # Create a tensor of size num_experts filled with -1
expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32) expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32)
# Create an expert map for the local experts # Create an expert map for the local experts
if expert_placement_strategy == "linear":
start_idx = ep_rank * base_experts + min(ep_rank, remainder) start_idx = ep_rank * base_experts + min(ep_rank, remainder)
expert_map[start_idx:start_idx + local_num_experts] = torch.arange( expert_map[start_idx:start_idx + local_num_experts] = torch.arange(
0, local_num_experts, dtype=torch.int32) 0, local_num_experts, dtype=torch.int32)
elif expert_placement_strategy == "round_robin":
local_log_experts = torch.arange(ep_rank,
global_num_experts,
ep_size,
dtype=torch.int32)
expert_map[local_log_experts] = torch.arange(0,
local_num_experts,
dtype=torch.int32)
else:
raise ValueError("Unsupported expert placement strategy "
f"'{expert_placement_strategy}', expected one of "
f"{get_args(ExpertPlacementStrategy)}")
return (local_num_experts, expert_map) return (local_num_experts, expert_map)
@@ -846,15 +867,36 @@ class FusedMoE(CustomOp):
else: else:
assert num_redundant_experts == 0, \ assert num_redundant_experts == 0, \
"Redundant experts are only supported with EPLB." "Redundant experts are only supported with EPLB."
expert_placement_strategy = (
vllm_config.parallel_config.expert_placement_strategy)
if expert_placement_strategy == "round_robin":
# TODO(Bruce): will support round robin expert placement with
# EPLB enabled in the future.
round_robin_supported = ((num_expert_group is not None
and num_expert_group > 1)
and num_redundant_experts == 0
and not self.enable_eplb)
if not round_robin_supported:
logger.warning(
"Round-robin expert placement is only supported for "
"models with multiple expert groups and no redundant "
"experts. Falling back to linear expert placement.")
expert_placement_strategy = "linear"
self.local_num_experts, self.expert_map = determine_expert_map( self.local_num_experts, self.expert_map = determine_expert_map(
ep_size=self.ep_size, ep_size=self.ep_size,
ep_rank=self.ep_rank, ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts) global_num_experts=self.global_num_experts,
expert_placement_strategy=expert_placement_strategy,
)
logger.info_once( logger.info_once(
"[EP Rank %s/%s] Expert parallelism is enabled. Local/global" "[EP Rank %s/%s] Expert parallelism is enabled. Expert "
"placement strategy: %s. Local/global"
" number of experts: %s/%s. Experts local to global index map:" " number of experts: %s/%s. Experts local to global index map:"
" %s.", self.ep_rank, self.ep_size, self.local_num_experts, " %s.", self.ep_rank, self.ep_size, expert_placement_strategy,
self.global_num_experts, self.local_num_experts, self.global_num_experts,
get_compressed_expert_map(self.expert_map)) get_compressed_expert_map(self.expert_map))
else: else:
self.local_num_experts, self.expert_map = (self.global_num_experts, self.local_num_experts, self.expert_map = (self.global_num_experts,