[ROCm][CI] Fix Whisper translation test attention backend selection (#38508)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -16,10 +16,41 @@ import soundfile as sf
|
||||
|
||||
from tests.entrypoints.openai.conftest import add_attention_backend
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
SERVER_ARGS = ["--enforce-eager"]
|
||||
|
||||
|
||||
def _get_rocm_attention_config(model_name):
|
||||
"""Return appropriate ROCm attention config for the given model.
|
||||
|
||||
Whisper uses cross-attention (ENCODER_DECODER) which ROCM_AITER_FA does
|
||||
not support. For Whisper we use ROCM_AITER_UNIFIED_ATTN (or TRITON_ATTN
|
||||
as fallback); other models can use ROCM_AITER_FA.
|
||||
"""
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_rocm():
|
||||
return None
|
||||
|
||||
if "whisper" in model_name.lower():
|
||||
try:
|
||||
from vllm.platforms.rocm import _ON_MI3XX
|
||||
|
||||
if _ON_MI3XX:
|
||||
return {"backend": "ROCM_AITER_UNIFIED_ATTN"}
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"Could not import _ON_MI3XX from rocm platform, "
|
||||
"falling back to TRITON_ATTN for Whisper."
|
||||
)
|
||||
return {"backend": "TRITON_ATTN"}
|
||||
|
||||
return {"backend": "ROCM_AITER_FA"}
|
||||
|
||||
|
||||
def _get_server_args(attention_config):
|
||||
"""Get server args with attention backend if specified."""
|
||||
args = SERVER_ARGS.copy()
|
||||
@@ -30,10 +61,11 @@ def _get_server_args(attention_config):
|
||||
@pytest.fixture(
|
||||
scope="module", params=["openai/whisper-small", "google/gemma-3n-E2B-it"]
|
||||
)
|
||||
def server(request, rocm_aiter_fa_attention):
|
||||
def server(request):
|
||||
# Parametrize over model name
|
||||
attention_config = _get_rocm_attention_config(request.param)
|
||||
with RemoteOpenAIServer(
|
||||
request.param, _get_server_args(rocm_aiter_fa_attention)
|
||||
request.param, _get_server_args(attention_config)
|
||||
) as remote_server:
|
||||
yield remote_server, request.param
|
||||
|
||||
@@ -46,11 +78,12 @@ async def client_and_model(server):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_asr_model(foscolo, rocm_aiter_fa_attention):
|
||||
async def test_non_asr_model(foscolo):
|
||||
# text to text model
|
||||
model_name = "JackFram/llama-68m"
|
||||
attention_config = _get_rocm_attention_config(model_name)
|
||||
with RemoteOpenAIServer(
|
||||
model_name, _get_server_args(rocm_aiter_fa_attention)
|
||||
model_name, _get_server_args(attention_config)
|
||||
) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
|
||||
@@ -61,7 +94,7 @@ async def test_non_asr_model(foscolo, rocm_aiter_fa_attention):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_audio_with_lora(mary_had_lamb, rocm_aiter_fa_attention):
|
||||
async def test_basic_audio_with_lora(mary_had_lamb):
|
||||
"""Ensure STT (translate) requests can pass LoRA through to generate."""
|
||||
# ROCm SPECIFIC CONFIGURATION:
|
||||
# To ensure the test passes on ROCm, we modify the max model length to 512.
|
||||
@@ -85,7 +118,7 @@ async def test_basic_audio_with_lora(mary_had_lamb, rocm_aiter_fa_attention):
|
||||
"1",
|
||||
]
|
||||
|
||||
add_attention_backend(server_args, rocm_aiter_fa_attention)
|
||||
add_attention_backend(server_args, _get_rocm_attention_config(model_name))
|
||||
|
||||
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
|
||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||
|
||||
Reference in New Issue
Block a user