diff --git a/tests/entrypoints/openai/realtime/test_realtime_validation.py b/tests/entrypoints/openai/realtime/test_realtime_validation.py index 672894d0c..bb6b02f5c 100644 --- a/tests/entrypoints/openai/realtime/test_realtime_validation.py +++ b/tests/entrypoints/openai/realtime/test_realtime_validation.py @@ -15,6 +15,13 @@ from tests.entrypoints.openai.conftest import add_attention_backend from tests.utils import ROCM_ENV_OVERRIDES, ROCM_EXTRA_ARGS, RemoteOpenAIServer from vllm.assets.audio import AudioAsset +# Increase engine iteration timeout for ROCm where first-use JIT compilation +# can exceed the default 60s, causing a silent deadlock in feed_tokens. +REALTIME_ENV_OVERRIDES = { + **ROCM_ENV_OVERRIDES, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": "600", +} + MISTRAL_FORMAT_ARGS = [ "--tokenizer_mode", "mistral", @@ -77,7 +84,7 @@ async def test_multi_chunk_streaming( add_attention_backend(server_args, rocm_aiter_fa_attention) with RemoteOpenAIServer( - model_name, server_args, env_dict=ROCM_ENV_OVERRIDES + model_name, server_args, env_dict=REALTIME_ENV_OVERRIDES ) as remote_server: ws_url = _get_websocket_url(remote_server) async with websockets.connect(ws_url) as ws: @@ -180,7 +187,7 @@ async def test_empty_commit_does_not_crash_engine( add_attention_backend(server_args, rocm_aiter_fa_attention) with RemoteOpenAIServer( - model_name, server_args, env_dict=ROCM_ENV_OVERRIDES + model_name, server_args, env_dict=REALTIME_ENV_OVERRIDES ) as remote_server: ws_url = _get_websocket_url(remote_server) @@ -257,3 +264,68 @@ async def test_empty_commit_does_not_crash_engine( elif event["type"] == "error": pytest.fail(f"Engine error after empty commit: {event}") assert done_received + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_session_update_invalid_model_returns_error( + model_name, rocm_aiter_fa_attention +): + """Test that session.update with an invalid model returns an error.""" + server_args = ["--enforce-eager", "--max-model-len", "2048"] + + if model_name.startswith("mistralai"): + server_args += MISTRAL_FORMAT_ARGS + + add_attention_backend(server_args, rocm_aiter_fa_attention) + + with RemoteOpenAIServer( + model_name, server_args, env_dict=REALTIME_ENV_OVERRIDES + ) as remote_server: + ws_url = _get_websocket_url(remote_server) + async with websockets.connect(ws_url) as ws: + event = await receive_event(ws, timeout=30.0) + assert event["type"] == "session.created" + + # Send session.update with a model that doesn't exist + await send_event( + ws, + {"type": "session.update", "model": "nonexistent-model"}, + ) + + event = await receive_event(ws, timeout=10.0) + assert event["type"] == "error" + assert "nonexistent-model" in event["error"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_commit_without_session_update_returns_error( + model_name, rocm_aiter_fa_attention +): + """Test that committing before validating the model returns an error + and does not fall through to processing.""" + server_args = ["--enforce-eager", "--max-model-len", "2048"] + + if model_name.startswith("mistralai"): + server_args += MISTRAL_FORMAT_ARGS + + add_attention_backend(server_args, rocm_aiter_fa_attention) + + with RemoteOpenAIServer( + model_name, server_args, env_dict=REALTIME_ENV_OVERRIDES + ) as remote_server: + ws_url = _get_websocket_url(remote_server) + async with websockets.connect(ws_url) as ws: + event = await receive_event(ws, timeout=30.0) + assert event["type"] == "session.created" + + # Send commit without sending session.update first + await send_event( + ws, + {"type": "input_audio_buffer.commit", "final": True}, + ) + + event = await receive_event(ws, timeout=10.0) + assert event["type"] == "error" + assert "model_not_validated" in event.get("code", "") diff --git a/vllm/entrypoints/openai/realtime/connection.py b/vllm/entrypoints/openai/realtime/connection.py index c958004bb..58af32905 100644 --- a/vllm/entrypoints/openai/realtime/connection.py +++ b/vllm/entrypoints/openai/realtime/connection.py @@ -102,7 +102,14 @@ class RealtimeConnection: event_type = event.get("type") if event_type == "session.update": logger.debug("Session updated: %s", event) - self._check_model(event["model"]) + model = event.get("model") + if model is None: + await self.send_error("Missing required field: model", "invalid_event") + return + err = self._check_model(model) + if err is not None: + await self.send_error(err.error.message, "model_not_found") + return self._is_model_validated = True elif event_type == "input_audio_buffer.append": append_event = InputAudioBufferAppend(**event) @@ -140,6 +147,7 @@ class RealtimeConnection: err_msg, "model_not_validated", ) + return commit_event = InputAudioBufferCommit(**event) # final signals that the audio is finished