[CI Perf] Prune tests in tests/kernels/moe/ (#22939)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -44,6 +44,14 @@ requires_pplx = pytest.mark.skipif(
|
||||
reason="Requires PPLX kernels",
|
||||
)
|
||||
|
||||
BATCHED_MOE_MNK_FACTORS = [
|
||||
(1, 128, 128),
|
||||
(33, 2048, 128),
|
||||
(64, 128, 2048),
|
||||
(222, 128, 128),
|
||||
(222, 2048, 1024),
|
||||
]
|
||||
|
||||
PPLX_COMBOS = [
|
||||
# TODO: figure out why this fails, seems to be test problem
|
||||
#(1, 128, 128),
|
||||
@@ -152,9 +160,7 @@ def torch_batched_moe(
|
||||
return torch_finalize(out, topk_weight, topk_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 512, 1024])
|
||||
@pytest.mark.parametrize("m,n,k", BATCHED_MOE_MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
|
||||
Reference in New Issue
Block a user