[MoE Refactor] Aiter Experts for BF16 MoE (#31542)

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Yongye Zhu
2026-01-05 14:52:59 -08:00
committed by GitHub
parent af9a7ec255
commit 776ca1e187

View File

@@ -30,6 +30,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP, MoEPrepareAndFinalizeNoEP,
) )
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
swap_w13_to_w31, swap_w13_to_w31,
) )
@@ -60,12 +63,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
super().__init__(moe) super().__init__(moe)
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
if self.rocm_aiter_moe_enabled:
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
self.rocm_aiter_fused_experts = rocm_aiter_fused_experts
else:
self.rocm_aiter_fused_experts = None # type: ignore
# FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS # FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS
self.flashinfer_cutlass_moe_enabled = ( self.flashinfer_cutlass_moe_enabled = (
@@ -207,14 +204,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data) layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
if self.rocm_aiter_moe_enabled:
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
layer.w13_weight.data = shuffled_w13
layer.w2_weight.data = shuffled_w2
if current_platform.is_xpu(): if current_platform.is_xpu():
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
@@ -258,7 +247,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
elif current_platform.is_cuda_alike(): elif current_platform.is_cuda_alike():
self.moe_quant_config = self.get_fused_moe_quant_config(layer) self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.flashinfer_cutlass_moe_enabled: if self.rocm_aiter_moe_enabled:
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
replace_parameter(layer, "w13_weight", shuffled_w13)
replace_parameter(layer, "w2_weight", shuffled_w2)
self.use_inplace = True
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
AiterExperts(self.moe_quant_config),
shared_experts=None,
)
elif self.flashinfer_cutlass_moe_enabled:
self.use_inplace = False self.use_inplace = False
# Swap halves to arrange as [w3; w1] (kernel expectation) # Swap halves to arrange as [w3; w1] (kernel expectation)
w13_weight = swap_w13_to_w31(layer.w13_weight.data) w13_weight = swap_w13_to_w31(layer.w13_weight.data)
@@ -315,30 +318,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
router_logits=router_logits, router_logits=router_logits,
) )
if self.rocm_aiter_moe_enabled: result = self.kernel(
result = self.rocm_aiter_fused_experts( hidden_states=x,
hidden_states=x, w1=layer.w13_weight,
w1=layer.w13_weight, w2=layer.w2_weight,
w2=layer.w2_weight, topk_weights=topk_weights,
topk_weights=topk_weights, topk_ids=topk_ids,
topk_ids=topk_ids, inplace=self.use_inplace,
expert_map=layer.expert_map, activation=layer.activation,
activation=layer.activation, apply_router_weight_on_input=layer.apply_router_weight_on_input,
apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts,
) expert_map=layer.expert_map,
else: )
result = self.kernel(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=self.use_inplace,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
)
return result return result