[CI/Build] Only use supported types and features on ROCm in MoE kernel tests (#29149)
Signed-off-by: Randall Smith <ransmith@amd.com> Co-authored-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
@@ -39,6 +39,11 @@ MNK_FACTORS = [
|
||||
NUM_EXPERTS = [8, 64]
|
||||
TOP_KS = [1, 2, 6]
|
||||
|
||||
DTYPES = [torch.bfloat16]
|
||||
|
||||
if not current_platform.is_fp8_fnuz():
|
||||
DTYPES.append(torch.float8_e4m3fn)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
|
||||
|
||||
@@ -96,7 +101,7 @@ class BatchedMMTensors:
|
||||
@pytest.mark.parametrize("max_tokens_per_expert", [32, 224, 512])
|
||||
@pytest.mark.parametrize("K", [128, 1024])
|
||||
@pytest.mark.parametrize("N", [128, 1024])
|
||||
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
|
||||
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
||||
def test_batched_mm(
|
||||
@@ -229,7 +234,7 @@ def test_batched_mm(
|
||||
@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
||||
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
|
||||
@pytest.mark.parametrize("input_scales", [False])
|
||||
|
||||
Reference in New Issue
Block a user