[ROCm][CI] Fix cross-attention dispatch for encoder-decoder models (#38450)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-03-29 00:08:03 -05:00
committed by GitHub
parent 5b8c30d62b
commit 43cc5138e5
6 changed files with 90 additions and 19 deletions

View File

@@ -173,9 +173,9 @@ Priority is **1 = highest** (tried first).
| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 |
| `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any |
| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
| `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder, Enc-Dec | N/A |
| `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | %16 | Any | ✅ | ✅ | ❌ | All | N/A |
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ❌ | ✅ | ❌ | All | N/A |
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ❌ | ✅ | ❌ | Decoder, Encoder, Encoder Only | N/A |
| `TREE_ATTN` | | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any |
| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any |

View File

@@ -14,13 +14,62 @@ import pytest_asyncio
import soundfile as sf
from tests.utils import RemoteOpenAIServer
from vllm.platforms import current_platform
MODEL_NAME = "openai/whisper-large-v3-turbo"
# Disable prefix caching on ROCm to reduce non-determinism in
# streaming-vs-non-streaming comparisons.
_ROCM_ARGS = ["--no-enable-prefix-caching"] if current_platform.is_rocm() else []
@pytest.fixture(scope="module")
def server():
with RemoteOpenAIServer(MODEL_NAME, []) as remote_server:
def _get_attention_backend_params() -> list[str | None]:
"""Return attention backends to parametrize the server fixture with.
On ROCm, we test multiple backends explicitly:
- None: default auto-selection (ROCM_ATTN for decoder self-attention,
falls back to ROCM_AITER_UNIFIED_ATTN or TRITON_ATTN for
cross-attention since ROCM_ATTN doesn't support ENCODER_DECODER)
- TRITON_ATTN: always available on ROCm
- ROCM_AITER_UNIFIED_ATTN: only on gfx942/gfx950
On non-ROCm platforms, we just run with the default backend.
"""
try:
from vllm.platforms import current_platform
if current_platform.is_rocm():
backends: list[str | None] = [None, "TRITON_ATTN"]
from vllm.platforms.rocm import _ON_MI3XX
if _ON_MI3XX:
backends.append("ROCM_AITER_UNIFIED_ATTN")
return backends
except Exception:
pass
return [None]
# Aiter backends need VLLM_ROCM_USE_AITER=1 (and MHA=1 for ROCM_AITER_FA)
# to be enabled in the server subprocess.
_AITER_ENV = {
"VLLM_ROCM_USE_AITER": "1",
"VLLM_ROCM_USE_AITER_MHA": "1",
}
_ATTN_BACKENDS = _get_attention_backend_params()
_ATTN_IDS = [b or "default" for b in _ATTN_BACKENDS]
@pytest.fixture(scope="module", params=_ATTN_BACKENDS, ids=_ATTN_IDS)
def server(request):
args = [*_ROCM_ARGS]
env_dict = None
if request.param is not None:
args += ["--attention-backend", request.param]
if "AITER" in request.param:
env_dict = _AITER_ENV
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
yield remote_server

View File

@@ -446,7 +446,7 @@ def parse_attention_types(node: ast.ClassDef) -> str:
if not types:
return "Decoder"
return "All" if len(types) >= 3 else ", ".join(sorted(types))
return "All" if types >= set(type_map.values()) else ", ".join(sorted(types))
def parse_impl_bool_attr(

View File

@@ -439,7 +439,10 @@ class RocmPlatform(Platform):
f"this configuration. Reason: {invalid_reasons}"
)
else:
logger.info("Using %s backend.", selected_backend)
logger.info_once(
"Using %s backend (selected via --attention-backend).",
selected_backend.name,
)
return selected_backend.get_path()
# No selected backend or the selected backend is invalid,
@@ -476,12 +479,25 @@ class RocmPlatform(Platform):
)
selected_index = sorted_indices[0]
selected_backend = valid_backends_priorities[selected_index][0]
logger.info_once(
"Using %s attention backend out of potential backends: %s.",
selected_backend.name,
"[" + ", ".join(f"'{b[0].name}'" for b in valid_backends_priorities) + "]",
scope="local",
valid_str = (
"[" + ", ".join(f"'{b[0].name}'" for b in valid_backends_priorities) + "]"
)
if invalid_reasons:
rejected_str = ", ".join(b.name for b in invalid_reasons)
logger.info(
"Found incompatible backend(s) [%s] with %s. "
"Overriding with %s out of potential backends: %s.",
rejected_str,
attn_selector_config.attn_type,
selected_backend.name,
valid_str,
)
else:
logger.info_once(
"Using %s backend out of potential backends: %s.",
selected_backend.name,
valid_str,
)
return selected_backend.get_path()

View File

@@ -758,11 +758,12 @@ class AiterFlashAttentionBackend(AttentionBackend):
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""ROCM AITER FA supports decoder and encoder-decoder (cross) attention."""
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER_DECODER,
)
"""ENCODER_DECODER is not supported because the prefill path uses
flash_attn_varlen_func with cu_seqlens_k set to decoder
query_start_loc (not encoder seq lens) and causal=True, both of
which are incorrect for cross-attention layers.
"""
return attn_type in (AttentionType.DECODER,)
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:

View File

@@ -212,12 +212,17 @@ class RocmAttentionBackend(AttentionBackend):
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""RocmAttention supports all attention types."""
"""ENCODER_DECODER is not supported because
chunked_prefill_paged_decode's prefill kernel (context_attention_fwd)
assumes self-attention semantics: it treats passed K/V as new tokens
to mix with cached K/V. For cross-attention layers the encoder K/V
are already fully cached, so mixing them again produces incorrect
results when max_query_len > 1 (e.g. beam search).
"""
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
AttentionType.ENCODER_DECODER,
)
@staticmethod