kernels/moe test pruning (#27053)
Signed-off-by: Fardin Hoque <kfhfar@amazon.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
@@ -24,23 +24,16 @@ from vllm.triton_utils import tl
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 128, 128),
|
||||
(1, 128, 2048),
|
||||
(1, 512, 512),
|
||||
(1, 1024, 128),
|
||||
(1, 1024, 2048),
|
||||
(32, 128, 128),
|
||||
(32, 512, 512),
|
||||
(32, 1024, 2048),
|
||||
(45, 128, 128),
|
||||
(45, 128, 2048),
|
||||
(45, 512, 512),
|
||||
(45, 1024, 128),
|
||||
(45, 1024, 2048),
|
||||
(64, 512, 512),
|
||||
(64, 1024, 2048),
|
||||
(222, 128, 128),
|
||||
(222, 128, 2048),
|
||||
(222, 1024, 128),
|
||||
(222, 1024, 2048),
|
||||
]
|
||||
NUM_EXPERTS = [8, 64]
|
||||
@@ -117,10 +110,19 @@ def test_batched_mm(
|
||||
block_shape: list[int] | None,
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
|
||||
and those tests will be skipped on unsupported hardware."""
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
|
||||
|
||||
if (dtype == torch.float8_e4m3fn) and not current_platform.has_device_capability(
|
||||
89
|
||||
):
|
||||
pytest.skip(
|
||||
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
|
||||
)
|
||||
|
||||
if (per_act_token_quant or block_shape is not None) and not use_fp8_w8a8:
|
||||
pytest.skip("Don't test blocking for non-quantized types.")
|
||||
|
||||
@@ -244,10 +246,19 @@ def test_fused_moe_batched_experts(
|
||||
block_shape: list[int] | None,
|
||||
input_scales: bool,
|
||||
):
|
||||
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
|
||||
and those tests will be skipped on unsupported hardware."""
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
|
||||
|
||||
if (dtype == torch.float8_e4m3fn) and not current_platform.has_device_capability(
|
||||
89
|
||||
):
|
||||
pytest.skip(
|
||||
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
|
||||
)
|
||||
|
||||
if topk > e:
|
||||
pytest.skip("topk > e")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user