apply _validate_input to MistralTokenizer token-id chat prompts (#32448)

Signed-off-by: Vanshil Shah <vanshilshah@gmail.com>
This commit is contained in:
vanshil shah
2026-01-16 19:23:45 -08:00
committed by GitHub
parent 5a3050a089
commit 037a6487af
2 changed files with 100 additions and 3 deletions

View File

@@ -731,6 +731,101 @@ async def test_serving_chat_should_set_correct_max_tokens():
assert mock_engine.generate.call_args.args[1].max_tokens == 5
@pytest.mark.asyncio
async def test_serving_chat_mistral_token_ids_prompt_is_validated(monkeypatch_module):
"""Regression test: when the Mistral tokenizer path returns token IDs
directly, we must still apply input length + max_tokens validation.
"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
class DummyMistralTokenizer:
def decode(self, token_ids):
# Only used for logging/validation error messages.
return "dummy"
dummy_tokenizer = DummyMistralTokenizer()
mock_engine.get_tokenizer.return_value = dummy_tokenizer
# Patch the OpenAI engine serving module to treat our dummy tokenizer
# as a MistralTokenizer. This forces the code path where chat template
# rendering can return a list[int] (token IDs).
import vllm.entrypoints.openai.engine.serving as engine_serving
monkeypatch_module.setattr(
engine_serving, "MistralTokenizer", DummyMistralTokenizer
)
serving_chat = _build_serving_chat(mock_engine)
# Force the Mistral chat template renderer to return token IDs.
# Choose a prompt length that is < max_model_len, but large enough that
# adding max_tokens should exceed the model context window.
serving_chat._apply_mistral_chat_template_async = AsyncMock(
return_value=list(range(95))
)
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "what is 1+1?"}],
max_tokens=10,
)
resp = await serving_chat.create_chat_completion(req)
assert isinstance(resp, ErrorResponse)
assert "max_tokens" in resp.error.message
@pytest.mark.asyncio
async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected(
monkeypatch_module,
):
"""Regression test: MistralTokenizer token-id prompts must still enforce
the max context length for the input itself (token_num >= max_model_len).
"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
class DummyMistralTokenizer:
def decode(self, token_ids):
return "dummy"
dummy_tokenizer = DummyMistralTokenizer()
mock_engine.get_tokenizer.return_value = dummy_tokenizer
import vllm.entrypoints.openai.engine.serving as engine_serving
monkeypatch_module.setattr(
engine_serving, "MistralTokenizer", DummyMistralTokenizer
)
serving_chat = _build_serving_chat(mock_engine)
# prompt_token_ids length == max_model_len should be rejected for
# completion-like requests (ChatCompletionRequest).
serving_chat._apply_mistral_chat_template_async = AsyncMock(
return_value=list(range(mock_engine.model_config.max_model_len))
)
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "what is 1+1?"}],
max_tokens=1,
)
resp = await serving_chat.create_chat_completion(req)
assert isinstance(resp, ErrorResponse)
assert "maximum context length" in resp.error.message
@pytest.mark.asyncio
async def test_serving_chat_could_load_correct_generation_config():
mock_model_config = MockModelConfig()