[ROCm][CI] Guard CudaPlatform/RocmPlatform imports to fix test collection on cross-platform builds (#37617)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -14,8 +14,19 @@ from vllm.config import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.cpu import CpuPlatform
|
||||
from vllm.platforms.cuda import CudaPlatform
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
|
||||
# CudaPlatform and RocmPlatform import their respective compiled C extensions
|
||||
# at module level, raising ModuleNotFoundError on incompatible builds.
|
||||
try:
|
||||
from vllm.platforms.cuda import CudaPlatform
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
CudaPlatform = None
|
||||
|
||||
try:
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
RocmPlatform = None
|
||||
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.attention.selector import _cached_get_attn_backend, get_attn_backend
|
||||
|
||||
@@ -101,6 +112,8 @@ def test_backend_selection(
|
||||
assert backend.get_name() == "CPU_ATTN"
|
||||
|
||||
elif device == "hip":
|
||||
if RocmPlatform is None:
|
||||
pytest.skip("RocmPlatform not available")
|
||||
with patch("vllm.platforms.current_platform", RocmPlatform()):
|
||||
if use_mla:
|
||||
# ROCm MLA backend logic:
|
||||
@@ -126,6 +139,8 @@ def test_backend_selection(
|
||||
assert backend.get_name() == expected
|
||||
|
||||
elif device == "cuda":
|
||||
if CudaPlatform is None:
|
||||
pytest.skip("CudaPlatform not available")
|
||||
with patch("vllm.platforms.current_platform", CudaPlatform()):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if use_mla:
|
||||
@@ -214,7 +229,7 @@ def test_backend_selection(
|
||||
assert backend.get_name() == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu", "cuda"])
|
||||
@pytest.mark.parametrize("device", ["cpu", "cuda", "hip"])
|
||||
def test_fp32_fallback(device: str):
|
||||
"""Test attention backend selection with fp32."""
|
||||
# Use default config (no backend specified)
|
||||
@@ -227,10 +242,25 @@ def test_fp32_fallback(device: str):
|
||||
assert backend.get_name() == "CPU_ATTN"
|
||||
|
||||
elif device == "cuda":
|
||||
if CudaPlatform is None:
|
||||
pytest.skip("CudaPlatform not available")
|
||||
with patch("vllm.platforms.current_platform", CudaPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, None)
|
||||
assert backend.get_name() == "FLEX_ATTENTION"
|
||||
|
||||
elif device == "hip":
|
||||
if RocmPlatform is None:
|
||||
pytest.skip("RocmPlatform not available")
|
||||
# ROCm backends do not support head_size=16 (minimum is 32).
|
||||
# No known HuggingFace transformer model uses head_size=16.
|
||||
# Revisit if a real model with this head size is identified
|
||||
# and accuracy-tested.
|
||||
with (
|
||||
patch("vllm.platforms.current_platform", RocmPlatform()),
|
||||
pytest.raises(ValueError, match="No valid attention backend"),
|
||||
):
|
||||
get_attn_backend(16, torch.float32, None)
|
||||
|
||||
|
||||
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test FlashAttn validation."""
|
||||
@@ -367,6 +397,8 @@ def test_per_head_quant_scales_backend_selection(
|
||||
attention_config=attention_config, cache_config=cache_config
|
||||
)
|
||||
|
||||
if CudaPlatform is None:
|
||||
pytest.skip("CudaPlatform not available")
|
||||
with (
|
||||
set_current_vllm_config(vllm_config),
|
||||
patch("vllm.platforms.current_platform", CudaPlatform()),
|
||||
|
||||
Reference in New Issue
Block a user