[MoE Refactor] Explicit construct mk for flashinfer bf16 kernel (#31504)
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@@ -16,6 +16,9 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
biased_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
@@ -27,7 +30,10 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
swap_w13_to_w31,
|
||||
)
|
||||
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
@@ -73,18 +79,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
logger.info_once(
|
||||
"Enabling FlashInfer CUTLASS MoE for UnquantizedFusedMoEMethod"
|
||||
)
|
||||
from functools import partial
|
||||
|
||||
from .flashinfer_cutlass_moe import flashinfer_cutlass_moe
|
||||
|
||||
self.flashinfer_cutlass_moe = partial(
|
||||
flashinfer_cutlass_moe,
|
||||
quant_config=FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
tp_rank=self.moe.moe_parallel_config.tp_rank,
|
||||
tp_size=self.moe.moe_parallel_config.tp_size,
|
||||
ep_rank=self.moe.moe_parallel_config.ep_rank,
|
||||
ep_size=self.moe.moe_parallel_config.ep_size,
|
||||
)
|
||||
else:
|
||||
if (
|
||||
self.moe.moe_parallel_config.use_ep
|
||||
@@ -101,7 +95,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
"FlashInfer CUTLASS MoE is currently not available for DP.",
|
||||
scope="local",
|
||||
)
|
||||
self.flashinfer_cutlass_moe = None # type: ignore
|
||||
|
||||
@property
|
||||
def supports_eplb(self) -> bool:
|
||||
@@ -222,12 +215,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer.w13_weight.data = shuffled_w13
|
||||
layer.w2_weight.data = shuffled_w2
|
||||
|
||||
if self.flashinfer_cutlass_moe_enabled:
|
||||
# Swap halves to arrange as [w3; w1] (kernel expectation)
|
||||
w1_w, w3_w = torch.chunk(layer.w13_weight.data, 2, dim=1)
|
||||
w13_weight_swapped = torch.cat([w3_w, w1_w], dim=1)
|
||||
layer.w13_weight.data = w13_weight_swapped.contiguous()
|
||||
|
||||
if current_platform.is_xpu():
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
@@ -271,11 +258,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
|
||||
elif current_platform.is_cuda_alike():
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonExperts(self.moe_quant_config),
|
||||
shared_experts=None,
|
||||
)
|
||||
if self.flashinfer_cutlass_moe_enabled:
|
||||
self.use_inplace = False
|
||||
# Swap halves to arrange as [w3; w1] (kernel expectation)
|
||||
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
|
||||
replace_parameter(layer, "w13_weight", w13_weight)
|
||||
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
FlashInferExperts(
|
||||
out_dtype=layer.params_dtype,
|
||||
quant_config=self.moe_quant_config,
|
||||
tp_rank=self.moe.moe_parallel_config.tp_rank,
|
||||
tp_size=self.moe.moe_parallel_config.tp_size,
|
||||
ep_rank=self.moe.moe_parallel_config.ep_rank,
|
||||
ep_size=self.moe.moe_parallel_config.ep_size,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.use_inplace = True
|
||||
self.kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonExperts(self.moe_quant_config),
|
||||
shared_experts=None,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -320,16 +326,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
elif self.flashinfer_cutlass_moe_enabled:
|
||||
return self.flashinfer_cutlass_moe(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
result = self.kernel(
|
||||
hidden_states=x,
|
||||
@@ -337,7 +333,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
inplace=self.use_inplace,
|
||||
activation=layer.activation,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
|
||||
Reference in New Issue
Block a user