[MoE Refactor] Oracle Select FP8+NVFP4 Kernels In Priority (#32414)

This commit is contained in:
Robert Shaw
2026-01-21 08:22:33 -05:00
committed by GitHub
parent e14467be43
commit 42135d6898
82 changed files with 2710 additions and 1563 deletions

View File

@@ -4,7 +4,12 @@
import pytest
import torch
from tests.kernels.moe.utils import make_test_quant_config, make_test_weights
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import (
make_dummy_moe_config,
make_test_quant_config,
make_test_weights,
)
from tests.kernels.quant_utils import (
native_per_token_group_quant_fp8,
native_w8a8_block_matmul,
@@ -15,13 +20,21 @@ from vllm.model_executor.layers.fused_moe import (
fused_experts,
fused_topk,
)
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm_shape,
deep_gemm_moe_fp8,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
modular_triton_fused_moe,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
)
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
get_mk_alignment_for_contiguous_layout,
@@ -161,7 +174,7 @@ def test_w8a8_block_fp8_fused_moe(
block_shape=block_size,
)
m_fused_moe = modular_triton_fused_moe(quant_config)
m_fused_moe = modular_triton_fused_moe(make_dummy_moe_config(), quant_config)
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
@@ -236,6 +249,29 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch)
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=block_size,
)
deep_gemm_experts = mk.FusedMoEModularKernel(
prepare_finalize=MoEPrepareAndFinalizeNoEP(),
fused_experts=TritonOrDeepGemmExperts(
moe_config=make_dummy_moe_config(),
quant_config=quant_config,
),
)
def deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids):
return deep_gemm_experts(
hidden_states=a,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
ref_out = torch_w8a8_block_fp8_moe(