[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-09-17 19:43:31 -04:00
committed by GitHub
parent e6585ddb45
commit 5963b98b46
68 changed files with 2698 additions and 2526 deletions

View File

@@ -4,7 +4,7 @@
import pytest
import torch
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.moe.utils import make_test_quant_config, make_test_weights
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
native_w8a8_block_matmul)
from vllm.config import VllmConfig, set_current_vllm_config
@@ -161,22 +161,17 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype)
(_, w1, w1_s, _), (_, w2, w2_s,
_) = make_test_weights(E,
N,
K,
dtype,
torch.float8_e4m3fn,
per_act_token_quant=False,
block_shape=block_size)
w1, w2, quant_config = make_test_quant_config(
E,
N,
K,
dtype,
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=False,
block_shape=block_size,
)
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
use_mxfp4_w4a4=False,
per_act_token_quant=False,
block_shape=block_size)
m_fused_moe = modular_triton_fused_moe(quant_config)
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
@@ -186,37 +181,24 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
a,
w1,
w2,
w1_s,
w2_s,
quant_config.w1_scale,
quant_config.w2_scale,
topk_weights,
topk_ids,
block_size,
)
out = fused_experts(
a,
w1,
w2,
topk_weights,
topk_ids,
use_fp8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=block_size,
)
out = fused_experts(a,
w1,
w2,
topk_weights,
topk_ids,
quant_config=quant_config)
m_out = m_fused_moe(
a,
w1,
w2,
topk_weights,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
)
m_out = m_fused_moe(a, w1, w2, topk_weights, topk_ids)
# 0.039 only needed for [40000-4608-7168-2-1-block_size852-dtype852-0]
tol = 0.035 if M < 40000 else 0.039
# 0.039 only needed for M >= 8192
tol = 0.035 if M < 8192 else 0.039
torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol)
torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol)
@@ -248,14 +230,15 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype)
(_, w1, w1_s, _), (_, w2, w2_s,
_) = make_test_weights(E,
N,
K,
dtype,
torch.float8_e4m3fn,
per_act_token_quant=False,
block_shape=block_size)
(_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
E,
N,
K,
dtype,
torch.float8_e4m3fn,
per_out_ch_quant=False,
block_shape=block_size,
)
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and