[Bugfix][ResponsesAPI] Fix crash when tool_choice=required exceeds max_output_tokens (#37258)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -134,6 +134,34 @@ async def test_function_tool_use(
|
||||
assert reasoning.type == "reasoning"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_max_tokens_with_tool_choice_required(
|
||||
client: openai.AsyncOpenAI, model_name: str
|
||||
):
|
||||
prompt = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Can you tell me what the current weather is in Berlin and the "
|
||||
"forecast for the next 5 days, in fahrenheit?",
|
||||
},
|
||||
]
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
input=prompt,
|
||||
tools=tools,
|
||||
tool_choice="required",
|
||||
max_output_tokens=10,
|
||||
)
|
||||
assert len(response.output) >= 1
|
||||
for out in response.output:
|
||||
# When `tool_choice="required"` and the tokens of `tools`
|
||||
# exceed `max_output_tokens`,`function_call` should be empty.
|
||||
# This behavior should be consistent with OpenAI
|
||||
assert out.type != "function_call"
|
||||
assert response.incomplete_details.reason == "max_output_tokens"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_named_tool_use(client: openai.AsyncOpenAI):
|
||||
def get_weather(latitude: float, longitude: float) -> str:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
@@ -18,7 +19,7 @@ from openai.types.responses.response_output_text import Logprob
|
||||
from openai.types.responses.response_reasoning_item import (
|
||||
Content as ResponseReasoningTextContent,
|
||||
)
|
||||
from pydantic import TypeAdapter
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
@@ -422,15 +423,19 @@ class DelegatingParser(Parser):
|
||||
|
||||
if request.tool_choice == "required":
|
||||
# Required tool calls - parse JSON
|
||||
assert content is not None
|
||||
tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(content)
|
||||
function_calls.extend(
|
||||
FunctionCall(
|
||||
name=tool_call.name,
|
||||
arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
|
||||
tool_calls = []
|
||||
with contextlib.suppress(ValidationError):
|
||||
content = content or ""
|
||||
tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json(
|
||||
content
|
||||
)
|
||||
for tool_call in tool_calls:
|
||||
function_calls.append(
|
||||
FunctionCall(
|
||||
name=tool_call.name,
|
||||
arguments=json.dumps(tool_call.parameters, ensure_ascii=False),
|
||||
)
|
||||
)
|
||||
for tool_call in tool_calls
|
||||
)
|
||||
return function_calls, None # Clear content since tool is called.
|
||||
|
||||
if (
|
||||
|
||||
Reference in New Issue
Block a user