[Feat][EPLB] A novel static EPLB placement strategy for MoE models. (#23745)
Signed-off-by: bruceszchen <bruceszchen@tencent.com> Signed-off-by: Chen Bruce <bruceszchen@tencent.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Chen Bruce <cszwwdz@vip.qq.com> Co-authored-by: lemon412 <lemon412@foxmail.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -29,6 +29,7 @@ else:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
ExpertPlacementStrategy = Literal["linear", "round_robin"]
|
||||
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
|
||||
|
||||
|
||||
@@ -102,6 +103,15 @@ class ParallelConfig:
|
||||
"""Enable expert parallelism load balancing for MoE layers."""
|
||||
eplb_config: EPLBConfig = field(default_factory=EPLBConfig)
|
||||
"""Expert parallelism configuration."""
|
||||
expert_placement_strategy: ExpertPlacementStrategy = "linear"
|
||||
"""The expert placement strategy for MoE layers:\n
|
||||
- "linear": Experts are placed in a contiguous manner. For example, with 4
|
||||
experts and 2 ranks, rank 0 will have experts [0, 1] and rank 1 will have
|
||||
experts [2, 3].\n
|
||||
- "round_robin": Experts are placed in a round-robin manner. For example,
|
||||
with 4 experts and 2 ranks, rank 0 will have experts [0, 2] and rank 1
|
||||
will have experts [1, 3]. This strategy can help improve load balancing
|
||||
for grouped expert models with no redundant experts."""
|
||||
num_redundant_experts: Optional[int] = None
|
||||
"""`num_redundant_experts` is deprecated and has been replaced with
|
||||
`eplb_config.num_redundant_experts`. This will be removed in v0.12.0.
|
||||
|
||||
Reference in New Issue
Block a user