[ROCm][CI] Fix Whisper translation test attention backend selection (#38508)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-03-31 00:21:49 -05:00
committed by GitHub
parent 3e802e8786
commit b9cdc85207

View File

@@ -16,10 +16,41 @@ import soundfile as sf
from tests.entrypoints.openai.conftest import add_attention_backend
from tests.utils import RemoteOpenAIServer
from vllm.logger import init_logger
logger = init_logger(__name__)
SERVER_ARGS = ["--enforce-eager"]
def _get_rocm_attention_config(model_name):
"""Return appropriate ROCm attention config for the given model.
Whisper uses cross-attention (ENCODER_DECODER) which ROCM_AITER_FA does
not support. For Whisper we use ROCM_AITER_UNIFIED_ATTN (or TRITON_ATTN
as fallback); other models can use ROCM_AITER_FA.
"""
from vllm.platforms import current_platform
if not current_platform.is_rocm():
return None
if "whisper" in model_name.lower():
try:
from vllm.platforms.rocm import _ON_MI3XX
if _ON_MI3XX:
return {"backend": "ROCM_AITER_UNIFIED_ATTN"}
except ImportError:
logger.warning(
"Could not import _ON_MI3XX from rocm platform, "
"falling back to TRITON_ATTN for Whisper."
)
return {"backend": "TRITON_ATTN"}
return {"backend": "ROCM_AITER_FA"}
def _get_server_args(attention_config):
"""Get server args with attention backend if specified."""
args = SERVER_ARGS.copy()
@@ -30,10 +61,11 @@ def _get_server_args(attention_config):
@pytest.fixture(
scope="module", params=["openai/whisper-small", "google/gemma-3n-E2B-it"]
)
def server(request, rocm_aiter_fa_attention):
def server(request):
# Parametrize over model name
attention_config = _get_rocm_attention_config(request.param)
with RemoteOpenAIServer(
request.param, _get_server_args(rocm_aiter_fa_attention)
request.param, _get_server_args(attention_config)
) as remote_server:
yield remote_server, request.param
@@ -46,11 +78,12 @@ async def client_and_model(server):
@pytest.mark.asyncio
async def test_non_asr_model(foscolo, rocm_aiter_fa_attention):
async def test_non_asr_model(foscolo):
# text to text model
model_name = "JackFram/llama-68m"
attention_config = _get_rocm_attention_config(model_name)
with RemoteOpenAIServer(
model_name, _get_server_args(rocm_aiter_fa_attention)
model_name, _get_server_args(attention_config)
) as remote_server:
client = remote_server.get_async_client()
@@ -61,7 +94,7 @@ async def test_non_asr_model(foscolo, rocm_aiter_fa_attention):
@pytest.mark.asyncio
async def test_basic_audio_with_lora(mary_had_lamb, rocm_aiter_fa_attention):
async def test_basic_audio_with_lora(mary_had_lamb):
"""Ensure STT (translate) requests can pass LoRA through to generate."""
# ROCm SPECIFIC CONFIGURATION:
# To ensure the test passes on ROCm, we modify the max model length to 512.
@@ -85,7 +118,7 @@ async def test_basic_audio_with_lora(mary_had_lamb, rocm_aiter_fa_attention):
"1",
]
add_attention_backend(server_args, rocm_aiter_fa_attention)
add_attention_backend(server_args, _get_rocm_attention_config(model_name))
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
with RemoteOpenAIServer(model_name, server_args) as remote_server: