diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index a6995e4ca..dca5512c0 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -731,6 +731,101 @@ async def test_serving_chat_should_set_correct_max_tokens(): assert mock_engine.generate.call_args.args[1].max_tokens == 5 +@pytest.mark.asyncio +async def test_serving_chat_mistral_token_ids_prompt_is_validated(monkeypatch_module): + """Regression test: when the Mistral tokenizer path returns token IDs + directly, we must still apply input length + max_tokens validation. + """ + + mock_engine = MagicMock(spec=AsyncLLM) + mock_engine.errored = False + mock_engine.model_config = MockModelConfig() + mock_engine.input_processor = MagicMock() + mock_engine.io_processor = MagicMock() + + class DummyMistralTokenizer: + def decode(self, token_ids): + # Only used for logging/validation error messages. + return "dummy" + + dummy_tokenizer = DummyMistralTokenizer() + mock_engine.get_tokenizer.return_value = dummy_tokenizer + + # Patch the OpenAI engine serving module to treat our dummy tokenizer + # as a MistralTokenizer. This forces the code path where chat template + # rendering can return a list[int] (token IDs). + import vllm.entrypoints.openai.engine.serving as engine_serving + + monkeypatch_module.setattr( + engine_serving, "MistralTokenizer", DummyMistralTokenizer + ) + + serving_chat = _build_serving_chat(mock_engine) + + # Force the Mistral chat template renderer to return token IDs. + # Choose a prompt length that is < max_model_len, but large enough that + # adding max_tokens should exceed the model context window. + serving_chat._apply_mistral_chat_template_async = AsyncMock( + return_value=list(range(95)) + ) + + req = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{"role": "user", "content": "what is 1+1?"}], + max_tokens=10, + ) + + resp = await serving_chat.create_chat_completion(req) + assert isinstance(resp, ErrorResponse) + assert "max_tokens" in resp.error.message + + +@pytest.mark.asyncio +async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected( + monkeypatch_module, +): + """Regression test: MistralTokenizer token-id prompts must still enforce + the max context length for the input itself (token_num >= max_model_len). + """ + + mock_engine = MagicMock(spec=AsyncLLM) + mock_engine.errored = False + mock_engine.model_config = MockModelConfig() + mock_engine.input_processor = MagicMock() + mock_engine.io_processor = MagicMock() + + class DummyMistralTokenizer: + def decode(self, token_ids): + return "dummy" + + dummy_tokenizer = DummyMistralTokenizer() + mock_engine.get_tokenizer.return_value = dummy_tokenizer + + import vllm.entrypoints.openai.engine.serving as engine_serving + + monkeypatch_module.setattr( + engine_serving, "MistralTokenizer", DummyMistralTokenizer + ) + + serving_chat = _build_serving_chat(mock_engine) + + # prompt_token_ids length == max_model_len should be rejected for + # completion-like requests (ChatCompletionRequest). + serving_chat._apply_mistral_chat_template_async = AsyncMock( + return_value=list(range(mock_engine.model_config.max_model_len)) + ) + + req = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{"role": "user", "content": "what is 1+1?"}], + max_tokens=1, + ) + + resp = await serving_chat.create_chat_completion(req) + assert isinstance(resp, ErrorResponse) + assert "maximum context length" in resp.error.message + + @pytest.mark.asyncio async def test_serving_chat_could_load_correct_generation_config(): mock_model_config = MockModelConfig() diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index 26f4b725a..3d535f72d 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -1277,9 +1277,11 @@ class OpenAIServing: assert is_list_of(request_prompt, int), ( "Prompt has to be either a string or a list of token ids" ) - prompt_inputs = TokensPrompt( - prompt=tokenizer.decode(request_prompt), - prompt_token_ids=request_prompt, + input_text = tokenizer.decode(request_prompt) + prompt_inputs = self._validate_input( + request=request, + input_ids=request_prompt, + input_text=input_text, ) engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["prompt_token_ids"])