[responsesAPI][5] ResponsesParser with tools for full MCP python loop (#29798)
Signed-off-by: Andrew Xia <axia@fb.com> Signed-off-by: Andrew Xia <axia@meta.com> Co-authored-by: Andrew Xia <axia@fb.com>
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
|
||||
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
|
||||
from openai.types.responses.response_output_message import ResponseOutputMessage
|
||||
from openai.types.responses.response_output_text import ResponseOutputText
|
||||
from openai.types.responses.response_reasoning_item import (
|
||||
@@ -11,8 +12,10 @@ from openai.types.responses.response_reasoning_item import (
|
||||
)
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ResponseInputOutputItem, ResponsesRequest
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser
|
||||
from vllm.outputs import CompletionOutput
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
||||
from vllm.tokenizers.protocol import TokenizerLike
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
@@ -29,6 +32,7 @@ class ResponsesParser:
|
||||
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser],
|
||||
response_messages: list[ResponseInputOutputItem],
|
||||
request: ResponsesRequest,
|
||||
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
|
||||
):
|
||||
self.response_messages: list[ResponseInputOutputItem] = (
|
||||
# TODO: initial messages may not be properly typed
|
||||
@@ -39,6 +43,9 @@ class ResponsesParser:
|
||||
self.request = request
|
||||
|
||||
self.reasoning_parser_instance = reasoning_parser_cls(tokenizer)
|
||||
self.tool_parser_instance = None
|
||||
if tool_parser_cls is not None:
|
||||
self.tool_parser_instance = tool_parser_cls(tokenizer)
|
||||
|
||||
def process(self, output: CompletionOutput) -> "ResponsesParser":
|
||||
reasoning_content, content = self.reasoning_parser_instance.extract_reasoning(
|
||||
@@ -59,6 +66,29 @@ class ResponsesParser:
|
||||
)
|
||||
)
|
||||
|
||||
function_calls: list[ResponseFunctionToolCall] = []
|
||||
if self.tool_parser_instance is not None:
|
||||
tool_call_info = self.tool_parser_instance.extract_tool_calls(
|
||||
content if content is not None else "",
|
||||
request=self.request, # type: ignore
|
||||
)
|
||||
if tool_call_info is not None and tool_call_info.tools_called:
|
||||
# extract_tool_calls() returns a list of tool calls.
|
||||
function_calls.extend(
|
||||
ResponseFunctionToolCall(
|
||||
id=f"fc_{random_uuid()}",
|
||||
call_id=f"call_{random_uuid()}",
|
||||
type="function_call",
|
||||
status="completed",
|
||||
name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments,
|
||||
)
|
||||
for tool_call in tool_call_info.tool_calls
|
||||
)
|
||||
content = tool_call_info.content
|
||||
if content and content.strip() == "":
|
||||
content = None
|
||||
|
||||
if content:
|
||||
self.response_messages.append(
|
||||
ResponseOutputMessage(
|
||||
@@ -76,6 +106,8 @@ class ResponsesParser:
|
||||
],
|
||||
)
|
||||
)
|
||||
if len(function_calls) > 0:
|
||||
self.response_messages.extend(function_calls)
|
||||
|
||||
return self
|
||||
|
||||
@@ -86,6 +118,7 @@ def get_responses_parser_for_simple_context(
|
||||
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser],
|
||||
response_messages: list[ResponseInputOutputItem],
|
||||
request: ResponsesRequest,
|
||||
tool_parser_cls,
|
||||
) -> ResponsesParser:
|
||||
"""Factory function to create a ResponsesParser with
|
||||
optional reasoning parser.
|
||||
@@ -98,4 +131,5 @@ def get_responses_parser_for_simple_context(
|
||||
reasoning_parser_cls=reasoning_parser_cls,
|
||||
response_messages=response_messages,
|
||||
request=request,
|
||||
tool_parser_cls=tool_parser_cls,
|
||||
)
|
||||
|
||||
@@ -18,6 +18,16 @@ from pydantic import ConfigDict, TypeAdapter
|
||||
from starlette.datastructures import Headers
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from vllm.entrypoints.context import (
|
||||
HarmonyContext,
|
||||
ParsableContext,
|
||||
StreamingHarmonyContext,
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
FunctionCall,
|
||||
ResponseInputOutputItem,
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
@@ -39,6 +49,7 @@ from vllm.entrypoints.pooling.score.protocol import (
|
||||
ScoreRequest,
|
||||
ScoreResponse,
|
||||
)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
@@ -72,9 +83,7 @@ from vllm.entrypoints.openai.protocol import (
|
||||
DetokenizeRequest,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
FunctionCall,
|
||||
FunctionDefinition,
|
||||
ResponsesRequest,
|
||||
TokenizeChatRequest,
|
||||
TokenizeCompletionRequest,
|
||||
TokenizeResponse,
|
||||
@@ -85,6 +94,9 @@ from vllm.entrypoints.openai.protocol import (
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
|
||||
from vllm.entrypoints.responses_utils import (
|
||||
construct_input_messages,
|
||||
)
|
||||
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.inputs.data import PromptType
|
||||
@@ -1224,6 +1236,31 @@ class OpenAIServing:
|
||||
)
|
||||
return engine_request, tokenization_kwargs
|
||||
|
||||
async def _render_next_turn(
|
||||
self,
|
||||
request: ResponsesRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
messages: list[ResponseInputOutputItem],
|
||||
tool_dicts: list[dict[str, Any]] | None,
|
||||
tool_parser,
|
||||
chat_template: str | None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
):
|
||||
new_messages = construct_input_messages(
|
||||
request_input=messages,
|
||||
)
|
||||
|
||||
_, request_prompts, engine_prompts = await self._preprocess_chat(
|
||||
request,
|
||||
tokenizer,
|
||||
new_messages,
|
||||
tool_dicts=tool_dicts,
|
||||
tool_parser=tool_parser,
|
||||
chat_template=chat_template,
|
||||
chat_template_content_format=chat_template_content_format,
|
||||
)
|
||||
return request_prompts, engine_prompts
|
||||
|
||||
async def _generate_with_builtin_tools(
|
||||
self,
|
||||
request_id: str,
|
||||
@@ -1286,11 +1323,27 @@ class OpenAIServing:
|
||||
|
||||
# Create inputs for the next turn.
|
||||
# Render the next prompt token ids.
|
||||
prompt_token_ids = context.render_for_completion()
|
||||
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
|
||||
request_prompt = prompt_token_ids
|
||||
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
|
||||
prompt_token_ids = context.render_for_completion()
|
||||
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
|
||||
request_prompt = prompt_token_ids
|
||||
elif isinstance(context, ParsableContext):
|
||||
request_prompts, engine_prompts = await self._render_next_turn(
|
||||
context.request,
|
||||
context.tokenizer,
|
||||
context.parser.response_messages,
|
||||
context.tool_dicts,
|
||||
context.tool_parser_cls,
|
||||
context.chat_template,
|
||||
context.chat_template_content_format,
|
||||
)
|
||||
engine_prompt = engine_prompts[0]
|
||||
request_prompt = request_prompts[0]
|
||||
|
||||
# Update the sampling params.
|
||||
sampling_params.max_tokens = self.max_model_len - len(prompt_token_ids)
|
||||
sampling_params.max_tokens = self.max_model_len - len(
|
||||
engine_prompt["prompt_token_ids"]
|
||||
)
|
||||
# OPTIMIZATION
|
||||
priority = orig_priority - 1
|
||||
sub_request += 1
|
||||
|
||||
@@ -375,7 +375,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
generators: list[AsyncGenerator[ConversationContext, None]] = []
|
||||
|
||||
builtin_tool_list: list[str] = []
|
||||
if self.use_harmony and self.tool_server is not None:
|
||||
if self.tool_server is not None:
|
||||
if self.tool_server.has_tool("browser"):
|
||||
builtin_tool_list.append("browser")
|
||||
if self.tool_server.has_tool("python"):
|
||||
@@ -423,6 +423,10 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
tokenizer=tokenizer,
|
||||
reasoning_parser_cls=self.reasoning_parser,
|
||||
request=request,
|
||||
tool_parser_cls=self.tool_parser,
|
||||
available_tools=available_tools,
|
||||
chat_template=self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
)
|
||||
else:
|
||||
context = SimpleContext()
|
||||
|
||||
Reference in New Issue
Block a user