[Frontend] Enable generic structured_outputs for responses API (#33709)
Signed-off-by: Alec Solder <alecs@fb.com> Co-authored-by: Alec Solder <alecs@fb.com>
This commit is contained in:
@@ -4,8 +4,17 @@
|
||||
"""Unit tests for ResponsesRequest.to_sampling_params() parameter mapping."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from openai.types.responses.response_format_text_json_schema_config import (
|
||||
ResponseFormatTextJSONSchemaConfig,
|
||||
)
|
||||
from pydantic import ValidationError
|
||||
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponsesRequest,
|
||||
ResponseTextConfig,
|
||||
)
|
||||
from vllm.sampling_params import StructuredOutputsParams
|
||||
|
||||
|
||||
class TestResponsesRequestSamplingParams:
|
||||
@@ -76,9 +85,6 @@ class TestResponsesRequestSamplingParams:
|
||||
|
||||
def test_seed_bounds_validation(self):
|
||||
"""Test that seed values outside torch.long bounds are rejected."""
|
||||
import torch
|
||||
from pydantic import ValidationError
|
||||
|
||||
# Test seed below minimum
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ResponsesRequest(
|
||||
@@ -111,3 +117,40 @@ class TestResponsesRequestSamplingParams:
|
||||
seed=torch.iinfo(torch.long).max,
|
||||
)
|
||||
assert request_max.seed == torch.iinfo(torch.long).max
|
||||
|
||||
def test_structured_outputs_passed_through(self):
|
||||
"""Test that structured_outputs field is passed to SamplingParams."""
|
||||
structured_outputs = StructuredOutputsParams(grammar="root ::= 'hello'")
|
||||
request = ResponsesRequest(
|
||||
model="test-model",
|
||||
input="test input",
|
||||
structured_outputs=structured_outputs,
|
||||
)
|
||||
|
||||
sampling_params = request.to_sampling_params(default_max_tokens=1000)
|
||||
|
||||
assert sampling_params.structured_outputs is not None
|
||||
assert sampling_params.structured_outputs.grammar == "root ::= 'hello'"
|
||||
|
||||
def test_structured_outputs_and_json_schema_conflict(self):
|
||||
"""Test that specifying both structured_outputs and json_schema raises."""
|
||||
structured_outputs = StructuredOutputsParams(grammar="root ::= 'hello'")
|
||||
text_config = ResponseTextConfig()
|
||||
text_config.format = ResponseFormatTextJSONSchemaConfig(
|
||||
type="json_schema",
|
||||
name="test",
|
||||
schema={"type": "object"},
|
||||
)
|
||||
request = ResponsesRequest(
|
||||
model="test-model",
|
||||
input="test input",
|
||||
structured_outputs=structured_outputs,
|
||||
text=text_config,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
request.to_sampling_params(default_max_tokens=1000)
|
||||
|
||||
assert "Cannot specify both structured_outputs and text.format" in str(
|
||||
exc_info.value
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user