[Frontend] OpenAI Responses API supports Tool/Function calling with streaming (#29947)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
Chauncey
2026-03-12 15:03:50 +08:00
committed by GitHub
parent 802f306cd1
commit 9fe404ed04
3 changed files with 348 additions and 20 deletions

View File

@@ -659,9 +659,10 @@ class TestStreamingReasoningToContentTransition:
# Mock the reasoning parser on the serving instance
mock_parser = MagicMock()
mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming
mock_parser.extract_tool_calls_streaming = mock_extract_reasoning_streaming
serving.parser = MagicMock()
serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser)
serving.parser.tool_parser_cls = MagicMock(return_value=mock_parser)
# Create contexts for each streaming chunk
contexts = [
_make_simple_context_with_output("chunk1", [10]),
@@ -739,8 +740,10 @@ class TestStreamingReasoningToContentTransition:
mock_parser = MagicMock()
mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming
mock_parser.extract_tool_calls_streaming = mock_extract_reasoning_streaming
serving.parser = MagicMock()
serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser)
serving.parser.tool_parser_cls = MagicMock(return_value=mock_parser)
contexts = [
_make_simple_context_with_output("chunk1", [10]),
@@ -812,8 +815,10 @@ class TestStreamingReasoningToContentTransition:
mock_parser = MagicMock()
mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming
mock_parser.extract_tool_calls_streaming = mock_extract_reasoning_streaming
serving.parser = MagicMock()
serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser)
serving.parser.tool_parser_cls = MagicMock(return_value=mock_parser)
contexts = [
_make_simple_context_with_output("chunk1", [10]),

View File

@@ -197,3 +197,108 @@ async def test_named_tool_use(client: openai.AsyncOpenAI):
response_2 = await client.responses.create(model=MODEL_NAME, input=input_messages)
# check the output
assert len(response_2.output_text) > 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_function_calling_with_streaming_expected_arguments(
client: openai.AsyncOpenAI, model_name: str
):
tools = [
{
"type": "function",
"name": "get_weather",
"description": "Get current temperature for provided location in celsius.",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string"},
},
"required": ["location"],
"additionalProperties": False,
},
"strict": True,
}
]
stream_response = await client.responses.create(
model=model_name,
input="Can you tell me what the current weather is in Berlin?",
tools=tools,
stream=True,
)
tool_call_item = None
completed_event = None
async for event in stream_response:
if (
event.type == "response.output_item.added"
and event.item.type == "function_call"
):
tool_call_item = event.item
elif event.type == "response.function_call_arguments.delta" and tool_call_item:
tool_call_item.arguments += event.delta
elif (
event.type == "response.output_item.done"
and event.item.type == "function_call"
):
completed_event = event
assert tool_call_item is not None
assert tool_call_item.type == "function_call"
assert tool_call_item.name == "get_weather"
assert completed_event is not None
assert tool_call_item.arguments == completed_event.item.arguments
assert tool_call_item.name == completed_event.item.name
args = json.loads(tool_call_item.arguments)
assert "location" in args
assert args["location"] is not None
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_function_calling_with_streaming_types(
client: openai.AsyncOpenAI, model_name: str
):
# this links the "done" type with the "start" type
# so every "done" type should have a corresponding "start" type
# and every open block should be closed by the end of the stream
pairs_of_event_types = {
"response.completed": "response.created",
"response.output_item.done": "response.output_item.added",
"response.output_text.done": "response.output_text.delta",
"response.content_part.done": "response.content_part.added",
"response.reasoning_text.done": "response.reasoning_text.delta",
"response.reasoning_part.done": "response.reasoning_part.added",
"response.function_call_arguments.done": "response.function_call_arguments.delta", # noqa
}
input_list = [
{
"role": "user",
"content": "Can you tell me what the current weather is in Berlin?",
}
]
stream_response = await client.responses.create(
model=model_name,
input=input_list,
tools=tools,
stream=True,
)
stack_of_event_types = []
async for event in stream_response:
if event.type == "response.created":
stack_of_event_types.append(event.type)
elif event.type == "response.completed":
assert stack_of_event_types[-1] == pairs_of_event_types[event.type]
stack_of_event_types.pop()
if event.type.endswith("added"):
stack_of_event_types.append(event.type)
elif event.type.endswith("delta"):
if stack_of_event_types[-1] == event.type:
continue
stack_of_event_types.append(event.type)
elif event.type.endswith("done"):
assert stack_of_event_types[-1] == pairs_of_event_types[event.type]
stack_of_event_types.pop()
assert len(stack_of_event_types) == 0

