kernels/moe test pruning (#27053)

Signed-off-by: Fardin Hoque <kfhfar@amazon.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Fardin Hoque
2025-10-29 21:10:34 -07:00
committed by GitHub
parent 17d055f527
commit b8c48c5d72
13 changed files with 34 additions and 56 deletions

View File

@@ -24,23 +24,16 @@ from vllm.triton_utils import tl
MNK_FACTORS = [
(1, 128, 128),
(1, 128, 2048),
(1, 512, 512),
(1, 1024, 128),
(1, 1024, 2048),
(32, 128, 128),
(32, 512, 512),
(32, 1024, 2048),
(45, 128, 128),
(45, 128, 2048),
(45, 512, 512),
(45, 1024, 128),
(45, 1024, 2048),
(64, 512, 512),
(64, 1024, 2048),
(222, 128, 128),
(222, 128, 2048),
(222, 1024, 128),
(222, 1024, 2048),
]
NUM_EXPERTS = [8, 64]
@@ -117,10 +110,19 @@ def test_batched_mm(
block_shape: list[int] | None,
per_act_token_quant: bool,
):
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
and those tests will be skipped on unsupported hardware."""
current_platform.seed_everything(7)
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
if (dtype == torch.float8_e4m3fn) and not current_platform.has_device_capability(
89
):
pytest.skip(
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
)
if (per_act_token_quant or block_shape is not None) and not use_fp8_w8a8:
pytest.skip("Don't test blocking for non-quantized types.")
@@ -244,10 +246,19 @@ def test_fused_moe_batched_experts(
block_shape: list[int] | None,
input_scales: bool,
):
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
and those tests will be skipped on unsupported hardware."""
current_platform.seed_everything(7)
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
if (dtype == torch.float8_e4m3fn) and not current_platform.has_device_capability(
89
):
pytest.skip(
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
)
if topk > e:
pytest.skip("topk > e")