[Kernels] Clean up FusedMoeMethodBase and modular kernel setup. Remove extra arguments from modular kernel methods. (#22035)
Signed-off-by: Bill Nell <bnell@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -161,18 +161,20 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
(_, w1, w1_s, _), (_, w2, w2_s,
|
||||
_) = make_test_weights(E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
|
||||
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
use_mxfp4_w4a4=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
|
||||
@@ -247,13 +249,14 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
(_, w1, w1_s, _), (_, w2, w2_s,
|
||||
_) = make_test_weights(E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
|
||||
# Note: for now use_compile will error out if the problem size is
|
||||
# large enough to trigger chunking. I'm leaving the flag and
|
||||
|
||||
Reference in New Issue
Block a user