diff --git a/tests/entrypoints/openai/responses/test_sampling_params.py b/tests/entrypoints/openai/responses/test_sampling_params.py index b8d1aa664..87910271d 100644 --- a/tests/entrypoints/openai/responses/test_sampling_params.py +++ b/tests/entrypoints/openai/responses/test_sampling_params.py @@ -4,8 +4,17 @@ """Unit tests for ResponsesRequest.to_sampling_params() parameter mapping.""" import pytest +import torch +from openai.types.responses.response_format_text_json_schema_config import ( + ResponseFormatTextJSONSchemaConfig, +) +from pydantic import ValidationError -from vllm.entrypoints.openai.responses.protocol import ResponsesRequest +from vllm.entrypoints.openai.responses.protocol import ( + ResponsesRequest, + ResponseTextConfig, +) +from vllm.sampling_params import StructuredOutputsParams class TestResponsesRequestSamplingParams: @@ -76,9 +85,6 @@ class TestResponsesRequestSamplingParams: def test_seed_bounds_validation(self): """Test that seed values outside torch.long bounds are rejected.""" - import torch - from pydantic import ValidationError - # Test seed below minimum with pytest.raises(ValidationError) as exc_info: ResponsesRequest( @@ -111,3 +117,40 @@ class TestResponsesRequestSamplingParams: seed=torch.iinfo(torch.long).max, ) assert request_max.seed == torch.iinfo(torch.long).max + + def test_structured_outputs_passed_through(self): + """Test that structured_outputs field is passed to SamplingParams.""" + structured_outputs = StructuredOutputsParams(grammar="root ::= 'hello'") + request = ResponsesRequest( + model="test-model", + input="test input", + structured_outputs=structured_outputs, + ) + + sampling_params = request.to_sampling_params(default_max_tokens=1000) + + assert sampling_params.structured_outputs is not None + assert sampling_params.structured_outputs.grammar == "root ::= 'hello'" + + def test_structured_outputs_and_json_schema_conflict(self): + """Test that specifying both structured_outputs and json_schema raises.""" + structured_outputs = StructuredOutputsParams(grammar="root ::= 'hello'") + text_config = ResponseTextConfig() + text_config.format = ResponseFormatTextJSONSchemaConfig( + type="json_schema", + name="test", + schema={"type": "object"}, + ) + request = ResponsesRequest( + model="test-model", + input="test input", + structured_outputs=structured_outputs, + text=text_config, + ) + + with pytest.raises(ValueError) as exc_info: + request.to_sampling_params(default_max_tokens=1000) + + assert "Cannot specify both structured_outputs and text.format" in str( + exc_info.value + ) diff --git a/vllm/entrypoints/openai/responses/protocol.py b/vllm/entrypoints/openai/responses/protocol.py index 9a471852b..2b62d7dca 100644 --- a/vllm/entrypoints/openai/responses/protocol.py +++ b/vllm/entrypoints/openai/responses/protocol.py @@ -233,6 +233,10 @@ class ResponsesRequest(OpenAIBaseModel): # this cannot be used in conjunction with previous_response_id # TODO: consider supporting non harmony messages as well previous_input_messages: list[OpenAIHarmonyMessage | dict] | None = None + structured_outputs: StructuredOutputsParams | None = Field( + default=None, + description="Additional kwargs for structured outputs", + ) repetition_penalty: float | None = None seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) @@ -319,8 +323,14 @@ class ResponsesRequest(OpenAIBaseModel): stop_token_ids = default_sampling_params.get("stop_token_ids") # Structured output - structured_outputs = None + structured_outputs = self.structured_outputs + + # Also check text.format for OpenAI-style json_schema if self.text is not None and self.text.format is not None: + if structured_outputs is not None: + raise ValueError( + "Cannot specify both structured_outputs and text.format" + ) response_format = self.text.format if ( response_format.type == "json_schema" @@ -329,8 +339,6 @@ class ResponsesRequest(OpenAIBaseModel): structured_outputs = StructuredOutputsParams( json=response_format.schema_ ) - elif response_format.type == "json_object": - raise NotImplementedError("json_object is not supported") stop = self.stop if self.stop else [] if isinstance(stop, str):