[CI Perf] Prune tests in tests/kernels/moe/ (#22939)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2025-08-14 23:33:42 -04:00
committed by GitHub
parent 590bddbfc5
commit d2b0e97ea6
6 changed files with 46 additions and 31 deletions

View File

@@ -89,14 +89,11 @@ class BatchedMMTensors:
return BatchedMMTensors(A, B, C, num_expert_tokens)
@pytest.mark.parametrize("num_experts", [8, 16, 32])
@pytest.mark.parametrize("max_tokens_per_expert",
[32, 64, 128, 192, 224, 256, 512])
@pytest.mark.parametrize("K", [128, 256, 1024])
@pytest.mark.parametrize("N", [128, 256, 1024])
@pytest.mark.parametrize(
"dtype",
[torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("num_experts", [8, 32])
@pytest.mark.parametrize("max_tokens_per_expert", [32, 224, 512])
@pytest.mark.parametrize("K", [128, 1024])
@pytest.mark.parametrize("N", [128, 1024])
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,