diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 6b35c18dc..fd759f22b 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1342,22 +1342,41 @@ class FusedMoE(CustomOp): weight_name = qual_name.replace(weight_name, param_name) param_name = weight_name.removeprefix(f"{self.layer_name}.") param = getattr(self, param_name) - success = self.weight_loader( - param=param, - loaded_weight=loaded_weight, - weight_name=weight_name, - shard_id=shard_id, - expert_id=expert_id, - return_success=True, - ) - if success: - logger.debug( - "Loaded %s for expert %d into %s", - param_name, - expert_id, - self.layer_name, + # Fused expert weights can be identified by their 3D tensors + if loaded_weight.dim() == 3: + # Repurpose expert_id as shard_idx for deconcatenating w1 and w3 + if shard_id in {"w1", "w3"}: + shard_idx = expert_id + experts_shard = loaded_weight.chunk(2, dim=1)[shard_idx] + else: + experts_shard = loaded_weight + start = 0 + else: + # loaded_weight is a single expert weight, so we add a dummy expert + # dimension to unify the loading logic with the fused case + experts_shard = loaded_weight.unsqueeze(0) + start = expert_id + + # Unified loading logic for fused and non-fused experts + loaded_experts = experts_shard.unbind() + for expert_id, loaded_expert in enumerate(loaded_experts, start=start): + success = self.weight_loader( + param=param, + loaded_weight=loaded_expert, + weight_name=weight_name, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, ) - yield param_name + if success: + logger.debug( + "Loaded expert %d of shard %s into %s for layer %s", + expert_id, + shard_id, + param_name, + self.layer_name, + ) + yield param_name def get_expert_weights(self) -> Iterable[torch.Tensor]: def _maybe_make_contiguous( diff --git a/vllm/model_executor/models/transformers/moe.py b/vllm/model_executor/models/transformers/moe.py index 320bbab08..5f8352fae 100644 --- a/vllm/model_executor/models/transformers/moe.py +++ b/vllm/model_executor/models/transformers/moe.py @@ -156,6 +156,17 @@ class MoEMixin(MixtureOfExperts): Params for weights, fp8 weight scales, fp8 activation scales (param_name, weight_name, expert_id, shard_id) """ + # Models saved with fused experts. These are checkpoints released: + # - After Transformers v5 + # - Before Transformers v5, but re-saved with save_original_format=False + # In the fused experts case, we repurpose the expert_id as shard_idx for + # deconcatenating w1 and w3 in FusedMoE.load_weights. + expert_mapping = [ + ("experts.w13_weight", "experts.gate_up_proj", 0, "w1"), + ("experts.w13_weight", "experts.gate_up_proj", 1, "w3"), + ("experts.w2_weight", "experts.down_proj", 0, "w2"), + ] + # Models saved with ModuleList experts ckpt_names = [ # (ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name) ("gate_proj", "down_proj", "up_proj"), # Most common MoE style @@ -164,7 +175,6 @@ class MoEMixin(MixtureOfExperts): ] num_experts = self.model_config.get_num_experts() num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts - expert_mapping = [] for gate_proj, down_proj, up_proj in ckpt_names: expert_mapping.extend( FusedMoE.make_expert_params_mapping(