[CI] Fix realtime WebSocket timeout deadlock and unhandled model validation errors (#37483)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-03-25 05:24:33 -05:00
committed by GitHub
parent e9ae3f8077
commit 9ac2fcafbb
2 changed files with 83 additions and 3 deletions

View File

@@ -15,6 +15,13 @@ from tests.entrypoints.openai.conftest import add_attention_backend
from tests.utils import ROCM_ENV_OVERRIDES, ROCM_EXTRA_ARGS, RemoteOpenAIServer
from vllm.assets.audio import AudioAsset
# Increase engine iteration timeout for ROCm where first-use JIT compilation
# can exceed the default 60s, causing a silent deadlock in feed_tokens.
REALTIME_ENV_OVERRIDES = {
**ROCM_ENV_OVERRIDES,
"VLLM_ENGINE_ITERATION_TIMEOUT_S": "600",
}
MISTRAL_FORMAT_ARGS = [
"--tokenizer_mode",
"mistral",
@@ -77,7 +84,7 @@ async def test_multi_chunk_streaming(
add_attention_backend(server_args, rocm_aiter_fa_attention)
with RemoteOpenAIServer(
model_name, server_args, env_dict=ROCM_ENV_OVERRIDES
model_name, server_args, env_dict=REALTIME_ENV_OVERRIDES
) as remote_server:
ws_url = _get_websocket_url(remote_server)
async with websockets.connect(ws_url) as ws:
@@ -180,7 +187,7 @@ async def test_empty_commit_does_not_crash_engine(
add_attention_backend(server_args, rocm_aiter_fa_attention)
with RemoteOpenAIServer(
model_name, server_args, env_dict=ROCM_ENV_OVERRIDES
model_name, server_args, env_dict=REALTIME_ENV_OVERRIDES
) as remote_server:
ws_url = _get_websocket_url(remote_server)
@@ -257,3 +264,68 @@ async def test_empty_commit_does_not_crash_engine(
elif event["type"] == "error":
pytest.fail(f"Engine error after empty commit: {event}")
assert done_received
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_session_update_invalid_model_returns_error(
model_name, rocm_aiter_fa_attention
):
"""Test that session.update with an invalid model returns an error."""
server_args = ["--enforce-eager", "--max-model-len", "2048"]
if model_name.startswith("mistralai"):
server_args += MISTRAL_FORMAT_ARGS
add_attention_backend(server_args, rocm_aiter_fa_attention)
with RemoteOpenAIServer(
model_name, server_args, env_dict=REALTIME_ENV_OVERRIDES
) as remote_server:
ws_url = _get_websocket_url(remote_server)
async with websockets.connect(ws_url) as ws:
event = await receive_event(ws, timeout=30.0)
assert event["type"] == "session.created"
# Send session.update with a model that doesn't exist
await send_event(
ws,
{"type": "session.update", "model": "nonexistent-model"},
)
event = await receive_event(ws, timeout=10.0)
assert event["type"] == "error"
assert "nonexistent-model" in event["error"]
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_commit_without_session_update_returns_error(
model_name, rocm_aiter_fa_attention
):
"""Test that committing before validating the model returns an error
and does not fall through to processing."""
server_args = ["--enforce-eager", "--max-model-len", "2048"]
if model_name.startswith("mistralai"):
server_args += MISTRAL_FORMAT_ARGS
add_attention_backend(server_args, rocm_aiter_fa_attention)
with RemoteOpenAIServer(
model_name, server_args, env_dict=REALTIME_ENV_OVERRIDES
) as remote_server:
ws_url = _get_websocket_url(remote_server)
async with websockets.connect(ws_url) as ws:
event = await receive_event(ws, timeout=30.0)
assert event["type"] == "session.created"
# Send commit without sending session.update first
await send_event(
ws,
{"type": "input_audio_buffer.commit", "final": True},
)
event = await receive_event(ws, timeout=10.0)
assert event["type"] == "error"
assert "model_not_validated" in event.get("code", "")

View File

@@ -102,7 +102,14 @@ class RealtimeConnection:
event_type = event.get("type")
if event_type == "session.update":
logger.debug("Session updated: %s", event)
self._check_model(event["model"])
model = event.get("model")
if model is None:
await self.send_error("Missing required field: model", "invalid_event")
return
err = self._check_model(model)
if err is not None:
await self.send_error(err.error.message, "model_not_found")
return
self._is_model_validated = True
elif event_type == "input_audio_buffer.append":
append_event = InputAudioBufferAppend(**event)
@@ -140,6 +147,7 @@ class RealtimeConnection:
err_msg,
"model_not_validated",
)
return
commit_event = InputAudioBufferCommit(**event)
# final signals that the audio is finished