[MoE Refactor] Create MK for TRTLLM Kernels (#32564)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
Signed-off-by: Robert Shaw <robertgshaw2@gmail.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
Robert Shaw
2026-03-03 13:39:50 -05:00
committed by GitHub
parent 881a6b011b
commit 97995f6376
77 changed files with 2575 additions and 2087 deletions

View File

@@ -4,6 +4,7 @@
import pytest
import torch
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
@@ -12,7 +13,7 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize,
BatchedTritonExperts,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
from vllm.utils.deep_gemm import calc_diff, is_deep_gemm_supported
from .test_deepgemm import make_block_quant_fp8_weights
@@ -74,19 +75,22 @@ def test_batched_deepgemm_vs_triton(
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
)
mk_triton = FusedMoEModularKernel(
mk_triton = FusedMoEKernel(
prep_finalize,
triton_experts,
inplace=False,
)
out_triton = mk_triton(
out_triton = mk_triton.apply(
hidden_states=a,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=MoEActivation.SILU,
global_num_experts=E,
expert_map=None,
apply_router_weight_on_input=False,
)
# deepgemm
@@ -96,19 +100,22 @@ def test_batched_deepgemm_vs_triton(
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
)
mk_deepgemm = FusedMoEModularKernel(
mk_deepgemm = FusedMoEKernel(
prep_finalize,
deepgemm_experts,
inplace=False,
)
out_deepgemm = mk_deepgemm(
out_deepgemm = mk_deepgemm.apply(
hidden_states=a,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=MoEActivation.SILU,
global_num_experts=E,
expert_map=None,
apply_router_weight_on_input=False,
)
diff = calc_diff(out_deepgemm, out_triton)