diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 029edc44c..41762b7f6 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -30,6 +30,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP, ) +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + AiterExperts, +) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( swap_w13_to_w31, ) @@ -60,12 +63,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): super().__init__(moe) self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() - if self.rocm_aiter_moe_enabled: - from .rocm_aiter_fused_moe import rocm_aiter_fused_experts - - self.rocm_aiter_fused_experts = rocm_aiter_fused_experts - else: - self.rocm_aiter_fused_experts = None # type: ignore # FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS self.flashinfer_cutlass_moe_enabled = ( @@ -207,14 +204,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data) layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) - if self.rocm_aiter_moe_enabled: - shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( - layer.w13_weight.data, layer.w2_weight.data - ) - - layer.w13_weight.data = shuffled_w13 - layer.w2_weight.data = shuffled_w2 - if current_platform.is_xpu(): import intel_extension_for_pytorch as ipex @@ -258,7 +247,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) elif current_platform.is_cuda_alike(): self.moe_quant_config = self.get_fused_moe_quant_config(layer) - if self.flashinfer_cutlass_moe_enabled: + if self.rocm_aiter_moe_enabled: + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( + layer.w13_weight.data, layer.w2_weight.data + ) + replace_parameter(layer, "w13_weight", shuffled_w13) + replace_parameter(layer, "w2_weight", shuffled_w2) + + self.use_inplace = True + self.kernel = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + AiterExperts(self.moe_quant_config), + shared_experts=None, + ) + + elif self.flashinfer_cutlass_moe_enabled: self.use_inplace = False # Swap halves to arrange as [w3; w1] (kernel expectation) w13_weight = swap_w13_to_w31(layer.w13_weight.data) @@ -315,30 +318,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): router_logits=router_logits, ) - if self.rocm_aiter_moe_enabled: - result = self.rocm_aiter_fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - expert_map=layer.expert_map, - activation=layer.activation, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - ) - else: - result = self.kernel( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=self.use_inplace, - activation=layer.activation, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - ) + result = self.kernel( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=self.use_inplace, + activation=layer.activation, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + ) return result