[V1] Structured Outputs + Thinking compatibility (#16577)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
# ruff: noqa: E501
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -5,17 +6,22 @@ from __future__ import annotations
|
||||
import json
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import jsonschema
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from tests.reasoning.utils import run_reasoning_extraction
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import TokenizerMode
|
||||
|
||||
NGRAM_SPEC_CONFIG = {
|
||||
"model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
@@ -444,7 +450,7 @@ def test_structured_output(
|
||||
|
||||
prompt = """
|
||||
You have access to the following function to retrieve the weather in a city:
|
||||
|
||||
|
||||
{
|
||||
"name": "get_weather",
|
||||
"parameters": {
|
||||
@@ -455,7 +461,7 @@ You have access to the following function to retrieve the weather in a city:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
If a you choose to call a function ONLY reply in the following format:
|
||||
<{start_tag}={function_name}>{parameters}{end_tag}
|
||||
where
|
||||
@@ -476,7 +482,7 @@ Reminder:
|
||||
- Always add your sources when using search results to answer the user query
|
||||
|
||||
You are a helpful assistant.
|
||||
|
||||
|
||||
Given the previous instructions, what is the weather in New York City? \
|
||||
Make the response as short as possible.
|
||||
"""
|
||||
@@ -514,6 +520,88 @@ Make the response as short as possible.
|
||||
f"{generated_text!r}\nError: {str(e)}")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, guided_decoding_backend, tokenizer_mode, reasoning_parser, speculative_config", # noqa: E501
|
||||
[
|
||||
("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto",
|
||||
"deepseek_r1", NGRAM_SPEC_CONFIG),
|
||||
("Qwen/Qwen3-1.7B", "xgrammar", "auto", "deepseek_r1", None),
|
||||
],
|
||||
)
|
||||
def test_structured_output_with_reasoning_matrices(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
guided_decoding_backend: str,
|
||||
tokenizer_mode: TokenizerMode,
|
||||
reasoning_parser: str,
|
||||
model_name: str,
|
||||
speculative_config: dict[str, Any] | None,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
if current_platform.is_tpu() and speculative_config:
|
||||
pytest.skip("TPU does not support speculative decoding")
|
||||
|
||||
# Use a single LLM instance for several scenarios to
|
||||
# speed up the test suite.
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
# Don't use eager execution on TPUs because we want to test for no
|
||||
# recompilation at runtime
|
||||
enforce_eager=bool(not current_platform.is_tpu()),
|
||||
max_model_len=1024,
|
||||
max_num_seqs=16,
|
||||
guided_decoding_backend=guided_decoding_backend,
|
||||
guided_decoding_disable_any_whitespace=True,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
reasoning_parser=reasoning_parser,
|
||||
speculative_config=speculative_config,
|
||||
)
|
||||
tokenizer = llm.get_tokenizer(None)
|
||||
reasoner = ReasoningParserManager.get_reasoning_parser(reasoning_parser)(
|
||||
tokenizer=tokenizer)
|
||||
|
||||
reasoning_prompt = "Solve the following math problem step-by-step, then provide the final answer as JSON object with a single key 'result'. Make sure to correct your reasoning if there are any issue should it arise.\nProblem: What is 5 * 8 + 2?" # noqa: E501
|
||||
reasoning_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"result": {
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"required": ["result"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
if "Qwen3" in model_name:
|
||||
reasoning_prompt += "<think>\n"
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.1,
|
||||
max_tokens=8192,
|
||||
guided_decoding=GuidedDecodingParams(json=reasoning_schema),
|
||||
)
|
||||
outputs = llm.generate(
|
||||
[reasoning_prompt],
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
assert outputs is not None
|
||||
output = outputs[0]
|
||||
assert output is not None and isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
reasoning_content, content = run_reasoning_extraction(
|
||||
reasoner, [generated_text])
|
||||
print(
|
||||
f"Prompt: {prompt!r}\nReasoning: {reasoning_content!r}\nContent: {content!r}"
|
||||
)
|
||||
|
||||
assert content is not None and reasoning_content is not None
|
||||
output_json = json.loads(content)
|
||||
jsonschema.validate(instance=output_json, schema=reasoning_schema)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("model_name, tokenizer_mode",
|
||||
PARAMS_MODELS_TOKENIZER_MODE)
|
||||
|
||||
Reference in New Issue
Block a user