[MoE Refactor] Oracle Select FP8+NVFP4 Kernels In Priority (#32414)
This commit is contained in:
@@ -11,10 +11,19 @@ import math
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||
|
||||
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from tests.kernels.moe.utils import make_dummy_moe_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
@@ -100,6 +109,14 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
block_shape=block_size,
|
||||
)
|
||||
|
||||
deep_gemm_experts = mk.FusedMoEModularKernel(
|
||||
prepare_finalize=MoEPrepareAndFinalizeNoEP(),
|
||||
fused_experts=TritonOrDeepGemmExperts(
|
||||
moe_config=make_dummy_moe_config(),
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
# triton reference
|
||||
out_triton = fused_experts(
|
||||
hidden_states=tokens_bf16,
|
||||
@@ -109,19 +126,16 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
quant_config=quant_config,
|
||||
allow_deep_gemm=False,
|
||||
)
|
||||
|
||||
# DeepGemm
|
||||
out_deepgemm = fused_experts(
|
||||
out_deepgemm = deep_gemm_experts(
|
||||
hidden_states=tokens_bf16,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
quant_config=quant_config,
|
||||
allow_deep_gemm=True,
|
||||
)
|
||||
diff = calc_diff(out_deepgemm, out_triton)
|
||||
assert diff < 0.001, f"Diff exceeded 1%: {diff}"
|
||||
@@ -147,20 +161,19 @@ def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch, workspace_i
|
||||
with monkeypatch.context() as mp:
|
||||
mp.setenv("VLLM_USE_DEEP_GEMM", "1")
|
||||
|
||||
_fused_moe_mod = importlib.import_module(
|
||||
"vllm.model_executor.layers.fused_moe.fused_moe"
|
||||
)
|
||||
_DeepGemmExperts = importlib.import_module(
|
||||
"vllm.model_executor.layers.fused_moe.deep_gemm_moe"
|
||||
).DeepGemmExperts
|
||||
|
||||
call_counter = {"cnt": 0}
|
||||
|
||||
orig_fn = _fused_moe_mod.deep_gemm_moe_fp8
|
||||
orig_fn = _DeepGemmExperts.apply
|
||||
|
||||
def _spy_deep_gemm_moe_fp8(*args, **kwargs):
|
||||
def _spy_apply(*args, **kwargs):
|
||||
call_counter["cnt"] += 1
|
||||
return orig_fn(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", _spy_deep_gemm_moe_fp8)
|
||||
|
||||
monkeypatch.setattr(_DeepGemmExperts, "apply", _spy_apply)
|
||||
if topk > num_experts:
|
||||
pytest.skip(f"topk={topk} > num_experts={num_experts}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user