[ROCm][CI] Fix cross-attention dispatch for encoder-decoder models (#38450)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -173,9 +173,9 @@ Priority is **1 = highest** (tried first).
|
||||
| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 |
|
||||
| `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any |
|
||||
| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
|
||||
| `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder, Enc-Dec | N/A |
|
||||
| `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A |
|
||||
| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | %16 | Any | ✅ | ✅ | ❌ | All | N/A |
|
||||
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ❌ | ✅ | ❌ | All | N/A |
|
||||
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ❌ | ✅ | ❌ | Decoder, Encoder, Encoder Only | N/A |
|
||||
| `TREE_ATTN` | | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any |
|
||||
| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any |
|
||||
|
||||
|
||||
@@ -14,13 +14,62 @@ import pytest_asyncio
|
||||
import soundfile as sf
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODEL_NAME = "openai/whisper-large-v3-turbo"
|
||||
|
||||
# Disable prefix caching on ROCm to reduce non-determinism in
|
||||
# streaming-vs-non-streaming comparisons.
|
||||
_ROCM_ARGS = ["--no-enable-prefix-caching"] if current_platform.is_rocm() else []
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
with RemoteOpenAIServer(MODEL_NAME, []) as remote_server:
|
||||
|
||||
def _get_attention_backend_params() -> list[str | None]:
|
||||
"""Return attention backends to parametrize the server fixture with.
|
||||
|
||||
On ROCm, we test multiple backends explicitly:
|
||||
- None: default auto-selection (ROCM_ATTN for decoder self-attention,
|
||||
falls back to ROCM_AITER_UNIFIED_ATTN or TRITON_ATTN for
|
||||
cross-attention since ROCM_ATTN doesn't support ENCODER_DECODER)
|
||||
- TRITON_ATTN: always available on ROCm
|
||||
- ROCM_AITER_UNIFIED_ATTN: only on gfx942/gfx950
|
||||
|
||||
On non-ROCm platforms, we just run with the default backend.
|
||||
"""
|
||||
try:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_rocm():
|
||||
backends: list[str | None] = [None, "TRITON_ATTN"]
|
||||
from vllm.platforms.rocm import _ON_MI3XX
|
||||
|
||||
if _ON_MI3XX:
|
||||
backends.append("ROCM_AITER_UNIFIED_ATTN")
|
||||
return backends
|
||||
except Exception:
|
||||
pass
|
||||
return [None]
|
||||
|
||||
|
||||
# Aiter backends need VLLM_ROCM_USE_AITER=1 (and MHA=1 for ROCM_AITER_FA)
|
||||
# to be enabled in the server subprocess.
|
||||
_AITER_ENV = {
|
||||
"VLLM_ROCM_USE_AITER": "1",
|
||||
"VLLM_ROCM_USE_AITER_MHA": "1",
|
||||
}
|
||||
|
||||
_ATTN_BACKENDS = _get_attention_backend_params()
|
||||
_ATTN_IDS = [b or "default" for b in _ATTN_BACKENDS]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=_ATTN_BACKENDS, ids=_ATTN_IDS)
|
||||
def server(request):
|
||||
args = [*_ROCM_ARGS]
|
||||
env_dict = None
|
||||
if request.param is not None:
|
||||
args += ["--attention-backend", request.param]
|
||||
if "AITER" in request.param:
|
||||
env_dict = _AITER_ENV
|
||||
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
|
||||
@@ -446,7 +446,7 @@ def parse_attention_types(node: ast.ClassDef) -> str:
|
||||
|
||||
if not types:
|
||||
return "Decoder"
|
||||
return "All" if len(types) >= 3 else ", ".join(sorted(types))
|
||||
return "All" if types >= set(type_map.values()) else ", ".join(sorted(types))
|
||||
|
||||
|
||||
def parse_impl_bool_attr(
|
||||
|
||||
@@ -439,7 +439,10 @@ class RocmPlatform(Platform):
|
||||
f"this configuration. Reason: {invalid_reasons}"
|
||||
)
|
||||
else:
|
||||
logger.info("Using %s backend.", selected_backend)
|
||||
logger.info_once(
|
||||
"Using %s backend (selected via --attention-backend).",
|
||||
selected_backend.name,
|
||||
)
|
||||
return selected_backend.get_path()
|
||||
|
||||
# No selected backend or the selected backend is invalid,
|
||||
@@ -476,12 +479,25 @@ class RocmPlatform(Platform):
|
||||
)
|
||||
selected_index = sorted_indices[0]
|
||||
selected_backend = valid_backends_priorities[selected_index][0]
|
||||
logger.info_once(
|
||||
"Using %s attention backend out of potential backends: %s.",
|
||||
selected_backend.name,
|
||||
"[" + ", ".join(f"'{b[0].name}'" for b in valid_backends_priorities) + "]",
|
||||
scope="local",
|
||||
valid_str = (
|
||||
"[" + ", ".join(f"'{b[0].name}'" for b in valid_backends_priorities) + "]"
|
||||
)
|
||||
if invalid_reasons:
|
||||
rejected_str = ", ".join(b.name for b in invalid_reasons)
|
||||
logger.info(
|
||||
"Found incompatible backend(s) [%s] with %s. "
|
||||
"Overriding with %s out of potential backends: %s.",
|
||||
rejected_str,
|
||||
attn_selector_config.attn_type,
|
||||
selected_backend.name,
|
||||
valid_str,
|
||||
)
|
||||
else:
|
||||
logger.info_once(
|
||||
"Using %s backend out of potential backends: %s.",
|
||||
selected_backend.name,
|
||||
valid_str,
|
||||
)
|
||||
|
||||
return selected_backend.get_path()
|
||||
|
||||
|
||||
@@ -758,11 +758,12 @@ class AiterFlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""ROCM AITER FA supports decoder and encoder-decoder (cross) attention."""
|
||||
return attn_type in (
|
||||
AttentionType.DECODER,
|
||||
AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
"""ENCODER_DECODER is not supported because the prefill path uses
|
||||
flash_attn_varlen_func with cu_seqlens_k set to decoder
|
||||
query_start_loc (not encoder seq lens) and causal=True, both of
|
||||
which are incorrect for cross-attention layers.
|
||||
"""
|
||||
return attn_type in (AttentionType.DECODER,)
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
|
||||
@@ -212,12 +212,17 @@ class RocmAttentionBackend(AttentionBackend):
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""RocmAttention supports all attention types."""
|
||||
"""ENCODER_DECODER is not supported because
|
||||
chunked_prefill_paged_decode's prefill kernel (context_attention_fwd)
|
||||
assumes self-attention semantics: it treats passed K/V as new tokens
|
||||
to mix with cached K/V. For cross-attention layers the encoder K/V
|
||||
are already fully cached, so mixing them again produces incorrect
|
||||
results when max_query_len > 1 (e.g. beam search).
|
||||
"""
|
||||
return attn_type in (
|
||||
AttentionType.DECODER,
|
||||
AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY,
|
||||
AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user