kernels/moe test pruning (#27053)

Signed-off-by: Fardin Hoque <kfhfar@amazon.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Fardin Hoque
2025-10-29 21:10:34 -07:00
committed by GitHub
parent 17d055f527
commit b8c48c5d72
13 changed files with 34 additions and 56 deletions

View File

@@ -66,8 +66,6 @@ FUSED_MOE_MNK_FACTORS = [
(1, 128, 128),
(1, 2048, 128),
(33, 2048, 128),
(222, 1024, 1024),
(32768, 128, 128),
(32768, 2048, 511),
(40000, 1024, 1024),
]
@@ -76,7 +74,6 @@ FUSED_MOE_WN16_MNK_FACTORS = [
(1, 128, 128),
(1, 1024, 1024),
(32, 2048, 128),
(32, 1024, 1024),
(222, 2048, 1024),
]
@@ -512,8 +509,8 @@ def marlin_moe_generate_valid_test_cases():
e_list = [4, 12]
topk_list = [2, 3]
ep_size_list = [1, 4]
dtype_list = [torch.half, torch.bfloat16]
group_size_list = [-1, 16, 32, 128]
dtype_list = [torch.bfloat16]
group_size_list = [-1, 32, 128]
act_order_list = [True, False]
quant_type_list = [
scalar_types.float4_e2m1f,
@@ -885,10 +882,10 @@ def test_batched_moe_align_block_size_opcheck():
)
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("m", [1, 33, 222])
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
input = torch.randn((m, topk, k), device="cuda", dtype=dtype)