[MoE Refactor] Explicit construct mk for flashinfer bf16 kernel (#31504)

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Yongye Zhu
2026-01-02 13:54:50 -08:00
committed by GitHub
parent 5a468ff7c7
commit a3f2f40947

View File

@@ -16,6 +16,9 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
biased_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
@@ -27,7 +30,10 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
swap_w13_to_w31,
)
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
@@ -73,18 +79,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
logger.info_once(
"Enabling FlashInfer CUTLASS MoE for UnquantizedFusedMoEMethod"
)
from functools import partial
from .flashinfer_cutlass_moe import flashinfer_cutlass_moe
self.flashinfer_cutlass_moe = partial(
flashinfer_cutlass_moe,
quant_config=FUSED_MOE_UNQUANTIZED_CONFIG,
tp_rank=self.moe.moe_parallel_config.tp_rank,
tp_size=self.moe.moe_parallel_config.tp_size,
ep_rank=self.moe.moe_parallel_config.ep_rank,
ep_size=self.moe.moe_parallel_config.ep_size,
)
else:
if (
self.moe.moe_parallel_config.use_ep
@@ -101,7 +95,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"FlashInfer CUTLASS MoE is currently not available for DP.",
scope="local",
)
self.flashinfer_cutlass_moe = None # type: ignore
@property
def supports_eplb(self) -> bool:
@@ -222,12 +215,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.w13_weight.data = shuffled_w13
layer.w2_weight.data = shuffled_w2
if self.flashinfer_cutlass_moe_enabled:
# Swap halves to arrange as [w3; w1] (kernel expectation)
w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1)
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)
layer.w13_weight.data = w13_weight_swapped.contiguous()
if current_platform.is_xpu():
import intel_extension_for_pytorch as ipex
@@ -271,11 +258,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
elif current_platform.is_cuda_alike():
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonExperts(self.moe_quant_config),
shared_experts=None,
)
if self.flashinfer_cutlass_moe_enabled:
self.use_inplace = False
# Swap halves to arrange as [w3; w1] (kernel expectation)
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
replace_parameter(layer, "w13_weight", w13_weight)
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
FlashInferExperts(
out_dtype=layer.params_dtype,
quant_config=self.moe_quant_config,
tp_rank=self.moe.moe_parallel_config.tp_rank,
tp_size=self.moe.moe_parallel_config.tp_size,
ep_rank=self.moe.moe_parallel_config.ep_rank,
ep_size=self.moe.moe_parallel_config.ep_size,
),
)
else:
self.use_inplace = True
self.kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
TritonExperts(self.moe_quant_config),
shared_experts=None,
)
def apply(
self,
@@ -320,16 +326,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
elif self.flashinfer_cutlass_moe_enabled:
return self.flashinfer_cutlass_moe(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
else:
result = self.kernel(
hidden_states=x,
@@ -337,7 +333,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
inplace=self.use_inplace,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,