[MoE Refactor] Create MK for TRTLLM Kernels (#32564)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
Signed-off-by: Robert Shaw <robertgshaw2@gmail.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
Robert Shaw
2026-03-03 13:39:50 -05:00
committed by GitHub
parent 881a6b011b
commit 97995f6376
77 changed files with 2575 additions and 2087 deletions

View File

@@ -22,7 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
from vllm.utils.deep_gemm import (
get_mk_alignment_for_contiguous_layout,
is_deep_gemm_e8m0_used,
@@ -170,7 +170,7 @@ def make_ll_modular_kernel(
q_dtype: torch.dtype | None,
test_config: TestConfig,
quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
) -> FusedMoEKernel:
assert test_config.low_latency
assert test_config.use_fp8_dispatch is not None
@@ -195,7 +195,7 @@ def make_ll_modular_kernel(
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
)
return FusedMoEModularKernel(
return FusedMoEKernel(
prepare_finalize=a2a,
fused_experts=fused_experts,
inplace=False,
@@ -210,7 +210,7 @@ def make_ht_modular_kernel(
q_dtype: torch.dtype | None,
test_config: TestConfig,
quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
) -> FusedMoEKernel:
assert not test_config.low_latency
assert test_config.use_fp8_dispatch is None
@@ -228,7 +228,7 @@ def make_ht_modular_kernel(
moe_config=make_dummy_moe_config(),
quant_config=quant_config,
)
return FusedMoEModularKernel(
return FusedMoEKernel(
prepare_finalize=a2a,
fused_experts=fused_experts,
inplace=False,
@@ -242,11 +242,11 @@ def make_modular_kernel(
num_local_experts: int,
test_tensors: TestTensors,
quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
) -> FusedMoEKernel:
q_dtype = torch.float8_e4m3fn
test_config = test_tensors.config
mk: FusedMoEModularKernel
mk: FusedMoEKernel
# Make modular kernel
if test_config.low_latency:
max_tokens_per_rank = max(64, next_power_of_2(test_tensors.rank_tokens.size(0)))
@@ -307,7 +307,7 @@ def deepep_deepgemm_moe_impl(
)
# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(
mk: FusedMoEKernel = make_modular_kernel(
pg=pg,
pgi=pgi,
dp_size=dp_size,
@@ -319,7 +319,7 @@ def deepep_deepgemm_moe_impl(
with with_dp_metadata(
M=test_tensors.rank_tokens.size(0), world_size=pgi.world_size
):
out = mk.forward(
out = mk.apply(
hidden_states=test_tensors.rank_tokens,
w1=w1,
w2=w2,