[ROCm][CI] Fix cross-attention dispatch for encoder-decoder models (#38450)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-03-29 00:08:03 -05:00
committed by GitHub
parent 5b8c30d62b
commit 43cc5138e5
6 changed files with 90 additions and 19 deletions

View File

@@ -14,13 +14,62 @@ import pytest_asyncio
import soundfile as sf
from tests.utils import RemoteOpenAIServer
from vllm.platforms import current_platform
MODEL_NAME = "openai/whisper-large-v3-turbo"
# Disable prefix caching on ROCm to reduce non-determinism in
# streaming-vs-non-streaming comparisons.
_ROCM_ARGS = ["--no-enable-prefix-caching"] if current_platform.is_rocm() else []
@pytest.fixture(scope="module")
def server():
with RemoteOpenAIServer(MODEL_NAME, []) as remote_server:
def _get_attention_backend_params() -> list[str | None]:
"""Return attention backends to parametrize the server fixture with.
On ROCm, we test multiple backends explicitly:
- None: default auto-selection (ROCM_ATTN for decoder self-attention,
falls back to ROCM_AITER_UNIFIED_ATTN or TRITON_ATTN for
cross-attention since ROCM_ATTN doesn't support ENCODER_DECODER)
- TRITON_ATTN: always available on ROCm
- ROCM_AITER_UNIFIED_ATTN: only on gfx942/gfx950
On non-ROCm platforms, we just run with the default backend.
"""
try:
from vllm.platforms import current_platform
if current_platform.is_rocm():
backends: list[str | None] = [None, "TRITON_ATTN"]
from vllm.platforms.rocm import _ON_MI3XX
if _ON_MI3XX:
backends.append("ROCM_AITER_UNIFIED_ATTN")
return backends
except Exception:
pass
return [None]
# Aiter backends need VLLM_ROCM_USE_AITER=1 (and MHA=1 for ROCM_AITER_FA)
# to be enabled in the server subprocess.
_AITER_ENV = {
"VLLM_ROCM_USE_AITER": "1",
"VLLM_ROCM_USE_AITER_MHA": "1",
}
_ATTN_BACKENDS = _get_attention_backend_params()
_ATTN_IDS = [b or "default" for b in _ATTN_BACKENDS]
@pytest.fixture(scope="module", params=_ATTN_BACKENDS, ids=_ATTN_IDS)
def server(request):
args = [*_ROCM_ARGS]
env_dict = None
if request.param is not None:
args += ["--attention-backend", request.param]
if "AITER" in request.param:
env_dict = _AITER_ENV
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
yield remote_server