[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_weights
|
||||
from tests.kernels.moe.utils import make_test_quant_config, make_test_weights
|
||||
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
|
||||
native_w8a8_block_matmul)
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
@@ -161,22 +161,17 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
(_, w1, w1_s, _), (_, w2, w2_s,
|
||||
_) = make_test_weights(E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
w1, w2, quant_config = make_test_quant_config(
|
||||
E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size,
|
||||
)
|
||||
|
||||
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
use_mxfp4_w4a4=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
m_fused_moe = modular_triton_fused_moe(quant_config)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
@@ -186,37 +181,24 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
w1_s,
|
||||
w2_s,
|
||||
quant_config.w1_scale,
|
||||
quant_config.w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
block_size,
|
||||
)
|
||||
|
||||
out = fused_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=block_size,
|
||||
)
|
||||
out = fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config)
|
||||
|
||||
m_out = m_fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
)
|
||||
m_out = m_fused_moe(a, w1, w2, topk_weights, topk_ids)
|
||||
|
||||
# 0.039 only needed for [40000-4608-7168-2-1-block_size852-dtype852-0]
|
||||
tol = 0.035 if M < 40000 else 0.039
|
||||
# 0.039 only needed for M >= 8192
|
||||
tol = 0.035 if M < 8192 else 0.039
|
||||
torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol)
|
||||
torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol)
|
||||
|
||||
@@ -248,14 +230,15 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
(_, w1, w1_s, _), (_, w2, w2_s,
|
||||
_) = make_test_weights(E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
(_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
|
||||
E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
block_shape=block_size,
|
||||
)
|
||||
|
||||
# Note: for now use_compile will error out if the problem size is
|
||||
# large enough to trigger chunking. I'm leaving the flag and
|
||||
|
||||
Reference in New Issue
Block a user