[CI] Fix realtime WebSocket timeout deadlock and unhandled model validation errors (#37483)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -15,6 +15,13 @@ from tests.entrypoints.openai.conftest import add_attention_backend
|
||||
from tests.utils import ROCM_ENV_OVERRIDES, ROCM_EXTRA_ARGS, RemoteOpenAIServer
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
# Increase engine iteration timeout for ROCm where first-use JIT compilation
|
||||
# can exceed the default 60s, causing a silent deadlock in feed_tokens.
|
||||
REALTIME_ENV_OVERRIDES = {
|
||||
**ROCM_ENV_OVERRIDES,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": "600",
|
||||
}
|
||||
|
||||
MISTRAL_FORMAT_ARGS = [
|
||||
"--tokenizer_mode",
|
||||
"mistral",
|
||||
@@ -77,7 +84,7 @@ async def test_multi_chunk_streaming(
|
||||
add_attention_backend(server_args, rocm_aiter_fa_attention)
|
||||
|
||||
with RemoteOpenAIServer(
|
||||
model_name, server_args, env_dict=ROCM_ENV_OVERRIDES
|
||||
model_name, server_args, env_dict=REALTIME_ENV_OVERRIDES
|
||||
) as remote_server:
|
||||
ws_url = _get_websocket_url(remote_server)
|
||||
async with websockets.connect(ws_url) as ws:
|
||||
@@ -180,7 +187,7 @@ async def test_empty_commit_does_not_crash_engine(
|
||||
add_attention_backend(server_args, rocm_aiter_fa_attention)
|
||||
|
||||
with RemoteOpenAIServer(
|
||||
model_name, server_args, env_dict=ROCM_ENV_OVERRIDES
|
||||
model_name, server_args, env_dict=REALTIME_ENV_OVERRIDES
|
||||
) as remote_server:
|
||||
ws_url = _get_websocket_url(remote_server)
|
||||
|
||||
@@ -257,3 +264,68 @@ async def test_empty_commit_does_not_crash_engine(
|
||||
elif event["type"] == "error":
|
||||
pytest.fail(f"Engine error after empty commit: {event}")
|
||||
assert done_received
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_session_update_invalid_model_returns_error(
|
||||
model_name, rocm_aiter_fa_attention
|
||||
):
|
||||
"""Test that session.update with an invalid model returns an error."""
|
||||
server_args = ["--enforce-eager", "--max-model-len", "2048"]
|
||||
|
||||
if model_name.startswith("mistralai"):
|
||||
server_args += MISTRAL_FORMAT_ARGS
|
||||
|
||||
add_attention_backend(server_args, rocm_aiter_fa_attention)
|
||||
|
||||
with RemoteOpenAIServer(
|
||||
model_name, server_args, env_dict=REALTIME_ENV_OVERRIDES
|
||||
) as remote_server:
|
||||
ws_url = _get_websocket_url(remote_server)
|
||||
async with websockets.connect(ws_url) as ws:
|
||||
event = await receive_event(ws, timeout=30.0)
|
||||
assert event["type"] == "session.created"
|
||||
|
||||
# Send session.update with a model that doesn't exist
|
||||
await send_event(
|
||||
ws,
|
||||
{"type": "session.update", "model": "nonexistent-model"},
|
||||
)
|
||||
|
||||
event = await receive_event(ws, timeout=10.0)
|
||||
assert event["type"] == "error"
|
||||
assert "nonexistent-model" in event["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_commit_without_session_update_returns_error(
|
||||
model_name, rocm_aiter_fa_attention
|
||||
):
|
||||
"""Test that committing before validating the model returns an error
|
||||
and does not fall through to processing."""
|
||||
server_args = ["--enforce-eager", "--max-model-len", "2048"]
|
||||
|
||||
if model_name.startswith("mistralai"):
|
||||
server_args += MISTRAL_FORMAT_ARGS
|
||||
|
||||
add_attention_backend(server_args, rocm_aiter_fa_attention)
|
||||
|
||||
with RemoteOpenAIServer(
|
||||
model_name, server_args, env_dict=REALTIME_ENV_OVERRIDES
|
||||
) as remote_server:
|
||||
ws_url = _get_websocket_url(remote_server)
|
||||
async with websockets.connect(ws_url) as ws:
|
||||
event = await receive_event(ws, timeout=30.0)
|
||||
assert event["type"] == "session.created"
|
||||
|
||||
# Send commit without sending session.update first
|
||||
await send_event(
|
||||
ws,
|
||||
{"type": "input_audio_buffer.commit", "final": True},
|
||||
)
|
||||
|
||||
event = await receive_event(ws, timeout=10.0)
|
||||
assert event["type"] == "error"
|
||||
assert "model_not_validated" in event.get("code", "")
|
||||
|
||||
@@ -102,7 +102,14 @@ class RealtimeConnection:
|
||||
event_type = event.get("type")
|
||||
if event_type == "session.update":
|
||||
logger.debug("Session updated: %s", event)
|
||||
self._check_model(event["model"])
|
||||
model = event.get("model")
|
||||
if model is None:
|
||||
await self.send_error("Missing required field: model", "invalid_event")
|
||||
return
|
||||
err = self._check_model(model)
|
||||
if err is not None:
|
||||
await self.send_error(err.error.message, "model_not_found")
|
||||
return
|
||||
self._is_model_validated = True
|
||||
elif event_type == "input_audio_buffer.append":
|
||||
append_event = InputAudioBufferAppend(**event)
|
||||
@@ -140,6 +147,7 @@ class RealtimeConnection:
|
||||
err_msg,
|
||||
"model_not_validated",
|
||||
)
|
||||
return
|
||||
|
||||
commit_event = InputAudioBufferCommit(**event)
|
||||
# final signals that the audio is finished
|
||||
|
||||
Reference in New Issue
Block a user