[MoE Refactor] Oracle Select FP8+NVFP4 Kernels In Priority (#32414)

This commit is contained in:
Robert Shaw
2026-01-21 08:22:33 -05:00
committed by GitHub
parent e14467be43
commit 42135d6898
82 changed files with 2710 additions and 1563 deletions

View File

@@ -11,10 +11,19 @@ import math
import pytest
import torch
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import make_dummy_moe_config
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
@@ -100,6 +109,14 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
block_shape=block_size,
)
deep_gemm_experts = mk.FusedMoEModularKernel(
prepare_finalize=MoEPrepareAndFinalizeNoEP(),
fused_experts=TritonOrDeepGemmExperts(
moe_config=make_dummy_moe_config(),
quant_config=quant_config,
),
)
# triton reference
out_triton = fused_experts(
hidden_states=tokens_bf16,
@@ -109,19 +126,16 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
topk_ids=topk_ids,
inplace=False,
quant_config=quant_config,
allow_deep_gemm=False,
)
# DeepGemm
out_deepgemm = fused_experts(
out_deepgemm = deep_gemm_experts(
hidden_states=tokens_bf16,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
quant_config=quant_config,
allow_deep_gemm=True,
)
diff = calc_diff(out_deepgemm, out_triton)
assert diff < 0.001, f"Diff exceeded 1%: {diff}"
@@ -147,20 +161,19 @@ def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch, workspace_i
with monkeypatch.context() as mp:
mp.setenv("VLLM_USE_DEEP_GEMM", "1")
_fused_moe_mod = importlib.import_module(
"vllm.model_executor.layers.fused_moe.fused_moe"
)
_DeepGemmExperts = importlib.import_module(
"vllm.model_executor.layers.fused_moe.deep_gemm_moe"
).DeepGemmExperts
call_counter = {"cnt": 0}
orig_fn = _fused_moe_mod.deep_gemm_moe_fp8
orig_fn = _DeepGemmExperts.apply
def _spy_deep_gemm_moe_fp8(*args, **kwargs):
def _spy_apply(*args, **kwargs):
call_counter["cnt"] += 1
return orig_fn(*args, **kwargs)
monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", _spy_deep_gemm_moe_fp8)
monkeypatch.setattr(_DeepGemmExperts, "apply", _spy_apply)
if topk > num_experts:
pytest.skip(f"topk={topk} > num_experts={num_experts}")