[responsesAPI][5] ResponsesParser with tools for full MCP python loop (#29798)

Signed-off-by: Andrew Xia <axia@fb.com>
Signed-off-by: Andrew Xia <axia@meta.com>
Co-authored-by: Andrew Xia <axia@fb.com>
This commit is contained in:
Andrew Xia
2025-12-05 08:11:50 -08:00
committed by GitHub
parent 949a6a19d2
commit da7bc54ea8
8 changed files with 347 additions and 16 deletions

View File

@@ -3,6 +3,7 @@
import logging
from collections.abc import Callable
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
from openai.types.responses.response_output_message import ResponseOutputMessage
from openai.types.responses.response_output_text import ResponseOutputText
from openai.types.responses.response_reasoning_item import (
@@ -11,8 +12,10 @@ from openai.types.responses.response_reasoning_item import (
)
from vllm.entrypoints.openai.protocol import ResponseInputOutputItem, ResponsesRequest
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser
from vllm.outputs import CompletionOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.tokenizers.protocol import TokenizerLike
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
@@ -29,6 +32,7 @@ class ResponsesParser:
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser],
response_messages: list[ResponseInputOutputItem],
request: ResponsesRequest,
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
):
self.response_messages: list[ResponseInputOutputItem] = (
# TODO: initial messages may not be properly typed
@@ -39,6 +43,9 @@ class ResponsesParser:
self.request = request
self.reasoning_parser_instance = reasoning_parser_cls(tokenizer)
self.tool_parser_instance = None
if tool_parser_cls is not None:
self.tool_parser_instance = tool_parser_cls(tokenizer)
def process(self, output: CompletionOutput) -> "ResponsesParser":
reasoning_content, content = self.reasoning_parser_instance.extract_reasoning(
@@ -59,6 +66,29 @@ class ResponsesParser:
)
)
function_calls: list[ResponseFunctionToolCall] = []
if self.tool_parser_instance is not None:
tool_call_info = self.tool_parser_instance.extract_tool_calls(
content if content is not None else "",
request=self.request, # type: ignore
)
if tool_call_info is not None and tool_call_info.tools_called:
# extract_tool_calls() returns a list of tool calls.
function_calls.extend(
ResponseFunctionToolCall(
id=f"fc_{random_uuid()}",
call_id=f"call_{random_uuid()}",
type="function_call",
status="completed",
name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
for tool_call in tool_call_info.tool_calls
)
content = tool_call_info.content
if content and content.strip() == "":
content = None
if content:
self.response_messages.append(
ResponseOutputMessage(
@@ -76,6 +106,8 @@ class ResponsesParser:
],
)
)
if len(function_calls) > 0:
self.response_messages.extend(function_calls)
return self
@@ -86,6 +118,7 @@ def get_responses_parser_for_simple_context(
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser],
response_messages: list[ResponseInputOutputItem],
request: ResponsesRequest,
tool_parser_cls,
) -> ResponsesParser:
"""Factory function to create a ResponsesParser with
optional reasoning parser.
@@ -98,4 +131,5 @@ def get_responses_parser_for_simple_context(
reasoning_parser_cls=reasoning_parser_cls,
response_messages=response_messages,
request=request,
tool_parser_cls=tool_parser_cls,
)

View File

@@ -18,6 +18,16 @@ from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers
from typing_extensions import TypeIs
from vllm.entrypoints.context import (
HarmonyContext,
ParsableContext,
StreamingHarmonyContext,
)
from vllm.entrypoints.openai.protocol import (
FunctionCall,
ResponseInputOutputItem,
ResponsesRequest,
)
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
@@ -39,6 +49,7 @@ from vllm.entrypoints.pooling.score.protocol import (
ScoreRequest,
ScoreResponse,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer
if sys.version_info >= (3, 12):
from typing import TypedDict
@@ -72,9 +83,7 @@ from vllm.entrypoints.openai.protocol import (
DetokenizeRequest,
ErrorInfo,
ErrorResponse,
FunctionCall,
FunctionDefinition,
ResponsesRequest,
TokenizeChatRequest,
TokenizeCompletionRequest,
TokenizeResponse,
@@ -85,6 +94,9 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
from vllm.entrypoints.responses_utils import (
construct_input_messages,
)
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs.data import PromptType
@@ -1224,6 +1236,31 @@ class OpenAIServing:
)
return engine_request, tokenization_kwargs
async def _render_next_turn(
self,
request: ResponsesRequest,
tokenizer: AnyTokenizer,
messages: list[ResponseInputOutputItem],
tool_dicts: list[dict[str, Any]] | None,
tool_parser,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
):
new_messages = construct_input_messages(
request_input=messages,
)
_, request_prompts, engine_prompts = await self._preprocess_chat(
request,
tokenizer,
new_messages,
tool_dicts=tool_dicts,
tool_parser=tool_parser,
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
)
return request_prompts, engine_prompts
async def _generate_with_builtin_tools(
self,
request_id: str,
@@ -1286,11 +1323,27 @@ class OpenAIServing:
# Create inputs for the next turn.
# Render the next prompt token ids.
prompt_token_ids = context.render_for_completion()
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
request_prompt = prompt_token_ids
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
prompt_token_ids = context.render_for_completion()
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
request_prompt = prompt_token_ids
elif isinstance(context, ParsableContext):
request_prompts, engine_prompts = await self._render_next_turn(
context.request,
context.tokenizer,
context.parser.response_messages,
context.tool_dicts,
context.tool_parser_cls,
context.chat_template,
context.chat_template_content_format,
)
engine_prompt = engine_prompts[0]
request_prompt = request_prompts[0]
# Update the sampling params.
sampling_params.max_tokens = self.max_model_len - len(prompt_token_ids)
sampling_params.max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"]
)
# OPTIMIZATION
priority = orig_priority - 1
sub_request += 1

View File

@@ -375,7 +375,7 @@ class OpenAIServingResponses(OpenAIServing):
generators: list[AsyncGenerator[ConversationContext, None]] = []
builtin_tool_list: list[str] = []
if self.use_harmony and self.tool_server is not None:
if self.tool_server is not None:
if self.tool_server.has_tool("browser"):
builtin_tool_list.append("browser")
if self.tool_server.has_tool("python"):
@@ -423,6 +423,10 @@ class OpenAIServingResponses(OpenAIServing):
tokenizer=tokenizer,
reasoning_parser_cls=self.reasoning_parser,
request=request,
tool_parser_cls=self.tool_parser,
available_tools=available_tools,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
)
else:
context = SimpleContext()