[CI/ROCm] Fixing "V1 Test attention (H100)" test group. (#31187)
Signed-off-by: DCCS-4560 <alivanov@chi-mi325x-pod1-108.ord.vultr.cpe.ice.amd.com> Signed-off-by: <> Co-authored-by: DCCS-4560 <alivanov@chi-mi325x-pod1-108.ord.vultr.cpe.ice.amd.com> Co-authored-by: root <root@chi-mi325x-pod1-108.ord.vultr.cpe.ice.amd.com>
This commit is contained in:
committed by
GitHub
parent
56f516254c
commit
d63b969675
@@ -557,9 +557,21 @@ def test_causal_backend_correctness(
|
||||
if is_torch_equal_or_newer("2.9.0.dev0")
|
||||
else []
|
||||
)
|
||||
SMALL_BLOCK_BACKENDS = [
|
||||
x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
|
||||
]
|
||||
|
||||
if current_platform.is_rocm():
|
||||
SMALL_BLOCK_BACKENDS = [
|
||||
x
|
||||
for x in BACKENDS_TO_TEST
|
||||
if (
|
||||
x not in LARGE_BLOCK_BACKENDS
|
||||
and x is not AttentionBackendEnum.FLASH_ATTN
|
||||
)
|
||||
]
|
||||
else:
|
||||
SMALL_BLOCK_BACKENDS = [
|
||||
x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
|
||||
]
|
||||
|
||||
_test_backend_correctness(
|
||||
batch_spec,
|
||||
model,
|
||||
@@ -580,12 +592,20 @@ def test_causal_backend_correctness(
|
||||
)
|
||||
|
||||
|
||||
SLIDING_WINDOW_BACKENDS_TO_TEST = [
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.FLEX_ATTENTION,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
"FLEX_ATTENTION_SLOW",
|
||||
]
|
||||
if current_platform.is_rocm():
|
||||
# FLASH_ATTN is not supported on ROCm
|
||||
SLIDING_WINDOW_BACKENDS_TO_TEST = [
|
||||
AttentionBackendEnum.FLEX_ATTENTION,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
"FLEX_ATTENTION_SLOW",
|
||||
]
|
||||
else:
|
||||
SLIDING_WINDOW_BACKENDS_TO_TEST = [
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.FLEX_ATTENTION,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
"FLEX_ATTENTION_SLOW",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -8,6 +8,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.selector import AttentionSelectorConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# ROCm-specific attention backend selection tests
|
||||
@@ -144,8 +145,7 @@ def test_standard_attention_backend_selection(
|
||||
# Get the backend class path
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
|
||||
backend_path = RocmPlatform.get_attn_backend_cls(
|
||||
selected_backend=backend_enum,
|
||||
attn_selector_config = AttentionSelectorConfig(
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
kv_cache_dtype="auto",
|
||||
@@ -154,6 +154,11 @@ def test_standard_attention_backend_selection(
|
||||
has_sink=False,
|
||||
use_sparse=False,
|
||||
)
|
||||
|
||||
backend_path = RocmPlatform.get_attn_backend_cls(
|
||||
selected_backend=backend_enum, attn_selector_config=attn_selector_config
|
||||
)
|
||||
|
||||
assert backend_path == expected_backend_path
|
||||
|
||||
|
||||
@@ -267,8 +272,7 @@ def test_mla_backend_selection(
|
||||
|
||||
if should_raise:
|
||||
with pytest.raises(ValueError):
|
||||
RocmPlatform.get_attn_backend_cls(
|
||||
selected_backend=backend_enum,
|
||||
attn_selector_config = AttentionSelectorConfig(
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
kv_cache_dtype="auto",
|
||||
@@ -277,9 +281,22 @@ def test_mla_backend_selection(
|
||||
has_sink=False,
|
||||
use_sparse=False,
|
||||
)
|
||||
attn_selector_config = AttentionSelectorConfig(
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
kv_cache_dtype="auto",
|
||||
block_size=block_size,
|
||||
use_mla=True,
|
||||
has_sink=False,
|
||||
use_sparse=False,
|
||||
)
|
||||
backend_path = RocmPlatform.get_attn_backend_cls(
|
||||
selected_backend=backend_enum,
|
||||
attn_selector_config=attn_selector_config,
|
||||
)
|
||||
|
||||
else:
|
||||
backend_path = RocmPlatform.get_attn_backend_cls(
|
||||
selected_backend=backend_enum,
|
||||
attn_selector_config = AttentionSelectorConfig(
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
kv_cache_dtype="auto",
|
||||
@@ -288,6 +305,11 @@ def test_mla_backend_selection(
|
||||
has_sink=False,
|
||||
use_sparse=False,
|
||||
)
|
||||
|
||||
backend_path = RocmPlatform.get_attn_backend_cls(
|
||||
selected_backend=backend_enum, attn_selector_config=attn_selector_config
|
||||
)
|
||||
|
||||
assert backend_path == expected_backend_path
|
||||
|
||||
|
||||
@@ -303,8 +325,7 @@ def test_aiter_fa_requires_gfx9(mock_vllm_config):
|
||||
match="only supported on gfx9",
|
||||
),
|
||||
):
|
||||
RocmPlatform.get_attn_backend_cls(
|
||||
selected_backend=AttentionBackendEnum.ROCM_AITER_FA,
|
||||
attn_selector_config = AttentionSelectorConfig(
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
kv_cache_dtype="auto",
|
||||
@@ -314,6 +335,11 @@ def test_aiter_fa_requires_gfx9(mock_vllm_config):
|
||||
use_sparse=False,
|
||||
)
|
||||
|
||||
RocmPlatform.get_attn_backend_cls(
|
||||
selected_backend=AttentionBackendEnum.ROCM_AITER_FA,
|
||||
attn_selector_config=attn_selector_config,
|
||||
)
|
||||
|
||||
|
||||
def test_sparse_not_supported(mock_vllm_config):
|
||||
"""Test that sparse attention is not supported on ROCm."""
|
||||
@@ -322,8 +348,7 @@ def test_sparse_not_supported(mock_vllm_config):
|
||||
with pytest.raises(
|
||||
AssertionError, match="Sparse MLA backend on ROCm only supports block size 1"
|
||||
):
|
||||
RocmPlatform.get_attn_backend_cls(
|
||||
selected_backend=None,
|
||||
attn_selector_config = AttentionSelectorConfig(
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
kv_cache_dtype="auto",
|
||||
@@ -332,3 +357,7 @@ def test_sparse_not_supported(mock_vllm_config):
|
||||
has_sink=False,
|
||||
use_sparse=True,
|
||||
)
|
||||
|
||||
RocmPlatform.get_attn_backend_cls(
|
||||
selected_backend=None, attn_selector_config=attn_selector_config
|
||||
)
|
||||
|
||||
@@ -24,6 +24,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.attention.ops import flashmla
|
||||
from vllm.config import set_current_vllm_config
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
FlashMLASparseBackend,
|
||||
@@ -125,6 +126,9 @@ def _quantize_dequantize_fp8_ds_mla(
|
||||
def test_sparse_backend_decode_correctness(
|
||||
dist_init, batch_name, kv_cache_dtype, tensor_parallel_size, workspace_init
|
||||
):
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip("ROCm does not support fp8_ds_mla data type for kv cache.")
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is required for sparse MLA decode test")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user