[BugFix] Support EP/DP + EPLB with MTP (#25311)

Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
Ilya Markov
2025-11-05 16:22:17 +01:00
committed by GitHub
parent 5d16d0fa62
commit e50c454672
27 changed files with 957 additions and 529 deletions

View File

@@ -30,9 +30,11 @@ from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
get_ep_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
@@ -46,6 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.model_executor.models.interfaces import MixtureOfExperts
from vllm.model_executor.models.utils import sequence_parallel_chunk
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
@@ -56,6 +59,8 @@ from .utils import (
is_pp_missing_parameter,
)
logger = init_logger(__name__)
class Llama4MoE(nn.Module):
@staticmethod
@@ -80,6 +85,9 @@ class Llama4MoE(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.top_k = config.num_experts_per_tok
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
self.ep_group = get_ep_group().device_group
self.ep_rank = get_ep_group().rank_in_group
self.ep_size = self.ep_group.size()
intermediate_size_moe = config.intermediate_size
self.router = ReplicatedLinear(
@@ -101,6 +109,20 @@ class Llama4MoE(nn.Module):
disable_tp=self.is_sequence_parallel,
)
# Load balancing settings.
eplb_config = parallel_config.eplb_config if parallel_config else None
self.enable_eplb = parallel_config.enable_eplb if parallel_config else False
self.n_redundant_experts = (
eplb_config.num_redundant_experts if eplb_config else 0
)
self.n_routed_experts: int = config.num_local_experts
self.n_logical_experts = self.n_routed_experts
self.n_shared_experts: int = 1
self.n_local_experts: int = config.num_local_experts
self.n_physical_experts = self.n_local_experts + self.n_redundant_experts
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.experts = SharedFusedMoE(
shared_experts=self.shared_expert,
num_experts=config.num_local_experts,
@@ -114,6 +136,8 @@ class Llama4MoE(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.experts",
is_sequence_parallel=self.is_sequence_parallel,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
)
def forward(self, hidden_states):
@@ -378,6 +402,9 @@ class Llama4Model(LlamaModel):
layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer,
):
self.num_experts = vllm_config.model_config.hf_config.num_local_experts
self.n_redundant_experts = (
vllm_config.parallel_config.eplb_config.num_redundant_experts
)
super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
def load_moe_expert_weights(
@@ -499,7 +526,6 @@ class Llama4Model(LlamaModel):
shard_id=shard_id,
expert_id=expert_id,
)
loaded_params.add(full_param_name)
expert_param_loaded = True
@@ -526,6 +552,7 @@ class Llama4Model(LlamaModel):
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.num_experts,
num_redundant_experts=self.n_redundant_experts,
)
# Expert parameter mapping for the case where the expert weights are
# fused into a single weight tensor.
@@ -683,7 +710,7 @@ class Llama4Model(LlamaModel):
return loaded_params
class Llama4ForCausalLM(LlamaForCausalLM):
class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
@@ -702,6 +729,57 @@ class Llama4ForCausalLM(LlamaForCausalLM):
super().__init__(
vllm_config=vllm_config, prefix=prefix, layer_type=Llama4DecoderLayer
)
# Set MoE hyperparameters
self.set_moe_parameters()
def set_moe_parameters(self):
self.expert_weights = []
self.moe_layers = []
example_moe = None
for layer in self.model.layers:
assert isinstance(layer, Llama4DecoderLayer)
if isinstance(layer.feed_forward, Llama4MoE):
# Pick last one layer since the first ones may be dense layers.
example_moe = layer.feed_forward
self.moe_layers.append(layer.feed_forward.experts)
if example_moe is None:
self.num_moe_layers = 0
self.num_expert_groups = 0
self.num_logical_experts = 0
self.num_physical_experts = 0
self.num_local_physical_experts = 0
self.num_routed_experts = 0
self.num_shared_experts = 0
self.num_redundant_experts = 0
logger.warning("No Llama4MoE layer found in model.layers.")
else:
self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for layer in self.model.layers:
if isinstance(layer.feed_forward, Llama4MoE):
moe = layer.feed_forward
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def _init_model(
self,