[Bugfix] [ROCm] [UX] Reorganize ROCm Backend Selection Logic (#26980)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
337
tests/v1/attention/test_rocm_attention_backends_selection.py
Normal file
337
tests/v1/attention/test_rocm_attention_backends_selection.py
Normal file
@@ -0,0 +1,337 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Tests for attention backend selectors."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
# ROCm-specific attention backend selection tests
|
||||||
|
pytestmark = pytest.mark.skipif(
|
||||||
|
not current_platform.is_rocm(), reason="ROCm-specific tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_vllm_config():
|
||||||
|
"""Create a mock VllmConfig for testing."""
|
||||||
|
config = MagicMock()
|
||||||
|
config.model_config.dtype = torch.float16
|
||||||
|
config.model_config.hf_config.architectures = ["LlamaForCausalLM"]
|
||||||
|
config.cache_config.block_size = 16
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_on_gfx9():
|
||||||
|
"""Mock the on_gfx9 function to return True."""
|
||||||
|
with patch("vllm.platforms.rocm.on_gfx9", return_value=True):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"env_vars, selected_backend, expected_backend_path",
|
||||||
|
[
|
||||||
|
# Test Case 1: Default (no env vars, no explicit backend)
|
||||||
|
(
|
||||||
|
{},
|
||||||
|
None,
|
||||||
|
AttentionBackendEnum.TRITON_ATTN.get_path(),
|
||||||
|
),
|
||||||
|
# Test Case 2: Explicit TRITON_ATTN backend
|
||||||
|
(
|
||||||
|
{},
|
||||||
|
"TRITON_ATTN",
|
||||||
|
AttentionBackendEnum.TRITON_ATTN.get_path(),
|
||||||
|
),
|
||||||
|
# Test Case 3: Explicit ROCM_ATTN backend
|
||||||
|
(
|
||||||
|
{},
|
||||||
|
"ROCM_ATTN",
|
||||||
|
AttentionBackendEnum.ROCM_ATTN.get_path(),
|
||||||
|
),
|
||||||
|
# Test Case 4: Explicit ROCM_AITER_FA backend
|
||||||
|
(
|
||||||
|
{},
|
||||||
|
"ROCM_AITER_FA",
|
||||||
|
AttentionBackendEnum.ROCM_AITER_FA.get_path(),
|
||||||
|
),
|
||||||
|
# Test Case 5: Explicit ROCM_AITER_UNIFIED_ATTN backend
|
||||||
|
(
|
||||||
|
{},
|
||||||
|
"ROCM_AITER_UNIFIED_ATTN",
|
||||||
|
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(),
|
||||||
|
),
|
||||||
|
# Test Case 6: VLLM_ROCM_USE_AITER=1
|
||||||
|
# (defaults to AITER FA when MHA not explicitly disabled)
|
||||||
|
(
|
||||||
|
{"VLLM_ROCM_USE_AITER": "1"},
|
||||||
|
None,
|
||||||
|
AttentionBackendEnum.ROCM_AITER_FA.get_path(),
|
||||||
|
),
|
||||||
|
# Test Case 7: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=1
|
||||||
|
(
|
||||||
|
{"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "1"},
|
||||||
|
None,
|
||||||
|
AttentionBackendEnum.ROCM_AITER_FA.get_path(),
|
||||||
|
),
|
||||||
|
# Test Case 8: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION=1
|
||||||
|
(
|
||||||
|
{
|
||||||
|
"VLLM_ROCM_USE_AITER": "1",
|
||||||
|
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": "1",
|
||||||
|
},
|
||||||
|
None,
|
||||||
|
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(),
|
||||||
|
),
|
||||||
|
# Test Case 9: VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1
|
||||||
|
(
|
||||||
|
{"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"},
|
||||||
|
None,
|
||||||
|
AttentionBackendEnum.ROCM_ATTN.get_path(),
|
||||||
|
),
|
||||||
|
# Test Case 10: VLLM_ROCM_USE_AITER=1 + explicit TRITON_ATTN
|
||||||
|
(
|
||||||
|
{"VLLM_ROCM_USE_AITER": "1"},
|
||||||
|
"TRITON_ATTN",
|
||||||
|
AttentionBackendEnum.TRITON_ATTN.get_path(),
|
||||||
|
),
|
||||||
|
# Test Case 11: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0
|
||||||
|
# (explicitly disabled)
|
||||||
|
(
|
||||||
|
{"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "0"},
|
||||||
|
None,
|
||||||
|
AttentionBackendEnum.TRITON_ATTN.get_path(),
|
||||||
|
),
|
||||||
|
# Test Case 12: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN
|
||||||
|
(
|
||||||
|
{"VLLM_ROCM_USE_AITER": "1"},
|
||||||
|
"ROCM_ATTN",
|
||||||
|
AttentionBackendEnum.ROCM_ATTN.get_path(),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_standard_attention_backend_selection(
|
||||||
|
env_vars,
|
||||||
|
selected_backend,
|
||||||
|
expected_backend_path,
|
||||||
|
mock_vllm_config,
|
||||||
|
mock_on_gfx9,
|
||||||
|
monkeypatch,
|
||||||
|
):
|
||||||
|
"""Test standard attention backend selection with various configurations."""
|
||||||
|
# Set environment variables
|
||||||
|
for key, value in env_vars.items():
|
||||||
|
monkeypatch.setenv(key, value)
|
||||||
|
|
||||||
|
# Import after setting env vars to ensure they're picked up
|
||||||
|
# Reload envs to pick up new environment variables
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
|
|
||||||
|
importlib.reload(envs)
|
||||||
|
|
||||||
|
# Convert string backend to enum if provided
|
||||||
|
backend_enum = None
|
||||||
|
if selected_backend:
|
||||||
|
backend_enum = getattr(_Backend, selected_backend)
|
||||||
|
|
||||||
|
# Get the backend class path
|
||||||
|
from vllm.platforms.rocm import RocmPlatform
|
||||||
|
|
||||||
|
backend_path = RocmPlatform.get_attn_backend_cls(
|
||||||
|
selected_backend=backend_enum,
|
||||||
|
head_size=128,
|
||||||
|
dtype=torch.float16,
|
||||||
|
kv_cache_dtype="auto",
|
||||||
|
block_size=16,
|
||||||
|
use_mla=False,
|
||||||
|
has_sink=False,
|
||||||
|
use_sparse=False,
|
||||||
|
)
|
||||||
|
assert backend_path == expected_backend_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"env_vars, selected_backend, block_size, expected_backend_path, should_raise",
|
||||||
|
[
|
||||||
|
# Test Case 1: TRITON_MLA with block_size != 1
|
||||||
|
(
|
||||||
|
{},
|
||||||
|
"TRITON_MLA",
|
||||||
|
16,
|
||||||
|
AttentionBackendEnum.TRITON_MLA.get_path(),
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
# Test Case 2: TRITON_MLA with block_size == 1 (should raise)
|
||||||
|
(
|
||||||
|
{},
|
||||||
|
"TRITON_MLA",
|
||||||
|
1,
|
||||||
|
None,
|
||||||
|
True,
|
||||||
|
),
|
||||||
|
# Test Case 3: ROCM_AITER_MLA with block_size == 1
|
||||||
|
(
|
||||||
|
{},
|
||||||
|
"ROCM_AITER_MLA",
|
||||||
|
1,
|
||||||
|
AttentionBackendEnum.ROCM_AITER_MLA.get_path(),
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
# Test Case 4: ROCM_AITER_MLA with block_size != 1 (should raise)
|
||||||
|
(
|
||||||
|
{},
|
||||||
|
"ROCM_AITER_MLA",
|
||||||
|
16,
|
||||||
|
AttentionBackendEnum.ROCM_AITER_MLA.get_path(),
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
# Test Case 5: VLLM_ROCM_USE_AITER=1 with block_size == 1
|
||||||
|
(
|
||||||
|
{"VLLM_ROCM_USE_AITER": "1"},
|
||||||
|
None,
|
||||||
|
1,
|
||||||
|
AttentionBackendEnum.ROCM_AITER_MLA.get_path(),
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
# Test Case 6: VLLM_ROCM_USE_AITER=1 with block_size == 16
|
||||||
|
# (should use ROCM_AITER_MLA now, as it supports block_size 16)
|
||||||
|
(
|
||||||
|
{"VLLM_ROCM_USE_AITER": "1"},
|
||||||
|
None,
|
||||||
|
16,
|
||||||
|
AttentionBackendEnum.ROCM_AITER_MLA.get_path(),
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
# Test Case 7: VLLM_ROCM_USE_AITER=1 + explicit TRITON_MLA
|
||||||
|
(
|
||||||
|
{"VLLM_ROCM_USE_AITER": "1"},
|
||||||
|
"TRITON_MLA",
|
||||||
|
16,
|
||||||
|
AttentionBackendEnum.TRITON_MLA.get_path(),
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
# Test Case 8: Explicit ROCM_AITER_TRITON_MLA
|
||||||
|
(
|
||||||
|
{},
|
||||||
|
"ROCM_AITER_TRITON_MLA",
|
||||||
|
16,
|
||||||
|
AttentionBackendEnum.ROCM_AITER_TRITON_MLA.get_path(),
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_mla_backend_selection(
|
||||||
|
env_vars,
|
||||||
|
selected_backend,
|
||||||
|
block_size,
|
||||||
|
expected_backend_path,
|
||||||
|
should_raise,
|
||||||
|
mock_vllm_config,
|
||||||
|
monkeypatch,
|
||||||
|
):
|
||||||
|
"""Test MLA backend selection with various configurations."""
|
||||||
|
# Set environment variables
|
||||||
|
for key, value in env_vars.items():
|
||||||
|
monkeypatch.setenv(key, value)
|
||||||
|
|
||||||
|
# Import after setting env vars
|
||||||
|
# Reload envs
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
|
|
||||||
|
importlib.reload(envs)
|
||||||
|
|
||||||
|
# Mock is_aiter_mla_enabled based on env vars and block_size
|
||||||
|
aiter_enabled = env_vars.get("VLLM_ROCM_USE_AITER") == "1"
|
||||||
|
|
||||||
|
mock_rocm_ops = MagicMock()
|
||||||
|
mock_rocm_ops.is_mla_enabled.return_value = aiter_enabled
|
||||||
|
mock_aiter_module = MagicMock()
|
||||||
|
mock_aiter_module.rocm_aiter_ops = mock_rocm_ops
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", {"vllm._aiter_ops": mock_aiter_module}):
|
||||||
|
# Convert string backend to enum if provided
|
||||||
|
backend_enum = None
|
||||||
|
if selected_backend:
|
||||||
|
backend_enum = getattr(_Backend, selected_backend)
|
||||||
|
|
||||||
|
from vllm.platforms.rocm import RocmPlatform
|
||||||
|
|
||||||
|
if should_raise:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
RocmPlatform.get_attn_backend_cls(
|
||||||
|
selected_backend=backend_enum,
|
||||||
|
head_size=128,
|
||||||
|
dtype=torch.float16,
|
||||||
|
kv_cache_dtype="auto",
|
||||||
|
block_size=block_size,
|
||||||
|
use_mla=True,
|
||||||
|
has_sink=False,
|
||||||
|
use_sparse=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
backend_path = RocmPlatform.get_attn_backend_cls(
|
||||||
|
selected_backend=backend_enum,
|
||||||
|
head_size=128,
|
||||||
|
dtype=torch.float16,
|
||||||
|
kv_cache_dtype="auto",
|
||||||
|
block_size=block_size,
|
||||||
|
use_mla=True,
|
||||||
|
has_sink=False,
|
||||||
|
use_sparse=False,
|
||||||
|
)
|
||||||
|
assert backend_path == expected_backend_path
|
||||||
|
|
||||||
|
|
||||||
|
def test_aiter_fa_requires_gfx9(mock_vllm_config):
|
||||||
|
"""Test that ROCM_AITER_FA requires gfx9 architecture."""
|
||||||
|
from vllm.attention.backends.registry import _Backend
|
||||||
|
from vllm.platforms.rocm import RocmPlatform
|
||||||
|
|
||||||
|
# Mock on_gfx9 to return False
|
||||||
|
with (
|
||||||
|
patch("vllm.platforms.rocm.on_gfx9", return_value=False),
|
||||||
|
pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="only supported on gfx9",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
RocmPlatform.get_attn_backend_cls(
|
||||||
|
selected_backend=_Backend.ROCM_AITER_FA,
|
||||||
|
head_size=128,
|
||||||
|
dtype=torch.float16,
|
||||||
|
kv_cache_dtype="auto",
|
||||||
|
block_size=16,
|
||||||
|
use_mla=False,
|
||||||
|
has_sink=False,
|
||||||
|
use_sparse=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sparse_not_supported(mock_vllm_config):
|
||||||
|
"""Test that sparse attention is not supported on ROCm."""
|
||||||
|
from vllm.platforms.rocm import RocmPlatform
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
AssertionError, match="Sparse MLA backend on ROCm only supports block size 1"
|
||||||
|
):
|
||||||
|
RocmPlatform.get_attn_backend_cls(
|
||||||
|
selected_backend=None,
|
||||||
|
head_size=128,
|
||||||
|
dtype=torch.float16,
|
||||||
|
kv_cache_dtype="auto",
|
||||||
|
block_size=16,
|
||||||
|
use_mla=False,
|
||||||
|
has_sink=False,
|
||||||
|
use_sparse=True,
|
||||||
|
)
|
||||||
@@ -262,30 +262,64 @@ class RocmPlatform(Platform):
|
|||||||
f"is not MLA type while requested for MLA backend."
|
f"is not MLA type while requested for MLA backend."
|
||||||
)
|
)
|
||||||
|
|
||||||
if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
|
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
|
||||||
logger.info("Using FlexAttention backend.")
|
logger.info("Using Triton Attention backend on V1 engine.")
|
||||||
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
|
return AttentionBackendEnum.TRITON_ATTN.get_path()
|
||||||
if (
|
|
||||||
rocm_aiter_ops.is_mha_enabled()
|
if selected_backend == AttentionBackendEnum.ROCM_ATTN:
|
||||||
) or selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
|
logger.info("Using Rocm Attention backend on V1 engine.")
|
||||||
logger.info("Using Aiter Flash Attention backend.")
|
|
||||||
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
|
|
||||||
if (
|
|
||||||
rocm_aiter_ops.is_triton_unified_attn_enabled()
|
|
||||||
) or selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
|
|
||||||
logger.info("Using Aiter Unified Attention backend.")
|
|
||||||
return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
|
|
||||||
if (
|
|
||||||
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
|
|
||||||
or selected_backend == AttentionBackendEnum.ROCM_ATTN
|
|
||||||
):
|
|
||||||
# rocm specific backend, with aiter and/or
|
|
||||||
# triton prefix-prefill
|
|
||||||
logger.info("Using Rocm Attention backend.")
|
|
||||||
return AttentionBackendEnum.ROCM_ATTN.get_path()
|
return AttentionBackendEnum.ROCM_ATTN.get_path()
|
||||||
# default case, using triton unified attention
|
|
||||||
logger.info("Using Triton Attention backend.")
|
if selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
|
||||||
return AttentionBackendEnum.TRITON_ATTN.get_path()
|
if on_gfx9():
|
||||||
|
logger.info("Using Aiter Flash Attention backend on V1 engine.")
|
||||||
|
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"The selected backend, {selected_backend.name}, "
|
||||||
|
"is only supported on gfx9 architectures."
|
||||||
|
)
|
||||||
|
|
||||||
|
if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
|
||||||
|
logger.info("Using Aiter Unified Attention backend on V1 engine.")
|
||||||
|
return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
|
||||||
|
|
||||||
|
# Handle automatic backend selection based on environment variables
|
||||||
|
if selected_backend is None:
|
||||||
|
# Priority 1: Check for AITER Unified Attention (must check before MHA)
|
||||||
|
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
|
||||||
|
logger.info("Using Aiter Unified Attention backend on V1 engine.")
|
||||||
|
return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
|
||||||
|
|
||||||
|
# Priority 2: Check for AITER MHA (Flash Attention)
|
||||||
|
# Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
|
||||||
|
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
|
||||||
|
logger.info("Using Aiter Flash Attention backend on V1 engine.")
|
||||||
|
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
|
||||||
|
|
||||||
|
# Priority 3: Check for ROCM_ATTN (prefill-decode split)
|
||||||
|
if envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION:
|
||||||
|
logger.info("Using Rocm Attention backend on V1 engine.")
|
||||||
|
return AttentionBackendEnum.ROCM_ATTN.get_path()
|
||||||
|
|
||||||
|
# Priority 4: Check for AITER enabled without specific flags
|
||||||
|
# This defaults to AITER FA only if MHA is not explicitly disabled
|
||||||
|
if (
|
||||||
|
envs.VLLM_ROCM_USE_AITER
|
||||||
|
and on_gfx9()
|
||||||
|
and envs.VLLM_ROCM_USE_AITER_MHA is not False
|
||||||
|
):
|
||||||
|
logger.info("Using Aiter Flash Attention backend on V1 engine.")
|
||||||
|
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
|
||||||
|
|
||||||
|
# Default: Triton Unified Attention
|
||||||
|
logger.info("Using Triton Attention backend on V1 engine.")
|
||||||
|
return AttentionBackendEnum.TRITON_ATTN.get_path()
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
|
||||||
|
"to select a supported backend."
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_device(cls, device: torch.device) -> None:
|
def set_device(cls, device: torch.device) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user