[MoE Refactor] Create MK for TRTLLM Kernels (#32564)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
Signed-off-by: Robert Shaw <robertgshaw2@gmail.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
Robert Shaw
2026-03-03 13:39:50 -05:00
committed by GitHub
parent 881a6b011b
commit 97995f6376
77 changed files with 2575 additions and 2087 deletions

View File

@@ -8,6 +8,9 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
@@ -15,16 +18,14 @@ from vllm.model_executor.layers.fused_moe.config import (
RoutingMethodType,
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import (
TrtLlmFp8Experts,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe,
register_scales_for_trtllm_fp8_per_tensor_moe,
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe,
swap_w13_to_w31,
)
@@ -115,6 +116,7 @@ class TestData:
e: int,
is_trtllm: bool,
activation: MoEActivation = MoEActivation.SILU,
topk: int = 1,
) -> "TestData":
is_gated = activation.is_gated
@@ -152,13 +154,6 @@ class TestData:
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
layer.w13_weight, layer.w2_weight, is_gated
)
register_scales_for_trtllm_fp8_per_tensor_moe(
layer,
layer.w13_weight_scale,
layer.w13_input_scale,
layer.w2_weight_scale,
layer.w2_input_scale,
)
layer.custom_routing_function = Llama4MoE.custom_routing_function
layer.routing_method_type = RoutingMethodType.Llama4
layer.renormalize = False
@@ -166,6 +161,21 @@ class TestData:
layer.ep_rank = 0
layer.local_num_experts = e
layer.moe = FusedMoEConfig(
num_experts=e,
experts_per_token=topk,
hidden_dim=k,
intermediate_size_per_partition=n,
num_local_experts=e,
num_logical_experts=e,
moe_parallel_config=layer.moe_parallel_config,
in_dtype=hidden_states.dtype,
is_act_and_mul=is_gated,
routing_method=layer.routing_method_type,
activation=activation,
device=w13_quantized.device,
)
return TestData(
hidden_states=hidden_states,
w13_quantized=w13_quantized,
@@ -230,16 +240,29 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
quant_config=quant_config,
)
flashinfer_output = apply_fi_trtllm_fp8_per_tensor_moe(
layer=td.layer,
kernel = mk.FusedMoEKernel(
maybe_make_prepare_finalize(
moe=td.layer.moe,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=True,
),
TrtLlmFp8Experts(
moe_config=td.layer.moe,
quant_config=quant_config,
),
)
flashinfer_output = kernel.apply_monolithic(
hidden_states=td.hidden_states,
w1=td.layer.w13_weight,
w2=td.layer.w2_weight,
router_logits=score,
routing_bias=None,
activation=activation,
global_num_experts=e,
top_k=topk,
num_expert_group=None,
topk_group=None,
expert_map=None,
apply_router_weight_on_input=True,
routed_scaling_factor=1.0,
)
check_accuracy(
@@ -329,8 +352,13 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
routing_method=RoutingMethodType.TopK,
)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
kernel = mk.FusedMoEKernel(
maybe_make_prepare_finalize(
moe=moe_config,
quant_config=quant_config,
allow_new_interface=True,
use_monolithic=False,
),
FlashInferExperts(
moe_config=moe_config,
quant_config=quant_config,
@@ -338,7 +366,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
inplace=False,
)
flashinfer_cutlass_output = kernel(
flashinfer_cutlass_output = kernel.apply(
td.hidden_states,
td.layer.w13_weight,
td.layer.w2_weight,