[ROCm][CI] Fix cross-attention dispatch for encoder-decoder models (#38450)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -14,13 +14,62 @@ import pytest_asyncio
|
||||
import soundfile as sf
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODEL_NAME = "openai/whisper-large-v3-turbo"
|
||||
|
||||
# Disable prefix caching on ROCm to reduce non-determinism in
|
||||
# streaming-vs-non-streaming comparisons.
|
||||
_ROCM_ARGS = ["--no-enable-prefix-caching"] if current_platform.is_rocm() else []
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
with RemoteOpenAIServer(MODEL_NAME, []) as remote_server:
|
||||
|
||||
def _get_attention_backend_params() -> list[str | None]:
|
||||
"""Return attention backends to parametrize the server fixture with.
|
||||
|
||||
On ROCm, we test multiple backends explicitly:
|
||||
- None: default auto-selection (ROCM_ATTN for decoder self-attention,
|
||||
falls back to ROCM_AITER_UNIFIED_ATTN or TRITON_ATTN for
|
||||
cross-attention since ROCM_ATTN doesn't support ENCODER_DECODER)
|
||||
- TRITON_ATTN: always available on ROCm
|
||||
- ROCM_AITER_UNIFIED_ATTN: only on gfx942/gfx950
|
||||
|
||||
On non-ROCm platforms, we just run with the default backend.
|
||||
"""
|
||||
try:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_rocm():
|
||||
backends: list[str | None] = [None, "TRITON_ATTN"]
|
||||
from vllm.platforms.rocm import _ON_MI3XX
|
||||
|
||||
if _ON_MI3XX:
|
||||
backends.append("ROCM_AITER_UNIFIED_ATTN")
|
||||
return backends
|
||||
except Exception:
|
||||
pass
|
||||
return [None]
|
||||
|
||||
|
||||
# Aiter backends need VLLM_ROCM_USE_AITER=1 (and MHA=1 for ROCM_AITER_FA)
|
||||
# to be enabled in the server subprocess.
|
||||
_AITER_ENV = {
|
||||
"VLLM_ROCM_USE_AITER": "1",
|
||||
"VLLM_ROCM_USE_AITER_MHA": "1",
|
||||
}
|
||||
|
||||
_ATTN_BACKENDS = _get_attention_backend_params()
|
||||
_ATTN_IDS = [b or "default" for b in _ATTN_BACKENDS]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=_ATTN_BACKENDS, ids=_ATTN_IDS)
|
||||
def server(request):
|
||||
args = [*_ROCM_ARGS]
|
||||
env_dict = None
|
||||
if request.param is not None:
|
||||
args += ["--attention-backend", request.param]
|
||||
if "AITER" in request.param:
|
||||
env_dict = _AITER_ENV
|
||||
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user