[Feature] use --eplb_config to set eplb param (#20562)

Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: rongfu.leng <lenronfu@gmail.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
rongfu.leng
2025-08-21 05:07:28 +08:00
committed by GitHub
parent 4e51fa8cba
commit 4fbda0b20c
9 changed files with 149 additions and 52 deletions

View File

@@ -25,7 +25,7 @@ import vllm.envs as envs
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
ConfigFormat, ConfigType, ConvertOption,
DecodingConfig, DetailedTraceModules, Device,
DeviceConfig, DistributedExecutorBackend,
DeviceConfig, DistributedExecutorBackend, EPLBConfig,
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
KVTransferConfig, LoadConfig, LogprobsMode,
LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
@@ -305,11 +305,12 @@ class EngineArgs:
data_parallel_hybrid_lb: bool = False
data_parallel_backend: str = ParallelConfig.data_parallel_backend
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
enable_eplb: bool = ParallelConfig.enable_eplb
num_redundant_experts: int = ParallelConfig.num_redundant_experts
eplb_window_size: int = ParallelConfig.eplb_window_size
eplb_step_interval: int = ParallelConfig.eplb_step_interval
eplb_log_balancedness: bool = ParallelConfig.eplb_log_balancedness
num_redundant_experts: int = EPLBConfig.num_redundant_experts
eplb_window_size: int = EPLBConfig.window_size
eplb_step_interval: int = EPLBConfig.step_interval
eplb_log_balancedness: bool = EPLBConfig.log_balancedness
max_parallel_loading_workers: Optional[
int] = ParallelConfig.max_parallel_loading_workers
block_size: Optional[BlockSize] = CacheConfig.block_size
@@ -454,6 +455,9 @@ class EngineArgs:
if isinstance(self.compilation_config, dict):
self.compilation_config = CompilationConfig(
**self.compilation_config)
if isinstance(self.eplb_config, dict):
self.eplb_config = EPLBConfig.from_cli(json.dumps(
self.eplb_config))
# Setup plugins
from vllm.plugins import load_general_plugins
load_general_plugins()
@@ -661,14 +665,32 @@ class EngineArgs:
**parallel_kwargs["enable_expert_parallel"])
parallel_group.add_argument("--enable-eplb",
**parallel_kwargs["enable_eplb"])
parallel_group.add_argument("--num-redundant-experts",
**parallel_kwargs["num_redundant_experts"])
parallel_group.add_argument("--eplb-window-size",
**parallel_kwargs["eplb_window_size"])
parallel_group.add_argument("--eplb-step-interval",
**parallel_kwargs["eplb_step_interval"])
parallel_group.add_argument("--eplb-log-balancedness",
**parallel_kwargs["eplb_log_balancedness"])
parallel_group.add_argument("--eplb-config",
**parallel_kwargs["eplb_config"])
parallel_group.add_argument(
"--num-redundant-experts",
type=int,
help=
"[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.",
deprecated=True)
parallel_group.add_argument(
"--eplb-window-size",
type=int,
help="[DEPRECATED] --eplb-window-size will be removed in v0.12.0.",
deprecated=True)
parallel_group.add_argument(
"--eplb-step-interval",
type=int,
help=
"[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.",
deprecated=True)
parallel_group.add_argument(
"--eplb-log-balancedness",
action=argparse.BooleanOptionalAction,
help=
"[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.",
deprecated=True)
parallel_group.add_argument(
"--max-parallel-loading-workers",
**parallel_kwargs["max_parallel_loading_workers"])
@@ -1244,6 +1266,16 @@ class EngineArgs:
"Currently, speculative decoding is not supported with "
"async scheduling.")
# Forward the deprecated CLI args to the EPLB config.
if self.num_redundant_experts is not None:
self.eplb_config.num_redundant_experts = self.num_redundant_experts
if self.eplb_window_size is not None:
self.eplb_config.window_size = self.eplb_window_size
if self.eplb_step_interval is not None:
self.eplb_config.step_interval = self.eplb_step_interval
if self.eplb_log_balancedness is not None:
self.eplb_config.log_balancedness = self.eplb_log_balancedness
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
@@ -1257,10 +1289,7 @@ class EngineArgs:
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
enable_expert_parallel=self.enable_expert_parallel,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.num_redundant_experts,
eplb_window_size=self.eplb_window_size,
eplb_step_interval=self.eplb_step_interval,
eplb_log_balancedness=self.eplb_log_balancedness,
eplb_config=self.eplb_config,
max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce,
ray_workers_use_nsight=self.ray_workers_use_nsight,