[BugFix] Support EP/DP + EPLB with MTP (#25311)

Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
Ilya Markov
2025-11-05 16:22:17 +01:00
committed by GitHub
parent 5d16d0fa62
commit e50c454672
27 changed files with 957 additions and 529 deletions

View File

@@ -32,6 +32,7 @@ from vllm.distributed.parallel_state import (
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.model_executor.models.interfaces import is_mixture_of_experts
from vllm.model_executor.warmup.kernel_warmup import kernel_warmup
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
@@ -613,7 +614,6 @@ class Worker(WorkerBase):
}
assert self.model_runner.eplb_state is not None
self.model_runner.eplb_state.rearrange(
self.model_runner.model,
execute_shuffle=True,
global_expert_load=None,
rank_mapping=rank_mapping,
@@ -626,7 +626,7 @@ class Worker(WorkerBase):
self,
old_ep_size: int,
new_ep_size: int,
global_expert_load: torch.Tensor | None,
global_expert_loads: list[torch.Tensor] | None,
) -> None:
from vllm.distributed.parallel_state import get_ep_group
@@ -635,9 +635,8 @@ class Worker(WorkerBase):
rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
assert self.model_runner.eplb_state is not None
self.model_runner.eplb_state.rearrange(
self.model_runner.model,
execute_shuffle=True,
global_expert_load=global_expert_load,
global_expert_loads=global_expert_loads,
rank_mapping=rank_mapping,
)
if get_ep_group().rank == 0:
@@ -684,31 +683,56 @@ class Worker(WorkerBase):
get_ep_group,
prepare_communication_buffer_for_model,
)
from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEParallelConfig,
)
parallel_config = self.vllm_config.parallel_config
moe_modules = [
module
for module in self.model_runner.model.modules()
if (
module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE"
)
]
num_local_experts = moe_modules[0].moe_config.num_local_experts
assert all(
module.moe_config.num_local_experts == num_local_experts
for module in moe_modules
), "All MoE modules must have the same number of experts"
for module in moe_modules:
module.moe_config.num_experts = num_local_experts * new_ep_size
module.global_num_experts = module.moe_config.num_experts
module.moe_parallel_config = FusedMoEParallelConfig.make(
tp_size_=get_tp_group().world_size,
dp_size_=get_dp_group().world_size,
vllm_parallel_config=parallel_config,
)
module.moe_config.moe_parallel_config = module.moe_parallel_config
def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]:
return [
module
for module in model.modules()
if (
module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE"
)
]
def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int):
assert all(
module.moe_config.num_local_experts == num_local_experts
for module in moe_modules
), "All MoE modules must have the same number of experts"
for module in moe_modules:
module.moe_config.num_experts = num_local_experts * new_ep_size
module.global_num_experts = module.moe_config.num_experts
module.moe_parallel_config = FusedMoEParallelConfig.make(
tp_size_=get_tp_group().world_size,
dp_size_=get_dp_group().world_size,
vllm_parallel_config=parallel_config,
)
module.moe_config.moe_parallel_config = module.moe_parallel_config
return moe_modules
model_moe_modules = get_moe_modules(self.model_runner.model)
num_local_experts = model_moe_modules[0].moe_config.num_local_experts
update_moe_modules(model_moe_modules, num_local_experts)
drafter_model = None
if hasattr(self.model_runner, "drafter") and hasattr(
self.model_runner.drafter, "model"
):
drafter_model = self.model_runner.drafter.model
if drafter_model is not None and is_mixture_of_experts(drafter_model):
drafter_moe_modules = get_moe_modules(drafter_model)
# Check if drafter and model have matching configs
assert (
drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts
), "Drafter and model configs should be the same"
update_moe_modules(drafter_moe_modules, num_local_experts)
if new_ep_size < old_ep_size:
num_local_physical_experts = num_local_experts
assert self.model_runner.eplb_state is not None
@@ -719,7 +743,7 @@ class Worker(WorkerBase):
new_physical_experts
- self.model_runner.eplb_state.logical_replica_count.shape[1]
)
global_expert_load = None
global_expert_loads = None
else:
num_local_physical_experts = torch.tensor(
[num_local_experts], dtype=torch.int32, device="cpu"
@@ -730,18 +754,20 @@ class Worker(WorkerBase):
num_local_physical_experts = num_local_physical_experts.item()
new_physical_experts = num_local_physical_experts * new_ep_size
assert self.model_runner.eplb_state is not None
global_expert_load = self.model_runner.eplb_state.rearrange(
self.model_runner.model, execute_shuffle=False
global_expert_loads = self.model_runner.eplb_state.rearrange(
execute_shuffle=False
)
parallel_config.eplb_config.num_redundant_experts = (
new_physical_experts - global_expert_load.shape[1]
new_physical_experts - global_expert_loads[0].shape[1]
)
prepare_communication_buffer_for_model(self.model_runner.model)
if drafter_model is not None:
prepare_communication_buffer_for_model(drafter_model)
self.model_runner.model.update_physical_experts_metadata(
num_physical_experts=new_physical_experts,
num_local_physical_experts=num_local_physical_experts,
)
return global_expert_load
return global_expert_loads
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest
@@ -782,11 +808,11 @@ class Worker(WorkerBase):
self.local_rank,
)
global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size)
global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
if new_ep_size > old_ep_size:
assert global_expert_load is not None
self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_load)
assert global_expert_loads is not None
self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
def save_sharded_state(
self,