[Frontend] OpenAI Responses API supports Tool/Function calling with streaming (#29947)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -659,9 +659,10 @@ class TestStreamingReasoningToContentTransition:
|
||||
# Mock the reasoning parser on the serving instance
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming
|
||||
mock_parser.extract_tool_calls_streaming = mock_extract_reasoning_streaming
|
||||
serving.parser = MagicMock()
|
||||
serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser)
|
||||
|
||||
serving.parser.tool_parser_cls = MagicMock(return_value=mock_parser)
|
||||
# Create contexts for each streaming chunk
|
||||
contexts = [
|
||||
_make_simple_context_with_output("chunk1", [10]),
|
||||
@@ -739,8 +740,10 @@ class TestStreamingReasoningToContentTransition:
|
||||
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming
|
||||
mock_parser.extract_tool_calls_streaming = mock_extract_reasoning_streaming
|
||||
serving.parser = MagicMock()
|
||||
serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser)
|
||||
serving.parser.tool_parser_cls = MagicMock(return_value=mock_parser)
|
||||
|
||||
contexts = [
|
||||
_make_simple_context_with_output("chunk1", [10]),
|
||||
@@ -812,8 +815,10 @@ class TestStreamingReasoningToContentTransition:
|
||||
|
||||
mock_parser = MagicMock()
|
||||
mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming
|
||||
mock_parser.extract_tool_calls_streaming = mock_extract_reasoning_streaming
|
||||
serving.parser = MagicMock()
|
||||
serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser)
|
||||
serving.parser.tool_parser_cls = MagicMock(return_value=mock_parser)
|
||||
|
||||
contexts = [
|
||||
_make_simple_context_with_output("chunk1", [10]),
|
||||
|
||||
@@ -197,3 +197,108 @@ async def test_named_tool_use(client: openai.AsyncOpenAI):
|
||||
response_2 = await client.responses.create(model=MODEL_NAME, input=input_messages)
|
||||
# check the output
|
||||
assert len(response_2.output_text) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_function_calling_with_streaming_expected_arguments(
|
||||
client: openai.AsyncOpenAI, model_name: str
|
||||
):
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get current temperature for provided location in celsius.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string"},
|
||||
},
|
||||
"required": ["location"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"strict": True,
|
||||
}
|
||||
]
|
||||
|
||||
stream_response = await client.responses.create(
|
||||
model=model_name,
|
||||
input="Can you tell me what the current weather is in Berlin?",
|
||||
tools=tools,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
tool_call_item = None
|
||||
completed_event = None
|
||||
async for event in stream_response:
|
||||
if (
|
||||
event.type == "response.output_item.added"
|
||||
and event.item.type == "function_call"
|
||||
):
|
||||
tool_call_item = event.item
|
||||
elif event.type == "response.function_call_arguments.delta" and tool_call_item:
|
||||
tool_call_item.arguments += event.delta
|
||||
elif (
|
||||
event.type == "response.output_item.done"
|
||||
and event.item.type == "function_call"
|
||||
):
|
||||
completed_event = event
|
||||
assert tool_call_item is not None
|
||||
assert tool_call_item.type == "function_call"
|
||||
assert tool_call_item.name == "get_weather"
|
||||
assert completed_event is not None
|
||||
assert tool_call_item.arguments == completed_event.item.arguments
|
||||
assert tool_call_item.name == completed_event.item.name
|
||||
args = json.loads(tool_call_item.arguments)
|
||||
assert "location" in args
|
||||
assert args["location"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_function_calling_with_streaming_types(
|
||||
client: openai.AsyncOpenAI, model_name: str
|
||||
):
|
||||
# this links the "done" type with the "start" type
|
||||
# so every "done" type should have a corresponding "start" type
|
||||
# and every open block should be closed by the end of the stream
|
||||
pairs_of_event_types = {
|
||||
"response.completed": "response.created",
|
||||
"response.output_item.done": "response.output_item.added",
|
||||
"response.output_text.done": "response.output_text.delta",
|
||||
"response.content_part.done": "response.content_part.added",
|
||||
"response.reasoning_text.done": "response.reasoning_text.delta",
|
||||
"response.reasoning_part.done": "response.reasoning_part.added",
|
||||
"response.function_call_arguments.done": "response.function_call_arguments.delta", # noqa
|
||||
}
|
||||
|
||||
input_list = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Can you tell me what the current weather is in Berlin?",
|
||||
}
|
||||
]
|
||||
stream_response = await client.responses.create(
|
||||
model=model_name,
|
||||
input=input_list,
|
||||
tools=tools,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
stack_of_event_types = []
|
||||
async for event in stream_response:
|
||||
if event.type == "response.created":
|
||||
stack_of_event_types.append(event.type)
|
||||
elif event.type == "response.completed":
|
||||
assert stack_of_event_types[-1] == pairs_of_event_types[event.type]
|
||||
stack_of_event_types.pop()
|
||||
if event.type.endswith("added"):
|
||||
stack_of_event_types.append(event.type)
|
||||
elif event.type.endswith("delta"):
|
||||
if stack_of_event_types[-1] == event.type:
|
||||
continue
|
||||
stack_of_event_types.append(event.type)
|
||||
elif event.type.endswith("done"):
|
||||
assert stack_of_event_types[-1] == pairs_of_event_types[event.type]
|
||||
stack_of_event_types.pop()
|
||||
assert len(stack_of_event_types) == 0
|
||||
|
||||
@@ -15,7 +15,10 @@ from fastapi import Request
|
||||
from openai.types.responses import (
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseContentPartDoneEvent,
|
||||
ResponseFunctionCallArgumentsDeltaEvent,
|
||||
ResponseFunctionCallArgumentsDoneEvent,
|
||||
ResponseFunctionToolCall,
|
||||
ResponseFunctionToolCallItem,
|
||||
ResponseOutputItem,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
@@ -113,6 +116,7 @@ from vllm.parser import ParserManager
|
||||
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.collection_utils import as_list
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -1236,38 +1240,134 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
reasoning_parser = None
|
||||
if self.parser and self.parser.reasoning_parser_cls:
|
||||
reasoning_parser = self.parser.reasoning_parser_cls(tokenizer)
|
||||
tool_parser = None
|
||||
if self.parser and self.parser.tool_parser_cls:
|
||||
tool_parser = self.parser.tool_parser_cls(tokenizer)
|
||||
reasoning_ended = False
|
||||
tool_call_text_started = False
|
||||
previous_text = ""
|
||||
previous_token_ids: list[int] = []
|
||||
prompt_is_reasoning_end = None
|
||||
first_delta_sent = False
|
||||
previous_delta_messages: list[DeltaMessage] = []
|
||||
async for ctx in result_generator:
|
||||
assert isinstance(ctx, SimpleContext)
|
||||
if ctx.last_output is None:
|
||||
continue
|
||||
if reasoning_parser and prompt_is_reasoning_end is None:
|
||||
prompt_is_reasoning_end = reasoning_parser.is_reasoning_end(
|
||||
ctx.last_output.prompt_token_ids
|
||||
)
|
||||
if ctx.last_output.outputs:
|
||||
output = ctx.last_output.outputs[0]
|
||||
# finish_reason='error' indicates a retryable error
|
||||
self._raise_if_error(output.finish_reason, request.request_id)
|
||||
if reasoning_parser:
|
||||
delta_text = output.text
|
||||
delta_token_ids = as_list(output.token_ids)
|
||||
current_text = previous_text + delta_text
|
||||
current_token_ids = previous_token_ids + delta_token_ids
|
||||
|
||||
if reasoning_parser and tool_parser:
|
||||
if prompt_is_reasoning_end:
|
||||
reasoning_ended = True
|
||||
if not reasoning_ended:
|
||||
delta_message = reasoning_parser.extract_reasoning_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=previous_token_ids,
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=delta_token_ids,
|
||||
)
|
||||
if reasoning_parser.is_reasoning_end(delta_token_ids):
|
||||
reasoning_ended = True
|
||||
current_token_ids = reasoning_parser.extract_content_ids(
|
||||
delta_token_ids
|
||||
)
|
||||
if delta_message and delta_message.content:
|
||||
current_text = delta_message.content
|
||||
delta_message.content = None
|
||||
else:
|
||||
current_text = ""
|
||||
|
||||
if reasoning_ended:
|
||||
if not tool_call_text_started:
|
||||
tool_call_text_started = True
|
||||
previous_text = ""
|
||||
previous_token_ids = []
|
||||
delta_text = current_text
|
||||
delta_token_ids = current_token_ids
|
||||
|
||||
delta_message = tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=previous_token_ids,
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=delta_token_ids,
|
||||
request=request, # type: ignore[arg-type]
|
||||
)
|
||||
elif reasoning_parser:
|
||||
delta_message = reasoning_parser.extract_reasoning_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=previous_text + output.text,
|
||||
delta_text=output.text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=previous_token_ids,
|
||||
current_token_ids=previous_token_ids + output.token_ids,
|
||||
delta_token_ids=output.token_ids,
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=delta_token_ids,
|
||||
)
|
||||
elif tool_parser:
|
||||
delta_message = tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=previous_token_ids,
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=delta_token_ids,
|
||||
request=request, # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
delta_message = DeltaMessage(
|
||||
content=output.text,
|
||||
)
|
||||
previous_text += output.text
|
||||
previous_token_ids += output.token_ids
|
||||
previous_text = current_text
|
||||
previous_token_ids = current_token_ids
|
||||
if not delta_message:
|
||||
continue
|
||||
if not first_delta_sent:
|
||||
current_item_id = str(uuid.uuid4())
|
||||
if delta_message.reasoning:
|
||||
current_item_id = random_uuid()
|
||||
if delta_message.tool_calls:
|
||||
current_tool_call_id = f"call_{random_uuid()}"
|
||||
assert len(delta_message.tool_calls) == 1, (
|
||||
"Multiple tool calls in one delta is not supported"
|
||||
)
|
||||
assert delta_message.tool_calls[0].function is not None, (
|
||||
"Tool call without function is not supported"
|
||||
)
|
||||
assert delta_message.tool_calls[0].function.name is not None, (
|
||||
"Tool call without function name is not supported"
|
||||
)
|
||||
current_tool_call_name = delta_message.tool_calls[
|
||||
0
|
||||
].function.name
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseOutputItemAddedEvent(
|
||||
type="response.output_item.added",
|
||||
sequence_number=-1,
|
||||
output_index=current_output_index,
|
||||
item=ResponseFunctionToolCallItem(
|
||||
type="function_call",
|
||||
id=current_item_id,
|
||||
call_id=current_tool_call_id,
|
||||
name=current_tool_call_name,
|
||||
arguments=delta_message.tool_calls[
|
||||
0
|
||||
].function.arguments,
|
||||
status="in_progress",
|
||||
),
|
||||
)
|
||||
)
|
||||
elif delta_message.reasoning:
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseOutputItemAddedEvent(
|
||||
type="response.output_item.added",
|
||||
@@ -1294,7 +1394,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
elif not delta_message.tool_calls:
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseOutputItemAddedEvent(
|
||||
type="response.output_item.added",
|
||||
@@ -1325,7 +1425,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
)
|
||||
)
|
||||
first_delta_sent = True
|
||||
# todo(kebe7jun) tool call support
|
||||
|
||||
# check delta message and previous delta message are
|
||||
# same as content or reasoning content
|
||||
@@ -1438,8 +1537,87 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
)
|
||||
# reset previous delta messages
|
||||
previous_delta_messages = []
|
||||
|
||||
if delta_message.reasoning is not None:
|
||||
if delta_message.tool_calls and delta_message.tool_calls[0].function:
|
||||
if delta_message.tool_calls[0].function.arguments:
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseFunctionCallArgumentsDeltaEvent(
|
||||
type="response.function_call_arguments.delta",
|
||||
sequence_number=-1,
|
||||
output_index=current_output_index,
|
||||
item_id=current_item_id,
|
||||
delta=delta_message.tool_calls[0].function.arguments,
|
||||
)
|
||||
)
|
||||
# tool call initiated with no arguments
|
||||
elif delta_message.tool_calls[0].function.name:
|
||||
# send done with current content part
|
||||
# and add new function call item
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseTextDoneEvent(
|
||||
type="response.output_text.done",
|
||||
sequence_number=-1,
|
||||
output_index=current_output_index,
|
||||
content_index=current_content_index,
|
||||
text="",
|
||||
logprobs=[],
|
||||
item_id=current_item_id,
|
||||
)
|
||||
)
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseContentPartDoneEvent(
|
||||
type="response.content_part.done",
|
||||
sequence_number=-1,
|
||||
item_id=current_item_id,
|
||||
output_index=current_output_index,
|
||||
content_index=current_content_index,
|
||||
part=ResponseOutputText(
|
||||
type="output_text",
|
||||
text="",
|
||||
annotations=[],
|
||||
logprobs=[],
|
||||
),
|
||||
)
|
||||
)
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseOutputItemDoneEvent(
|
||||
type="response.output_item.done",
|
||||
sequence_number=-1,
|
||||
output_index=current_output_index,
|
||||
item=ResponseOutputMessage(
|
||||
id=current_item_id,
|
||||
type="message",
|
||||
role="assistant",
|
||||
content=[],
|
||||
status="completed",
|
||||
),
|
||||
)
|
||||
)
|
||||
current_output_index += 1
|
||||
current_item_id = random_uuid()
|
||||
assert delta_message.tool_calls[0].function is not None
|
||||
current_tool_call_name = delta_message.tool_calls[
|
||||
0
|
||||
].function.name
|
||||
current_tool_call_id = f"call_{random_uuid()}"
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseOutputItemAddedEvent(
|
||||
type="response.output_item.added",
|
||||
sequence_number=-1,
|
||||
output_index=current_output_index,
|
||||
item=ResponseFunctionToolCallItem(
|
||||
type="function_call",
|
||||
id=current_item_id,
|
||||
call_id=current_tool_call_id,
|
||||
name=current_tool_call_name,
|
||||
arguments="",
|
||||
status="in_progress",
|
||||
),
|
||||
)
|
||||
)
|
||||
# skip content part for tool call
|
||||
current_content_index = 1
|
||||
continue
|
||||
elif delta_message.reasoning is not None:
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseReasoningTextDeltaEvent(
|
||||
type="response.reasoning_text.delta",
|
||||
@@ -1450,7 +1628,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
delta=delta_message.reasoning,
|
||||
)
|
||||
)
|
||||
elif delta_message.content is not None:
|
||||
elif delta_message.content:
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseTextDeltaEvent(
|
||||
type="response.output_text.delta",
|
||||
@@ -1473,8 +1651,50 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
)
|
||||
|
||||
previous_delta_messages.append(delta_message)
|
||||
|
||||
if previous_delta_messages:
|
||||
if previous_delta_messages[-1].reasoning is not None:
|
||||
parts = []
|
||||
for pm in previous_delta_messages:
|
||||
if pm.tool_calls:
|
||||
assert len(pm.tool_calls) == 1, (
|
||||
"Multiple tool calls in one delta is not supported"
|
||||
)
|
||||
assert pm.tool_calls[0].function is not None, (
|
||||
"Tool call without function is not supported"
|
||||
)
|
||||
parts.append(pm.tool_calls[0].function.arguments or "")
|
||||
|
||||
tool_call_arguments = "".join(parts)
|
||||
if tool_call_arguments:
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseFunctionCallArgumentsDoneEvent(
|
||||
type="response.function_call_arguments.done",
|
||||
sequence_number=-1,
|
||||
output_index=current_output_index,
|
||||
item_id=current_item_id,
|
||||
arguments=tool_call_arguments,
|
||||
name=current_tool_call_name,
|
||||
)
|
||||
)
|
||||
current_content_index = 0
|
||||
function_call_item = ResponseFunctionToolCall(
|
||||
type="function_call",
|
||||
name=current_tool_call_name,
|
||||
arguments=tool_call_arguments,
|
||||
status="completed",
|
||||
id=current_item_id,
|
||||
call_id=current_tool_call_id,
|
||||
)
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseOutputItemDoneEvent(
|
||||
type="response.output_item.done",
|
||||
sequence_number=-1,
|
||||
output_index=current_output_index,
|
||||
item=function_call_item,
|
||||
)
|
||||
)
|
||||
|
||||
elif previous_delta_messages[-1].reasoning is not None:
|
||||
reason_content = "".join(
|
||||
pm.reasoning
|
||||
for pm in previous_delta_messages
|
||||
@@ -1523,11 +1743,9 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
item=reasoning_item,
|
||||
)
|
||||
)
|
||||
elif previous_delta_messages[-1].content is not None:
|
||||
elif previous_delta_messages[-1].content:
|
||||
final_content = "".join(
|
||||
pm.content
|
||||
for pm in previous_delta_messages
|
||||
if pm.content is not None
|
||||
pm.content for pm in previous_delta_messages if pm.content
|
||||
)
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseTextDoneEvent(
|
||||
|
||||
Reference in New Issue
Block a user