diff --git a/tests/entrypoints/openai/test_serving_responses.py b/tests/entrypoints/openai/test_serving_responses.py index 1abaaad21..0ad1e1c93 100644 --- a/tests/entrypoints/openai/test_serving_responses.py +++ b/tests/entrypoints/openai/test_serving_responses.py @@ -659,9 +659,10 @@ class TestStreamingReasoningToContentTransition: # Mock the reasoning parser on the serving instance mock_parser = MagicMock() mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming + mock_parser.extract_tool_calls_streaming = mock_extract_reasoning_streaming serving.parser = MagicMock() serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser) - + serving.parser.tool_parser_cls = MagicMock(return_value=mock_parser) # Create contexts for each streaming chunk contexts = [ _make_simple_context_with_output("chunk1", [10]), @@ -739,8 +740,10 @@ class TestStreamingReasoningToContentTransition: mock_parser = MagicMock() mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming + mock_parser.extract_tool_calls_streaming = mock_extract_reasoning_streaming serving.parser = MagicMock() serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser) + serving.parser.tool_parser_cls = MagicMock(return_value=mock_parser) contexts = [ _make_simple_context_with_output("chunk1", [10]), @@ -812,8 +815,10 @@ class TestStreamingReasoningToContentTransition: mock_parser = MagicMock() mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming + mock_parser.extract_tool_calls_streaming = mock_extract_reasoning_streaming serving.parser = MagicMock() serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser) + serving.parser.tool_parser_cls = MagicMock(return_value=mock_parser) contexts = [ _make_simple_context_with_output("chunk1", [10]), diff --git a/tests/v1/entrypoints/openai/serving_responses/test_function_call.py b/tests/v1/entrypoints/openai/serving_responses/test_function_call.py index 90161e7c2..0b8a2e649 100644 --- a/tests/v1/entrypoints/openai/serving_responses/test_function_call.py +++ b/tests/v1/entrypoints/openai/serving_responses/test_function_call.py @@ -197,3 +197,108 @@ async def test_named_tool_use(client: openai.AsyncOpenAI): response_2 = await client.responses.create(model=MODEL_NAME, input=input_messages) # check the output assert len(response_2.output_text) > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_function_calling_with_streaming_expected_arguments( + client: openai.AsyncOpenAI, model_name: str +): + tools = [ + { + "type": "function", + "name": "get_weather", + "description": "Get current temperature for provided location in celsius.", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"}, + }, + "required": ["location"], + "additionalProperties": False, + }, + "strict": True, + } + ] + + stream_response = await client.responses.create( + model=model_name, + input="Can you tell me what the current weather is in Berlin?", + tools=tools, + stream=True, + ) + + tool_call_item = None + completed_event = None + async for event in stream_response: + if ( + event.type == "response.output_item.added" + and event.item.type == "function_call" + ): + tool_call_item = event.item + elif event.type == "response.function_call_arguments.delta" and tool_call_item: + tool_call_item.arguments += event.delta + elif ( + event.type == "response.output_item.done" + and event.item.type == "function_call" + ): + completed_event = event + assert tool_call_item is not None + assert tool_call_item.type == "function_call" + assert tool_call_item.name == "get_weather" + assert completed_event is not None + assert tool_call_item.arguments == completed_event.item.arguments + assert tool_call_item.name == completed_event.item.name + args = json.loads(tool_call_item.arguments) + assert "location" in args + assert args["location"] is not None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_function_calling_with_streaming_types( + client: openai.AsyncOpenAI, model_name: str +): + # this links the "done" type with the "start" type + # so every "done" type should have a corresponding "start" type + # and every open block should be closed by the end of the stream + pairs_of_event_types = { + "response.completed": "response.created", + "response.output_item.done": "response.output_item.added", + "response.output_text.done": "response.output_text.delta", + "response.content_part.done": "response.content_part.added", + "response.reasoning_text.done": "response.reasoning_text.delta", + "response.reasoning_part.done": "response.reasoning_part.added", + "response.function_call_arguments.done": "response.function_call_arguments.delta", # noqa + } + + input_list = [ + { + "role": "user", + "content": "Can you tell me what the current weather is in Berlin?", + } + ] + stream_response = await client.responses.create( + model=model_name, + input=input_list, + tools=tools, + stream=True, + ) + + stack_of_event_types = [] + async for event in stream_response: + if event.type == "response.created": + stack_of_event_types.append(event.type) + elif event.type == "response.completed": + assert stack_of_event_types[-1] == pairs_of_event_types[event.type] + stack_of_event_types.pop() + if event.type.endswith("added"): + stack_of_event_types.append(event.type) + elif event.type.endswith("delta"): + if stack_of_event_types[-1] == event.type: + continue + stack_of_event_types.append(event.type) + elif event.type.endswith("done"): + assert stack_of_event_types[-1] == pairs_of_event_types[event.type] + stack_of_event_types.pop() + assert len(stack_of_event_types) == 0 diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index ddd7bae04..a7eaccd83 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -15,7 +15,10 @@ from fastapi import Request from openai.types.responses import ( ResponseContentPartAddedEvent, ResponseContentPartDoneEvent, + ResponseFunctionCallArgumentsDeltaEvent, + ResponseFunctionCallArgumentsDoneEvent, ResponseFunctionToolCall, + ResponseFunctionToolCallItem, ResponseOutputItem, ResponseOutputItemAddedEvent, ResponseOutputItemDoneEvent, @@ -113,6 +116,7 @@ from vllm.parser import ParserManager from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.tokenizers import TokenizerLike from vllm.utils import random_uuid +from vllm.utils.collection_utils import as_list logger = init_logger(__name__) @@ -1236,38 +1240,134 @@ class OpenAIServingResponses(OpenAIServing): reasoning_parser = None if self.parser and self.parser.reasoning_parser_cls: reasoning_parser = self.parser.reasoning_parser_cls(tokenizer) + tool_parser = None + if self.parser and self.parser.tool_parser_cls: + tool_parser = self.parser.tool_parser_cls(tokenizer) + reasoning_ended = False + tool_call_text_started = False previous_text = "" previous_token_ids: list[int] = [] + prompt_is_reasoning_end = None first_delta_sent = False previous_delta_messages: list[DeltaMessage] = [] async for ctx in result_generator: assert isinstance(ctx, SimpleContext) if ctx.last_output is None: continue + if reasoning_parser and prompt_is_reasoning_end is None: + prompt_is_reasoning_end = reasoning_parser.is_reasoning_end( + ctx.last_output.prompt_token_ids + ) if ctx.last_output.outputs: output = ctx.last_output.outputs[0] # finish_reason='error' indicates a retryable error self._raise_if_error(output.finish_reason, request.request_id) - if reasoning_parser: + delta_text = output.text + delta_token_ids = as_list(output.token_ids) + current_text = previous_text + delta_text + current_token_ids = previous_token_ids + delta_token_ids + + if reasoning_parser and tool_parser: + if prompt_is_reasoning_end: + reasoning_ended = True + if not reasoning_ended: + delta_message = reasoning_parser.extract_reasoning_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + ) + if reasoning_parser.is_reasoning_end(delta_token_ids): + reasoning_ended = True + current_token_ids = reasoning_parser.extract_content_ids( + delta_token_ids + ) + if delta_message and delta_message.content: + current_text = delta_message.content + delta_message.content = None + else: + current_text = "" + + if reasoning_ended: + if not tool_call_text_started: + tool_call_text_started = True + previous_text = "" + previous_token_ids = [] + delta_text = current_text + delta_token_ids = current_token_ids + + delta_message = tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + request=request, # type: ignore[arg-type] + ) + elif reasoning_parser: delta_message = reasoning_parser.extract_reasoning_streaming( previous_text=previous_text, - current_text=previous_text + output.text, - delta_text=output.text, + current_text=current_text, + delta_text=delta_text, previous_token_ids=previous_token_ids, - current_token_ids=previous_token_ids + output.token_ids, - delta_token_ids=output.token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + ) + elif tool_parser: + delta_message = tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + request=request, # type: ignore[arg-type] ) else: delta_message = DeltaMessage( content=output.text, ) - previous_text += output.text - previous_token_ids += output.token_ids + previous_text = current_text + previous_token_ids = current_token_ids if not delta_message: continue if not first_delta_sent: - current_item_id = str(uuid.uuid4()) - if delta_message.reasoning: + current_item_id = random_uuid() + if delta_message.tool_calls: + current_tool_call_id = f"call_{random_uuid()}" + assert len(delta_message.tool_calls) == 1, ( + "Multiple tool calls in one delta is not supported" + ) + assert delta_message.tool_calls[0].function is not None, ( + "Tool call without function is not supported" + ) + assert delta_message.tool_calls[0].function.name is not None, ( + "Tool call without function name is not supported" + ) + current_tool_call_name = delta_message.tool_calls[ + 0 + ].function.name + yield _increment_sequence_number_and_return( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=ResponseFunctionToolCallItem( + type="function_call", + id=current_item_id, + call_id=current_tool_call_id, + name=current_tool_call_name, + arguments=delta_message.tool_calls[ + 0 + ].function.arguments, + status="in_progress", + ), + ) + ) + elif delta_message.reasoning: yield _increment_sequence_number_and_return( ResponseOutputItemAddedEvent( type="response.output_item.added", @@ -1294,7 +1394,7 @@ class OpenAIServingResponses(OpenAIServing): ), ) ) - else: + elif not delta_message.tool_calls: yield _increment_sequence_number_and_return( ResponseOutputItemAddedEvent( type="response.output_item.added", @@ -1325,7 +1425,6 @@ class OpenAIServingResponses(OpenAIServing): ) ) first_delta_sent = True - # todo(kebe7jun) tool call support # check delta message and previous delta message are # same as content or reasoning content @@ -1438,8 +1537,87 @@ class OpenAIServingResponses(OpenAIServing): ) # reset previous delta messages previous_delta_messages = [] - - if delta_message.reasoning is not None: + if delta_message.tool_calls and delta_message.tool_calls[0].function: + if delta_message.tool_calls[0].function.arguments: + yield _increment_sequence_number_and_return( + ResponseFunctionCallArgumentsDeltaEvent( + type="response.function_call_arguments.delta", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + delta=delta_message.tool_calls[0].function.arguments, + ) + ) + # tool call initiated with no arguments + elif delta_message.tool_calls[0].function.name: + # send done with current content part + # and add new function call item + yield _increment_sequence_number_and_return( + ResponseTextDoneEvent( + type="response.output_text.done", + sequence_number=-1, + output_index=current_output_index, + content_index=current_content_index, + text="", + logprobs=[], + item_id=current_item_id, + ) + ) + yield _increment_sequence_number_and_return( + ResponseContentPartDoneEvent( + type="response.content_part.done", + sequence_number=-1, + item_id=current_item_id, + output_index=current_output_index, + content_index=current_content_index, + part=ResponseOutputText( + type="output_text", + text="", + annotations=[], + logprobs=[], + ), + ) + ) + yield _increment_sequence_number_and_return( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=ResponseOutputMessage( + id=current_item_id, + type="message", + role="assistant", + content=[], + status="completed", + ), + ) + ) + current_output_index += 1 + current_item_id = random_uuid() + assert delta_message.tool_calls[0].function is not None + current_tool_call_name = delta_message.tool_calls[ + 0 + ].function.name + current_tool_call_id = f"call_{random_uuid()}" + yield _increment_sequence_number_and_return( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=ResponseFunctionToolCallItem( + type="function_call", + id=current_item_id, + call_id=current_tool_call_id, + name=current_tool_call_name, + arguments="", + status="in_progress", + ), + ) + ) + # skip content part for tool call + current_content_index = 1 + continue + elif delta_message.reasoning is not None: yield _increment_sequence_number_and_return( ResponseReasoningTextDeltaEvent( type="response.reasoning_text.delta", @@ -1450,7 +1628,7 @@ class OpenAIServingResponses(OpenAIServing): delta=delta_message.reasoning, ) ) - elif delta_message.content is not None: + elif delta_message.content: yield _increment_sequence_number_and_return( ResponseTextDeltaEvent( type="response.output_text.delta", @@ -1473,8 +1651,50 @@ class OpenAIServingResponses(OpenAIServing): ) previous_delta_messages.append(delta_message) + if previous_delta_messages: - if previous_delta_messages[-1].reasoning is not None: + parts = [] + for pm in previous_delta_messages: + if pm.tool_calls: + assert len(pm.tool_calls) == 1, ( + "Multiple tool calls in one delta is not supported" + ) + assert pm.tool_calls[0].function is not None, ( + "Tool call without function is not supported" + ) + parts.append(pm.tool_calls[0].function.arguments or "") + + tool_call_arguments = "".join(parts) + if tool_call_arguments: + yield _increment_sequence_number_and_return( + ResponseFunctionCallArgumentsDoneEvent( + type="response.function_call_arguments.done", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + arguments=tool_call_arguments, + name=current_tool_call_name, + ) + ) + current_content_index = 0 + function_call_item = ResponseFunctionToolCall( + type="function_call", + name=current_tool_call_name, + arguments=tool_call_arguments, + status="completed", + id=current_item_id, + call_id=current_tool_call_id, + ) + yield _increment_sequence_number_and_return( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=function_call_item, + ) + ) + + elif previous_delta_messages[-1].reasoning is not None: reason_content = "".join( pm.reasoning for pm in previous_delta_messages @@ -1523,11 +1743,9 @@ class OpenAIServingResponses(OpenAIServing): item=reasoning_item, ) ) - elif previous_delta_messages[-1].content is not None: + elif previous_delta_messages[-1].content: final_content = "".join( - pm.content - for pm in previous_delta_messages - if pm.content is not None + pm.content for pm in previous_delta_messages if pm.content ) yield _increment_sequence_number_and_return( ResponseTextDoneEvent(