[BugFix] Support EP/DP + EPLB with MTP (#25311)
Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
@@ -30,9 +30,11 @@ from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (
|
||||
get_ep_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
@@ -46,6 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from vllm.model_executor.models.interfaces import MixtureOfExperts
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
|
||||
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
|
||||
@@ -56,6 +59,8 @@ from .utils import (
|
||||
is_pp_missing_parameter,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Llama4MoE(nn.Module):
|
||||
@staticmethod
|
||||
@@ -80,6 +85,9 @@ class Llama4MoE(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
self.ep_group = get_ep_group().device_group
|
||||
self.ep_rank = get_ep_group().rank_in_group
|
||||
self.ep_size = self.ep_group.size()
|
||||
|
||||
intermediate_size_moe = config.intermediate_size
|
||||
self.router = ReplicatedLinear(
|
||||
@@ -101,6 +109,20 @@ class Llama4MoE(nn.Module):
|
||||
disable_tp=self.is_sequence_parallel,
|
||||
)
|
||||
|
||||
# Load balancing settings.
|
||||
eplb_config = parallel_config.eplb_config if parallel_config else None
|
||||
self.enable_eplb = parallel_config.enable_eplb if parallel_config else False
|
||||
self.n_redundant_experts = (
|
||||
eplb_config.num_redundant_experts if eplb_config else 0
|
||||
)
|
||||
|
||||
self.n_routed_experts: int = config.num_local_experts
|
||||
self.n_logical_experts = self.n_routed_experts
|
||||
self.n_shared_experts: int = 1
|
||||
self.n_local_experts: int = config.num_local_experts
|
||||
self.n_physical_experts = self.n_local_experts + self.n_redundant_experts
|
||||
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
|
||||
|
||||
self.experts = SharedFusedMoE(
|
||||
shared_experts=self.shared_expert,
|
||||
num_experts=config.num_local_experts,
|
||||
@@ -114,6 +136,8 @@ class Llama4MoE(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
@@ -378,6 +402,9 @@ class Llama4Model(LlamaModel):
|
||||
layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer,
|
||||
):
|
||||
self.num_experts = vllm_config.model_config.hf_config.num_local_experts
|
||||
self.n_redundant_experts = (
|
||||
vllm_config.parallel_config.eplb_config.num_redundant_experts
|
||||
)
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
|
||||
|
||||
def load_moe_expert_weights(
|
||||
@@ -499,7 +526,6 @@ class Llama4Model(LlamaModel):
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
|
||||
loaded_params.add(full_param_name)
|
||||
expert_param_loaded = True
|
||||
|
||||
@@ -526,6 +552,7 @@ class Llama4Model(LlamaModel):
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.num_experts,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
)
|
||||
# Expert parameter mapping for the case where the expert weights are
|
||||
# fused into a single weight tensor.
|
||||
@@ -683,7 +710,7 @@ class Llama4Model(LlamaModel):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Llama4ForCausalLM(LlamaForCausalLM):
|
||||
class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
@@ -702,6 +729,57 @@ class Llama4ForCausalLM(LlamaForCausalLM):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config, prefix=prefix, layer_type=Llama4DecoderLayer
|
||||
)
|
||||
# Set MoE hyperparameters
|
||||
self.set_moe_parameters()
|
||||
|
||||
def set_moe_parameters(self):
|
||||
self.expert_weights = []
|
||||
|
||||
self.moe_layers = []
|
||||
example_moe = None
|
||||
for layer in self.model.layers:
|
||||
assert isinstance(layer, Llama4DecoderLayer)
|
||||
if isinstance(layer.feed_forward, Llama4MoE):
|
||||
# Pick last one layer since the first ones may be dense layers.
|
||||
example_moe = layer.feed_forward
|
||||
self.moe_layers.append(layer.feed_forward.experts)
|
||||
|
||||
if example_moe is None:
|
||||
self.num_moe_layers = 0
|
||||
self.num_expert_groups = 0
|
||||
self.num_logical_experts = 0
|
||||
self.num_physical_experts = 0
|
||||
self.num_local_physical_experts = 0
|
||||
self.num_routed_experts = 0
|
||||
self.num_shared_experts = 0
|
||||
self.num_redundant_experts = 0
|
||||
logger.warning("No Llama4MoE layer found in model.layers.")
|
||||
else:
|
||||
self.num_moe_layers = len(self.moe_layers)
|
||||
self.num_expert_groups = 1
|
||||
self.num_logical_experts = example_moe.n_logical_experts
|
||||
self.num_physical_experts = example_moe.n_physical_experts
|
||||
self.num_local_physical_experts = example_moe.n_local_physical_experts
|
||||
self.num_routed_experts = example_moe.n_routed_experts
|
||||
self.num_shared_experts = example_moe.n_shared_experts
|
||||
self.num_redundant_experts = example_moe.n_redundant_experts
|
||||
|
||||
def update_physical_experts_metadata(
|
||||
self,
|
||||
num_physical_experts: int,
|
||||
num_local_physical_experts: int,
|
||||
) -> None:
|
||||
assert self.num_local_physical_experts == num_local_physical_experts
|
||||
self.num_physical_experts = num_physical_experts
|
||||
self.num_local_physical_experts = num_local_physical_experts
|
||||
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
|
||||
for layer in self.model.layers:
|
||||
if isinstance(layer.feed_forward, Llama4MoE):
|
||||
moe = layer.feed_forward
|
||||
moe.n_local_physical_experts = num_local_physical_experts
|
||||
moe.n_physical_experts = num_physical_experts
|
||||
moe.n_redundant_experts = self.num_redundant_experts
|
||||
moe.experts.update_expert_map()
|
||||
|
||||
def _init_model(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user