AMD CI Test - unskip moe_sum test and moe_align_block_size tests (#32039)

Signed-off-by: Hongxia Yang <hongxia.yang@amd.com>
This commit is contained in:
Hongxia Yang
2026-01-14 02:25:10 -05:00
committed by GitHub
parent 7933638051
commit 048bb59728
3 changed files with 11 additions and 4 deletions

View File

@@ -1161,7 +1161,6 @@ def test_batched_moe_align_block_size_opcheck():
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
input = torch.randn((m, topk, k), device="cuda", dtype=dtype)
actual = torch.empty((m, k), device="cuda", dtype=dtype)

View File

@@ -12,7 +12,6 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
batched_moe_align_block_size,
moe_align_block_size,
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
from vllm.utils.torch_utils import set_random_seed
@@ -185,7 +184,6 @@ def torch_moe_align_block_size(
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("pad_sorted_ids", [False, True])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_align_block_size(
m: int, topk: int, num_experts: int, block_size: int, pad_sorted_ids: bool
):
@@ -245,7 +243,6 @@ def test_moe_align_block_size(
@pytest.mark.parametrize("topk", [2, 4])
@pytest.mark.parametrize("num_experts", [8, 64])
@pytest.mark.parametrize("block_size", [64])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_align_block_size_with_expert_map(
m: int, topk: int, num_experts: int, block_size: int
):

View File

@@ -187,6 +187,17 @@ class RocmPlatform(Platform):
if not on_gfx9():
supported_quantization += ["bitsandbytes"]
@classmethod
def import_kernels(cls) -> None:
"""Import ROCm-specific kernels."""
super().import_kernels()
import contextlib
# Import ROCm-specific extension
with contextlib.suppress(ImportError):
import vllm._rocm_C # noqa: F401
@classmethod
def get_attn_backend_cls(
cls,