[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -11,6 +11,8 @@ import math
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config)
|
||||
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
@@ -94,6 +96,13 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
|
||||
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
a1_scale=a1_scale,
|
||||
block_shape=block_size,
|
||||
)
|
||||
|
||||
# triton reference
|
||||
out_triton = fused_experts(
|
||||
hidden_states=tokens_bf16,
|
||||
@@ -102,11 +111,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
a1_scale=a1_scale,
|
||||
block_shape=block_size,
|
||||
quant_config=quant_config,
|
||||
allow_deep_gemm=False,
|
||||
)
|
||||
|
||||
@@ -118,19 +123,14 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
a1_scale=a1_scale,
|
||||
block_shape=block_size,
|
||||
quant_config=quant_config,
|
||||
allow_deep_gemm=True,
|
||||
)
|
||||
diff = calc_diff(out_deepgemm, out_triton)
|
||||
assert diff < 0.001, f"Diff exceeded 1%: {diff}"
|
||||
|
||||
|
||||
# Note: W1 has shape (E, 2N, K), so N = 512
|
||||
# can trigger the deepgemm path.
|
||||
# Note: N <= 512 will disable the deepgemm path due to performance issues.
|
||||
MNKs = [
|
||||
(1024, 768, 128),
|
||||
(1024, 768, 512),
|
||||
@@ -144,15 +144,15 @@ TOPKS = [2, 6]
|
||||
NUM_EXPERTS = [32]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mnk", MNKs)
|
||||
@pytest.mark.parametrize(("m", "n", "k"), MNKs)
|
||||
@pytest.mark.parametrize("topk", TOPKS)
|
||||
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
|
||||
@pytest.mark.skipif(not is_deep_gemm_supported(),
|
||||
reason="Requires deep_gemm kernels")
|
||||
def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch):
|
||||
def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch):
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_DEEP_GEMM", "1")
|
||||
with monkeypatch.context() as mp:
|
||||
mp.setenv("VLLM_USE_DEEP_GEMM", "1")
|
||||
|
||||
_fused_moe_mod = importlib.import_module(
|
||||
"vllm.model_executor.layers.fused_moe.fused_moe")
|
||||
@@ -168,8 +168,6 @@ def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch):
|
||||
monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8",
|
||||
_spy_deep_gemm_moe_fp8)
|
||||
|
||||
m, n, k = mnk
|
||||
|
||||
if topk > num_experts:
|
||||
pytest.skip(f"topk={topk} > num_experts={num_experts}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user