[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_weights
|
||||
from tests.kernels.moe.utils import make_test_quant_config
|
||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype)
|
||||
@@ -41,7 +41,6 @@ MNK_FACTORS = [
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", [40, 64, 256])
|
||||
#@pytest.mark.parametrize("e", [128, 256])
|
||||
@pytest.mark.parametrize("topk", [1, 6, 8])
|
||||
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
|
||||
@torch.inference_mode()
|
||||
@@ -56,16 +55,15 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
|
||||
quant_blocksize = 16
|
||||
|
||||
(_, w1_q, w1_blockscale,
|
||||
w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights(
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
in_dtype=dtype,
|
||||
quant_dtype="nvfp4",
|
||||
block_shape=None, # use quant_blocksize?
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
w1_q, w2_q, quant_config = make_test_quant_config(
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
in_dtype=dtype,
|
||||
quant_dtype="nvfp4",
|
||||
block_shape=None,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(a,
|
||||
@@ -73,35 +71,17 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
topk,
|
||||
renormalize=False)
|
||||
|
||||
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||
|
||||
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
|
||||
|
||||
assert w1_gs is not None
|
||||
assert w2_gs is not None
|
||||
assert w1_blockscale is not None
|
||||
assert w2_blockscale is not None
|
||||
|
||||
flashinfer_experts = FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
FlashInferExperts(
|
||||
a1_gscale=a1_gs,
|
||||
g1_alphas=(1 / w1_gs),
|
||||
a2_gscale=a2_gs,
|
||||
g2_alphas=(1 / w2_gs),
|
||||
out_dtype=dtype,
|
||||
quant_dtype="nvfp4",
|
||||
))
|
||||
FlashInferExperts(out_dtype=dtype, quant_config=quant_config),
|
||||
)
|
||||
|
||||
flashinfer_output = flashinfer_experts(
|
||||
hidden_states=a,
|
||||
w1=w1_q,
|
||||
w1_scale=w1_blockscale,
|
||||
w2=w2_q,
|
||||
w2_scale=w2_blockscale,
|
||||
a1_scale=a1_gs,
|
||||
a2_scale=a2_gs,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
)
|
||||
@@ -122,18 +102,18 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
|
||||
|
||||
for idx in range(0, e):
|
||||
w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
|
||||
w1_blockscale[idx],
|
||||
w1_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w1_q.device,
|
||||
block_size=quant_blocksize)
|
||||
w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
|
||||
w2_blockscale[idx],
|
||||
w2_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w2_q.device,
|
||||
block_size=quant_blocksize)
|
||||
w1_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w1_q[idx],
|
||||
quant_config.w1_scale[idx], (1 / quant_config.g1_alphas[idx]),
|
||||
dtype=dtype,
|
||||
device=w1_q.device,
|
||||
block_size=quant_blocksize)
|
||||
w2_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w2_q[idx],
|
||||
quant_config.w2_scale[idx], (1 / quant_config.g2_alphas[idx]),
|
||||
dtype=dtype,
|
||||
device=w2_q.device,
|
||||
block_size=quant_blocksize)
|
||||
|
||||
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user