[ROCm][CI] Fix Whisper translation test attention backend selection (#38508)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -16,10 +16,41 @@ import soundfile as sf
|
|||||||
|
|
||||||
from tests.entrypoints.openai.conftest import add_attention_backend
|
from tests.entrypoints.openai.conftest import add_attention_backend
|
||||||
from tests.utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
SERVER_ARGS = ["--enforce-eager"]
|
SERVER_ARGS = ["--enforce-eager"]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_rocm_attention_config(model_name):
|
||||||
|
"""Return appropriate ROCm attention config for the given model.
|
||||||
|
|
||||||
|
Whisper uses cross-attention (ENCODER_DECODER) which ROCM_AITER_FA does
|
||||||
|
not support. For Whisper we use ROCM_AITER_UNIFIED_ATTN (or TRITON_ATTN
|
||||||
|
as fallback); other models can use ROCM_AITER_FA.
|
||||||
|
"""
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
if not current_platform.is_rocm():
|
||||||
|
return None
|
||||||
|
|
||||||
|
if "whisper" in model_name.lower():
|
||||||
|
try:
|
||||||
|
from vllm.platforms.rocm import _ON_MI3XX
|
||||||
|
|
||||||
|
if _ON_MI3XX:
|
||||||
|
return {"backend": "ROCM_AITER_UNIFIED_ATTN"}
|
||||||
|
except ImportError:
|
||||||
|
logger.warning(
|
||||||
|
"Could not import _ON_MI3XX from rocm platform, "
|
||||||
|
"falling back to TRITON_ATTN for Whisper."
|
||||||
|
)
|
||||||
|
return {"backend": "TRITON_ATTN"}
|
||||||
|
|
||||||
|
return {"backend": "ROCM_AITER_FA"}
|
||||||
|
|
||||||
|
|
||||||
def _get_server_args(attention_config):
|
def _get_server_args(attention_config):
|
||||||
"""Get server args with attention backend if specified."""
|
"""Get server args with attention backend if specified."""
|
||||||
args = SERVER_ARGS.copy()
|
args = SERVER_ARGS.copy()
|
||||||
@@ -30,10 +61,11 @@ def _get_server_args(attention_config):
|
|||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
scope="module", params=["openai/whisper-small", "google/gemma-3n-E2B-it"]
|
scope="module", params=["openai/whisper-small", "google/gemma-3n-E2B-it"]
|
||||||
)
|
)
|
||||||
def server(request, rocm_aiter_fa_attention):
|
def server(request):
|
||||||
# Parametrize over model name
|
# Parametrize over model name
|
||||||
|
attention_config = _get_rocm_attention_config(request.param)
|
||||||
with RemoteOpenAIServer(
|
with RemoteOpenAIServer(
|
||||||
request.param, _get_server_args(rocm_aiter_fa_attention)
|
request.param, _get_server_args(attention_config)
|
||||||
) as remote_server:
|
) as remote_server:
|
||||||
yield remote_server, request.param
|
yield remote_server, request.param
|
||||||
|
|
||||||
@@ -46,11 +78,12 @@ async def client_and_model(server):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_non_asr_model(foscolo, rocm_aiter_fa_attention):
|
async def test_non_asr_model(foscolo):
|
||||||
# text to text model
|
# text to text model
|
||||||
model_name = "JackFram/llama-68m"
|
model_name = "JackFram/llama-68m"
|
||||||
|
attention_config = _get_rocm_attention_config(model_name)
|
||||||
with RemoteOpenAIServer(
|
with RemoteOpenAIServer(
|
||||||
model_name, _get_server_args(rocm_aiter_fa_attention)
|
model_name, _get_server_args(attention_config)
|
||||||
) as remote_server:
|
) as remote_server:
|
||||||
client = remote_server.get_async_client()
|
client = remote_server.get_async_client()
|
||||||
|
|
||||||
@@ -61,7 +94,7 @@ async def test_non_asr_model(foscolo, rocm_aiter_fa_attention):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_basic_audio_with_lora(mary_had_lamb, rocm_aiter_fa_attention):
|
async def test_basic_audio_with_lora(mary_had_lamb):
|
||||||
"""Ensure STT (translate) requests can pass LoRA through to generate."""
|
"""Ensure STT (translate) requests can pass LoRA through to generate."""
|
||||||
# ROCm SPECIFIC CONFIGURATION:
|
# ROCm SPECIFIC CONFIGURATION:
|
||||||
# To ensure the test passes on ROCm, we modify the max model length to 512.
|
# To ensure the test passes on ROCm, we modify the max model length to 512.
|
||||||
@@ -85,7 +118,7 @@ async def test_basic_audio_with_lora(mary_had_lamb, rocm_aiter_fa_attention):
|
|||||||
"1",
|
"1",
|
||||||
]
|
]
|
||||||
|
|
||||||
add_attention_backend(server_args, rocm_aiter_fa_attention)
|
add_attention_backend(server_args, _get_rocm_attention_config(model_name))
|
||||||
|
|
||||||
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
|
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
|
||||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||||
|
|||||||
Reference in New Issue
Block a user