[MoE Refactor] Create MK for TRTLLM Kernels (#32564)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
Signed-off-by: Robert Shaw <robertgshaw2@gmail.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
Robert Shaw
2026-03-03 13:39:50 -05:00
committed by GitHub
parent 881a6b011b
commit 97995f6376
77 changed files with 2575 additions and 2087 deletions

View File

@@ -8,6 +8,9 @@ from tests.kernels.quant_utils import per_block_cast_to_int8
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
@@ -23,10 +26,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
TritonExperts,
fused_experts,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.deep_gemm import per_block_cast_to_fp8
@@ -125,7 +125,9 @@ def batched_moe(
a2_scale=a2_scale,
)
fused_experts = FusedMoEModularKernel(
moe_config = make_dummy_moe_config()
fused_experts = FusedMoEKernel(
BatchedPrepareAndFinalize(
max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
),
@@ -133,12 +135,22 @@ def batched_moe(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
moe_config=moe_config,
),
inplace=False,
)
return fused_experts(a, w1, w2, topk_weight, topk_ids)
return fused_experts.apply(
a,
w1,
w2,
topk_weight,
topk_ids,
global_num_experts=w1.shape[0],
activation=moe_config.activation,
apply_router_weight_on_input=False,
expert_map=None,
)
def naive_batched_moe(
@@ -166,8 +178,9 @@ def naive_batched_moe(
a1_scale=a1_scale,
a2_scale=a2_scale,
)
moe_config = make_dummy_moe_config()
fused_experts = FusedMoEModularKernel(
fused_experts = FusedMoEKernel(
BatchedPrepareAndFinalize(
max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0
),
@@ -175,12 +188,22 @@ def naive_batched_moe(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
moe_config=moe_config,
),
inplace=False,
)
return fused_experts(a, w1, w2, topk_weight, topk_ids)
return fused_experts.apply(
a,
w1,
w2,
topk_weight,
topk_ids,
global_num_experts=w1.shape[0],
activation=moe_config.activation,
apply_router_weight_on_input=False,
expert_map=None,
)
def chunk_scales(
@@ -581,9 +604,14 @@ def modular_triton_fused_moe(
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
shared_experts: torch.nn.Module | None = None,
) -> FusedMoEModularKernel:
return FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
) -> FusedMoEKernel:
return FusedMoEKernel(
maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
TritonExperts(moe_config, quant_config),
shared_experts,
inplace=False,