[MoE Refactor] Oracle Select FP8+NVFP4 Kernels In Priority (#32414)

This commit is contained in:
Robert Shaw
2026-01-21 08:22:33 -05:00
committed by GitHub
parent e14467be43
commit 42135d6898
82 changed files with 2710 additions and 1563 deletions

View File

@@ -8,7 +8,10 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
@@ -116,18 +119,7 @@ class TestData:
layer.w13_weight_scale = w13_weight_scale
layer.w2_weight_scale = w2_weight_scale
# Setup dummy config.
layer.moe_parallel_config = mk.FusedMoEParallelConfig(
tp_size=1,
pcp_size=1,
dp_size=1,
ep_size=1,
tp_rank=0,
pcp_rank=0,
dp_rank=0,
ep_rank=0,
use_ep=False,
all2all_backend="naive",
)
layer.moe_parallel_config = mk.FusedMoEParallelConfig.make_no_parallel()
# flashinfer expects swapped rows for w13
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
@@ -238,6 +230,8 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
):
set_random_seed(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
assert activation in ["silu", "relu2_no_mul"]
is_act_and_mul = activation == "silu_and_mul"
with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit(
m, k, n, e, is_trtllm=False, activation=activation
@@ -285,19 +279,30 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config
td.layer.quant_method = td.layer
moe_config = FusedMoEConfig(
num_experts=e,
experts_per_token=topk,
hidden_dim=k,
intermediate_size_per_partition=n,
num_local_experts=e,
activation=activation,
device="cuda",
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
in_dtype=torch.bfloat16,
is_act_and_mul=is_act_and_mul,
routing_method=RoutingMethodType.TopK,
)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(
defer_input_quant=quant_config.is_block_quantized
defer_input_quant=FlashInferExperts.expects_unquantized_inputs(
moe_config=moe_config,
quant_config=quant_config,
)
),
FlashInferExperts(
out_dtype=td.layer.orig_dtype,
moe_config=moe_config,
quant_config=quant_config,
ep_rank=td.layer.moe_parallel_config.ep_rank,
ep_size=td.layer.moe_parallel_config.ep_size,
tp_rank=td.layer.moe_parallel_config.tp_rank,
tp_size=td.layer.moe_parallel_config.tp_size,
use_dp=False,
use_deepseek_fp8_block_scale=False,
),
)