[CI] Fix realtime WebSocket timeout deadlock and unhandled model validation errors (#37483)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -15,6 +15,13 @@ from tests.entrypoints.openai.conftest import add_attention_backend
|
||||
from tests.utils import ROCM_ENV_OVERRIDES, ROCM_EXTRA_ARGS, RemoteOpenAIServer
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
# Increase engine iteration timeout for ROCm where first-use JIT compilation
|
||||
# can exceed the default 60s, causing a silent deadlock in feed_tokens.
|
||||
REALTIME_ENV_OVERRIDES = {
|
||||
**ROCM_ENV_OVERRIDES,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": "600",
|
||||
}
|
||||
|
||||
MISTRAL_FORMAT_ARGS = [
|
||||
"--tokenizer_mode",
|
||||
"mistral",
|
||||
@@ -77,7 +84,7 @@ async def test_multi_chunk_streaming(
|
||||
add_attention_backend(server_args, rocm_aiter_fa_attention)
|
||||
|
||||
with RemoteOpenAIServer(
|
||||
model_name, server_args, env_dict=ROCM_ENV_OVERRIDES
|
||||
model_name, server_args, env_dict=REALTIME_ENV_OVERRIDES
|
||||
) as remote_server:
|
||||
ws_url = _get_websocket_url(remote_server)
|
||||
async with websockets.connect(ws_url) as ws:
|
||||
@@ -180,7 +187,7 @@ async def test_empty_commit_does_not_crash_engine(
|
||||
add_attention_backend(server_args, rocm_aiter_fa_attention)
|
||||
|
||||
with RemoteOpenAIServer(
|
||||
model_name, server_args, env_dict=ROCM_ENV_OVERRIDES
|
||||
model_name, server_args, env_dict=REALTIME_ENV_OVERRIDES
|
||||
) as remote_server:
|
||||
ws_url = _get_websocket_url(remote_server)
|
||||
|
||||
@@ -257,3 +264,68 @@ async def test_empty_commit_does_not_crash_engine(
|
||||
elif event["type"] == "error":
|
||||
pytest.fail(f"Engine error after empty commit: {event}")
|
||||
assert done_received
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_session_update_invalid_model_returns_error(
|
||||
model_name, rocm_aiter_fa_attention
|
||||
):
|
||||
"""Test that session.update with an invalid model returns an error."""
|
||||
server_args = ["--enforce-eager", "--max-model-len", "2048"]
|
||||
|
||||
if model_name.startswith("mistralai"):
|
||||
server_args += MISTRAL_FORMAT_ARGS
|
||||
|
||||
add_attention_backend(server_args, rocm_aiter_fa_attention)
|
||||
|
||||
with RemoteOpenAIServer(
|
||||
model_name, server_args, env_dict=REALTIME_ENV_OVERRIDES
|
||||
) as remote_server:
|
||||
ws_url = _get_websocket_url(remote_server)
|
||||
async with websockets.connect(ws_url) as ws:
|
||||
event = await receive_event(ws, timeout=30.0)
|
||||
assert event["type"] == "session.created"
|
||||
|
||||
# Send session.update with a model that doesn't exist
|
||||
await send_event(
|
||||
ws,
|
||||
{"type": "session.update", "model": "nonexistent-model"},
|
||||
)
|
||||
|
||||
event = await receive_event(ws, timeout=10.0)
|
||||
assert event["type"] == "error"
|
||||
assert "nonexistent-model" in event["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_commit_without_session_update_returns_error(
|
||||
model_name, rocm_aiter_fa_attention
|
||||
):
|
||||
"""Test that committing before validating the model returns an error
|
||||
and does not fall through to processing."""
|
||||
server_args = ["--enforce-eager", "--max-model-len", "2048"]
|
||||
|
||||
if model_name.startswith("mistralai"):
|
||||
server_args += MISTRAL_FORMAT_ARGS
|
||||
|
||||
add_attention_backend(server_args, rocm_aiter_fa_attention)
|
||||
|
||||
with RemoteOpenAIServer(
|
||||
model_name, server_args, env_dict=REALTIME_ENV_OVERRIDES
|
||||
) as remote_server:
|
||||
ws_url = _get_websocket_url(remote_server)
|
||||
async with websockets.connect(ws_url) as ws:
|
||||
event = await receive_event(ws, timeout=30.0)
|
||||
assert event["type"] == "session.created"
|
||||
|
||||
# Send commit without sending session.update first
|
||||
await send_event(
|
||||
ws,
|
||||
{"type": "input_audio_buffer.commit", "final": True},
|
||||
)
|
||||
|
||||
event = await receive_event(ws, timeout=10.0)
|
||||
assert event["type"] == "error"
|
||||
assert "model_not_validated" in event.get("code", "")
|
||||
|
||||
Reference in New Issue
Block a user