[MoE Refactor] Create MK for TRTLLM Kernels (#32564)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Robert Shaw <rshaw@neuralmagic.com> Signed-off-by: Robert Shaw <robertgshaw2@gmail.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
@@ -17,13 +17,13 @@ from .mk_objects import (
|
||||
|
||||
|
||||
def make_config_arg_parser(description: str):
|
||||
def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalize:
|
||||
def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalizeModular:
|
||||
for pf in MK_ALL_PREPARE_FINALIZE_TYPES:
|
||||
if pf.__name__ == s:
|
||||
return pf
|
||||
raise ValueError(f"Cannot find a PrepareFinalize type that matches {s}")
|
||||
|
||||
def to_experts_class_type(s: str) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
def to_experts_class_type(s: str) -> mk.FusedMoEExpertsModular:
|
||||
for fe in MK_FUSED_EXPERT_TYPES:
|
||||
if fe.__name__ == s:
|
||||
return fe
|
||||
|
||||
@@ -66,7 +66,7 @@ class Config:
|
||||
quant_config: TestMoEQuantConfig | None
|
||||
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute
|
||||
fused_experts_type: mk.FusedMoEExperts
|
||||
|
||||
fused_moe_chunk_size: int | None
|
||||
world_size: int
|
||||
@@ -566,7 +566,7 @@ def make_modular_kernel(
|
||||
config: Config,
|
||||
vllm_config: VllmConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
) -> mk.FusedMoEKernel:
|
||||
def next_power_of_2(x):
|
||||
import math
|
||||
|
||||
@@ -613,7 +613,7 @@ def make_modular_kernel(
|
||||
config.N,
|
||||
)
|
||||
|
||||
modular_kernel = mk.FusedMoEModularKernel(
|
||||
modular_kernel = mk.FusedMoEKernel(
|
||||
prepare_finalize=prepare_finalize,
|
||||
fused_experts=fused_experts,
|
||||
inplace=False,
|
||||
@@ -667,6 +667,7 @@ def run_modular_kernel(
|
||||
"w2": rank_weights.w2,
|
||||
"topk_weights": rank_tensors.topk_weights,
|
||||
"topk_ids": topk_ids,
|
||||
"activation": MoEActivation.SILU,
|
||||
"expert_map": rank_tensors.expert_map,
|
||||
"global_num_experts": config.E,
|
||||
"apply_router_weight_on_input": config.topk == 1
|
||||
@@ -684,6 +685,6 @@ def run_modular_kernel(
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
):
|
||||
out = mk.forward(**mk_kwargs)
|
||||
out = mk.apply(**mk_kwargs)
|
||||
|
||||
return out
|
||||
|
||||
@@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
NaiveBatchedExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
MoEPrepareAndFinalizeNoDPEPModular,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
@@ -71,12 +71,14 @@ class ExpertInfo:
|
||||
needs_aiter: bool = False
|
||||
|
||||
|
||||
PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize, PrepareFinalizeInfo] = {}
|
||||
EXPERT_INFO: dict[mk.FusedMoEPermuteExpertsUnpermute, ExpertInfo] = {}
|
||||
MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
||||
MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = []
|
||||
PREPARE_FINALIZE_INFO: dict[
|
||||
mk.FusedMoEPrepareAndFinalizeModular, PrepareFinalizeInfo
|
||||
] = {}
|
||||
EXPERT_INFO: dict[mk.FusedMoEExpertsModular, ExpertInfo] = {}
|
||||
MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalizeModular] = []
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalizeModular] = []
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalizeModular] = []
|
||||
MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEExpertsModular] = []
|
||||
|
||||
standard_format = mk.FusedMoEActivationFormat.Standard
|
||||
batched_format = mk.FusedMoEActivationFormat.BatchedExperts
|
||||
@@ -162,7 +164,7 @@ def expert_info(kind) -> ExpertInfo:
|
||||
|
||||
|
||||
register_prepare_and_finalize(
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
MoEPrepareAndFinalizeNoDPEPModular,
|
||||
standard_format,
|
||||
common_float_types,
|
||||
blocked_quantization_support=True,
|
||||
@@ -239,14 +241,14 @@ if has_mori():
|
||||
|
||||
if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100):
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import ( # noqa: E501
|
||||
FlashInferCutlassMoEPrepareAndFinalize,
|
||||
FlashInferA2APrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
)
|
||||
|
||||
register_prepare_and_finalize(
|
||||
FlashInferCutlassMoEPrepareAndFinalize,
|
||||
FlashInferA2APrepareAndFinalize,
|
||||
standard_format,
|
||||
nvfp4_types + fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
@@ -430,12 +432,12 @@ def make_cutlass_strides(
|
||||
|
||||
|
||||
def make_fused_experts(
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
|
||||
fused_experts_type: mk.FusedMoEExpertsModular,
|
||||
moe: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
num_dispatchers: int,
|
||||
N: int,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
) -> mk.FusedMoEExpertsModular:
|
||||
if (
|
||||
fused_experts_type.activation_format()
|
||||
== mk.FusedMoEActivationFormat.BatchedExperts
|
||||
|
||||
@@ -72,7 +72,7 @@ def profile_modular_kernel(
|
||||
"apply_router_weight_on_input": config.topk == 1,
|
||||
}
|
||||
|
||||
do_profile(mk.forward, mk_kwargs, pgi, config)
|
||||
do_profile(mk.apply, mk_kwargs, pgi, config)
|
||||
|
||||
|
||||
def rank_worker(
|
||||
|
||||
Reference in New Issue
Block a user