[MoE Refactor] Aiter Experts for BF16 MoE (#31542)
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@@ -30,6 +30,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
AiterExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
swap_w13_to_w31,
|
||||
)
|
||||
@@ -60,12 +63,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
super().__init__(moe)
|
||||
|
||||
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
|
||||
|
||||
self.rocm_aiter_fused_experts = rocm_aiter_fused_experts
|
||||
else:
|
||||
self.rocm_aiter_fused_experts = None # type: ignore
|
||||
|
||||
# FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS
|
||||
self.flashinfer_cutlass_moe_enabled = (
|
||||
@@ -207,14 +204,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
|
||||
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data
|
||||
)
|
||||
|
||||
layer.w13_weight.data = shuffled_w13
|
||||
layer.w2_weight.data = shuffled_w2
|
||||
|
||||
if current_platform.is_xpu():
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
@@ -258,7 +247,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
|
||||
elif current_platform.is_cuda_alike():
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.flashinfer_cutlass_moe_enabled:
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data
|
||||
)
|
||||
replace_parameter(layer, "w13_weight", shuffled_w13)
|
||||
replace_parameter(layer, "w2_weight", shuffled_w2)
|
||||
|
||||
self.use_inplace = True
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
AiterExperts(self.moe_quant_config),
|
||||
shared_experts=None,
|
||||
)
|
||||
|
||||
elif self.flashinfer_cutlass_moe_enabled:
|
||||
self.use_inplace = False
|
||||
# Swap halves to arrange as [w3; w1] (kernel expectation)
|
||||
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
|
||||
@@ -315,30 +318,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
result = self.rocm_aiter_fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=layer.expert_map,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
result = self.kernel(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=self.use_inplace,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
)
|
||||
result = self.kernel(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=self.use_inplace,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
Reference in New Issue
Block a user