[Feature][Responses API] Support MCP tool in background mode (#23494)

Signed-off-by: wuhang <wuhang6@huawei.com>
This commit is contained in:
wuhang
2025-08-27 09:06:58 +08:00
committed by GitHub
parent b1625dbe9c
commit 6891205b16
2 changed files with 162 additions and 134 deletions

View File

@@ -4,13 +4,15 @@ import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import TYPE_CHECKING, Union
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Optional, Union
from openai_harmony import Author, Message, Role, StreamState, TextContent
from vllm.entrypoints.harmony_utils import (
get_encoding, get_streamable_parser_for_assistant, render_for_completion)
from vllm.entrypoints.tool import Tool
from vllm.entrypoints.tool_server import ToolServer
from vllm.outputs import RequestOutput
if TYPE_CHECKING:
@@ -37,6 +39,11 @@ class ConversationContext(ABC):
def render_for_completion(self) -> list[int]:
pass
@abstractmethod
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack) -> None:
pass
class SimpleContext(ConversationContext):
@@ -55,16 +62,21 @@ class SimpleContext(ConversationContext):
def render_for_completion(self) -> list[int]:
raise NotImplementedError("Should not be called.")
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack) -> None:
pass
class HarmonyContext(ConversationContext):
def __init__(
self,
messages: list,
tool_sessions: dict[str, Tool],
available_tools: list[str],
):
self._messages = messages
self.tool_sessions = tool_sessions
self.available_tools = available_tools
self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {}
self.parser = get_streamable_parser_for_assistant()
self.num_init_messages = len(messages)
@@ -116,10 +128,10 @@ class HarmonyContext(ConversationContext):
if recipient is not None:
if recipient.startswith("browser."):
return await self.call_search_tool(
self.tool_sessions["browser"], last_msg)
self._tool_sessions["browser"], last_msg)
elif recipient.startswith("python"):
return await self.call_python_tool(
self.tool_sessions["python"], last_msg)
self._tool_sessions["python"], last_msg)
raise ValueError("No tool call found")
def render_for_completion(self) -> list[int]:
@@ -161,6 +173,15 @@ class HarmonyContext(ConversationContext):
recipient=Role.ASSISTANT)
]
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack) -> None:
if tool_server:
for tool_name in self.available_tools:
if tool_name not in self._tool_sessions:
self._tool_sessions[
tool_name] = await exit_stack.enter_async_context(
tool_server.new_session(tool_name))
class StreamingHarmonyContext(HarmonyContext):

View File

@@ -8,7 +8,7 @@ from collections.abc import AsyncGenerator, AsyncIterator, Sequence
from contextlib import AsyncExitStack
from copy import copy
from http import HTTPStatus
from typing import Any, Callable, Final, Optional, Union
from typing import Callable, Final, Optional, Union
import jinja2
import openai.types.responses as openai_responses_types
@@ -248,8 +248,8 @@ class OpenAIServingResponses(OpenAIServing):
raw_request.state.request_metadata = request_metadata
if self.tool_server is not None and isinstance(
self.tool_server, MCPToolServer
) and (request.background or request.stream) and request.tools and any(
self.tool_server,
MCPToolServer) and request.stream and request.tools and any(
tool.type in ["web_search_preview", "code_interpreter"]
for tool in request.tools):
return self.create_error_response(
@@ -265,24 +265,13 @@ class OpenAIServingResponses(OpenAIServing):
builtin_tool_list.append("browser")
if self.tool_server.has_tool("python"):
builtin_tool_list.append("python")
async with AsyncExitStack() as exit_stack:
try:
if self.tool_server is not None:
# TODO: initialize tool sessions lazily when the session
# is actually used.
tool_session_ctxs: dict[str, Any] = {
tool_name:
exit_stack.enter_async_context(
self.tool_server.new_session(tool_name))
for tool_name in builtin_tool_list
}
tool_sessions = {}
for tool_name in builtin_tool_list:
tool_sessions[tool_name] = (
await tool_session_ctxs[tool_name])
available_tools = builtin_tool_list
else:
assert len(builtin_tool_list) == 0
tool_sessions = {}
available_tools = []
try:
for i, engine_prompt in enumerate(engine_prompts):
default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"])
@@ -290,16 +279,15 @@ class OpenAIServingResponses(OpenAIServing):
default_max_tokens, self.default_sampling_params)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(
raw_request.headers))
self._get_trace_headers(raw_request.headers))
context: ConversationContext
if self.use_harmony:
if request.stream:
context = StreamingHarmonyContext(
messages, tool_sessions)
messages, available_tools)
else:
context = HarmonyContext(messages, tool_sessions)
context = HarmonyContext(messages, available_tools)
else:
context = SimpleContext()
generator = self._generate_with_builtin_tools(
@@ -383,7 +371,6 @@ class OpenAIServingResponses(OpenAIServing):
)
except Exception as e:
return self.create_error_response(str(e))
return self.create_error_response("Should not reach here")
async def _make_request(
self,
@@ -439,7 +426,9 @@ class OpenAIServingResponses(OpenAIServing):
if created_time is None:
created_time = int(time.time())
async with AsyncExitStack() as exit_stack:
try:
await context.init_tool_sessions(self.tool_server, exit_stack)
async for _ in result_generator:
pass
except asyncio.CancelledError:
@@ -838,7 +827,7 @@ class OpenAIServingResponses(OpenAIServing):
status_code=HTTPStatus.BAD_REQUEST,
)
async def responses_stream_generator(
async def _process_streaming_events(
self,
request: ResponsesRequest,
sampling_params: SamplingParams,
@@ -847,18 +836,8 @@ class OpenAIServingResponses(OpenAIServing):
model_name: str,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
created_time: Optional[int] = None,
created_time: int,
) -> AsyncGenerator[str, None]:
# TODO:
# 1. Handle disconnect
if not isinstance(context, StreamingHarmonyContext):
raise NotImplementedError(
"Streaming is not supported for responses API without Harmony."
)
created_time = created_time or int(time.time())
sequence_number = 0
def _send_event(event: BaseModel):
@@ -1270,3 +1249,31 @@ class OpenAIServingResponses(OpenAIServing):
sequence_number=-1,
response=final_response.model_dump(),
))
async def responses_stream_generator(
self,
request: ResponsesRequest,
sampling_params: SamplingParams,
result_generator: AsyncIterator[Optional[ConversationContext]],
context: ConversationContext,
model_name: str,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
created_time: Optional[int] = None,
) -> AsyncGenerator[str, None]:
# TODO:
# 1. Handle disconnect
if not isinstance(context, StreamingHarmonyContext):
raise NotImplementedError(
"Streaming is not supported for responses API without Harmony."
)
created_time = created_time or int(time.time())
async with AsyncExitStack() as exit_stack:
await context.init_tool_sessions(self.tool_server, exit_stack)
async for event_data in self._process_streaming_events(
request, sampling_params, result_generator, context,
model_name, tokenizer, request_metadata, created_time):
yield event_data