AMD CI Test - unskip moe_sum test and moe_align_block_size tests (#32039)
Signed-off-by: Hongxia Yang <hongxia.yang@amd.com>
This commit is contained in:
@@ -1161,7 +1161,6 @@ def test_batched_moe_align_block_size_opcheck():
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
|
||||
input = torch.randn((m, topk, k), device="cuda", dtype=dtype)
|
||||
actual = torch.empty((m, k), device="cuda", dtype=dtype)
|
||||
|
||||
@@ -12,7 +12,6 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
batched_moe_align_block_size,
|
||||
moe_align_block_size,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
@@ -185,7 +184,6 @@ def torch_moe_align_block_size(
|
||||
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("pad_sorted_ids", [False, True])
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_moe_align_block_size(
|
||||
m: int, topk: int, num_experts: int, block_size: int, pad_sorted_ids: bool
|
||||
):
|
||||
@@ -245,7 +243,6 @@ def test_moe_align_block_size(
|
||||
@pytest.mark.parametrize("topk", [2, 4])
|
||||
@pytest.mark.parametrize("num_experts", [8, 64])
|
||||
@pytest.mark.parametrize("block_size", [64])
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_moe_align_block_size_with_expert_map(
|
||||
m: int, topk: int, num_experts: int, block_size: int
|
||||
):
|
||||
|
||||
@@ -187,6 +187,17 @@ class RocmPlatform(Platform):
|
||||
if not on_gfx9():
|
||||
supported_quantization += ["bitsandbytes"]
|
||||
|
||||
@classmethod
|
||||
def import_kernels(cls) -> None:
|
||||
"""Import ROCm-specific kernels."""
|
||||
super().import_kernels()
|
||||
|
||||
import contextlib
|
||||
|
||||
# Import ROCm-specific extension
|
||||
with contextlib.suppress(ImportError):
|
||||
import vllm._rocm_C # noqa: F401
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
|
||||
Reference in New Issue
Block a user