[MoE Refactor] Create MK for TRTLLM Kernels (#32564)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
Signed-off-by: Robert Shaw <robertgshaw2@gmail.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
Robert Shaw
2026-03-03 13:39:50 -05:00
committed by GitHub
parent 881a6b011b
commit 97995f6376
77 changed files with 2575 additions and 2087 deletions

View File

@@ -17,13 +17,13 @@ from .mk_objects import (
def make_config_arg_parser(description: str):
def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalize:
def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalizeModular:
for pf in MK_ALL_PREPARE_FINALIZE_TYPES:
if pf.__name__ == s:
return pf
raise ValueError(f"Cannot find a PrepareFinalize type that matches {s}")
def to_experts_class_type(s: str) -> mk.FusedMoEPermuteExpertsUnpermute:
def to_experts_class_type(s: str) -> mk.FusedMoEExpertsModular:
for fe in MK_FUSED_EXPERT_TYPES:
if fe.__name__ == s:
return fe

View File

@@ -66,7 +66,7 @@ class Config:
quant_config: TestMoEQuantConfig | None
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute
fused_experts_type: mk.FusedMoEExperts
fused_moe_chunk_size: int | None
world_size: int
@@ -566,7 +566,7 @@ def make_modular_kernel(
config: Config,
vllm_config: VllmConfig,
quant_config: FusedMoEQuantConfig,
) -> mk.FusedMoEModularKernel:
) -> mk.FusedMoEKernel:
def next_power_of_2(x):
import math
@@ -613,7 +613,7 @@ def make_modular_kernel(
config.N,
)
modular_kernel = mk.FusedMoEModularKernel(
modular_kernel = mk.FusedMoEKernel(
prepare_finalize=prepare_finalize,
fused_experts=fused_experts,
inplace=False,
@@ -667,6 +667,7 @@ def run_modular_kernel(
"w2": rank_weights.w2,
"topk_weights": rank_tensors.topk_weights,
"topk_ids": topk_ids,
"activation": MoEActivation.SILU,
"expert_map": rank_tensors.expert_map,
"global_num_experts": config.E,
"apply_router_weight_on_input": config.topk == 1
@@ -684,6 +685,6 @@ def run_modular_kernel(
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
):
out = mk.forward(**mk_kwargs)
out = mk.apply(**mk_kwargs)
return out

View File

@@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
NaiveBatchedExperts,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
MoEPrepareAndFinalizeNoDPEPModular,
)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
@@ -71,12 +71,14 @@ class ExpertInfo:
needs_aiter: bool = False
PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize, PrepareFinalizeInfo] = {}
EXPERT_INFO: dict[mk.FusedMoEPermuteExpertsUnpermute, ExpertInfo] = {}
MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = []
PREPARE_FINALIZE_INFO: dict[
mk.FusedMoEPrepareAndFinalizeModular, PrepareFinalizeInfo
] = {}
EXPERT_INFO: dict[mk.FusedMoEExpertsModular, ExpertInfo] = {}
MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalizeModular] = []
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalizeModular] = []
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalizeModular] = []
MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEExpertsModular] = []
standard_format = mk.FusedMoEActivationFormat.Standard
batched_format = mk.FusedMoEActivationFormat.BatchedExperts
@@ -162,7 +164,7 @@ def expert_info(kind) -> ExpertInfo:
register_prepare_and_finalize(
MoEPrepareAndFinalizeNoEP,
MoEPrepareAndFinalizeNoDPEPModular,
standard_format,
common_float_types,
blocked_quantization_support=True,
@@ -239,14 +241,14 @@ if has_mori():
if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100):
from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize,
FlashInferA2APrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
register_prepare_and_finalize(
FlashInferCutlassMoEPrepareAndFinalize,
FlashInferA2APrepareAndFinalize,
standard_format,
nvfp4_types + fp8_types,
blocked_quantization_support=True,
@@ -430,12 +432,12 @@ def make_cutlass_strides(
def make_fused_experts(
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
fused_experts_type: mk.FusedMoEExpertsModular,
moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
num_dispatchers: int,
N: int,
) -> mk.FusedMoEPermuteExpertsUnpermute:
) -> mk.FusedMoEExpertsModular:
if (
fused_experts_type.activation_format()
== mk.FusedMoEActivationFormat.BatchedExperts

View File

@@ -72,7 +72,7 @@ def profile_modular_kernel(
"apply_router_weight_on_input": config.topk == 1,
}
do_profile(mk.forward, mk_kwargs, pgi, config)
do_profile(mk.apply, mk_kwargs, pgi, config)
def rank_worker(