[Frontend] Add sampling parameters to Responses API (#32609)
Signed-off-by: Daniel Mescheder <dmesch@amazon.com> Co-authored-by: Daniel Mescheder <dmesch@amazon.com>
This commit is contained in:
113
tests/entrypoints/openai/responses/test_sampling_params.py
Normal file
113
tests/entrypoints/openai/responses/test_sampling_params.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""Unit tests for ResponsesRequest.to_sampling_params() parameter mapping."""
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
|
||||
|
||||
class TestResponsesRequestSamplingParams:
|
||||
"""Test that ResponsesRequest correctly maps parameters to SamplingParams."""
|
||||
|
||||
def test_basic_sampling_params(self):
|
||||
"""Test basic sampling parameters are correctly mapped."""
|
||||
request = ResponsesRequest(
|
||||
model="test-model",
|
||||
input="test input",
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
top_k=50,
|
||||
max_output_tokens=100,
|
||||
)
|
||||
|
||||
sampling_params = request.to_sampling_params(default_max_tokens=1000)
|
||||
|
||||
assert sampling_params.temperature == 0.8
|
||||
assert sampling_params.top_p == 0.95
|
||||
assert sampling_params.top_k == 50
|
||||
assert sampling_params.max_tokens == 100
|
||||
|
||||
def test_extra_sampling_params(self):
|
||||
"""Test extra sampling parameters are correctly mapped."""
|
||||
request = ResponsesRequest(
|
||||
model="test-model",
|
||||
input="test input",
|
||||
repetition_penalty=1.2,
|
||||
seed=42,
|
||||
stop=["END", "STOP"],
|
||||
ignore_eos=True,
|
||||
vllm_xargs={"custom": "value"},
|
||||
)
|
||||
|
||||
sampling_params = request.to_sampling_params(default_max_tokens=1000)
|
||||
|
||||
assert sampling_params.repetition_penalty == 1.2
|
||||
assert sampling_params.seed == 42
|
||||
assert sampling_params.stop == ["END", "STOP"]
|
||||
assert sampling_params.ignore_eos is True
|
||||
assert sampling_params.extra_args == {"custom": "value"}
|
||||
|
||||
def test_stop_string_conversion(self):
|
||||
"""Test that single stop string is converted to list."""
|
||||
request = ResponsesRequest(
|
||||
model="test-model",
|
||||
input="test input",
|
||||
stop="STOP",
|
||||
)
|
||||
|
||||
sampling_params = request.to_sampling_params(default_max_tokens=1000)
|
||||
|
||||
assert sampling_params.stop == ["STOP"]
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values for optional parameters."""
|
||||
request = ResponsesRequest(
|
||||
model="test-model",
|
||||
input="test input",
|
||||
)
|
||||
|
||||
sampling_params = request.to_sampling_params(default_max_tokens=1000)
|
||||
|
||||
assert sampling_params.repetition_penalty == 1.0 # None → 1.0
|
||||
assert sampling_params.stop == [] # Empty list
|
||||
assert sampling_params.extra_args == {} # Empty dict
|
||||
|
||||
def test_seed_bounds_validation(self):
|
||||
"""Test that seed values outside torch.long bounds are rejected."""
|
||||
import torch
|
||||
from pydantic import ValidationError
|
||||
|
||||
# Test seed below minimum
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ResponsesRequest(
|
||||
model="test-model",
|
||||
input="test input",
|
||||
seed=torch.iinfo(torch.long).min - 1,
|
||||
)
|
||||
assert "greater_than_equal" in str(exc_info.value).lower()
|
||||
|
||||
# Test seed above maximum
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ResponsesRequest(
|
||||
model="test-model",
|
||||
input="test input",
|
||||
seed=torch.iinfo(torch.long).max + 1,
|
||||
)
|
||||
assert "less_than_equal" in str(exc_info.value).lower()
|
||||
|
||||
# Test valid seed at boundaries
|
||||
request_min = ResponsesRequest(
|
||||
model="test-model",
|
||||
input="test input",
|
||||
seed=torch.iinfo(torch.long).min,
|
||||
)
|
||||
assert request_min.seed == torch.iinfo(torch.long).min
|
||||
|
||||
request_max = ResponsesRequest(
|
||||
model="test-model",
|
||||
input="test input",
|
||||
seed=torch.iinfo(torch.long).max,
|
||||
)
|
||||
assert request_max.seed == torch.iinfo(torch.long).max
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from openai import OpenAI
|
||||
@@ -147,3 +146,27 @@ async def test_max_tokens(client: OpenAI, model_name: str):
|
||||
assert response is not None
|
||||
assert response.status == "incomplete"
|
||||
assert response.incomplete_details.reason == "max_output_tokens"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_extra_sampling_params(client: OpenAI, model_name: str):
|
||||
"""Test that extra sampling parameters are accepted and work."""
|
||||
# Test with multiple sampling parameters - just verify they're accepted
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
input="Write a short sentence",
|
||||
max_output_tokens=50,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
extra_body={
|
||||
"top_k": 40,
|
||||
"repetition_penalty": 1.2,
|
||||
"seed": 42,
|
||||
},
|
||||
)
|
||||
|
||||
# Verify request succeeded and parameters were accepted
|
||||
assert response.status in ["completed", "incomplete"]
|
||||
assert len(response.output) > 0
|
||||
assert response.output[0].content[0].text # Has text output
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
import time
|
||||
from typing import Any, Literal, TypeAlias
|
||||
|
||||
import torch
|
||||
from openai.types.responses import (
|
||||
ResponseCodeInterpreterCallCodeDeltaEvent,
|
||||
ResponseCodeInterpreterCallCodeDoneEvent,
|
||||
@@ -77,6 +78,8 @@ from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_LONG_INFO = torch.iinfo(torch.long)
|
||||
|
||||
|
||||
class InputTokensDetails(OpenAIBaseModel):
|
||||
cached_tokens: int
|
||||
@@ -230,6 +233,18 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
# this cannot be used in conjunction with previous_response_id
|
||||
# TODO: consider supporting non harmony messages as well
|
||||
previous_input_messages: list[OpenAIHarmonyMessage | dict] | None = None
|
||||
|
||||
repetition_penalty: float | None = None
|
||||
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||
stop: str | list[str] | None = []
|
||||
ignore_eos: bool = False
|
||||
vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Additional request parameters with (list of) string or "
|
||||
"numeric values, used by custom extensions."
|
||||
),
|
||||
)
|
||||
# --8<-- [end:responses-extra-params]
|
||||
|
||||
def build_chat_params(
|
||||
@@ -297,6 +312,10 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
top_k = default_sampling_params.get(
|
||||
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
|
||||
)
|
||||
|
||||
if (repetition_penalty := self.repetition_penalty) is None:
|
||||
repetition_penalty = default_sampling_params.get("repetition_penalty", 1.0)
|
||||
|
||||
stop_token_ids = default_sampling_params.get("stop_token_ids")
|
||||
|
||||
# Structured output
|
||||
@@ -313,7 +332,10 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
elif response_format.type == "json_object":
|
||||
raise NotImplementedError("json_object is not supported")
|
||||
|
||||
# TODO: add more parameters
|
||||
stop = self.stop if self.stop else []
|
||||
if isinstance(stop, str):
|
||||
stop = [stop]
|
||||
|
||||
return SamplingParams.from_optional(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
@@ -321,11 +343,16 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
max_tokens=max_tokens,
|
||||
logprobs=self.top_logprobs if self.is_include_output_logprobs() else None,
|
||||
stop_token_ids=stop_token_ids,
|
||||
stop=stop,
|
||||
repetition_penalty=repetition_penalty,
|
||||
seed=self.seed,
|
||||
ignore_eos=self.ignore_eos,
|
||||
output_kind=(
|
||||
RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY
|
||||
),
|
||||
structured_outputs=structured_outputs,
|
||||
logit_bias=self.logit_bias,
|
||||
extra_args=self.vllm_xargs or {},
|
||||
skip_clone=True, # Created fresh per request, safe to skip clone
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
|
||||
Reference in New Issue
Block a user