[CI Perf] Prune tests in tests/kernels/moe/ (#22939)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -42,6 +42,24 @@ NUM_EXPERTS = [8, 64, 192]
|
||||
EP_SIZE = [1, 4]
|
||||
TOP_KS = [2, 6]
|
||||
|
||||
FUSED_MOE_MNK_FACTORS = [
|
||||
(1, 128, 128),
|
||||
(1, 2048, 128),
|
||||
(33, 2048, 128),
|
||||
(222, 1024, 1024),
|
||||
(32768, 128, 128),
|
||||
(32768, 2048, 511),
|
||||
(40000, 1024, 1024),
|
||||
]
|
||||
|
||||
FUSED_MOE_WN16_MNK_FACTORS = [
|
||||
(1, 128, 128),
|
||||
(1, 1024, 1024),
|
||||
(32, 2048, 128),
|
||||
(32, 1024, 1024),
|
||||
(222, 2048, 1024),
|
||||
]
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
@@ -116,13 +134,11 @@ def run_moe_test(
|
||||
return baseline_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222, 32768, 40000])
|
||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("padding", [True, False])
|
||||
@pytest.mark.parametrize("chunk_size", [8192])
|
||||
def test_fused_moe(
|
||||
@@ -235,13 +251,11 @@ def test_fused_moe(
|
||||
use_cudagraph=use_cudagraph)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 32, 222])
|
||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 1024])
|
||||
@pytest.mark.parametrize("m,n,k", FUSED_MOE_WN16_MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("group_size", [64, 128])
|
||||
@pytest.mark.parametrize("has_zp", [True, False])
|
||||
@pytest.mark.parametrize("weight_bits", [4, 8])
|
||||
@@ -352,8 +366,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("padding", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
|
||||
|
||||
Reference in New Issue
Block a user