View File

@@ -15,7 +15,10 @@ from fastapi import Request
from openai.types.responses import (
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseFunctionCallArgumentsDeltaEvent,
ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionToolCall,
ResponseFunctionToolCallItem,
ResponseOutputItem,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
@@ -113,6 +116,7 @@ from vllm.parser import ParserManager
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid
from vllm.utils.collection_utils import as_list
logger = init_logger(__name__)
@@ -1236,38 +1240,134 @@ class OpenAIServingResponses(OpenAIServing):
reasoning_parser = None
if self.parser and self.parser.reasoning_parser_cls:
reasoning_parser = self.parser.reasoning_parser_cls(tokenizer)
tool_parser = None
if self.parser and self.parser.tool_parser_cls:
tool_parser = self.parser.tool_parser_cls(tokenizer)
reasoning_ended = False
tool_call_text_started = False
previous_text = ""
previous_token_ids: list[int] = []
prompt_is_reasoning_end = None
first_delta_sent = False
previous_delta_messages: list[DeltaMessage] = []
async for ctx in result_generator:
assert isinstance(ctx, SimpleContext)
if ctx.last_output is None:
continue
if reasoning_parser and prompt_is_reasoning_end is None:
prompt_is_reasoning_end = reasoning_parser.is_reasoning_end(
ctx.last_output.prompt_token_ids
)
if ctx.last_output.outputs:
output = ctx.last_output.outputs[0]
# finish_reason='error' indicates a retryable error
self._raise_if_error(output.finish_reason, request.request_id)
if reasoning_parser:
delta_text = output.text
delta_token_ids = as_list(output.token_ids)
current_text = previous_text + delta_text
current_token_ids = previous_token_ids + delta_token_ids
if reasoning_parser and tool_parser:
if prompt_is_reasoning_end:
reasoning_ended = True
if not reasoning_ended:
delta_message = reasoning_parser.extract_reasoning_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
)
if reasoning_parser.is_reasoning_end(delta_token_ids):
reasoning_ended = True
current_token_ids = reasoning_parser.extract_content_ids(
delta_token_ids
)
if delta_message and delta_message.content:
current_text = delta_message.content
delta_message.content = None
else:
current_text = ""
if reasoning_ended:
if not tool_call_text_started:
tool_call_text_started = True
previous_text = ""
previous_token_ids = []
delta_text = current_text
delta_token_ids = current_token_ids
delta_message = tool_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
request=request, # type: ignore[arg-type]
)
elif reasoning_parser:
delta_message = reasoning_parser.extract_reasoning_streaming(
previous_text=previous_text,
current_text=previous_text + output.text,
delta_text=output.text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=previous_token_ids + output.token_ids,
delta_token_ids=output.token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
)
elif tool_parser:
delta_message = tool_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
request=request, # type: ignore[arg-type]
)
else:
delta_message = DeltaMessage(
content=output.text,
)
previous_text += output.text
previous_token_ids += output.token_ids
previous_text = current_text
previous_token_ids = current_token_ids
if not delta_message:
continue
if not first_delta_sent:
current_item_id = str(uuid.uuid4())
if delta_message.reasoning:
current_item_id = random_uuid()
if delta_message.tool_calls:
current_tool_call_id = f"call_{random_uuid()}"
assert len(delta_message.tool_calls) == 1, (
"Multiple tool calls in one delta is not supported"
)
assert delta_message.tool_calls[0].function is not None, (
"Tool call without function is not supported"
)
assert delta_message.tool_calls[0].function.name is not None, (
"Tool call without function name is not supported"
)
current_tool_call_name = delta_message.tool_calls[
0
].function.name
yield _increment_sequence_number_and_return(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=ResponseFunctionToolCallItem(
type="function_call",
id=current_item_id,
call_id=current_tool_call_id,
name=current_tool_call_name,
arguments=delta_message.tool_calls[
0
].function.arguments,
status="in_progress",
),
)
)
elif delta_message.reasoning:
yield _increment_sequence_number_and_return(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
@@ -1294,7 +1394,7 @@ class OpenAIServingResponses(OpenAIServing):
),
)
)
else:
elif not delta_message.tool_calls:
yield _increment_sequence_number_and_return(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
@@ -1325,7 +1425,6 @@ class OpenAIServingResponses(OpenAIServing):
)
)
first_delta_sent = True
# todo(kebe7jun) tool call support
# check delta message and previous delta message are
# same as content or reasoning content
@@ -1438,8 +1537,87 @@ class OpenAIServingResponses(OpenAIServing):
)
# reset previous delta messages
previous_delta_messages = []
if delta_message.reasoning is not None:
if delta_message.tool_calls and delta_message.tool_calls[0].function:
if delta_message.tool_calls[0].function.arguments:
yield _increment_sequence_number_and_return(
ResponseFunctionCallArgumentsDeltaEvent(
type="response.function_call_arguments.delta",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
delta=delta_message.tool_calls[0].function.arguments,
)
)
# tool call initiated with no arguments
elif delta_message.tool_calls[0].function.name:
# send done with current content part
# and add new function call item
yield _increment_sequence_number_and_return(
ResponseTextDoneEvent(
type="response.output_text.done",
sequence_number=-1,
output_index=current_output_index,
content_index=current_content_index,
text="",
logprobs=[],
item_id=current_item_id,
)
)
yield _increment_sequence_number_and_return(
ResponseContentPartDoneEvent(
type="response.content_part.done",
sequence_number=-1,
item_id=current_item_id,
output_index=current_output_index,
content_index=current_content_index,
part=ResponseOutputText(
type="output_text",
text="",
annotations=[],
logprobs=[],
),
)
)
yield _increment_sequence_number_and_return(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
item=ResponseOutputMessage(
id=current_item_id,
type="message",
role="assistant",
content=[],
status="completed",
),
)
)
current_output_index += 1
current_item_id = random_uuid()
assert delta_message.tool_calls[0].function is not None
current_tool_call_name = delta_message.tool_calls[
0
].function.name
current_tool_call_id = f"call_{random_uuid()}"
yield _increment_sequence_number_and_return(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=ResponseFunctionToolCallItem(
type="function_call",
id=current_item_id,
call_id=current_tool_call_id,
name=current_tool_call_name,
arguments="",
status="in_progress",
),
)
)
# skip content part for tool call
current_content_index = 1
continue
elif delta_message.reasoning is not None:
yield _increment_sequence_number_and_return(
ResponseReasoningTextDeltaEvent(
type="response.reasoning_text.delta",
@@ -1450,7 +1628,7 @@ class OpenAIServingResponses(OpenAIServing):
delta=delta_message.reasoning,
)
)
elif delta_message.content is not None:
elif delta_message.content:
yield _increment_sequence_number_and_return(
ResponseTextDeltaEvent(
type="response.output_text.delta",
@@ -1473,8 +1651,50 @@ class OpenAIServingResponses(OpenAIServing):
)
previous_delta_messages.append(delta_message)
if previous_delta_messages:
if previous_delta_messages[-1].reasoning is not None:
parts = []
for pm in previous_delta_messages:
if pm.tool_calls:
assert len(pm.tool_calls) == 1, (
"Multiple tool calls in one delta is not supported"
)
assert pm.tool_calls[0].function is not None, (
"Tool call without function is not supported"
)
parts.append(pm.tool_calls[0].function.arguments or "")
tool_call_arguments = "".join(parts)
if tool_call_arguments:
yield _increment_sequence_number_and_return(
ResponseFunctionCallArgumentsDoneEvent(
type="response.function_call_arguments.done",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
arguments=tool_call_arguments,
name=current_tool_call_name,
)
)
current_content_index = 0
function_call_item = ResponseFunctionToolCall(
type="function_call",
name=current_tool_call_name,
arguments=tool_call_arguments,
status="completed",
id=current_item_id,
call_id=current_tool_call_id,
)
yield _increment_sequence_number_and_return(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
item=function_call_item,
)
)
elif previous_delta_messages[-1].reasoning is not None:
reason_content = "".join(
pm.reasoning
for pm in previous_delta_messages
@@ -1523,11 +1743,9 @@ class OpenAIServingResponses(OpenAIServing):
item=reasoning_item,
)
)
elif previous_delta_messages[-1].content is not None:
elif previous_delta_messages[-1].content:
final_content = "".join(
pm.content
for pm in previous_delta_messages
if pm.content is not None
pm.content for pm in previous_delta_messages if pm.content
)
yield _increment_sequence_number_and_return(
ResponseTextDoneEvent(