From b9cdc85207b3af83613fb45501b666e4e76c974d Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Tue, 31 Mar 2026 00:21:49 -0500 Subject: [PATCH] [ROCm][CI] Fix Whisper translation test attention backend selection (#38508) Signed-off-by: Andreas Karatzas --- .../test_translation_validation.py | 45 ++++++++++++++++--- 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/tests/entrypoints/openai/speech_to_text/test_translation_validation.py b/tests/entrypoints/openai/speech_to_text/test_translation_validation.py index 6fb60d537..16b9614d9 100644 --- a/tests/entrypoints/openai/speech_to_text/test_translation_validation.py +++ b/tests/entrypoints/openai/speech_to_text/test_translation_validation.py @@ -16,10 +16,41 @@ import soundfile as sf from tests.entrypoints.openai.conftest import add_attention_backend from tests.utils import RemoteOpenAIServer +from vllm.logger import init_logger + +logger = init_logger(__name__) SERVER_ARGS = ["--enforce-eager"] +def _get_rocm_attention_config(model_name): + """Return appropriate ROCm attention config for the given model. + + Whisper uses cross-attention (ENCODER_DECODER) which ROCM_AITER_FA does + not support. For Whisper we use ROCM_AITER_UNIFIED_ATTN (or TRITON_ATTN + as fallback); other models can use ROCM_AITER_FA. + """ + from vllm.platforms import current_platform + + if not current_platform.is_rocm(): + return None + + if "whisper" in model_name.lower(): + try: + from vllm.platforms.rocm import _ON_MI3XX + + if _ON_MI3XX: + return {"backend": "ROCM_AITER_UNIFIED_ATTN"} + except ImportError: + logger.warning( + "Could not import _ON_MI3XX from rocm platform, " + "falling back to TRITON_ATTN for Whisper." + ) + return {"backend": "TRITON_ATTN"} + + return {"backend": "ROCM_AITER_FA"} + + def _get_server_args(attention_config): """Get server args with attention backend if specified.""" args = SERVER_ARGS.copy() @@ -30,10 +61,11 @@ def _get_server_args(attention_config): @pytest.fixture( scope="module", params=["openai/whisper-small", "google/gemma-3n-E2B-it"] ) -def server(request, rocm_aiter_fa_attention): +def server(request): # Parametrize over model name + attention_config = _get_rocm_attention_config(request.param) with RemoteOpenAIServer( - request.param, _get_server_args(rocm_aiter_fa_attention) + request.param, _get_server_args(attention_config) ) as remote_server: yield remote_server, request.param @@ -46,11 +78,12 @@ async def client_and_model(server): @pytest.mark.asyncio -async def test_non_asr_model(foscolo, rocm_aiter_fa_attention): +async def test_non_asr_model(foscolo): # text to text model model_name = "JackFram/llama-68m" + attention_config = _get_rocm_attention_config(model_name) with RemoteOpenAIServer( - model_name, _get_server_args(rocm_aiter_fa_attention) + model_name, _get_server_args(attention_config) ) as remote_server: client = remote_server.get_async_client() @@ -61,7 +94,7 @@ async def test_non_asr_model(foscolo, rocm_aiter_fa_attention): @pytest.mark.asyncio -async def test_basic_audio_with_lora(mary_had_lamb, rocm_aiter_fa_attention): +async def test_basic_audio_with_lora(mary_had_lamb): """Ensure STT (translate) requests can pass LoRA through to generate.""" # ROCm SPECIFIC CONFIGURATION: # To ensure the test passes on ROCm, we modify the max model length to 512. @@ -85,7 +118,7 @@ async def test_basic_audio_with_lora(mary_had_lamb, rocm_aiter_fa_attention): "1", ] - add_attention_backend(server_args, rocm_aiter_fa_attention) + add_attention_backend(server_args, _get_rocm_attention_config(model_name)) # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb. with RemoteOpenAIServer(model_name, server_args) as remote_server: