[CI/Build] Only use supported types and features on ROCm in MoE kernel tests (#29149)

Signed-off-by: Randall Smith <ransmith@amd.com>
Co-authored-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
rasmith
2025-11-21 21:34:33 -06:00
committed by GitHub
parent 77e1c035d0
commit fd65015a14
7 changed files with 41 additions and 2 deletions

View File

@@ -39,6 +39,11 @@ MNK_FACTORS = [
NUM_EXPERTS = [8, 64]
TOP_KS = [1, 2, 6]
DTYPES = [torch.bfloat16]
if not current_platform.is_fp8_fnuz():
DTYPES.append(torch.float8_e4m3fn)
vllm_config = VllmConfig()
@@ -96,7 +101,7 @@ class BatchedMMTensors:
@pytest.mark.parametrize("max_tokens_per_expert", [32, 224, 512])
@pytest.mark.parametrize("K", [128, 1024])
@pytest.mark.parametrize("N", [128, 1024])
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
def test_batched_mm(
@@ -229,7 +234,7 @@ def test_batched_mm(
@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("input_scales", [False])