[MoE Refactor] Create MK for TRTLLM Kernels (#32564)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Robert Shaw <rshaw@neuralmagic.com> Signed-off-by: Robert Shaw <robertgshaw2@gmail.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
@@ -8,6 +8,9 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -15,16 +18,14 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
RoutingMethodType,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import (
|
||||
TrtLlmFp8Experts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_fi_trtllm_fp8_per_tensor_moe,
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe,
|
||||
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe,
|
||||
swap_w13_to_w31,
|
||||
)
|
||||
@@ -115,6 +116,7 @@ class TestData:
|
||||
e: int,
|
||||
is_trtllm: bool,
|
||||
activation: MoEActivation = MoEActivation.SILU,
|
||||
topk: int = 1,
|
||||
) -> "TestData":
|
||||
is_gated = activation.is_gated
|
||||
|
||||
@@ -152,13 +154,6 @@ class TestData:
|
||||
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
|
||||
layer.w13_weight, layer.w2_weight, is_gated
|
||||
)
|
||||
register_scales_for_trtllm_fp8_per_tensor_moe(
|
||||
layer,
|
||||
layer.w13_weight_scale,
|
||||
layer.w13_input_scale,
|
||||
layer.w2_weight_scale,
|
||||
layer.w2_input_scale,
|
||||
)
|
||||
layer.custom_routing_function = Llama4MoE.custom_routing_function
|
||||
layer.routing_method_type = RoutingMethodType.Llama4
|
||||
layer.renormalize = False
|
||||
@@ -166,6 +161,21 @@ class TestData:
|
||||
layer.ep_rank = 0
|
||||
layer.local_num_experts = e
|
||||
|
||||
layer.moe = FusedMoEConfig(
|
||||
num_experts=e,
|
||||
experts_per_token=topk,
|
||||
hidden_dim=k,
|
||||
intermediate_size_per_partition=n,
|
||||
num_local_experts=e,
|
||||
num_logical_experts=e,
|
||||
moe_parallel_config=layer.moe_parallel_config,
|
||||
in_dtype=hidden_states.dtype,
|
||||
is_act_and_mul=is_gated,
|
||||
routing_method=layer.routing_method_type,
|
||||
activation=activation,
|
||||
device=w13_quantized.device,
|
||||
)
|
||||
|
||||
return TestData(
|
||||
hidden_states=hidden_states,
|
||||
w13_quantized=w13_quantized,
|
||||
@@ -230,16 +240,29 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
flashinfer_output = apply_fi_trtllm_fp8_per_tensor_moe(
|
||||
layer=td.layer,
|
||||
kernel = mk.FusedMoEKernel(
|
||||
maybe_make_prepare_finalize(
|
||||
moe=td.layer.moe,
|
||||
quant_config=quant_config,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=True,
|
||||
),
|
||||
TrtLlmFp8Experts(
|
||||
moe_config=td.layer.moe,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
flashinfer_output = kernel.apply_monolithic(
|
||||
hidden_states=td.hidden_states,
|
||||
w1=td.layer.w13_weight,
|
||||
w2=td.layer.w2_weight,
|
||||
router_logits=score,
|
||||
routing_bias=None,
|
||||
activation=activation,
|
||||
global_num_experts=e,
|
||||
top_k=topk,
|
||||
num_expert_group=None,
|
||||
topk_group=None,
|
||||
expert_map=None,
|
||||
apply_router_weight_on_input=True,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
|
||||
check_accuracy(
|
||||
@@ -329,8 +352,13 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
routing_method=RoutingMethodType.TopK,
|
||||
)
|
||||
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
kernel = mk.FusedMoEKernel(
|
||||
maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=quant_config,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=False,
|
||||
),
|
||||
FlashInferExperts(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
@@ -338,7 +366,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
flashinfer_cutlass_output = kernel(
|
||||
flashinfer_cutlass_output = kernel.apply(
|
||||
td.hidden_states,
|
||||
td.layer.w13_weight,
|
||||
td.layer.w2_weight,
|
||||
|
||||
Reference in New Issue
Block a user