[CI Perf] Prune tests in tests/kernels/moe/ (#22939)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2025-08-14 23:33:42 -04:00
committed by GitHub
parent 590bddbfc5
commit d2b0e97ea6
6 changed files with 46 additions and 31 deletions

View File

@@ -42,6 +42,24 @@ NUM_EXPERTS = [8, 64, 192]
EP_SIZE = [1, 4]
TOP_KS = [2, 6]
FUSED_MOE_MNK_FACTORS = [
(1, 128, 128),
(1, 2048, 128),
(33, 2048, 128),
(222, 1024, 1024),
(32768, 128, 128),
(32768, 2048, 511),
(40000, 1024, 1024),
]
FUSED_MOE_WN16_MNK_FACTORS = [
(1, 128, 128),
(1, 1024, 1024),
(32, 2048, 128),
(32, 1024, 1024),
(222, 2048, 1024),
]
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@@ -116,13 +134,11 @@ def run_moe_test(
return baseline_output
@pytest.mark.parametrize("m", [1, 33, 64, 222, 32768, 40000])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False])
@pytest.mark.parametrize("chunk_size", [8192])
def test_fused_moe(
@@ -235,13 +251,11 @@ def test_fused_moe(
use_cudagraph=use_cudagraph)
@pytest.mark.parametrize("m", [1, 32, 222])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 1024])
@pytest.mark.parametrize("m,n,k", FUSED_MOE_WN16_MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("weight_bits", [4, 8])
@@ -352,8 +366,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False])
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])